📜  神经网络如何用于 R 编程中的分类(1)

📅  最后修改于: 2023-12-03 15:41:03.069000             🧑  作者: Mango

神经网络如何用于 R 编程中的分类

神经网络是一种模拟人类神经系统的算法模型,它广泛应用于分类问题。R 编程语言有多个包能够实现神经网络的建模和训练,其中最常用的是 neuralnet 包。本文将介绍如何使用 neuralnet 包在 R 编程中进行分类。

数据准备

首先需要准备训练数据和测试数据。训练数据应该是带有标签 (label) 的数据集,即每个样本都对应着它所属的一类。测试数据则是没有标签的数据集,它用于测试训练出来的神经网络分类器的准确性。

在本文中,我们将使用著名的鸢尾花 (iris) 数据集,它包含了 150 个样本,每个样本包含了 4 个特征 (花萼长度、花萼宽度、花瓣长度、花瓣宽度) 和一个标签 (setosa、versicolor、virginica)。首先需要将数据集分成训练集和测试集:

library(neuralnet)

# 读入鸢尾花数据集
data(iris)

# 将数据集分为训练集和测试集
set.seed(123)
train_index <- sample(1:nrow(iris), 100)
train_data <- iris[train_index, ]
test_data <- iris[-train_index, ]
构建神经网络分类器

建立一个神经网络分类器的过程分为 3 个步骤:定义网络结构、设置训练参数、训练神经网络。

定义网络结构

首先需要定义神经网络的结构,即有几个隐层 (hidden layer),每个隐层有几个神经元。在实际应用中,如何设置这些参数需要进行良好的调参工作。在本文中,我们将定义 2 个隐层,每个隐层都有 5 个神经元。

# 定义神经网络结构
nn <- neuralnet(
  Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width,
  data = train_data,
  hidden = c(5, 5),
  linear.output = FALSE,
  threshold = 0.01
)

在这个例子中,我们使用了 neuralnet 函数来定义神经网络结构。第一个参数是指定目标变量 (也就是标签) 与自变量之间的关系。我们使用鸢尾花的四个特征作为自变量,Species 作为目标变量。hidden 参数指定了神经网络的结构,这里设置为两个隐层,每个隐层有 5 个神经元。linear.output 参数指定是否将输出层设为线性输出,这里为 FALSE。threshold 参数指定一个阈值,当神经网络的误差小于这个阈值时终止训练过程。

设置训练参数

第二个步骤是设置训练参数。neuralnet 函数提供了一系列参数来控制训练过程,包括最大迭代次数 (maxit)、学习率 (learningrate)、正则化参数 (rep) 等。在这个例子中,我们将训练次数设置为 1000,学习率设置为 0.01:

# 设置训练参数
nn$learningrate <- 0.01
nn$maxit <- 1000
训练神经网络

最后一个步骤是训练神经网络。neuralnet 函数会自动进行反向传播算法来更新神经网络的权重和偏置 (bias),从而最小化误差。训练过程中可以使用 plot 函数来画出神经网络的误差曲线。

# 训练神经网络
nn <- neuralnet(
  Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width,
  data = train_data,
  hidden = c(5, 5),
  linear.output = FALSE,
  threshold = 0.01
)

# 画出误差曲线
plot(nn)
预测和评估

训练完神经网络分类器后,我们可以使用它来对新的数据进行分类。在 neuralnet 包中,可以使用 compute 函数来进行预测。

# 预测测试数据
test_outputs <- compute(nn, test_data[, 1:4])$net.result

# 将神经网络的输出转换为标签
pred_labels <- apply(test_outputs, 1, which.max)

在这里,我们调用 compute 函数来预测测试数据的标签。compute 函数会输出每个样本属于每一类的概率,我们需要将其转换为具体的标签。这里使用 which.max 函数找到最大概率的那个类别。

最后,我们需要评估分类器的准确性。在这里,可以使用混淆矩阵 (confusion matrix) 来评估分类器的性能。

# 计算混淆矩阵
table(pred_labels, test_data$Species)

混淆矩阵可以用来计算各种分类性能指标,如准确率 (accuracy)、精确率 (precision)、召回率 (recall) 等。具体计算方法可以参考相关的文献和资料。

结论

神经网络广泛用于分类问题中,它可以自动地学习特征的表示,具有较好的泛化性能。在 R 编程中,neuralnet 包提供了一种实现神经网络分类器的简单方法。在使用神经网络分类器时,需要进行参数调整和性能评估,才能得到一个准确可靠的分类器。