R-多分类logistic回归(机器学习)

这篇具有很好参考价值的文章主要介绍了R-多分类logistic回归(机器学习)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

多分类logistic回归

在之前文章介绍了,如何在R里面处理多分类的回归模型,得到的是各个因素的系数及相对OR,但是解释性,比二元logistic回归方程要冗杂的多。

那么今天继续前面的基础上,用机器学习的方法来解释多分类问题。
其实最终回归到这类分类问题的本质:有了一系列的影响因素x,那么根据这些影响因素来判断最终y属于哪一类别。

R-多分类logistic回归(机器学习),机器学习,机器学习,r语言,分类

image.png

1.数据案例

这里主要用到DALEX包里面包含的HR数据,里面记录了职工在工作岗位的状态与年龄,性别,工作时长,评价及薪水有关。根据7847条记录来评估,如果一个职工属于男性,68岁,薪水及评价处于3等级,那么该职工可能会处于什么状态。

library(DALEX)
library(iBreakDown)
library(car)
library(questionr)
try(data(package="DALEX"))
data(HR)

# split
set.seed(543)
ind = sample(2,nrow(HR),replace=TRUE,prob=c(0.9,0.1))
trainData = HR[ind==1,]
testData = HR[ind==2,]

# randforest
m_rf = randomForest(status ~ . , data = trainData)

2.随机森林模型

我们根据上述数据,分成训练集与测试集(Train and Test)测试集用来估计随机森林模型的效果。

2.1模型评估

通过对Train数据构建rf模型后,我们对Train数据进行拟合,看一下模型的效果,Accuracy : 0.9357 显示很好,kappa一致性为90%。
那再用该fit去预测test数据, Accuracy : 0.7166 , Kappa : 56% ,显示效果不怎么理想。

# Prediction and Confusion Matrix - Training data 
pred1 <- predict(m_rf, trainData)
head(pred1)
confusionMatrix(pred1, trainData$status)  #

pred2 <- predict(m_rf, testData)
head(pred2)
confusionMatrix(pred2, testData$status)  #

> confusionMatrix(pred1, trainData$status)  #
Confusion Matrix and Statistics

          Reference
Prediction fired   ok promoted
  fired     2478  194       49
  ok          43 1738       80
  promoted    25   64     2375

Overall Statistics
                                          
               Accuracy : 0.9354          
                 95% CI : (0.9294, 0.9411)
    No Information Rate : 0.3613          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.9024          
                                          
 Mcnemar's Test P-Value : < 2.2e-16       

Statistics by Class:

                     Class: fired Class: ok Class: promoted
Sensitivity                0.9733    0.8707          0.9485
Specificity                0.9460    0.9756          0.9804
Pos Pred Value             0.9107    0.9339          0.9639
Neg Pred Value             0.9843    0.9502          0.9718
Prevalence                 0.3613    0.2833          0.3554
Detection Rate             0.3517    0.2467          0.3371
Detection Prevalence       0.3862    0.2641          0.3497
Balanced Accuracy          0.9596    0.9232          0.9644
> 
> pred2 <- predict(m_rf, testData)
> head(pred2)
    1    20    36    42    49    56 
fired fired fired fired fired    ok 
Levels: fired ok promoted
> confusionMatrix(pred2, testData$status)  #
Confusion Matrix and Statistics

          Reference
Prediction fired  ok promoted
  fired      246  62       19
  ok          37 117       37
  promoted    26  46      211

Overall Statistics
                                         
               Accuracy : 0.7166         
                 95% CI : (0.684, 0.7476)
    No Information Rate : 0.3858         
    P-Value [Acc > NIR] : < 2e-16        
                                         
                  Kappa : 0.5692         
                                         
 Mcnemar's Test P-Value : 0.03881        

Statistics by Class:

                     Class: fired Class: ok Class: promoted
Sensitivity                0.7961    0.5200          0.7903
Specificity                0.8354    0.8715          0.8652
Pos Pred Value             0.7523    0.6126          0.7456
Neg Pred Value             0.8671    0.8230          0.8919
Prevalence                 0.3858    0.2809          0.3333
Detection Rate             0.3071    0.1461          0.2634
Detection Prevalence       0.4082    0.2385          0.3533
Balanced Accuracy          0.8157    0.6958          0.8277

2.2变量重要性

我们看到,对影响因素进行重要性排序,等同于P值。在预测时候,哪些因素对y占影响比重较大。这里的variable_importance(),可以有好几种方式对变量进行衡量,这里采用默认的MeanDecreaseGini.

# vip
vip(m_rf)
var=randomForest::importance(m_rf)
var

image.png

2.2边际效应

我们知道了hours,age比较重要,那么是如何重要的,譬如年龄在什么阶段,会导致升职或者开除。
当工作小时在45以内,被开除/离职的概率较大,当工作时常超过60以后,很有可能会被提升。得到升职加薪的机会。
当然了,也可以绘制2D的边际效应,两个因素相互作用的Partial plot。

# partial plot
partialPlot(m_rf, HR, age)
head(partial(m_rf, pred.var = "age"))  # returns a data frame

# for all varibles
nm=rownames(var)
# Get partial depedence values for top predictors
pd_df <- partial_dependence(fit = m_rf,
                            vars = nm,
                            data = df_rf,
                            n = c(100, 200))
                        
# Plot partial dependence using edarf
plot_pd(pd_df)

image.png

image.png

2.3个体预测

现在假如有一个员工的信息如下,

      gender      age    hours evaluation salary   status
10000 female 57.96254 54.78624          4      4 promoted

去预测该职工最后的状态:
该预测结果显示,这个职工,有97%的可能性要升职加薪。而他的实际状态也是Promoted。

