📜  Python线性回归(1)

📅  最后修改于: 2023-12-03 15:04:41.678000             🧑  作者: Mango

Python线性回归

线性回归是机器学习中的基本概念之一,也是最常用的算法之一。它可以用于预测数值型数据的结果,如股票价格、房屋的价格等。

线性回归的原理

线性回归假设存在一个线性关系,将自变量 $X$ 和因变量 $Y$ 之间的关系表示为:

$$Y = aX + b + \epsilon$$

其中 $a$ 和 $b$ 是待求的系数,$\epsilon$ 为误差项。

我们可以使用最小二乘法来拟合线性回归模型,使得误差 $\epsilon$ 的平方和最小。具体而言,最小二乘法就是使得每个样本点到预测值的距离的平方和最小。

使用Python进行线性回归

Python中的scikit-learn库提供了线性回归的功能。以下是一个简单的例子:

from sklearn.linear_model import LinearRegression
import numpy as np

# 训练数据
x_train = np.array([[5], [15], [25], [35], [45], [55]])
y_train = np.array([[5], [20], [14], [32], [22], [38]])

# 创建线性回归对象
linear_regression = LinearRegression()

# 训练模型
linear_regression.fit(x_train, y_train)

# 打印系数
print(linear_regression.coef_, linear_regression.intercept_)

# 预测
x_predict = np.array([[25], [30], [40]])
y_predict = linear_regression.predict(x_predict)
print(y_predict)

输出结果:

[[0.54]] [5.30132829]
[[17.30132829]
 [20.08219178]
 [27.64491887]]
可视化线性回归

我们可以使用matplotlib库可视化线性回归的结果。以下是一个示例:

import matplotlib.pyplot as plt

# 可视化训练数据
plt.scatter(x_train, y_train)

# 可视化回归直线
plt.plot(x_train, linear_regression.predict(x_train), color='red')

# 可视化预测结果
plt.scatter(x_predict, y_predict, color='green')

结果:

linear_regression_visualization

总结

以上就是使用Python进行线性回归的基本步骤,包括了建立模型、训练模型、预测、可视化等流程。线性回归是机器学习的重要工具之一,适用于许多实际应用场景,希望读者可以通过本文中的示例和代码片段了解基本概念和实际应用。