📜  Tensorflow.js tf.eye()函数(1)

📅  最后修改于: 2023-12-03 14:47:54.912000             🧑  作者: Mango

Tensorflow.js tf.eye()函数介绍

简介

tf.eye()是Tensorflow.js中的一个函数,用于创建一个指定大小的单位矩阵。

语法
tf.eye(
    numRows: number,
    numColumns: number,
    batchShape?: number[],
    dtype?: 'float32'|'int32'|'bool'|'complex64'
): tf.Tensor

参数:

  • numRows(必选):生成矩阵的行数。
  • numColumns(必选):生成矩阵的列数。
  • batchShape(可选):生成矩阵的批次形状(batch shape),默认为null
  • dtype(可选):生成矩阵的数据类型,可以是'float32''int32''bool''complex64',默认为'float32'

返回值:

  • 一个张量,其形状为[batchShape, numRows, numColumns]
示例
const eye = tf.eye(3);
// 输出:
// [
//   [1, 0, 0],
//   [0, 1, 0],
//   [0, 0, 1]
// ]
eye.print();

const eye2 = tf.eye(2, 3);
// 输出:
// [
//   [1, 0, 0],
//   [0, 1, 0],
// ]
eye2.print();

const eye3 = tf.eye(2, 2, [2]);
// 输出:
// [
//   [
//     [1, 0],
//     [0, 1]
//   ],
//   [
//     [1, 0],
//     [0, 1]
//   ]
// ]
eye3.print();
示例解释
  • 示例一:生成一个3x3单位矩阵,输出其内容并打印。
  • 示例二:生成一个2x3单位矩阵,输出其内容并打印。由于指定的列数为3,因此生成的矩阵中第3列由于没有被单位矩阵中的元素填充,因此被默认填充为0。
  • 示例三:生成一个2x2的单位矩阵,并复制一份形成一个批次形状为[2]的张量。输出其内容并打印。