📜  K最近邻居的实现

📅  最后修改于: 2021-05-06 23:29:15             🧑  作者: Mango

先决条件: K个最近的邻居

介绍

假设我们获得了一组数据项,每个数据项均具有数值特征(例如身高,体重,年龄等)。如果特征数为n ,则可以将这些项表示为n维网格中的点。给定一个新项目,我们可以计算出该项目到集合中所有其他项目的距离。我们选择k个最近的邻居,然后看到大多数这些邻居被归类在哪里。我们在那里对新项目进行分类。

因此,问题就变成了我们如何计算项目之间的距离。解决方案取决于数据集。如果值是实数,我们通常使用欧几里得距离。如果值是分类的或二进制的,则通常使用汉明距离。

算法:

Given a new item:
    1. Find distances between new item and all other items
    2. Pick k shorter distances
    3. Pick the most common class in these k distances
    4. That class is where we will classify the new item

读取数据

让我们的输入文件采用以下格式:

Height, Weight, Age, Class
1.70, 65, 20, Programmer
1.90, 85, 33, Builder
1.78, 76, 31, Builder
1.73, 74, 24, Programmer
1.81, 75, 35, Builder
1.73, 70, 75, Scientist
1.80, 71, 63, Scientist
1.75, 69, 25, Programmer

每个项目都是一行,在“类”下我们可以看到该项目的分类位置。要素名称(“高度”等)下的值是该项目对该要素具有的值。所有值和功能均以逗号分隔。

将这些数据文件放置在工作目录data2和data中。选择一个并将内容按原样粘贴到名为data的文本文件中。

我们将从文件(名为“ data.txt”)中读取内容,并将输入内容按行分割:

f = open('data.txt', 'r');
lines = f.read().splitlines();
f.close();

文件的第一行包含要素名称,末尾带有关键字“ Class”。我们要将功能名称存储到列表中:

# Split the first line by commas,
# remove the first element and 
# save the rest into a list. The
# list now holds the feature 
# names of the data set.
features = lines[0].split(', ')[:-1];

然后,我们进入数据集本身。我们将这些项目保存到一个名为items的列表中,其元素为字典(每个项目一个)。这些项目字典的关键字是要素名称,加上用于保存项目类别的“类”。最后,我们要对列表中的项目进行洗牌(这是一项安全措施,以防这些项目处于怪异顺序)。

items = [];
  
for i in range(1, len(lines)):
      
    line = lines[i].split(', ');
  
    itemFeatures = {"Class" : line[-1]};
  
    # Iterate through the features
    for j in range(len(features)):
      
        # Get the feature at index j
        f = features[j]; 
         
        # The first item in the line
        # is the class, skip it
        v = float(line[j]);
          
        # Add feature to dict
        itemFeatures[f] = v; 
      
    # Append temp dict to items
    items.append(itemFeatures); 
      
shuffle(items);

分类数据

将数据存储到item中后,我们现在开始构建分类器。对于分类器,我们将创建一个新函数Classify 。我们将要分类的项目,项目列表和k (最接近的邻居数)作为输入。

如果k大于数据集的长度,我们将不继续进行分类,因为我们不能拥有比数据集中的项目总数更近的邻居。 (或者,我们可以将k设置为项目长度,而不是返回错误消息)

if(k > len(Items)):
        
        # k is larger than list
        # length, abort
        return "k larger than list length";

我们要计算要分类的项目与训练集中的所有项目之间的距离,最后保持k最短的距离。为了保持当前最近的邻居,我们使用一个列表,称为neighbors 。最少的每个元素都具有两个值,一个表示与要分类的项目的距离,另一个表示邻域所在类别的距离。我们将通过广义的欧几里得公式(对于n个维度)计算距离。然后,我们将选择大多数情况下在邻居中出现的班级,这将是我们的选择。在代码中:

def Classify(nItem, k, Items):
    if(k > len(Items)):
          
        # k is larger than list
        # length, abort
        return "k larger than list length";
      
    # Hold nearest neighbors.
    # First item is distance, 
    # second class
    neighbors = [];
  
    for item in Items:
        
        # Find Euclidean Distance
        distance = EuclideanDistance(nItem, item);
  
        # Update neighbors, either adding
        # the current item in neighbors 
        # or not.
        neighbors = UpdateNeighbors(neighbors, item, distance, k);
  
    # Count the number of each
    # class in neighbors
    count = CalculateNeighborsClass(neighbors, k);
  
    # Find the max in count, aka the
    # class with the most appearances.
    return FindMax(count);

我们需要实现的外部函数是EuclideanDistanceUpdateNeighborsCalculateNeighborsClassFindMax

寻找欧几里得距离

两个向量x和y的广义欧几里得公式是这样的:

distance = sqrt{(x_{1}-y_{1})^2 + (x_{2}-y_{2})^2 + ... + (x_{n}-y_{n})^2}

在代码中:

def EuclideanDistance(x, y):
      
    # The sum of the squared 
    # differences of the elements
    S = 0; 
      
    for key in x.keys():
        S += math.pow(x[key]-y[key], 2);
  
    # The square root of the sum
    return math.sqrt(S);

更新邻居

我们有我们的邻居列表(其最大长度应为k ),并且我们想要以给定的距离将一个项目添加到列表中。首先,我们将检查邻居的长度是否为k 。如果数量较少,则无论距离多长,我们都将其添加到其中(因为在开始拒绝商品之前,我们需要将列表填满最多k个)。如果不是,我们将检查该物品的距离是否短于列表中具有最大距离的物品。如果是这样,我们将用最大距离替换新项。

为了更快地找到最大距离,我们将列表按升序排列。因此,列表中的最后一项将具有最大距离。我们将其替换为新项目,然后再次排序。

为了加快此过程,我们可以实现插入排序,在该列表中,我们可以在列表中插入新项目,而不必对整个列表进行排序。该代码虽然很长,但虽然很简单,但会使教程陷入困境。

def UpdateNeighbors(neighbors, item, distance, k):
      
    if(len(neighbors) > distance):
              
            # If yes, replace the last
            # element with new item
            neighbors[-1] = [distance, item["Class"]];
            neighbors = sorted(neighbors);
  
    return neighbors;

CalculateNeighborsClass

在这里,我们将计算在邻居中最常出现的类。为此,我们将使用另一个称为count的字典,其中的键是出现在neighbors中的类名。如果某个键不存在,则将其添加,否则将增加其值。

def CalculateNeighborsClass(neighbors, k):
    count = {};
      
    for i in range(k):
          
        if(neighbors[i][1] not in count):
          
            # The class at the ith index
            # is not in the count dict.
            # Initialize it to 1.
            count[neighbors[i][1]] = 1;
        else:
              
            # Found another item of class 
            # c[i]. Increment its counter.
            count[neighbors[i][1]] += 1;
  
    return count;

FindMax

我们将在CalculateNeighborsClass中构建的字典计数输入到此函数,然后返回其最大值。

def FindMax(countList):
      
    # Hold the max
    maximum = -1;
      
    # Hold the classification
    classification = ""; 
      
    for key in countList.keys():
        
        if(countList[key] > maximum):
            maximum = countList[key];
            classification = key;
  
    return classification, maximum;

结论

这样,本kNN教程就完成了。

您现在可以对新项目进行分类,将k设置为您认为合适的值。通常对于k使用奇数,但这不是必需的。要对新项目进行分类,您需要创建一个词典,其中包含功能名称和表征该项目的值的键。分类示例:

newItem = {'Height' : 1.74, 'Weight' : 67, 'Age' : 22};
print Classify(newItem, 3, items);

上述方法的完整代码如下:

# Python Program to illustrate
# KNN algorithm
  
# For pow and sqrt
import math 
from random import shuffle
  
###_Reading_### def ReadData(fileName):
  
    # Read the file, splitting by lines
    f = open(fileName, 'r')
    lines = f.read().splitlines()
    f.close()
  
    # Split the first line by commas, 
    # remove the first element and save
    # the rest into a list. The list 
    # holds the feature names of the 
    # data set.
    features = lines[0].split(', ')[:-1]
  
    items = []
  
    for i in range(1, len(lines)):
          
        line = lines[i].split(', ')
  
        itemFeatures = {'Class': line[-1]}
  
        for j in range(len(features)):
              
            # Get the feature at index j
            f = features[j]  
  
            # Convert feature value to float
            v = float(line[j]) 
              
             # Add feature value to dict
            itemFeatures[f] = v
          
        items.append(itemFeatures)
  
    shuffle(items)
  
    return items
  
  
