📜  在 Tensorflow 中保存和加载模型(1)

📅  最后修改于: 2023-12-03 14:51:11.355000             🧑  作者: Mango

在 Tensorflow 中保存和加载模型

Tensorflow 提供了将模型保存到硬盘和从硬盘加载模型的功能。这个功能对于在训练模型时定期保存模型状态非常有用,这样在模型中断或失败时可以恢复模型。它还可以用于在不重新训练模型的情况下重新使用模型。

在 Tensorflow 中,可以使用 tf.train.Saver 实例来保存和加载模型。下面详细介绍这个功能。

保存模型

要保存模型,首先必须创建 tf.train.Saver 实例。通常,这个实例在创建模型时被创建,例如:

import tensorflow as tf

# 假设这个变量是你训练模型的一部分
my_variable = tf.Variable([1.0, 2.0], name="my_variable")

# 现在创建一个 saver 实例
saver = tf.train.Saver()

注意,Tensorflow 中所有的变量都必须有一个名字,因为这个名字是用来保存和加载变量值的。

要保存模型,可以通过 Saver.save 方法执行:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 训练你的模型
    # ...
    # 现在保存你的模型
    save_path = saver.save(sess, "/tmp/model.ckpt")
    print("Model saved in path: %s" % save_path)

这个操作将序列化图形和所有变量值到文件 /tmp/model.ckpt。在保存模型时,你必须在 sess.run 块中运行 tf.global_variables_initializer(),以确保所有变量都被初始化。

save_path 是返回的字符串,其中包含完整的路径和文件名。这个字符串可以在稍后加载模型时使用。

加载模型

要加载模型,必须按照与保存模型相同的方式定义变量。然后在定义 Saver 实例和会话时,通过给 Saver 构造函数传递变量的列表来指定要加载的变量。

例如,如果要加载前面例子中保存的模型,可以执行:

import tensorflow as tf

# 定义与保存模型相同的变量
my_variable = tf.Variable([0.0, 0.0], name="my_variable")

# 创建一个 Saver 实例
saver = tf.train.Saver([my_variable])

# 使用 saver.restore() 方法加载模型
with tf.Session() as sess:
    # 加载以前保存的模型
    saver.restore(sess, "/tmp/model.ckpt")
    print("Model restored.")

    # 现在开始对模型进行评估
    # ...

注意,必须在定义变量时给它们相同的名字才能正常加载模型。此外,被加载的变量的形状和类型也必须匹配。

在上面的例子中,我们指定了 my_variable 变量,但是实际上你可能有更多的变量,你可以将它们作为列表传递给构造函数,例如:

saver = tf.train.Saver([my_first_variable, my_second_variable, my_third_variable])
保存和加载模型参数

Tensorflow 还提供了一个方便的函数 tf.trainable_variables 来获取所有训练模型参数的列表。这可以用来方便地创建 Saver 实例。

例如:

import tensorflow as tf

# 创建变量
my_variable1 = tf.Variable([1.0, 2.0], name="my_variable1")
my_variable2 = tf.Variable([3.0, 4.0], name="my_variable2")

# 创建 Saver 实例
saver = tf.train.Saver(tf.trainable_variables())

# 保存模型
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    save_path = saver.save(sess, "/tmp/model.ckpt")
    print("Model saved in path: %s" % save_path)

# 加载模型
with tf.Session() as sess:
    # 定义与保存模型相同的变量
    my_variable1 = tf.Variable([0.0, 0.0], name="my_variable1")
    my_variable2 = tf.Variable([0.0, 0.0], name="my_variable2")

    # 创建 Saver 实例
    saver = tf.train.Saver(tf.trainable_variables())

    # 加载以前保存的模型
    saver.restore(sess, "/tmp/model.ckpt")
    print("Model restored.")

在这个例子中,我们使用 tf.trainable_variables 函数来获取在变量上执行优化的所有变量的列表。然后我们将这个列表传递给 Saver 构造函数,这样就不需要手动指定每个变量了。

当我们加载模型时,我们也定义了一个与原来相同的变量,但是初始化为 0。然后我们使用 tf.trainable_variablesSaver 实例来加载模型参数。