机器学习入门4

前言

用matplotlib.pyplot绘制决策树。


确定x轴y轴长度

即确定树的深度和叶子节点的数量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def getNumberOfLeafs(tree):
numberOfLeafs = 0
nextFeature = list(tree.keys())[0]
valuesOfFeature = tree[nextFeature]
for value in valuesOfFeature.keys():
if type(valuesOfFeature[value]).__name__ == "dict":
numberOfLeafs += getNumberOfLeafs(valuesOfFeature[value])
else:
numberOfLeafs += 1
return numberOfLeafs


def getDeepthOfTree(tree):
depth = 0
nextFeature = list(tree.keys())[0]
valuesOfFeature = tree[nextFeature]
for value in valuesOfFeature.keys():
if type(valuesOfFeature[value]).__name__ == "dict":
thisDepth = 1 + getDeepthOfTree(valuesOfFeature[value])
else:
thisDepth = 1
if thisDepth > depth:
depth = thisDepth
return depth

绘制节点

1
2
def plotNode(ax1, nodeText, centerPt, parentPt, nodeType):
ax1.annotate(nodeText, xy=parentPt, xytext=centerPt, va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) # va/ha=水平垂直居中

构造树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def plotMidTest(ax1, centerPt, parentPt, text):
xMid = (parentPt[0] - centerPt[0]) / 2 + centerPt[0]
yMid = (parentPt[1] - centerPt[1]) / 2 + centerPt[1]
ax1.text(xMid, yMid, text)


def plotTree(ax1, treeToPlot, myTree, parentPt, nodeText):
numberOfLeafs = getNumberOfLeafs(myTree)
nextFeature = list(myTree.keys())[0]
centerPt = (treeToPlot["xOff"] + (1.0 + float(numberOfLeafs)) / 2.0 / treeToPlot["totalW"], treeToPlot["yOff"])
plotMidTest(ax1, centerPt, parentPt, nodeText)
plotNode(ax1, nextFeature, centerPt, parentPt, decisionNode)
valuesOfFeature = myTree[nextFeature]
treeToPlot["yOff"] = treeToPlot["yOff"] - 1.0 / treeToPlot["totalD"] # 进入下一层
for value in valuesOfFeature.keys():
if type(valuesOfFeature[value]).__name__ == "dict":
plotTree(ax1, treeToPlot, valuesOfFeature[value], centerPt, value)
else:
treeToPlot["xOff"] = treeToPlot["xOff"] + 1.0 / treeToPlot["totalW"] # 绘制该层下一个节点
plotNode(ax1, valuesOfFeature[value], (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, leafNode)
plotMidTest(ax1, (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, value)
treeToPlot["yOff"] = treeToPlot["yOff"] + 1.0 / treeToPlot["totalD"] # 返回上一层


def createPlot(myTree):
figure = plt.figure("myTree", facecolor="white")
figure.clf()
axprops = dict(xticks=[], yticks=[]) # 消除x轴y轴
ax1 = plt.subplot(111, frameon=False, **axprops)
treeToPlot = {}
treeToPlot["totalW"] = float(getNumberOfLeafs(myTree))
treeToPlot["totalD"] = float(getDeepthOfTree(myTree))
treeToPlot["xOff"] = -0.5 / treeToPlot["totalW"]
treeToPlot["yOff"] = 1.0
plotTree(ax1, treeToPlot, myTree, (0.5, 1.0), "")
plt.show()

拼起来:

1
2
3
4
5
6
decisionNode = dict(boxstyle="circle", fc="0.8") # boxstyle=边框样式,fc=颜色深度
leafNode = dict(boxstyle="round", fc="0.8")
arrow_args = dict(arrowstyle="<-") # 连线样式
myDataSet = file2DataSet("E:/dataSet/tree/lenses.data")
myTree = createTree(myDataSet, features, valueForFeatures, labels)
createPlot(myTree)

最后连线上的标签有点重合,可能要斜放或者调大一下间距会好一点。


后记

hard。


机器学习入门4
http://yoursite.com/2021/12/10/机器学习入门4/
作者
Aluvion
发布于
2021年12月10日
许可协议