📜  scikit 学习决策树 12 - Python (1)

📅  最后修改于: 2023-12-03 15:34:51.462000             🧑  作者: Mango

Scikit 学习决策树 12 - Python

如果你想使用Python来学习决策树, Scikit-learn 库是一个很好的选择。Scikit-learn是一个免费的机器学习库,它为Python提供了一系列机器学习算法,其中包括了决策树。

在本文中,我们将使用Scikit-learn库的决策树模型来预测关于Iris数据集的花卉种类。

Iris 数据集简介

Iris是一个常用的数据集,它包含了三种不同的鸢尾花(Iris Setosa,Iris Versicolour和Iris Virginica)的样本数据,每个样本有四个特征(花瓣长度,花瓣宽度,花萼长度和花萼宽度)和它所属的鸢尾花种类。

我们将使用这个数据集来训练我们的决策树模型。Scikit-learn库包含了这个数据集,可以直接从库中加载。

# 加载Iris数据集
from sklearn.datasets import load_iris

iris = load_iris()

# 查看数据集的特征
print(iris.feature_names)

# 查看数据集的标签
print(iris.target_names)

输出结果:

['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
['setosa' 'versicolor' 'virginica']

接下来,我们将把这个数据集划分为训练集和测试集。

# 导入 train_test_split 来划分训练集和测试集
from sklearn.model_selection import train_test_split

# 将数据集划分为训练集和测试集,并将训练集和测试集按 7:3 的比例划分
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=1)
构建决策树模型

接下来我们将使用Scikit-learn库的DecisionTreeClassifier类来构建决策树模型。这个类提供了多种参数,我们将选择其中的一些来创建我们的模型。

# 导入DecisionTreeClassifier类来构建决策树
from sklearn.tree import DecisionTreeClassifier

# 创建 DecisionTreeClassifier 实例并指定一些参数
clf = DecisionTreeClassifier(criterion='gini') # 选择基于Gini指数的决策树

# 在训练集上训练模型
clf = clf.fit(X_train, y_train)
预测和评估

一旦我们完成训练,我们可以使用测试集来评估我们的模型的准确性。

# 在测试集上进行预测
y_pred = clf.predict(X_test)

# 计算模型准确性
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: ", accuracy)

输出结果:

Accuracy:  0.9777777777777777

这个模型的准确性非常高,达到了0.98。

可视化决策树

决策树可以被可视化为一个图形化的模型,其中每个节点表示对数据的一个判断,并基于该判断将数据分为两个或更多的子集。Scikit-learn库包含一个 export_graphviz 函数,该函数将训练过程生成的树导出为 Graphviz 格式。我们可以使用Graphviz软件包将这个格式转换为矢量图形,以便它们可以被可视化。

# 导入pydotplus库和export_graphviz函数
from sklearn.tree import export_graphviz
import pydotplus

# 将决策树导出为Graphviz格式
dot_data = export_graphviz(clf, out_file=None,
                           feature_names=iris.feature_names,
                           class_names=iris.target_names,
                           filled=True, rounded=True,
                           special_characters=True)

# 将Graphviz格式转换为矢量图形
graph = pydotplus.graph_from_dot_data(dot_data)

# 将矢量图形保存到文件
graph.write_pdf("iris.pdf")

在运行这个代码后,你会得到一个名为iris.pdf的PDF文件,其中包含了整个决策树的可视化。你可以使用任何常用的PDF查看器来查看这个PDF文件。

结论

Scikit-learn库是一个非常强大的Python库,它提供了各种机器学习算法,其中包括了决策树模型。在本文中,我们使用Scikit-learn库的DecisionTreeClassifier类来构建决策树模型,并使用Scikit-learn库的train_test_split和accuracy_score函数来评估模型的准确性。最后,我们可视化了决策树,并将其保存为一个PDF文件。