基于weka手工实现ID3决策树

这篇具有很好参考价值的文章主要介绍了基于weka手工实现ID3决策树。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一、决策树ID3算法

相比于logistic回归、BP网络、支持向量机等基于超平面的方法,决策树更像一种算法,里面的数学原理并不是很多,较好理解。

决策树就是一个不断地属性选择、属性划分地过程,直到满足某一情况就停止划分。

  1. 当前样本全部属于同一类别了(信息增益为0);
  2. 已经是空叶子了(没有样本了);
  3. 当前叶子节点所有样本所有属性上取值相同,无法划分了(信息增益为0)。

信息增益如何计算?根据信息熵地变化量,信息熵减少最大地属性就是我们要选择地属性。

信息熵定义:

E n t ( D ) = − ∑ k = 1 ∣ y ∣ p k l o g 2 p k Ent(D)=-\sum_{k=1}^{|y|}p_klog_2p_k Ent(D)=k=1ypklog2pk

信息增益定义:

G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 v ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D,a)=Ent(D)-\sum_{v=1}^v\frac{|D^v|}{|D|}Ent(D^v) Gain(D,a)=Ent(D)v=1vDDvEnt(Dv)

信息增益越大,则意味着属性a来划分所获得的“纯度提升”越大。

ID3就是以信息增益作为属性选择和划分的标准的。有了决策树生长和停止生长的条件,剩下的其实就是一些编程技巧了,我们就可以进行编码了。

除此之外,决策树还有C4.5等其它实现的算法,包括基尼系数、增益率、剪枝、预剪枝等防止过拟合的方法,但决策树最本质、朴素的思想还是在ID3中体现的最好。

具体可以参考这篇博客:机器学习06:决策树学习.文章来源地址https://www.toymoban.com/news/detail-629751.html

二、基于weka平台实现ID3决策树

package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.*;

/**
 * @author YFMan
 * @Description 自定义的 ID3 分类器
 * @Date 2023/5/25 18:07
 */
public class myId3 extends Classifier {

    // 当前节点 的 后续节点
    private myId3[] m_Successors;

    // 当前节点的划分属性 (如果为空,说明当前节点是叶子节点;否则,说明当前节点是中间节点)
    private Attribute m_Attribute;

    // 当前节点的类别分布 (如果为中间节点,全为 0;为叶子节点,为类别分布)
    private double[] m_Distribution;

    // 当前节点的类别 (如果为中间节点,为 0;为叶子节点,为类别分布)
    // (用于获取类别的索引,对于算法本身没用,但对于可视化 决策树有用)
    private double m_ClassValue;

    // 当前节点的类别属性 (如果为中间节点,为 null;为叶子节点,为类别属性)
    // (用于获取类别的名称,对于算法本身没用,但对于可视化 决策树有用)
    private Attribute m_ClassAttribute;

    /*
     * @Author YFMan
     * @Description 根据训练数据 建立 决策树
     * @Date 2023/5/25 18:43
     * @Param [data]
     * @return void
     **/
    public void buildClassifier(Instances data) throws Exception {
        // 建树
        makeTree(data);
    }

