Pytorch封装简单RNN模型,进行中文训练及文本预测

简述

使用pytorch封装简单RNN模型,使用单层nn.RNNnn.Linear等实现,然后做简单的文本预测。

数据集

代码参考李沐:https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html,但他使用的是一篇英文小说,
这里改为使用COIG-CQIA中文数据集中的:douban_book_introduce.jsonlruozhiba_ruozhiba_ruozhiba.jsonl两个文件,本文目的是为了学习rnn,所以数据集比较简单,不过这个数据集由于都是问答形式,不像小说那样有主题性,所以感觉学习效果不好。理想的应该还是找个中文长篇小说之类。

COIG-CQIA: https://huggingface.co/datasets/m-a-p/COIG-CQIA

另外由于COIG-CQIA的数据是指令问答形式的json文件,所以这里稍作处理,改为单个问题+答案为一行的纯文本txt格式, 去除其它json字段及各种符号。

代码如下:

def jsonl_to_txt(dir_path):  dict_list = []  jsonl_list = os.listdir(dir_path)  qa_list = list()  chars_to_remove = r'[,。?;、:“”:!~()『』「」【】\"\[\]➕〈〉/<>()‰\%《》\*\?\-\.…·○01234567890123456789•\n\t abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ—*]'  for jsonl in jsonl_list:  path = os.path.join(dir_path, jsonl)  print(path)  with open(path, 'r', encoding='utf-8') as f:  jsonl_data = f.readlines()  for line in jsonl_data:  line_dict = JSON.loads(line)  qa = line_dict['instruction'] + line_dict['output']  qa = re.sub(chars_to_remove, '', qa).strip()  qa_list.append(qa)  path = os.path.join(dir_path, 'chengyu_qa.txt')  with open(path, 'w', encoding='utf-8') as f:  f.write('\n'.join(qa_list))  if __name__ == '__main__':  dir_path = '../data/COIG-CQIA'  jsonl_to_txt(dir_path)  print()

上面处理完毕后,还需要进行词元化、构建词典等步骤,参考:
python实现简单中文词元化、词典构造、时序数据集封装等-CSDN博客

模型封装

RNN — PyTorch 2.4 documentation

可以先观察一下tensorboard的add_graph函数对模型可视化后的结构:

在这里插入图片描述

这里使用单层的RNN(nn.RNN有默认参数num_layers=1),nn.functional.one_hot是为了实现单词的向量化表示,后续可以优化成nn.Embedding来做词向量。

nn.functional.one_hot前将x进行了转置,这里有点抽象,来关注一下nn.RNN的参数要求,便可理解。

先看x的初始shape为(batch_size, time_size),转置并向量化后为(time_size, batch_size, vocab_size)

若不转置直接向量化,则为(batch_size, time_size, vocab_size),实际上这两种格式的数据nn.RNN都支持。

但若为(batch_size, time_size, vocab_size)形式,则需在创建nn.RNN实例时指定参数batch_first=False。

在这里插入图片描述

另外,还需要提供一个初始的隐状态,这里用init_state函数实现。

在这里插入图片描述

class SimpleRNNModel(nn.Module):  def __init__(self, vocab_size, hidden_size):  super(SimpleRNNModel, self).__init__()  self.vocab_size = vocab_size  self.hidden_size = hidden_size  self.rnn = nn.RNN(vocab_size, hidden_size)  self.linear = nn.Linear(hidden_size, vocab_size)  def forward(self, x, hidden_state=None):  x = nn.functional.one_hot(x.T.long(), num_classes=self.vocab_size)  x = x.to(torch.float32)  outputs, hidden_state = self.rnn(x, hidden_state)  # rrn的outputs.shape(N, L, D*H)  outputs = outputs.reshape(-1, self.hidden_size)  outputs = self.linear(outputs)  return outputs, hidden_state  def init_state(self, device, batch_size=1):  return torch.zeros((self.rnn.num_layers, batch_size, self.hidden_size), device=device)  

梯度裁剪

源自李沐:https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-scratch.html

def grad_clipping(net, max_norm):  if isinstance(net, nn.Module):  params = [p for p in net.parameters() if p.requires_grad]  else:  params = net.params  norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))  if norm > max_norm:  for param in params:  param.grad[:] *= max_norm / norm