new_observation=tail(HR,1)
p_fun <- function(object, newdata){predict(object, newdata = newdata, type = "prob")}
bd_rf <- local_attributions(m_rf,
                            data = HR_test,
                            new_observation =  new_observation,
                            predict_function = p_fun)

bd_rf
plot(bd_rf)

image.png文章来源地址https://www.toymoban.com/news/detail-786986.html

> sessionInfo()
R version 3.6.2 (2019-12-12)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Mojave 10.14

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  utils     datasets  grDevices methods   base     

other attached packages:
 [1] edarf_1.1.1         ranger_0.12.1       questionr_0.7.0     car_3.0-7          
 [5] carData_3.0-3       nnet_7.3-14         DALEX_1.2.1         vip_0.2.2          
 [9] ggpubr_0.3.0        rstatix_0.5.0       caret_6.0-86        lattice_0.20-41    
[13] pdp_0.7.0           randomForest_4.6-14 iBreakDown_1.2.0    hrbrthemes_0.8.0   
[17] reshape2_1.4.4      RColorBrewer_1.1-2  forcats_0.5.0       stringr_1.4.0      
[21] dplyr_0.8.5         purrr_0.3.4         readr_1.3.1         tidyr_1.0.3        
[25] tibble_3.0.1        ggplot2_3.3.0       tidyverse_1.3.0    

参考

  1. iBreakDown plots for classification models
  2. prediction 预测结果输出为概率
  3. pdp 边际效应
  4. Partial dependence (PD) plots For Random Forests
  5. Explaining Black-Box Machine Learning Models
  6. Interpretable Machine Learning

到了这里,关于R-多分类logistic回归(机器学习)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 机器学习11:逻辑回归-Logistic Regression

    目录 1.计算概率 2.损失和正则化 2.1 逻辑回归的损失函数 2.2 逻辑回归中的正则化 3.参考文献

    2024年02月11日
    浏览(62)
  • 机器学习:基于逻辑回归(Logistic Regression)对股票客户流失预测分析

    作者:i阿极 作者简介:Python领域新星作者、多项比赛获奖者:博主个人首页 😊😊😊如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦!👍👍👍 📜📜📜如果有小伙伴需要数据集和学习交流,文章下方有交流学习区!一起学习进步!💪 专栏案例:

    2023年04月26日
    浏览(52)
  • 吴恩达老师《机器学习》课后习题2之逻辑回归(logistic_regression)

    用于解决输出标签y为0或1的二元分类问题 判断邮件是否属于垃圾邮件? 银行卡交易是否属于诈骗? 肿瘤是否为良性? 等等。 案例:根据学生的两门学生成绩,建立一个逻辑回归模型,预测该学生是否会被大学录取 数据集:ex2data1.txt python实现逻辑回归, 目标:建立分类器(求

    2024年02月09日
    浏览(47)
  • 机器学习:分类、回归、决策树

            如:去银行借钱,会有借或者不借的两种类别         如:去银行借钱,预测银行会借给我多少钱,如:1~100000之间的一个数值         为了要将表格转化为一棵树,决策树需要找出最佳节点和最佳的分枝方法,对分类树来说,衡量这个 “ 最佳 ” 的指标 叫

    2024年02月02日
    浏览(49)
  • 【机器学习】逻辑回归(二元分类)

    离散感知器:输出的预测值仅为 0 或 1 连续感知器(逻辑分类器):输出的预测值可以是 0 到 1 的任何数字,标签为 0 的点输出接近于 0 的数,标签为 1 的点输出接近于 1 的数 逻辑回归算法(logistics regression algorithm):用于训练逻辑分类器的算法 sigmoid 函数: g ( z ) = 1 1 +

    2024年02月21日
    浏览(51)
  • 机器学习常识 3: 分类、回归、聚类

    摘要 : 本贴描述分类、回归、聚类问题的基本概念. 机器学习常识 2: 数据类型从输入数据的角度来进行讨论, 这里从输出数据, 或者目标的角度来讨论. 分类 是指将一个样本预测为给定类别之一. 也称为该样本打标签. 例 1: 如果我去向那个女生表白, 她会同意吗? (Y/N) 由于可能的

    2024年02月06日
    浏览(39)
  • R语言-logistic回归

    #logistic- caret::train #划分数据集 set.seed(123) folds - createFolds(y=data$Groups,k=10)

    2024年02月16日
    浏览(40)
  • 【机器学习】鸢尾花分类-逻辑回归示例

    功能: 这段代码演示了如何使用逻辑回归对鸢尾花数据集进行训练,并将训练好的模型保存到文件中。然后,它允许用户输入新的鸢尾花特征数据,使用保存的模型进行预测,并输出预测结果。 步骤概述: 加载数据和预处理: 使用 Scikit-Learn 中的 datasets 模块加载鸢尾花数据

    2024年02月10日
    浏览(42)
  • 机器学习:什么是分类/回归/聚类/降维/决策

    目录 学习模式分为三大类:监督,无监督,强化学习 监督学习基本问题 分类问题 回归问题 无监督学习基本问题 聚类问题 降维问题 强化学习基本问题 决策问题 如何选择合适的算法 我们将涵盖目前「五大」最常见机器学习任务: 回归 分类 聚类 降维 决策 分类是监督学习

    2024年02月12日
    浏览(46)
  • 机器学习算法(一): 基于逻辑回归的分类预测

    逻辑回归的介绍 逻辑回归(Logistic regression,简称LR)虽然其中带有\\\"回归\\\"两个字,但逻辑回归其实是一个 分类 模型,并且广泛应用于各个领域之中。虽然现在深度学习相对于这些传统方法更为火热,但实则这些传统方法由于其独特的优势依然广泛应用于各个领域中。 而对于

    2024年01月15日
    浏览(49)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包