MapReduce实现KNN算法分类推测鸢尾花种类

这篇具有很好参考价值的文章主要介绍了MapReduce实现KNN算法分类推测鸢尾花种类。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

代码地址

https://gitcode.net/m0_56745306/knn_classifier.git

一、KNN算法简介

该部分内容参考自:https://zhuanlan.zhihu.com/p/45453761

  • KNN(K-Nearest Neighbor) 算法是机器学习算法中最基础、最简单的算法之一。它既能用于分类,也能用于回归。KNN通过测量不同特征值之间的距离来进行分类。

  • KNN算法的思想非常简单:对于任意n维输入向量,分别对应于特征空间中的一个点,输出为该特征向量所对应的类别标签或预测值。

  • 对于一个需要预测的输入向量x,我们只需要在训练数据集中寻找k个与向量x最近的向量的集合,然后把x的类别预测为这k个样本中类别数最多的那一类。
    MapReduce实现KNN算法分类推测鸢尾花种类
    如图所示,ω1、ω2、ω3分别代表训练集中的三个类别。其中,与xu最相近的5个点(k=5)如图中箭头所指,很明显与其最相近的5个点中最多的类别为ω1,因此,KNN算法将xu的类别预测为ω1。

二、KNN算法示例:推测鸢尾花种类

鸢尾花数据集记载了三类花(Setosa,versicolor,virginica)以及它们的四种属性(花萼长度、花萼宽度、花瓣长度、花瓣宽度)。例如:

4.9,3.0,1.4,0.2,setosa
6.4,3.2,4.5,1.5,versicolor
6.0,2.2,5.0,1.5,virginica

对于给定的测试数据,我们需要根据它的四种信息判断其属于哪一种鸢尾花。并输出它的序号:
例如:

#假设该数据为第一条数据(对应序号为0)
5.7,3.0,4.2,1.2  

输出可以为:

0 setosa

三、MapReduce+Hadoop实现KNN鸢尾花分类:

1. 实现环境

  • Ubuntu20.04
  • Hadoop3.3.5
  • Java8
  • Maven3.9.1

