【Python机器学习】实验06 KNN最近邻算法

这篇具有很好参考价值的文章主要介绍了【Python机器学习】实验06 KNN最近邻算法。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

KNN算法

1. k k k近邻法是基本且简单的分类与回归方法。 k k k近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的 k k k个最近邻训练实例点,然后利用这 k k k个训练实例点的类的多数来预测输入实例点的类。

2. k k k近邻模型对应于基于训练数据集对特征空间的一个划分。 k k k近邻法中,当训练集、距离度量、 k k k值及分类决策规则确定后,其结果唯一确定,没有近似,他没有学习参数。

3. k k k近邻法三要素:距离度量、 k k k值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的pL距离。 k k k值小时, k k k近邻模型更复杂; k k k值大时, k k k近邻模型更简单。 k k k值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的 k k k

常用的分类决策规则是多数表决,对应于经验风险最小化。

4. k k k近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对 k k k维空间的一个划分,其每个结点对应于 k k k维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。

前言 距离度量

在机器学习算法中,我们经常需要计算样本之间的相似度,通常的做法是计算样本之间的距离。

x x x y y y为两个向量,求它们之间的距离。

这里用Numpy实现,设和为ndarray <numpy.ndarray>,它们的shape都是(N,)

d d d为所求的距离,是个浮点数(float)。

(1) 欧式距离

欧几里得度量(euclidean metric)(也称欧氏距离)是一个通常采用的距离定义,指在 m m m维空间中两个点之间的真实距离,或者向量的自然长度(即该点到原点的距离)。在二维和三维空间中的欧氏距离就是两点之间的实际距离。

距离公式:

d ( x , y ) = ∑ i ( x i − y i ) 2 d\left( x,y \right) = \sqrt{\sum_{i}^{}(x_{i} - y_{i})^{2}} d(x,y)=i(xiyi)2

代码实现:

def euclidean(x, y):
    return np.sqrt(np.sum((x - y)**2))

(2) 曼哈顿距离(Manhattan distance)

想象你在城市道路里,要从一个十字路口开车到另外一个十字路口,驾驶距离是两点间的直线距离吗?显然不是,除非你能穿越大楼。实际驾驶距离就是这个“曼哈顿距离”。而这也是曼哈顿距离名称的来源,曼哈顿距离也称为城市街区距离(City Block distance)。

距离公式:
d ( x , y ) = ∑ i ∣ x i − y i ∣ d(x,y) = \sum_{i}^{}|x_{i} - y_{i}| d(x,y)=ixiyi

代码实现:

def manhatan_distance(x,y):
    return np.sum(np.abs(x-y))

(3) 切比雪夫距离(Chebyshev distance)

在数学中,切比雪夫距离(Chebyshev distance)或是L∞度量,是向量空间中的一种度量,二个点之间的距离定义是其各坐标数值差绝对值的最大值。以数学的观点来看,切比雪夫距离是由一致范数(uniform norm)(或称为上确界范数)所衍生的度量,也是超凸度量(injective metric space)的一种。

距离公式:

d ( x , y ) = max ⁡ i ∣ x i − y i ∣ d\left( x,y \right) = \max_{i}\left| x_{i} - y_{i} \right| d(x,y)=imaxxiyi

若将国际象棋棋盘放在二维直角座标系中,格子的边长定义为1,座标的 x x x轴及 y y y轴和棋盘方格平行,原点恰落在某一格的中心点,则王从一个位置走到其他位置需要的步数恰为二个位置的切比雪夫距离,因此切比雪夫距离也称为棋盘距离。例如位置F6和位置E2的切比雪夫距离为4。任何一个不在棋盘边缘的位置,和周围八个位置的切比雪夫距离都是1。

代码实现:

def chebysev_distance(x,y):
    return np.max(np.abs(x-y))

(4) 闵可夫斯基距离(Minkowski distance)

闵氏空间指狭义相对论中由一个时间维和三个空间维组成的时空,为俄裔德国数学家闵可夫斯基(H.Minkowski,1864-1909)最先表述。他的平坦空间(即假设没有重力,曲率为零的空间)的概念以及表示为特殊距离量的几何学是与狭义相对论的要求相一致的。闵可夫斯基空间不同于牛顿力学的平坦空间。 p p p取1或2时的闵氏距离是最为常用的, p = 2 p= 2 p=2即为欧氏距离,而 p = 1 p =1 p=1时则为曼哈顿距离。

p p p取无穷时的极限情况下,可以得到切比雪夫距离。

距离公式:

d ( x , y ) = ( ∑ i ∣ x i − y i ∣ p ) 1 p d\left( x,y \right) = \left( \sum_{i}^{}|x_{i} - y_{i}|^{p} \right)^{\frac{1}{p}} d(x,y)=(ixiyip)p1

代码实现:

def minkowski(x, y, p):
    return np.sum(np.abs(x - y)**p)**(1 / p)

(5) 汉明距离(Hamming distance)

汉明距离是使用在数据传输差错控制编码里面的,汉明距离是一个概念,它表示两个(相同长度)字对应位不同的数量,我们以表示两个字,之间的汉明距离。对两个字符串进行异或运算,并统计结果为1的个数,那么这个数就是汉明距离。

距离公式:

d ( x , y ) = 1 N ∑ i 1 x i ≠ y i d\left( x,y \right) = \frac{1}{N}\sum_{i}^{}1_{x_{i} \neq y_{i}} d(x,y)=N1i1xi=yi

def hamming(x,y):
    return np.sum(x!=y)/len(x)

(6) 余弦相似度(Cosine Similarity)

余弦相似性通过测量两个向量的夹角的余弦值来度量它们之间的相似性。0度角的余弦值是1,而其他任何角度的余弦值都不大于1;并且其最小值是-1。从而两个向量之间的角度的余弦值确定两个向量是否大致指向相同的方向。两个向量有相同的指向时,余弦相似度的值为1;两个向量夹角为90°时,余弦相似度的值为0;两个向量指向完全相反的方向时,余弦相似度的值为-1。这结果是与向量的长度无关的,仅仅与向量的指向方向相关。余弦相似度通常用于正空间,因此给出的值为0到1之间。

二维空间为例,上图的 a a a b b b是两个向量,我们要计算它们的夹角θ。余弦定理告诉我们,可以用下面的公式求得:

cos ⁡ θ = a 2 + b 2 − c 2 2 a b \cos\theta = \frac{a^{2} + b^{2} - c^{2}}{2ab} cosθ=2aba2+b2c2

假定 a a a向量是 [ x 1 , y 1 ] \left\lbrack x_{1},y_{1} \right\rbrack [x1,y1] b b b向量是 [ x 2 , y 2 ] \left\lbrack x_{2},y_{2} \right\rbrack [x2,y2],两个向量间的余弦值可以通过使用欧几里得点积公式求出:

cos ⁡ ( θ ) = A ⋅ B ∥ A ∥ ∥ B ∥ = ∑ i = 1 n A i × B i ∑ i = 1 n ( A i ) 2 × ∑ i = 1 n ( B i ) 2 \cos\left( \theta \right) = \frac{A \cdot B}{\parallel A \parallel \parallel B \parallel} = \frac{\sum_{i = 1}^{n}A_{i} \times B_{i}}{\sqrt{\sum_{i = 1}^{n}(A_{i})^{2} \times \sqrt{\sum_{i = 1}^{n}(B_{i})^{2}}}} cos(θ)=A∥∥BAB=i=1n(Ai)2×i=1n(Bi)2 i=1nAi×Bi

cos ⁡ ( θ ) = A ⋅ B ∥ A ∥ ∥ B ∥ = ( x 1 , y 1 ) ⋅ ( x 2 , y 2 ) x 1 2 + y 1 2 × x 2 2 + y 2 2 = x 1 x 2 + y 1 y 2 x 1 2 + y 1 2 × x 2 2 + y 2 2 \cos\left( \theta \right) = \frac{A \cdot B}{\parallel A \parallel \parallel B \parallel} = \frac{\left( x_{1},y_{1} \right) \cdot \left( x_{2},y_{2} \right)}{\sqrt{x_{1}^{2} + y_{1}^{2}} \times \sqrt{x_{2}^{2} + y_{2}^{2}}} = \frac{x_{1}x_{2} + y_{1}y_{2}}{\sqrt{x_{1}^{2} + y_{1}^{2}} \times \sqrt{x_{2}^{2} + y_{2}^{2}}} cos(θ)=A∥∥BAB=x12+y12 ×x22+y22 (x1,y1)(x2,y2)=x12+y12 ×x22+y22 x1x2+y1y2

