微信登录

Matplotlib - 模型评估可视化 - 绘制混淆矩阵等

PyTorch 《Matplotlib - 模型评估可视化 - 绘制混淆矩阵等》

在深度学习中,模型训练完成后,对其进行评估是至关重要的环节。通过可视化的方式展示评估结果,能够让我们更直观地理解模型的性能。Matplotlib 是 Python 中一个强大的绘图库,结合 PyTorch 可以方便地实现各种模型评估指标的可视化,本文将重点介绍如何使用 Matplotlib 绘制混淆矩阵等常见的评估可视化图表。

1. 环境准备

在开始之前,确保你已经安装了 PyTorch 和 Matplotlib。可以使用以下命令进行安装:

  1. pip install torch matplotlib

2. 绘制混淆矩阵

2.1 混淆矩阵简介

混淆矩阵是一种常用的评估分类模型性能的工具,它可以清晰地展示模型在每个类别上的分类情况。矩阵的行表示真实类别,列表示预测类别,对角线上的元素表示正确分类的样本数,非对角线上的元素表示错误分类的样本数。

2.2 示例代码

  1. import torch
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from sklearn.metrics import confusion_matrix
  5. import seaborn as sns
  6. # 模拟真实标签和预测标签
  7. true_labels = torch.tensor([0, 1, 2, 0, 1, 2])
  8. pred_labels = torch.tensor([0, 2, 1, 0, 0, 2])
  9. # 计算混淆矩阵
  10. cm = confusion_matrix(true_labels, pred_labels)
  11. # 绘制混淆矩阵
  12. plt.figure(figsize=(8, 6))
  13. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
  14. xticklabels=['Class 0', 'Class 1', 'Class 2'],
  15. yticklabels=['Class 0', 'Class 1', 'Class 2'])
  16. plt.xlabel('Predicted Labels')
  17. plt.ylabel('True Labels')
  18. plt.title('Confusion Matrix')
  19. plt.show()

2.3 代码解释

  • 首先,我们使用 torch.tensor 模拟了真实标签和预测标签。
  • 然后,使用 sklearn.metrics.confusion_matrix 函数计算混淆矩阵。
  • 最后,使用 seaborn.heatmap 函数绘制混淆矩阵,annot=True 表示在每个单元格中显示具体的数值,fmt='d' 表示以整数形式显示。

3. 绘制 ROC 曲线

3.1 ROC 曲线简介

ROC 曲线(Receiver Operating Characteristic Curve)是一种用于评估二分类模型性能的工具,它以假正率(False Positive Rate)为横轴,真正率(True Positive Rate)为纵轴,通过改变分类阈值来绘制曲线。曲线下的面积(AUC)越大,说明模型的性能越好。

3.2 示例代码

  1. from sklearn.metrics import roc_curve, auc
  2. # 模拟真实标签和预测概率
  3. true_labels = torch.tensor([0, 1, 0, 1])
  4. pred_probs = torch.tensor([0.1, 0.9, 0.2, 0.8])
  5. # 计算 FPR, TPR 和 AUC
  6. fpr, tpr, thresholds = roc_curve(true_labels, pred_probs)
  7. roc_auc = auc(fpr, tpr)
  8. # 绘制 ROC 曲线
  9. plt.figure(figsize=(8, 6))
  10. plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
  11. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  12. plt.xlim([0.0, 1.0])
  13. plt.ylim([0.0, 1.05])
  14. plt.xlabel('False Positive Rate')
  15. plt.ylabel('True Positive Rate')
  16. plt.title('Receiver Operating Characteristic')
  17. plt.legend(loc="lower right")
  18. plt.show()

3.3 代码解释

  • 我们使用 torch.tensor 模拟了真实标签和预测概率。
  • 然后,使用 sklearn.metrics.roc_curve 函数计算假正率、真正率和阈值,使用 sklearn.metrics.auc 函数计算 AUC。
  • 最后,使用 plt.plot 函数绘制 ROC 曲线,并添加了对角线作为参考线。

4. 总结

可视化图表 用途 相关函数
混淆矩阵 展示分类模型在每个类别上的分类情况 sklearn.metrics.confusion_matrixseaborn.heatmap
ROC 曲线 评估二分类模型的性能 sklearn.metrics.roc_curvesklearn.metrics.auc

通过使用 Matplotlib 和相关的库,我们可以方便地实现模型评估结果的可视化,从而更直观地理解模型的性能。在实际应用中,可以根据具体的需求选择合适的可视化图表进行分析。

希望本文能够帮助你掌握使用 Matplotlib 进行模型评估可视化的方法。如果你有任何疑问或建议,欢迎留言讨论。

Matplotlib - 模型评估可视化 - 绘制混淆矩阵等