Building a Decision Tree in Python from Postgres data

This example uses a twenty year old data set that you can use to predict someone’s income from demographic data.

The purpose of this example is to show how to go from data in a relational database to a predictive model, and note what problems you may encounter.

One of the nice things about this data is that the authors of the paper provide the accuracy of various algorithms, which we can use to smoke test our results:

       
           Algorithm               Error
        -- ----------------        -----
        1  C4.5                    15.54
        2  C4.5-auto               14.46
        3  C4.5 rules              14.94
        4  Voted ID3 (0.6)         15.64
        5  Voted ID3 (0.8)         16.47
        6  T2                      16.84
        7  1R                      19.54
        8  NBTree                  14.10
        9  CN2                     16.00
        10 HOODG                   14.82
        11 FSS Naive Bayes         14.05
        12 IDTM (Decision table)   14.46
        13 Naive-Bayes             16.12
        14 Nearest-neighbor (1)    21.42
        15 Nearest-neighbor (3)    20.35
        16 OC1                     15.04

To load this dataset into Postgres you’ll need to remove the blank lines at the end of the author’s files, and the “1×0 Cross Validator” line.

The following will load their data – If this seems contrived, I’m doing this to set up test cases for a real Postgres database.

Note that Postgres can read UNC paths, which shouldn’t be surprising, but considering Visual Studio still can’t, it’s always nice when it works.

drop table income_trn;

create table income_trn 
	(age integer, 
	 workclass text, 
	 fnlwgt integer, 
	 education text, 
	 education_num integer, 
	 marital_status text, 
	 occupation text, 
	 relationship text, 
	 race text, 
	 sex text, 
	 capital_gain integer, 
	 capital_loss integer, 
	 hours_per_week integer, 
	 native_country text,
	 category text);

COPY income_trn 
FROM '\\\\nas\\Files\\Data\\income\\adult.data' DELIMITER ',' CSV;

drop table income_test;
create table income_test 
	(age integer, 
	 workclass text, 
	 fnlwgt integer, 
	 education text, 
	 education_num integer, 
	 marital_status text, 
	 occupation text, 
	 relationship text, 
	 race text, 
	 sex text, 
	 capital_gain integer, 
	 capital_loss integer, 
	 hours_per_week integer, 
	 native_country text,
	 category text);

COPY income_test 
FROM '\\\\nas\\Files\\Data\\income\\adult.test' DELIMITER ',' CSV;

You can load this data in python easily with sqlalchemy. However, the Postgres driver (“pg8000”) seems to be flaky, as it will throw these errors at random: (Edit 2/26/14 – one of the pg8000 contributors fixed the issue after I posted this – Thanks!)

ProgrammingError: (ProgrammingError) 
('ERROR', '34000', 
'portal "pg8000_portal_12" does not exist') 
None None

These can be caused by using an old version of Postgres (I’m on 9.3), although the error is inconsistent – it appears to be an issue where the code reads from a closed cursor.

from sqlalchemy import *
engine = create_engine(
                "postgresql+pg8000://postgres:postgres@localhost/pacer",
                isolation_level="READ UNCOMMITTED"
            )
c = engine.connect()

meta = MetaData()

income_trn = Table('income_trn', meta, autoload=True, autoload_with=engine)
income_test = Table('income_test', meta, autoload=True, autoload_with=engine)

For a large data set it is likely valuable to stream the results of queries into the model, but this dataset is small, so I haven’t attempted that. If your data is all in one table, you’ll need to find a way to split it in half randomly that also performs well enough.

from sqlalchemy.sql import select

def get_data(table):
  s = select([table])
  result = c.execute(s)  
  return [row for row in result]

test_data = get_data(income_trn)
trn_data = get_data(income_test)

When this data comes in, it has text labels for some columns, and integers for others (e.g. age vs occupation). This is oddly difficult for the python machine learning libraries (or at least, the decision trees). This is a little concerning considering that continuous data is very different than data that is a set of values.

The library expects that everything is a list of values, but they have to be numbers, but we can build a global dictionary to map back and forth:

maxVal = 0
vals = dict()
rev_vals = dict()
def f(x):
  global maxVal
  global vals
  if (not x in vals):
    maxVal = maxVal + 1
    vals[x] = maxVal
    rev_vals[maxVal] = x
  return vals[x]

Then, we have to split the attributes into the output and the attributes used to predict the output:

def get_selectors(data):
  return [ [f(x) for x in t[0:-1]] for t in data]

def get_predictors(data):
  return [0 if "<" in t[14] else 1 for t in data]
 
trn = get_selectors(trn_data)
trn_v = get_predictors(trn_data)

The most compelling thing about this whole set-up to me is how trivially easy it is to build a model:

from sklearn import tree
clf = tree.DecisionTreeRegressor()
clf = clf.fit(trn, trn_v)

I ended up writing my own test method, because the confusion matrix calculator didn't like data in classes.

test = get_selectors(test_data)
test_v = get_predictors(test_data)

testsRun = 0
testsPassed = 0
for t in test:
  if clf.predict(t) == test_v[testsRun]:
    testsPassed = testsPassed + 1

  testsRun = testsRun + 1

100 * testsPassed / testsRun

DecisionTreeClassifier: 78%
DecisionTreeRegressor: 79%

If you look at the the scikit-learn documentation, all the examples have awesome charts. One of the things I discovered from this is that the Decision Trees can be quite long - they may have thousands of rules included, which doesn't lend itself to charting, except in the simplest cases.

3 Replies to “Building a Decision Tree in Python from Postgres data”

  1. I’m a contributor to pg8000 and I reproduced the bug you found, and I’ve released a new version of pg8000 (1.9.6) that should fix it. Let me know if you still have problems.

Leave a Reply to Tony Locke Cancel reply

Your email address will not be published. Required fields are marked *