如果向量 a a a b b b不是二维而是 n n n维,上述余弦的计算法仍然正确。假定 A A A B B B是两个 n n n维向量, A A A [ A 1 , A 2 , … , A n ] \left\lbrack A_{1},A_{2},\ldots,A_{n} \right\rbrack [A1,A2,,An] B B B [ B 1 , B 2 , … , B n ] \left\lbrack B_{1},B_{2},\ldots,B_{n} \right\rbrack [B1,B2,,Bn],则 A A A B B B的夹角余弦等于:

cos ⁡ ( θ ) = A ⋅ B ∥ A ∥ ∥ B ∥ = ∑ i = 1 n A i × B i ∑ i = 1 n ( A i ) 2 × ∑ i = 1 n ( B i ) 2 \cos\left( \theta \right) = \frac{A \cdot B}{\parallel A \parallel \parallel B \parallel} = \frac{\sum_{i = 1}^{n}A_{i} \times B_{i}}{\sqrt{\sum_{i = 1}^{n}(A_{i})^{2}} \times \sqrt{\sum_{i = 1}^{n}(B_{i})^{2}}} cos(θ)=A∥∥BAB=i=1n(Ai)2 ×i=1n(Bi)2 i=1nAi×Bi

代码实现:

def square_rooted(x):
    return np.sqrt(np.sum(np.power(x,2)))
def cosine_similarity_distance(x,y):
    fenzi=np.sum(np.multiply(x,y))
    fenmu=square_rooted(x)*square_rooted(y)
    return fenzi/fenmu
import numpy as np
print(cosine_similarity_distance([3, 45, 7, 2], [2, 54, 13, 15]))
0.9722842517123499

KNN算法介绍

1. k k k近邻法是基本且简单的分类与回归方法。 k k k近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的 k k k个最近邻训练实例点,然后利用这 k k k个训练实例点的类的多数来预测输入实例点的类。

2. k k k近邻模型对应于基于训练数据集对特征空间的一个划分。 k k k近邻法中,当训练集、距离度量、 k k k值及分类决策规则确定后,其结果唯一确定。

3. k k k近邻法三要素:距离度量、 k k k值的选择和分类决策规则。常用的距离度量是欧氏距离。 k k k值小时, k k k近邻模型更复杂; k k k值大时, k k k近邻模型更简单。 k k k值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的 k k k

常用的分类决策规则是多数表决,对应于经验风险最小化。

4. k k k近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对 k k k维空间的一个划分,其每个结点对应于 k k k维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。

python实现,遍历所有数据点,找出 n n n个距离最近的点的分类情况,少数服从多数

1 数据的准备

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter

导入鸢尾花数据集

