GPT建模与预测实战

代码链接见文末

效果图:

1.数据样本生成方法

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl

其中train.pkl是处理后的文件

因此,我们首先需要执行preprocess.py进行预处理操作,配置参数:

--data_path data/novel --save_path data/train.pkl --win_size 200 --step 200

其中--vocab_file是语料表,一般不用修改,--log_path是日志路径

预处理流程如下:

  • 首先,初始化tokenizer
  • 读取作文数据集目录下的所有文件,预处理后,对于每条数据,使用滑动窗口对其进行截断
  • 最后,序列化训练数据 

代码如下:

# 初始化tokenizertokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")#pip install jiebaeod_id = tokenizer.convert_tokens_to_ids("<eod>")   # 文档结束符sep_id = tokenizer.sep_token_id# 读取作文数据集目录下的所有文件train_list = []logger.info("start tokenizing data")for file in tqdm(os.listdir(args.data_path)):file = os.path.join(args.data_path, file)with open(file, "r", encoding="utf8")as reader:lines = reader.readlines()title = lines[1][3:].strip()    # 取出标题lines = lines[7:]   # 取出正文内容article = ""for line in lines:if line.strip() != "":  # 去除换行article += linetitle_ids = tokenizer.encode(title, add_special_tokens=False)article_ids = tokenizer.encode(article, add_special_tokens=False)token_ids = title_ids + [sep_id] + article_ids + [eod_id]# train_list.append(token_ids)# 对于每条数据,使用滑动窗口对其进行截断win_size = args.win_sizestep = args.stepstart_index = 0end_index = win_sizedata = token_ids[start_index:end_index]train_list.append(data)start_index += stepend_index += stepwhile end_index+50 < len(token_ids):  # 剩下的数据长度,大于或等于50,才加入训练数据集data = token_ids[start_index:end_index]train_list.append(data)start_index += stepend_index += step# 序列化训练数据with open(args.save_path, "wb") as f:pickle.dump(train_list, f)

2.模型训练过程

 (1) 数据与标签

        在训练过程中,我们需要根据前面的内容预测后面的内容,因此,对于每一个词的标签需要向后错一位。最终预测的是每一个位置的下一个词的token_id的概率。

(2)训练过程

        对于每一轮epoch,我们需要统计该batch的预测token的正确数与总数,并计算损失,更新梯度。

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0  # 记录下整个epoch的loss的总和epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量epoch_total_num = 0  # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):# 捕获cuda out of memory exceptiontry:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()# 统计该batch的预测token的正确数与总数batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 进行一定step的梯度累计之后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exception# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# save modellogger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss

(3)部署与网页预测展示

        app.py既是模型预测文件,又能够在网页中展示,这需要我们下载一个依赖库:

pip install streamlit

        

生成下一个词流程,每次只根据当前位置的前context_len个token进行生成:

  • 第一步,先将输入文本截断成训练的token大小,训练时我们采用的200,截断为后200个词
  • 第二步,预测的下一个token的概率,采用温度采样和topk/topp采样

最终,我们不断的以自回归的方式不断生成预测结果

这里指定模型目录 

进入项目路径

执行streamlit run app.py 

 生成效果:

 数据与代码链接:https://pan.baidu.com/s/1XmurJn3k_VI5OR3JsFJgTQ?pwd=x3ci 
提取码:x3ci 

 

         

      

 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/306859.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式锁-redission锁的MutiLock原理

5.5 分布式锁-redission锁的MutiLock原理 为了提高redis的可用性&#xff0c;我们会搭建集群或者主从&#xff0c;现在以主从为例 此时我们去写命令&#xff0c;写在主机上&#xff0c; 主机会将数据同步给从机&#xff0c;但是假设在主机还没有来得及把数据写入到从机去的时…

【Android surface 】二:源码分析App的surface创建过程