模型训练

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
print(f'\ndevice: {device}')  corpus, vocab = load_corpus("../data/COIG-CQIA/qa_list.txt")  vocab_size = len(vocab)  
hidden_size = 256  
epochs = 5  
batch_size = 50  
learning_rate = 0.01  
time_size = 4  
max_grad_max_norm = 0.5  dataset = make_dataset(corpus=corpus, time_size=time_size)  
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)  net = SimpleRNNModel(vocab_size, hidden_size)  
net.to(device)  # print(net.state_dict())  criterion = nn.CrossEntropyLoss()  
criterion.to(device)  
optimizer = optim.Adam(net.parameters(), lr=learning_rate)  writer = SummaryWriter('./train_logs')  
# 随便定义个输入, 好使用add_graph  
tmp = torch.rand((batch_size, time_size)).to(device)  
writer.add_graph(net, tmp)  loss_counter = 0  
total_loss = 0  
ppl_list = list()  
total_train_step = 0  for epoch in range(epochs):  print('------------Epoch {}/{}'.format(epoch + 1, epochs))  for X, y in data_loader:  X, y = X.to(device), y.to(device)  # 如果各个批次间的时序是连续的,则可以把上次的hidden_state传入下个批次, 不然就要重置hidden_state  # 这里batch_size=X.shape[0]是因为在加载数据时, DataLoader没有设置丢弃不完整的批次, 所以存在实际批次不满足设定的batch_size  hidden_state = net.init_state(batch_size=X.shape[0], device=device)  outputs, hidden_state = net(X, hidden_state=hidden_state)  optimizer.zero_grad()  # y也变成 时间序列*批次大小的行数, 才和 outputs 一致  y = y.T.reshape(-1)  # 交叉熵的第二个参数需要LongTorch  loss = criterion(outputs, y.long())  loss.backward()  # 求完梯度之后可以考虑梯度裁剪, 再更新梯度  grad_clipping(net, max_grad_max_norm)  optimizer.step()  total_loss += loss.item()  loss_counter += 1  total_train_step += 1  if total_train_step % 10 == 0:  print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  writer.add_scalar('train_loss', loss.item(), total_train_step)  ppl = np.exp(total_loss / loss_counter)  ppl_list.append(ppl)  print(f'Epoch {epoch + 1} 结束, batch_loss_average: {total_loss / loss_counter}, perplexity: {ppl}')  writer.add_scalar('ppl', ppl, epoch + 1)  total_loss = 0  loss_counter = 0  torch.save(net.state_dict(), './save/epoch_{}_ppl_{}.pth'.format(epoch + 1, ppl))  writer.close()

tensorboard训练过程观察

横轴为训练epoch。

在这里插入图片描述

横轴为训练次数。

在这里插入图片描述

文本预测

这里首先完善模型的预测函数(该函数放到模型中):

def predict(self, prefix, num_preds, vocab, device):  state = self.init_state(batch_size=1, device=device)  # prefix为字符, 转成索引  outputs = [vocab.word2idx(prefix[0])]  get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))  # 一个字符一个字符跑一遍, 对用户输入进行预热, 即对输入的各个字符间建立联系  for y in prefix[1:]:  # 预热期  _, state = self.forward(get_input(), state)  outputs.append(vocab.word2idx(y))  # 刚好每次都用上一次的预测值做输入  for _ in range(num_preds):  # 预测num_preds步  y, state = self.forward(get_input(), state)  outputs.append(int(y.argmax(dim=1).reshape(1)))  return ''.join([vocab.idx2word(i) for i in outputs])

实现对提示词处理及预测函数的调用:

注意:这里的语料库应和训练使用的一致。

def predict(state_dict_path, vocab, prefix=None, num_preds=3):  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  vocab_size = len(vocab)  hidden_size = 256  net = SimpleRNNModel(vocab_size, hidden_size).to(device)  net.load_state_dict(torch.load(state_dict_path, map_location=device, weights_only=True))  net.eval()  with torch.no_grad():  outputs = net.predict(prefix=prefix, num_preds=num_preds, vocab=vocab, device=device)  return outputs  if __name__ == '__main__':  corpus, vocab = load_corpus("../data/COIG-CQIA/qa_list.txt")  # corpus, vocab = load_corpus("../data/COIG-CQIA/chengyu_qa.txt")  # print(len(vocab))  # idx = [vocab.word2idx(ch) for ch in prefix]  path = "../save/Simple/新建文件夹/state_dict-time_size_30-ppl_1.pth"  prefix = "有什么超赞的诗句"  print(f'提示词: {prefix}')  outputs = predict(path, vocab, prefix=prefix, num_preds=22)  print(f'预测输出: {outputs}\n')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/411732.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通配符证书的简介和申请方法