iris = load_iris()
iris
{'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        [5.4, 3.9, 1.7, 0.4],
        [4.6, 3.4, 1.4, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [4.4, 2.9, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.7, 1.5, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3. , 1.4, 0.1],
        [4.3, 3. , 1.1, 0.1],
        [5.8, 4. , 1.2, 0.2],
        [5.7, 4.4, 1.5, 0.4],
        [5.4, 3.9, 1.3, 0.4],
        [5.1, 3.5, 1.4, 0.3],
        [5.7, 3.8, 1.7, 0.3],
        [5.1, 3.8, 1.5, 0.3],
        [5.4, 3.4, 1.7, 0.2],
        [5.1, 3.7, 1.5, 0.4],
        [4.6, 3.6, 1. , 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [4.8, 3.4, 1.9, 0.2],
        [5. , 3. , 1.6, 0.2],
        [5. , 3.4, 1.6, 0.4],
        [5.2, 3.5, 1.5, 0.2],
        [5.2, 3.4, 1.4, 0.2],
        [4.7, 3.2, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5.2, 4.1, 1.5, 0.1],
        [5.5, 4.2, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.2],
        [5. , 3.2, 1.2, 0.2],
        [5.5, 3.5, 1.3, 0.2],
        [4.9, 3.6, 1.4, 0.1],
        [4.4, 3. , 1.3, 0.2],
        [5.1, 3.4, 1.5, 0.2],
        [5. , 3.5, 1.3, 0.3],
        [4.5, 2.3, 1.3, 0.3],
        [4.4, 3.2, 1.3, 0.2],
        [5. , 3.5, 1.6, 0.6],
        [5.1, 3.8, 1.9, 0.4],
        [4.8, 3. , 1.4, 0.3],
        [5.1, 3.8, 1.6, 0.2],
        [4.6, 3.2, 1.4, 0.2],
        [5.3, 3.7, 1.5, 0.2],
        [5. , 3.3, 1.4, 0.2],
        [7. , 3.2, 4.7, 1.4],
        [6.4, 3.2, 4.5, 1.5],
        [6.9, 3.1, 4.9, 1.5],
        [5.5, 2.3, 4. , 1.3],
        [6.5, 2.8, 4.6, 1.5],
        [5.7, 2.8, 4.5, 1.3],
        [6.3, 3.3, 4.7, 1.6],
        [4.9, 2.4, 3.3, 1. ],
        [6.6, 2.9, 4.6, 1.3],
        [5.2, 2.7, 3.9, 1.4],
        [5. , 2. , 3.5, 1. ],
        [5.9, 3. , 4.2, 1.5],
        [6. , 2.2, 4. , 1. ],
        [6.1, 2.9, 4.7, 1.4],
        [5.6, 2.9, 3.6, 1.3],
        [6.7, 3.1, 4.4, 1.4],
        [5.6, 3. , 4.5, 1.5],
        [5.8, 2.7, 4.1, 1. ],
        [6.2, 2.2, 4.5, 1.5],
        [5.6, 2.5, 3.9, 1.1],
        [5.9, 3.2, 4.8, 1.8],
        [6.1, 2.8, 4. , 1.3],
        [6.3, 2.5, 4.9, 1.5],
        [6.1, 2.8, 4.7, 1.2],
        [6.4, 2.9, 4.3, 1.3],
        [6.6, 3. , 4.4, 1.4],
        [6.8, 2.8, 4.8, 1.4],
        [6.7, 3. , 5. , 1.7],
        [6. , 2.9, 4.5, 1.5],
        [5.7, 2.6, 3.5, 1. ],
        [5.5, 2.4, 3.8, 1.1],
        [5.5, 2.4, 3.7, 1. ],
        [5.8, 2.7, 3.9, 1.2],
        [6. , 2.7, 5.1, 1.6],
        [5.4, 3. , 4.5, 1.5],
        [6. , 3.4, 4.5, 1.6],
        [6.7, 3.1, 4.7, 1.5],
        [6.3, 2.3, 4.4, 1.3],
        [5.6, 3. , 4.1, 1.3],
        [5.5, 2.5, 4. , 1.3],
        [5.5, 2.6, 4.4, 1.2],
        [6.1, 3. , 4.6, 1.4],
        [5.8, 2.6, 4. , 1.2],
        [5. , 2.3, 3.3, 1. ],
        [5.6, 2.7, 4.2, 1.3],
        [5.7, 3. , 4.2, 1.2],
        [5.7, 2.9, 4.2, 1.3],
        [6.2, 2.9, 4.3, 1.3],
        [5.1, 2.5, 3. , 1.1],
        [5.7, 2.8, 4.1, 1.3],
        [6.3, 3.3, 6. , 2.5],
        [5.8, 2.7, 5.1, 1.9],
        [7.1, 3. , 5.9, 2.1],
        [6.3, 2.9, 5.6, 1.8],
        [6.5, 3. , 5.8, 2.2],
        [7.6, 3. , 6.6, 2.1],
        [4.9, 2.5, 4.5, 1.7],
        [7.3, 2.9, 6.3, 1.8],
        [6.7, 2.5, 5.8, 1.8],
        [7.2, 3.6, 6.1, 2.5],
        [6.5, 3.2, 5.1, 2. ],
        [6.4, 2.7, 5.3, 1.9],
        [6.8, 3. , 5.5, 2.1],
        [5.7, 2.5, 5. , 2. ],
        [5.8, 2.8, 5.1, 2.4],
        [6.4, 3.2, 5.3, 2.3],
        [6.5, 3. , 5.5, 1.8],
        [7.7, 3.8, 6.7, 2.2],
        [7.7, 2.6, 6.9, 2.3],
        [6. , 2.2, 5. , 1.5],
        [6.9, 3.2, 5.7, 2.3],
        [5.6, 2.8, 4.9, 2. ],
        [7.7, 2.8, 6.7, 2. ],
        [6.3, 2.7, 4.9, 1.8],
        [6.7, 3.3, 5.7, 2.1],
        [7.2, 3.2, 6. , 1.8],
        [6.2, 2.8, 4.8, 1.8],
        [6.1, 3. , 4.9, 1.8],
        [6.4, 2.8, 5.6, 2.1],
        [7.2, 3. , 5.8, 1.6],
        [7.4, 2.8, 6.1, 1.9],
        [7.9, 3.8, 6.4, 2. ],
        [6.4, 2.8, 5.6, 2.2],
        [6.3, 2.8, 5.1, 1.5],
        [6.1, 2.6, 5.6, 1.4],
        [7.7, 3. , 6.1, 2.3],
        [6.3, 3.4, 5.6, 2.4],
        [6.4, 3.1, 5.5, 1.8],
        [6. , 3. , 4.8, 1.8],
        [6.9, 3.1, 5.4, 2.1],
        [6.7, 3.1, 5.6, 2.4],
        [6.9, 3.1, 5.1, 2.3],
        [5.8, 2.7, 5.1, 1.9],
        [6.8, 3.2, 5.9, 2.3],
        [6.7, 3.3, 5.7, 2.5],
        [6.7, 3. , 5.2, 2.3],
        [6.3, 2.5, 5. , 1.9],
        [6.5, 3. , 5.2, 2. ],
        [6.2, 3.4, 5.4, 2.3],
        [5.9, 3. , 5.1, 1.8]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'frame': None,
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n                \n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...',
 'feature_names': ['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)'],
 'filename': 'iris.csv',
 'data_module': 'sklearn.datasets.data'}
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df["target"]=iris.target
df.columns=iris.feature_names+["target"]
df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2

150 rows × 5 columns

df.head()
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0

选择长和宽的数据进行可视化

#选取前100行数据进行可视化
plt.figure(figsize=(12, 8))
plt.scatter(df[:50]["sepal length (cm)"], df[:50]["sepal width (cm)"], label='0')
plt.scatter(df[50:100]["sepal length (cm)"], df[50:100]["sepal width (cm)"], label='1')
plt.xlabel('sepal length', fontsize=18)
plt.ylabel('sepal width', fontsize=18)
plt.legend()
plt.show()

【Python机器学习】实验06 KNN最近邻算法,《 Python机器学习入门实验 》,python,机器学习,近邻算法

2 划分训练数据和测试数据

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(df.iloc[:100,:2].values,df.iloc[:100,-1].values)
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((75, 2), (25, 2), (75,), (25,))
X_train,y_train
(array([[5. , 3.3],
        [4.6, 3.4],
        [5.2, 4.1],
        [5.7, 2.8],
        [5.1, 3.4],
        [4.8, 3. ],
        [5.9, 3.2],
        [5.7, 3.8],
        [4.8, 3.4],
        [5.3, 3.7],
        [5.1, 3.8],
        [5.5, 2.4],
        [6. , 2.2],
        [5.5, 4.2],
        [5.5, 2.6],
        [5.4, 3.4],
        [4.4, 2.9],
        [6. , 2.9],
        [5.8, 2.7],
        [4.4, 3.2],
        [5.6, 2.9],
        [5.8, 2.7],
        [6.7, 3.1],
        [6. , 2.7],
        [5.7, 2.9],
        [4.6, 3.2],
        [4.9, 3.1],
        [7. , 3.2],
        [4.7, 3.2],
        [5.1, 2.5],
        [6.3, 2.3],
        [4.6, 3.1],
        [6.4, 3.2],
        [6.6, 3. ],
        [4.6, 3.6],
        [5.5, 2.4],
        [5.6, 3. ],
        [5.1, 3.7],
        [6.1, 2.8],
        [5.6, 2.7],
        [4.8, 3.1],
        [4.8, 3. ],
        [5. , 3.5],
        [6.2, 2.2],
        [6. , 3.4],
        [5.1, 3.3],
        [5.4, 3.9],
        [5.7, 2.6],
        [6.7, 3.1],
        [4.5, 2.3],
        [4.8, 3.4],
        [4.9, 2.4],
        [5.8, 4. ],
        [5. , 3. ],
        [6.6, 2.9],
        [6.1, 2.9],
        [5. , 3.5],
        [6.8, 2.8],
        [5. , 2.3],
        [5.4, 3. ],
        [4.3, 3. ],
        [4.9, 3.1],
        [4.9, 3. ],
        [5.1, 3.8],
        [5.1, 3.5],
        [5.5, 2.5],
        [5. , 3.6],
        [5. , 3.4],
        [5.4, 3.9],
        [5.1, 3.8],
        [5.1, 3.5],
        [5.2, 3.5],
        [5.8, 2.6],
        [6.4, 2.9],
        [6.1, 2.8]]),
 array([0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1,
        1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1,
        1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 1, 1, 1]))

3 通过K个近邻预测的标签的距离来预测当前样本的标签

#定义邻居数量
from collections import Counter
K=3
KNN_x=[]
for i in range(X_train.shape[0]):
    if len(KNN_x)<K:
        KNN_x.append((euclidean(X_test[0],X_train[i]),y_train[i]))
KNN_x
[(0.6324555320336757, 0), (0.9219544457292889, 0), (1.3999999999999995, 0)]
count=Counter([item[1] for item in KNN_x])
count
Counter({0: 3})
count.items()
dict_items([(0, 3)])
sorted(count.items(),key=lambda x:x[1])[-1][0]
0
#返回任意一个样本x的标签
def calcu_distance_return(x,X_train,y_train):
    KNN_x=[]
    #遍历训练集中的每个样本
    for i in range(X_train.shape[0]):
        if len(KNN_x)<K:
            KNN_x.append((euclidean(x,X_train[i]),y_train[i]))
        else:
            KNN_x.sort()
            for j in range(K): 
                if (euclidean(x,X_train[i]))< KNN_x[j][0]:
                    KNN_x[j]=(euclidean(x,X_train[i]),y_train[i])
                    break
    knn_label=[item[1] for item in KNN_x]           
    counter_knn=Counter(knn_label) 
    return sorted(counter_knn.items(),key=lambda item:item[1])[-1][0]                  
#对整个测试集进行预测
def predict(X_test):
    y_pred=np.zeros(X_test.shape[0])
    for i in range(X_test.shape[0]):
        y_hat_i=calcu_distance_return(X_test[i],X_train,y_train) 
        y_pred[i]=y_hat_i
    return y_pred

4 计算准确率

#输出预测结果
y_pred= predict(X_test).astype("int32")
y_pred
array([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 0])
y_test=y_test.astype("int32")
y_test
array([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 0])
#计算准确率
np.sum(y_pred==y_test)/X_test.shape[0]
1.0

试试Scikit-learn

sklearn.neighbors.KNeighborsClassifier
  • n_neighbors: 临近点个数,即k的个数,默认是5

  • p: 距离度量,默认

  • algorithm: 近邻算法,可选{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}

  • weights: 确定近邻的权重

  • n_neighbors : int,optional(default = 5)
    默认情况下kneighbors查询使用的邻居数。就是k-NN的k的值,选取最近的k个点。

  • weights : str或callable,可选(默认=‘uniform’)
    默认是uniform,参数可以是uniform、distance,也可以是用户自己定义的函数。uniform是均等的权重,就说所有的邻近点的权重都是相等的。distance是不均等的权重,距离近的点比距离远的点的影响大。用户自定义的函数,接收距离的数组,返回一组维数相同的权重。

  • algorithm : {‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选
    快速k近邻搜索算法,默认参数为auto,可以理解为算法自己决定合适的搜索算法。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索,brute是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。kd_tree,构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。ball tree是为了克服kd树高纬失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体。

  • leaf_size : int,optional(默认值= 30)
    默认是30,这个是构造的kd树和ball树的大小。这个值的设置会影响树构建的速度和搜索速度,同样也影响着存储树所需的内存大小。需要根据问题的性质选择最优的大小。

  • p : 整数,可选(默认= 2)
    距离度量公式。在上小结,我们使用欧氏距离公式进行距离度量。除此之外,还有其他的度量方法,例如曼哈顿距离。这个参数默认为2,也就是默认使用欧式距离公式进行距离度量。也可以设置为1,使用曼哈顿距离公式进行距离度量。

  • metric : 字符串或可调用,默认为’minkowski’
    用于距离度量,默认度量是minkowski,也就是p=2的欧氏距离(欧几里德度量)。

  • metric_params : dict,optional(默认=None)
    距离公式的其他关键参数,这个可以不管,使用默认的None即可。

  • n_jobs : int或None,可选(默认=None)
    并行处理设置。默认为1,临近点搜索并行工作数。如果为-1,那么CPU的所有cores都用于并行工作。

# 1导入模块
from sklearn.neighbors import KNeighborsClassifier
# 2创建KNN近邻实例
knn=KNeighborsClassifier(n_neighbors=4)
# 3 拟合该模型
knn.fit(X_train,y_train)
# 4 得到分数
knn.score(X_test,y_test)
1.0

试试其他的近邻数量

# 1导入模块
from sklearn.neighbors import KNeighborsClassifier
# 2创建KNN近邻实例
knn=KNeighborsClassifier(n_neighbors=2)
# 3 拟合该模型
knn.fit(X_train,y_train)
# 4 得到分数
knn.score(X_test,y_test)
1.0
# 1导入模块
from sklearn.neighbors import KNeighborsClassifier
# 2创建KNN近邻实例
knn=KNeighborsClassifier(n_neighbors=6)
# 3 拟合该模型
knn.fit(X_train,y_train)
# 4 得到分数
knn.score(X_test,y_test)
1.0
#5 搜索一下什么样的邻居个数K是最好的,K的范围这里设置为1,10
from sklearn.model_selection import train_test_split
def getBestK(X_train,y_train,K):
    best_score=0
    best_k=1
    best_model=knn=KNeighborsClassifier(1)
    X_train_set,X_val,y_train_set,y_val=train_test_split(X_train,y_train,random_state=0)
    for num in range(1,K):
        knn=KNeighborsClassifier(num)
        knn.fit(X_train_set,y_train_set)
        score=round(knn.score(X_val,y_val),2)
        print(score,num)
        if score>best_score:
            best_k=num
            best_score=score
            best_model=knn
    return best_k,best_score,best_model

best_k,best_score,best_model=getBestK(X_train,y_train,11)
0.95 1
0.95 2
0.95 3
0.95 4
0.95 5
1.0 6
1.0 7
1.0 8
1.0 9
1.0 10
#5采用测试集查看经验风险
best_model.score(X_test,y_test)
1.0

上面选择的k是在一次对训练集的划分的验证集上选的参数,具有一定的偶然性,使得最后根据最高验证分数选出来的在测试集上的效果不佳

#6 试试交叉验证误差
from sklearn.model_selection import RepeatedKFold
rkf=RepeatedKFold(n_repeats=10,n_splits=5,random_state=42)
for i,(train_index,test_index) in enumerate(rkf.split(X_train)):
    print("train_index",train_index)
    print("test_index",test_index)
#     print("新的训练数据为",X_train[train_index],y_train[train_index])
#     print("新的验证数据为",X_train[test_index],y_train[test_index])
train_index [ 1  2  3  5  6  7  8 11 13 14 15 16 17 19 20 21 22 23 24 25 26 27 29 30
 31 32 33 36 37 38 39 40 41 43 44 45 46 47 48 50 51 52 53 54 55 56 57 58
 59 60 62 65 66 67 68 70 71 72 73 74]
test_index [ 0  4  9 10 12 18 28 34 35 42 49 61 63 64 69]
train_index [ 0  1  2  3  4  6  8  9 10 11 12 13 14 15 17 18 19 20 21 23 24 25 26 27
 28 29 32 34 35 36 37 38 41 42 43 46 48 49 50 51 52 53 54 55 57 59 60 61
 62 63 64 65 67 68 69 70 71 72 73 74]
test_index [ 5  7 16 22 30 31 33 39 40 44 45 47 56 58 66]
train_index [ 0  1  2  4  5  7  9 10 11 12 14 15 16 18 20 21 22 23 24 26 27 28 29 30
 31 32 33 34 35 37 39 40 41 42 43 44 45 46 47 48 49 51 52 55 56 57 58 59
 60 61 63 64 65 66 67 68 69 70 71 73]
test_index [ 3  6  8 13 17 19 25 36 38 50 53 54 62 72 74]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 16 17 18 19 20 21 22 23 25 28
 29 30 31 33 34 35 36 37 38 39 40 42 44 45 47 49 50 51 52 53 54 56 58 59
 60 61 62 63 64 65 66 69 70 71 72 74]
test_index [11 15 24 26 27 32 41 43 46 48 55 57 67 68 73]
train_index [ 0  3  4  5  6  7  8  9 10 11 12 13 15 16 17 18 19 22 24 25 26 27 28 30
 31 32 33 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 50 53 54 55 56 57
 58 61 62 63 64 66 67 68 69 72 73 74]
test_index [ 1  2 14 20 21 23 29 37 51 52 59 60 65 70 71]
train_index [ 0  2  3  4  6  7  8  9 10 11 12 13 14 16 18 19 21 22 23 24 25 26 27 28
 30 32 33 34 35 36 37 38 39 40 41 42 43 44 47 48 50 52 53 54 55 56 57 58
 59 61 62 64 65 66 67 68 70 71 72 73]
test_index [ 1  5 15 17 20 29 31 45 46 49 51 60 63 69 74]
train_index [ 0  1  2  4  5  6  7  8 10 11 13 14 15 16 17 20 21 22 23 25 26 27 28 29
 31 32 33 34 35 36 38 39 40 41 43 44 45 46 49 50 51 52 53 54 55 56 57 59
 60 61 62 63 64 65 66 69 70 71 73 74]
test_index [ 3  9 12 18 19 24 30 37 42 47 48 58 67 68 72]
train_index [ 0  1  3  4  5  6  7  8  9 10 11 12 14 15 16 17 18 19 20 23 24 25 27 28
 29 30 31 32 34 37 38 40 41 42 43 44 45 46 47 48 49 50 51 52 56 57 58 59
 60 62 63 64 65 67 68 69 70 72 73 74]
test_index [ 2 13 21 22 26 33 35 36 39 53 54 55 61 66 71]
train_index [ 0  1  2  3  5  7  8  9 10 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28
 29 30 31 33 35 36 37 39 40 42 43 44 45 46 47 48 49 51 52 53 54 55 58 59
 60 61 63 64 66 67 68 69 71 72 73 74]
test_index [ 4  6 11 16 27 32 34 38 41 50 56 57 62 65 70]
train_index [ 1  2  3  4  5  6  9 11 12 13 15 16 17 18 19 20 21 22 24 26 27 29 30 31
 32 33 34 35 36 37 38 39 41 42 45 46 47 48 49 50 51 53 54 55 56 57 58 60
 61 62 63 65 66 67 68 69 70 71 72 74]
test_index [ 0  7  8 10 14 23 25 28 40 43 44 52 59 64 73]
train_index [ 0  1  2  3  4  5  7  8 10 11 14 16 18 19 20 21 22 23 24 25 26 27 28 29
 31 32 35 36 38 39 40 41 42 43 45 46 47 48 49 50 51 52 53 54 55 56 57 58
 61 62 63 64 66 67 68 69 71 72 73 74]
test_index [ 6  9 12 13 15 17 30 33 34 37 44 59 60 65 70]
train_index [ 0  1  2  5  6  7  8  9 11 12 13 14 15 16 17 18 20 22 23 26 27 29 30 31
 32 33 34 36 37 38 40 41 43 44 45 47 48 50 51 53 54 55 56 57 58 59 60 61
 63 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 3  4 10 19 21 24 25 28 35 39 42 46 49 52 62]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 17 19 21 23 24 25 26 27
 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 46 49 50 51 52 53 59
 60 61 62 63 65 66 68 69 70 71 73 74]
test_index [16 18 20 22 45 47 48 54 55 56 57 58 64 67 72]
train_index [ 0  2  3  4  5  6  7  9 10 12 13 15 16 17 18 19 20 21 22 24 25 26 27 28
 29 30 33 34 35 37 38 39 42 43 44 45 46 47 48 49 52 54 55 56 57 58 59 60
 61 62 64 65 66 67 68 69 70 72 73 74]
test_index [ 1  8 11 14 23 31 32 36 40 41 50 51 53 63 71]
train_index [ 1  3  4  6  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 28 30
 31 32 33 34 35 36 37 39 40 41 42 44 45 46 47 48 49 50 51 52 53 54 55 56
 57 58 59 60 62 63 64 65 67 70 71 72]
test_index [ 0  2  5  7 26 27 29 38 43 61 66 68 69 73 74]
train_index [ 0  1  2  3  4  6  7  8 10 11 13 15 17 18 19 20 21 22 23 24 25 27 28 29
 30 31 32 33 34 36 37 38 39 40 41 44 45 46 47 48 49 51 52 53 54 55 56 57
 59 60 61 66 67 68 69 70 71 72 73 74]
test_index [ 5  9 12 14 16 26 35 42 43 50 58 62 63 64 65]
train_index [ 0  1  2  4  5  6  7  8  9 10 11 12 14 15 16 18 19 22 23 24 25 26 29 30
 31 32 34 35 36 37 38 39 40 41 42 43 44 47 48 49 50 51 55 56 57 58 59 62
 63 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 3 13 17 20 21 27 28 33 45 46 52 53 54 60 61]
train_index [ 0  1  3  4  5  6  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 25 26 27
 28 29 30 31 32 33 34 35 36 38 39 41 42 43 45 46 47 48 49 50 51 52 53 54
 55 56 58 60 61 62 63 64 65 67 70 71]
test_index [ 2  7 23 24 37 40 44 57 59 66 68 69 72 73 74]
train_index [ 0  2  3  5  7  9 10 12 13 14 16 17 18 19 20 21 22 23 24 26 27 28 29 30
 32 33 35 37 38 39 40 41 42 43 44 45 46 49 50 51 52 53 54 56 57 58 59 60
 61 62 63 64 65 66 68 69 70 72 73 74]
test_index [ 1  4  6  8 11 15 25 31 34 36 47 48 55 67 71]
train_index [ 1  2  3  4  5  6  7  8  9 11 12 13 14 15 16 17 20 21 23 24 25 26 27 28
 31 33 34 35 36 37 40 42 43 44 45 46 47 48 50 52 53 54 55 57 58 59 60 61
 62 63 64 65 66 67 68 69 71 72 73 74]
test_index [ 0 10 18 19 22 29 30 32 38 39 41 49 51 56 70]
train_index [ 0  1  2  3  4  5  7  8  9 13 14 16 17 18 20 21 22 23 24 25 26 27 28 29
 30 31 32 34 35 36 37 38 40 41 42 43 44 45 46 47 48 50 53 54 56 59 60 61
 63 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 6 10 11 12 15 19 33 39 49 51 52 55 57 58 62]
train_index [ 2  3  4  5  6  7 10 11 12 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
 30 31 32 33 34 36 37 39 40 42 43 45 46 47 48 49 50 51 52 53 55 56 57 58
 59 60 61 62 63 64 65 66 67 69 72 74]
test_index [ 0  1  8  9 13 14 35 38 41 44 54 68 70 71 73]
train_index [ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 18 19 20 26 27 28 29 32
 33 34 35 36 37 38 39 40 41 43 44 45 47 48 49 50 51 52 53 54 55 56 57 58
 59 60 62 63 65 66 68 69 70 71 73 74]
test_index [ 2 17 21 22 23 24 25 30 31 42 46 61 64 67 72]
train_index [ 0  1  2  4  6  7  8  9 10 11 12 13 14 15 17 19 20 21 22 23 24 25 26 27
 29 30 31 32 33 35 37 38 39 41 42 44 46 49 50 51 52 53 54 55 57 58 59 60
 61 62 63 64 67 68 69 70 71 72 73 74]
test_index [ 3  5 16 18 28 34 36 40 43 45 47 48 56 65 66]
train_index [ 0  1  2  3  5  6  8  9 10 11 12 13 14 15 16 17 18 19 21 22 23 24 25 28
 30 31 33 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 51 52 54 55 56 57
 58 61 62 64 65 66 67 68 70 71 72 73]
test_index [ 4  7 20 26 27 29 32 37 50 53 59 60 63 69 74]
train_index [ 0  1  3  4  5  7  8 11 12 13 14 15 16 18 19 20 21 22 23 24 25 26 27 28
 29 30 31 32 34 35 36 37 38 39 41 42 43 44 45 46 48 50 51 52 54 56 57 58
 59 60 62 63 64 65 66 67 69 70 73 74]
test_index [ 2  6  9 10 17 33 40 47 49 53 55 61 68 71 72]
train_index [ 2  3  4  5  6  7  9 10 12 13 14 15 16 17 18 19 21 24 25 27 29 31 32 33
 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 55 57 58 59 60
 61 62 63 64 65 66 67 68 69 70 71 72]
test_index [ 0  1  8 11 20 22 23 26 28 30 37 54 56 73 74]
train_index [ 0  1  2  5  6  7  8  9 10 11 13 14 15 17 19 20 21 22 23 24 26 28 30 31
 32 33 35 36 37 40 41 42 43 44 46 47 48 49 50 51 53 54 55 56 57 58 59 60
 61 62 63 64 65 67 68 70 71 72 73 74]
test_index [ 3  4 12 16 18 25 27 29 34 38 39 45 52 66 69]
train_index [ 0  1  2  3  4  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 33 34 35 37 38 39 40 44 45 47 49 50 52 53 54 55 56
 57 61 62 64 65 66 68 69 71 72 73 74]
test_index [ 5 32 36 41 42 43 46 48 51 58 59 60 63 67 70]
train_index [ 0  1  2  3  4  5  6  8  9 10 11 12 16 17 18 20 22 23 25 26 27 28 29 30
 32 33 34 36 37 38 39 40 41 42 43 45 46 47 48 49 51 52 53 54 55 56 58 59
 60 61 63 66 67 68 69 70 71 72 73 74]
test_index [ 7 13 14 15 19 21 24 31 35 44 50 57 62 64 65]
train_index [ 0  1  2  3  4  6  7  8  9 10 11 12 13 15 16 17 18 19 22 23 24 26 27 28
 30 31 32 33 34 35 36 37 38 39 43 44 45 46 47 48 51 52 53 54 55 56 57 59
 60 61 62 65 66 67 68 69 70 72 73 74]
test_index [ 5 14 20 21 25 29 40 41 42 49 50 58 63 64 71]
train_index [ 0  1  2  3  4  5  7  9 11 14 15 18 19 20 21 22 23 25 26 27 28 29 30 31
 32 33 34 35 36 37 38 39 40 41 42 44 46 47 48 49 50 51 52 53 55 56 57 58
 60 61 62 63 64 65 67 68 69 70 71 72]
test_index [ 6  8 10 12 13 16 17 24 43 45 54 59 66 73 74]
train_index [ 0  1  3  4  5  6  8  9 10 12 13 14 15 16 17 18 20 21 22 23 24 25 28 29
 30 31 32 33 35 38 40 41 42 43 44 45 46 47 48 49 50 51 53 54 56 57 58 59
 60 61 62 63 64 66 68 69 71 72 73 74]
test_index [ 2  7 11 19 26 27 34 36 37 39 52 55 65 67 70]
train_index [ 2  4  5  6  7  8  9 10 11 12 13 14 15 16 17 19 20 21 22 24 25 26 27 28
 29 32 34 36 37 38 39 40 41 42 43 45 46 47 49 50 52 53 54 55 56 57 58 59
 61 63 64 65 66 67 68 70 71 72 73 74]
test_index [ 0  1  3 18 23 30 31 33 35 44 48 51 60 62 69]
train_index [ 0  1  2  3  5  6  7  8 10 11 12 13 14 16 17 18 19 20 21 23 24 25 26 27
 29 30 31 33 34 35 36 37 39 40 41 42 43 44 45 48 49 50 51 52 54 55 58 59
 60 62 63 64 65 66 67 69 70 71 73 74]
test_index [ 4  9 15 22 28 32 38 46 47 53 56 57 61 68 72]
train_index [ 2  3  4  6  8  9 10 11 12 13 14 15 16 18 19 20 21 22 23 24 26 27 29 30
 32 33 34 35 36 37 38 39 40 42 44 45 46 47 48 49 50 51 53 54 56 59 60 61
 62 63 64 65 66 67 68 70 71 72 73 74]
test_index [ 0  1  5  7 17 25 28 31 41 43 52 55 57 58 69]
train_index [ 0  1  3  4  5  6  7  8 11 12 13 15 16 17 18 19 20 21 22 23 24 25 27 28
 29 30 31 32 34 35 36 40 41 43 44 45 47 48 50 52 53 54 55 56 57 58 59 60
 61 63 64 65 67 68 69 70 71 72 73 74]
test_index [ 2  9 10 14 26 33 37 38 39 42 46 49 51 62 66]
train_index [ 0  1  2  5  7  9 10 11 12 14 16 17 18 19 21 22 23 24 25 26 28 29 31 33
 34 35 36 37 38 39 40 41 42 43 46 47 48 49 50 51 52 54 55 56 57 58 59 61
 62 63 65 66 67 68 69 70 71 72 73 74]
test_index [ 3  4  6  8 13 15 20 27 30 32 44 45 53 60 64]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 17 20 22 23 24 25 26 27
 28 30 31 32 33 34 35 36 37 38 39 41 42 43 44 45 46 48 49 51 52 53 54 55
 57 58 60 61 62 63 64 66 68 69 72 73]
