【NLP练习】Transformer实战-单词预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

参考链接:
LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT

任务:自定义输入一段英文文本进行预测

一、定义模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn,Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math, os, torchclass TransformerModel(nn.Module):def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):super().__init__()self.pos_encoder = PositionalEncoding(d_model, dropout)#定义编码器层encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)#定义编码器,pytorch将Transformer编码器进行了打包,这里直接调用即可self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)self.embedding = nn.Embedding(ntoken,d_model)self.d_model = d_modelself.linear = nn.Linear(d_model, ntoken)self.init_weights()#初始化权重def init_weights(self) -> None:initrange = 0.1self.embedding.weight.data.uniform_(-initrange, initrange)self.linear.bias.data.zeros_()self.linear.weight.data.uniform_(-initrange, initrange)def forward(self, src:Tensor, src_mask: Tensor = None) -> Tensor:"""Arguments:src:      Tensor, 形状为[seq_len, batch_size]src_mask: Tensor, 形状为[seq_len, seq_len]Returns:输出的Tensor,形状为[seq_len, batch_size, ntoken]"""src = self.embedding(src) * math.sqrt(self.d_model)src = self.pos_encoder(src)output = self.transformer_encoder(src, src_mask)output = self.linear(output)return output
class PositionalEncoding(nn.Module):def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):super().__init__()self.dropout = nn.Dropout(p = dropout)#生成位置编码的位置张量position = torch.arange(max_len).unsqueeze(1)#计算位置编码的除数项div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))#创建位置编码张量pe = torch.zeros(max_len, 1, d_model)#使用正弦函数计算位置编码中的基数维度部分pe[:, 0, 1::2] = torch.sin(position * div_term)#使用余弦函数计算位置编码中的偶数维度部分pe[:, 0, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x: Tensor) -> Tensor:"""Arguments:x:      Tensor, 形状为[seq_len, batch_size, embedding_dim]"""#将位置编码添加到输入张量x = x + self.pe[:x.size(0)]#应用dropoutreturn self.dropout(x)

二、加载数据集

本实验使用torchtext生成Wikitext-2数据集。在此之前,你需要安装下面的包:

  • pip install portalocker
  • pip install torchdata
import torchtext
from torchtext.datasets.wikitext2 import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator#从torchtext库中导入WikiTetx2数据集
train_iter = WikiText2(split = 'train')#获取基本的英语分词器
tokenizer = get_tokenizer('basic_english')
#通过迭代器构建词汇表
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
#将默认索引设置为'<unk>'
vocab.set_default_index(vocab['<unk>'])def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:"""将原始文本转换为扁平的张量"""data = [torch.tensor(vocab(tokenizer(item)),dtype = torch.long) for item in raw_text_iter]return torch.cat(tuple(filer(lambda t: t.numel() > 0, data)))#由于构建词汇表时"train_iter"被使用了,所以需要重新创建
train_iter, val_iter, test_iter = WikiText2()#队训练、验证和测试数据进行处理
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)#检查是否有可用的CUDA设备,将设备设置为GPU或者CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def batchify(data: Tensor, bsz: int) -> Tensor:"""将数据划分为bsz个单独的序列,去除不能完全容纳的额外元素。参数:data: Tensor,形状为``[N]``bsz:int,批大小返回:形状为[N // bsz, bsz]的张量"""seq_len = data.size(0) // bszdata = data[:seq_len * bsz]data = data.view(bsz, seq_len).t().contiguous()return data.to(device)#设置批大小和评估批大小
batch_size = 20
eval_batch_size = 10
#将训练、验证和测试数据进行批处理
train_data = batchify(train_data, batch_size)   #形状为[seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
bptt = 35#获取批次数据
def get_batch(source:Tensor, i: int) -> Tuple[Tensor, Tensor]:"""参数:source: Tensor,形状为``[full_seq_len, batch_size]``i : int, 当前批次索引返回:tuple(data, target),-data形状为[seq_len, batch_size],-target形状为[seq_len * batch_size]"""#计算当前批次的序列长度,最大为bptt,确保不超过source的长度seq_len = min(bptt, len(source) - 1 - i)#获取data,从i开始,长度为seq_lendata = source[i:i+seq_len]#获取target,从i+1开始,长度为seq_len,并将其形状转换为一维张量target = source[i+1:i+1+seq_len].reshape(-1)return data, target

三、初始化实例

ntokend = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.2
#创建transformer模型
model = TransformerModel(ntokend,emsize,nhead,d_hid,nlayers,dropout).to(device)

四、训练模型

结合使用CrossEntropyLoss与SGD(随机梯度下降优化器)。训练期间,使用torch.nn.utils.clip_grad_norm_来防止梯度爆炸

