资源简介
treePlotter模块,其实就是一系列函数组成的自定义模块,资源包括该模块的代码具体实现代码,用于实现树与分类结构的可视化
代码片段和文件信息
# _*_ coding: UTF-8 _*_
import matplotlib.pyplot as plt
“““绘决策树的函数“““
decisionNode = dict(boxstyle=“sawtooth“ fc=“0.8“) # 定义分支点的样式
leafNode = dict(boxstyle=“round4“ fc=“0.8“) # 定义叶节点的样式
arrow_args = dict(arrowstyle=“<-“) # 定义箭头标识样式
# 计算树的叶子节点数量
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == ‘dict‘:
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 计算树的最大深度
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == ‘dict‘:
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 画出节点
def plotNode(nodeTxt centerPt parentPt nodeType):
createPlot.ax1.annotate(nodeTxt xy=parentPt xycoords=‘axes fraction‘ \
xytext=centerPt textcoords=‘axes fraction‘ va=“center“ ha=“center“ \
bbox=nodeType arrowprops=arrow_args)
# 标箭头上的文字
def plotMidText(cntrPt parentPt txtString):
lens = len(txtString)
xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002
yMid = (parentPt[1] + cntrPt[1]) / 2.0
createPlot.ax1.text(xMid yMid txtString)
def plotTree(myTree parentPt nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.x0ff + \
(1.0 + float(numLeafs)) / 2.0 / plotTree.totalW plotTree.y0ff)
plotMidText(cntrPt parentPt nodeTxt)
plotNode(firstStr cntrPt parentPt decisionNode)
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == ‘dict‘:
plotTree(secondDict[key] cntrPt str(key))
else:
plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
plotNode(secondDict[key] \
(plotTree.x0ff plotTree.y0ff) cntrPt leafNode)
plotMidText((plotTree.x0ff plotTree.y0ff) \
cntrPt str(key))
plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1 facecolor=‘white‘)
fig.clf()
axprops = dict(xticks=[] yticks=[])
createPlot.ax1 = plt.subplot(111 frameon=False **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.x0ff = -0.5 / plotTree.totalW
plotTree.y0ff = 1.0
plotTree(inTree (0.5 1.0) ‘‘)
plt.show()
if __name__==‘__main__‘:
createPlot()
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 3023 2018-10-18 20:59 treePlotter\__init__.py
文件 3253 2018-10-18 21:10 treePlotter\__init__.pyc
目录 0 2018-10-18 21:10 treePlotter
----------- --------- ---------- ----- ----
6276 3
评论
共有 条评论