📜  如何在 keras 中保存和加载模型 - Python (1)

📅  最后修改于: 2023-12-03 15:24:14.669000             🧑  作者: Mango

如何在 keras 中保存和加载模型 - Python

在 keras 中,我们可以通过两种方式保存已经训练好的神经网络模型:

  1. 保存模型的结构和参数
  2. 只保存模型的参数
保存模型的结构和参数

我们可以使用 model.save() 方法将模型的结构以及参数保存下来。这个方法会将模型的架构(包括每一层的参数)和训练得到的权重参数一同保存下来,具体代码如下:

from keras.models import load_model

# 假设我们已经训练好了一个模型
my_model = ...  

# 将模型保存
my_model.save('my_model.h5')

其中,my_model.h5 是你想要保存的模型文件名。

只保存模型的参数

如果我们只想保存模型的参数,可以使用 model.save_weights() 方法。这个方法只会保存模型的参数,而不会保存模型的结构,具体代码如下:

from keras.models import load_model

# 假设我们已经训练好了一个模型
my_model = ...  

# 只保存模型的参数
my_model.save_weights('my_model_weights.h5')
加载模型

当我们需要使用已经训练好的模型时,我们可以使用 load_model() 方法加载模型,具体代码如下:

from keras.models import load_model

# 加载保存好的模型
my_model = load_model('my_model.h5')

如果我们只保存了模型的参数,我们需要先定义好模型的结构,然后使用 load_weights() 方法加载参数,具体代码如下:

from keras.models import load_model

# 定义模型结构
my_model = ...  

# 加载保存好的参数
my_model.load_weights('my_model_weights.h5')

如果你想查看更多细节,请参考 Keras 的文档