test_index [16 18 19 21 29 40 47 50 56 59 65 67 70 71 74]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 13 14 15 16 17 18 19 20 21 25 26 27 28
 29 30 31 32 33 37 38 39 40 41 42 43 44 45 46 47 49 50 51 52 53 55 56 57
 58 59 60 62 64 65 66 67 69 70 71 74]
test_index [11 12 22 23 24 34 35 36 48 54 61 63 68 72 73]
train_index [ 0  2  3  4  5  7  8  9 10 12 13 14 15 16 17 18 19 20 22 24 25 26 27 28
 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 46 47 48 49 51 52 53 57 58
 59 60 61 62 63 64 65 66 67 69 73 74]
test_index [ 1  6 11 21 23 29 45 50 54 55 56 68 70 71 72]
train_index [ 0  1  2  3  4  5  6  7  9 10 11 12 15 16 18 19 20 21 23 24 25 26 27 28
 29 30 31 32 34 35 36 37 38 39 40 43 44 45 46 48 49 50 51 52 53 54 55 56
 57 59 60 63 64 65 66 68 69 70 71 72]
test_index [ 8 13 14 17 22 33 41 42 47 58 61 62 67 73 74]
train_index [ 1  2  3  4  5  6  7  8  9 11 12 13 14 16 17 18 19 21 22 23 25 26 27 28
 29 30 33 35 36 37 38 41 42 43 44 45 47 48 50 53 54 55 56 57 58 59 60 61
 62 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 0 10 15 20 24 31 32 34 39 40 46 49 51 52 63]