###_Auxiliary Function_### def EuclideanDistance(x, y):
      
    # The sum of the squared differences
    # of the elements
    S = 0  
      
    for key in x.keys():
        S += math.pow(x[key] - y[key], 2)
  
    # The square root of the sum
    return math.sqrt(S)
  
def CalculateNeighborsClass(neighbors, k):
    count = {}
  
    for i in range(k):
        if neighbors[i][1] not in count:
  
            # The class at the ith index is
            # not in the count dict. 
            # Initialize it to 1.
            count[neighbors[i][1]] = 1
        else:
  
            # Found another item of class 
            # c[i]. Increment its counter.
            count[neighbors[i][1]] += 1
  
    return count
  
def FindMax(Dict):
  
    # Find max in dictionary, return 
    # max value and max index
    maximum = -1
    classification = ''
  
    for key in Dict.keys():
          
        if Dict[key] > maximum:
            maximum = Dict[key]
            classification = key
  
    return (classification, maximum)
  
  
###_Core Functions_### def Classify(nItem, k, Items):
  
    # Hold nearest neighbours. First item
    # is distance, second class
    neighbors = []
  
    for item in Items:
  
        # Find Euclidean Distance
        distance = EuclideanDistance(nItem, item)
  
        # Update neighbors, either adding the
        # current item in neighbors or not.
        neighbors = UpdateNeighbors(neighbors, item, distance, k)
  
    # Count the number of each class 
    # in neighbors
    count = CalculateNeighborsClass(neighbors, k)
  
    # Find the max in count, aka the
    # class with the most appearances
    return FindMax(count)
  
  
def UpdateNeighbors(neighbors, item, distance,
                                          k, ):
    if len(neighbors) < k:
  
        # List is not full, add 
        # new item and sort
        neighbors.append([distance, item['Class']])
        neighbors = sorted(neighbors)
    else:
  
        # List is full Check if new 
        # item should be entered
        if neighbors[-1][0] > distance:
  
            # If yes, replace the 
            # last element with new item
            neighbors[-1] = [distance, item['Class']]
            neighbors = sorted(neighbors)
  
    return neighbors
  
###_Evaluation Functions_### def K_FoldValidation(K, k, Items):
      
    if K > len(Items):
        return -1
  
    # The number of correct classifications
    correct = 0  
      
    # The total number of classifications
    total = len(Items) * (K - 1)  
      
    # The length of a fold
    l = int(len(Items) / K)  
  
    for i in range(K):
  
        # Split data into training set
        # and test set
        trainingSet = Items[i * l:(i + 1) * l]
        testSet = Items[:i * l] + Items[(i + 1) * l:]
  
        for item in testSet:
            itemClass = item['Class']
  
            itemFeatures = {}
  
            # Get feature values
            for key in item:
                if key != 'Class':
  
                    # If key isn't "Class", add 
                    # it to itemFeatures
                    itemFeatures[key] = item[key]
  
            # Categorize item based on
            # its feature values
            guess = Classify(itemFeatures, k, trainingSet)[0]
  
            if guess == itemClass:
  
                # Guessed correctly
                correct += 1
  
    accuracy = correct / float(total)
    return accuracy
  
  
def Evaluate(K, k, items, iterations):
  
    # Run algorithm the number of
    # iterations, pick average
    accuracy = 0
      
    for i in range(iterations):
        shuffle(items)
        accuracy += K_FoldValidation(K, k, items)
  
    print accuracy / float(iterations)
  
  
###_Main_### def main():
    items = ReadData('data.txt')
  
    Evaluate(5, 5, items, 100)
  
if __name__ == '__main__':
    main()

输出:

0.9375

输出因机器而异。该代码包含Fold Validation函数,但它与算法无关,可用于计算算法的准确性。