pyspark使用XGboost训练模型实例

遇到一个还不错的使用Xgboost训练模型的githubhttps://github.com/MachineLP/Spark-/tree/master/pyspark-xgboost

1、这是一个跑通的代码实例,使用的是泰坦尼克生还数据,分类模型。

这里使用了Pipeline来封装特征处理和模型训练步骤,保存为pipelineModel

注意这里加载xgboost依赖的jar包和zip包的方法。

#这是用 pipeline 包装了XGBOOST的例子。 此路通!import os
import sys
import time
import pandas as pd
import numpy as np
import pyspark.sql.types as typ
import pyspark.ml.feature as ft
from pyspark.sql.functions import isnan, isnullfrom pyspark.sql.types import StructType, StructFieldfrom pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.sql.functions import col
from pyspark.sql import SparkSessionos.environ['PYSPARK_PYTHON'] = 'Python3.7/bin/python'
os.environ['PYSPARK_SUBMIT_ARGS'] = '--jars xgboost4j-spark-0.90.jar,xgboost4j-0.90.jar pyspark-shell'spark = SparkSession \.builder \.appName("PySpark XGBOOST Titanic") \.config('spark.driver.allowMultipleContexts', 'true') \.config('spark.pyspark.python', 'Python3.7/bin/python') \.config('spark.yarn.dist.archives', 'hdfs://ns62007/user/dmc_adm/_PYSPARK_ENV/Python3.7.zip#Python3.7') \.config('spark.executorEnv.PYSPARK_PYTHON', 'Python3.7/bin/python') \.config('spark.sql.autoBroadcastJoinThreshold', '-1') \.enableHiveSupport() \.getOrCreate()spark.sparkContext.addPyFile("sparkxgb.zip")schema = StructType([StructField("PassengerId", DoubleType()),StructField("Survived", DoubleType()),StructField("Pclass", DoubleType()),StructField("Name", StringType()),StructField("Sex", StringType()),StructField("Age", DoubleType()),StructField("SibSp", DoubleType()),StructField("Parch", DoubleType()),StructField("Ticket", StringType()),StructField("Fare", DoubleType()),StructField("Cabin", StringType()),StructField("Embarked", StringType())])upload_file = "titanic/train.csv"
hdfs_path = "hdfs://tmp/gao/dev_data/dmb_upload_data/"
file_path = os.path.join(hdfs_path, upload_file.split("/")[-1])df_raw = spark\.read\.option("header", "true")\.schema(schema)\.csv(file_path)df_raw.show(20)
df = df_raw.na.fill(0)sexIndexer = StringIndexer()\.setInputCol("Sex")\.setOutputCol("SexIndex")\.setHandleInvalid("keep")cabinIndexer = StringIndexer()\.setInputCol("Cabin")\.setOutputCol("CabinIndex")\.setHandleInvalid("keep")embarkedIndexer = StringIndexer()\.setInputCol("Embarked")\.setHandleInvalid("keep")# .setOutputCol("EmbarkedIndex")\vectorAssembler = VectorAssembler()\.setInputCols(["Pclass", "Age", "SibSp", "Parch", "Fare"])\.setOutputCol("features")from sparkxgb import XGBoostClassifier
xgboost = XGBoostClassifier(maxDepth=3,missing=float(0.0),featuresCol="features",labelCol="Survived"
)pipeline = Pipeline(stages=[vectorAssembler, xgboost])trainDF, testDF = df.randomSplit([0.8, 0.2], seed=24)
trainDF.show(2)
model = pipeline.fit(trainDF)print (88888888888888888888)
model.transform(testDF).select(col("PassengerId"), col("Survived"), col("prediction")).show()
print (9999999999999999999)# Write model/classifier
model.write().overwrite().save(os.path.join(hdfs_path,"xgboost_class_test"))from pyspark.ml import PipelineModel
model1 = PipelineModel.load(os.path.join(hdfs_path,"xgboost_class_test"))
model1.transform(testDF).show()

这是执行结果:

2、当然也可以不用pipeline封装,直接训练xgboost模型,并保存。