通配符证书是一种SSL证书&#xff0c;它利用域名字段中的通配符&#xff08;*&#xff09;来指示&#xff0c;允许用户在一个证书中关联多个顶级域名及其子域&#xff0c;从而简化证书管理流程&#xff0c;节省成本和时间。以下是通配符证书的简介和申请方法的详细说明&#xf…

Springsecurity中的Eureka报错:Cannot execute request on any known server

完整报错信息&#xff1a; com.netflix.discovery.shared.transport.TransportException: Cannot execute request on any known server 报错体现&#xff1a; 访问eureka控制面板&#xff1a; 访问测试地址&#xff1a; 控制台报错&#xff1a; 可能的报错原因&#xff…

ZW3D二次开发_UI_ZsCc::OptionRadios控件回调

1.ZW3D中的OptionRadios控件如何实现点击触发回调并且获取点击后的值&#xff1f;如下图 2.教程如下&#xff1a; 1&#xff09;添加OptionRadios控件到表单中 2&#xff09;增加radio按钮 3&#xff09;添加回调 4&#xff09;编写回调函数 int radioCallbackDemo(char* for…

【信息安全】基于CBC的3DES加解密-实验报告

实验运行效果截图 3DES进行加密 3DES进行解密 然后可以选择你想要的操作,继续加密解密或者退出。 基于CBC模式的3DES加解密 一、实验内容 基于3DES加解密算法,编程实现对任意文件实现加解密的软件。 编程实现DES加密和解密算法,并使用DES加解密算法实现3DES加解密算法。选…

Android活动(activity)与服务(service)进行通信

文章目录 Android活动&#xff08;activity&#xff09;与服务&#xff08;service&#xff09;进行通信活动与服务进行通信服务的生命周期 Android活动&#xff08;activity&#xff09;与服务&#xff08;service&#xff09;进行通信 活动与服务进行通信 上一小节中我们学…

基于FPGA的SD卡的数据读写实现(SD NAND FLASH)

文章目录 目录 1、存储芯片分类 2、NOR Flash 与 NAND Flash的区别 3、什么是SD卡&#xff1f; 4、什么是SD NAND&#xff1f; 5、SD NAND的控制时序 6、FPGA实现SD NAND读写 1、存储芯片分类 目前市面上的存储芯片&#xff0c;大致可以将其分为3大类&#xff1a; ① …

【回眸】QAC软件指南——错误分析篇(完整版)

前言 近期需要再次测一下代码&#xff0c;相比以前测试更有经验&#xff0c;也做了比较多的记录&#xff0c;正好将经验通过博客保留下来&#xff0c;为以后可能的QAC测试做准备。 安装导入分析代码 这部分在上一篇中已经详细介绍&#xff0c;具体请见&#xff0c;如有疑问可…

百元蓝牙耳机什么牌子的好?四大宝藏机型真实推荐,快速收藏!

作为一位蓝牙耳机爱好者&#xff0c;无论是上班、娱乐、学习我都离不开蓝牙耳机。通勤时候能听听音乐&#xff0c;是最好的享受&#xff0c;可以让我更加放松&#xff0c;尽情享受音乐带来的乐趣。但市面上的大多数蓝牙耳机都是货不对板的&#xff0c;不是音质一般、就是续航时…

谷歌发布 3 款 Gemini 新模型;字节开源 FLUX Dev Hyper SD Lora,8 步生图丨 RTE 开发者日报

开发者朋友们大家好&#xff1a; 这里是 「RTE 开发者日报」 &#xff0c;每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE&#xff08;Real-Time Engagement&#xff09; 领域内「有话题的 新闻 」、「有态度的 观点 」、「有意思的 数据 」、「有思考的 文…

Seata执行原理分析-AT、XA、TCC、SAGA比较

