Python鸢尾花SVM分类模型代码

这篇具有很好参考价值的文章主要介绍了Python鸢尾花SVM分类模型代码。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一.前言

       机器学习的经典实验,对于数据集进行分类,网上看了一点其他的和GPT写的,好像只展示了4个特征中两个特征与3种类别的分类图,在我做这个实验交报告时,老师就问这个特征之间有很多交叉的点,在线性模型不应该得到分类准确度接近1的效果,后面改进加上另外两个特征的分类图可以发现,另外两个特征和类别有非常明显的线性关系,且分类的界限也非常清晰,所以模型分类准确度是合理的。下面主要是代码分享,给有这个学习需求或者课程实验的朋友们提供这个代码来学习或者参考。

二.实验要求

相当于我下面展示的代码的实现功能了

1.鸢尾花数据集准备与理解,并对数据集进行可视化分析;
2.随机划分数据集,80%样本作为训练数据,20%样本作为测试数据;
4.用训练数据分别训练以下2种SVM模型:
线性SVM模型
基于RBF的非线性SVM模型
5.分别对上述2种模型进行调优,性能指标为分类准确度﹔
6.测试上述2种模型在测试集上的分类性能。

三.说明

3.1.需导入的库

1.pandas

2.matplotblib

3.sklearn(安装时使用全称:scikit-learn)

全部都使用pip安装命令即可,简单快捷

win+R,输入cmd回车,pip命令【后面是镜像,不加速度非常慢】:

pip install 库名 -i https://pypi.doubanio.com/simple

3.2.数据集特征说明

这里使用的数据集是sklearn内置的数据集,不需要再下载数据集

load_iris()数据集包含了150个样本,每个样本有4个特征,分别是:

  • 花萼长度(sepal length)
  • 花萼宽度(sepal width)
  • 花瓣长度(petal length)
  • 花瓣宽度(petal width)

。这4个特征的单位均为厘米。这个数据集是鸢尾花数据集(Iris dataset)的一部分,共有3个不同的鸢尾花种类,每个种类有50个样本。3种鸢尾花的种类分别是:

  • 0:Setosa(山鸢尾)
  • 1:Versicolour(杂色鸢尾)
  • 2:Virginica(维吉尼亚鸢尾)

四.代码分享

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score


# 导入数据,分离特征与输出
iris = load_iris()
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
target = pd.Series(iris.target)
print('特征和目标的大小', iris_df.shape, target.shape)

# 画布属性设置
mpl.rcParams['font.family'] = ['sans-serif']
mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False

# 散点图
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
scatter = plt.scatter(iris_df.iloc[:, 0], iris_df.iloc[:, 1], c=target, cmap='viridis')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('鸢尾花特征分析散点图')
a, b = scatter.legend_elements()
print('散点对象及其标签', a, b)
b = ['setosa', 'versicolor', 'virginica']  # 原标签0,1,2分别对应以下三种
plt.legend(a, b)

plt.subplot(1, 2, 2)
scatter = plt.scatter(iris_df.iloc[:, 2], iris_df.iloc[:, 3], c=target, cmap='viridis')
plt.xlabel('Petal length')
plt.ylabel('Petal width')
plt.title('鸢尾花特征分析散点图')
a, b = scatter.legend_elements()
print('散点对象及其标签', a, b)
b = ['setosa', 'versicolor', 'virginica']  # 原标签0,1,2分别对应以下三种
plt.legend(a, b)
plt.show()

# 划分训练和测试数据
X_train, X_test, y_train, y_test = train_test_split(iris_df, target, test_size=0.2, random_state=5)

# 线型SVM模型训练
linear_svm = SVC(kernel='linear', C=1, random_state=0)
linear_svm.fit(X_train, y_train)

# 基于RBF非线性SVM模型训练
rbf_svm = SVC(kernel='rbf', C=1, gamma=0.1, random_state=0)
rbf_svm.fit(X_train, y_train)

# 模型评估
linear_y_pred = linear_svm.predict(X_test)
linear_accuracy = accuracy_score(y_test, linear_y_pred)

rbf_y_pred = rbf_svm.predict(X_test)
rbf_accuracy = accuracy_score(y_test, rbf_y_pred)

print()
print('线性SVM模型模型评估:')
print('参数C:', linear_svm.C)
print('分类准确度accuracy: ', linear_accuracy)
print()
print('RBF非线性SVM模型评估:')
print('参数C:{:d},gamma:{:.1f}'.format(rbf_svm.C, rbf_svm.gamma))
print('分类准确度accuracy: ', rbf_accuracy)

# print('真实分类:', list(y_test))
# print('预测分类', list(linear_y_pred))


