花费7元训练自己的GPT 2模型

在上一篇博客中,我介绍了用Tensorflow来重现GPT 1的模型和训练的过程。这次我打算用Pytorch来重现GPT 2的模型并从头进行训练。

GPT 2的模型相比GPT 1的改进并不多,主要在以下方面:

1. GPT 2把layer normalization放在每个decoder block的前面。

2. 最终的decoder block之后额外添加了一个layer normalization。

3. 残差层的参数初始化根据网络深度进行调节

4. 训练集采用了webtext(45GB),而不是之前采用的bookcorpus(5GB)

5. 更深的网络结构,最大的模型拥有15亿的参数,对比GPT 1是1.2亿的参数

GPT 2有以下四种不同深度的模型架构,如图:

以下我将用pytorch代码来搭建一个GPT 2的模型,以最小的GPT 2为例,采用bookcorpus的数据,在AutoDL平台的一个40G显存的A100显卡上进行训练,看看效果如何。

模型结构

整个模型的结构和GPT 1是基本一致的。

定义一个多头注意力模块,如以下代码:

class MHA(nn.Module):def __init__(self, d_model, num_heads, attn_pdrop, resid_pdrop):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.attn_pdrop = attn_pdropself.resid_dropout = nn.Dropout(resid_pdrop)self.ln = nn.Linear(d_model, d_model*3)self.c_proj = nn.Linear(d_model, d_model)def forward(self, x):B, T, C = x.size()x_qkv = self.ln(x)q, k, v = x_qkv.split(self.d_model, dim=2)q = q.view(B, T, self.num_heads, C//self.num_heads).transpose(1, 2)k = k.view(B, T, self.num_heads, C//self.num_heads).transpose(1, 2)v = v.view(B, T, self.num_heads, C//self.num_heads).transpose(1, 2)y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_pdrop if self.training else 0, is_causal=True)y = y.transpose(1, 2).contiguous().view(B, T, C)y = self.c_proj(y)y = self.resid_dropout(y)return y

这个模块接收一个输入数据,大小为(batch_size, seq_len, dimension),然后进行一个线性变换层,把数据映射为(batch_size, seq_len, dimension*3)的维度,这里的dimension*3表示的是qkv这三个值的拼接。接着就把这个数据切分为q,k,v三份,然后每份都把维度调整为(batch_size, seq_len, num_head, dimension/num_head),num_head表示这个自注意力模块包含多少个head。最后就可以调用scaled_dot_product_attention进行qk的相似度计算,进行缩放之后与v值相乘。Pytorch的这个函数提供了最新的flash attention的实现,可以大幅提升计算性能。最后就是对qkv的结果进行一个线性变换,映射为一个(batch_size, seq_len, dimension)的向量。

自注意力模块的输出结果,将通过一个Feed forward层进行计算,代码如下:

class FeedForward(nn.Module):def __init__(self, d_model, dff, dropout):super().__init__()self.ln1 = nn.Linear(d_model, dff)self.ln2 = nn.Linear(dff, d_model)self.dropout = nn.Dropout(dropout)self.layernorm = nn.LayerNorm(d_model)self.gelu = nn.GELU()def forward(self, x):x = self.ln1(x)x = self.gelu(x)x = self.ln2(x)x = self.dropout(x)return x

代码很简单,就是做了两次线性变换,第一次把维度扩充到dimension*4,第二次把维度恢复为dimension。

最后定义一个decoder block模块,把多头注意力模块和feed forward模块组合起来,代码如下:

class Block(nn.Module):def __init__(self, d_model, num_heads, dff, attn_pdrop, resid_pdrop, dropout):super().__init__()self.layernorm1 = nn.LayerNorm(d_model)self.attn = MHA(d_model, num_heads, attn_pdrop, resid_pdrop)self.layernorm2 = nn.LayerNorm(d_model)self.ff = FeedForward(d_model, dff, dropout)def forward(self, x):x = x + self.attn(self.layernorm1(x))x = x + self.ff(self.layernorm2(x))return x

有了decoder block之后,GPT 2的模型就是把这些block串起来,例如最小的GPT 2的模型结构是定义了12个decoder block。模型接收的是字符序列经过tokenizer之后的数字,然后把这些数字通过embedding层映射为向量表达,例如对每个token id,映射为784维度的一个向量。为了能在embedding的向量里面反映字符的位置信息,我们需要把字符的位置也做一个embedding,然后两个embedding相加。

输入数据经过embedding处理后,通过多个decoder block处理之后,数据的维度为(batch_size, seq_len, dimension), 我们需要通过一个权重维度为(dimension, vocab_size)的线性变换,把数据映射为(batch_size, seq_len, vocab_size)的维度。这里vocab_size表示tokenizer的单词表的长度,例如对于GPT 2所用的tokenizer,有50257个单词。对于输出数据进行softmax计算之后,我们就可以得到每个token的预测概率,从而可以和label数据,即真实的下一个token id进行比较,计算loss值。

GPT 2模型的代码如下:

class GPT2(nn.Module):def __init__(self, vocab_size, d_model, block_size, embed_pdrop, num_heads, dff, attn_pdrop, resid_pdrop, dropout, num_layer):super().__init__()self.token_embed = nn.Embedding(vocab_size, d_model, sparse=False)self.pos_embed = nn.Embedding(block_size, d_model, sparse=False)self.dropout_embed = nn.Dropout(embed_pdrop)#self.blocks = [Block(d_model, num_heads, dff, attn_pdrop, resid_pdrop, dropout) for _ in range(num_layer)]self.blocks = nn.ModuleList([Block(d_model, num_heads, dff, attn_pdrop, resid_pdrop, dropout) for _ in range(num_layer)])self.num_layer = num_layerself.block_size = block_sizeself.lm_head = nn.Linear(d_model, vocab_size, bias=False)self.token_embed.weight = self.lm_head.weightself.layernorm = nn.LayerNorm(d_model)self.apply(self._init_weights)# apply special scaled init to the residual projections, per GPT-2 paperfor pn, p in self.named_parameters():if pn.endswith('c_proj.weight'):torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_layer))def _init_weights(self, module):if isinstance(module, nn.Linear):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)if module.bias is not None:torch.nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)def forward(self, x, targets=None):device = x.deviceb, t = x.size()pos = torch.arange(0, t, dtype=torch.long, device=device) x = self.token_embed(x) + self.pos_embed(pos)x = self.dropout_embed(x)for block in self.blocks:x = block(x)x = self.layernorm(x)if targets is not None:logits = self.lm_head(x)loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)else:logits = self.lm_head(x[:, -1, :])loss = Nonereturn logits, lossdef configure_optimizers(self, weight_decay, learning_rate, betas, device_type):# start with all of the candidate parametersparam_dict = {pn: p for pn, p in self.named_parameters()}# filter out those that do not require gradparam_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]optim_groups = [{'params': decay_params, 'weight_decay': weight_decay},{'params': nodecay_params, 'weight_decay': 0.0}]num_decay_params = sum(p.numel() for p in decay_params)num_nodecay_params = sum(p.numel() for p in nodecay_params)print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")# Create AdamW optimizer and use the fused version if it is availablefused_available = 'fused' in inspect.signature(torch.optim.AdamW).parametersuse_fused = fused_available and device_type == 'cuda'extra_args = dict(fused=True) if use_fused else dict()optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)print(f"using fused AdamW: {use_fused}")return optimizer@torch.no_grad()def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, block_size=512):for _ in range(max_new_tokens):# if the sequence context is growing too long we must crop it at block_sizeidx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]# forward the model to get the logits for the index in the sequencelogits, _ = self(idx_cond)# pluck the logits at the final step and scale by desired temperaturelogits = logits / temperature# optionally crop the logits to only the top k optionsif top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits[logits < v[:, [-1]]] = -float('Inf')# apply softmax to convert logits to (normalized) probabilitiesprobs = F.softmax(logits, dim=-1)# sample from the distributionidx_next = torch.multinomial(probs, num_samples=1)# append sampled index to the running sequence and continueidx = torch.cat((idx, idx_next), dim=1)return idx

模型训练

定义好模型之后,我们就可以开始训练了。

首先我们需要准备训练数据集。GPT 2采用的是webtext,网上的一些公开网页数据来进行训练。在Huggingface上面有对应的一个公开数据集。不过考虑到我们的资源有限,我这次还是采用GPT 1所用的bookcorpus数据集来训练。

以下代码是下载huggingface的数据集,并用GPT 2的tokenizer来进行编码:

from datasets import load_dataset
from transformers import GPT2Tokenizerdataset = load_dataset("bookcorpusopen", split="train")block_size=513
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")def tokenize_function(examples):token_ids = [tokenizer(text) for text in examples["text"]]total_length = [len(t["input_ids"]) for t in token_ids]total_length = [(l//(block_size+1))*(block_size+1) for l in total_length]result = []label = []for i in range(len(total_length)):result.extend([token_ids[i]["input_ids"][j:j+block_size+1] for j in range(0, total_length[i], block_size+1)])return {"token_ids": result}ds_test = ds['train'].select(range(10000))tokenized_datasets = ds_test.map(tokenize_function, batched=True, num_proc=8, remove_columns=["title", "text"], batch_size=100
)tokenized_datasets.save_to_disk("data/boocorpusopen_10000_512tokens")

