• [机器学习]-分类问题常用评价指标、混淆矩阵及ROC曲线绘制方法-python实现(多分类)


    分类问题

    分类问题是人工智能领域中最常见的一类问题之一,掌握合适的评价指标,对模型进行恰当的评价,是至关重要的。

    同样地,分割问题是像素级别的分类,除了mAcc、mIoU之外,也可以采用分类问题的一些指标来评价。

    本文对分类问题的常见评价指标进行介绍,并附上利用sklearn库的python实现。

    将从以下三个方面分别介绍:

    1. 常用评价指标
    2. 混淆矩阵绘制及评价指标计算
    3. ROC曲线绘制及AUC计算

    1. 常用评价指标

    混淆矩阵(confusion matrix)

    一般用来描述一个分类器分类的准确程度。
    根据分类器在测试数据集上的预测是否正确可以分为四种情况:

    • TP(True Positive)——将正类预测为正类数;
    • FN(False Negative)——将正类预测为负类数;
    • FP(False Positive)——将负类预测为正类数;
    • TN(True Negative)——将负类预测为负类数。
      构成一个二分类的混淆矩阵如图:
      image

    均交并比(Mean Intersection over Union,mIoU):

    语义分割的标准度量。其计算两个集合的交并比,在语义分割的问题中,这两个集合为真实值(ground truth)和预测值(predicted segmentation)。
    image

    分类问题评价指标

    二分类问题经混淆矩阵的处理后,针对不同问题,可以选用不同的指标来评价系统。

    1. Accuracy:表示预测结果的精确度,预测正确的样本数除以总样本数;
    2. Precision:准确率,表示预测结果中,预测为正样本的样本中,正确预测为正样本的概率;
    3. Sensitivity:灵敏度,表示在原始样本的正样本中,最后被正确预测为正样本的概率;
    4. Specificity:常常称作特异性,它研究的样本集是原始样本中的负样本,表示的是在这些负样本中最后被正确预测为负样本的概率;
    5. F1-score:表示的是precision和recall的调和平均评估指标。
      image

    受试者工作特征(Receiver Operating Characteristic,ROC)曲线

    ROC曲线是以真阳性率(TPR)为Y轴,以假阳性率(FPR)为X轴做的图。同样用来综合评价模型分类情况。是反映敏感性和特异性连续变量的综合指标。
    image

    AUC(Area Under Curve)

    AUC的值为ROC曲线下与x轴围成的面积,分类器的性能越接近完美,AUC的值越接近。当0.5>AUC>1时,效果优于“随机猜测”。一般情况下,模型的AUC值应当在此范围内。

    2. 混淆矩阵绘制及评价指标计算

    首先是分类器的训练,以sklearn库中的基础分类器为例

    import numpy as np
    import pandas as pd
    from sklearn.svm import SVC, LinearSVC
    from sklearn import metrics
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    
    clf = LinearSVC()
    clf.fit(train_features, train_target)
    predict = clf.predict(test_features)
    
    # 绘制混淆矩阵和评价指标计算
    cal(test_target, pred)
    
    # 获取分类score
    score = clf.decision_function(test_features)
    
    # 绘制ROC曲线和计算AUC
    paint_ROC(test_target, test_score)
    
    

    混淆矩阵的绘制和评价指标计算可以写在一起,在绘制混淆矩阵时,已经可以算出TP\TN\FP\FN的数值。

    import numpy as np
    import pandas as pd
    from sklearn.svm import SVC, LinearSVC
    from sklearn import metrics
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    # 这是一个多分类问题,y_true是target,y_pred是模型预测结果,数据格式为numpy
    
    def cal(y_true, y_pred):
    
        # confusion matrix row means GT, column means predication
        name = 'save_name'
        '''画混淆矩阵'''
        mat = confusion_matrix(y_true, y_pred)
        da = pd.DataFrame(mat, index = ['0', '1', '2'])
        sns.heatmap(da, annot =True, cbar = None, cmap = 'Blues')
        plt.title(name)
        # plt.tight_layout()yt
        plt.ylabel('True Label')
        plt.xlabel('Predict Label')
        plt.show()
        plt.savefig('{}/{}.png'.format('save_path', name)) # 将混淆矩阵图片保存下来
        plt.close()
        
        '''计算指标'''
        tp = np.diagonal(mat) # 每类的tp
        gt_num = np.sum(mat, axis=1) # axis = 1 指每行 ,每类的总数
        pre_num = np.sum(mat, axis=0)
        fp = pre_num - tp
        fn = gt_num - tp
        num = np.sum(gt_num)
        num = np.repeat(num, gt_num.shape[0])
        gt_num0 = num - gt_num
        tn = gt_num0 -fp
    	
        recall = tp.astype(np.float32) / gt_num
        specificity = tn.astype(np.float32) / gt_num0
        precision = tp.astype(np.float32) / pre_num
        F1 = 2 * (precision * recall) / (precision + recall)
        acc = (tp + tn).astype(np.float32) / num
    
        print('recall:', recall, '\nmean recall:{:.4f}'.format(np.mean(recall)) )
        print('specificity:', specificity, '\nmean specificity:{:.4f}'.format(np.mean(specificity)))
        print('precision:', precision, '\nmean precision:{:.4f}'.format(np.mean(precision)))
        print('F1:', F1 , '\nmean F1:{:.4f}'.format(np.mean(F1)))
        print('acc:', acc , '\nmean acc:{:.4f}'.format(np.mean(acc)))
    
    

    混淆矩阵如图所示:

    image

    3. ROC曲线绘制及AUC计算

    import numpy as np
    import pandas as pd
    from sklearn.svm import SVC, LinearSVC
    from sklearn import metrics
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    # 这是一个多分类问题(三分类),可以在一张图上绘制多条ROC曲线
    
    def paint_ROC(y_test, y_score):
    
        '''画ROC曲线'''
        plt.figure()
        # 修改颜色
        colors = ['','darkred', 'darkorange', 'cornflowerblue']
    
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        # print('label',y_test)
        # print('score', y_score)
    
        label = np.zeros((len(y_test), 3),  dtype="uint8")
        for i in range(len(y_test)):
            label[i][int(y_test[i])-1] = 1
        # print('label',label)
    
        for i in range(1,4):
            fpr[i], tpr[i], _ = metrics.roc_curve(label[:,i-1], y_score[:, i-1])
            roc_auc[i] = metrics.auc(fpr[i], tpr[i])
    
        fpr["mean"], tpr["mean"], _ = metrics.roc_curve(label.ravel(), y_score.ravel())
        roc_auc["mean"] = metrics.auc(fpr["mean"], tpr["mean"])
    
        lw = 2
        plt.plot(fpr["mean"], tpr["mean"],
             label='average, ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["mean"]),
             color='k', linewidth=lw)
    
        for i in range(1,4):
            auc = roc_auc[i]
            # 输出不同类别的FPR\TPR\AUC
            print('label: {}, fpr: {}, tpr: {}, auc: {}'.format(i, np.mean(fpr[i]), np.mean(tpr[i]), auc))
            plt.plot(fpr[i], tpr[i], color=colors[i],linestyle=':',lw = lw, label='Label = {0}, ROC curve (area = {1:0.2f})'.format(i, auc))
    
        plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
        plt.xlim([0.0, 1.05])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        # plt.title('Receiver operating characteristic example')
        plt.grid(linestyle='-.')  
        plt.grid(True)
        plt.legend(loc="lower right")
        plt.show()
        # 保存绘制好的ROC曲线
        plt.savefig('{}/{}.png'.format('save_path', 'save_name'))
        plt.close()
    
    

    ROC曲线如图所示:

    image

  • 相关阅读:
    Java开发 - Canal进阶之和Redis的数据同步
    【MySQL从入门到精通】【高级篇】(五)MySQL的SQL语句执行流程
    一线大厂软件测试面试题及答案解析,2022最强版...
    pytorch初学笔记(十四):损失函数
    【基础整理】Mapping representation 机器人所用地图种类及相关介绍
    2023湖北汽车工业学院计算机考研信息汇总
    第三课-软件升级-Stable Diffusion教程
    Java 数据结构 ---》 泛型
    拓端tecdat|GIS遥感数据可视化评估:印度河流域上部的积雪面积变化
    【PowerMockito:编写单元测试过程中原方法没有注入的属性在跑单元测试时出现空指针】
  • 原文地址:https://www.cnblogs.com/camilia/p/16697128.html