train_index [ 0  1  3  4  5  6  7  8 10 11 13 14 15 16 17 18 20 21 22 23 24 28 29 30
 31 32 33 34 35 36 37 39 40 41 42 44 45 46 47 49 50 51 52 54 55 56 58 59
 61 62 63 64 65 67 68 70 71 72 73 74]
test_index [ 2  9 12 19 25 26 27 38 43 48 53 57 60 66 69]
train_index [ 0  1  2  6  8  9 10 11 12 13 14 15 17 19 20 21 22 23 24 25 26 27 29 31
 32 33 34 38 39 40 41 42 43 45 46 47 48 49 50 51 52 53 54 55 56 57 58 60
 61 62 63 66 67 68 69 70 71 72 73 74]
test_index [ 3  4  5  7 16 18 28 30 35 36 37 44 59 64 65]
train_index [ 0  1  2  4  5  9 10 12 15 16 17 18 19 20 21 22 24 25 26 27 28 29 30 31
 32 33 34 36 38 39 40 41 42 44 45 46 47 48 49 50 51 52 54 55 56 57 58 59
 60 61 62 63 64 65 66 68 69 71 72 73]
test_index [ 3  6  7  8 11 13 14 23 35 37 43 53 67 70 74]
train_index [ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 20 21 22 23 24 25 26
 27 28 29 31 32 33 34 35 37 40 42 43 44 45 46 47 49 50 53 54 55 56 57 58
 59 60 61 62 63 65 67 68 69 70 72 74]
