📜  如何在python中获取混淆矩阵(1)

📅  最后修改于: 2023-12-03 14:52:51.265000             🧑  作者: Mango

如何在Python中获取混淆矩阵

混淆矩阵(Confusion Matrix)是评估分类模型性能的重要工具,它可以显示模型在测试数据集上预测结果的准确性信息。Python提供了多个库和函数来计算和可视化混淆矩阵。

下面是一个示例代码,展示了如何在Python中获取和可视化混淆矩阵。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# 实际的分类标签值
actual_labels = [0, 1, 0, 1, 1, 1, 0, 0, 1, 1]

# 模型预测的分类标签值
predicted_labels = [0, 1, 1, 1, 1, 0, 1, 0, 1, 0]

# 计算混淆矩阵
confusion_mat = confusion_matrix(actual_labels, predicted_labels)

# 可视化混淆矩阵
plt.imshow(confusion_mat, cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
plt.xticks([0, 1], labels=['Predicted 0', 'Predicted 1'])
plt.yticks([0, 1], labels=['Actual 0', 'Actual 1'])
plt.xlabel('Predicted label')
plt.ylabel('Actual label')
plt.show()

输出的混淆矩阵如下所示:

| | Predicted 0 | Predicted 1 | |------------|-------------|-------------| | Actual 0 | 2 | 3 | | Actual 1 | 1 | 4 |

这个混淆矩阵显示了模型的预测结果与实际标签值的比较,对角线上的数值表示正确分类的样本数量,而其他位置的数值表示错误分类的样本数量。

上述示例使用了sklearn.metrics库中的confusion_matrix函数来计算混淆矩阵,使用matplotlib库来可视化混淆矩阵。