📜  tf.squeeze() - Python (1)

📅  最后修改于: 2023-12-03 15:35:19.440000             🧑  作者: Mango

tf.squeeze() - Python

Introduction

In TensorFlow, the tf.squeeze() function is used to remove dimensions of size 1 from a tensor. It is often used to remove extra dimensions that are added during the process of training neural networks.

Syntax
tf.squeeze(input, axis=None, name=None)
Parameters
input
  • A Tensor. The input tensor from which dimension of size 1 will be removed
axis
  • An optional list of ints. If specified, only squeezes the dimensions listed. The dimension index starts at 0.
  • By default, all dimensions that are of size 1 will be removed.
name
  • An optional string name for the operation.
Return Value
  • The squeezed tensor with the same data as input tensor.
Example
import tensorflow as tf

# create a tensor with extra dimensions
x = tf.constant([[[[1], [2]]]])

# use squeeze to remove the extra dimensions
y = tf.squeeze(x)

print("Shape before squeezing:", x.shape)
print("Shape after squeezing:", y.shape)

Output:

Shape before squeezing: (1, 1, 2, 1)
Shape after squeezing: (2,)

In this example, the original tensor x had dimensions of size 1 at indices 0, 1, and 3. Using tf.squeeze(), we removed those dimensions and obtained a new tensor y with dimensions (2,).

Conclusion

The tf.squeeze() function is a powerful tool in TensorFlow for removing dimensions of size 1 from a tensor. By reducing the number of dimensions in a tensor, we can simplify computations and make our code easier to understand.