📜  sklearn 预测阈值 - Python (1)

📅  最后修改于: 2023-12-03 15:20:09.414000             🧑  作者: Mango

sklearn 预测阈值介绍

在使用机器学习模型进行分类时,需要设置预测阈值来判断分类的结果。Scikit-learn提供了一些方法可以帮助我们计算和调整预测阈值,以达到更好的分类效果。

阈值的概念

阈值是在分类问题中被用来进行决策的一个界限值。对于二元分类问题,通常将预测概率值大于阈值的样本标记为正样本,反之则标记为负样本。阈值的大小会影响模型的精确度、召回率和F1-score等性能指标。

阈值调整方法

Scikit-learn提供了两种常见的阈值调整方法:阈值移动和ROC曲线。

阈值移动

阈值移动法是一种简单有效的方法,通过不断调整阈值来获得更好的模型性能。我们可以使用predict_proba()方法获得模型对于每个样本的预测概率值,然后根据阈值将其转化为类别标签。通过改变阈值,可以取得不同的召回率和准确率。

from sklearn.metrics import precision_recall_curve

y_scores = model.predict_proba(X_test)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_test, y_scores)

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="upper left")
    plt.ylim([0, 1])

plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()

上述代码中,首先调用了predict_proba()方法获取预测概率值,然后使用precision_recall_curve()方法计算不同阈值下的精确率和召回率。最后调用plot_precision_recall_vs_threshold()方法将准确率和召回率与不同阈值对比画图展示。

ROC曲线

ROC曲线是另一种评估模型分类准确度的方法。这种方法也可以帮助优化模型的阈值。ROC曲线的横坐标是假阳性率,纵坐标是真阳性率,可以用来衡量模型总体分类性能。ROC曲线越接近左上角,代表模型的性能越好。

我们可以通过调用roc_curve()方法来获取FPR,TPR等值,然后绘制ROC曲线。

from sklearn.metrics import roc_curve, auc

fpr, tpr, thresholds = roc_curve(y_test, y_scores)
roc_auc = auc(fpr, tpr)

def plot_roc_curve(fpr, tpr, label=None):
    plt.plot(fpr, tpr, linewidth=2, label=label)
    plt.plot([0, 1], [0, 1], 'r--')
    plt.axis([0, 1, 0, 1])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')

plt.figure(figsize=(8, 6))
plot_roc_curve(fpr, tpr, roc_auc)
plt.show()

上述代码中,通过调用roc_curve()方法计算FPR和TPR,然后使用auc()方法求出曲线下的面积auc,得到ROC曲线的面积。最后调用plot_roc_curve()方法将ROC曲线画出并展示出来。

总结

本篇文章介绍了阈值的概念和使用Scikit-learn进行阈值调整的两种方法:阈值移动和ROC曲线。这些方法可以帮助优化模型性能,从而得到更好的分类结果。