前言

继续学习。


matplotlib

用conda装不上,去手动下载一个来安装:

pip install E:\dataSet\matplotlib-3.5.0-cp39-cp39-win_amd64.whl

数据集准备

这里用文档中使用的隐形眼镜数据集

读取数据集

数据集长这样:

1  1  1  1  1  3
2  1  1  1  2  2
3  1  1  2  1  3
4  1  1  2  2  1
5  1  2  1  1  3
6  1  2  1  2  2
7  1  2  2  1  3
8  1  2  2  2  1
9  2  1  1  1  3
10 2  1  1  2  2
11 2  1  2  1  3
12 2  1  2  2  1
13 2  2  1  1  3
14 2  2  1  2  2
15 2  2  2  1  3
16 2  2  2  2  3
17 3  1  1  1  3
18 3  1  1  2  3
19 3  1  2  1  3
20 3  1  2  2  1
21 3  2  1  1  3
22 3  2  1  2  2
23 3  2  2  1  3
24 3  2  2  2  3

按照描述:

7. Attribute Information:
    -- 3 Classes
     1 : the patient should be fitted with hard contact lenses,
     2 : the patient should be fitted with soft contact lenses,
     3 : the patient should not be fitted with contact lenses.

    1. age of the patient: (1) young, (2) pre-presbyopic, (3) presbyopic
    2. spectacle prescription:  (1) myope, (2) hypermetrope
    3. astigmatic:     (1) no, (2) yes
    4. tear production rate:  (1) reduced, (2) normal

第一列是序号,最后一列是标签,中间是四个特征值。读取:

def file2DataSet(filename):
    f = open(filename)
    arrayOfLines = f.readlines()
    myDataSet = []
    for line in arrayOfLines:
        line = line.strip()
        line = line.replace("  ", " ")
        listFromLine = line.split(" ")
        myDataSet.append(listFromLine[1:])
    return myDataSet

划分数据集

构造决策树需要确定每一层所使用的特征,简单来说就是先判断什么,所以需要通过计算数据集本身的熵,以及根据每个特征划分后的数据集熵,最后通过信息增益确定使用什么特征。

def splitDataSet(dataSet, axis, value): # axis=根据哪个特征划分,value=筛选出该特征符合该值的数据
    returnDataSet = []
    for row in dataSet:
        if row[axis] == value:
            reduceRow = row[:axis]
            reduceRow.extend(row[axis + 1:])
            returnDataSet.append(reduceRow)
    return returnDataSet

计算熵

香农熵:

def calcShannonEnt(dataSet):
    numberOfRows = len(dataSet)
    labelCounts = {}
    for row in dataSet:
        currentLabel = row[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonCount = 0
    for label in labelCounts:
        prob = float(labelCounts[label]) / numberOfRows
        shannonCount -= prob * log2(prob)
    return shannonCount

选择用于划分的特征

计算各特征的信息增益(分割前后香农熵相减)然后取最大的一个:

def chooseBestFeatureToSplit(dataSet):
    numberOfFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for x in range(numberOfFeatures):
        featureList = [row[x + 1] for row in dataSet]
        uniqueValue = set(featureList)
        newEntropy = 0
        for value in uniqueValue:
            subDataSet = splitDataSet(dataSet, x + 1, value)
            prob = float(len(subDataSet)) / len(dataSet)
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = x + 1
    return bestFeature

构造决策树

首先考虑的是叶子节点怎么构造,一共两种情况,第一种是数据集中都是同一个标签了,那就直接返回。

第二种是已经没有特征可以用来继续分割,数据集中只剩下标签,即特征值完全相同的数据中有多种标签。这里可以统计出现次数,直接返回出现最多次数的标签:

def majorityCnt(labels):
    labelsCount = {}
    for label in labels:
        if label not in labelsCount.keys():
            labelsCount[label] = 0
        labelsCount[label] += 1
    sortedLabelsCount = sorted(labelsCount.items(), key=lambda label: label[1], reverse=True)
    return sortedLabelsCount[0][0]

然后建立决策树,注意复制数组时不要直接用:

subLabels = labels

这样就不是复制而是引用,而要写成:

subLabels = labels[:]

决策树:

def createTree(dataSet, features, valueForFeatures, labels):
    classList = [row[-1] for row in dataSet]
    if classList.count(classList[0]) == len(classList): # 数据集中所有数据都是同一个标签
        return labels[int(classList[0]) - 1]
    if len(dataSet[0]) == 1: # 数据集只剩下标签没有特征
        return labels[majorityCnt(classList) - 1]
    bestFeature = chooseBestFeatureToSplit(dataSet)
    bestFeatureLabel = features[bestFeature]
    myTree = {bestFeatureLabel:{}}
    realValueForThisFeature = valueForFeatures[bestFeature]
    del(features[bestFeature])
    del(valueForFeatures[bestFeature])
    featureValues = [row[bestFeature] for row in dataSet]
    uniqueValue = set(featureValues)
    for value in uniqueValue:
        subFeatures = features[:]
        subRealValueForFeatures = valueForFeatures[:]
        realValue = realValueForThisFeature[int(value) - 1]
        myTree[bestFeatureLabel][realValue] = createTree(splitDataSet(dataSet, bestFeature, value), subFeatures, subRealValueForFeatures, labels)
    return myTree

最后整合一下:

labels = [
    "hard",
    "soft",
    "not be fitted with contact lenses"
]
features = [
    "age of the patient",
    "spectacle prescription",
    "astigmatic",
    "tear production rate",
]
valueForFeatures = [
    ["young", "pre-presbyopic", "presbyopic"],
    ["myope", "hypermetrope"],
    ["no", "yes"],
    ["reduced", "normal"]
]
myDataSet = file2DataSet("E:/dataSet/tree/lenses.data")
print(createTree(myDataSet, features, valueForFeatures, labels))

结果:

{
    "tear production rate": {
        "normal": {
            "astigmatic": {
                "yes": {
                    "spectacle prescription": {
                        "hypermetrope": {
                            "age of the patient": {
                                "presbyopic": "not be fitted with contact lenses",
                                "pre-presbyopic": "not be fitted with contact lenses",
                                "young": "hard"
                            }
                        },
                        "myope": "hard"
                    }
                },
                "no": {
                    "age of the patient": {
                        "presbyopic": {
                            "spectacle prescription": {
                                "hypermetrope": "soft",
                                "myope": "not be fitted with contact lenses"
                            }
                        },
                        "pre-presbyopic": "soft",
                        "young": "soft"
                    }
                }
            }
        },
        "reduced": "not be fitted with contact lenses"
    }
}

后记

用matplotlib画画的事情晚点再做吧。


机器学习

本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!

机器学习入门4
机器学习入门2