Introduction
You must be aware of Binary Search Tree where each node may have two children, if child’s value is less than or equal to parent then it is left child else it is added as right child. Decision tree is similar to binary search tree only difference, conditions for left and right child will be dynamic, different at each node and defined as per training data.
Decision tree is easy to interpret, you can understand it just looking at figure above. After training, decision tree looks like hierarchical if-else statements.
Implementation
Suppose a builder’s website where user may post queries about available flats. Builder has 1BHK to 4BHK flats, below is the relevant data from previous users that we will use to train decision tree:
Number of Members |
Family Earning |
Marital status |
FAQ |
Flat |
1 |
30000 |
Unmarried |
1 |
None |
2 |
50000 |
Unmarried |
1 |
2BHK |
4 |
70000 |
Married |
3 |
3BHK |
6 |
90000 |
Married |
3 |
3BHK |
2 |
55000 |
Married |
2 |
None |
4 |
55000 |
Married |
2 |
3BHK |
3 |
60000 |
Married |
2 |
2BHK |
1 |
35000 |
Unmarried |
1 |
2BHK |
1 |
25000 |
Unmarried |
2 |
1BHK |
6 |
95000 |
Married |
3 |
4BHK |
6 |
85000 |
Married |
4 |
4BHK |
4 |
50000 |
Married |
3 |
3BHK |
3 |
50000 |
Unmarried |
3 |
3BHK |
4 |
80000 |
Married |
3 |
2BHK |
6 |
90000 |
Married |
2 |
4BHK |
4 |
75000 |
Married |
3 |
4BHK |
2 |
60000 |
Married |
2 |
1BHK |
Above we can see number of members in family is proportional to number of bedrooms as well family earning, marital status will affect buys choice, also if user reads more FAQs means user is more interested. Now we need to build a tree and train it so for new users it may predict whether user will buy any of 1BHK-4BHK flat or none of these. Let’s create node structure of tree:
class decisionNode: def __init__(self,col=-1,value=None,results=None,trueBranch=None,falseBranch=None): #column index on split happens self.col = col #column value on split happens self.value=value self.results = results #tree child if condition is true self.trueBranch = trueBranch #tree child if condition is false self.falseBranch = falseBranch
Above tree has two children true branch and false branch which decides flow of tree based on conditions. Now we will create a class decisionTree and a method which divides training data based on a condition:
class decisionTree: #separating rows in two sets of rows def divideSet(self, rows, columnIndex, columnValue): splitFunction = None if isinstance(columnValue, int) or isinstance(columnValue, float): #condition for numeric data splitFunction = lambda row: row[columnIndex] >= columnValue else: #condition for string data splitFunction = lambda row: row[columnIndex] == columnValue #rows satisfy condition set1 = [row for row in rows if splitFunction(row)] #rows do not satisfy condition set2 = [row for row in rows if not splitFunction(row)] return set1,set2
In above method we will have to give set of rows, column index on which we want to put split condition and on what value it should split rows. Suppose if we want to split above table data based on family income and income value should be 55000, then columnIndex will be 1(starting from zero) and columnValue will be 55000. We can run above method:
familyData = [ [1,30000,'Unmarried',1,'None'], [2,50000,'Unmarried',1,'2BHK'], [4,70000,'Married',3,'3BHK'], [6,90000,'Married',3,'3BHK'], [2,55000,'Married',2,'None'], [4,55000,'Married',2,'3BHK'], [3,60000,'Married',2,'2BHK'], [1,35000,'Unmarried',1,'2BHK'], [1,25000,'Unmarried',2,'1BHK'], [6,95000,'Married',3,'4BHK'], [6,85000,'Married',4,'4BHK'], [4,50000,'Married',3,'3BHK'], [3,50000,'Unmarried',3,'3BHK'], [4,80000,'Married',3,'2BHK'], [6,90000,'Married',2,'4BHK'], [4,75000,'Married',3,'4BHK'], [2,60000,'Married',2,'1BHK'] ]
treeClass = decisionTree() set1,set2 = treeClass.divideSet(familyData, 1, 55000) print('set1:') print(set1) print('set2:') print(set2)
Output:
set1:
[[4, 70000, 'Married', 3, '3BHK'], [6, 90000, 'Married', 3, '3BHK'], [2, 55000, 'Married', 2, 'None'], [4, 55000, 'Married', 2, '3BHK'], [3, 60000, 'Married', 2, '2BHK'], [6, 95000, 'Married', 3, '4BHK'], [6, 85000, 'Married', 4, '4BHK'], [4, 80000, 'Married', 3, '2BHK'], [6, 90000, 'Married', 2, '4BHK'], [4, 75000, 'Married', 3, '4BHK'], [2, 60000, 'Married', 2, '1BHK']]
set2:
[[1, 30000, 'Unmarried', 1, 'None'], [2, 50000, 'Unmarried', 1, '2BHK'], [1, 35000, 'Unmarried', 1, '2BHK'], [1, 25000, 'Unmarried', 2, '1BHK'], [4, 50000, 'Married', 3, '3BHK'], [3, 50000, 'Unmarried', 3, '3BHK']]
Above we can see both groups are having all kind of results (1BHK-4BHK, None). It means sets are divided with high disorder and having mixed results, we must choose a condition on which group gets divided with least disorder.
Entropy: Entropy in information theory is amount of disorder in a group. Formula is given below:
P_{i} is the probability of each distinct result in set. We must find unique results from group to calculate entropy:
#unique data in result(usually last) column def uniqueResults(self, rows): results = {} for row in rows: #result is last column of a row i.e. 2BHK result = row[len(row)-1] results.setdefault(result, 0) results[result] += 1 return results
And to calculate entropy:
#calculates entropy of a group def entropy(self, rows): import math uniqueRes = self.uniqueResults(rows) entropy = 0 for result in uniqueRes.values(): p = float(result)/len(rows) #logrithm of base 2 entropy -= p*math.log(p, 2) return entropy
Now we implement buildTree method:
def buildTree(self, rows, entropyFun=entropy): if(len(rows) == 0): return current_entropy = entropyFun(rows) best_gain = 0 best_criteria = None best_sets = None #leaving last column in calculation since it is result column_count = len(rows[0]) - 1 for col in range(column_count): column_values = {} for row in rows: columnVal = row[col] #getting each unique value in that column column_values[columnVal] = 1 for colVal in column_values: #try to split rows based on each column value set1,set2 = self.divideSet(rows, col, colVal) #calculate weighted average entropy of both sets avgEntropy = float(entropyFun(set1)*len(set1) + entropyFun(set2)*len(set2))/len(rows) #subtracting entropy of parent group from average entropy of 2 groups to calculate gain gain = current_entropy - avgEntropy #taking highest gain to choose split condition if(gain > best_gain and len(set1) > 0 and len(set2) > 0): best_gain = gain best_criteria = (col, colVal) best_sets = (set1,set2) if(best_gain > 0): #building tree trueBranch = self.buildTree(best_sets[0], self.entropy) falseBranch = self.buildTree(best_sets[1], self.entropy) return decisionNode(col=best_criteria[0],value=best_criteria[1], trueBranch=trueBranch, falseBranch=falseBranch) else: return decisionNode(results=self.uniqueResults(rows))
In above code, we are calculating best gain which is information gain can be drawn by subtracting parent group entropy by average entropy of child groups. To find best split we have to look for highest info gain from each column’s each distinct value. Once we find for which column and which value information gain is highest (the condition which we can split group) we create two children from node trueBranch and falseBranch.
We can print this tree using below method:
#prints tree def printTree(self, tree, colNames, indent=''): if(tree.results != None): #print result print(indent + str(tree.results)) else: #print node condition print(indent + str(colNames[tree.col]) +':'+ str(tree.value)) #print true branch print(indent+'T->') self.printTree(tree.trueBranch, colNames, indent+' ') #print false branch print(indent+'F->') self.printTree(tree.falseBranch, colNames, indent+' ')
Let’s train and print the tree:
familyData = [ [1,30000,'Unmarried',1,'None'], [2,50000,'Unmarried',1,'2BHK'], [4,70000,'Married',3,'3BHK'], [6,90000,'Married',3,'3BHK'], [2,55000,'Married',2,'None'], [4,55000,'Married',2,'3BHK'], [3,60000,'Married',2,'2BHK'], [1,35000,'Unmarried',1,'2BHK'], [1,25000,'Unmarried',2,'1BHK'], [6,95000,'Married',3,'4BHK'], [6,85000,'Married',4,'4BHK'], [4,50000,'Married',3,'3BHK'], [3,50000,'Unmarried',3,'3BHK'], [4,80000,'Married',3,'2BHK'], [6,90000,'Married',2,'4BHK'], [4,75000,'Married',3,'4BHK'], [2,60000,'Married',2,'1BHK'] ] treeClass = decisionTree() #train the tree tree = treeClass.buildTree(familyData, treeClass.entropy) #print the tree treeClass.printTree(tree, ['Number of Members','Salary','Marital Status','Inquiries', 'Flat Type'])
Above tree will be big to print, below I’m giving a sample of printTree:
Number of Members:3
T->
Salary:75000
T->
Number of Members:6
T->
Salary:90000
T->
Salary:95000
T->
{'4BHK': 1}
F->
Inquiries:3
And finally, we will test the tree, we must write a method to classify new users:
#classify new users def classify(self, observation, tree): if(tree.results != None): return tree.results else: v=observation[tree.col] branch=None if(isinstance(v, int) or isinstance(v, float)): if(v>=tree.value):branch=tree.trueBranch else:branch=tree.falseBranch else: if(v==tree.value):branch=tree.trueBranch else:branch=tree.falseBranch return self.classify(observation,branch)
Now let’s classify some new users who inquire for flat:
familyData = [ [1,30000,'Unmarried',1,'None'], [2,50000,'Unmarried',1,'2BHK'], [4,70000,'Married',3,'3BHK'], [6,90000,'Married',3,'3BHK'], [2,55000,'Married',2,'None'], [4,55000,'Married',2,'3BHK'], [3,60000,'Married',2,'2BHK'], [1,35000,'Unmarried',1,'2BHK'], [1,25000,'Unmarried',2,'1BHK'], [6,95000,'Married',3,'4BHK'], [6,85000,'Married',4,'4BHK'], [4,50000,'Married',3,'3BHK'], [3,50000,'Unmarried',3,'3BHK'], [4,80000,'Married',3,'2BHK'], [6,90000,'Married',2,'4BHK'], [4,75000,'Married',3,'4BHK'], [2,60000,'Married',2,'1BHK'] ] treeClass = decisionTree() tree = treeClass.buildTree(familyData, treeClass.entropy) print(treeClass.classify([2,40000,'Unmarried',1],tree)) print(treeClass.classify([1,10000,'Unmarried',1],tree)) print(treeClass.classify([6,80000,'Married',1],tree))
Output:
{'2BHK': 2}
{'None': 1}
{'4BHK': 1}
Analysis
Decision trees are traditional ways in machine learning. Tree may get big if training data is having a lot of disorders. Sum of distinct values available in each column is usually proportional to number of decision nodes in decision tree. It may take much time to train the algorithm but once it is trained it will classify new items in O(log n) time.
Comments:
Nice Article.