但这里遇到无法加载训练好的xgb模型的问题。

# Train a xgboost model
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler
from pyspark.ml import Pipeline
from sparkxgb import XGBoostClassifierassembler = VectorAssembler(inputCols=[ 'Pclass', 'Age', 'SibSp', 'Parch','Fare'],outputCol="features", handleInvalid="skip")xgboost = XGBoostClassifier(maxDepth=3,missing=float(0.0),featuresCol="features", labelCol="Survived")# pipeline = Pipeline(stages=[assembler, xgboost])
# trained_model = pipeline.fit(data)td = assembler.transform(data)
trained_raw_model = xgboost.fit(td)result = trained_raw_model.transform(td)
result.select(["Survived", "rawPrediction", "probability", "prediction"]).show()# save trained model to local disk
trained_raw_model.nativeBooster.saveModel("outputmodel.xgboost")# 无法加载已经训练好的XGB模型
from sparkxgb import XGBoostClassifier,XGBoostClassificationModel
model1= XGBoostClassificationModel.load("outputmodel.xgboost")
model1.transform(td).show()

这是运行结果:

 这里报错,无法使用 XGBoostClassificationModel加载已经训练好的XGB模型。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/75974.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【LeetCode每日一题】——304.二维区域和检索-矩阵不可变

文章目录 一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【题目提示】七【解题思路】八【时间频度】九【代码实现】十【提交结果】 一【题目类别】 矩阵 二【题目难度】 中等 三【题目编号】 304.二维区域和检索-矩阵不可变 四【题目描述】 …

JS进阶-Day3

🥔:永远做自己的聚光灯 JS进阶-Day1——点击此处(作用域、函数、解构赋值等) JS进阶-Day2——点击此处(深入对象之构造函数、实例成员、静态成员等;内置构造函数之引用类型、包装类型等) 更多JS…

【C++】—— 多态常见的笔试和面试问题

序言: 在上期,我们对多态进行了详细的讲解。本期,我给大家带来的是关于有关多态常见的笔试和面试问题,帮助大家理解记忆相关知识点。 目录 (一)概念查考 (二)问答题 1、简述一下…

C数据结构与算法——哈希表/散列表创建过程中的冲突与聚集(哈希查找) 应用

实验任务 (1) 掌握散列算法(散列函数、散列存储、散列查找)的实现; (2) 掌握常用的冲突解决方法。 实验内容 (1) 选散列函数 H(key) key % p,取散列表长 m 为 10000,p 取小于 m 的最大素数; (2) 测试 α…

阿里云安全组设置

简介​ 云主机安全组必须打开如下端口: ssh:22http:80https:443ftp:21、20000~30000 阿里云安全组端口开放教程​ 腾讯云安全组端口开放教程​ 华为云安全组端口开放教程​

C 语言高级2-多维数组,结构体,递归操作

1. 多维数组 1.1 一维数组 元素类型角度:数组是相同类型的变量的有序集合内存角度:连续的一大片内存空间 在讨论多维数组之前,我们还需要学习很多关于一维数组的知识。首先让我们学习一个概念。 1.1.1 数组名 考虑下面这些声明&#xff1…

华为云低代码平台Astro Canvas 搭建汽车展示大屏——实验指导手册

实验背景 大屏应用Astro Canvas是华为云低代码平台Astro的子服务之一,是以数据可视化为核心,以屏幕轻松编排,多屏适配可视为基础,用户可通过图形化界面轻松搭建专业水准的数据可视化大屏。例如汽车展示大屏、监控大屏、项目开发大…

docker端口映射详解(随机端口、指定IP端口、随意ip指定端口、指定ip随机端口)

目录 docker端口映射详解 一、端口映射概述: 二、案例实验: 1、-P选项,随机端口 2、使用-p可以指定要映射到的本地端口。 Local_Port:Container_Port,任意地址的指定端口 Local_IP:Local_Port:Container_Port 映射到指定地…

从零开始理解Linux中断架构(24)软中断核心函数__do_softirq

