• 机器学习实验六:决策树-海洋生物例子


     

     

     

     

      

    1. #创建数据集
    2. import numpy as np
    3. import pandas as pd
    4. from sklearn import tree
    5. from sklearn.tree import DecisionTreeClassifier
    6. import graphviz
    7. def createDataSet():
    8. row_data = {'no surfacing':[1,1,1,0,0],
    9. 'flippers':[1,1,0,1,1],
    10. 'fish':['yes','yes','no','no','no']}
    11. dataSet = pd.DataFrame(row_data)
    12. return dataSet
    13. def calEnt(dataSet):
    14. n = dataSet.shape[0] #数据集总行数
    15. iset = dataSet.iloc[:,-1].value_counts() #标签的所有类别
    16. p = iset/n #每一类标签所占比
    17. ent = (-p*np.log2(p)).sum() #计算信息熵
    18. return ent
    19. dataSet = createDataSet()
    20. print(calEnt(dataSet))
    21. #选择最优的列进行切分
    22. def bestSplit(dataSet):
    23. baseEnt = calEnt(dataSet) #计算原始熵
    24. bestGain = 0 #初始化信息增益
    25. axis = -1 #初始化最佳切分列,标签列
    26. for i in range(dataSet.shape[1]-1): #对特征的每一列进行循环
    27. levels= dataSet.iloc[:,i].value_counts().index #提取出当前列的所有取值
    28. ents = 0 #初始化子节点的信息熵
    29. for j in levels: #对当前列的每一个取值进行循环
    30. childSet = dataSet[dataSet.iloc[:,i]==j] #某一个子节点的dataframe
    31. ent = calEnt(childSet) #计算某一个子节点的信息熵
    32. ents += (childSet.shape[0]/dataSet.shape[0])*ent #计算当前列的信息熵
    33. print(f'第{i}{j}类的信息熵为{ents}')
    34. infoGain = baseEnt-ents #计算当前列的信息增益
    35. print(f'第{i}列的信息增益为{infoGain}')
    36. if (infoGain > bestGain):
    37. bestGain = infoGain #选择最大信息增益
    38. axis = i #最大信息增益所在列的索引
    39. return axis
    40. bestSplit(dataSet)
    41. def mySplit(dataSet,axis,value):
    42. col = dataSet.columns[axis]
    43. redataSet = dataSet.loc[dataSet[col]==value,:].drop(col,axis=1) #取切分属性值为value的数据子集,并且删除切分列
    44. return redataSet
    45. value =1
    46. axis=0
    47. mySplit(dataSet,axis,value)
    48. col = dataSet.columns[axis]
    49. dataSet.loc[dataSet[col]==value,:].drop(col,axis=1)
    50. def createTree(dataSet):
    51. featlist = list(dataSet.columns) #提取出数据集所有的列
    52. classlist = dataSet.iloc[:,-1].value_counts() #获取最后一列类标签
    53. #判断最多标签数目是否等于数据集行数,或者数据集是否只有一列
    54. if classlist[0]==dataSet.shape[0] or dataSet.shape[1] == 1:
    55. return classlist.index[0] #如果是,返回类标签
    56. axis = bestSplit(dataSet) #确定出当前最佳切分列的索引
    57. bestfeat = featlist[axis] #获取该索引对应的特征
    58. myTree = {bestfeat:{}} #采用字典嵌套的方式存储树信息
    59. del featlist[axis] #删除当前特征
    60. valuelist = set(dataSet.iloc[:,axis]) #提取最佳切分列所有属性值
    61. for value in valuelist: #对每一个属性值递归建树
    62. myTree[bestfeat][value] = createTree(mySplit(dataSet,axis,value))
    63. return myTree
    64. myTree = createTree(dataSet)
    65. #树的存储
    66. np.save('myTree.npy',myTree)
    67. #树的读取
    68. read_myTree = np.load('myTree.npy',allow_pickle=True).item()
    69. read_myTree
    70. def storeTree(inputTree,filename):
    71. import pickle
    72. fw=open(filename,'wb')
    73. pickle.dump(inputTree,fw)
    74. fw.close()
    75. storeTree(myTree,'mytree2.npy')
    76. def grabTree(filename):
    77. import pickle
    78. fr=open(filename,'rb')
    79. return pickle.load(fr)
    80. grabTree('mytree2.npy')
    81. def classify(inputTree,labels, testVec):
    82. firstStr = next(iter(inputTree)) #获取决策树第一个节点
    83. secondDict = inputTree[firstStr] #下一个字典
    84. featIndex = labels.index(firstStr) #第一个节点所在列的索引
    85. for key in secondDict.keys():
    86. if testVec[featIndex] == key:
    87. if type(secondDict[key]) == dict :
    88. classLabel = classify(secondDict[key], labels, testVec)
    89. else:
    90. classLabel = secondDict[key]
    91. return classLabel
    92. labels = list(dataSet.columns)
    93. inputTree = myTree
    94. firstStr = next(iter(inputTree))
    95. secondDict = inputTree[firstStr]
    96. def acc_classify(train,test):
    97. inputTree = createTree(train) #根据测试集生成一棵树
    98. labels = list(train.columns) #数据集所有的列名称
    99. result = []
    100. for i in range(test.shape[0]): #对测试集中每一条数据进行循环
    101. testVec = test.iloc[i,:-1] #测试集中的一个实例
    102. classLabel = classify(inputTree,labels,testVec) #预测该实例的分类
    103. result.append(classLabel) #将分类结果追加到result列表中
    104. test['predict']=result #将预测结果追加到测试集最后一列
    105. acc = (test.iloc[:,-1]==test.iloc[:,-2]).mean() #计算准确率
    106. print(f'模型预测准确率为{acc}')
    107. return test
    108. train = dataSet
    109. test = dataSet.iloc[:3,:]
    110. acc_classify(train,test)
    111. #特征
    112. Xtrain = dataSet.iloc[:,:-1]
    113. #标签
    114. Ytrain = dataSet.iloc[:,-1]
    115. labels = Ytrain.unique().tolist()
    116. Ytrain = Ytrain.apply(lambda x: labels.index(x)) #将本文转换为数字
    117. treemodel = tree.DecisionTreeClassifier(criterion='gini',max_depth=None,min_samples_leaf=1,ccp_alpha=0.0)
    118. clf2=treemodel.fit(Xtrain, Ytrain)
    119. clf2
    120. #绘制树模型
    121. clf = DecisionTreeClassifier()
    122. clf = clf.fit(Xtrain, Ytrain)
    123. tree.export_graphviz(clf)
    124. dot_data = tree.export_graphviz(clf, out_file=None)
    125. graphviz.Source(dot_data)
    126. #给图形增加标签和颜色
    127. dot_data = tree.export_graphviz(clf, out_file=None,
    128. feature_names=['no surfacing', 'flippers'],
    129. class_names=['fish', 'not fish'],
    130. filled=True, rounded=True,
    131. special_characters=True)
    132. graphviz.Source(dot_data)
    133. #利用render方法生成图形
    134. graph = graphviz.Source(dot_data)
    135. graph.render("fish")
    136. #利用render方法生成图形
    137. graph = graphviz.Source(dot_data)
    138. graph.render("fish")
    139. def getNumLeafs(myTree):
    140. numLeafs = 0 #初始化叶节点数目
    141. firstStr = next(iter(myTree)) #获得树的第一个键值,即第一个特征
    142. secondDict = myTree[firstStr] #获取下一组字典
    143. for key in secondDict.keys():
    144. if type(secondDict[key]) == dict: #测试该节点是否为字典
    145. numLeafs += getNumLeafs(secondDict[key]) #是字典,递归,循环计算新分支叶节点数
    146. else:
    147. numLeafs +=1 #不是字典,代表此结点为叶子结点
    148. return numLeafs
    149. firstStr = next(iter(myTree))
    150. getNumLeafs(myTree)
    151. def getTreeDepth(myTree):
    152. maxDepth = 0
    153. firstStr = next(iter(myTree))
    154. secondDict = myTree[firstStr]
    155. for key in secondDict.keys():
    156. if type(secondDict[key]) == dict:
    157. thisDepth = 1+getTreeDepth(secondDict[key])
    158. else:
    159. thisDepth = 1
    160. if thisDepth>maxDepth:
    161. maxDepth = thisDepth
    162. return maxDepth
    163. getTreeDepth(myTree)

    运行结果

     目前决策树无法实现

     

  • 相关阅读:
    容器是什么?
    ZooKeeper 集群部署
    一文读懂字符编码ASCII、Unicode与UTF-8
    牛皮了!阿里面试官终于分享出了 2022 年最新的 java 面试题及答案
    C++11(一)
    全图化在线系统设计
    Linux学习系列--如何在Linux中进行文件的管理
    前端开发工具vscode
    牛客刷题总结——Python入门08:面向对象、正则表达式
    springcloudgateway 默认转发不生效
  • 原文地址:https://blog.csdn.net/weixin_60530224/article/details/134026435