Java语言在Spark3.2.4集群中使用Spark MLlib库完成XGboost算法

这篇具有很好参考价值的文章主要介绍了Java语言在Spark3.2.4集群中使用Spark MLlib库完成XGboost算法。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一、概述

XGBoost是一种基于决策树的集成学习算法,它在处理结构化数据方面表现优异。相比其他算法,XGBoost能够处理大量特征和样本,并且支持通过正则化控制模型的复杂度。XGBoost也可以自动进行特征选择并对缺失值进行处理。

二、代码实现步骤

1、导入相关库

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor};
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SparkSession;

2、加载数据

SparkSession spark = SparkSession.builder().appName("XGBoost").master("local[*]").getOrCreate();
DataFrame data = spark.read().option("header", "true").option("inferSchema", "true").csv("data.csv");

3、准备特征向量

String[] featureCols = data.columns();
featureCols = Arrays.copyOfRange(featureCols, 0, featureCols.length - 1);
VectorAssembler assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features");
DataFrame inputData = assembler.transform(data).select("features", "output");
inputData.show(false);

4、划分训练集和测试集

double[] weights = {0.7, 0.3};
DataFrame[] splitData = inputData.randomSplit(weights);
DataFrame train = splitData[0];
DataFrame test = splitData[1];

5、定义XGBoost模型

GBTRegressor gbt = new GBTRegressor()
    .setLabelCol("output")
    .setFeaturesCol("features")
    .setMaxIter(100)
    .setStepSize(0.1)
    .setMaxDepth(6)
    .setLossType("squared")
    .setFeatureSubsetStrategy("auto");

6、构建管道

Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{gbt});

7、训练模型

GBTRegressionModel model = (GBTRegressionModel) pipeline.fit(train).stages()[0];

8、进行预测并评估模型

DataFrame predictions = model.transform(test);
predictions.show(false);

RegressionEvaluator evaluator = new RegressionEvaluator()
    .setMetricName("rmse")
    .setLabelCol("output")
    .setPredictionCol("prediction");

double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

以上就是Java语言中基于SparkML的XGBoost算法实现的示例代码。需要注意的是,这里使用了GBTRegressor作为XGBoost的实现方式,但是也可以使用其他实现方式,例如XGBoostRegressor或者XGBoostClassification。

三、完整代码

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor};
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SparkSession;
import java.util.Arrays;

public class XGBoostExample {

    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("XGBoost").master("local[*]").getOrCreate();

        // 加载数据
        DataFrame data = spark.read().option("header", "true").option("inferSchema", "true").csv("data.csv");
        data.printSchema();
        data.show(false);

        // 准备特征向量
        String[] featureCols = data.columns();
        featureCols = Arrays.copyOfRange(featureCols, 0, featureCols.length - 1);
        VectorAssembler assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features");
        DataFrame inputData = assembler.transform(data).select("features", "output");
        inputData.show(false);

        // 划分训练集和测试集
        double[] weights = {0.7, 0.3};
        DataFrame[] splitData = inputData.randomSplit(weights);
        DataFrame train = splitData[0];
        DataFrame test = splitData[1];

        // 定义XGBoost模型
        GBTRegressor gbt = new GBTRegressor()
                .setLabelCol("output")
                .setFeaturesCol("features")
                .setMaxIter(100)
                .setStepSize(0.1)
                .setMaxDepth(6)
                .setLossType("squared")
                .setFeatureSubsetStrategy("auto");

        // 构建管道
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{gbt});

        // 训练模型
        GBTRegressionModel model = (GBTRegressionModel) pipeline.fit(train).stages()[0];

        // 进行预测并评估模型
        DataFrame predictions = model.transform(test);
        predictions.show(false);

        RegressionEvaluator evaluator = new RegressionEvaluator()
                .setMetricName("rmse")
                .setLabelCol("output")
                .setPredictionCol("prediction");

        double rmse = evaluator.evaluate(predictions);
        System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

        spark.stop();
    }
}

在运行代码之前需要将数据文件data.csv放置到程序所在目录下,以便加载数据。另外,需要将代码中的相关路径和参数按照实际情况进行修改。 文章来源地址https://www.toymoban.com/news/detail-411531.html

