首页 > 编程语言 > 详细

python 根据分类结果求ROC,AUC

时间:2021-03-10 23:57:43      阅读:59      评论:0      收藏:0      [点我收藏+]
#!/usr/bin/python3
# _*_coding:utf-8 _*_

# @Time       :2021/2/21 23:14
# @Author    :jory.d
# @File       :roc_auc.py
# @Software    :PyCharm
# @Desc: 绘制多分类的ROC AUC曲线

import matplotlib as mpl

# mpl.use(‘Agg‘)  # Agg   TkAgg
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
from sklearn.preprocessing import label_binarize
import random
from pprint import pprint

np.set_printoptions(precision=2)


def get_other_metrics(label_names, y_trues, y_probs):
    """
    计算分类指标, P, R, F1
    """
    assert type(label_names) is list
    assert type(y_trues) is list
    assert type(y_probs) is list
    assert len(y_trues) == len(y_probs)
    y_true = np.array(y_trues)
    y_prob = np.array(y_probs)
    y_pred = np.argmax(y_prob, axis=-1)

    Precision = metrics.precision_score(y_true, y_pred, average=None)
    Recall = metrics.recall_score(y_true, y_pred, average=None)
    F1_Score = metrics.f1_score(y_true, y_pred, average=None)
    return Precision, Recall, F1_Score


def get_cmap(N):
    ‘‘‘
    Returns a function that maps each index in 0, 1,.. . N-1 to a distinct
    RGB color.
    ‘‘‘
    import matplotlib.cm as cmx
    import matplotlib.colors as colors
    color_norm = colors.Normalize(vmin=0, vmax=N - 1)
    scalar_map = cmx.ScalarMappable(norm=color_norm, cmap=hsv)

    def map_index_to_rgb_color(index):
        return scalar_map.to_rgba(index)

    return map_index_to_rgb_color