GPT1采用的bookcorpus有7000多本书,huggingface的bookcorpusopen数据集有14000多本,这里我只采用了10000本书来构建数据集,对于每本书进行tokenizer转化后,每513个token写入为1条记录。这样我们在训练时,每条记录我们采用前1-512个token作为训练,取2-513个token作为label。

以下代码将读取我们处理好的数据集,并转化为pytorch的dataloader

from datasets import load_from_diskdataset = load_from_disk("data/boocorpusopen_10000_512tokens")
dataset = dataset.with_format("torch")
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)

然后我们就可以实例化一个GPT 2的model并开始训练,具体的代码可以见repo https://github.com/gzroy/gpt2_torch.git 里面的train.py文件。

如果在本地显卡上训练,对应12层的网络结构需要30多G的显存,我的显卡是2080Ti,只有11G显存,因此只能指定6层decoder。我们可以在autodl上面租用一个40G显存的A100显卡,价格是3.45元每小时,在这个显卡上开启半精度进行训练,大约1个小时可以跑10000个迭代,batch大小为64。我总共训练了2小时,最终在训练集上的Loss值为3.5左右,准确度为35%,花费为7元。

生成文本

最后我们可以基于这个训练了1个小时的GPT 2模型来测试一下,看生成文本的效果如何,如以下代码:

from transformers import GPT2Tokenizer
from model import GPT2
import torch
from torch.nn import functional as F
import argparseif __name__ == '__main__':parser = argparse.ArgumentParser(description='gpt2 predict')parser.add_argument('--checkpoint_path', type=str, default='checkpoints/')parser.add_argument('--checkpoint_name', type=str, default='')parser.add_argument('--d_model', type=int, default=768)parser.add_argument('--block_size', type=int, default=512)parser.add_argument('--dff', type=int, default=768*4)parser.add_argument('--heads', type=int, default=12)parser.add_argument('--decoder_layers', type=int, default=6)parser.add_argument('--device', type=str, default='cuda')parser.add_argument('--input', type=str)parser.add_argument('--generate_len', type=int, default=100)parser.add_argument('--topk', type=int, default=5)args = parser.parse_args()tokenizer = GPT2Tokenizer.from_pretrained("gpt2")vocab_size = len(tokenizer.get_vocab())model = GPT2(vocab_size, args.d_model, args.block_size, 0, args.heads, args.dff, 0, 0, 0, args.decoder_layers)model.to(args.device)model = torch.compile(model)checkpoint = torch.load(args.checkpoint_path+args.checkpoint_name)model.load_state_dict(checkpoint['model_state_dict'])token_id = tokenizer.encode(args.input)input_data = torch.reshape(torch.tensor(token_id, device=args.device), [1,-1])predicted = model.generate(input_data, args.generate_len, 1.0, args.topk, args.block_size)print("Generated text:\n-------------------")print(tokenizer.decode(predicted.cpu().numpy()[0]))

运行以下命令,给定一个文本的开头,然后让模型生成200字看看:

python predict.py --checkpoint_name model_1.pt --input 'it was saturday night, the street' --generate_len 200 --topk 10

生成的文本如下:

