📜  毫升 |使用 Sklearn 的投票分类器

📅  最后修改于: 2022-05-13 01:55:09.093000             🧑  作者: Mango

毫升 |使用 Sklearn 的投票分类器



  1. 硬投票:在硬投票中,预测的输出类是具有最高多数票的类,即每个分类器预测的概率最高的类。假设三个分类器预测输出类(A,A,B) ,所以这里大多数预测A作为输出。因此A将是最终的预测。
  2. 软投票:在软投票中,输出类是基于给定该类的概率平均值的预测。假设给三个模型的一些输入,类别A = (0.30, 0.47, 0.53)B = (0.20, 0.32, 0.40)的预测概率。所以A 类的平均值是 0.4333B 是 0.3067 ,获胜者显然是A类,因为它具有每个分类器平均的最高概率。


# importing libraries
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# loading iris dataset
iris = load_iris()
X = iris.data[:, :4]
Y = iris.target
# train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, 
                                                    test_size = 0.20, 
                                                    random_state = 42)
# group / ensemble of models
estimator = []
                  LogisticRegression(solver ='lbfgs', 
                                     multi_class ='multinomial', 
                                     max_iter = 200)))
estimator.append(('SVC', SVC(gamma ='auto', probability = True)))
estimator.append(('DTC', DecisionTreeClassifier()))
# Voting Classifier with hard voting
vot_hard = VotingClassifier(estimators = estimator, voting ='hard')
vot_hard.fit(X_train, y_train)
y_pred = vot_hard.predict(X_test)
# using accuracy_score metric to predict accuracy
score = accuracy_score(y_test, y_pred)
print("Hard Voting Score % d" % score)
# Voting Classifier with soft voting
vot_soft = VotingClassifier(estimators = estimator, voting ='soft')
vot_soft.fit(X_train, y_train)
y_pred = vot_soft.predict(X_test)
# using accuracy_score
score = accuracy_score(y_test, y_pred)
print("Soft Voting Score % d" % score)

输出 :

Hard Voting Score 1
Soft Voting Score 1


Input  :4.7, 3.2, 1.3, 0.2 
Output :Iris Setosa 

实际上,软投票的输出精度会更高,因为它是所有估计器组合的平均概率,对于我们的基本 iris 数据集,我们已经过拟合,因此输出不会有太大差异。