ChatGLM-6B模型训练自己的数据集

ChatGLM-6B模型训练自己的数据集

上期我主要分享了一下ChatGLM-6B官方模型的部署、官方数据集的微调、推理以及测试过程,这期我将主要分享一下使用ChatGLM-6B微调自己数据集的过程。上期链接

1.首先将自己处理好的数据集拷贝到’ChatGLM-6B/ptuning/’文件夹下,可以新建一个自己的数据集文件夹如mydata。

我新建的文件夹mydata

2.首先要修改train.sh中的参数,官方train.sh文档:

PRE_SEQ_LEN=128
LR=2e-2CUDA_VISIBLE_DEVICES=0 python3 main.py \--do_train \--train_file AdvertiseGen/train.json \--validation_file AdvertiseGen/dev.json \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path THUDM/chatglm-6b \--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 64 \--max_target_length 64 \--per_device_train_batch_size 1 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 16 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4

参数解释如下(自己的理解,不正确欢迎指正):

1. PRE_SEQ_LEN:这个是输入的最大序列长度,可以根据你的数据集适当调大或调小,一般64-128是一个合理范围。
2. LR:学习率,可以适当调小,如1e-3 - 5e-3,因为这是微调,学习率不需要太大。
3. CUDA_VISIBLE_DEVICES:设置使用的GPU设备。
4. --train_file 和 --validation_file:设置你自己的数据集路径。
5. --model_name_or_path:设置为 chatglm-6b 的模型路径。
6. --output_dir:设置微调后的模型和日志输出路径。
7.- max_source_length 指定我们输入对话文本(即软提示)的最大长度。如果某个输入文本超过这个长度,则会截断;如果短于这个长度,则会在末尾填充padding。
- max_target_length 指定模型输出的响应文本的最大长度。如果模型生成的响应超过这个长度,则会截断;如果短于这个长度,则会在末尾填充padding。
-所以,这两个参数的设置值需要根据你的硬件情况和数据特点来确定。一般来说,64-128对于输入,32-64对于输出是比较合理的范围。你可以在开发集上进行测试,观察模型输出的响应结果和计算开销来选择最优值。
设置这两个参数的目的是:
(1) 使所有训练实例输入和输出的长度统一,以方便模型的训练和batch的构造。
(2) 避免过长的输入和输出导致的计算开销大和内存超限的问题。
(3) 鼓励模型学习如何在约束长度内生成更精准和相关的响应。
7. --gradient_accumulation_steps:设置梯度累积步数,可以适当增大,如8-32,以便使用更大的batch size。
8. --per_device_train_batch_size 和 --per_device_eval_batch_size:可以适当增大,以加快训练速度,如4-8。
9. --max_steps:设置最大训练步数,根据数据集大小适当调整,一般3000-10000步是一个合理范围。
11. --logging_steps 和 --save_steps:设置日志记录步数和模型保存步数,可以根据需要自己更改。
12.--quantization_bit:模型参数的精度选择。
这个参数是关于 P-Tuning 方法中的模型量化(model quantization)方法的选择。
P-Tuning 方法中的量化是指将模型参数从浮点数(FP32)精度量化为低精度,如FP16(半精度)或更低。这可以大大减小模型的参数量,加速推理速度,同时牺牲一定的精度。
在这句话中,作者指出:
1. P-Tuning 方法会冻结(freeze)预训练模型(如GPT2)的全部参数,以此作为量化的起点。
2. 通过调整quantization_bit 参数可以选择不同的量化级别。不设置此参数则默认使用FP16半精度。
3. 调整量化级别会影响模型精度,需要在验证集上测试不同的量化级别,选择精度损失最小的配置。
也就是说,P-Tuning 方法采用冻结预训练模型,然后对其进行量化和微调的策略。设置quantization_bit参数可以选择FP16之外的更低精度,如8bit或4bit,以获得更高的加速比,同时尽可能保留精度。
量化会使模型参数变得更加稀疏,进而可以大幅压缩模型体积和加速推理。但同时也会损失一定精度。所以需要根据实际情况选择一个合理的trade-off。
总的来说,这句话的意思就是说明P-Tuning方法采用了模型冻结和可调量化级别来实现从FP32到低精度的转换,以获得较高的加速比。quantization_bit参数可以选择不同的量化精度,无此参数则默认为FP16。