1)概要 __do_softirq函数处理是总是尽可能的执行所有未决软中断。 (1)关闭软中断:在preempt_count设置软中断标志:SOFTIRQ_OFFSET 让in_interrupt检查条件为真,进入软中断处理临界区,后面进来的处理请求,需要检查in_interrupt(),从而达到禁止本cpu上的软中断嵌套的目…

【C语言进阶】指针的高级应用(上)

本专栏介绍:免费专栏,并且会持续更新C语言知识,欢迎各位订阅关注。 关注我,带你了解更多关于机器人、嵌入式、人工智能等方面的优质文章,坚持更新! 大家的支持才是更新的最强动力! 文章目录 一、…

【Spring框架】Spring AOP

目录 什么是AOP?AOP组成Spring AOP 实现步骤Spring AOP实现原理JDK Proxy VS CGLIB 什么是AOP? AOP(Aspect Oriented Programming):⾯向切⾯编程,它是⼀种思想,它是对某⼀类事情的集中处理。⽐如…

林大数据结构【2019】

关键字: 哈夫曼树权值最小、哈夫曼编码、邻接矩阵时间复杂度、二叉树后序遍历、二叉排序树最差时间复杂度、非连通无向图顶点数(完全图)、带双亲的孩子链表、平衡二叉树调整、AOE网关键路径 一、判断 二、单选 三、填空 四、应用题

Flutter运行app时向logcat输出当前打开的界面路径且点击可跳转

当一个项目大了目录文件多了,我们往往会为了找到一个文件花费大量的时间和精力,为了快捷方便的调试我们的项目,我们往往需要在打开app运行的时候需要知道当前打开的界面的文件在哪儿,我们这个代码就能快捷的知道我们app正在打开的…

MySQL存储过程(二十四)

你相信吗, 相信那一天的夕阳吗? 上一章简单介绍了 MySQL的索引(二十三),如果没有看过,请观看上一章 一. 存储过程 MySQL从5.0版本开始支持存储过程和函数。存储过程和函数能够将复杂的SQL逻辑封装在一起, 应用程序无须关注存储过程和函数内部复杂的S…

浪潮服务器硬盘指示灯显示黄色的服务器数据恢复案例

服务器数据恢复环境: 宁夏某市某单位的一台浪潮服务器,该服务器中有一组由6块SAS硬盘组建的RAID5阵列。 服务器上存放的是Oracle数据库文件,操作系统层面划分了1个卷。 服务器故障&初检: 服务器在运行过程中有两块磁盘的指示灯…

一个3年Android的找工作记录

作者:Petterp 这是我最近 1个月 的找工作记录,希望这些经历对你会有所帮助。 有时机会就像一阵风,如果没有握住,那下一阵风什么时候吹来,往往是个运气问题。 写在开始 先说背景: 自考本,3年经验&#xff0…

掌握 JVM 的参数及配置

点击下方关注我,然后右上角点击...“设为星标”,就能第一时间收到更新推送啦~~~ JVM(Java虚拟机)是Java编程语言的核心组件之一,它负责执行Java程序,并提供一系列参数和配置选项,可以调整Java程…

JVM | 从类加载到JVM内存结构

引言 我在上篇文章:JVM | 基于类加载的一次完全实践 中为你讲解如何请“建筑工人”来做一些定制化的工作。但是,大型的Java应用程序时,材料(类)何止数万,我们直接堆放在工地上(JVM)…

AI Chat 设计模式:12. 享元模式

本文是该系列的第十二篇,采用问答式的方式展开,问题由我提出,答案由 Chat AI 作出,灰色背景的文字则主要是我的一些思考和补充。 问题列表 Q.1 给我介绍一下享元模式A.1Q.2 也就是说,其实共享的是对象的内部状态&…

TCP的三次握手四次挥手

TCP的三次握手和四次挥手实质就是TCP通信的连接和断开。 三次握手:为了对每次发送的数据量进行跟踪与协商,确保数据段的发送和接收同步,根据所接收到的数据量而确认数据发送、接收完毕后何时撤消联系,并建立虚连接。 四次挥手&a…