it was saturday night, the street lights blared and the street lights flickered on. A few more houses were visible.The front door opened, and a large man stepped in and handed him one. He handed the man the keys and a small smile. It looked familiar, and then a little too familiar. The door was closed."Hey! You guys out there?" he said, his eyes wide."What are you up to?" the man asked."I'm just asking for you out in my office."The man was about thirty feet away from them."I'm in a serious situation, but it's just the way you are."He looked around at the man, the man looked up and down, and then his eyes met hers. He was a little older than he was, but his eyes were blue with red blood. He looked like a giant. His eyes were blue and red, and his jaw looked like a giant

可见生成的文本语法没有问题,内容上也比较连贯,上下文的逻辑也有关联。如果模型继续训练更长时间,相信生成文本的内容会更加好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/73594.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL索引特性

MySQL索引特性 理论部分&#xff1a;一.什么是索引&#xff1f;二.索引的概念三.认识磁盘1. 磁盘的结构2. 磁盘的随机访问&#xff08;Random Access&#xff09;与连续访问&#xff08;Sequential Access&#xff09; 四.MySQL与磁盘交互的基本单位五.索引的理解1. 主键索引现…

SpringBoot 实现数据加密脱敏(注解 + 反射 + AOP)

SpringBoot 实现数据加密脱敏&#xff08;注解 反射 AOP&#xff09; 场景&#xff1a;响应政府要求&#xff0c;商业软件应保证用户基本信息不被泄露&#xff0c;不能直接展示用户手机号&#xff0c;身份证&#xff0c;地址等敏感信息。 根据上面场景描述&#xff0c;我们…

不同 vlan 之间互通

不同VLAN间的用户要实现互通 如果是不同网段用户&#xff0c;常用的技术为&#xff1a;vlanif 和 单臂路由都可以解决不同 vlan 之间三层包互通问题。 VLANIF VLANIF接口是一种三层的逻辑接口&#xff0c;能实现不同VLAN间&#xff0c;不同网段的用户进行三层互通。由于配置…

概率论与数理统计复习总结3

概率论与数理统计复习总结&#xff0c;仅供笔者复习使用&#xff0c;参考教材&#xff1a; 《概率论与数理统计》/ 荣腾中主编. — 第 2 版. 高等教育出版社《2024高途考研数学——概率基础精讲》王喆 概率论与数理统计实际上是两个互补的分支&#xff1a;概率论 在 已知随机…

