📜  Tensorflow.js tf.layers.reshape()函数

📅  最后修改于: 2022-05-13 01:56:36.844000             🧑  作者: Mango

Tensorflow.js tf.layers.reshape()函数

Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

tf.layers.reshape()函数用于将输入重塑为特定形状。

句法:

tf.layers.reshape(args) 

参数:此函数将args对象作为参数,该参数可以具有以下属性:

  • targetShape:它是一个不包括批处理轴的数字。
  • inputShape:这是一个数字,用于创建要在该层之前插入的输入层。
  • batchInputShape:这是一个数字,用于创建要在该层之前插入的输入层。
  • batchSize:它是一个数字,用于构造batchInputShape。
  • dtype:该层的数据类型。
  • name:这是该层的字符串。
  • trainable:这是一个布尔值,其中该层的权重是否可以通过拟合更新。
  • weights:层的初始权重值。
  • inputDtype:用于旧版支持。不要用于新代码。

返回值:它返回重塑。

下面的示例演示了使用 tf.layers.reshape()函数对图层进行重塑。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining the tensor input elements 
const input = tf.input({shape: [2, 6]});
  
// Calling the layers.reshape ( ) function
const reshapeLayer = tf.layers.reshape({targetShape: [3, 9]});
  
// Inspect the inferred output shape of the
// Reshape layer, which equals `[null, 3, 9]`. 
// (The 1st dimension is the undermined batch size.)
console.log(JSON.stringify(
    reshapeLayer.apply(input).shape));


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining the tensor input elements 
const input = tf.input({shape: [4, 8]});
  
// Calling the layers.reshape ( ) function
const reshapeLayer = 
    tf.layers.reshape({targetShape: [4, 8]});
  
// Inspect the inferred output shape of
// the Reshape layer, which equals `[null, 4, 8]`. 
// (The 1st dimension is the undermined batch size.)
console.log(JSON.stringify(
    reshapeLayer.apply(input).shape));


输出:

[null, 3, 9]

示例 2:在此示例中,我们讨论的是图层的重塑。

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining the tensor input elements 
const input = tf.input({shape: [4, 8]});
  
// Calling the layers.reshape ( ) function
const reshapeLayer = 
    tf.layers.reshape({targetShape: [4, 8]});
  
// Inspect the inferred output shape of
// the Reshape layer, which equals `[null, 4, 8]`. 
// (The 1st dimension is the undermined batch size.)
console.log(JSON.stringify(
    reshapeLayer.apply(input).shape));

输出:

[null, 4, 8]

参考: https://js.tensorflow.org/api/latest/#layers.reshape