#!/usr/bin/python3 # _*_coding:utf-8 _*_ # @Time :2021/2/21 23:14 # @Author :jory.d # @File :roc_auc.py # @Software :PyCharm # @Desc: 绘制多分类的ROC AUC曲线 import matplotlib as mpl # mpl.use(‘Agg‘) # Agg TkAgg import matplotlib.pyplot as plt import numpy as np from sklearn import metrics from sklearn.preprocessing import label_binarize import random from pprint import pprint np.set_printoptions(precision=2) def get_other_metrics(label_names, y_trues, y_probs): """ 计算分类指标, P, R, F1 """ assert type(label_names) is list assert type(y_trues) is list assert type(y_probs) is list assert len(y_trues) == len(y_probs) y_true = np.array(y_trues) y_prob = np.array(y_probs) y_pred = np.argmax(y_prob, axis=-1) Precision = metrics.precision_score(y_true, y_pred, average=None) Recall = metrics.recall_score(y_true, y_pred, average=None) F1_Score = metrics.f1_score(y_true, y_pred, average=None) return Precision, Recall, F1_Score def get_cmap(N): ‘‘‘ Returns a function that maps each index in 0, 1,.. . N-1 to a distinct RGB color. ‘‘‘ import matplotlib.cm as cmx import matplotlib.colors as colors color_norm = colors.Normalize(vmin=0, vmax=N - 1) scalar_map = cmx.ScalarMappable(norm=color_norm, cmap=‘hsv‘) def map_index_to_rgb_color(index): return scalar_map.to_rgba(index) return map_index_to_rgb_color def create_roc_auc(label_names, y_trues, y_probs, png_save_path, is_show=True): """ 使用sklearn得api计算ROC,并绘制曲线 :param label_names: :param y_trues: :param y_probs: :param png_save_path: :param is_show: :return: """ assert type(label_names) is list assert type(y_trues) is list assert type(y_probs) is list assert len(y_trues) == len(y_probs) labels = list(label_names) n_classes = len(label_names) y_true = np.array(y_trues) y_prob = np.array(y_probs) y_true_one_hot = label_binarize(y_true, np.arange(n_classes)) # 装换成类似二进制的编码 # Compute ROC curve and ROC area for each class fpr, tpr, roc_auc = {}, {}, {} for i in range(n_classes): fpr[i], tpr[i], thres = metrics.roc_curve(y_true_one_hot[:, i], y_prob[:, i]) roc_auc[i] = metrics.auc(fpr[i], tpr[i]) pprint(fpr) pprint(tpr) print(‘AUC: {}‘.format(roc_auc)) mpl.rcParams[‘font.sans-serif‘] = u‘DejaVu Sans‘ # DejaVu Sans SimHei mpl.rcParams[‘axes.unicode_minus‘] = False fig = plt.figure() color = (‘b‘, ‘g‘, ‘r‘, ‘c‘, ‘m‘, ‘y‘, ‘k‘, ‘w‘) cmap = get_cmap(n_classes) # Plot of a ROC curve for a specific class for i in range(n_classes): # FPR就是横坐标,TPR就是纵坐标 _col = cmap(i) if n_classes > len(color) else color[i] plt.plot(fpr[i], tpr[i], c=_col, lw=2, alpha=0.7, label=u‘%d AUC=%.3f‘ % (i, roc_auc[i])) plt.plot((0, 1), (0, 1), c=‘#808080‘, lw=1, ls=‘--‘, alpha=0.7) plt.xlim((-0.01, 1.02)) plt.ylim((-0.01, 1.02)) plt.xticks(np.arange(0, 1.1, 0.1)) plt.yticks(np.arange(0, 1.1, 0.1)) plt.xlabel(‘False Positive Rate‘, fontsize=13) plt.ylabel(‘True Positive Rate‘, fontsize=13) plt.grid(b=True, ls=‘:‘) plt.legend(loc=‘lower right‘, fancybox=True, framealpha=0.8, fontsize=12) plt.title(u‘ROC curve‘, fontsize=17) plt.savefig(png_save_path, format=‘png‘) if is_show: plt.show() return fig def create_roc_self(label_names, y_trues, y_probs, png_save_path, is_show=True): """ python 实现计算tpr, fpr; 同时统计多个阈值下每个class的指标,用于后处理时选择最优阈值 :param label_names: :param y_trues: :param y_probs: :param png_save_path: :param is_show: :return: """ assert type(label_names) is list assert type(y_trues) is list assert type(y_probs) is list assert len(y_trues) == len(y_probs) n_classes = len(label_names) y_trues = np.array(y_trues) y_probs = np.array(y_probs) bs = y_probs.shape[0] y_trues_one_hot = label_binarize(y_trues, np.arange(n_classes)) # 装换成类似二进制的编码 print(y_trues) print(y_trues_one_hot) tpr_dict, fpr_dict = {}, {} thresh = [i / 10 for i in range(1, 11)] # y_pred = np.argmax(y_probs, axis=1) # [n,] for i in range(n_classes): tpr_dict[i] = [] fpr_dict[i] = [] y_true = y_trues_one_hot[:, i] # [n,] y_pred_prob = y_probs[:, i] # 计算下0.1~1.0这每个阈值下的tpr, fpr for th in thresh: # tpr = tp/(tp+fn), fpr = fp/(tn+fp) # y_pred_prob = np.array([y_probs[i, y_pred[i]] for i in range(bs)]) # [n,] y_pred2 = np.where(y_pred_prob >= th, 1, 0) tp = np.sum(y_pred2[y_true == 1] == 1) fn = np.sum(y_pred2[y_true == 1] == 0) fp = np.sum(y_pred2[y_true == 0] == 1) tn = np.sum(y_pred2[y_true == 0] == 0) tpr = tp / (tp + fn + 1e-5) fpr = fp / (tn + fp + 1e-5) print(f‘thres={th}, tpr={tpr}, fpr={fpr}‘) tpr_dict[i].append(round(tpr, 2)) fpr_dict[i].append(round(fpr, 2)) pprint(‘tpr: {}‘.format(tpr_dict)) pprint(‘fpr: {}‘.format(fpr_dict)) cols = 2 rows = round(n_classes / cols) fig = plt.figure(figsize=(12, 12), dpi=150) fig.suptitle(‘per class tpr and fpr‘, fontsize=‘xx-large‘) for r in range(rows): for c in range(cols): id = r * cols + c if id > n_classes - 1: break ax = fig.add_subplot(rows, cols, id + 1) x = thresh ax.plot(x, tpr_dict[id], c=‘b‘, label=‘tpr‘) ax.plot(x, fpr_dict[id], c=‘r‘, label=‘fpr‘) ax.set_xlabel(‘thres‘, fontsize=‘x-large‘) ax.set_ylabel(‘tpr_fpr‘, fontsize=‘x-large‘) plt.xticks(np.arange(0, 1.1, 0.2)) plt.yticks(np.arange(0, 1.1, 0.2)) handles, labels = ax.get_legend_handles_labels() fig.legend(handles, labels, loc=‘lower right‘, fontsize=‘x-large‘) plt.savefig(png_save_path, format=‘png‘) if is_show: plt.show() return fig # 计算每个class的 fpr, tpr np.random.seed(888) if __name__ == ‘__main__‘: labels = [‘A‘, ‘B‘, ‘C‘] batch_size = 100 # 真值和预测值 y_true = np.random.randint(0, len(labels), [batch_size]).tolist() y_prob = np.random.random([batch_size, len(labels)]).tolist() # _ = create_roc_auc(labels, y_true, y_prob, ‘./ss1.png‘) _ = create_roc_self(labels, y_true, y_prob, ‘./ss2.png‘)