import time
criterion = nn.CrossEntropyLoss() #定义交叉熵损失函数
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(),lr = lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gama = 0.95)def train(model: nn.Module) -> None:model.train() #开启训练模式total_loss = 0.log_interval = 200 #start_time = time.time()num_batches = len(train_data) // bpttfor batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):data, targets = get_batch(train_data, i)output = model(data)output_flat = output.view(-1, ntokens)loss = criterion(output_flat, targets) #计算损失optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()total_loss += loss.item()if batch % log_interval == 0 and batch > 0:lr = scheduler.get_last_lr()[0]ms_per_batch = (time.time() - start_time) * 1000 / log_intervalcur_loss = total_loss / log_intervalppl = math.exp(cur_loss)print(f'| epoch{epoch:3d} | {batch:5d} / {num_batches:5d} batches |'f'lr{lr:02.2f} | ms/batch {ms_per_batch:5.2f} |'f'loss {cur_loss:5.2f}|ppl{ppl:8.2f}')total_loss = 0start_time = time.time()def evaluate(model:nn.Module, eval_data:Tensor) -> float:model.eval()total_loss = 0.with torch.no_grad():for i in range(0,eval_data.size(0) - 1, bptt):data, targets = get_batch(eval_data,i)seq_len = data.size(0)output = model(data)output_flat = output.view(-1,ntokens)total_loss += seq_len * criterion(output_flat, targets).item()return total_loss / (len(eval_data) - 1)
best_val_loss = float('inf')
epochs = 1with TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, "best_model_params.pt")for epoch in range(1, epochs + 1):epoch_start_time = time.time()train(model)val_loss = evaluate(model, val_data)val_ppl = math_exp(val_loss)elapsed = time.time() - epoch_start_time#打印当前epoch的信息,包括耗时、验证损失和困惑度print('-' * 89)print(f'|end of epoch {epoch:3d} | time:{elapsed: 5.2f}s |'f'valid loss {val_loss:5.2f} | valid ppl {val_ppl: 8.2f}')print('-' * 89)if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), best_model_params_path)scheduler.step()    #更新学习率model.load_state_dict(torch.load(best_model_params_path))

代码输出:
在这里插入图片描述

五、总结

加载数据集时,注意包的版本关联关系。另外,注意结合使用优化器提升优化性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/359937.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多维表格/业务库表格大数据量性能瓶颈

先说最终结论&#xff1a;Angular 组件创建性能损耗是当下主要的性能瓶颈 理由&#xff1a; 基于以往编辑器性能优化的经验&#xff0c;编辑器在动态渲染内容时会创建很多壳子组件&#xff08;也就是Angular 组件&#xff09;&#xff0c;排查的时候就发现如果略这些壳子组件性…

Day9 —— 大数据技术之ZooKeeper

ZooKeeper快速入门系列 ZooKeeper的概述什么是ZooKeeper&#xff1f;ZooKeeper的特点和功能使用ZooKeeper的原因 ZooKeeper数据模型ZooKeeper安装ZooKeeper配置ZooKeeper命令行操作常见服务端命令 ZooKeeper的概述 什么是ZooKeeper&#xff1f; ZooKeeper是一个开源的分布式协…

ZOOM太卡怎么办?公司如何解决ZOOM会议卡顿?

ZOOM作为一种常见的办公工具&#xff0c;尤其在跨国公司和外资企业中&#xff0c;在线会议非常普遍。然而&#xff0c;由于ZOOM的服务器部署在国外&#xff0c;国内用户使用时可能会遇到卡顿、不稳定和声音断续等问题。那么&#xff0c;如何有效解决ZOOM卡顿的问题呢&#xff1…

数据库管理-第212期 上期SQL性能优化勘误与扩展(20240624)

数据库管理212期 2024-06-24 数据库管理-第212期 上期SQL性能优化勘误与扩展&#xff08;20240624&#xff09;1 环境2 方案1问题3 问题引申总结 数据库管理-第212期 上期SQL性能优化勘误与扩展&#xff08;20240624&#xff09; 作者&#xff1a;胖头鱼的鱼缸&#xff08;尹海…

OS中断机制-外部中断触发

中断函数都定义在中断向量表中,外部中断通过中断跳转指令触发中断向量表中的中断服务函数,中断指令可以理解为由某个中断寄存器的状态切换触发的汇编指令,这个汇编指令就是中断跳转指令外部中断通过在初始化的时候使能对应的中断服务函数如何判断外部中断被触发的条件根据Da…

【机器学习项目实战(二)】基于朴素贝叶斯的中文垃圾短信分类

完整代码、数据集和相应的报告 链接已经放在了正文最下方, 供大家参考学习 摘要 ​ 本文探讨了中文垃圾短信分类的问题,通过收集实际数据集,运用多种机器学习算法进行分类,并对比了不同算法在垃圾短信分类任务上的性能。本研究旨在提高中文垃圾短信的识别准确率,为构建更…

《UDS协议从入门到精通》系列——图解0x2A:通过周期读ID数据

《UDS协议从入门到精通》系列——图解0x2A&#xff1a;通过周期读ID数据 一、简介二、数据包格式2.1 服务请求格式2.2 服务响应格式2.2.1 肯定响应2.2.2 否定响应 三、通信示例 Tip&#x1f4cc;&#xff1a;本文描述中但凡涉及到其他UDS服务的&#xff0c;将陆续提供链接跳转方…