3.我的train.sh参数设置

PRE_SEQ_LEN=128
LR=2e-3 #因为是微调所以我考虑减小了学习率CUDA_VISIBLE_DEVICES=0 python3 main.py \--do_train \--train_file mydata/train.json \  #我的训练数据集地址--validation_file mydata/test.json \  #我的测试数据集地址--prompt_column question\  #我的数据集标签--response_column answer\  #我的数据集标签--overwrite_cache \--model_name_or_path THUDM/chatglm-6b \--output_dir myoutput/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 64 \   #输入句子的最大目标长度--max_target_length 512\   #输出句子的最大目标长度--per_device_train_batch_size 4 \  #由于我的内存比较大,我增大了batch_size--per_device_eval_batch_size 1 \--gradient_accumulation_steps 8 \   --predict_with_generate \--max_steps 3000 \--logging_steps 50 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \

这里我去掉了 --quantization_bit 4,考虑到我有足够的内存,我采用了精度较高的FP16。

4.推理

推理参数调整参考微调。

5.利用微调后的模型进行验证

参考上期

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/44437.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

刚拿到北京户口就离职,员工赔了180000!

推荐专门分享AI技术的公众号 关注后,回复:ChatGPT ,领取账号 公众号“互联网坊间八卦” 之所以写这个话题,是因为今天又看到了一个关于北京落户的案例。 前不久,北京市政府发布工作报告。数据中提到,2022年…

李彦宏宣布设立10亿创投基金促进大模型生态发展;Kindle中国电子书店停止运营;Bootstrap 5.3发布|极客头条...

「极客头条」—— 技术人员的新闻圈! CSDN 的读者朋友们早上好哇,「极客头条」来啦,快来看今天都有哪些值得我们技术人关注的重要新闻吧。 整理 | 梦依丹 出品 | CSDN(ID:CSDNnews) 一分钟速览新闻点&#…

汉王考勤管理系统 与服务器连接失败,汉王考勤管理系统

汉王考勤管理系统是一款功能强大的考勤管理软件,软件为用户提供了基本信息管理、人员排班管理、考勤处理与统计等多个不同模块,能够帮助用户对企业的的考勤进行统计与管理,而且能够支持一键生成各类报表,并支持以Excel等多种格式导出考勤信息,能够极大的提示用户的报表统计…

时间序列预测之DeepAR

目录 前言 一、模型介绍 1、模型框架介绍 2、训练策略 3、似然函数模型 4、损失函数 二、论文精华 1.尺度处理 三、仿真实验 1、数据集介绍 2、评价指标 2.1 评价指标1(分布式评估) 2.2 评价指标2(点预测评估) 2.3 定性分析 总结 前言 最近看论文《DeepAR:Probabil…

基于Prophet时间序列的监测值预测

留全部代码备份 通过facebook开源模型Prophet对未来时间内某基坑变形监控值进行预测,但该模型好像并不适用于这种施工过程中的数据预测,但是至少能预测,交差总没问题吧。预测10天。 import pandas as pd from matplotlib import pyplot as …

facebook时间序列预测算法prophet解读+实战

facebook时间序列预测算法prophet解读实战 原理解读一、时间序列的分解二、趋势项模型基于逻辑回归的趋势项定义变点(change point) 基于线性回归的趋势项变点的选择 三、季节性趋势四、节假日影响 模型实战 原理解读 prophet与常用的自回归时间序列预测…

Kaggle系列之预测泰坦尼克号人员的幸存与死亡(随机森林模型)

Kaggle是开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,本节是对于初次接触的伙伴们一个快速了解和参与比赛的例子,快速熟悉这个平台。当然提交预测结果需要注册,这个可能需要科学上网了。 我们选择一个预测的入…

【时间序列预测】人口数量预测神经网络程序

下载完整代码 clc;clear; %导入1949年至2010年人口数据 dataimportdata(population_data.txt); lag3; %利用前3年数据做为输入,去预测下一年人口数量 nlength(data); %计算数据长度 %% %准备输入和输出数据 inputszeros(lag,n-lag); for i1:n-lag inpu…

