📜  Python – tensorflow.gather_nd()(1)

📅  最后修改于: 2023-12-03 15:19:03.593000             🧑  作者: Mango

Python – tensorflow.gather_nd()

Introduction

TensorFlow is an open-source software toolkit developed and maintained by Google that is used for building machine learning models. TensorFlow provides a wide range of operations that can be used to manipulate and transform tensors. One such operation is the gather_nd() method, which is used to gather specific elements from a tensor.

Usage

The gather_nd() method takes two arguments as input. The first argument is the tensor from which elements are to be gathered, and the second argument is a list of indices of the elements to be gathered. The indices can be passed as a numpy array or a TensorFlow tensor.

The gather_nd() method allows you to gather elements from a tensor along any number of dimensions. The rank of the resulting tensor will be equal to the rank of the indices tensor minus one.

import tensorflow as tf
import numpy as np

tensor = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
indices = np.array([[0, 0], [2, 1], [1, 2]])

gathered_values = tf.gather_nd(tensor, indices)
print(gathered_values)

# Output
# tf.Tensor([1 8 6], shape=(3,), dtype=int32)

In the above example, we have a 2D tensor of shape (3, 3) and the indices array of shape (3, 2). The gather_nd() method has been used to gather elements at the specified indices. The resulting tensor has a rank of 1 and contains the gathered values.

Conclusion

The gather_nd() method is a powerful tool that can be used for a wide range of applications in TensorFlow. It allows you to gather specific elements from a tensor based on their indices. The method is particularly useful when working with multi-dimensional tensors, as it allows you to gather elements along any number of dimensions.