Decision Tree analysis with example

What is Decision Tree and how to implement and train it to classify new items, implementation and analysis of Decision Tree with an example


Category: Machine Learning Tags: Python, Python 3

Decision Tree Code Files

Introduction

    You must be aware of Binary Search Tree where each node may have two children, if child’s value is less than or equal to parent then it is left child else it is added as right child. Decision tree is similar to binary search tree only difference, conditions for left and right child will be dynamic, different at each node and defined as per training data.

 

Decision Tree
Fig 1: Decision Tree

 

Decision tree is easy to interpret, you can understand it just looking at figure above. After training, decision tree looks like hierarchical if-else statements.

Implementation

    Suppose a builder’s website where user may post queries about available flats. Builder has 1BHK to 4BHK flats, below is the relevant data from previous users that we will use to train decision tree:

Number of Members

Family Earning

Marital status

FAQ

Flat

1

30000

Unmarried

1

None

2

50000

Unmarried

1

2BHK

4

70000

Married

3

3BHK

6

90000

Married

3

3BHK

2

55000

Married

2

None

4

55000

Married

2

3BHK

3

60000

Married

2

2BHK

1

35000

Unmarried

1

2BHK

1

25000

Unmarried

2

1BHK

6

95000

Married

3

4BHK

6

85000

Married

4

4BHK

4

50000

Married

3

3BHK

3

50000

Unmarried

3

3BHK

4

80000

Married

3

2BHK

6

90000

Married

2

4BHK

4

75000

Married

3

4BHK

2

60000

Married

2

1BHK

 

Above we can see number of members in family is proportional to number of bedrooms as well family earning, marital status will affect buys choice, also if user reads more FAQs means user is more interested. Now we need to build a tree and train it so for new users it may predict whether user will buy any of 1BHK-4BHK flat or none of these. Let’s create node structure of tree:

class decisionNode:
    def __init__(self,col=-1,value=None,results=None,trueBranch=None,falseBranch=None):
        #column index on split happens
        self.col = col
        #column value on split happens
        self.value=value
        self.results = results
        #tree child if condition is true
        self.trueBranch = trueBranch
        #tree child if condition is false
        self.falseBranch = falseBranch

Above tree has two children true branch and false branch which decides flow of tree based on conditions. Now we will create a class decisionTree and a method which divides training data based on a condition:

class decisionTree:
    
    #separating rows in two sets of rows
    def divideSet(self, rows, columnIndex, columnValue):
        splitFunction = None
        if isinstance(columnValue, int) or isinstance(columnValue, float):
            #condition for numeric data
            splitFunction = lambda row: row[columnIndex] >= columnValue
        else:
            #condition for string data
            splitFunction = lambda row: row[columnIndex] == columnValue

        #rows satisfy condition
        set1 = [row for row in rows if splitFunction(row)]
        #rows do not satisfy condition
        set2 = [row for row in rows if not splitFunction(row)]
        return set1,set2

In above method we will have to give set of rows, column index on which we want to put split condition and on what value it should split rows. Suppose if we want to split above table data based on family income and income value should be 55000, then columnIndex will be 1(starting from zero) and columnValue will be 55000. We can run above method:

familyData = [
[1,30000,'Unmarried',1,'None'],
[2,50000,'Unmarried',1,'2BHK'],
[4,70000,'Married',3,'3BHK'],
[6,90000,'Married',3,'3BHK'],
[2,55000,'Married',2,'None'],
[4,55000,'Married',2,'3BHK'],
[3,60000,'Married',2,'2BHK'],
[1,35000,'Unmarried',1,'2BHK'],
[1,25000,'Unmarried',2,'1BHK'],
[6,95000,'Married',3,'4BHK'],
[6,85000,'Married',4,'4BHK'],
[4,50000,'Married',3,'3BHK'],
[3,50000,'Unmarried',3,'3BHK'],
[4,80000,'Married',3,'2BHK'],
[6,90000,'Married',2,'4BHK'],
[4,75000,'Married',3,'4BHK'],
[2,60000,'Married',2,'1BHK']
]
treeClass = decisionTree() set1,set2 = treeClass.divideSet(familyData, 1, 55000) print('set1:') print(set1) print('set2:') print(set2)