# 线型模型分类结果展示
plt.figure(figsize=(8, 8))
plt.subplot(2, 2, 1)
plt.scatter(X_test.iloc[:, 0], X_test.iloc[:, 1], c=y_test, cmap='viridis')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('真实散点图')
plt.subplot(2, 2, 2)
plt.scatter(X_test.iloc[:, 0], X_test.iloc[:, 1], c=linear_y_pred, cmap='viridis')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('预测散点图')
plt.subplot(2, 2, 3)
plt.scatter(X_test.iloc[:, 2], X_test.iloc[:, 3], c=y_test, cmap='viridis')
plt.xlabel('Petal length')
plt.ylabel('Petal width')
plt.title('真实散点图')
plt.subplot(2, 2, 4)
plt.scatter(X_test.iloc[:, 2], X_test.iloc[:, 3], c=linear_y_pred, cmap='viridis')
plt.xlabel('Petal length')
plt.ylabel('Petal width')
plt.title('预测散点图')
plt.suptitle('线性模型分类结果对照图')
plt.show()

# rbf非线型模型分类结果展示
plt.figure(figsize=(8, 8))
plt.subplot(2, 2, 1)
plt.scatter(X_test.iloc[:, 0], X_test.iloc[:, 1], c=y_test, cmap='viridis')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('真实散点图')
plt.subplot(2, 2, 2)
plt.scatter(X_test.iloc[:, 0], X_test.iloc[:, 1], c=rbf_y_pred, cmap='viridis')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('预测散点图')
plt.subplot(2, 2, 3)
plt.scatter(X_test.iloc[:, 2], X_test.iloc[:, 3], c=y_test, cmap='viridis')
plt.xlabel('Petal length')
plt.ylabel('Petal width')
plt.title('真实散点图')
plt.subplot(2, 2, 4)
plt.scatter(X_test.iloc[:, 2], X_test.iloc[:, 3], c=rbf_y_pred, cmap='viridis')
plt.xlabel('Petal length')
plt.ylabel('Petal width')
plt.title('预测散点图')
plt.suptitle('rbf非线性模型分类结果对照图')
plt.show()

# 参数调优
print()
print('模型调优')
linear_params = {'C': [0.1, 0.5, 1, 5, 10]}
linear_gridsearch = GridSearchCV(SVC(kernel='linear'), linear_params, scoring='accuracy', cv=5)
linear_gridsearch.fit(X_train, y_train)
print('线性SVM参数调优结果:', linear_gridsearch.best_params_)
linear_svm.C = linear_gridsearch.best_params_['C']

rbf_params = {'C': [0.1, 0.5, 1, 5, 10], 'gamma': [0.001, 0.01, 0.1, 1, 10]}
rbf_gridsearch = GridSearchCV(SVC(kernel='rbf'), rbf_params, cv=5)
rbf_gridsearch.fit(X_train, y_train)
print('RBF非线性SVM参数调优结果:', rbf_gridsearch.best_params_)
rbf_svm.C = rbf_gridsearch.best_params_['C']
rbf_svm.gamma = rbf_gridsearch.best_params_['gamma']

# 模型评估
linear_y_pred = linear_svm.predict(X_test)
linear_accuracy = accuracy_score(y_test, linear_y_pred)

rbf_y_pred = rbf_svm.predict(X_test)
rbf_accuracy = accuracy_score(y_test, rbf_y_pred)

print()
print('参数调优后的模型性能:')
print('C:', linear_svm.C)
print('线性SVM模型准分类确度: ', linear_accuracy)
print('C,gamma:', rbf_svm.C, rbf_svm.gamma)
print('RBF非线性SVM模型准分类确度:  ', rbf_accuracy)

五.结果展示

输出:

参数调优可以看到并没有增加分类准确度,调优在于使得模型训练能更快收敛。同时也可以看出非线性模型的分类效果比线性模型好一点

线性SVM模型模型评估:
参数C: 1
分类准确度accuracy:  0.9333333333333333

RBF非线性SVM模型评估:
参数C:1,gamma:0.1
分类准确度accuracy:  0.9666666666666667

模型调优
线性SVM参数调优结果: {'C': 0.5}
RBF非线性SVM参数调优结果: {'C': 5, 'gamma': 0.1}

参数调优后的模型性能:
C: 0.5
线性SVM模型准分类确度:  0.9333333333333333
C,gamma: 5 0.1
RBF非线性SVM模型准分类确度:   0.9666666666666667

Process finished with exit code 0
 

第一张图是4个特征与3种分类的散点图,一张图就描述两个特征,可以看到前两个特征来对花进行分类的话,ve,vi两类有很多交叉,分类是会产生误差的,但是后面两个特征与类别有明显线性关系,有清晰的分界线,除了少量点落在分界线附件可能产生错误分类,基本不会出现错误分类。

 用鸢尾花算法实现svm,机器学习,人工智能,python,分类

 下面二张图是线性分类模型和rbf非线性模型训练后的预测对照图,分类准确度都超过90%,其实是故意用随机数种子得到的,改一下种子可以得到100%,因为特征与分类有明显线性关系,所以如果对数据集划分训练集和测试集时,测试集刚刚好没有被划分到在交界处的点,则分类可达100,都是正确的。

用鸢尾花算法实现svm,机器学习,人工智能,python,分类

