📜  Tensorflow.js tf.topk()函数

📅  最后修改于: 2022-05-13 01:56:48.659000             🧑  作者: Mango

Tensorflow.js tf.topk()函数

Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。

tf.topk()函数以及最后一个维度也用于查找 k 个最大条目的值和索引。

句法:

tf.topk (x, k?, sorted?)

参数:

  • x:一维或更高的 tf.Tensor,最后一维至少为 k。
  • k:它是要查找的元素的数量。
  • 排序:它是布尔值。如果为真,则生成的 k 个元素将按值降序排序。

返回值: {值:tf.Tensor,索引:tf.Tensor}。它返回一个包含两个张量的对象,其中包含值和索引。

示例 1:

Javascript
const tf = require("@tensorflow/tfjs")
  
// Creating a 2d tensor
const a = tf.tensor2d([[1, 20, 3], [4, 3, 1], [8, 9, 10]]);
const {values, indices} = tf.topk(a);
  
// Printing the values and indices
values.print();
indices.print();


Javascript
const tf = require("@tensorflow/tfjs")
  
// Creating a 2d tensor
const a = tf.tensor2d([[1, 20, 3], [4, 3, 1], [8, 9, 10]]);
const {values, indices} = tf.topk(a, 3);
  
// Printing the values and indices
values.print();
indices.print();


输出:

Tensor
    [[20],
     [4 ],
     [10]]
Tensor
    [[1],
     [0],
     [2]]

示例 2:在此示例中,我们将提供参数 k,以获取最大的 k 个条目。

Javascript

const tf = require("@tensorflow/tfjs")
  
// Creating a 2d tensor
const a = tf.tensor2d([[1, 20, 3], [4, 3, 1], [8, 9, 10]]);
const {values, indices} = tf.topk(a, 3);
  
// Printing the values and indices
values.print();
indices.print();

输出:

当我们通过 k = 3 时,我们在结果中得到 3 个最大值。

Tensor
    [[20, 3, 1],
     [4 , 3, 1],
     [10, 9, 8]]
Tensor
    [[1, 2, 0],
     [0, 1, 2],
     [2, 1, 0]]

参考: https://js.tensorflow.org/api/latest/#topk