    /*
     * @Author YFMan
     * @Description 根据训练数据 建立 决策树
     * @Date 2023/5/25 18:43
     * @Param [data] 训练数据
     * @return void
     **/
    private void makeTree(Instances data) throws Exception {

        // 如果是空叶子,拒绝建树 (拒判)
        if (data.numInstances() == 0) {
            m_Attribute = null;
            m_ClassValue = Instance.missingValue();
            m_Distribution = new double[data.numClasses()];
            return;
        }

        // 计算 所有属性的 信息增益
        double[] infoGains = new double[data.numAttributes()];
        // 遍历所有属性
        for(int i = 0; i < data.numAttributes(); i++) {
            // 如果是类别属性,跳过
            if (i == data.classIndex()) {
                infoGains[i] = 0;
            } else {
                // 计算信息增益
                infoGains[i] = computeInfoGain(data, data.attribute(i));
            }
        }

        // 选择信息增益最大的属性
        m_Attribute = data.attribute(Utils.maxIndex(infoGains));

        // 如果信息增益为 0,说明当前节点包含的样例都属于同一类别,直接设置为叶子节点
        if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
            // 设置为叶子节点
            m_Attribute = null;
            m_Distribution = new double[data.numClasses()];
            // 遍历所有样例
            for (int i = 0; i < data.numInstances(); i++) {
                // 获取当前样例的类别
                Instance inst = data.instance(i);
                // 统计类别分布
                m_Distribution[(int) inst.classValue()]++;
            }
            // 归一化
            Utils.normalize(m_Distribution);
            // 设置类别
            m_ClassValue = Utils.maxIndex(m_Distribution);
            m_ClassAttribute = data.classAttribute();
        } else { // 否则,递归建树
            // 划分数据集
            Instances[] splitData = splitData(data, m_Attribute);
            // 创建叶子
            m_Successors = new myId3[m_Attribute.numValues()];
            // 叶子再去长叶子,递归调用
            for (int j = 0; j < m_Attribute.numValues(); j++) {
                m_Successors[j] = new myId3();
                m_Successors[j].makeTree(splitData[j]);
            }
        }


    }

    /*
     * @Author YFMan
     * @Description 根据 instance 进行分类
     * @Date 2023/5/25 18:33
     * @Param [instance] 待分类的实例
     * @return double[] 类别分布
     **/
    public double[] distributionForInstance(Instance instance)
            throws NoSupportForMissingValuesException {
        // 如果到达叶子节点,返回类别分布
        if (m_Attribute == null) {
            // 如果 m_Distribution 全为 0(是空叶子),随机返回一个类别分布
            if (Utils.eq(Utils.sum(m_Distribution), 0)) {
                // 在 0~类别数-1 之间随机选择一个类别
                m_Distribution = new double[m_ClassAttribute.numValues()];
                m_Distribution[(int) Math.round(Math.random() * m_ClassAttribute.numValues())] = 1.0;
            }
            return m_Distribution;
        } else {
            // 否则,递归调用
            return m_Successors[(int) instance.value(m_Attribute)].
                    distributionForInstance(instance);
        }
    }

    /*
     * @Author YFMan
     * @Description 计算当前数据集 选择某个属性的 信息增益
     * @Date 2023/5/25 18:29
     * @Param [data, att] 当前数据集,选择的属性
     * @return double 信息增益
     **/
    private double computeInfoGain(Instances data, Attribute att)
            throws Exception {
        // 计算 data 的信息熵
        double infoGain = computeEntropy(data);
        // 计算 data 按照 att 属性进行划分的信息熵
        // 划分数据集
        Instances[] splitData = splitData(data, att);
        // 遍历划分后的数据集
        for (Instances instances : splitData) {
            // 计算概率
            double probability = (double) instances.numInstances() / data.numInstances();
            // 计算信息熵
            infoGain -= probability * computeEntropy(instances);
        }
        // 返回信息增益
        return infoGain;
    }

    /*
     * @Author YFMan
     * @Description 计算信息熵
     * @Date 2023/5/25 18:18
     * @Param [data] 计算的数据集
     * @return double 信息熵
     **/
    private double computeEntropy(Instances data) throws Exception {
        // 计不同类别的数量
        double[] classCounts = new double[data.numClasses()];
        // 遍历数据集
        for(int i=0;i<data.numInstances();i++){
            // 获取类别
            int classIndex = (int) data.instance(i).classValue();
            // 数量加一
            classCounts[classIndex]++;
        }
        // 计算信息熵
        double entropy = 0;
        // 遍历类别
        for (double classCount : classCounts) {
            // 注意:这里是大于 0,因为 log2(0) = -Infinity;
            // 如果是等于 0,那么计算结果就是 NaN,熵就出错了
            if(classCount > 0){
                // 计算概率
                double probability = classCount / data.numInstances();
                // 计算信息熵
                entropy -= probability * Utils.log2(probability);
            }
        }
        // 返回信息熵
        return entropy;
    }

    /*
     * @Author YFMan
     * @Description 根据属性划分数据集
     * @Date 2023/5/25 18:23
     * @Param [data, att] 数据集,属性
     * @return weka.core.Instances[] 划分后的数据集
     **/
    private Instances[] splitData(Instances data, Attribute att) {
        // 定义划分后的数据集
        Instances[] splitData = new Instances[att.numValues()];
        // 遍历划分后的数据集
        for(int i=0;i<splitData.length;i++){
            // 创建数据集 (这里主要是为了初始化 数据集 header)
            // Constructor copying all instances and references to the header
            // information from the given set of instances.
            splitData[i] = new Instances(data,0);
        }
        // 遍历数据集
        for(int i=0;i<data.numInstances();i++){
            // 获取实例
            Instance instance = data.instance(i);
            // 获取实例的属性值
            double value = instance.value(att);
            // 将实例添加到对应的数据集中
            splitData[(int) value].add(instance);
        }
        // 返回划分后的数据集
        return splitData;
    }

    private String toString(int level) {

        StringBuffer text = new StringBuffer();

        if (m_Attribute == null) {
            if (Instance.isMissingValue(m_ClassValue)) {
                text.append(": null");
            } else {
                text.append(": " + m_ClassAttribute.value((int) m_ClassValue));
            }
        } else {
            for (int j = 0; j < m_Attribute.numValues(); j++) {
                text.append("\n");
                for (int i = 0; i < level; i++) {
                    text.append("|  ");
                }
                text.append(m_Attribute.name() + " = " + m_Attribute.value(j));
                text.append(m_Successors[j].toString(level + 1));
            }
        }
        return text.toString();
    }

    public String toString() {

        if ((m_Distribution == null) && (m_Successors == null)) {
            return "Id3: No model built yet.";
        }
        return "Id3\n\n" + toString(0);
    }

    /**
     * Main method.
     *
     * @param args the options for the classifier
     */
    public static void main(String[] args) {
        runClassifier(new myId3(), args);
    }
}