Output:

set1:

[[4, 70000, 'Married', 3, '3BHK'], [6, 90000, 'Married', 3, '3BHK'], [2, 55000, 'Married', 2, 'None'], [4, 55000, 'Married', 2, '3BHK'], [3, 60000, 'Married', 2, '2BHK'], [6, 95000, 'Married', 3, '4BHK'], [6, 85000, 'Married', 4, '4BHK'], [4, 80000, 'Married', 3, '2BHK'], [6, 90000, 'Married', 2, '4BHK'], [4, 75000, 'Married', 3, '4BHK'], [2, 60000, 'Married', 2, '1BHK']]

set2:

[[1, 30000, 'Unmarried', 1, 'None'], [2, 50000, 'Unmarried', 1, '2BHK'], [1, 35000, 'Unmarried', 1, '2BHK'], [1, 25000, 'Unmarried', 2, '1BHK'], [4, 50000, 'Married', 3, '3BHK'], [3, 50000, 'Unmarried', 3, '3BHK']]

Above we can see both groups are having all kind of results (1BHK-4BHK, None). It means sets are divided with high disorder and having mixed results, we must choose a condition on which group gets divided with least disorder.

Entropy: Entropy in information theory is amount of disorder in a group. Formula is given below:

Entropy Formula



Pi is the probability of each distinct result in set. We must find unique results from group to calculate entropy:

#unique data in result(usually last) column
def uniqueResults(self, rows):
    results = {}
    for row in rows:
        #result is last column of a row i.e. 2BHK
        result = row[len(row)-1]
        results.setdefault(result, 0)
        results[result] += 1
    return results

And to calculate entropy:

#calculates entropy of a group
def entropy(self, rows):
    import math
    uniqueRes = self.uniqueResults(rows)
    entropy = 0
    for result in uniqueRes.values():
        p = float(result)/len(rows)
        #logrithm of base 2
        entropy -= p*math.log(p, 2)

    return entropy

Now we implement buildTree method:

def buildTree(self, rows, entropyFun=entropy):
    if(len(rows) == 0):
        return
    current_entropy = entropyFun(rows)
    best_gain = 0
    best_criteria = None
    best_sets = None

    #leaving last column in calculation since it is result
    column_count = len(rows[0]) - 1
    for col in range(column_count):
        column_values = {}

        for row in rows:
            columnVal = row[col]
            #getting each unique value in that column
            column_values[columnVal] = 1

        for colVal in column_values:
            #try to split rows based on each column value
            set1,set2 = self.divideSet(rows, col, colVal)
            #calculate weighted average entropy of both sets
            avgEntropy = float(entropyFun(set1)*len(set1) + entropyFun(set2)*len(set2))/len(rows)
            #subtracting entropy of parent group from average entropy of 2 groups to calculate gain
            gain = current_entropy - avgEntropy
            #taking highest gain to choose split condition
            if(gain > best_gain and len(set1) > 0 and len(set2) > 0):
                best_gain = gain
                best_criteria = (col, colVal)
                best_sets = (set1,set2)
    if(best_gain  > 0):
        #building tree
        trueBranch = self.buildTree(best_sets[0], self.entropy)
        falseBranch = self.buildTree(best_sets[1], self.entropy)
        return decisionNode(col=best_criteria[0],value=best_criteria[1], trueBranch=trueBranch, falseBranch=falseBranch)
    else:
        return decisionNode(results=self.uniqueResults(rows))

In above code, we are calculating best gain which is information gain can be drawn by subtracting parent group entropy by average entropy of child groups. To find best split we have to look for highest info gain from each column’s each distinct value. Once we find for which column and which value information gain is highest (the condition which we can split group) we create two children from node trueBranch and falseBranch.

We can print this tree using below method:

#prints tree
def printTree(self, tree, colNames, indent=''):
    if(tree.results != None):
        #print result
        print(indent + str(tree.results))
    else:
        #print node condition
        print(indent + str(colNames[tree.col]) +':'+ str(tree.value))
        #print true branch
        print(indent+'T->')
        self.printTree(tree.trueBranch, colNames, indent+'    ')
        #print false branch
        print(indent+'F->')
        self.printTree(tree.falseBranch, colNames, indent+'    ')