分布式事务简介 1.1 本地事务 大多数场景下&#xff0c;我们的应用都只需要操作单一的数据库&#xff0c;这种情况下的事务称之为本地事务(Local Transaction)。本地事务的ACID特性是数据库直接提供支持。本地事务应用架构如下所示&#xff1a; 在JDBC编程中&#xff0c;我们…

priority_queue模拟

一、什么是priority_queue? priority_queue是C标准库中的一个容器适配器&#xff0c;用于实现优先队列&#xff08;priority queue&#xff09;的数据结构。优先队列是一种特殊的队列&#xff0c;其中的元素按照一定的优先级进行排序&#xff0c;每次取出的元素都是优先级最高…

从零开始掌握Vue实例

从零开始掌握Vue实例&#xff1a;深入理解数据绑定与生命周期的核心秘诀 引言 简要介绍主题&#xff1a; 在学习Vue.js的过程中&#xff0c;Vue实例是最基础也是最关键的部分。Vue实例是Vue应用的核心&#xff0c;它是数据、DOM元素和Vue组件之间的桥梁。掌握Vue实例的使用对于…

文件上传面板中限制需要的文件格式,并且隐藏“所有文件”选项

直接说需求&#xff1a;需要实现在文件上传面板中限制需要的文件格式&#xff0c;并且不想展示“所有文件”这个选项&#xff0c;应该怎么做嘞&#xff1f;效果如下图&#xff1a; 这里用到了 window.showOpenFilePicker 方法实现&#xff0c;首先定义接受的格式及限制&#xf…

格行“信号增强技术”引领行业创新,格行随身WiFi带你感受不一样的速度与激情,行业第一的随身WiFi并非浪得虚名!

近年来&#xff0c;随着市场保有量的不断提升与相关技术的不断扩展&#xff0c;我国随身WiFi市场已经到了发展质量更高的“2.0”阶段&#xff0c;消费者对随身WiFi的需求变得多元且“高级”。与之对应的供给端&#xff0c;品牌之间的竞争也从未停止&#xff0c;有的品牌选择卷价…

如何使用ssm实现实验室仪器设备管理系统

TOC ssm354实验室仪器设备管理系统jsp 绪论 1.1 研究背景 当前社会各行业领域竞争压力非常大&#xff0c;随着当前时代的信息化&#xff0c;科学化发展&#xff0c;让社会各行业领域都争相使用新的信息技术&#xff0c;对行业内的各种相关数据进行科学化&#xff0c;规范化…

快来尝尝,食家巷荞面甜甜圈超赞

当荞面与甜甜圈相遇&#xff0c;便诞生了食家巷荞面甜甜圈&#xff0c;一种独具特色的美食体验。 食家巷荞面甜甜圈&#xff0c;外形圆润可爱&#xff0c;色泽金黄诱人。那精致的环状造型&#xff0c;仿佛是一个小小的魔法圈&#xff0c;散发着迷人的魅力。 与传统甜甜圈…

计算机网络面试真题总结(七)

文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 什么是对称加密、非对称加密&#xff1f; 对称加密是一种常用的加…

探索AI智能问答:改变未来交流的新动力

人工智能(AI)是当今科技领域中最具潜力和影响力的技术之一&#xff0c;AI智能问答系统更是这一领域中的一颗璀璨明珠。随着大数据和机器学习的发展&#xff0c;AI智能问答系统已经不仅仅是科幻小说中的幻想&#xff0c;而是正逐步融入我们的日常生活&#xff0c;从客户服务到教…

生成式AI扩散模型-Diffusion Model【李宏毅2023】概念讲解、原理剖析笔记

目录 一、Diffusion的基本概念和运作方法 1.Diffusion Model是如何运作的&#xff1f; 2.Denoise模块内部正在做的事情 如何训练Noise predictor&#xff1f; 1&#xff09;Forward Process (Diffusion Process) 2&#xff09;noise predictor 3.Text-to-Image 4.两个A…

入门Java第一步—>IDEA的下载与安装与JDK的环境配置(day01)

1.JDK的下载与安装 jdk的安装链接分为不同操作系统如下,点击链接跳转下载页面&#xff1a; windows操作系统JDK下载链接(按住键盘ctrl键单击链接即可)&#xff1a; 链接7天有效&#xff0c;有需要的评论区找我哈 通过网盘分享的文件&#xff1a;jdk-8u271-windows-x64.exe 链…