test_index [ 2 18 19 30 36 38 39 41 48 51 52 64 66 71 73]
train_index [ 0  1  2  3  4  5  6  7  8  9 11 12 13 14 16 17 18 19 23 24 26 27 28 29
 30 32 34 35 36 37 38 39 40 41 43 44 45 46 48 49 50 51 52 53 56 57 58 59
 60 62 63 64 65 66 67 70 71 72 73 74]
test_index [10 15 20 21 22 25 31 33 42 47 54 55 61 68 69]
train_index [ 2  3  6  7  8 10 11 12 13 14 15 16 18 19 20 21 22 23 25 26 27 28 30 31
 32 33 34 35 36 37 38 39 40 41 42 43 45 47 48 49 51 52 53 54 55 57 59 60
 61 62 63 64 66 67 68 69 70 71 73 74]
test_index [ 0  1  4  5  9 17 24 29 44 46 50 56 58 65 72]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 11 13 14 15 17 18 19 20 21 22 23 24 25
 29 30 31 33 35 36 37 38 39 41 42 43 44 46 47 48 50 51 52 53 54 55 56 58
 61 64 65 66 67 68 69 70 71 72 73 74]
test_index [12 16 26 27 28 32 34 40 45 49 57 59 60 62 63]
from sklearn.model_selection import cross_validate
cross_validate(knn,X_train,y_train,cv=rkf,scoring="accuracy",return_estimator=True)
{'fit_time': array([0.00099969, 0.        , 0.00099897, 0.        , 0.        ,
        0.00100088, 0.00100112, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.00099134, 0.00101256, 0.00099635,
        0.        , 0.        , 0.        , 0.00099874, 0.        ,
        0.00105643, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.00100422,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ]),
 'score_time': array([0.00099945, 0.00100017, 0.        , 0.00099826, 0.0010016 ,
        0.00099826, 0.00112462, 0.00212598, 0.00103188, 0.00099683,
        0.0009737 , 0.00103641, 0.        , 0.        , 0.        ,
        0.00097394, 0.00102925, 0.00099778, 0.        , 0.00100136,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.00100565, 0.00099897, 0.        , 0.00099373, 0.00099897,
        0.00100088, 0.00106072, 0.00103712, 0.00107408, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.00101113, 0.0010767 , 0.00099373, 0.00093102]),
 'estimator': [KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6)],
 'test_score': array([1.        , 1.        , 1.        , 1.        , 0.93333333,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        0.93333333, 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ])}
#5 搜索一下什么样的邻居个数K是最好的,K的范围这里设置为1,10
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_validate
def getBestK(X_train,y_train,K):
    best_score=0
    best_k=1
#     X_train_set,X_val,y_train_set,y_val=train_test_split(X_train,y_train)
    rkf=RepeatedKFold(n_repeats=5,n_splits=5,random_state=42)
    for num in range(1,K):
        knn=KNeighborsClassifier(num)
        result=cross_validate(knn,X_train,y_train,cv=rkf,scoring="f1")
        score=result["test_score"].mean()
        score=round(score,2)
        print(score,num)
        if score>best_score:
            best_k=num
            best_score=score
    return best_k,best_score
best_k,best_score=getBestK(X_train,y_train,15)
best_k,best_score
0.98 1
0.99 2
0.99 3
0.99 4
0.99 5
0.99 6
1.0 7
0.99 8
0.99 9
0.98 10
0.98 11
0.97 12
0.98 13
0.97 14


(7, 1.0)
knn=KNeighborsClassifier(best_k)
knn.fit(X_train,y_train)
knn.score(X_test,y_test)
1.0

自动调参吧,试试循环,找到最优的k值文章来源地址https://www.toymoban.com/news/detail-618852.html

实验:试试用KNN完成回归任务

1 准备数据

import numpy as np
x1=np.linspace(-10,10,100)
x2=np.linspace(-5,15,100)
#手动构造一些数据
w1=5
w2=4
y=x1*w1+x2*w2
y
array([-70.        , -68.18181818, -66.36363636, -64.54545455,
       -62.72727273, -60.90909091, -59.09090909, -57.27272727,
       -55.45454545, -53.63636364, -51.81818182, -50.        ,
       -48.18181818, -46.36363636, -44.54545455, -42.72727273,
       -40.90909091, -39.09090909, -37.27272727, -35.45454545,
       -33.63636364, -31.81818182, -30.        , -28.18181818,
       -26.36363636, -24.54545455, -22.72727273, -20.90909091,
       -19.09090909, -17.27272727, -15.45454545, -13.63636364,
       -11.81818182, -10.        ,  -8.18181818,  -6.36363636,
        -4.54545455,  -2.72727273,  -0.90909091,   0.90909091,
         2.72727273,   4.54545455,   6.36363636,   8.18181818,
        10.        ,  11.81818182,  13.63636364,  15.45454545,
        17.27272727,  19.09090909,  20.90909091,  22.72727273,
        24.54545455,  26.36363636,  28.18181818,  30.        ,
        31.81818182,  33.63636364,  35.45454545,  37.27272727,
        39.09090909,  40.90909091,  42.72727273,  44.54545455,
        46.36363636,  48.18181818,  50.        ,  51.81818182,
        53.63636364,  55.45454545,  57.27272727,  59.09090909,
        60.90909091,  62.72727273,  64.54545455,  66.36363636,
        68.18181818,  70.        ,  71.81818182,  73.63636364,
        75.45454545,  77.27272727,  79.09090909,  80.90909091,
        82.72727273,  84.54545455,  86.36363636,  88.18181818,
        90.        ,  91.81818182,  93.63636364,  95.45454545,
        97.27272727,  99.09090909, 100.90909091, 102.72727273,
       104.54545455, 106.36363636, 108.18181818, 110.        ])
