📜  scikit 学习决策树 - Python (1)

📅  最后修改于: 2023-12-03 14:47:18.150000             🧑  作者: Mango

Scikit 学习决策树 - Python

简介

Scikit-learn是一个Python语言的机器学习库,其提供了一系列用于分类,回归和聚类等任务的算法和工具。其中决策树算法是其中一个最为常用的算法之一。在本文中,我们将对Scikit-learn提供的决策树算法进行介绍。

安装

使用pip进行安装:

pip install -U scikit-learn
使用Scikit-learn构建决策树

本节将介绍如何使用Scikit-learn进行决策树的构建。首先,我们需要导入以下库:

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import graphviz

然后我们可以使用鸢尾花数据集进行一个示例:

# 加载数据集并分为训练和测试集
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 训练决策树
clf = tree.DecisionTreeClassifier()
clf.fit(X_train, y_train)

# 显示决策树
dot_data = tree.export_graphviz(clf, out_file=None, class_names=iris.target_names, 
                                feature_names=iris.feature_names, filled=True, rounded=True,
                                special_characters=True)
graph = graphviz.Source(dot_data)
graph

上述代码首先加载了鸢尾花数据集并将其分为训练和测试集,然后创建了一个决策树分类器,并使用训练数据进行拟合。我们还可以使用export_graphviz函数将生成的决策树可视化。结果如下:

决策树

这里我们可以看到,petal length (cm)petal width (cm)是最重要的特征,作为根节点被选中。使用我们的测试数据集进行预测:

# 预测
y_pred = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

得到的准确性为:

Accuracy: 0.97
决策树的参数

Scikit-learn的决策树有许多可调参数,以下是一些常见的参数:

  • max_depth:树的最大深度。
  • min_samples_split:分裂内部节点所需的最小样本数。
  • min_samples_leaf:叶子节点所需的最小样本数。
  • max_features:寻找最佳分裂时要考虑的特征数。
决策树的优缺点

使用决策树模型的优点包括:

  • 简单易懂。
  • 可以同时处理数值型和类别型数据。
  • 可以捕获特征之间的非线性关系。

决策树模型的缺点包括:

  • 对噪声和过度拟合敏感,需要进行剪枝。
  • 无法处理缺失数据,需要进行处理。
  • 对于一些复杂的问题,单独使用决策树会导致准确度不高。
总结

在本文中,我们介绍了如何使用Scikit-learn构建决策树。决策树是机器学习领域中常用的一种可解释性模型。然而,决策树不适用于所有问题,需要根据实际情况进行选择。Scikit-learn提供了一系列机器学习算法和工具,适用于各种问题和场景。