Prophet:一种大规模时间序列预测模型

前言 Prophet是由facebook开发的开源时间序列预测程序,擅长处理具有季节性特征大规模商业时间序列数据。本文主要介绍了Prophet模型的设计原理,并与经典的时间序列模型ARIMA进行了对比。 1. Prophet模型原理 Prophet模型把一个时间序列看做由3种主要成分…

时间序列预测算法梳理(Arima、Prophet、Nbeats、NbeatsX、Informer)

时间序列预测算法梳理(Arima、Prophet、Nbeats、NbeatsX、Informer) Arima1. 算法原理2. 算法实现 Prophet1. 优点2. 算法实现3.算法api实现(fbprophet调api) Nbeats1. Nbeats优点2. Nbeats模型结构 NbeatsXInformer参考&#xff…

Prophet 时间序列预测

Prophet 允许使用具有指定承载能力的物流增长趋势模型进行预测。 我们必须在列中指定承载能力cap。在这里,我们将假设一个特定的值,但这通常是使用有关市场规模的数据或专业知识来设置的。 # Python df[cap] 8.5需要注意的重要事项是cap必须为数据框中…

时间序列预测方法之 DeepAR

本文链接:个人站 | 简书 | CSDN 版权声明:除特别声明外,本博客文章均采用 BY-NC-SA 许可协议。转载请注明出处。 最近打算分享一些基于深度学习的时间序列预测方法。这是第一篇。 DeepAR 是 Amazon 于 2017 年提出的基于深度学习的时间序列预…

【时间序列】初识时间序列预测神器 NeuralProphet 实战预测股票指数

历经神奇的2022年,终于迎来曙光的2023年,新的一年,MyEncyclopedia 会和小伙伴们一同学习思考实践。长风破浪会有时,直挂云帆济沧海!共勉之 NeuralProphet深度学习Prophet NeuralProphet 负有盛名,是 Facebo…

时序预测 | Python实现TCN时间卷积神经网络时间序列预测

时序预测 | Python实现TCN时间卷积神经网络时间序列预测 目录 时序预测 | Python实现TCN时间卷积神经网络时间序列预测预测效果基本介绍环境准备模型描述程序设计学习小结参考资料预测效果 基本介绍 递归神经网络 (RNN),尤其是 LSTM,非常适合时间序列处理。 作为研究相关技术…

使用sklearn.ensemble.RandomForestRegressor和GridSearchCV进行成人死亡率预测

原文链接:https://blog.csdn.net/weixin_44491423/article/details/127011461 本文借鉴博主hhhcbw实现方法完成随机森林回归预测成人死亡率,使用训练数据测试模型的最优得分R20.8161,在测试集上得分R20.5825 成年人死亡率指的是每一千人中15岁…

基于TCN时间序列预测Python程序

基于TCN预测模型 特色:1、单变量,多变量输入,自由切换 2、单步预测,多步预测,自动切换 3、基于Pytorch架构 4、多个评估指标(MAE,MSE,R2,MAPE等) 5、数据从excel文件中读取,更换简单…

时序预测 | Python实现Attention-TCN注意力机制时间卷积神经网络的多元时间序列预测

时序预测 | Python实现Attention-TCN注意力机制时间卷积神经网络的多元时间序列预测 目录 时序预测 | Python实现Attention-TCN注意力机制时间卷积神经网络的多元时间序列预测预测效果基本介绍环境配置程序设计模型效果参考资料预测效果 基本介绍 使用时间注意卷积神经网络进行…

AI预测死亡时间,准确率95%

(本内容转载自公众号“科技与Python”) 日前,谷歌新出炉的一项研究报告称,该公司已开发出一种新人工智能(AI)算法,可预测人的死亡时间,且准确率高达95%。最近,谷歌的这项研究发表在了《自然》杂…

FACEBOOK 时间序列预测算法 PROPHET 的研究

1.思想 在时间序列分析领域,有一种常见的分析方法叫做时间序列的分解(Decomposition of Time Series),它把时间序列 分成几个部分,分别是季节项 ,趋势项 ,剩余项 。也就是说对所有的 &#xff…