📜  sklearn 绘制混淆矩阵 - Python (1)

📅  最后修改于: 2023-12-03 14:47:28.380000             🧑  作者: Mango

使用sklearn绘制混淆矩阵 - Python

混淆矩阵是衡量分类算法性能的一种可视化评估工具。sklearn提供了一个方便的函数confusion_matrix,可以帮助我们创建混淆矩阵。

导入库和数据

首先,我们需要导入需要使用的库和数据。

import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 假设我们有真实标签和预测标签的数据
y_true = [0, 1, 0, 1, 0, 1]
y_pred = [0, 1, 1, 1, 0, 0]
创建混淆矩阵

使用confusion_matrix函数,我们可以计算真实标签和预测标签之间的混淆矩阵。

cm = confusion_matrix(y_true, y_pred)
可视化混淆矩阵

一种常见的可视化混淆矩阵的方法是使用热力图。我们可以使用seaborn库的heatmap函数来创建这个热力图。

df_cm = pd.DataFrame(cm, index=['True 0', 'True 1'], columns=['Pred 0', 'Pred 1'])
plt.figure(figsize=(10,7))
sns.heatmap(df_cm, annot=True, fmt='d', cmap='YlGnBu')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

以上代码会生成一个可视化的混淆矩阵热力图,横轴表示预测标签,纵轴表示真实标签。

完整代码
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 假设我们有真实标签和预测标签的数据
y_true = [0, 1, 0, 1, 0, 1]
y_pred = [0, 1, 1, 1, 0, 0]

# 创建混淆矩阵
cm = confusion_matrix(y_true, y_pred)

# 可视化混淆矩阵
df_cm = pd.DataFrame(cm, index=['True 0', 'True 1'], columns=['Pred 0', 'Pred 1'])
plt.figure(figsize=(10,7))
sns.heatmap(df_cm, annot=True, fmt='d', cmap='YlGnBu')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

以上代码会生成一个可视化的混淆矩阵热力图,用于评估分类算法性能。

如果你想了解更多关于混淆矩阵的细节,请参阅sklearn文档的混淆矩阵部分