📜  支持向量机 svm 使用 python 数值示例 - Python (1)

📅  最后修改于: 2023-12-03 15:26:01.752000             🧑  作者: Mango

支持向量机 SVM 使用 Python 数值示例

支持向量机(Support Vector Machine, SVM)是一种二分类的模型,可以将输入的数据根据类别分为两类。SVM通过找到最优的分割超平面(hyperplane)来实现分类,其中的“最优”是由SVM的目标函数所决定的。

在介绍SVM的使用前,需要引入一些必要的库,包括NumPy和scikit-learn:

import numpy as np
from sklearn.svm import SVC
数据准备

为了演示SVM的使用,我们需要准备一些数据集。这里我们采用sklearn库中的iris数据集,该数据集包含了3种鸢尾花的特征和类别标签。

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data[:, :2]  # 只使用前两个特征,便于在图像上可视化
y = iris.target

# 将类别标签转化为-1和1
y[y == 0] = -1
y[y == 1] = 1

我们只使用前两个特征是为了便于在图像上可视化,毕竟绘图是最能直观看出分割超平面的。

数据可视化

我们用matplotlib库的scatter函数将数据集在二维平面上作图,用不同的颜色标记不同类别的数据。

import matplotlib.pyplot as plt

plt.scatter(X[:, 0], X[:, 1], c=y)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')

SVM示例数据集散点图

从图中可以看出,数据集中存在明显的分界线将两类数据分开。我们将使用SVM来找到此划分超平面。

SVM模型训练

我们使用scikit-learn库中的SVC类来训练SVM模型。下面的代码演示如何创建一个线性SVM和一个高斯核SVM,两个模型都使用默认参数。

# 创建线性SVM
linear_svm = SVC(kernel='linear', C=1).fit(X, y)

# 创建高斯核SVM
rbf_svm = SVC(kernel='rbf', gamma='auto', C=1).fit(X, y)

上述代码中的C是一个超参数,它决定了如何权衡SVM的训练误差和模型复杂度。具有更小的C值的模型将更容易被误分类,但可能会导致噪声数据的过度拟合;具有更大的C值的模型更易于正确分类,但可能会产生过度拟合,导致泛化能力下降。

SVM模型预测

SVM训练完成后,我们可以使用predict方法来预测新数据。下面的代码演示如何使用线性SVM和高斯核SVM对新样本进行分类。

# 预测
linear_y_pred = linear_svm.predict(X)
rbf_y_pred = rbf_svm.predict(X)
SVM模型评估

为了评估SVM的性能,我们可以计算它的准确率。准确率指的是SVM正确预测出测试集样本的比例。

from sklearn.metrics import accuracy_score

# 计算准确率
linear_acc = accuracy_score(y, linear_y_pred)
rbf_acc = accuracy_score(y, rbf_y_pred)

print(f"线性SVM准确率: {linear_acc}")
print(f"高斯核SVM准确率: {rbf_acc}")
分割超平面可视化

最后,我们用一些python代码将数据集和SVM的分割超平面(即决策边界)在二维平面上可视化。

# 绘制分割超平面
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')

# 获得分割超平面的权重和偏置
linear_w = linear_svm.coef_[0]
linear_b = linear_svm.intercept_[0]

# 计算分割超平面的斜率和截距
linear_slope = -linear_w[0] / linear_w[1]
linear_intercept = -linear_b / linear_w[1]

# 计算支持向量
linear_sv = linear_svm.support_vectors_

# 绘制决策边界和支持向量
plt.plot([3, 9], [3*linear_slope + linear_intercept, 9*linear_slope + linear_intercept], c='r')
plt.scatter(linear_sv[:, 0], linear_sv[:, 1], s=100, c='r', marker='^', edgecolors='k')

SVM分割超平面可视化

从上图可以看出,线性SVM找到的分割超平面是将数据集中的两类数据完美分开的,而高斯核SVM找到的分割超平面则具有更高的鲁棒性,但可能存在一些分类误差。