📜  投票分类器网格搜索 - Python (1)

📅  最后修改于: 2023-12-03 14:54:36.829000             🧑  作者: Mango

投票分类器网格搜索 - Python

在机器学习中,投票分类器是一种经常使用的技术。它将多个分类器的预测组合起来,最终的预测结果是根据多数投票的结果来决定的。对于分类问题,投票分类器通常能够提高预测的准确率。

然而,在使用投票分类器时,如何选择合适的分类器,以及如何调整这些分类器的超参数,是一个需要注意的问题。在这种情况下,网格搜索技术提供了一种有效的解决方案。本文将介绍如何使用Python的GridSearchCV函数,在给定的分类器集合和超参数范围内进行网格搜索调参。

实现方法
导入需要的库
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import VotingClassifier
载入数据集

在这个例子中,我们将会使用著名的鸢尾花数据集。这个数据集包含了3种不同种类的鸢尾花,每种鸢尾花有50个样本。

data = load_iris()
X = data.data
y = data.target
数据集划分

我们将数据集分成训练集和测试集,其中训练集占75%。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
定义分类器并设置超参数范围

我们将使用SVM,随机森林和逻辑回归三种分类器来构建投票分类器。我们还将为每个分类器设置超参数范围。

# SVM
svm_clf = SVC(probability=True)
svm_param_grid = {
    'kernel': ['linear', 'rbf', 'poly', 'sigmoid'],
    'gamma': ['scale', 'auto'],
    'C': [0.1, 1, 10, 100],
}

# 随机森林
rf_clf = RandomForestClassifier()
rf_param_grid = {
    'n_estimators': [50, 100, 150],
    'max_features': ['sqrt', 'log2'],
    'max_depth': [2, 4, 6, 8],
    'criterion': ['gini', 'entropy'],
}

# 逻辑回归
lr_clf = LogisticRegression()
lr_param_grid = {
    'penalty': ['l1', 'l2', 'elasticnet', 'none'],
    'C': [0.1, 1, 10, 100],
}
搜索最佳超参数组合

我们将以上述分类器和超参数范围为基础,使用GridSearchCV函数进行网格搜索。搜索过程中,我们将以准确率为指标,执行5折交叉验证。

estimators = [('svm', svm_clf), ('rf', rf_clf), ('lr', lr_clf)]
params = [svm_param_grid, rf_param_grid, lr_param_grid]

voting_clf = VotingClassifier(estimators=estimators, voting='soft')
grid_search = GridSearchCV(voting_clf, param_grid=params, cv=5, scoring='accuracy')
grid_search.fit(X_train, y_train)
输出最佳超参数

搜索完成后,我们可以通过best_params_属性查看最佳的超参数组合。

print("Best hyperparameters: ", grid_search.best_params_)
输出准确率

我们还可以使用训练好的投票分类器对测试集进行预测,并输出准确率。

y_pred = grid_search.predict(X_test)

print("Accuracy: ", (y_test == y_pred).mean())
结论

本文介绍了如何使用Python的GridSearchCV函数,在给定的分类器集合和超参数范围内进行网格搜索调参。通过搜索最佳的超参数组合,我们可以构建一个更准确的投票分类器。