金融数据_PySpark-3.0.3决策树(DecisionTreeClassifier)实例

这篇具有很好参考价值的文章主要介绍了金融数据_PySpark-3.0.3决策树(DecisionTreeClassifier)实例。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

金融数据_PySpark-3.0.3决策树(DecisionTreeClassifier)实例

逻辑回归: 逻辑回归常被用于二分类问题, 比如涨跌预测。你可以将涨跌标记为类别, 然后使用逻辑回归进行训练。

决策树和随机森林: 决策树和随机森林是用于分类问题的强大模型。它们能够处理非线性关系, 并且对于特征的重要性有较好的解释。

在 PySpark-3.x.x 中构建决策树主要使用 pyspark.ml 模块中的 DecisionTreeClassifier。

下面是一个简单的示例, 演示如何使用 PySpark-3.x.x 构建和训练决策树模型。

实例数据

本实例截取了 “湖北宜化(000422)” 2015年08月06日 - 2015年12月31日的数据。

HBYH_000422_20150806_20151231.csv

Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80

探索思路

这里只是简单示例, 目的在于熟悉 Spark 中的决策树分类器使用方法, 无任何投资引导。

目标:

通过当日数值情况, 预测当日收盘涨跌, 如果 “涨跌幅(Change) >= 0”, 则用 1 表示, 如果 “涨跌幅(Change) < 0”, 则用 0 表示 (二分类标签)。

变量:

  1. 当日最高价

  2. 当日最低价

  3. 当日换手率

  4. 当日成交量

  5. 当日星期几 (星期对价格的影响)

  6. 当日 “短期均线(MA5)” 与 “长期均线(MA10)” 的关系, 如果 “MA5 > MA10”, 则用 1 表示, 如果 “MA5 = MA10”, 则用 0 表示, 如果 “MA5 < MA10”, 则用 -1 表示。

导入 pyspark.sql 相关模块

Spark SQL 是用于结构化数据处理的 Spark 模块。它提供了一种成为 DataFrame 编程抽象, 是由 SchemaRDD 发展而来。

不同于 SchemaRDD 直接继承 RDD, DataFrame 自己实现了 RDD 的绝大多数功能。

from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DateType, IntegerType, DoubleType

导入 pyspark.ml 相关模块

Spark 在核心数据抽象 RDD 的基础上, 支持 4 大组件, 其中机器学习占其一。

进一步的, Spark 中实际上支持两个机器学习模块, MLlib 和 ML, 区别在于前者主要是基于 RDD 数据结构, 当前处于维护状态; 而后者则是 DataFrame 数据结构, 支持更多的算法, 后续将以此为主进行迭代。

所以, 在实际应用中优先使用 ML 子模块。

Spark 的 ML 库与 Python 中的另一大机器学习库 Sklearn 的关系是: Spark 的 ML 库支持大部分机器学习算法和接口功能, 虽远不如 Sklearn 功能全面, 但主要面向分布式训练, 针对大数据。

而 Sklearn 是单点机器学习算法库, 支持几乎所有主流的机器学习算法, 从样例数据, 特征选择, 模型选择和验证, 基础学习算法和集成学习算法, 提供了机器学习一站式解决方案, 但仅支持并行而不支持分布式。

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline

创建 SparkSession 对象

Spark 2.0 以上版本的 spark-shell 在启动时会自动创建一个名为 spark 的 SparkSession 对象。

当需要手工创建时, SparkSession 可以由其伴生对象的 builder 方法创建出来。

spark = SparkSession.builder.master("local[*]").appName("spark").getOrCreate()

使用 Spark 构建 DataFrame 数据 (Optional)

当数据量较小时, 可以使用该方法手工构建 DataFrame 数据。

构建数据行 Row (以前 3 行为例):

Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85")
ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85")
Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")

将构建好的数据行 Row 加入列表 (以前 3 行为例):

Data_Rows = [
    Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85"),
    ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85"),
    Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")
]

生成 DataFrame 数据框 (以前 3 行为例):

SDF = spark.createDataFrame(Data_Rows)

输出 DataFrame 数据框 (以前 3 行为例):

print("[Message] Builded Spark DataFrame:")
SDF.show()

输出:

+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
|      Date|   Code|Open|High| Low|Close|Pre_Close|   Change|Turnover_Rate|    Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77|     7.93|-0.020177|     0.015498| 1.39152E7|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93|     7.84|  0.01148|     0.018662| 1.67559E7|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84|     7.71| 0.016861|     0.015886| 1.42638E7|7.90|7.81|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+

使用 Spark 读取 CSV 数据

调用 SparkSession 的 .read 方法读取 CSV 数据:

其中 .option 是读取文件时的选项, 左边是 “键(Key)”, 右边是 “值(Value)”, 例如 .option(“header”, “true”) 与 {header = “true”} 类同。

SDF = spark.read.option("header", "true").option("encoding", "utf-8").csv("file:///D:\\HBYH_000422_20150806_20151231.csv")

输出 DataFrame 数据框:

print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
SDF.show()

输出:

[Message] Readed CSV File: D:\HBYH_000422_20150806_20151231.csv
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
|      Date|   Code|Open|High| Low|Close|Pre_Close|   Change|Turnover_Rate|  Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77|     7.93|-0.020177|     0.015498|13915200|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93|     7.84| 0.011480|     0.018662|16755900|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84|     7.71| 0.016861|     0.015886|14263800|7.90|7.81|
|2015-12-28|'000422|8.03|8.08|7.70| 7.71|     8.03|-0.039851|     0.030821|27672800|7.91|7.78|
|2015-12-25|'000422|8.03|8.05|7.93| 8.03|     7.99| 0.005006|     0.021132|18974000|7.93|7.78|
|2015-12-24|'000422|7.93|8.16|7.87| 7.99|     7.92| 0.008838|     0.026487|23781900|7.85|7.72|
|2015-12-23|'000422|7.97|8.11|7.88| 7.92|     7.89| 0.003802|     0.042360|38033600|7.80|7.69|
|2015-12-22|'000422|7.86|7.93|7.76| 7.89|     7.83| 0.007663|     0.026929|24178700|7.73|7.68|
|2015-12-21|'000422|7.59|7.89|7.56| 7.83|     7.63| 0.026212|     0.030777|27633600|7.66|7.67|
|2015-12-18|'000422|7.71|7.74|7.57| 7.63|     7.74|-0.014212|     0.024764|22234900|7.62|7.71|
|2015-12-17|'000422|7.58|7.75|7.57| 7.74|     7.55| 0.025166|     0.028054|25188400|7.59|7.77|
|2015-12-16|'000422|7.57|7.62|7.53| 7.55|     7.55| 0.000000|     0.020718|18601600|7.58|7.79|
|2015-12-15|'000422|7.63|7.66|7.52| 7.55|     7.62|-0.009186|     0.025902|23256600|7.64|7.78|
|2015-12-14|'000422|7.40|7.64|7.36| 7.62|     7.51| 0.014647|     0.021005|18860100|7.68|7.76|
|2015-12-11|'000422|7.65|7.70|7.41| 7.51|     7.67|-0.020860|     0.020477|18385900|7.80|7.73|
|2015-12-10|'000422|7.78|7.87|7.65| 7.67|     7.83|-0.020434|     0.019972|17931900|7.95|7.69|
|2015-12-09|'000422|7.76|8.00|7.75| 7.83|     7.77| 0.007722|     0.025137|22569700|8.00|7.68|
|2015-12-08|'000422|8.08|8.18|7.76| 7.77|     8.24|-0.057039|     0.036696|32948200|7.92|7.66|
|2015-12-07|'000422|8.12|8.39|7.94| 8.24|     8.23| 0.001215|     0.064590|57993100|7.84|7.64|
|2015-12-04|'000422|7.85|8.48|7.80| 8.23|     7.92| 0.039141|     0.100106|89881900|7.65|7.58|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
only showing top 20 rows

转换 Spark 中 DateFrame 各列数据类型

通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。

# 转换 Spark 中 DateFrame 数据类型。
SDF = SDF.withColumn("Date",          col("Date").cast(DateType()))
SDF = SDF.withColumn("Open",          col("Open").cast(DoubleType()))
SDF = SDF.withColumn("High",          col("High").cast(DoubleType()))
SDF = SDF.withColumn("Low",           col("Low").cast(DoubleType()))
SDF = SDF.withColumn("Close",         col("Close").cast(DoubleType()))
SDF = SDF.withColumn("Pre_Close",     col("Pre_Close").cast(DoubleType()))
SDF = SDF.withColumn("Change",        col("Change").cast(DoubleType()))
SDF = SDF.withColumn("Turnover_Rate", col("Turnover_Rate").cast(DoubleType()))
SDF = SDF.withColumn("Volume",        col("Volume").cast(IntegerType()))
SDF = SDF.withColumn("MA5",           col("MA5").cast(DoubleType()))
SDF = SDF.withColumn("MA10",          col("MA10").cast(DoubleType()))

# 输出 Spark 中 DataFrame 字段和数据类型。
print("[Message] Changed Spark DataFrame Data Type:")
SDF.printSchema()

输出:

[Message] Changed Spark DataFrame Data Type:
root
 |-- Date: date (nullable = true)
 |-- Code: string (nullable = true)
 |-- Open: double (nullable = true)
 |-- High: double (nullable = true)
 |-- Low: double (nullable = true)
 |-- Close: double (nullable = true)
 |-- Pre_Close: double (nullable = true)
 |-- Change: double (nullable = true)
 |-- Turnover_Rate: double (nullable = true)
 |-- Volume: integer (nullable = true)
 |-- MA5: double (nullable = true)
 |-- MA10: double (nullable = true)

将 Spark 的 DateFrame 和 Spark RDD 互相转换并计算数据

编写 “向 spark.sql 的 Row 对象添加字段和字段值” 函数:

def MapFunc_SparkSQL_Row_Add_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row: 

    """
    [Require] import pyspark
    
    [Example] >>> SrcRow = Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10)
              >>> NewRow = MapFunc_SparkSQL_Row_Add_Field(SrcRow=SrcRow, FldName='Weekday', FldVal=SrcRow['Date'].weekday())
              >>> print(NewRow)
              Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10, Weekday=4)
    """
    
    # Convert Obj "pyspark.sql.types.Row" to Dict. 
    # ----------------------------------------------
    Row_Dict = SrcRow.asDict()
    
    # Add a New Key in the Dictionary With the New Column Name and Value.
    # ----------------------------------------------
    Row_Dict[FldName] = FldVal
    
    # Convert Dict to Obj "pyspark.sql.types.Row". 
    # ----------------------------------------------
    NewRow = pyspark.sql.types.Row(**Row_Dict)
    
    # ==============================================
    return NewRow

编写 “判断股票涨跌” 函数:

def MapFunc_Stock_Judgement_Rise_or_Fall(ChgRate:float) -> int: 

    if (ChgRate >= 0.0): return 1
    if (ChgRate <  0.0): return 0
    
    # ==============================================
    # End of Function.

编写 “判断股票短期均线和长期均线关系” 函数:

def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int: 

    if (Short_MA >= Long_MA): return  1
    if (Short_MA == Long_MA): return  0
    if (Short_MA <= Long_MA): return -1
    
    # ==============================================
    # End of Function.

编写 “返回星期几(中文)” 函数:

def DtmFunc_Weekday_Return_String_CN(SrcDtm:datetime.datetime) -> str:
    
    """
    [Require] import datetime
    
    [Explain] Python3 中 datetime.datetime 对象的 .weekday() 方法返回的是从 0 到 6 的数字 (0 代表周一, 6 代表周日)。
    """
    
    Weekday_Str_Chinese:list = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
    
    # ==============================================
    return Weekday_Str_Chinese[SrcDtm.weekday()]

在 Spark 中将 DataFrame 转换为 Spark RDD 并调用自定义函数:

# 在 Spark 中将 DataFrame 转换为 RDD。
CalcRDD = SDF.rdd

# --------------------------------------------------
# 调用自定义函数: 提取星期索引。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(Idx)", X["Date"].weekday()))
# ..................................................
# 调用自定义函数: 返回星期几(中文)。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(CN)", DtmFunc_Weekday_Return_String_CN(X["Date"])))
# ..................................................
# 调用自定义函数: 判断股票涨跌。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Rise_Fall", MapFunc_Stock_Judgement_Rise_or_Fall(X["Change"])))
# ..................................................
# 判断股票短期均线和长期均线关系。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "MA_Relationship", MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"])))

# 显示计算好的 RDD 前 5 行。
print("[Message] Calculated RDD Top 5 Rows:")
pprint.pprint(CalcRDD.take(5))

输出:

[Message] Calculated RDD Top 5 Rows:
[Row(Date=datetime.date(2015, 12, 31), Code="'000422", Open=7.93, High=7.95, Low=7.76, Close=7.77, Pre_Close=7.93, Change=-0.020177, Turnover_Rate=0.015498, Volume=13915200, MA5=7.86, MA10=7.85, Weekday(Idx)=3, Weekday(CN)='周四', Rise_Fall=0, MA_Relationship=1),
 Row(Date=datetime.date(2015, 12, 30), Code="'000422", Open=7.86, High=7.93, Low=7.75, Close=7.93, Pre_Close=7.84, Change=0.01148, Turnover_Rate=0.018662, Volume=16755900, MA5=7.9, MA10=7.85, Weekday(Idx)=2, Weekday(CN)='周三', Rise_Fall=1, MA_Relationship=1),
 Row(Date=datetime.date(2015, 12, 29), Code="'000422", Open=7.72, High=7.85, Low=7.69, Close=7.84, Pre_Close=7.71, Change=0.016861, Turnover_Rate=0.015886, Volume=14263800, MA5=7.9, MA10=7.81, Weekday(Idx)=1, Weekday(CN)='周二', Rise_Fall=1, MA_Relationship=1),
 Row(Date=datetime.date(2015, 12, 28), Code="'000422", Open=8.03, High=8.08, Low=7.7, Close=7.71, Pre_Close=8.03, Change=-0.039851, Turnover_Rate=0.030821, Volume=27672800, MA5=7.91, MA10=7.78, Weekday(Idx)=0, Weekday(CN)='周一', Rise_Fall=0, MA_Relationship=1),
 Row(Date=datetime.date(2015, 12, 25), Code="'000422", Open=8.03, High=8.05, Low=7.93, Close=8.03, Pre_Close=7.99, Change=0.005006, Turnover_Rate=0.021132, Volume=18974000, MA5=7.93, MA10=7.78, Weekday(Idx)=4, Weekday(CN)='周五', Rise_Fall=1, MA_Relationship=1)]

计算完成后将 Spark RDD 转换回 Spark 的 DataFrame:

# 在 Spark 中将 RDD 转换为 DataFrame。
NewSDF = CalcRDD.toDF()

print("[Message] Convert RDD to DataFrame and Filter Out Key Columns for Display:")
NewSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship"]).show()

输出:

[Message] Convert RDD to DataFrame and Filter Out Key Columns:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
|      Date|   Code|High| Low|Close|   Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85|       周四|        0|              1|
|2015-12-30|'000422|7.93|7.75| 7.93|  0.01148| 7.9|7.85|       周三|        1|              1|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81|       周二|        1|              1|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78|       周一|        0|              1|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78|       周五|        1|              1|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72|       周四|        1|              1|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69|       周三|        1|              1|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68|       周二|        1|              1|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67|       周一|        1|             -1|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71|       周五|        0|             -1|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77|       周四|        1|             -1|
|2015-12-16|'000422|7.62|7.53| 7.55|      0.0|7.58|7.79|       周三|        1|             -1|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78|       周二|        0|             -1|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76|       周一|        1|             -1|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73|       周五|        0|              1|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69|       周四|        0|              1|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68|       周三|        1|              1|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66|       周二|        0|              1|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64|       周一|        1|              1|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58|       周五|        1|              1|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+

字符串索引化 (StringIndexer) 演示 (Only Demo)

StringIndexer (字符串-索引变换) 是一个估计器, 是将字符串列编码为标签索引列。索引位于 [0, numLabels), 按标签频率排序, 频率最高的排 0, 依次类推, 因此最常见的标签获取索引是 0。

# 使用 StringIndexer 转换 Weekday(CN) 列。
MyStringIndexer = StringIndexer(inputCol="Weekday(CN)", outputCol="StrIdx")
# 拟合并转换数据。
IndexedSDF = MyStringIndexer.fit(NewSDF).transform(NewSDF)

# 筛选 Date, Weekday(Idx), Weekday(CN), StrIdx 四列, 输出 StringIndexer 效果。
print("[Message] The Effect of StringIndexer:")
IndexedSDF.select(["Date", "Weekday(Idx)", "Weekday(CN)", "StrIdx"]).show()

输出:

[Message] The Effect of StringIndexer:
+----------+------------+-----------+------+
|      Date|Weekday(Idx)|Weekday(CN)|StrIdx|
+----------+------------+-----------+------+
|2015-12-31|           3|       周四|   3.0|
|2015-12-30|           2|       周三|   1.0|
|2015-12-29|           1|       周二|   2.0|
|2015-12-28|           0|       周一|   0.0|
|2015-12-25|           4|       周五|   4.0|
|2015-12-24|           3|       周四|   3.0|
|2015-12-23|           2|       周三|   1.0|
|2015-12-22|           1|       周二|   2.0|
|2015-12-21|           0|       周一|   0.0|
|2015-12-18|           4|       周五|   4.0|
|2015-12-17|           3|       周四|   3.0|
|2015-12-16|           2|       周三|   1.0|
|2015-12-15|           1|       周二|   2.0|
|2015-12-14|           0|       周一|   0.0|
|2015-12-11|           4|       周五|   4.0|
|2015-12-10|           3|       周四|   3.0|
|2015-12-09|           2|       周三|   1.0|
|2015-12-08|           1|       周二|   2.0|
|2015-12-07|           0|       周一|   0.0|
|2015-12-04|           4|       周五|   4.0|
+----------+------------+-----------+------+
only showing top 20 rows

提取 标签(Label)列 和 特征向量(Features)列

在创建特征向量(Features)列时, 将会用到 VectorAssembler 模块, VectorAssembler 将多个特征合并为一个特征向量。

提取 标签(Label) 列:

# 将 Rise_Fall 列复制为 Label 列。
NewSDF = NewSDF.withColumn("Label", col("Rise_Fall"))

创建 特征向量(Features) 列:

# VectorAssembler 将多个特征合并为一个特征向量。
FeaColsName:list = ["High", "Low", "Turnover_Rate", "Volume", "Weekday(Idx)", "MA_Relationship"]
MyAssembler = VectorAssembler(inputCols=FeaColsName, outputCol="Features")

# 拟合数据 (可选, 如果在模型训练时使用 Pipeline, 则无需在此步骤拟合数据, 当然也就无法在此步骤预览数据)。
AssembledSDF = MyAssembler.transform(NewSDF)

输出预览:

print("[Message] Assembled Label and Features for DecisionTreeClassifier:")
AssembledSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship", "Label", "Features"]).show()

预览:

[Message] Assembled for DecisionTreeClassifier:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
|      Date|   Code|High| Low|Close|   Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|Label|            Features|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85|       周四|        0|              1|    0|[7.95,7.76,0.0154...|
|2015-12-30|'000422|7.93|7.75| 7.93|  0.01148| 7.9|7.85|       周三|        1|              1|    1|[7.93,7.75,0.0186...|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81|       周二|        1|              1|    1|[7.85,7.69,0.0158...|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78|       周一|        0|              1|    0|[8.08,7.7,0.03082...|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78|       周五|        1|              1|    1|[8.05,7.93,0.0211...|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72|       周四|        1|              1|    1|[8.16,7.87,0.0264...|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69|       周三|        1|              1|    1|[8.11,7.88,0.0423...|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68|       周二|        1|              1|    1|[7.93,7.76,0.0269...|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67|       周一|        1|             -1|    1|[7.89,7.56,0.0307...|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71|       周五|        0|             -1|    0|[7.74,7.57,0.0247...|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77|       周四|        1|             -1|    1|[7.75,7.57,0.0280...|
|2015-12-16|'000422|7.62|7.53| 7.55|      0.0|7.58|7.79|       周三|        1|             -1|    1|[7.62,7.53,0.0207...|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78|       周二|        0|             -1|    0|[7.66,7.52,0.0259...|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76|       周一|        1|             -1|    1|[7.64,7.36,0.0210...|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73|       周五|        0|              1|    0|[7.7,7.41,0.02047...|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69|       周四|        0|              1|    0|[7.87,7.65,0.0199...|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68|       周三|        1|              1|    1|[8.0,7.75,0.02513...|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66|       周二|        0|              1|    0|[8.18,7.76,0.0366...|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64|       周一|        1|              1|    1|[8.39,7.94,0.0645...|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58|       周五|        1|              1|    1|[8.48,7.8,0.10010...|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
only showing top 20 rows

训练 决策树分类器(DecisionTreeClassifier) 模型

将数据集划分为 “训练集” 和 “测试集”:

(TrainingData, TestData) = AssembledSDF.randomSplit([0.8, 0.2], seed=42)

创建 决策树分类器(DecisionTreeClassifier):

DTC = DecisionTreeClassifier(labelCol="Label", featuresCol="Features")

创建 Pipeline (可选):

# 创建 Pipeline, 将特征向量转换和决策树模型组合在一起
# 注意: 如果要使用 Pipeline, 则在创建 特征向量(Features)列 的时候不需要拟合数据, 否则会报 "Output column Features already exists." 的错误。
MyPipeline = Pipeline(stages=[MyAssembler, DTC])

训练 决策树分类器(DecisionTreeClassifier) 模型:

如果在创建 特征向量(Features)列 的时候已经拟合数据:

# 训练模型 (普通模式)。
Model = DTC.fit(TrainingData)

如果在创建 特征向量(Features)列 的时候没有拟合数据:

# 训练模型 (Pipeline 模式)。
Model = MyPipeline.fit(TrainingData)

使用 决策树分类器(DecisionTreeClassifier) 模型预测数据

# 在测试集上进行预测。
Predictions = Model.transform(TestData)

# 删除不需要的列 (以免列数太多, 结果显示拥挤, 不好观察)。
Predictions = Predictions.drop("Open")
Predictions = Predictions.drop("High")
Predictions = Predictions.drop("Low")
Predictions = Predictions.drop("Close")
Predictions = Predictions.drop("Pre_Close")
Predictions = Predictions.drop("Turnover_Rate")
Predictions = Predictions.drop("Volume")
Predictions = Predictions.drop("Weekday(Idx)")
Predictions = Predictions.drop("Weekday(CN)")

print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
Predictions.show()

输出:

[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+-------------+--------------------+----------+
|      Date|   Code|   Change| MA5|MA10|Rise_Fall|MA_Relationship|Label|            Features|rawPrediction|         probability|prediction|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+-------------+--------------------+----------+
|2015-08-10|'000422| 0.034105| 8.2|7.92|        1|              1|    1|[8.58,8.18,0.0412...|    [0.0,1.0]|           [0.0,1.0]|       1.0|
|2015-08-14|'000422| 0.009479|8.43|8.24|        1|              1|    1|[8.65,8.43,0.0411...|    [1.0,0.0]|           [1.0,0.0]|       0.0|
|2015-08-18|'000422|-0.095455|8.39|8.32|        0|              1|    0|[8.86,7.92,0.0561...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-08-25|'000422|-0.099424|7.52|7.96|        0|             -1|    0|[6.77,6.25,0.0294...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-02|'000422|-0.053412|6.73|6.91|        0|             -1|    0|[6.88,6.3,0.02228...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-10|'000422|-0.031161|6.76|6.74|        0|              1|    0|[7.01,6.76,0.0174...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-18|'000422|      0.0|6.39|6.62|        1|             -1|    1|[6.58,6.3,0.01662...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-28|'000422| 0.009464|6.48|6.47|        1|              1|    1|[6.42,6.25,0.0088...|    [0.0,4.0]|           [0.0,1.0]|       1.0|
|2015-10-19|'000422|-0.007062|6.94|6.72|        0|              1|    0|[7.13,6.92,0.0312...|    [0.0,2.0]|           [0.0,1.0]|       1.0|
|2015-10-20|'000422| 0.008535|6.98|6.81|        1|              1|    1|[7.09,6.94,0.0244...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-10-21|'000422|-0.062059|6.96|6.85|        0|              1|    0|[7.11,6.61,0.0393...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-10-23|'000422| 0.054412|6.95|6.93|        1|              1|    1|[7.22,6.81,0.0471...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-10-27|'000422| 0.033426|7.04|7.01|        1|              1|    1|[7.48,7.08,0.0576...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-11-02|'000422|-0.027548|7.23| 7.1|        0|              1|    0|[7.26,7.05,0.0168...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-11-11|'000422| 0.005284|7.54|7.37|        1|              1|    1|[7.64,7.52,0.0261...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-11-20|'000422| 0.002635|7.52|7.53|        1|             -1|    1|[7.71,7.53,0.0282...|    [6.0,3.0]|[0.66666666666666...|       0.0|
|2015-12-02|'000422| 0.009511|7.37|7.49|        1|             -1|    1|[7.48,7.2,0.01596...|    [7.0,1.0]|       [0.875,0.125]|       0.0|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+-------------+--------------------+----------+

使用 BinaryClassificationEvaluator 评估模型性能

# 使用 BinaryClassificationEvaluator 评估模型性能。
MyEvaluator = BinaryClassificationEvaluator(labelCol="Label", metricName="areaUnderROC")
auc = MyEvaluator.evaluate(Predictions)

print("Area Under ROC (AUC):", auc)

输出:

Area Under ROC (AUC): 0.3

完整代码

#!/usr/bin/python3
# Create By GF 2024-01-06

# 在这个例子中, 我们使用 VectorAssembler 将多个特征列合并为一个特征向量, 并使用 DecisionTreeClassifier 构建决策树模型。
# 最后, 我们使用 BinaryClassificationEvaluator 评估模型性能, 通常使用 ROC 曲线下面积 (AUC) 作为评估指标。
# 请根据你的实际数据和问题调整特征列, 标签列以及其他参数。在实际应用中, 你可能需要进行更多的特征工程, 调参和模型评估。

import datetime
import pprint
# --------------------------------------------------
import pyspark
# --------------------------------------------------
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DateType, IntegerType, DoubleType
# --------------------------------------------------
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline

# 编写 "向 spark.sql 的 Row 对象添加字段和字段值" 函数。
def MapFunc_SparkSQL_Row_Add_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row: 

    """
    [Require] import pyspark
    
    [Example] >>> SrcRow = Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10)
              >>> NewRow = MapFunc_SparkSQL_Row_Add_Field(SrcRow=SrcRow, FldName='Weekday', FldVal=SrcRow['Date'].weekday())
              >>> print(NewRow)
              Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10, Weekday=4)
    """
    
    # Convert Obj "pyspark.sql.types.Row" to Dict. 
    # ----------------------------------------------
    Row_Dict = SrcRow.asDict()
    
    # Add a New Key in the Dictionary With the New Column Name and Value.
    # ----------------------------------------------
    Row_Dict[FldName] = FldVal
    
    # Convert Dict to Obj "pyspark.sql.types.Row". 
    # ----------------------------------------------
    NewRow = pyspark.sql.types.Row(**Row_Dict)
    
    # ==============================================
    return NewRow

# 编写 "判断股票涨跌" 函数。
def MapFunc_Stock_Judgement_Rise_or_Fall(ChgRate:float) -> int: 

    if (ChgRate >= 0.0): return 1
    if (ChgRate <  0.0): return 0
    
    # ==============================================
    # End of Function.

# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int: 

    if (Short_MA >= Long_MA): return  1
    if (Short_MA == Long_MA): return  0
    if (Short_MA <= Long_MA): return -1
    
    # ==============================================
    # End of Function.

# 编写 "返回星期几(中文)" 函数。
def DtmFunc_Weekday_Return_String_CN(SrcDtm:datetime.datetime) -> str:
    
    """
    [Require] import datetime
    
    [Explain] Python3 中 datetime.datetime 对象的 .weekday() 方法返回的是从 0 到 6 的数字 (0 代表周一, 6 代表周日)。
    """
    
    Weekday_Str_Chinese:list = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
    
    # ==============================================
    return Weekday_Str_Chinese[SrcDtm.weekday()]

if __name__ == "__main__":

    # Spark 2.0 以上版本的 spark-shell 在启动时会自动创建一个名为 spark 的 SparkSession 对象。
    # 当需要手工创建时, SparkSession 可以由其伴生对象的 builder 方法创建出来。
    spark = SparkSession.builder.master("local[*]").appName("spark").getOrCreate()
    
    # 调用 SparkSession 的 .read 方法读取 CSV 数据:
    # 其中 .option 是读取文件时的选项, 左边是 "键(Key)", 右边是 "值(Value)", 例如 .option("header", "true") 与 {header = "true"} 类同。
    SDF = spark.read.option("header", "true").option("encoding", "utf-8").csv("file:///D:\\HBYH_000422_20150806_20151231.csv")
    
    print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
    SDF.show()
    
    # 转换 Spark 中 DateFrame 数据类型。
    SDF = SDF.withColumn("Date",          col("Date").cast(DateType()))
    SDF = SDF.withColumn("Open",          col("Open").cast(DoubleType()))
    SDF = SDF.withColumn("High",          col("High").cast(DoubleType()))
    SDF = SDF.withColumn("Low",           col("Low").cast(DoubleType()))
    SDF = SDF.withColumn("Close",         col("Close").cast(DoubleType()))
    SDF = SDF.withColumn("Pre_Close",     col("Pre_Close").cast(DoubleType()))
    SDF = SDF.withColumn("Change",        col("Change").cast(DoubleType()))
    SDF = SDF.withColumn("Turnover_Rate", col("Turnover_Rate").cast(DoubleType()))
    SDF = SDF.withColumn("Volume",        col("Volume").cast(IntegerType()))
    SDF = SDF.withColumn("MA5",           col("MA5").cast(DoubleType()))
    SDF = SDF.withColumn("MA10",          col("MA10").cast(DoubleType()))
    
    # 输出 Spark 中 DataFrame 字段和数据类型。
    print("[Message] Changed Spark DataFrame Data Type:")
    SDF.printSchema()
    
    # 在 Spark 中将 DataFrame 转换为 RDD。
    CalcRDD = SDF.rdd
    
    # --------------------------------------------------
    # 调用自定义函数: 提取星期索引。
    CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(Idx)", X["Date"].weekday()))
    # ..................................................
    # 调用自定义函数: 返回星期几(中文)。
    CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(CN)", DtmFunc_Weekday_Return_String_CN(X["Date"])))
    # ..................................................
    # 调用自定义函数: 判断股票涨跌。
    CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Rise_Fall", MapFunc_Stock_Judgement_Rise_or_Fall(X["Change"])))
    # ..................................................
    # 判断股票短期均线和长期均线关系。
    CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "MA_Relationship", MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"])))
    
    # 显示计算好的 RDD 前 5 行。
    print("[Message] Calculated RDD Top 5 Rows:")
    pprint.pprint(CalcRDD.take(5))
    
    # 在 Spark 中将 RDD 转换为 DataFrame。
    NewSDF = CalcRDD.toDF()
    
    print("[Message] Convert RDD to DataFrame and Filter Out Key Columns for Display:")
    NewSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship"]).show()
    
    # 提取 标签(Label) 列: 将 Rise_Fall 列复制为 Label 列。
    NewSDF = NewSDF.withColumn("Label", col("Rise_Fall"))
    
    # 创建 特征向量(Features) 列: VectorAssembler 将多个特征合并为一个特征向量。
    FeaColsName:list = ["High", "Low", "Turnover_Rate", "Volume", "Weekday(Idx)", "MA_Relationship"]
    MyAssembler = VectorAssembler(inputCols=FeaColsName, outputCol="Features")

    # 创建 特征向量(Features) 列: 拟合数据 (可选, 如果在模型训练时使用 Pipeline, 则无需在此步骤拟合数据, 当然也就无法在此步骤预览数据)。
    AssembledSDF = MyAssembler.transform(NewSDF)
    
    print("[Message] Assembled Label and Features for DecisionTreeClassifier:")
    AssembledSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship", "Label", "Features"]).show()
    
    # 将数据集划分为 "训练集" 和 "测试集"。
    (TrainingData, TestData) = AssembledSDF.randomSplit([0.8, 0.2], seed=42)
    
    # 创建 决策树分类器(DecisionTreeClassifier)。
    DTC = DecisionTreeClassifier(labelCol="Label", featuresCol="Features")
    
    # 创建 Pipeline (可选): 将特征向量转换和决策树模型组合在一起
    # 注意: 如果要使用 Pipeline, 则在创建 特征向量(Features)列 的时候不需要拟合数据, 否则会报 "Output column Features already exists." 的错误。
    #MyPipeline = Pipeline(stages=[MyAssembler, DTC])
    
    # 训练模型 (普通模式)。
    Model = DTC.fit(TrainingData)
    
    # 训练模型 (Pipeline 模式)。
    #Model = MyPipeline.fit(TrainingData)
    
    # 在测试集上进行预测。
    Predictions = Model.transform(TestData)
    
    # 删除不需要的列 (以免列数太多, 结果显示拥挤, 不好观察)。
    Predictions = Predictions.drop("Open")
    Predictions = Predictions.drop("High")
    Predictions = Predictions.drop("Low")
    Predictions = Predictions.drop("Close")
    Predictions = Predictions.drop("Pre_Close")
    Predictions = Predictions.drop("Turnover_Rate")
    Predictions = Predictions.drop("Volume")
    Predictions = Predictions.drop("Weekday(Idx)")
    Predictions = Predictions.drop("Weekday(CN)")
    
    print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
    Predictions.show()
    
    # 使用 BinaryClassificationEvaluator 评估模型性能。
    MyEvaluator = BinaryClassificationEvaluator(labelCol="Label", metricName="areaUnderROC")
    auc = MyEvaluator.evaluate(Predictions)

    print("Area Under ROC (AUC):", auc)

其它

在这个例子中, 我们使用 VectorAssembler 将多个特征列合并为一个特征向量, 并使用 DecisionTreeClassifier 构建决策树模型。

最后, 我们使用 BinaryClassificationEvaluator 评估模型性能, 通常使用 ROC 曲线下面积 (AUC) 作为评估指标。

请根据你的实际数据和问题调整特征列, 标签列以及其他参数。在实际应用中, 你可能需要进行更多的特征工程, 调参和模型评估。

总结

以上就是关于 金融数据 PySpark-3.0.3决策树(DecisionTreeClassifier)实例 的全部内容。

更多内容可以访问我的代码仓库:

https://gitee.com/goufeng928/public

https://github.com/goufeng928/public文章来源地址https://www.toymoban.com/news/detail-859871.html

到了这里,关于金融数据_PySpark-3.0.3决策树(DecisionTreeClassifier)实例的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【观察】金融行业决策智能化“换挡提速” 华为全球智慧金融峰会2023值得期待...

    当前以数字化、智能化为特征的第四次工业革命正“扑面而来”,数字经济浪潮对各行各业都产生着深刻影响。其中,金融行业作为现代经济的核心,也面临着一系列重大的挑战和机遇。 相比于其他企业,金融行业依靠数据分析和智能决策更好地服务客户,已成为其核心竞争

    2024年02月08日
    浏览(42)
  • 人工智能在金融投资决策中的应用与未来

    随着人工智能(AI)技术的不断发展和进步,金融领域也逐渐开始利用这一技术来提高投资决策的效率和准确性。AI在金融投资决策中的应用主要体现在数据分析、风险管理、交易策略优化等方面。本文将从以下几个方面进行阐述: 背景介绍 核心概念与联系 核心算法原理和具体

    2024年02月20日
    浏览(63)
  • 金融供应链智能合约 -- 智能合约实例

    前提 Ownable:监管者合约,有一个函数能转让监管者。 SupplyChainFin:供应链金融合约,银行、公司信息上链,公司和银行之间的转账。 发票:记录者交易双方和交易金额等的一种记录数据。如:我在超市买了一瓶水,超市给我开了一张发票。 Ownable SupplyChainFin

    2024年02月14日
    浏览(39)
  • 车辆换道决策的模糊控制算法实例

    目录 一、模糊控制在换道决策应用上的概念介绍 1.1 模糊化 1.2 建立模糊规则 1.3 解模糊 二、Matlab建立模糊逻辑系统 2.1 Matlab模糊逻辑工具箱 2.2 建立模糊系统的步骤  2.3 建立换道决策的模糊逻辑系统         实际驾车时,车辆在行驶过程中能否去执行换道的决策与道路信

    2024年02月10日
    浏览(56)
  • 使用Pycharm运行spark实例时没有pyspark包(ModuleNotFoundError: No module named ‘py4j‘)

    在安装并配置pyspark,下载并打开Pycharm(专业版)后进行spark实例操作(笔者以统计文件中的行数为例)时,运行程序后提示ModuleNotFoundError: No module named \\\'py4j\\\': 1.下载py4j包后下载pyspark包 打开新终端,在终端中输入(若在pycharm中进行下载可能导致下载失败,这里指定使用清华

    2024年04月26日
    浏览(39)
  • 【Python】PySpark 数据处理 ② ( 安装 PySpark | PySpark 数据处理步骤 | 构建 PySpark 执行环境入口对象 )

    执行 Windows + R , 运行 cmd 命令行提示符 , 在命令行提示符终端中 , 执行 命令 , 安装 PySpark , 安装过程中 , 需要下载 310 M 的安装包 , 耐心等待 ; 安装完毕 : 命令行输出 : 如果使用 官方的源 下载安装 PySpark 的速度太慢 , 可以使用 国内的 镜像网站 https://pypi.tuna.tsinghua.edu.cn/simple

    2024年02月06日
    浏览(43)
  • Python大数据之PySpark(二)PySpark安装

    1-明确PyPi库,Python Package Index 所有的Python包都从这里下载,包括pyspark 2-为什么PySpark逐渐成为主流? http://spark.apache.org/releases/spark-release-3-0-0.html Python is now the most widely used language on Spark. PySpark has more than 5 million monthly downloads on PyPI, the Python Package Index. 记住如果安装特定的版本

    2024年02月04日
    浏览(43)
  • PySpark数据分析基础:PySpark Pandas创建、转换、查询、转置、排序操作详解

    目录 前言 一、Pandas数据结构 1.Series 2.DataFrame  3.Time-Series  4.Panel 5.Panel4D 6.PanelND 二、Pyspark实例创建 1.引入库 2.转换实现 pyspark pandas series创建 pyspark pandas dataframe创建 from_pandas转换  Spark DataFrame转换  三、PySpark Pandas操作 1.读取行列索引 2.内容转换为数组 3.DataFrame统计描述 4.转

    2024年02月02日
    浏览(57)
  • PySpark数据分析基础:PySpark基础功能及DataFrame操作基础语法详解

    目录 前言 一、PySpark基础功能  1.Spark SQL 和DataFrame 2.Pandas API on Spark 3.Streaming 4.MLBase/MLlib 5.Spark Core 二、PySpark依赖 Dependencies 三、DataFrame 1.创建 创建不输入schema格式的DataFrame 创建带有schema的DataFrame 从Pandas DataFrame创建 通过由元组列表组成的RDD创建 2.查看 DataFrame.show() spark.sql.

    2024年01月18日
    浏览(56)
  • Python大数据之PySpark

    Apache Spark是一种用于大规模数据处理的多语言分布式引擎,用于在单节点机器或集群上执行数据工程、数据科学和机器学习 Spark官网:https://spark.apache.org/ 按照官网描述,Spark关键特征包括: 批/流处理 Spark支持您使用喜欢的语言:Python、SQL、Scala、Java或R,统一批量和实时流处

    2024年02月08日
    浏览(44)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包