Created on Oct 14 2010
@author: Peter Harrington
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle=“sawtooth“ fc=“0.8“)
leafNode = dict(boxstyle=“round4“ fc=“0.8“)
arrow_args = dict(arrowstyle=“<-“)
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__==‘dict‘:#test to see if the nodes are dictonaires if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__==‘dict‘:#test to see if the nodes are dictonaires if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt centerPt parentPt nodeType):
createPlot.ax1.annotate(nodeTxt xy=parentPt xycoords=‘axes fraction‘
xytext=centerPt textcoords=‘axes fraction‘
va=“center“ ha=“center“ bbox=nodeType arrowprops=arrow_args )
def plotMidText(cntrPt parentPt txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid yMid txtString va=“center“ ha=“center“ rotation=30)
def plotTree(myTree parentPt nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW plotTree.yOff)
plotMidText(cntrPt parentPt nodeTxt)
plotNode(firstStr cntrPt parentPt decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__==‘dict‘:#test to see if the nodes are dictonaires if not they are leaf nodes
plotTree(secondDict[key]cntrPtstr(key)) #recursion
else: #it‘s a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key] (plotTree.xOff plotTree.yOff) cntrPt leafNode)
plotMidText((plotTree.xOff plotTree.yOff) cntrPt str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it‘s a tree and the first element will be another dict
def createPlot(inTree):
fig = plt.figure(1 facecolor=‘white‘)
axprops = dict(xticks=[] yticks=[])
createPlot.ax1 = plt.subplot(111 frameon=False **axprops) #no ti
