Pytorch实现transformer语言模型

转载自:| 03_language_model/02_Transformer语言模型.ipynb | 从头训练Transformer语言模型 |Open In Colab |

Transformer语言模型

本节训练一个 sequence-to-sequence 模型,使用pytorch的
nn.Transformer <https://pytorch.org/docs/master/nn.html?highlight=nn%20transformer#torch.nn.Transformer> module.

PyTorch 1.2 基于论文 Attention is All YouNeed <https://arxiv.org/pdf/1706.03762.pdf> 实现了一个 Transformer 模型, nn.Transformer 模块依赖于 attention 机制实现表达输入和输出文本的关系。

定义模型

基于 nn.TransformerEncoder 模型训练语言模型。

语言模型任务是为句子后跟随单词输出一个似然概率,表征这个单词可能出现的概率。

首先做 embedding,再做 positional encoding, 表征单词位置关系。nn.TransformerEncoder 由多层nn.TransformerEncoderLayer <https://pytorch.org/docs/master/nn.html?highlight=transformerencoderlayer#torch.nn.TransformerEncoderLayer>组成,对于语言模型任务,每个未来可能出现的单词都需要 mask 并预测其概率,为了得到实际的预测单词,nn.TransformerEncoder模型的输出后需要接一个 log-Softmax 函数。

import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass TransformerModel(nn.Module):def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):super(TransformerModel, self).__init__()from torch.nn import TransformerEncoder, TransformerEncoderLayerself.model_type = 'Transformer'self.src_mask = Noneself.pos_encoder = PositionalEncoding(ninp, dropout)encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)self.encoder = nn.Embedding(ntoken, ninp)self.ninp = ninpself.decoder = nn.Linear(ninp, ntoken)self.init_weights()def _generate_square_subsequent_mask(self, sz):mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))return maskdef init_weights(self):initrange = 0.1self.encoder.weight.data.uniform_(-initrange, initrange)self.decoder.bias.data.zero_()self.decoder.weight.data.uniform_(-initrange, initrange)def forward(self, src):if self.src_mask is None or self.src_mask.size(0) != len(src):device = src.devicemask = self._generate_square_subsequent_mask(len(src)).to(device)self.src_mask = masksrc = self.encoder(src) * math.sqrt(self.ninp)src = self.pos_encoder(src)output = self.transformer_encoder(src, self.src_mask)output = self.decoder(output)return output

PositionalEncoding 模块包括 relative 和 absolute 位置编码,positional encodings 与 embeddings 的维度是一样的,这样两者可以相加。

class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0), :]return self.dropout(x)

加载数据

