📜  线性回归(Python实现)(1)

📅  最后修改于: 2023-12-03 14:56:50.190000             🧑  作者: Mango

线性回归

线性回归是一种用于预测数值型目标变量的机器学习算法。它适用于只有一个自变量(也称为特征)的情况。线性回归的目标是通过构建一个最佳拟合直线(也称为回归线)来建立自变量与目标变量之间的关系。

Python实现

Python提供了许多可以用于线性回归的库,如scikit-learn、tensorflow、numpy等。在这里我们将使用scikit-learn库来实现线性回归模型。

准备数据

我们需要准备一个数据集来训练和测试我们的模型。这里我们使用scikit-learn自带的波士顿房价数据集(Boston Housing Dataset)。我们可以使用以下代码来读取数据集:

from sklearn.datasets import load_boston

boston_data = load_boston()
X = boston_data.data
y = boston_data.target
划分数据集

接下来,我们需要将数据集划分为训练集和测试集。训练集用于训练模型,测试集用于评估模型的性能。我们可以使用train_test_split函数来完成数据集的划分:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

这里我们将数据集划分为70%的训练集和30%的测试集。

训练模型

现在我们可以使用线性回归算法来训练模型。我们可以使用LinearRegression类来完成模型的训练:

from sklearn.linear_model import LinearRegression

lr = LinearRegression()
lr.fit(X_train, y_train)
评估模型

我们可以使用测试集来评估模型的性能。可以使用以下代码来计算模型的均方误差:

from sklearn.metrics import mean_squared_error

y_pred = lr.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print("均方误差:", mse)
可视化结果

我们可以使用matplotlib库来可视化模型的结果。以下代码可以绘制真实值和预测值之间的对比图:

import matplotlib.pyplot as plt

plt.scatter(y_test, y_pred)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=4)
plt.xlabel('真实值')
plt.ylabel('预测值')
plt.show()

这里的直线表示理想情况下真实值和预测值应该是一致的。如果模型表现良好,则散点图应该越接近该直线。

完整代码如下:

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

boston_data = load_boston()
X = boston_data.data
y = boston_data.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

lr = LinearRegression()
lr.fit(X_train, y_train)

y_pred = lr.predict(X_test)
mse = mean_squared_error(y_test, y_pred)

plt.scatter(y_test, y_pred)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=4)
plt.xlabel('真实值')
plt.ylabel('预测值')
plt.show()

print("均方误差:", mse)

该代码将输出模型的均方误差,并绘制真实值和预测值之间的对比图。

小结

在这篇文章中,我们使用scikit-learn库来实现线性回归模型。我们介绍了如何准备数据集、划分数据集、训练模型、评估模型以及可视化结果。线性回归是预测数值型目标变量的一种简单而有效的机器学习算法。