📜  Tensorflow.js tf.pool()函数

📅  最后修改于: 2022-05-13 01:56:34.429000             🧑  作者: Mango

Tensorflow.js tf.pool()函数

简介: Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.pool()函数用于执行 ND 池化功能。

句法:

tf.pool(input, windowShape, poolingType, pad, dilations?, strides?)

参数:

  • 输入:指定的输入张量,其等级为 4 或等级 3,形状为:[batch, height, width, inChannels]。此外,如果等级为 3,则假定批次大小为 1。它可以是 tf.Tensor3D、tf.Tensor4D、TypedArray 或 Array 类型。
  • windowShape:规定的过滤器大小:[filterHeight, filterWidth]。它可以是 [number, number] 或 number 类型。如果filterSize是一个单数,那么 filterHeight == filterWidth。
  • poolingType:指定的池类型,可以是“max”或“avg”。
  • pad:用于填充的规定类型的算法。它的类型可以是 valid、same、number 或 conv_util.ExplicitPadding。
    1. 在这里,对于相同的步幅和步长 1,输出将具有与输入相同的大小,而与滤波器大小无关。
    2. 因为,“有效”输出应小于输入,以防过滤器大小大于 1*1×1。
  • dilations:规定的扩张率:[dilationHeight, dilationWidth] 输入值在扩张池中的高度和宽度维度上进行采样。默认值为 [1, 1]。此外,如果 dilations 是单个数字,则 dilationHeight == dilationWidth。如果它大于 1,那么步幅的所有值都应该是 1。它是可选的,并且是 [number, number], number 类型。
  • 步幅:池化的规定步幅:[strideHeight, strideWidth]。如果 strides 是一个单数,那么 strideHeight == strideWidth。它是可选的,类型为 [number, number] 或 number。

返回值:返回 tf.Tensor3D 或 tf.Tensor4D。

示例 1:

Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining input tensor
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
  
// Calling pool() method
const result = tf.pool(x, 3, 'avg', 'same', [1, 2], 1);
   
// Printing output
result.print();


Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling pool() method
tf.tensor3d([1.2, 2.1, 3.0, -4], [2, 2, 1]).pool(3,
                    'conv_util.ExplicitPadding', 1, 1).print();


输出:

Tensor
    [[[0.4444444],
      [0.6666667]],

     [[0.4444444],
      [0.6666667]]]

示例 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling pool() method
tf.tensor3d([1.2, 2.1, 3.0, -4], [2, 2, 1]).pool(3,
                    'conv_util.ExplicitPadding', 1, 1).print();

输出:

Tensor
    [[[3],
      [3]],

     [[3],
      [3]]]

参考: https://js.tensorflow.org/api/latest/#pool