到了这里,关于基于weka手工实现ID3决策树的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【数据挖掘大作业】基于决策树的评教指标筛选(weka+数据+报告+操作步骤)

    数据挖掘大作业 下载链接:【数据挖掘大作业】基于决策树的评教指标筛选(weka使用手册+数据+实验报告) 一、考核内容 现有某高校评教数据(pjsj.xls),共计842门课程,属性包括:课程名称、评价人数、总平均分以及10个评价指标Index1-Index10。指标内容详见表1。 表 1 学生评教

    2024年02月09日
    浏览(43)
  • 决策树之ID3的matlab实现

    森林内的两条分叉路,我选择了人迹罕见的一条,从此一切变得不一样。 ------佛洛斯特Robert Frost 目录 一 .决策树介绍 1.1 相关概念 1.2 图形表示 1.3 规则表示 二.决策树的信息计算 三.ID3相关介绍 3.1 ID3算法概述 3.2 算法流程 四.matlab实现

    2024年02月11日
    浏览(60)
  • ID3决策树及Python实现(详细)

    目录 一、划分特征的评价指标: 二、决策树学习算法伪代码: 三、决策树生成实例: 四、Python实现ID3决策树: 1、信息熵 Ent(D): 信息熵,是度量样本集合纯度的一种指标,Ent(D)的值越小,则样本集D的纯度越高; 2、信息增益 Gain(D,a): 信息增益越大,则意味着使用属性a来

    2024年02月09日
    浏览(41)
  • 基于weka手工实现KNN

    K最近邻(K-Nearest Neighbors,简称KNN)算法是一种常用的基于实例的监督学习算法。它可以用于分类和回归问题,并且是一种非常直观和简单的机器学习算法。 KNN算法的基本思想是:对于一个新的样本数据,在训练数据集中找到与其最接近的K个邻居,然后根据这K个邻居的标签

    2024年02月13日
    浏览(49)
  • 在西瓜数据集上用Python实现ID3决策树算法完整代码

    在西瓜数据集上用Python实现ID3决策树算法完整代码 1、决策树算法代码ID3.py 2、可视化决策树代码visual_decision_tree.py 3、贴几张运行结果图 1、生成的可视化决策树 2、代码运行结果 输出每次划分的每个属性特征的信息增益以及最后的决策树 3、记事本上手动跑程序的草图

    2024年02月08日
    浏览(47)
  • 基于weka平台手工实现朴素贝叶斯分类

    B事件发生后,A事件发生的概率可以如下表示: p ( A ∣ B ) = p ( A ∩ B ) P ( B ) (1) p(A|B)=frac{p(Acap B)}{P(B)}tag{1} p ( A ∣ B ) = P ( B ) p ( A ∩ B ) ​ ( 1 ) A事件发生后,B事件发生的概率可以如下表示: p ( B ∣ A ) = p ( A ∩ B ) P ( A ) (2) p(B|A)=frac{p(Acap B)}{P(A)}tag{2} p ( B ∣ A ) = P ( A ) p

    2024年02月13日
    浏览(46)
  • 基于weka手工实现多层感知机(BPNet)

    单层感知机,就是只有一层神经元,它的模型结构如下 1 : 对于权重 w w w 的更新,我们采用如下公式: w i = w i + Δ w i Δ w i = η ( y − y ^ ) x i (1) w_i=w_i+Delta w_i \\\\ Delta w_i=eta(y-hat{y})x_itag{1} w i ​ = w i ​ + Δ w i ​ Δ w i ​ = η ( y − y ^ ​ ) x i ​ ( 1 ) 其中, y y y 为标签, y

    2024年02月17日
    浏览(50)
  • 基于weka手工实现K-means

    K均值聚类(K-means clustering)是一种常见的无监督学习算法,用于将数据集中的样本划分为K个不同的类别或簇。它通过最小化样本点与所属簇中心点之间的距离来确定最佳的簇划分。 K均值聚类的基本思想如下: 随机选择K个初始聚类中心(质心)。 对于每个样本,计算其与各

    2024年02月13日
    浏览(36)
  • 基于weka手工实现逻辑斯谛回归(Logistic回归)

    逻辑斯谛回归模型其实是一种分类模型,这里实现的是参考李航的《统计机器学习》以及周志华的《机器学习》两本教材来整理实现的。 假定我们的输入为 x x x , x x x 可以是多个维度的,我们想要根据 x x x 去预测 y y y , y ∈ { 0 , 1 } yin {0,1} y ∈ { 0 , 1 } 。逻辑斯谛的模型

    2024年02月15日
    浏览(45)
  • 基于weka平台手工实现(LinearRegression | Ridge Regression,岭回归)

    线性回归主要采用最小二乘法来实现,主要思想如下: X = ( x 11 x 12 ⋯ x 1 d 1 x 21 x 22 ⋯ 5 1 ⋮ ⋮ ⋱ ⋮ ⋮ x m 1 x m 2 ⋯ x m d 1 ) X=left( begin{matrix} x_{11} x_{12} cdots x_{1d} 1 \\\\ x_{21} x_{22} cdots 5 1 \\\\ vdots vdots ddots vdots vdots \\\\ x_{m1} x_{m2} cdots x_{md} 1 \\\\ end{matrix} right) X = ​ x 11 ​ x

    2024年02月12日
    浏览(46)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包