📜  Python – tensorflow.GradientTape.stop_recording()(1)

📅  最后修改于: 2023-12-03 15:34:06.795000             🧑  作者: Mango

Python – tensorflow.GradientTape.stop_recording()

介绍

在 TensorFlow 中,GradientTape API 是前向传递和反向传递中最强大的 API 之一。在这个 API 中,tf.GradientTape 类被用来自动计算梯度。tf.GradientTape API 是 TensorFlow 2.0 中的一个新功能。

一般来说 GradientTape 不会生成有效的计算图。通常情况下,在前向传递期间,GradientTape 监视所有操作,以进行自动微分计算。但是默认情况下,GradientTape 在执行前向传递时不跟踪某些操作,如在执行计算时跳出到 Python 等。如果我们希望 GradientTape 监控这些操作,则需要使用 tf.GradientTape.start_recording() 方法启用跟踪。

相对应的 tf.GradientTape.stop_recording() 方法用于停止跟踪这些操作。

语法

以下是 tf.GradientTape.stop_recording() 的语法:

stop_recording()
参数

tf.GradientTape.stop_recording() 方法不需要参数。

返回值

tf.GradientTape.stop_recording() 方法没有返回值;

例子

下面的例子将演示如何使用 tf.GradientTape.start_recording()tf.GradientTape.stop_recording() 开启关闭 tape 的记录跟踪功能。

import tensorflow as tf

x = tf.constant(2.0)
y = tf.constant(3.0)


with tf.GradientTape() as tape:
    tape.watch([x,y])
    z = x * y
    tape.stop_recording()   # stop tracking x and y here to make them constant
    k = x + y

dz_dx, dz_dy = tape.gradient(z, [x, y])
dk_dx, dk_dy = tape.gradient(k, [x, y])
print(dz_dx.numpy())  # Output: 3.0
print(dz_dy.numpy())  # Output: 2.0
print(dk_dx)          # Output: None
print(dk_dy)          # Output: None

在上面的例子中,GradientTape 在跟踪 x 和 y 下的 'z = x * y' 操作,然后使用 tf.GradientTape.stop_recording() 停止跟踪 x 和 y 在 'x + y' 操作中。接着使用 tape.gradient() 计算 z 和 k 对 x 和 y 的偏导数,我们可以看到 tape.gradient() Yields None for Constant Input x + y。

Note
  • tf.GradientTape.stop_recording() 方法需要相对应的 tf.GradientTape.start_recording() 或者处于 tf.GradientTape() 上下文范围之内才能使用。

  • 用于记录的 tf.GradientTape 上的操作是动态生成的,它们只有在每次前向传递时才被执行。

  • tf.GradientTape.stop_recording() 可以不使用而快速抛开计算图构建,可加快速度。如上例所示,我们只对 * 操作进行了跟踪计算梯度,加快了计算速度。

  • tf.GradientTape.stop_recording() 可以在启用Monitoring 或不启用Monitoring 的情况下使用。它不影响 tf.GradientTape.gradient() 的结果。