Let’s train and print the tree:

familyData = [
[1,30000,'Unmarried',1,'None'],
[2,50000,'Unmarried',1,'2BHK'],
[4,70000,'Married',3,'3BHK'],
[6,90000,'Married',3,'3BHK'],
[2,55000,'Married',2,'None'],
[4,55000,'Married',2,'3BHK'],
[3,60000,'Married',2,'2BHK'],
[1,35000,'Unmarried',1,'2BHK'],
[1,25000,'Unmarried',2,'1BHK'],
[6,95000,'Married',3,'4BHK'],
[6,85000,'Married',4,'4BHK'],
[4,50000,'Married',3,'3BHK'],
[3,50000,'Unmarried',3,'3BHK'],
[4,80000,'Married',3,'2BHK'],
[6,90000,'Married',2,'4BHK'],
[4,75000,'Married',3,'4BHK'],
[2,60000,'Married',2,'1BHK']
]
treeClass = decisionTree()
#train the tree
tree = treeClass.buildTree(familyData, treeClass.entropy)
#print the tree
treeClass.printTree(tree, ['Number of Members','Salary','Marital Status','Inquiries', 'Flat Type'])

Above tree will be big to print, below I’m giving a sample of printTree:

Number of Members:3

T->

    Salary:75000

    T->

        Number of Members:6

        T->

            Salary:90000

            T->

                Salary:95000

                T->

                    {'4BHK': 1}

                F->

                    Inquiries:3

And finally, we will test the tree, we must write a method to classify new users:

#classify new users
def classify(self, observation, tree):
    if(tree.results != None):
        return tree.results
    else:
        v=observation[tree.col]
        branch=None
        if(isinstance(v, int) or isinstance(v, float)):
            if(v>=tree.value):branch=tree.trueBranch
            else:branch=tree.falseBranch
        else:
            if(v==tree.value):branch=tree.trueBranch
            else:branch=tree.falseBranch
        return self.classify(observation,branch)

Now let’s classify some new users who inquire for flat:

familyData = [
[1,30000,'Unmarried',1,'None'],
[2,50000,'Unmarried',1,'2BHK'],
[4,70000,'Married',3,'3BHK'],
[6,90000,'Married',3,'3BHK'],
[2,55000,'Married',2,'None'],
[4,55000,'Married',2,'3BHK'],
[3,60000,'Married',2,'2BHK'],
[1,35000,'Unmarried',1,'2BHK'],
[1,25000,'Unmarried',2,'1BHK'],
[6,95000,'Married',3,'4BHK'],
[6,85000,'Married',4,'4BHK'],
[4,50000,'Married',3,'3BHK'],
[3,50000,'Unmarried',3,'3BHK'],
[4,80000,'Married',3,'2BHK'],
[6,90000,'Married',2,'4BHK'],
[4,75000,'Married',3,'4BHK'],
[2,60000,'Married',2,'1BHK']
]
treeClass = decisionTree()
tree = treeClass.buildTree(familyData, treeClass.entropy)
print(treeClass.classify([2,40000,'Unmarried',1],tree))
print(treeClass.classify([1,10000,'Unmarried',1],tree))
print(treeClass.classify([6,80000,'Married',1],tree))

Output:

{'2BHK': 2}

{'None': 1}

{'4BHK': 1}

Analysis

    Decision trees are traditional ways in machine learning. Tree may get big if training data is having a lot of disorders. Sum of distinct values available in each column is usually proportional to number of decision nodes in decision tree. It may take much time to train the algorithm but once it is trained it will classify new items in O(log n) time.


Like 1 Person
Last modified on 11 October 2018
Nikhil Joshi

Nikhil Joshi
Ceo & Founder at Dotnetlovers
Atricles: 129
Questions: 9
Given Best Solutions: 9 *

Comments:


You are not loggedin, please login or signup to add comments:

Existing User

Login via:

New User



x