def create_roc_auc(label_names, y_trues, y_probs, png_save_path, is_show=True):
    """
    使用sklearn得api计算ROC,并绘制曲线
    :param label_names:
    :param y_trues:
    :param y_probs:
    :param png_save_path:
    :param is_show:
    :return:
    """
    assert type(label_names) is list
    assert type(y_trues) is list
    assert type(y_probs) is list
    assert len(y_trues) == len(y_probs)

    labels = list(label_names)
    n_classes = len(label_names)
    y_true = np.array(y_trues)
    y_prob = np.array(y_probs)
    y_true_one_hot = label_binarize(y_true, np.arange(n_classes))  # 装换成类似二进制的编码
    # Compute ROC curve and ROC area for each class
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(n_classes):
        fpr[i], tpr[i], thres = metrics.roc_curve(y_true_one_hot[:, i], y_prob[:, i])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])

    pprint(fpr)
    pprint(tpr)
    print(AUC: {}.format(roc_auc))
    mpl.rcParams[font.sans-serif] = uDejaVu Sans  # DejaVu Sans   SimHei
    mpl.rcParams[axes.unicode_minus] = False

    fig = plt.figure()
    color = (b, g, r, c, m, y, k, w)
    cmap = get_cmap(n_classes)
    # Plot of a ROC curve for a specific class
    for i in range(n_classes):
        # FPR就是横坐标,TPR就是纵坐标
        _col = cmap(i) if n_classes > len(color) else color[i]
        plt.plot(fpr[i], tpr[i], c=_col, lw=2, alpha=0.7, label=u%d AUC=%.3f % (i, roc_auc[i]))

    plt.plot((0, 1), (0, 1), c=#808080, lw=1, ls=--, alpha=0.7)
    plt.xlim((-0.01, 1.02))
    plt.ylim((-0.01, 1.02))
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.xlabel(False Positive Rate, fontsize=13)
    plt.ylabel(True Positive Rate, fontsize=13)
    plt.grid(b=True, ls=:)
    plt.legend(loc=lower right, fancybox=True, framealpha=0.8, fontsize=12)
    plt.title(uROC curve, fontsize=17)
    plt.savefig(png_save_path, format=png)
    if is_show:
        plt.show()

    return fig


def create_roc_self(label_names, y_trues, y_probs, png_save_path, is_show=True):
    """
    python 实现计算tpr, fpr; 同时统计多个阈值下每个class的指标,用于后处理时选择最优阈值
    :param label_names:
    :param y_trues:
    :param y_probs:
    :param png_save_path:
    :param is_show:
    :return:
    """
    assert type(label_names) is list
    assert type(y_trues) is list
    assert type(y_probs) is list
    assert len(y_trues) == len(y_probs)

    n_classes = len(label_names)
    y_trues = np.array(y_trues)
    y_probs = np.array(y_probs)
    bs = y_probs.shape[0]
    y_trues_one_hot = label_binarize(y_trues, np.arange(n_classes))  # 装换成类似二进制的编码
    print(y_trues)
    print(y_trues_one_hot)
    tpr_dict, fpr_dict = {}, {}
    thresh = [i / 10 for i in range(1, 11)]
    # y_pred = np.argmax(y_probs, axis=1)  # [n,]
    for i in range(n_classes):
        tpr_dict[i] = []
        fpr_dict[i] = []
        y_true = y_trues_one_hot[:, i]  # [n,]
        y_pred_prob = y_probs[:, i]
        # 计算下0.1~1.0这每个阈值下的tpr, fpr
        for th in thresh:
            # tpr = tp/(tp+fn), fpr = fp/(tn+fp)
            # y_pred_prob = np.array([y_probs[i, y_pred[i]] for i in range(bs)])  # [n,]
            y_pred2 = np.where(y_pred_prob >= th, 1, 0)
            tp = np.sum(y_pred2[y_true == 1] == 1)
            fn = np.sum(y_pred2[y_true == 1] == 0)
            fp = np.sum(y_pred2[y_true == 0] == 1)
            tn = np.sum(y_pred2[y_true == 0] == 0)
            tpr = tp / (tp + fn + 1e-5)
            fpr = fp / (tn + fp + 1e-5)
            print(fthres={th}, tpr={tpr}, fpr={fpr})
            tpr_dict[i].append(round(tpr, 2))
            fpr_dict[i].append(round(fpr, 2))

    pprint(tpr: {}.format(tpr_dict))
    pprint(fpr: {}.format(fpr_dict))

    cols = 2
    rows = round(n_classes / cols)
    fig = plt.figure(figsize=(12, 12), dpi=150)
    fig.suptitle(per class tpr and fpr, fontsize=xx-large)
    for r in range(rows):
        for c in range(cols):
            id = r * cols + c
            if id > n_classes - 1: break
            ax = fig.add_subplot(rows, cols, id + 1)
            x = thresh
            ax.plot(x, tpr_dict[id], c=b, label=tpr)
            ax.plot(x, fpr_dict[id], c=r, label=fpr)
            ax.set_xlabel(thres, fontsize=x-large)
            ax.set_ylabel(tpr_fpr, fontsize=x-large)
            plt.xticks(np.arange(0, 1.1, 0.2))
            plt.yticks(np.arange(0, 1.1, 0.2))

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc=lower right, fontsize=x-large)
    plt.savefig(png_save_path, format=png)
    if is_show:
        plt.show()

    return fig


# 计算每个class的 fpr, tpr

np.random.seed(888)
if __name__ == __main__:
    labels = [A, B, C]
    batch_size = 100
    # 真值和预测值
    y_true = np.random.randint(0, len(labels), [batch_size]).tolist()
    y_prob = np.random.random([batch_size, len(labels)]).tolist()
    # _ = create_roc_auc(labels, y_true, y_prob, ‘./ss1.png‘)
    _ = create_roc_self(labels, y_true, y_prob, ./ss2.png)

 

python 根据分类结果求ROC,AUC

原文:https://www.cnblogs.com/dxscode/p/14460545.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!