文章目录 画布surfaceViewRoot的创建&setView分析setViewrequestLayoutViewRoot和WMS的关系 activity的UI绘制draw surfacejni层分析Surface无参构造SurfaceSessionSurfaceSession_init surface的有参构造Surface_copyFromSurface_writeToParcelSurface_readFromParcel 总结…

文心一言

文章目录 前言一、首页二、使用总结 前言 今天给大家带来百度的文心一言,它基于百度的文心大模型,是一种全新的生成式人工智能工具。 一、首页 首先要登录才能使用,左侧可以看到以前的聊天历史 3.5的目前免费用,但是4.0的就需要vip了 二、使用 首先在最下方文本框输入你想要搜…

Harmony鸿蒙南向驱动开发-SDIO接口使用

功能简介 SDIO是安全数字输入输出接口&#xff08;Secure Digital Input and Output&#xff09;的缩写&#xff0c;是从SD内存卡接口的基础上演化出来的一种外设接口。SDIO接口兼容以前的SD卡&#xff0c;并且可以连接支持SDIO接口的其他设备。 SDIO接口定义了操作SDIO的通用…

总分410+专业130+国防科技大学831信号与系统考研经验国防科大电子信息与通信工程,真题,大纲,参考书。

好几个学弟催着&#xff0c;总结一下我自己的复习经历&#xff0c;希望大家复习少走弯路&#xff0c;投入的复习正比换回分数。我专业课831信号与系统130&#xff08;感觉比估分要低&#xff0c;后面找Jenny老师讨论了自己拿不准的地方也没有错误&#xff0c;心里最近也这经常回…

记账本|基于SSM的家庭记账本小程序设计与实现(源码+数据库+文档)

家庭记账本小程序目录 基于SSM的家庭记账本小程序设计与实现 一、前言 二、系统设计 三、系统功能设计 1、小程序端&#xff1a; 2、后台 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大…

【opencv】示例-imagelist_creator.cpp 从命令行参数中创建一个图像文件列表(yaml格式)...

/* 这个程序可以创建一个命令行参数列表的yaml或xml文件列表 */ // 包含必要的OpenCV头文件 #include "opencv2/core.hpp" #include "opencv2/imgcodecs.hpp" #include "opencv2/highgui.hpp" #include <string> #include <iostream>…

视频秒播优化实践

本文字数&#xff1a;2259字 预计阅读时间&#xff1a;10分钟 视频起播时间&#xff0c;即首帧时间&#xff0c;是视频类应用的一个重要核心指标&#xff0c;也是影响用户观看体验的核心因素之一。如果视频要加载很久才能开始播放&#xff0c;用户放弃播放甚至离开 App 的概率都…

飞书api增加权限

1&#xff0c;进入飞书开发者后台&#xff1a;飞书开放平台 给应用增加权限 2&#xff0c;进入飞书管理后台 https://fw5slkpbyb3.feishu.cn/admin/appCenter/audit 审核最新发布的版本 如果还是不行&#xff0c;则需要修改数据权限&#xff0c;修改为全部成员可修改。 改完…

大数据架构之关系型数据仓库——解读大数据架构(二)

文章目录 前言什么是关系型数仓对数仓的错误认识与使用自上而下的方法关系型数仓的优点关系型数仓的缺点数据加载加载数据的频率如何确定变更数据 关系型数仓会消失吗总结 前言 本文对关系型数据仓库&#xff08;RDW&#xff09;进行了简要的介绍说明&#xff0c;包括什么是关…

课时93:流程控制_函数进阶_综合练习

1.1.3 综合练习 学习目标 这一节&#xff0c;我们从 案例解读、脚本实践、小结 三个方面来学习。 案例解读 案例需求 使用shell脚本绘制一个杨辉三角案例解读 1、每行数字左右对称&#xff0c;从1开始变大&#xff0c;然后变小为1。    2、第n行的数字个数为n个&#xf…

Bug及异常:unity场景角色移动卡墙壁的问题