用鸢尾花算法实现svm,机器学习,人工智能,python,分类

 六.后言

        代码还有些地方待优化,但问题不大,懒得改了。嘿嘿文章来源地址https://www.toymoban.com/news/detail-810576.html

到了这里,关于Python鸢尾花SVM分类模型代码的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 初识机器学习——感知机(Perceptron)+ Python代码实现鸢尾花分类

      假设输入空间 χ ⊆ R n chisubseteq R^n χ ⊆ R n ,输出空间为 γ = { + 1 , − 1 } gamma=left { +1,-1right } γ = { + 1 , − 1 } 。其中每一个输入 x ⊆ χ xsubseteq chi x ⊆ χ 表示对应于实例的特征向量,也就是对应于输入空间(特征空间)的一个点, y ⊆ γ ysubseteq gamma y ⊆ γ 输出表

    2023年04月08日
    浏览(38)
  • 机器学习与深度学习——通过SVM线性支持向量机分类鸢尾花数据集iris求出错误率并可视化

    先来看一下什么叫数据近似线性可分,如下图所示,蓝色圆点和红色圆点分别代表正类和负类,显然我们不能找到一个线性的分离超平面将这两类完全正确的分开;但是如果将数据中的某些特异点(黑色箭头指向的点)去除之后,剩下的大部分样本点组成的集合是线性可分的,

    2023年04月18日
    浏览(39)
  • 机器学习-KNN算法(鸢尾花分类实战)

    前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 K近邻(K Nearest Neighbors,KNN)算法是最简单的分类算法之一,也就是根据现有训练数据判断输入样本是属于哪一个类别。 “近朱者赤近墨者黑\\\",所谓的K近邻,也就

    2023年04月08日
    浏览(63)
  • 机器学习:KNN算法对鸢尾花进行分类

    1.算法概述 KNN(K-NearestNeighbor)算法经常用来解决分类与回归问题, KNN算法的原理可以总结为\\\"近朱者赤近墨者黑\\\",通过数据之间的相似度进行分类。就是通过计算测试数据和已知数据之间的距离来进行分类。 如上图,四边形代表测试数据,原型表示已知数据,与测试数据最

    2024年02月09日
    浏览(39)
  • MapReduce实现KNN算法分类推测鸢尾花种类

    https://gitcode.net/m0_56745306/knn_classifier.git 该部分内容参考自:https://zhuanlan.zhihu.com/p/45453761 KNN(K-Nearest Neighbor) 算法是机器学习算法中最基础、最简单的算法之一。它既能用于分类,也能用于回归。KNN通过测量不同特征值之间的距离来进行分类。 KNN算法的思想非常简单:对于任

    2024年02月08日
    浏览(38)
  • 经典案例——利用 KNN算法 对鸢尾花进行分类

    实现流程:         1、获取数据集         2、数据基本处理         3、数据集预处理-数据标准化         4、机器学习(模型训练)         5、模型评估         6、模型预测 具体API: 1、获取数据集  查看各项属性  2、数据基本处理   3、数据集预处理

    2024年02月02日
    浏览(37)
  • 使用决策树对鸢尾花进行分类python

    鸢尾花数据集介绍 target介绍 1:绘制直方图 2.png)] 1:划分训练集和测试集 构建训练集和测试集,分别保存在X_train,y_train,X_test,y_test from sklearn.model_selection import train_test_split 2:训练和分类 from sklearn.tree import DecisionTreeClassifier DecisionTreeClassifier() DecisionTreeClassifier(criterion=‘entro

    2024年02月06日
    浏览(29)
  • 机器学习之线性回归与逻辑回归【完整房价预测和鸢尾花分类代码解释】

    目录 前言 一、什么是线性回归 二、什么是逻辑回归 三、基于Python 和 Scikit-learn 库实现线性回归 示例代码:  使用线性回归来预测房价: 四、基于Python 和 Scikit-learn 库实现逻辑回归 五、总结  线性回归的优缺点总结: 逻辑回归(Logistic Regression)是一种常用的分类算法,具有

    2024年04月13日
    浏览(34)
  • 【Python】使用Pandas和随机森林对鸢尾花数据集进行分类

    我在鼓楼的夜色中 为你唱花香自来 在别处 沉默相遇和期待 飞机飞过 车水马龙的城市 千里之外 不离开 把所有的春天 都揉进了一个清晨 把所有停不下的言语变成秘密 关上了门 莫名的情愫啊 请问 谁来将它带走呢 只好把岁月化成歌 留在山河                      🎵

    2024年04月26日
    浏览(26)
  • 机器学习---使用 TensorFlow 构建神经网络模型预测波士顿房价和鸢尾花数据集分类

    1. 预测波士顿房价 1.1 导包 最后一行设置了TensorFlow日志的详细程度: tf.logging.DEBUG :最详细的日志级别,用于记录调试信息。 tf.logging.INFO :用于记录一般的信息性消息,比如训练过程中的指标和进度。 tf.logging.WARN :用于记录警告消息,表示可能存在潜在问题,但不会导致

    2024年02月08日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包