1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| def plotMidTest(ax1, centerPt, parentPt, text): xMid = (parentPt[0] - centerPt[0]) / 2 + centerPt[0] yMid = (parentPt[1] - centerPt[1]) / 2 + centerPt[1] ax1.text(xMid, yMid, text)
def plotTree(ax1, treeToPlot, myTree, parentPt, nodeText): numberOfLeafs = getNumberOfLeafs(myTree) nextFeature = list(myTree.keys())[0] centerPt = (treeToPlot["xOff"] + (1.0 + float(numberOfLeafs)) / 2.0 / treeToPlot["totalW"], treeToPlot["yOff"]) plotMidTest(ax1, centerPt, parentPt, nodeText) plotNode(ax1, nextFeature, centerPt, parentPt, decisionNode) valuesOfFeature = myTree[nextFeature] treeToPlot["yOff"] = treeToPlot["yOff"] - 1.0 / treeToPlot["totalD"] for value in valuesOfFeature.keys(): if type(valuesOfFeature[value]).__name__ == "dict": plotTree(ax1, treeToPlot, valuesOfFeature[value], centerPt, value) else: treeToPlot["xOff"] = treeToPlot["xOff"] + 1.0 / treeToPlot["totalW"] plotNode(ax1, valuesOfFeature[value], (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, leafNode) plotMidTest(ax1, (treeToPlot["xOff"], treeToPlot["yOff"]), centerPt, value) treeToPlot["yOff"] = treeToPlot["yOff"] + 1.0 / treeToPlot["totalD"]
def createPlot(myTree): figure = plt.figure("myTree", facecolor="white") figure.clf() axprops = dict(xticks=[], yticks=[]) ax1 = plt.subplot(111, frameon=False, **axprops) treeToPlot = {} treeToPlot["totalW"] = float(getNumberOfLeafs(myTree)) treeToPlot["totalD"] = float(getDeepthOfTree(myTree)) treeToPlot["xOff"] = -0.5 / treeToPlot["totalW"] treeToPlot["yOff"] = 1.0 plotTree(ax1, treeToPlot, myTree, (0.5, 1.0), "") plt.show()
|