2.pom.xml

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>


    <groupId>org.example</groupId>
    <artifactId>KNN_Classifier</artifactId>
    <version>1.0-SNAPSHOT</version>
    <packaging>jar</packaging>

    <name>KNN_Classifier</name>
    <url>http://maven.apache.org</url>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>3.1.0</version>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                    </execution>
                </executions>
                <configuration>
                    <filters>
                        <filter>
                            <artifact>*:*</artifact>
                            <excludes>
                                <exclude>module-info.class</exclude>
                                <exclude>META-INF/*.SF</exclude>
                                <exclude>META-INF/*.DSA</exclude>
                                <exclude>META-INF/*.RSA</exclude>
                            </excludes>
                        </filter>
                    </filters>
                    <transformers>
                        <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                            <!-- main()所在的类,注意修改 -->
                            <mainClass>KNN_Classifier.KNN_Driver</mainClass>
                        </transformer>
                    </transformers>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.8.1</version>
                <configuration>
                    <source>8</source>
                    <target>8</target>
                    <encoding>UTF-8</encoding>
                </configuration>
            </plugin>
        </plugins>
    </build>


    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <java.version>17</java.version>
        <maven.compiler.source>17</maven.compiler.source>
        <maven.compiler.target>17</maven.compiler.target>
    </properties>

    <dependencies>
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.11</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-common</artifactId>
            <version>3.3.5</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-hdfs -->
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-hdfs</artifactId>
            <version>3.3.5</version>
        </dependency>

        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-mapreduce-client-core</artifactId>
            <version>3.3.5</version>
        </dependency>

    </dependencies>
</project>

3.设计思路及代码

1. KNN_Driver类

Diriver类主要负责初始化job的各项属性,同时将训练数据加载到缓存中去,以便于Mapper读取。同时为了记录测试数据量,在conf中设置testDataNum用于在map阶段记录。

package KNN_Classifier;


import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;

public class KNN_Driver {
    public static void main(String[] args) throws Exception {
        Configuration conf = new Configuration();
        GenericOptionsParser optionParser = new GenericOptionsParser(conf, args);
        String[] remainingArgs = optionParser.getRemainingArgs();
        if (remainingArgs.length != 3) {
            System.err.println("Usage: KNN_Classifier <training dataset> <test dataset> <output>");
            System.exit(2);
        }

        conf.setInt("K",5);//设置KNN算法的K值
        conf.setInt("testDataNum",0);//设置全局计数器,记录测试数据数目
        conf.setInt("dimension",4);//设置向量维度
        Job job = Job.getInstance(conf, "KNN_Classifier");
        job.setJarByClass(KNN_Driver.class);
        job.setMapperClass(KNN_Mapper.class);
        job.setReducerClass(KNN_Reducer.class);

        //将训练数据添加到CacheFile中
        job.addCacheFile(new Path(remainingArgs[0]).toUri());
        FileInputFormat.addInputPath(job, new Path(remainingArgs[1]));
        FileOutputFormat.setOutputPath(job, new Path(remainingArgs[2]));
        job.waitForCompletion(true);
        System.exit(0);
    }
}

2. MyData类

这个类对每条测试数据进行封装,同时用于计算向量距离。

package KNN_Classifier;

import java.util.Vector;

public class MyData {

    //向量维度
    private Integer dimension;
    //向量坐标
    private Vector<Double>vec = new Vector<Double>();
    //属性,这里是水仙花的种类
    private String attr = new String();

    public  void setAttr(String attr)
    {
        this.attr = attr;
    }

    public void setVec(Vector<Double> vec) {
        this.dimension = vec.size();
        for(Double d : vec)
        {
            this.vec.add(d);
        }
    }

    public double calDist(MyData data1)//计算两条数据之间的欧式距离
    {
        try{
            if(this.dimension != data1.dimension)
                throw new Exception("These two vectors have different dimensions.");

        }
        catch (Exception e)
        {
            System.out.println(e.getMessage());
            System.exit(-1);
        }
        double dist = 0;
        for(int i = 0;i<dimension;i++)
        {
            dist += Math.pow(this.vec.get(i)-data1.vec.get(i),2);
        }
        dist = Math.sqrt(dist);
        return dist;
    }

    public String getAttr() {
        return attr;
    }
}

3. KNN_Mapper类

  • setup:用于加载缓存中的训练数据到Mapper的列表当中,同时读取K值、维度等必要信息。

  • readTrainingData:由setup调用,加载缓存训练数据。

  • Gaussian:用于计算欧式距离x所占权重,它的公式为:
    f ( x ) = a e ( x − b ) 2 − 2 c 2 f(x) = ae^{\frac{(x-b)^2}{-2c^2}} f(x)=ae2c2(xb)2
    它的图像为:

MapReduce实现KNN算法分类推测鸢尾花种类

∣ x ∣ |x| x绝对值增加, f ( x ) f(x) f(x)的值越来越小,可以反映距离对权重的影响:即欧式距离越大,权重越小,对标签的影响也越小。

实际上高斯函数各个参数的确定需要对样本数据经过多次交叉验证得出,但为了简单起见,这里另a=1,b=0,c=0.9即可(这种情况下训练的结果比较好一些)。

  • map:对得到的测试数据进行KNN算法处理,它的伪代码如下:

    map(key,val): #key为样本数据偏移量,val为该行数据
    	testData = getTestData ; #从val中读取测试数据信息
    	K_Nearest = Empty ; #K最近邻,可以用最大堆来实现
    	for trainingData in trainingDataSet : #遍历可以改为用KDTree优化
    		dist = CalDist(testData,trainingData) ;
    		if sizeof(K_Nearest) < K : #如果此时还未达到K值,直接添加
    			K_Nearest.add(dist,trainingData.attr) ;
    		else :
    			if dist < K_Nearest.maxDist : #如果计算得出的距离大于当前K个点之中最大距离,则替换之
    				replace pair with maxDist to (dist,trainingData.attr) ; 
    	
    	calculate weight sum for every attr ; #为每种标签计算权重和
    	write(idx,max_weight_attr); #写入序号,最大权重标签,完成分类
    

综上,下面是KNN_Mapper的代码:

package KNN_Classifier;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.*;
import java.net.URI;
import java.io.BufferedReader;
import java.io.FileReader;
import javafx.util.Pair;

public class KNN_Mapper extends Mapper<LongWritable, Text, LongWritable, Text> {
    private Text text = new Text();//输出Val值

    private LongWritable longWritable = new LongWritable();//输出K值
    private Integer K;//K值

    private Configuration conf;//全局配置
    private Integer dimension;//维度
    private List<MyData> training_data = new ArrayList<>();
    
    
    private void readTrainingData(URI uri)//读取训练数据到training_data中
    {
        System.err.println("Read Training Data");
        try{
            Path patternsPath = new Path(uri.getPath());
            String patternsFileName = patternsPath.getName().toString();
            BufferedReader reader = new BufferedReader(new FileReader(
                    patternsFileName));
            String line;
            Vector<Double>vec = new Vector<>();
            while ((line = reader.readLine()) != null) {
                // TODO: your code here
                //
                String[] strings = line.split(",");

                for(int i=0;i<dimension;i++)
                {
                    vec.add(Double.valueOf(strings[i]));
                }
                MyData myData = new MyData();
                myData.setVec(vec);
                myData.setAttr(strings[dimension]);
                System.out.println(strings[dimension]);
                training_data.add(myData);
                vec.clear();
            }
            reader.close();
        }
        catch (FileNotFoundException e)
        {
            e.printStackTrace();
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
        System.err.println("Read End");
    }

    private double Gaussian(double dist)
    {
        //a = 1,b=0,c = 0.9,2*c^2 = 1.62
        double weight = Math.exp(-Math.pow(dist,2)/(1.62));
        return weight;
    }
    
    @Override
    public void setup(Context context) throws IOException,
            InterruptedException {

        conf = context.getConfiguration();
        this.K = conf.getInt("K",1);
        this.dimension = conf.getInt("dimension",1);
        URI[] uri = context.getCacheFiles();
        readTrainingData(uri[0]);
    }

    @Override
    public void map(LongWritable key, Text value, Context context
    ) throws IOException, InterruptedException {
        String line = value.toString();
        try {
            String[] strings = line.split(",");
            if (strings.length!=dimension) {
                throw new Exception("Error line format in the table.");
            }

            //获取测试数据信息
            Vector<Double>vec = new Vector<>();
            for(String s:strings)
            {
                System.err.println("S: "+s);
                vec.add(Double.valueOf(s));
            }
            MyData testData = new MyData();
            testData.setVec(vec);

            //计算与样本的K近邻

            //存放K近邻的优先级队列,元素类型为<距离,属性>
            PriorityQueue<Pair<Double,String>>K_nearst = new PriorityQueue<>((a,b)->(a.getKey()>b.getKey())?-1:1);
            double dist;
            for(MyData data : this.training_data)
            {
                dist = testData.calDist(data);
                if(K_nearst.size()<this.K)
                {
                    K_nearst.add(new Pair<>(dist,data.getAttr()));
                }
                else{
                    if(dist < K_nearst.peek().getKey())
                    {
                        K_nearst.poll();
                        K_nearst.add(new Pair<>(dist,data.getAttr()));
                    }
                }
            }

            //获取到K近邻后,通过高斯函数处理每条数据,并累加相同属性的权值,通过Hash_table实现
            Hashtable<String,Double>weightTable = new Hashtable<>();
            while(!K_nearst.isEmpty())
            {
                double d = K_nearst.peek().getKey();
                String attr = K_nearst.peek().getValue();
                double w = this.Gaussian(d);
                if(!weightTable.contains(attr))
                {
                    weightTable.put(attr,w);

                }
                else{
                    weightTable.put(attr,weightTable.get(attr)+w);
                }
                K_nearst.poll();
            }

            //选取权重最大的标签作为输出
            Double max_weight = Double.MIN_VALUE;
            String target_attr = "";

            for(Iterator<String> itr = weightTable.keySet().iterator();itr.hasNext();){
                String hash_key = (String)itr.next();
                Double hash_val = weightTable.get(hash_key);
                if(hash_val > max_weight)
                {
                    target_attr = hash_key;
                    max_weight = hash_val;
                }
            }

            text.set(target_attr);
            //获取测试数据条数,用作下标计数
            longWritable.set(conf.getLong("testDataNum",0));
            conf.setLong("testDataNum",longWritable.get()+1);//计数加一
            context.write(longWritable,text);
        }
        catch (Exception e) {
            System.err.println(e.toString());
            System.exit(-1);
        }
    }
}

4. KNN_Reducer类

由于Mapper类已经完成了所有工作,所以传入到Reducer中的键值对都是Index,Attr的形式,直接写入即可。文章来源地址https://www.toymoban.com/news/detail-479059.html

package KNN_Classifier;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
import java.io.IOException;

public class KNN_Reducer extends Reducer<LongWritable, Text,LongWritable,Text> {
    public void reduce(LongWritable key, Iterable<Text> values,
                       Context context
    ) throws IOException, InterruptedException {

        for(Text val:values)
        {
            context.write(key,val);
        }
    }
}

到了这里,关于MapReduce实现KNN算法分类推测鸢尾花种类的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

    什么是knn算法? KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm)。它是一种简单但非常有效的分类和回归算法。 该算法的基本思想是:对于一个新的输入样本,通过计算它与训练集中所有样本的距离,找到与它距离最近的K个训练集样

    2024年02月03日
    浏览(41)
  • 【机器学习】KNN算法-鸢尾花种类预测

    K最近邻(K-Nearest Neighbors,KNN)算法是一种用于模式识别和分类的简单但强大的机器学习算法。它的工作原理非常直观:给定一个新数据点,KNN算法会查找离这个数据点最近的K个已知数据点,然后基于这K个最近邻数据点的类别来决定新数据点的类别。简而言之,KNN算法通过周

    2024年02月07日
    浏览(42)
  • 【机器学习算法】KNN鸢尾花种类预测案例和特征预处理。全md文档笔记(已分享,附代码)

    本系列文章md笔记(已分享)主要讨论机器学习算法相关知识。机器学习算法文章笔记以算法、案例为驱动的学习,伴随浅显易懂的数学知识,让大家掌握机器学习常见算法原理,应用Scikit-learn实现机器学习算法的应用,结合场景解决实际问题。包括K-近邻算法,线性回归,逻

    2024年02月19日
    浏览(47)
  • 【机器学习案例】不同的模型算法对鸢尾花数据集进行分类

    经典机器学习入门项目,使用逻辑回归、线性判别分析、KNN、分类与回归树、朴素贝叶斯、向量机、随机森林、梯度提升决策树对不同占比的训练集进行分类 数据源 :Iris Species | Kaggle 150行,5列,分三种鸢尾花类型,每种类型50个样本,每行数据包含花萼长度、花萼宽度、花

    2024年02月04日
    浏览(34)
  • 初识机器学习——感知机(Perceptron)+ Python代码实现鸢尾花分类

      假设输入空间 χ ⊆ R n chisubseteq R^n χ ⊆ R n ,输出空间为 γ = { + 1 , − 1 } gamma=left { +1,-1right } γ = { + 1 , − 1 } 。其中每一个输入 x ⊆ χ xsubseteq chi x ⊆ χ 表示对应于实例的特征向量,也就是对应于输入空间(特征空间)的一个点, y ⊆ γ ysubseteq gamma y ⊆ γ 输出表

    2023年04月08日
    浏览(50)
  • 【决策树-鸢尾花分类】

    决策树是一种基于树状结构的分类与回归算法。它通过对数据集进行递归分割,将样本划分为多个类别或者回归值。决策树算法的核心思想是通过构建树来对数据进行划分,从而实现对未知样本的预测。 决策树的构建过程包括以下步骤: 选择特征 :从数据集中选择一个最优

    2024年02月15日
    浏览(41)
  • K-Means算法实现鸢尾花数据集聚类

    张勇,男,西安工程大学电子信息学院,2022级研究生 研究方向:智能信息处理与信息系统研究 电子邮件:17605542959@163.com 陈梦丹,女,西安工程大学电子信息学院,2022级硕士研究生,张宏伟人工智能课题组 研究方向:机器视觉与人工智能 电子邮件:1169738496@qq.com K-Means聚类

    2024年02月07日
    浏览(43)
  • 【机器学习】鸢尾花分类-逻辑回归示例

    功能: 这段代码演示了如何使用逻辑回归对鸢尾花数据集进行训练,并将训练好的模型保存到文件中。然后,它允许用户输入新的鸢尾花特征数据,使用保存的模型进行预测,并输出预测结果。 步骤概述: 加载数据和预处理: 使用 Scikit-Learn 中的 datasets 模块加载鸢尾花数据

    2024年02月10日
    浏览(43)
  • Python鸢尾花SVM分类模型代码

           机器学习的经典实验,对于数据集进行分类,网上看了一点其他的和GPT写的,好像只展示了4个特征中两个特征与3种类别的分类图,在我做这个实验交报告时,老师就问这个特征之间有很多交叉的点,在线性模型不应该得到分类准确度接近1的效果,后面改进加上另外

    2024年01月21日
    浏览(61)
  • Python基于Pytorch Transformer实现对iris鸢尾花的分类预测,分别使用CPU和GPU训练

    iris数据集是机器学习中一个经典的数据集,由英国统计学家Ronald Fisher在1936年收集整理而成。该数据集包含了3种不同品种的鸢尾花(Iris Setosa,Iris Versicolour,Iris Virginica)各50个样本,每个样本包含了花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)、花瓣宽度

    2024年02月01日
    浏览(60)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包