模型训练过程使用来自 torchtext 的Wikitext-2数据集。vocab 基于 train 数据集构建。batchify()函数将数据集排列成列,在将数据划分为大小为`batch_size``的批次后,删除所有剩余的标记。

例如,将字母表作为序列(总长度为26),批量大小为4,我们将字母表分成4个长度为6的序列:
在这里插入图片描述

import os
import torchtext
from torchtext.data.utils import get_tokenizerdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")TEXT = torchtext.legacy.data.Field(init_token='<sos>',eos_token='<eos>',lower=True)
train_txt, val_txt, test_txt = torchtext.legacy.datasets.language_modeling.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)TEXT
len(train_txt.examples[0].text)
# 2088628
def batchify(data, bsz):data = TEXT.numericalize([data.examples[0].text])# Divide the dataset into bsz parts.nbatch = data.size(0) // bsz# Trim off any extra elements that wouldn't cleanly fit (remainders).data = data.narrow(0, 0, nbatch * bsz)# Evenly divide the data across the bsz batches.data = data.view(bsz, -1).t().contiguous()return data.to(device)batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)print(train_data.shape)
print(val_data.shape)
# torch.Size([104431, 20])
# torch.Size([21764, 10])

定义生成target文本

bptt = 35
def get_batch(source, i):seq_len = min(bptt, len(source) - 1 - i)data = source[i:i+seq_len]target = source[i+1:i+1+seq_len].view(-1)return data, target

试一下模型效果

设置超参:

ntokens = len(TEXT.vocab.stoi)  # the size of vocabulary
emsize = 200  # embedding dimension
nhid = 200  # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # the number of heads in the multiheadattention models
dropout = 0.2  # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid,nlayers, dropout).to(device)

运行模型

import time
criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)def train():model.train()  # Turn on the train modetotal_loss = 0.start_time = time.time()ntokens = len(TEXT.vocab.stoi)for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):data, targets = get_batch(train_data, i)optimizer.zero_grad()output = model(data)loss = criterion(output.view(-1, ntokens), targets)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()total_loss += loss.item()log_interval = 200if batch % log_interval == 0 and batch > 0:cur_loss = total_loss / log_intervalelapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches | ''lr {:02.2f} | ms/batch {:5.2f} | ''loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],elapsed * 1000 / log_interval,cur_loss, math.exp(cur_loss)))total_loss = 0start_time = time.time()def evaluate(eval_model, data_source):eval_model.eval()  # Turn on the evaluation modetotal_loss = 0.ntokens = len(TEXT.vocab.stoi)with torch.no_grad():for i in range(0, data_source.size(0) - 1, bptt):data, targets = get_batch(data_source, i)output = eval_model(data)output_flat = output.view(-1, ntokens)total_loss += len(data) * criterion(output_flat, targets).item()return total_loss / (len(data_source) - 1)

在validation loss最优时保存模型,在每个epoch结束时调整learning rate。

best_val_loss = float("inf")
epochs = 10  # The number of epochs
best_model = None
MODEL_PATH = 'transformer_lm.pth'
for epoch in range(1, epochs + 1):epoch_start_time = time.time()train()val_loss = evaluate(model, val_data)print('-' * 89)print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ''valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),val_loss, math.exp(val_loss)))print('-' * 100)if val_loss < best_val_loss:best_val_loss = val_lossbest_model = modeltorch.save(best_model.state_dict(), MODEL_PATH)scheduler.step()best_model.load_state_dict(torch.load(MODEL_PATH))

在这里插入图片描述
Evaluate the model with the test dataset

Apply the best model to check the result with the test dataset.

test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)
import os
os.remove('transformer_lm.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/465805.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

<Project-20 YT-DLP> 给视频网站下载工具 yt-dlp/yt-dlp 加个页面 python web

介绍 yt-dlp Github 项目&#xff1a;https://github.com/yt-dlp/yt-dlp A feature-rich command-line audio/video downloader 一个功能丰富的视频与音频命令行下载器 原因与功能 之前我用的 cobalt 因为它不再提供Client Web功能&#xff0c;只能去它的官网使用。 翻 redd…

Sqli-Labs

目录 解题思路 题目设计原理 总结 解题思路 什么&#xff1f;sqli-labs&#xff1f;让我看看。还真是。想起了当初刚学被支配的恐惧。 悄咪咪点开第一关看看能不能秒了。测试闭合老样子&#xff0c;单引号闭合&#xff0c;双引号等都成功。这里 and 11 和 # 都不能通过检测&…

【基于Zynq FPGA对雷龙SD NAND的测试】

一、SD NAND 特征 1.1 SD 卡简介 雷龙的 SD NAND 有很多型号&#xff0c;在测试中使用的是 CSNP4GCR01-AMW 与 CSNP32GCR01-AOW。芯片是基于 NAND FLASH 和 SD 控制器实现的 SD 卡。具有强大的坏块管理和纠错功能&#xff0c;并且在意外掉电的情况下同样能保证数据的安全。 …

【NOIP提高组】引水入城

【NOIP提高组】引水入城 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 在一个遥远的国度&#xff0c;一侧是风景秀美的湖泊&#xff0c;另一侧则是漫无边际的沙漠。该国的行政 区划十分特殊&#xff0c;刚好构成一个N行M列的矩形&#xff…

鸿蒙开发:arkts 如何读取json数据

为了支持ArkTS语言的开发&#xff0c;华为提供了完善的工具链&#xff0c;包括代码编辑器、编译器、调试器、测试工具等。开发者可以使用这些工具进行ArkTS应用的开发、调试和测试。同时&#xff0c;华为还提供了DevEco Studio这一一站式的开发平台&#xff0c;为运行在Harmony…

OpenCV视觉分析之目标跟踪(11)计算两个图像之间的最佳变换矩阵函数findTransformECC的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 根据 ECC 标准 78找到两幅图像之间的几何变换&#xff08;warp&#xff09;。 该函数根据 ECC 标准 ([78]) 估计最优变换&#xff08;warpMatri…

【2024最新版Kotlin教程】Kotlin第一行代码系列第五课-类继承,抽象类,接口

【2024最新版Kotlin教程】Kotlin第一行代码系列第五课-类继承&#xff0c;抽象类&#xff0c;接口 为什么要有继承呢&#xff0c;现实中也是有继承的&#xff0c;对吧&#xff0c;你继承你爸的遗产&#xff0c;比如你爸建好了一个房子&#xff0c;儿子继承爸&#xff0c;就得了…

iOS用rime且导入自制输入方案

iPhone 16 的 cantonese 只能打传统汉字&#xff0c;没有繁简转换&#xff0c;m d sh d。考虑用「仓」输入法 [1] 使用 Rime 打字&#xff0c;且希望导入自制方案 [2]。 仓输入法有几种导入方案的方法&#xff0c;见 [3]&#xff0c;此处记录 wifi 上传法。准备工作&#xff1…

基于Zynq FPGA的雷龙SD NAND存储芯片性能测试

文章目录 前言一、SD NAND特征1.1 SD卡简介1.2 SD卡Block图 二、SD卡样片三、Zynq测试平台搭建3.1 测试流程3.2 SOC搭建 四、软件搭建五、测试结果六、总结 前言 随着嵌入式系统和物联网设备的快速发展&#xff0c;高效可靠的存储解决方案变得越来越重要。雷龙发展推出的SD NA…

【动态规划 数学】2745. 构造最长的新字符串|1607

本文涉及知识点 C动态规划 数学 LeetCode2745. 构造最长的新字符串 给你三个整数 x &#xff0c;y 和 z 。 这三个整数表示你有 x 个 “AA” 字符串&#xff0c;y 个 “BB” 字符串&#xff0c;和 z 个 “AB” 字符串。你需要选择这些字符串中的部分字符串&#xff08;可以全…

【Linux驱动开发】timer库下的jiffies时间戳和延时驱动编写

【Linux驱动开发】timer库下的jiffies时间戳和延时驱动编写 gitee地址&#xff1a; https://gitee.com/Mike_Zhou_Admin/Linux_Driver_Timestamp_Driver/更新以gitee为准 文章目录 timer库时间戳函数延时函数驱动代码应用测试附录&#xff1a;嵌入式Linux驱动开发基本步骤开发…

了解云计算工作负载保护的重要性及必要性

云计算de小白 云计算技术的快速发展使数据和应用程序安全成为一种关键需求&#xff0c;而不仅仅是一种偏好。随着越来越多的客户公司将业务迁移到云端&#xff0c;保护他们的云工作负载&#xff08;指所有部署的应用程序和服务&#xff09;变得越来越重要。云工作负载保护&…

C语言 循环高级

时间&#xff1a;2024.11.6 一、学习内容 1、无限循环 无限循环&#xff1a;循环永远停不下来 注意点&#xff1a;无限循环因为永远停不下来&#xff0c;所以下面不能再写其他的代码了 2、break 跳转控制语句&#xff1a; 在循环的过程中&#xff0c;跳到其他语句上执行 #…

易语言模拟真人动态生成鼠标滑动路径

一.简介 鼠标轨迹算法是一种模拟人类鼠标操作的程序&#xff0c;它能够模拟出自然而真实的鼠标移动路径。 鼠标轨迹算法的底层实现采用C/C语言&#xff0c;原因在于C/C提供了高性能的执行能力和直接访问操作系统底层资源的能力。 鼠标轨迹算法具有以下优势&#xff1a; 模拟…

CSS学习之Grid网格布局基本概念、容器属性

网格布局 网格布局&#xff08;Grid&#xff09;是将网页划分成一个个网格单元&#xff0c;可任意组合不同的网格&#xff0c;轻松实现各种布局效果&#xff0c;也是目前CSS中最强大布局方案&#xff0c;比Flex更强大。 基本概念 容器和项目 当一个 HTML 元素将 display 属性…

聊一聊Elasticsearch的索引的分片分配机制

1、什么是分片分配 分片分配是由ES主节点将索引分片移动到ES集群中各个节点上的过程。 该过程尽量保证&#xff0c;同一个索引的分片尽量分配到更多的节点上&#xff0c;以此来达到读写索引的时候可以利用更多硬件资源的效果。 在分配过程当中&#xff0c;也不能将某个主分片…

springboot的增删改查商城小实践(b to c)

首先准备一张表&#xff0c;根据业务去设计表 订单编号是参与业务的&#xff0c;他那订单编号里面是有特殊意义的&#xff0c;比如说像什么一些年月日什么的&#xff0c;一些用户的ID都在那编号里面呢&#xff1f;不能拿这种东西当主件啊 根据数据量去决定数据类型 价格需要注意…

Ubuntu 安装 RTL8811cu 网卡驱动

一、支持的网卡 RTL8811AU、RTL8811CU、RTL8821AU、RTL8821CU 二、下载驱动 github&#xff1a;https://github.com/brektrou/rtl8821CU 直接下载zip源码即可 三、安装驱动 sudo su -i #切换到root用户 apt-get update #更新安装源 apt-get install -y dkms …

解锁炎症和肿瘤免疫治疗新靶点:TREM1&TREM2

前 言 TREM家族属于细胞表面受体&#xff0c;介导调控炎症反应&#xff0c;现已成为癌症、神经退行性疾病以及炎症性疾病等多种疾病最有潜力的药物靶点。截至2023年6月&#xff0c;有5项FDA注册的临床前或临床试验正在进行中&#xff0c;有3项是TREM2在阿尔茨海默症&#xff…

【Unity】Unity拖拽在Android设备有延迟和卡顿问题的解决

一、介绍 在制作Block类游戏时&#xff0c;其核心的逻辑就是拖拽方块放入到地图中&#xff0c;这里最先想到的就是Unity的拖拽接口IDragHandler,然后通过 IPointerDownHandler, IPointerUpHandler 这两个接口判断按下和松手&#xff0c;具体的实现逻辑就是下面 public void On…