到了这里,关于Java语言在Spark3.2.4集群中使用Spark MLlib库完成XGboost算法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Spark MLlib ----- ALS算法

    在谈ALS(Alternating Least Squares)之前首先来谈谈LS,即最小二乘法。LS算法是ALS的基础,是一种数优化技术,也是一种常用的机器学习算法,他通过最小化误差平方和寻找数据的最佳匹配,利用最小二乘法寻找最优的未知数据,保证求的数据与已知的数据误差最小。LS也被用于拟

    2024年02月02日
    浏览(43)
  • [机器学习、Spark]Spark MLlib实现数据基本统计

    👨‍🎓👨‍🎓博主:发量不足 📑📑本期更新内容: Spark MLlib基本统计 📑📑下篇文章预告:Spark MLlib的分类🔥🔥 简介:耐心,自信来源于你强大的思想和知识基础!!   目录 Spark MLlib基本统计 一.摘要统计 二.相关统计 三.分层抽样   MLlib提供了很多统计方法,包含

    2024年02月02日
    浏览(48)
  • Spark3 新特性之AQE

    一、 背景 Spark 2.x 在遇到有数据倾斜的任务时,需要人为地去优化任务,比较费时费力;如果任务在Reduce阶段,Reduce Task 数据分布参差不齐,会造成各个excutor节点资源利用率不均衡,影响任务的执行效率;Spark 3新特性AQE极大地优化了以上任务的执行效率。 二、 Spark 为什么需

    2024年02月14日
    浏览(35)
  • spark3.3.0安装&部署过程

    为了防止不必要的报错,部署之前请务必从开头开始看,切勿跳过其中一个部署模式,因为每一个部署模式都是从上一个模式的配置上进行的 下载地址:https://archive.apache.org/dist/spark/ 本文所下载版本为: spark-3.3.0-bin-hadoop2 环境: hadoop-2.7.5 jdk1.8.0 Scala 所谓的Local模式,就是不需

    2023年04月20日
    浏览(81)
  • Spark编程实验六:Spark机器学习库MLlib编程

    目录 一、目的与要求 二、实验内容 三、实验步骤 1、数据导入 2、进行主成分分析(PCA) 3、训练分类模型并预测居民收入  4、超参数调优 四、结果分析与实验体会 1、通过实验掌握基本的MLLib编程方法; 2、掌握用MLLib解决一些常见的数据分析问题,包括数据导入、成分分析

    2024年02月20日
    浏览(42)
  • Windows10系统spark3.0.0配置

    Windows10系统基本环境:spark3.0. 0 +hadoop3.1. 0 +scala2.12.0+java jdk1.8。 环境变量配置路径:电脑→属性→高级系统设置→环境变量 path中加入:%JAVA_HOME%/bin。 注:jdk版本不宜过高。 cmd验证: java -version 官方下载网址:https://www.scala-lang.org/ 选择对应版本,这里我选择的是scala2.12.0版本

    2024年04月26日
    浏览(37)
  • spark3.3.x处理excel数据

    环境: spark3.3.x scala2.12.x 引用: spark-shell --jars spark-excel_2.12-3.3.1_0.18.5.jar 或项目里配置pom.xml 代码: 1、直接使用excel文件第一行作为schema 2、使用自定义schema(该方法如果excel文件第一行不是所需数据,需手动限制读取的数据范围) ps:刚开始用的3.3.3_0.20.1这个版本的不可用,具体

    2024年02月08日
    浏览(33)
  • Hive3 on Spark3配置

    大数据组件 版本 Hive 3.1.2 Spark spark-3.0.0-bin-hadoop3.2 OS 版本 MacOS Monterey 12.1 Linux - CentOS 7.6 1)Hive on Spark说明 Hive引擎包括:默认 mr 、 spark 、 Tez 。 Hive on Spark :Hive既作为存储元数据又负责SQL的解析优化,语法是HQL语法,执行引擎变成了Spark,Spark负责采用RDD执行。 Spark on Hive :

    2024年02月04日
    浏览(39)
  • [机器学习、Spark]Spark机器学习库MLlib的概述与数据类型

    👨‍🎓👨‍🎓博主:发量不足 📑📑本期更新内容: Spark机器学习库MLlib的概述与数据类型 📑📑下篇文章预告:Spark MLlib基本统计 💨💨简介:分享的是一个当代疫情在校封校的大学生学习笔记 目录 Spark机器学习库MLlib的概述 一.MLib的简介 二.Spark机器学习工作流程 数

    2023年04月09日
    浏览(86)
  • Hudi0.14.0集成Spark3.2.3(Spark Shell方式)

    1.1 启动Spark Shell

    2024年01月24日
    浏览(38)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包