0x2a

Don't Panic.

Steve's avatar Steve

kNN算法的简单理解与实现

在模式识别领域中,最近邻居法(KNN算法,又译K-近邻算法)是一种用于分类和回归的非参数统计方法。

K近邻算法是机器学习算法中最简单的算法之一,他的原理非常的简单直观。

算法思路

KNN是通过测量不同特征值之间的距离进行分类。它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

你可能会觉得 这都是些啥?? 别灰心 我们看个简单的例子:

如下图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类.

knn.jpg

显而易见的,K的取值将会影响结果的准确率

而在KNN中,一般使用距离来作为各个对象之间的相似指数 常用的有

  1. 欧氏距离
    dis1
  2. 曼哈顿距离

    坐标(x1, y1)的点P1与坐标(x2, y2)的点P2的曼哈顿距离表示为:

    dis2

KNN实例

对机器学习有所了解的同学可能有听说过MNIST数据集.

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下.

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)

Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

通常大家会使用cNN来对这个手写数据集来完成机器学习这个”hello world”级别的项目 不过我这里想说的是, 尽管最简单的kNN, 一样可以对这个问题进行求解.

示例代码

from collections import defaultdict
import struct #用作读取数据集
import numpy as np#用作矩阵计算以及图像向量化

def read_image(file_path):  
'''读取图像函数 不是主要的部分'''
    f_open=open(file_path,"rb")  
    content=f_open.read()  
    index=0  
    magic, num_images,num_rows,num_columns=struct.unpack_from(">IIII",content,index) # 以大端法读入四个unsigned int  
    print("number of images:"+str(num_images))  
    print("number of rows:"+str(num_rows))  
    print("number of columns:"+str(num_columns))  
    index+=struct.calcsize(">IIII")  
    img_piexl=[]  
    for i in range(num_images):  
        piexl_all=[]  
        for j in range(num_columns):  
            for k in range(num_rows):  
                piexl=struct.unpack_from(">B",content,index)  

                piexl=int(piexl[0])  
                if piexl<127:  
                    piexl=0  
                else:  
                    piexl=1  
                ###  
                piexl_all.append(piexl)  
                index+=struct.calcsize(">B")  
        piexl_all=np.array(piexl_all)  
        img_piexl.append(piexl_all)  
        """ 
        print(piexl_all) 
        piexl_all=piexl_all.reshape(28,28) 
        fig=plt.figure() 
        plotwindow=fig.add_subplot(111) 
        plotwindow.imshow(piexl_all,cmap="gray") 
        plt.show() #这里注释去掉可以看到图片
        """  
        if i%1000==0:  
            print(str(i)+"images have been processed")  
    f_open.close()  
    return img_piexl  

def read_label(file_path):  
    f_open=open(file_path,"rb")  
    content=f_open.read()  
    index=0  
    magic, num_items=struct.unpack_from(">II",content,index)  
    print("number of labels:"+str(num_items))  
    index+=struct.calcsize(">II")  
    label_num=[]  
    for i in range(num_items):  
        label=struct.unpack_from(">B",content,index)  
        label_num.append(int(label[0]))  
        index+=struct.calcsize(">B")  
        if i%1000==0:  
            print(str(i)+"labels have been processed!")  
    f_open.close()  
    return label_num

读取一下训练照片样本

image_train_file_path="mnist/train-images.idx3-ubyte"  
image_train_piexl = read_image(image_train_file_path)

读取测试样本

image_test_file_path="mnist/t10k-images.idx3-ubyte"  
label_test_file_path="mnist/t10k-labels.idx1-ubyte" 
image_test_piexl=read_image(image_test_file_path)  
label_test=read_label(label_test_file_path)

图像向量化并计算距离值

def calc_dis(train_image,test_image):  
    dist=np.linalg.norm(train_image-test_image)  
    return dist
def find_labels(k,train_images,train_labels,test_image):  
    all_dis = []  
    labels=defaultdict(int)  
    for i in range(len(train_images)):  
        dis = np.linalg.norm(train_images[i]-test_image)  
        all_dis.append(dis)  
    sorted_dis = np.argsort(all_dis)  
    count = 0  
    while (count < k):  
        labels[train_labels[sorted_dis[count]]]+=1  
        count += 1  
    return labels  

def knn_all(k,train_images,train_labels,test_images):  
    print("start knn_all!")  
    res=[]  
    count=0  
    for i in range(len(test_images)):  
        labels=find_labels(k,train_images,train_labels,test_images[i])  
        res.append(max(labels))  
        if count%50==0:  
            print("%d has been processed!"%(count))  
        count+=1  
    return res  

def calc_precision(res,test_labels):  
    f_res_open=open("res.txt","a+")  
    precision=0  
    for i in range(len(res)):  
        f_res_open.write("res:"+str(res[i])+"\n")  
        f_res_open.write("test:"+str(test_labels[i])+"\n")  
        if res[i]==test_labels[i]:  
            precision+=1  
    return precision/len(res)

最后就可以直接使用这个算法去算一下我们的准确率了

k=1
res=knn_all(k,image_train_piexl,label_train,image_test_piexl)  
print("precision:"+str(calc_precision(res,label_test)))

输出结果

start knn_all!
0 has been processed!
1000 has been processed!
2000 has been processed!
3000 has been processed!
4000 has been processed!
5000 has been processed!
6000 has been processed!
7000 has been processed!
8000 has been processed!
9000 has been processed!
precision:0.9562

我们发现 就算是最简单的KNN算法,用来手写图像识别 一样有95%这样不俗的准确率.

不过K近邻算法对K取值的要求还是蛮高的, 像这里 反而取1会发现他的准确率最高

KNN的缺点在于计算量较大, 需要计算测试值与每一个样本值的距离,并对他们进行比较. 所以耗费的时间还是蛮大的.

加上图片的信息率会比普通的数据要高, 所以我的辣鸡集显电脑跑了一个多小时才跑完这个简单的例子.

写在最后

虽然机器学习的学习过程中,数学基础是十分重要的, 但不可否认的是, 作为一个工科学习者, 工程思想的建立还是尤为重要的.

像本文的KNN算法, 用语言叙述起来原理是十分简单的, 但是应用到实际的例子中由于涉及到语言的语法、数据的处理等问题,会导致我们在实际的使用当中并没有那么顺利.

因此, 我们在学习机器学习的过程中, 一定要重视代码的实现, 这样不仅可以提高我们对算法的理解程度, 更可以提升我们学习的满足感,,让我们事半功倍.