场景是一个小的杠铃形状封闭空间&#xff0c;美术没有给包围盒&#xff0c;我自己用blender做了一个&#xff08;属于兴趣爱好&#xff09;&#xff0c;如下&#xff1a; 导入场景中使用meshcollider做成空气墙&#xff0c;发现角色移动到角落继续行走会卡角落处&#x…

谷歌pixel6/7pro等手机WiFi不能上网,显示网络连接受限

近期在项目中遇到一个机型出现的问题,先对项目代码进行排查,发现别的设备都能正常运行,就开始来排查机型的问题,特意写出来方便后续查看,也方便其它开发者来自查。 设备机型:Pixel 6a 设备安卓版本:13 该方法无需root,只需要电脑设备安装adb(即Android Debug Bridge…

分布式技术---------------消息队列中间件之 Kafka

目录 一、Kafka 概述 1.1为什么需要消息队列&#xff08;MQ&#xff09; 1.2使用消息队列的好处 1.2.1解耦 1.2.2可恢复性 1.2.3缓冲 1.2.4灵活性 & 峰值处理能力 1.2.5异步通信 1.3消息队列的两种模式 1.3.1点对点模式&#xff08;一对一&#xff0c;消费者主动…

出海企业如何从海外云手机中受益?

随着全球化的推进&#xff0c;越来越多的企业开始将目光投向海外市场。然而&#xff0c;不同国家和地区的网络环境、政策限制&#xff0c;以及语言文化的差异&#xff0c;给出海企业的市场拓展带来了诸多挑战。在这一背景下&#xff0c;海外云手机作为一种新兴解决方案&#xf…

MySQL如何定位慢查询?如何分析这条慢查询?

常见的慢查询 聚合查询&#xff08;常用的聚合函数有&#xff1a;MAX&#xff08;&#xff09;、MIN&#xff08;&#xff09;、COUNT&#xff08;&#xff09;、SUM&#xff08;&#xff09;、AVG&#xff08;&#xff09;&#xff09;。 多表查询 表数据过大查询 深度分页…

【Tomcat 文件读取/文件包含(CVE-2020-1938)漏洞复现】

文章目录 前言 一、漏洞名称 二、漏洞描述 三、受影响端口 四、受影响版本 五、漏洞验证 六、修复建议 前言 近日在做漏扫时发现提示服务器存在CVE-2020-1938漏洞&#xff0c;故文章记录一下相关内容。 一、漏洞名称 Tomcat 文件读取/文件包含漏洞(CVE-2020-1938) 二、漏洞描…

【SpringBoot】mybatis-plus实现增删改查

mapper继承BaseMapper service 继承ServiceImpl 使用方法新增 save,updateById新增和修改方法返回boolean值,或者使用saveOrUpdate方法有id执行修改操作,没有id 执行新增操作 案例 Service public class UserService extends ServiceImpl<UserMapper,User> {// Au…

送礼物动态特效直播和短视频特效源码送礼物动态连麦PK特效语音视频聊天室朋友圈十套

内附十套动画效果源码&#xff0c;可F5刷新随机显示特效预览。送礼物动态特效直播和短视频特效源码送礼物动态连麦PK特效语音视频聊天室朋友圈十套 SVGA 是一种用于嵌入式动画的矢量文件格式&#xff0c;通常用于在移动应用程序和网页中展示高质量的动画效果。相对于传统的 GIF…

办公软件巨头CCED、WPS迎来新挑战,新款办公软件已形成普及之势

办公软件巨头CCED、WPS的成长经历 CCED与WPS&#xff0c;这两者均是中国办公软件行业的佼佼者&#xff0c;为人们所熟知。 然而&#xff0c;它们的成功并非一蹴而就&#xff0c;而是经过了长时间的积累与沉淀。 CCED&#xff0c;这款中国大陆早期的文本编辑器&#xff0c;在上…