📜  加载保存的模型 pyspark (1)

📅  最后修改于: 2023-12-03 15:07:17.990000             🧑  作者: Mango

加载保存的模型 PySpark

在 PySpark 中,我们可以使用 pyspark.ml 模块的 Pipeline API 来创建机器学习模型,并使用 save() 函数将训练好的模型保存在本地文件系统或分布式文件系统中(如 HDFS)。

当我们需要重新使用训练好的模型时,我们可以使用 pyspark.ml 模块的 load() 函数来加载保存的模型。

以下是加载保存的模型 PySpark 的介绍,包括如何保存模型以及如何加载保存的模型。

保存模型

在 PySpark 中,我们使用以下代码语句可以将已训练好的模型保存到本地文件系统:

from pyspark.ml import PipelineModel

# 假设我们已经训练好了一个模型,并将其赋值给变量 model
model.write().overwrite().save('path/to/save/model')

在上面的代码中,我们使用 write() 函数将模型保存到指定目录 path/to/save/model 中。如果该目录已经存在,则需要添加 overwrite() 参数来覆盖原有的同名目录。

也可以使用以下代码语句将模型保存到 HDFS 中:

model.write().overwrite().save('hdfs://path/to/save/model')
加载保存的模型

在 PySpark 中,我们可以使用以下代码语句加载保存的模型:

from pyspark.ml import PipelineModel

loaded_model = PipelineModel.load('path/to/saved/model')

在上面的代码中,我们使用 load() 函数从指定目录 path/to/saved/model 中加载已保存的模型,并将其赋值给变量 loaded_model 进行后续使用。

需要注意的是,与保存模型时相同,如果模型保存在 HDFS 中,则读入模型时应该使用 hdfs:// 开头的路径。

示例

接下来,我们来看一个例子,使用 PySpark 训练一个简单的线性回归模型,并将其保存至本地文件系统:

from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

# 创建 DataFrame
dataset = spark.createDataFrame([(1.0, 2.0, 3.0), (2.0, 3.0, 5.0), (3.0, 4.0, 7.0)], ["x1", "x2", "y"])

# 特征列
assembler = VectorAssembler(inputCols=["x1", "x2"], outputCol="features")

# 线性回归模型
lr = LinearRegression(featuresCol="features", labelCol="y")

# 创建 Pipeline
pipeline = Pipeline(stages=[assembler, lr])

# 训练模型
model = pipeline.fit(dataset)

# 保存模型
model.write().overwrite().save('path/to/save/model')

接下来,我们再加载保存的模型并使用它进行预测:

from pyspark.ml import PipelineModel

# 加载模型
loaded_model = PipelineModel.load('path/to/save/model')

# 创建测试数据
test_data = spark.createDataFrame([(4.0, 5.0), (5.0, 6.0)], ["x1", "x2"])

# 使用模型进行预测
result = loaded_model.transform(test_data)

# 显示预测结果
result.show()

输出结果如下:

+---+---+-------------+------------------+
| x1| x2|     features|        prediction|
+---+---+-------------+------------------+
|4.0|5.0|[4.0,5.0]|[8.125000000000002]|
|5.0|6.0|[5.0,6.0]|        [10.125000000000002]|
+---+---+-------------+------------------+