📜  Python – tensorflow.eye()

📅  最后修改于: 2022-05-13 01:54:48.909000             🧑  作者: Mango

Python – tensorflow.eye()

TensorFlow 是由 Google 设计的开源Python库,用于开发机器学习模型和深度学习神经网络。

tensorflow.eye()用于生成单位矩阵。

示例 1:

Python3
# Importing the library
import tensorflow as tf
  
# Initializing the input
num_rows = 5
  
# Printing the input
print('num_rows:', num_rows)
  
# Calculating result
res = tf.eye(num_rows)
  
# Printing the result
print('res: ', res)


Python3
# Importing the library
import tensorflow as tf
  
# Initializing the input
num_rows = 5
num_columns = 6
batch_shape = [3]
  
# Printing the input
print('num_rows:', num_rows)
print('num_columns:', num_columns)
print('batch_shape:', batch_shape)
  
# Calculating result
res = tf.eye(num_rows, num_columns, batch_shape)
  
# Printing the result
print('res: ', res)


输出:

num_rows: 5
res:  tf.Tensor(
[[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]], shape=(5, 5), dtype=float32)

示例 2:

Python3

# Importing the library
import tensorflow as tf
  
# Initializing the input
num_rows = 5
num_columns = 6
batch_shape = [3]
  
# Printing the input
print('num_rows:', num_rows)
print('num_columns:', num_columns)
print('batch_shape:', batch_shape)
  
# Calculating result
res = tf.eye(num_rows, num_columns, batch_shape)
  
# Printing the result
print('res: ', res)

输出:

num_rows: 5
num_columns: 6
batch_shape: [3]
res:  tf.Tensor(
[[[1. 0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 1. 0.]]

 [[1. 0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 1. 0.]]

 [[1. 0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 1. 0.]]], shape=(3, 5, 6), dtype=float32)