【使用维纳滤波进行信号分离】基于维纳-霍普夫方程的信号分离或去噪维纳滤波器估计(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

error: #5: cannot open source input file “core_cmInstr.h“

GD32F103VET6和STM32F103VET6引脚兼容。 GD32F103VET6工程模板需要包含头文件&#xff1a;core_cmInstr.h和core_cmFunc.h&#xff0c;这个和STM32F103还是有区别的&#xff0c;否则会报错&#xff0c;如下&#xff1a; error: #5: cannot open source input file "core…

两个镜头、视野、分辨率不同的相机(rgb、红外)的视野校正

文章目录 背景实际效果查找资料资料1资料2 解决方案最终结果 背景 目前在做的项目用到两个摄像头&#xff0c;一个是热成像摄像头、另一个是普通的rgb摄像头。 一开始的目标是让他们像素级重合&#xff0c;使得点击rgb图像时&#xff0c;即可知道其像素对应的温度。但是在尝试…

js中的设计模式

设计模式 代码整体的结构会更加清楚&#xff0c;管理起来会更加方便&#xff0c;更好地维护 设计模式是一种思想 发布订阅 模块化开发 导入很多模块 容器即数组存储未来要执行的方法&#xff0c;同addEventListener 数组塌陷问题* 由于删除了元素&#xff0c;导致从删除元素的位…

ppt怎么压缩到10m以内?分享好用的压缩方法

PPT是一种常见的演示文稿格式&#xff0c;有时候文件过大&#xff0c;我们会遇到无法发送、上传的现象&#xff0c;这时候简单的解决方法就是压缩其大小&#xff0c;那怎么才能将PPT压缩到10M以内呢&#xff1f; PPT文件大小受到影响的主要因素就是以下几点&#xff1a; 1、图…

VR全景旅游,智慧文旅发展新趋势!

引言&#xff1a; VR全景旅游正在带领我们踏上一场全新的旅行体验。这种沉浸式的旅行方式&#xff0c;让我们可以足不出户&#xff0c;却又身临其境地感受世界各地的美景。 一&#xff0e;VR全景旅游是什么&#xff1f; VR全景旅游是一种借助虚拟现实技术&#xff0c;让用户…

AssetBundle学习

官方文档&#xff1a;AssetBundle 工作流程 - Unity 手册 (unity3d.com) 之前写的博客&#xff1a;AssetBundle学习_zaizai1007的博客-CSDN博客 使用流程图&#xff1a; 1&#xff0c;指定资源的AssetBundle属性 &#xff08;xxxa/xxx&#xff09;这里xxxa会生成目录&…

Arcgis 分区统计majority参数统计问题

利用Arcgis 进行分区统计时&#xff0c;需要统计不同矢量区域中栅格数据的众数&#xff08;majority&#xff09;&#xff0c;出现无法统计majority参数问题解决 解决&#xff1a;利用copy raster工具&#xff0c;将原始栅格数据 64bit转为16bit

iOS 应用上架流程详解

iOS 应用上架流程详解 欢迎来到我的博客&#xff0c;今天我将为大家分享 iOS 应用上架的详细流程。在这个数字化时代&#xff0c;移动应用已经成为了人们生活中不可或缺的一部分&#xff0c;而 iOS 平台的 App Store 则是开发者们发布应用的主要渠道之一。因此&#xff0c;了解…

Vision Transformer (ViT):图像分块、图像块嵌入、类别标记、QKV矩阵与自注意力机制的解析

作者&#xff1a;CSDN _养乐多_ 本文将介绍Vision Transformers &#xff08;ViT&#xff09;中的关键点。包括图像分块&#xff08;Image Patching&#xff09;、图像块嵌入&#xff08;Patch Embedding&#xff09;、类别标记、&#xff08;class_token&#xff09;、QKV矩…

微服务 云原生:搭建 K8S 集群

为节约时间和成本&#xff0c;仅供学习使用&#xff0c;直接在两台虚拟机上模拟 K8S 集群搭建 踩坑之旅 系统环境&#xff1a;CentOS-7-x86_64-Minimal-2009 镜像&#xff0c;为方便起见&#xff0c;直接在 root 账户下操作&#xff0c;现实情况最好不要这样做。 基础准备 关…

pycharm——涟漪散点图

from pyecharts import options as opts from pyecharts.charts import EffectScatterc (EffectScatter().add_xaxis( ["高等数学1&#xff0c;2","C语言程序设计","python程序设计","大数据导论","数据结构","大数据…

CentOS 8 上安装 Nginx

Nginx是一款高性能的开源Web服务器和反向代理服务器&#xff0c;以其轻量级和高效能而广受欢迎。在本教程中&#xff0c;我们将学习在 CentOS 8 操作系统上安装和配置 Nginx。 步骤 1&#xff1a;更新系统 在安装任何软件之前&#xff0c;让我们先更新系统的软件包列表和已安…

opencv 31-图像平滑处理-方框滤波cv2.boxFilter()

方框滤波&#xff08;Box Filtering&#xff09;是一种简单的图像平滑处理方法&#xff0c;它主要用于去除图像中的噪声和减少细节&#xff0c;同时保持图像的整体亮度分布。 方框滤波的原理很简单&#xff1a;对于图像中的每个像素&#xff0c;将其周围的一个固定大小的邻域内…

HTTP、HTTPS协议详解

文章目录 HTTP是什么报文结构请求头部响应头部 工作原理用户点击一个URL链接后&#xff0c;浏览器和web服务器会执行什么http的版本持久连接和非持久连接无状态与有状态Cookie和Sessionhttp方法&#xff1a;get和post的区别 状态码 HTTPS是什么ssl如何搞到证书nginx中的部署 加…

Sqli-labs1~65关 通关详解 解题思路+解题步骤+解析

Sqli-labs 01关 (web517) 输入?id1 正常 输入?id1 报错 .0 输入?id1-- 正常判断是字符型注入&#xff0c;闭合方式是这里插一句。limit 100,1是从第100条数据开始&#xff0c;读取1条数据。limit 6是读取前6条数据。 ?id1 order by 3-- 正常判断回显位有三个。?id…