📜  Tensorflow.js tf.layers.lstmCell()函数

📅  最后修改于: 2022-05-13 01:56:50.769000             🧑  作者: Mango

Tensorflow.js tf.layers.lstmCell()函数

Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

layers.lstmCell()函数用于 LSTM 的 Cell 类。它与 RNN 子类是分开的。

句法:

tf.layers.lstmCell(args)

参数:该函数包含一个args对象,该对象包含以下参数:

  • 循环激活:用于激活循环步骤。
  • unitForgetBias:它是一个布尔值,用于在初始化时忘记门。
  • implementation:它是一个整数或者指定实现模式。模式 1 用于将其操作构造为大量较小的点积和加法。模式 2 用于将它们分批处理为更少、更大的操作。
  • units:无论是整数还是输出空间的维数,都是一个数字。
  • 激活:用于要使用的函数。
  • useBias:该层是否使用偏置向量是一个布尔值。
  • kernelInitializer:用于输入的线性变换。
  • 循环初始化器:用于循环状态的线性变换。
  • biasInitializer:用于偏置向量。
  • kernelRegularizer:它是一个字符串,用于应用于核权重矩阵的正则化函数。
  • 经常性正则化器:它是一个字符串,用于应用于经常性内核权重矩阵的正则化函数。
  • biasRegularizer:它是一个字符串,用于应用于偏置向量的正则化函数。
  • kernelConstraint:它是一个字符串,用于应用于核权重矩阵的约束函数。
  • 循环约束:它是一个字符串,用于应用于循环内核权重矩阵的约束函数。
  • iasConstraint:它是一个字符串,用于应用于偏置向量的约束函数。
  • dropout:它是一个介于 0 和 1 之间的数字。对于输入的线性变换,要丢弃的单位的分数。
  • reverseDropout:它是一个介于 0 和 1 之间的数字。对于循环状态的线性变换,要丢弃的单位的分数。
  • inputShape:这是一个数字,用于创建要在该层之前插入的输入层。
  • batchInputShape:这是一个数字,用于创建要在该层之前插入的输入层。
  • batchSize:它是一个用于构造batchInputShape 的数字。
  • dtype:此参数仅适用于输入层。
  • name:这是一个用于图层的字符串。
  • trainable:它是一个布尔值,用于该层的权重是否可通过拟合更新。
  • weights:层的初始权重值。
  • inputDType:用于 Legacy 支持。它不适用于新代码。

返回值:返回 LSTMCell。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling the layers.lstmCell 
// function and printing the output
const cell = tf.layers.lstmCell({units: 3});
const input = tf.input({shape: [120]});
const output = cell.apply(input);
  
console.log(JSON.stringify(output.shape));


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
const cells = [
   tf.layers.lstmCell({units: 6}),
   tf.layers.lstmCell({units: 10}),
];
const rnn = tf.layers.rnn(
  {cell: cells, returnSequences: true}
);
  
// Create an input with 10 time steps 
// and a length-20 vector at each step
const input = tf.input({shape: [40, 60]});
const output = rnn.apply(input);
  
console.log(JSON.stringify(output.shape));


输出:

[null, 120]

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
const cells = [
   tf.layers.lstmCell({units: 6}),
   tf.layers.lstmCell({units: 10}),
];
const rnn = tf.layers.rnn(
  {cell: cells, returnSequences: true}
);
  
// Create an input with 10 time steps 
// and a length-20 vector at each step
const input = tf.input({shape: [40, 60]});
const output = rnn.apply(input);
  
console.log(JSON.stringify(output.shape));

输出:

[null, 30, 8]

参考: https://js.tensorflow.org/api/latest/#layers.lstmCell