前言
用matplotlib.pyplot绘制决策树。
确定x轴y轴长度
即确定树的深度和叶子节点的数量:
def getNumberOfLeafs(tree):
numberOfLeafs = 0
nextFeature = list(tree.keys())[0]
valuesOfFeature = tree[nextFeature]
for value in valuesOfFeature.keys():
if type(valuesOfFeature[value]).__name__ == "dict":
numberOfLeafs += getNumberOfLeafs(valuesOfFeature[value])
else:
numberOfLeafs += 1
return numberOfLeafs
def getDeepthOfTree(tree):
depth = 0
nextFeature = list(tree.keys())[0]
valuesOfFeature = tree[nextFeature]
for value in valuesOfFeature.keys():
if type(valuesOfFeature[value]).__name__ == "dict":
thisDepth = 1 + getDeepthOfTree(valuesOfFeature[value])
else:
thisDepth = 1
if thisDepth > depth:
depth = thisDepth
return depth
绘制节点
def plotNode(ax1, nodeText, centerPt, parentPt, nodeType):
ax1.annotate(nodeText, xy=parentPt, xytext=centerPt, va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) # va/ha=水平垂直居中
构造树
def plotMidTest(ax1, centerPt, parentPt, text):
xMid = (parentPt[0] - centerPt[0]) / 2 + centerPt[0]
yMid = (parentPt[1] - centerPt[1]) / 2 + centerPt[1]
ax1.text(xMid, yMid, text)
def plotTree(ax1, treeToPlot, myTree, parentPt, nodeText):
numberOfLeafs = getNumberOfLeafs(myTree)
nextFeature = list(myTree.keys())[0]
centerPt = (treeToPlot["xOff"] + (1.0 + float(numberOfLeafs)) / 2.0 / treeToPlot["totalW"], treeToPlot["yOff"])
plotMidTest(ax1, centerPt, parentPt, nodeText)
plotNode(ax1, nextFeature, centerPt, parentPt, decisionNode)
valuesOfFeature = myTree[nextFeature]
treeToPlot["yOff"] = treeToPlot["yOff"] - 1.0 / treeToPlot["totalD"] # 进入下一层
for value in valuesOfFeature.keys():
if type(valuesOfFeature[value]).__name__ == "dict":
plotTree(ax1, treeToPlot, valuesOfFeature[value], centerPt, value)
else:
treeToPlot["xOff"] = treeToPlot["xOff"] + 1.0 / treeToPlot["totalW"] # 绘制该层下一个节点
plotNode(ax1, valuesOfFeature[value], (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, leafNode)
plotMidTest(ax1, (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, value)
treeToPlot["yOff"] = treeToPlot["yOff"] + 1.0 / treeToPlot["totalD"] # 返回上一层
def createPlot(myTree):
figure = plt.figure("myTree", facecolor="white")
figure.clf()
axprops = dict(xticks=[], yticks=[]) # 消除x轴y轴
ax1 = plt.subplot(111, frameon=False, **axprops)
treeToPlot = {}
treeToPlot["totalW"] = float(getNumberOfLeafs(myTree))
treeToPlot["totalD"] = float(getDeepthOfTree(myTree))
treeToPlot["xOff"] = -0.5 / treeToPlot["totalW"]
treeToPlot["yOff"] = 1.0
plotTree(ax1, treeToPlot, myTree, (0.5, 1.0), "")
plt.show()
拼起来:
decisionNode = dict(boxstyle="circle", fc="0.8") # boxstyle=边框样式,fc=颜色深度
leafNode = dict(boxstyle="round", fc="0.8")
arrow_args = dict(arrowstyle="<-") # 连线样式
myDataSet = file2DataSet("E:/dataSet/tree/lenses.data")
myTree = createTree(myDataSet, features, valueForFeatures, labels)
createPlot(myTree)
最后连线上的标签有点重合,可能要斜放或者调大一下间距会好一点。
后记
hard。