x1=x1.reshape(len(x1),1)
x2=x2.reshape(len(x2),1)
y=y.reshape(len(y),1)
import pandas as pd
data=np.hstack([x1,x2,y])
# 给数据加点噪声
np.random.seed=10
data=data+np.random.normal(0.1,1,[100,3])
data
array([[-9.80997918e+00, -4.47671228e+00, -6.86113562e+01],
       [-9.07863100e+00, -3.29030887e+00, -6.75412089e+01],
       [-8.17535392e+00, -4.85515660e+00, -6.56682184e+01],
       [-9.33603110e+00, -4.67304042e+00, -6.39943055e+01],
       [-8.31454149e+00, -3.61401814e+00, -6.15552168e+01],
       [-9.35462761e+00, -3.99216837e+00, -6.16450829e+01],
       [-7.35641032e+00, -5.10713257e+00, -5.80574405e+01],
       [-7.75808720e+00, -2.81374154e+00, -5.72785817e+01],
       [-7.85420726e+00, -3.25192460e+00, -5.58260703e+01],
       [-7.79785201e+00, -4.59268755e+00, -5.46208629e+01],
       [-9.90411101e+00, -7.55985286e-01, -5.19239440e+01],
       [-4.91167456e+00, -1.48242138e+00, -5.06778041e+01],
       [-9.25608953e+00, -1.12391146e+00, -4.80701720e+01],
       [-6.92987717e+00, -3.58106474e+00, -4.58459514e+01],
       [-7.19890084e+00, -2.10260074e+00, -4.46497119e+01],
       [-8.56812108e+00, -2.45314063e+00, -4.19130070e+01],
       [-6.97527315e+00, -3.25615055e+00, -4.15373469e+01],
       [-6.09201512e+00, -1.07060626e+00, -4.05034362e+01],
       [-5.94248008e+00,  6.42232477e-01, -3.64281226e+01],
       [-5.99567467e+00, -2.26531046e+00, -3.32873129e+01],
       [-7.56906953e+00, -6.81005515e-01, -3.42368449e+01],
       [-6.54272630e+00, -7.32829423e-01, -3.18556358e+01],
       [-4.68241322e+00, -1.55653397e+00, -2.99105801e+01],
       [-5.61148642e+00, -1.96269845e+00, -2.80144819e+01],
       [-4.64818297e+00,  2.21684956e-01, -2.56420739e+01],
       [-5.64237828e+00, -5.05215614e-02, -2.44150985e+01],
       [-4.77269716e+00,  3.12543954e-01, -2.35962190e+01],
       [-3.93579614e+00,  3.14368041e-01, -2.04078436e+01],
       [-4.67599369e+00,  1.38646098e+00, -1.95569688e+01],
       [-4.56613680e+00,  2.18761537e-01, -1.76443732e+01],
       [-4.12462083e+00,  7.81731566e-01, -1.55500903e+01],
       [-5.00893448e+00,  8.43167883e-01, -1.37904298e+01],
       [-3.32575389e+00,  8.87284515e-01, -1.16870554e+01],
       [-4.60962500e+00,  2.47674165e+00, -9.43497025e+00],
       [-2.55399230e+00,  1.60304976e+00, -7.30116575e+00],
       [-3.92552974e+00,  2.02861216e+00, -8.47211685e+00],
       [-2.85445054e+00,  1.32252697e+00, -2.27221086e+00],
       [-3.20383909e+00,  1.56885433e+00, -1.46024067e+00],
       [-1.87732669e+00,  1.18972183e+00, -1.68276177e+00],
       [-1.35842429e+00,  3.76086938e+00,  3.35135047e-01],
       [-7.24957523e-01,  4.37716480e+00,  1.17352349e+00],
       [-3.70453016e+00,  5.08438460e+00,  3.35207490e+00],
       [-7.97872551e-01,  2.78241431e+00,  5.09073378e+00],
       [-3.08232423e+00,  4.21925884e+00,  7.90719675e+00],
       [ 5.28844300e-01,  4.16412164e+00,  1.01885052e+01],
       [-2.64895900e-02,  4.04451188e+00,  1.32964325e+01],
       [ 7.67644414e-01,  4.38295411e+00,  1.20330676e+01],
       [-3.17298624e-01,  5.52193479e+00,  1.44587349e+01],
       [-4.05576007e-01,  6.15916945e+00,  1.77192591e+01],
       [ 2.58635850e-01,  4.36652636e+00,  2.08469868e+01],
       [-1.15875757e+00,  5.86049204e+00,  2.12312972e+01],
       [-7.16862753e-01,  7.60609045e+00,  2.24464377e+01],
       [ 1.00827677e+00,  7.13593566e+00,  2.60236434e+01],
       [ 8.64304920e-01,  7.70071685e+00,  2.67335947e+01],
       [ 3.14401551e+00,  5.74841619e+00,  2.76627520e+01],
       [-1.18085370e-02,  5.45967297e+00,  3.01731518e+01],
       [ 9.67211352e-01,  6.30044676e+00,  3.31847137e+01],
       [ 1.32254229e+00,  6.51216091e+00,  3.31636096e+01],
       [ 9.66206984e-01,  8.15352634e+00,  3.54552668e+01],
       [ 1.50374715e+00,  8.38063421e+00,  3.82675089e+01],
       [ 1.20333031e+00,  8.30155252e+00,  4.05759780e+01],
       [ 2.84702572e+00,  7.44997601e+00,  4.16313092e+01],
       [ 2.82319554e+00,  7.03396275e+00,  4.33733979e+01],
       [ 3.88755763e+00,  9.63373825e+00,  4.63550733e+01],
       [ 3.31979805e+00,  1.00825563e+01,  4.66602506e+01],
       [ 3.67714879e+00,  8.98817386e+00,  4.71815191e+01],
       [ 5.61673924e+00,  8.83321195e+00,  4.90218726e+01],
       [ 4.64376606e+00,  1.05003123e+01,  5.16821640e+01],
       [ 3.38312917e+00,  9.93985678e+00,  5.44523927e+01],
       [ 2.90435391e+00,  8.76211593e+00,  5.72974806e+01],
       [ 1.94362594e+00,  8.37086325e+00,  5.69748221e+01],
       [ 4.86357671e+00,  8.79920772e+00,  5.92178403e+01],
       [ 5.21731274e+00,  8.76064972e+00,  6.30249467e+01],
       [ 5.86040809e+00,  1.12868041e+01,  6.26973140e+01],
       [ 4.05985223e+00,  8.65847315e+00,  6.61012727e+01],
       [ 6.19899121e+00,  8.30649111e+00,  6.37680817e+01],
       [ 5.73989925e+00,  1.00161474e+01,  6.92336558e+01],
       [ 5.38266361e+00,  1.03971821e+01,  7.17084241e+01],
       [ 7.23264561e+00,  1.20494918e+01,  7.05362027e+01],
       [ 6.11948179e+00,  1.19855375e+01,  7.55318286e+01],
       [ 8.03847795e+00,  9.79749582e+00,  7.47950707e+01],
       [ 8.30070319e+00,  1.07233637e+01,  7.93806649e+01],
       [ 7.44456666e+00,  1.11936713e+01,  7.84042566e+01],
       [ 6.87035796e+00,  1.23168763e+01,  8.01532295e+01],
       [ 6.57153443e+00,  1.12686434e+01,  8.32735790e+01],
       [ 8.06216701e+00,  1.26805930e+01,  8.58973008e+01],
       [ 8.75001919e+00,  1.36698902e+01,  8.72099703e+01],
       [ 7.30252179e+00,  1.34260600e+01,  8.71816534e+01],
       [ 1.02174549e+01,  1.12734356e+01,  9.06574864e+01],
       [ 9.16397441e+00,  1.35946035e+01,  9.12502949e+01],
       [ 7.65119402e+00,  1.26062408e+01,  9.37067133e+01],
       [ 7.88012441e+00,  1.20190767e+01,  9.49682650e+01],
       [ 8.32044954e+00,  1.32807945e+01,  9.65808990e+01],
       [ 8.01089317e+00,  1.64722621e+01,  9.82354518e+01],
       [ 9.02271142e+00,  1.33190747e+01,  1.00825525e+02],
       [ 8.09970303e+00,  1.46680917e+01,  1.03017581e+02],
       [ 1.13875348e+01,  1.46989516e+01,  1.04003935e+02],
       [ 1.01333057e+01,  1.33257429e+01,  1.05931984e+02],
       [ 9.38629399e+00,  1.39040038e+01,  1.10363757e+02],
       [ 1.13412247e+01,  1.61090392e+01,  1.10731822e+02]])