CAN总线在新能源行业中的重要应用

2019年10月26日第三届中国&#xff08;佛山&#xff09;氢能源大会中展示了氢燃料电池城市客车&#xff0c;与目前的纯电动汽车和传统汽车相比&#xff0c;作为一种新的驱动形式出现。但是&#xff0c;新能源汽车整车网络的实现依旧离不开CAN总线通讯。 工程师们通过CAN总线读取…

Linux Redhat ens33不显示IP问题

优质博文&#xff1a;IT-BLOG-CN 【第一步】&#xff1a;查看系统网卡设备 : ip addr show 【第二步】&#xff1a;修改网卡配置参数 cd /etc/sysconfig/network-scripts/ vi ifcfg-ens33 修改ONBOOT参数为yes 【第三步】&#xff1a;重启网卡&#xff0c;然后ping检测…

Vue + SpringBoot 实现文件的断点上传、秒传,存储到Minio

一、前端 1. 计算文件的md5值 前端页面使用的elment-plus的el-upload组件。 <el-upload action"#" :multiple"true" :auto-upload"false" :on-change"handleChange" :show-file-list"false"><FileButton content&…

python中的<class ‘complex‘>

一般编程里面不怎么会讲&#xff0c;但是还是挺强大的一个类。 在 Python 中&#xff0c;<class complex> 表示复数类型。复数是一种包含实部和虚部的数学数&#xff0c;可以用 a bj 的形式表示&#xff0c;其中 a 表示实部&#xff0c;b 表示虚部&#xff0c;j 是虚数…

kylinos 国产操作系统离线安装firefox 麒麟操作系统安装新版本firefox

1. 火狐地址&#xff1a; 下载 Firefox 浏览器&#xff0c;这里有简体中文及其他 90 多种语言版本供您选择 2. 选择&#xff1a; 3. 下载完之后&#xff0c;上传到离线机器 4. 解压缩&#xff1a; tar -xvjf firefox-127.0.1.tar.bz2 5. 去点击解压后的文件夹&#xff0c;找…

Redis持久化(RDB、AOF)详解

Redis持久化详解 一、Redis为什么需要持久化&#xff1f; Redis 是一个基于内存的数据库&#xff0c;拥有极高的读写性能&#xff0c;但是内存中的数据在断电或服务器重启时会全部丢失&#xff0c;因此需要一种持久化机制来将数据保存到硬盘上&#xff0c;以便在需要时进行恢复…

Latex学习之“usefont”用法

Latex学习之“\usefont”用法 一、通俗的解释 \usefont 是 LaTeX 中的一个命令&#xff0c;用于在文档中临时改变字体&#xff0c;其基本语法如下&#xff1a; \usefont{字体编码}{字体族}{字体系列}{字体形状}这样看起来好像蛮抽象&#xff0c;你可能以及晕了&#xff0c;什…

《代码大模型安全风险防范能力要求及评估方法》正式发布

​代码大模型在代码生成、代码翻译、代码补全、错误定位与修复、自动化测试等方面为研发人员带来了极大便利的同时&#xff0c;也带来了对安全风险防范能力的挑战。基于此&#xff0c;中国信通院依托中国人工智能产业发展联盟&#xff08;AIIA&#xff09;&#xff0c;联合开源…

Midway + TypeORM项目部署到BT后启动失败,MySQL报错

Midway TypeORM项目部署到BT后启动失败&#xff0c;MySQL报错 前沿 您需要先了解这篇文章&#xff1a;https://blog.csdn.net/weixin_45687201/article/details/139336111 错误日志 服务状态开启后就失败项目日志&#xff0c;输出 \> my-midway-project1.0.0 start \&…

前端vue-cli相关知识与搭建过程(项目创建,组件路由)very 详细

一.关于vue-cli 1.什么是vue Vue (读音 /vju ː /&#xff0c;类似于 view) 是一套用于构建用户界面的渐进式框架。Vue 的核心库只关注视图层&#xff0c;不仅易于上手&#xff0c;还便于与第三方库或既有项目整合。 Vue.js 是前端的主流框架之一&#xff0c;和 Angular.js…

19、删除链表的倒数第

1、题目描述 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5]示例 2&#xff1a; 输入&#xff1a;head [1], n 1 输出&#xff1a;[]示例 …

多商户零售外卖超市外卖商品系统源码

构建你的数字化零售王国 一、引言&#xff1a;数字化零售的崛起 在数字化浪潮的推动下&#xff0c;零售业务正经历着前所未有的变革。多商户零售外卖超市商品系统源码应运而生&#xff0c;为商户们提供了一个全新的数字化零售解决方案。通过该系统源码&#xff0c;商户们可以…

基于Openmv的追小球的云台

介绍 在这篇文章&#xff0c;我会先介绍需要用到且需要注意的函数&#xff0c;之后再给出整体代码 在追小球的云台中&#xff0c;比较重要的部分就是云台&#xff08;实质上就是舵机&#xff09;的控制以及对识别的色块位置进行处理得到相应信息后控制云台进行运动 1、舵机模…