📜  Python – tensorflow.GradientTape.reset()

📅  最后修改于: 2022-05-13 01:54:25.538000             🧑  作者: Mango

Python – tensorflow.GradientTape.reset()

TensorFlow 是由 Google 设计的开源Python库,用于开发机器学习模型和深度学习神经网络。

reset()用于清除磁带存储的所有信息。

示例 1:

Python3
# Importing the library
import tensorflow as tf
  
x = tf.constant(4.0)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x * x
  y+=x*x
  
# Computing gradient without reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x*x*x + x*x): ",res)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x * x
  
  # Resetting the Tape
  gfg.reset()
    
  gfg.watch(x)
  y+=x*x
  
# Computing gradient with reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x*x): ",res)


Python3
# Importing the library
import tensorflow as tf
  
x = tf.constant(3.0)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x
  y+=x*x
  
# Computing gradient without reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x*x + x*x): ",res)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x
  
  # Resetting the Tape
  gfg.reset()
  gfg.watch(x)
  y+=x
  
# Computing gradient with reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x): ",res)


输出:

res(y = x*x*x + x*x):  tf.Tensor(56.0, shape=(), dtype=float32)
res(y = x*x):  tf.Tensor(8.0, shape=(), dtype=float32)

示例 2:

Python3

# Importing the library
import tensorflow as tf
  
x = tf.constant(3.0)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x
  y+=x*x
  
# Computing gradient without reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x*x + x*x): ",res)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x
  
  # Resetting the Tape
  gfg.reset()
  gfg.watch(x)
  y+=x
  
# Computing gradient with reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x): ",res)

输出:

res(y = x*x + x*x):  tf.Tensor(12.0, shape=(), dtype=float32)
res(y = x):  tf.Tensor(1.0, shape=(), dtype=float32)