前言

用matplotlib.pyplot绘制决策树。


确定x轴y轴长度

即确定树的深度和叶子节点的数量:

def getNumberOfLeafs(tree):
    numberOfLeafs = 0
    nextFeature = list(tree.keys())[0]
    valuesOfFeature = tree[nextFeature]
    for value in valuesOfFeature.keys():
        if type(valuesOfFeature[value]).__name__ == "dict":
            numberOfLeafs += getNumberOfLeafs(valuesOfFeature[value])
        else:
            numberOfLeafs += 1
    return numberOfLeafs


def getDeepthOfTree(tree):
    depth = 0
    nextFeature = list(tree.keys())[0]
    valuesOfFeature = tree[nextFeature]
    for value in valuesOfFeature.keys():
        if type(valuesOfFeature[value]).__name__ == "dict":
            thisDepth = 1 + getDeepthOfTree(valuesOfFeature[value])
        else:
            thisDepth = 1
        if thisDepth > depth:
            depth = thisDepth
    return depth

绘制节点

def plotNode(ax1, nodeText, centerPt, parentPt, nodeType):
    ax1.annotate(nodeText, xy=parentPt, xytext=centerPt, va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) # va/ha=水平垂直居中

构造树

def plotMidTest(ax1, centerPt, parentPt, text):
    xMid = (parentPt[0] - centerPt[0]) / 2 + centerPt[0]
    yMid = (parentPt[1] - centerPt[1]) / 2 + centerPt[1]
    ax1.text(xMid, yMid, text)


def plotTree(ax1, treeToPlot, myTree, parentPt, nodeText):
    numberOfLeafs = getNumberOfLeafs(myTree)
    nextFeature = list(myTree.keys())[0]
    centerPt = (treeToPlot["xOff"] + (1.0 + float(numberOfLeafs)) / 2.0 / treeToPlot["totalW"], treeToPlot["yOff"])
    plotMidTest(ax1, centerPt, parentPt, nodeText)
    plotNode(ax1, nextFeature, centerPt, parentPt, decisionNode)
    valuesOfFeature = myTree[nextFeature]
    treeToPlot["yOff"] = treeToPlot["yOff"] - 1.0 / treeToPlot["totalD"] # 进入下一层
    for value in valuesOfFeature.keys():
        if type(valuesOfFeature[value]).__name__ == "dict":
            plotTree(ax1, treeToPlot, valuesOfFeature[value], centerPt, value)
        else:
            treeToPlot["xOff"] = treeToPlot["xOff"] + 1.0 / treeToPlot["totalW"] # 绘制该层下一个节点
            plotNode(ax1, valuesOfFeature[value], (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, leafNode)
            plotMidTest(ax1, (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, value)
    treeToPlot["yOff"] = treeToPlot["yOff"] + 1.0 / treeToPlot["totalD"] # 返回上一层


def createPlot(myTree):
    figure = plt.figure("myTree", facecolor="white")
    figure.clf()
    axprops = dict(xticks=[], yticks=[]) # 消除x轴y轴
    ax1 = plt.subplot(111, frameon=False, **axprops)
    treeToPlot = {}
    treeToPlot["totalW"] = float(getNumberOfLeafs(myTree))
    treeToPlot["totalD"] = float(getDeepthOfTree(myTree))
    treeToPlot["xOff"] = -0.5 / treeToPlot["totalW"]
    treeToPlot["yOff"] = 1.0
    plotTree(ax1, treeToPlot, myTree, (0.5, 1.0), "")
    plt.show()

拼起来:

decisionNode = dict(boxstyle="circle", fc="0.8") # boxstyle=边框样式,fc=颜色深度
leafNode = dict(boxstyle="round", fc="0.8")
arrow_args = dict(arrowstyle="<-") # 连线样式
myDataSet = file2DataSet("E:/dataSet/tree/lenses.data")
myTree = createTree(myDataSet, features, valueForFeatures, labels)
createPlot(myTree)

最后连线上的标签有点重合,可能要斜放或者调大一下间距会好一点。


后记

hard。


机器学习

本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!

h2 jdbc 攻击
机器学习入门3