#将数据拆分成训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(data[:,:2],data[:,-1])
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((75, 2), (25, 2), (75,), (25,))

2 通过K个近邻预测的标签的距离来预测当前样本的标签

#改写函数
#返回所有近邻的标签的均值作为当前x的预测值
def calcu_distance_return(x,X_train,y_train):
    KNN_x=[]
    #遍历训练集中的每个样本
    for i in range(X_train.shape[0]):
        if len(KNN_x)<K:
            KNN_x.append((euclidean(x,X_train[i]),y_train[i]))
        else:
            KNN_x.sort()
            for j in range(K): 
                if (euclidean(x,X_train[i]))< KNN_x[j][0]:
                    KNN_x[j]=(euclidean(x,X_train[i]),y_train[i])
                    break
    knn_label=[item[1] for item in KNN_x]           
    return np.mean(knn_label)
#对整个测试集进行预测
def predict(X_test):
    y_pred=np.zeros(X_test.shape[0])
    for i in range(X_test.shape[0]):
        y_hat_i=calcu_distance_return(X_test[i],X_train,y_train) 
        y_pred[i]=y_hat_i
    return y_pred
#输出预测结果
y_pred= predict(X_test)
y_pred
array([-48.77391118, -61.82953142,  -7.08681066,  31.79119171,
        89.89605669,  49.28413251,  52.97713079,  33.48545677,
        63.32131747,  98.05154212, -55.78008004,  98.04210317,
         7.02443886, -19.02562562,  11.49285143, -13.67585848,
        52.97713079,  21.82629113,  10.45687568,  55.14568247,
        -9.552268  ,  94.91846026, -11.51277047,  22.35944142,
        86.13169115])
y_test
array([-41.53734685, -58.05744051,  -1.46024067,  40.57597798,
       103.01758072,  66.10127272,  46.66025056,  56.97482206,
        63.0249467 , 100.8255246 , -54.62086294,  91.25029492,
         3.3520749 , -23.59621905,   1.17352349, -20.40784363,
        46.35507328,  21.23129715,   5.09073378,  59.21784029,
         7.90719675,  98.23545178,  -1.68276177,  17.71925914,
        78.40425661])

3 通过R方进行评估

from sklearn.metrics import r2_score
r2_score(y_test,y_pred)
0.9634297760055799

附:系列文章

实验 目录 直达链接
1 Numpy以及可视化回顾 https://want595.blog.csdn.net/article/details/131891689
2 线性回归 https://want595.blog.csdn.net/article/details/131892463
3 逻辑回归 https://want595.blog.csdn.net/article/details/131912053
4 多分类实践(基于逻辑回归) https://want595.blog.csdn.net/article/details/131913690
5 机器学习应用实践-手动调参 https://want595.blog.csdn.net/article/details/131934812
6 贝叶斯推理 https://want595.blog.csdn.net/article/details/131947040
7 KNN最近邻算法 https://want595.blog.csdn.net/article/details/131947885
8 K-means无监督聚类 https://want595.blog.csdn.net/article/details/131952371
9 决策树 https://want595.blog.csdn.net/article/details/131991014
10 随机森林和集成学习 https://want595.blog.csdn.net/article/details/132003451
11 支持向量机 https://want595.blog.csdn.net/article/details/132010861
12 神经网络-感知器 https://want595.blog.csdn.net/article/details/132014769
13 基于神经网络的回归-分类实验 https://want595.blog.csdn.net/article/details/132127413
14 手写体卷积神经网络 https://want595.blog.csdn.net/article/details/132223494
15 将Lenet5应用于Cifar10数据集 https://want595.blog.csdn.net/article/details/132223751
16 卷积、下采样、经典卷积网络 https://want595.blog.csdn.net/article/details/132223985

到了这里,关于【Python机器学习】实验06 KNN最近邻算法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【机器学习实战】K- 近邻算法(KNN算法)

    K-近邻算法 ,又称为  KNN 算法 ,是数据挖掘技术中原理最简单的算法。 KNN  的工作原理:给定一个已知类别标签的数据训练集,输入没有标签的新数据后,在训练数据集中找到与新数据最临近的 K 个实例。如果这 K 个实例的多数属于某个类别,那么新数据就属于这个类别。

    2023年04月20日
    浏览(48)
  • 机器学习——K最近邻算法(KNN)

    机器学习——K最近邻算法(KNN) 在传统机器学习中,KNN算法是一种基于实例的学习算法,能解决分类和回归问题,而本文将介绍一下KNN即K最近邻算法。 K最近邻(KNN)算法是一种基于实例的学习算法,用于分类和回归问题。它的原理是 根据样本之间的距离来进行预测 。 核

    2024年02月09日
    浏览(37)
  • 【机器学习】分类算法 - KNN算法(K-近邻算法)KNeighborsClassifier

    「作者主页」: 士别三日wyx 「作者简介」: CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」: 对网络安全感兴趣的小伙伴可以关注专栏《网络安全入门到精通》 is_array() 可以 「检测」 变量是不是 「数组」 类型。 语法 参数 $var :需要检

    2024年02月16日
    浏览(36)
  • 2.机器学习-K最近邻(k-Nearest Neighbor,KNN)分类算法原理讲解

    🏘️🏘️个人主页:以山河作礼。 🎖️🎖️: Python领域新星创作者,CSDN实力新星认证,CSDN内容合伙人,阿里云社区专家博主,新星计划导师,在职数据分析师。 🎉🎉 免费学习专栏 : 1. 《Python基础入门》——0基础入门 2.《Python网络爬虫》——从入门到精通 3.《Web全栈开

    2024年01月23日
    浏览(46)
  • 机器学习之KNN(K近邻)算法

    KNN算法又叫做K近邻算法,是众多机器学习算法里面最基础入门的算法。KNN算法是最简单的分类算法之一,同时,它也是最常用的分类算法之一。KNN算法是有监督学习中的分类算法,它看起来和Kmeans相似(Kmeans是无监督学习算法),但却是有本质区别的。 KNN算法基于实例之间

    2024年02月08日
    浏览(29)
  • 机器学习之——K近邻(KNN)算法

                    k-近邻算法(K-Nearest Neighbors,简称KNN)是一种用于分类和回归的统计方法。KNN 可以说是最简单的分类算法之一,同时,它也是最常用的分类算法之一。                 k-近邻算法基于某种距离度量来找到输入样本在训练集中的k个最近邻居,并且根据这k个

    2024年04月10日
    浏览(34)
  • 机器学习与模式识别2:KNN(k近邻)

    首先,随机选择K个对象,而且所选择的每个对象都代表一个组的初始均值或初始的组中心值,对剩余的每个对象,根据其与各个组初始均值的距离,将他们分配各最近的(最相似)小组,然后重新计算每个小组新的均值,这个过程不断重复,直到所有的对象在K组分布中都找

    2024年02月12日
    浏览(36)
  • python机器学习——分类模型评估 & 分类算法(k近邻,朴素贝叶斯,决策树,随机森林,逻辑回归,svm)

    交叉验证:为了让被评估的模型更加准确可信 交叉验证:将拿到的数据,分为训练和验证集。以下图为例:将数据分成5份,其中一份作为验证集。然后经过5次(组)的测试,每次都更换不同的验证集。即得到5组模型的结果,取平均值作为最终结果。又称5折交叉验证。 通常情

    2024年02月03日
    浏览(60)
  • 【OpenCV-Python】——机器学习kNN算法&SVM算法&k均值聚类算法&深度学习图像识别&对象检测

    目录 前言: 1、机器学习 1.1 kNN算法 1.2 SVM算法(支持向量机)  1.3 k均值聚类算

    2024年02月05日
    浏览(41)
  • 机器学习之Python使用KNN算法对鸢尾花进行分类

    要求: (1)数据集划分为测试集占20%; (2)n_neighbors=5; (3)评价模型的准确率; (4)使用模型预测未知种类的鸢尾花。 (待预测数据:X1=[[1.5 , 3 , 5.8 , 2.2], [6.2 , 2.9 , 4.3 , 1.3]]) iris数据集有150组,每组4个数据。 第一步:引入所需库 第二步:划分测试集占20% test_size为

    2024年02月08日
    浏览(33)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包