📜  Tensorflow.js 简介(1)

📅  最后修改于: 2023-12-03 14:47:56.187000             🧑  作者: Mango

TensorFlow.js 简介

TensorFlow.js 是由 Google 开发的用于在浏览器中训练和部署机器学习模型的库。它有许多强大的功能,包括对现有 TensorFlow 模型的加载和使用,使用 JavaScript 进行模型的训练和推理,以及对图像和音频数据的处理和分析。

安装和使用

在使用 TensorFlow.js 之前,需要使用 npm 或 yarn 进行安装。可以在 Node.js 应用程序中使用它,也可以将其引入到浏览器中。

以下是安装 TensorFlow.js 的示例代码:

# 使用 npm 安装 TensorFlow.js
npm install @tensorflow/tfjs

# 使用 yarn 安装 TensorFlow.js
yarn add @tensorflow/tfjs

要在浏览器中使用 TensorFlow.js,可以使用以下 CDN 引入代码:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.7.0"></script>
加载 TensorFlow 模型

TensorFlow.js 支持从 Keras 或 TensorFlow SavedModel 加载训练好的模型。以下是加载 TensorFlow 模型的示例代码:

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('path/to/model.json');
JavaScript 中的模型训练

TensorFlow.js 提供了一组 API,可以在 JavaScript 中训练模型。例如,通过以下代码,可以创建一个简单的线性回归模型并训练它:

import * as tf from '@tensorflow/tfjs';

const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));

model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });

const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

await model.fit(xs, ys, {
  epochs: 10,
  callbacks: {
    onEpochEnd: (_, logs) => console.log(`Loss: ${logs.loss}`),
  },
});
JavaScript 中的模型推理

TensorFlow.js 提供了可以在浏览器中执行模型推理的 API。例如,以下代码演示了如何使用 MNIST 手写数字识别模型对图像进行分类:

import * as tf from '@tensorflow/tfjs';
import { loadGraphModel } from '@tensorflow/tfjs-converter';

const model = await loadGraphModel('path/to/mnist/model.json');

const image = tf.browser.fromPixels(document.getElementById('image'));

const logits = model.predict(image);
const prediction = tf.argMax(logits, -1).dataSync()[0];

console.log(`Prediction: ${prediction}`);
结论

TensorFlow.js 是一种非常有用的工具,可以在浏览器中构建和部署机器学习模型。它提供了许多功能,可让您轻松加载现有模型、训练新模型并执行模型推理。