📜  Tensorflow.js tf.layers.repeatVector()函数

📅  最后修改于: 2022-05-13 01:56:39.993000             🧑  作者: Mango

Tensorflow.js tf.layers.repeatVector()函数

Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

tf.layers.repeatVector()函数用于在新的指定维度上重复输入 n 次。它是 TensorFlow 的.js 库的内置函数。

句法:

tf.layers.repeatVector(n)

参数:

  • n:整数,指定输入将重复的次数。

返回值:返回 tf.layers.Layer

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
    
  
// Create a new model
const model = tf.sequential();
    
// Add repeatVector layer to the model
model.add(tf.layers.repeatVector({
    n: 5, inputShape: [2]}
));
    
const x = tf.tensor2d([[10, 15]]);
   
console.log(model.predict(x).shape)


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
    
// Create a new model
const model = tf.sequential();
    
// Add repeatVector layer to the model
model.add(tf.layers.repeatVector(
      {n: 8, inputShape: [2]}
));
   
const x = tf.tensor2d([[0,1]]);
   
model.predict(x).print();


输出:

1, 5, 2

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
    
// Create a new model
const model = tf.sequential();
    
// Add repeatVector layer to the model
model.add(tf.layers.repeatVector(
      {n: 8, inputShape: [2]}
));
   
const x = tf.tensor2d([[0,1]]);
   
model.predict(x).print();

输出:

Tensor
    [[[0, 1],
      [0, 1],
      [0, 1],
      [0, 1],
      [0, 1],
      [0, 1],
      [0, 1],
      [0, 1]]]

参考: https://js.tensorflow.org/api/latest/#layers.repeatVector