从零学习大模型(二)-----AG_NEWS上训练自回归Transformer

有兴趣的同学可以在自己的电脑上跑跑看看实验结果

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocabclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super(PositionalEncoding, self).__init__()# 创建位置编码矩阵,形状为 (max_len, d_model)pe = torch.zeros(max_len, d_model)# 创建位置的张量 (0, 1, 2, ..., max_len-1) 并扩展其维度position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)# 计算正弦和余弦函数的除数项div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))# 对位置编码的偶数索引应用正弦函数,奇数索引应用余弦函数pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)# 添加一个额外的维度以便与批次兼容pe = pe.unsqueeze(0).transpose(0, 1)# 注册位置编码为缓冲区,在训练期间不更新self.register_buffer('pe', pe)def forward(self, x):# 将位置编码加到输入的嵌入上return x + self.pe[:x.size(0), :]class TransformerEncoderLayerCustom(nn.Module):def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):super(TransformerEncoderLayerCustom, self).__init__()# 自定义多头自注意力机制self.d_model = d_modelself.nhead = nheadself.dropout = nn.Dropout(dropout)# 前馈网络,包含两个线性层和一个激活函数(ReLU)self.linear1 = nn.Linear(d_model, dim_feedforward)self.linear2 = nn.Linear(dim_feedforward, d_model)# 层归一化self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)# 激活函数self.activation = F.reludef scaled_dot_product_attention(self, query, key, value, mask=None):# 计算注意力分数scores = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(self.d_model // self.nhead)# 应用掩码(如果有)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))# 计算注意力权重attn_weights = F.softmax(scores, dim=-1)# 应用注意力权重到值上return torch.matmul(attn_weights, value)def forward(self, src, src_mask=None):batch_size, seq_len, _ = src.size()head_dim = self.d_model // self.nhead# 将输入分割成多个头query = key = value = src.view(batch_size, seq_len, self.nhead, head_dim).transpose(1, 2)# 计算多头注意力attn_output = self.scaled_dot_product_attention(query, key, value, mask=src_mask)attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)attn_output = self.dropout(attn_output)# 残差连接和层归一化src = self.norm1(src + attn_output)# 前馈网络src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))# 残差连接和层归一化src = self.norm2(src + src2)return srcclass AutoregressiveTransformer(nn.Module):def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_len):super(AutoregressiveTransformer, self).__init__()# 嵌入层,将标记索引转换为稠密向量self.embedding = nn.Embedding(vocab_size, d_model)# 位置编码,用于将序列信息添加到嵌入中self.pos_encoder = PositionalEncoding(d_model, max_len)# 自定义的 Transformer 编码器层,使用指定的参数encoder_layers = TransformerEncoderLayerCustom(d_model, nhead, dim_feedforward)# 堆叠多个 Transformer 编码器层self.transformer_encoder = nn.ModuleList([encoder_layers for _ in range(num_encoder_layers)])self.d_model = d_model# 线性层,将编码器的输出投影到词汇表大小self.decoder = nn.Linear(d_model, vocab_size)def forward(self, src, src_mask):# 对源标记应用嵌入层,并按 sqrt(d_model) 进行缩放src = self.embedding(src) * np.sqrt(self.d_model)# 将位置编码加到嵌入后的标记上src = self.pos_encoder(src)# 通过所有的 Transformer 编码器层for layer in self.transformer_encoder:src = layer(src, src_mask)# 将输出投影到词汇表大小output = self.decoder(src)return outputdef generate_square_subsequent_mask(self, sz):# 生成一个掩码,以防模型关注未来的位置mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))return maskdef train(model, data, vocab_size, num_epochs=10, learning_rate=0.0005):# 定义损失函数为交叉熵损失criterion = nn.CrossEntropyLoss()# 使用 Adam 优化器进行训练optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):model.train()  # 将模型设置为训练模式total_loss = 0for batch in data:# 输入序列为除最后一个标记外的所有标记src = batch[:-1]# 目标序列为除第一个标记外的所有标记(向右移动一个位置)tgt = batch[1:]# 为输入序列生成后续掩码src_mask = model.generate_square_subsequent_mask(len(src)).to(src.device)optimizer.zero_grad()  # 在反向传播前将梯度归零# 通过模型进行前向传播output = model(src, src_mask)# 计算模型输出与目标序列之间的损失loss = criterion(output.view(-1, vocab_size), tgt.view(-1))# 反向传播损失loss.backward()# 更新模型参数optimizer.step()total_loss += loss.item()# 打印每个 epoch 的平均损失print(f"Epoch {epoch+1}, Loss: {total_loss / len(data)}")def predict(model, start_token, max_len, vocab_size):model.eval()  # 将模型设置为评估模式generated_sequence = [start_token]  # 使用起始标记初始化生成的序列# 使用起始标记创建初始输入张量src = torch.tensor([start_token]).unsqueeze(1)  # 形状为 (seq_len, batch_size)for _ in range(max_len - 1):# 为当前输入序列生成掩码src_mask = model.generate_square_subsequent_mask(len(src)).to(src.device)# 通过模型进行前向传播output = model(src, src_mask)# 获取具有最高概率的标记作为下一个标记next_token = torch.argmax(output[-1, 0, :], dim=-1).item()# 将预测的标记添加到生成的序列中generated_sequence.append(next_token)# 通过添加新标记更新输入序列src = torch.cat([src, torch.tensor([[next_token]])], dim=0)return generated_sequence# 数据处理部分
train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english')
counter = Counter()
for (label, line) in train_iter:counter.update(tokenizer(line))
vocab = Vocab(counter, specials=['<unk>', '<pad>'])# 超参数
vocab_size = len(vocab)  # 使用 AG_NEWS 数据集的词汇表大小
d_model = 32  # 嵌入向量的维度
nhead = 2  # 注意力头的数量
num_encoder_layers = 2  # Transformer 编码器层的数量
dim_feedforward = 64  # Transformer 中前馈网络的维度
max_len = 20  # 输入序列的最大长度# 实例化模型
model = AutoregressiveTransformer(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_len)# 生成训练数据
train_iter = AG_NEWS(split='train')
data = []
for (label, line) in train_iter:tokens = [vocab[token] for token in tokenizer(line)]if len(tokens) >= max_len:tokens = tokens[:max_len]else:tokens = tokens + [vocab['<pad>']] * (max_len - len(tokens))data.append(torch.tensor(tokens))if len(data) >= 100:  # 使用 100 条样本进行演示break# 训练模型
train(model, data, vocab_size, num_epochs=10)# 预测序列
start_token = vocab['<unk>']  # 序列生成的起始标记
generated_sequence = predict(model, start_token, max_len, vocab_size)
print("Generated Sequence:", generated_sequence)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/453342.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spring day 1021

ok了家人们&#xff0c;这周学习spring框架&#xff0c;我们一起去看看吧 Spring 一.Spring概述 1.1 Spring介绍 官网&#xff1a; https://spring.io/ 广义的 Spring &#xff1a; Spring 技术栈 &#xff08;全家桶&#xff09; 广义上的 Spring 泛指以 Spring Framework…

Spring AI 整体介绍_关键组件快速入门_prompt_embedding等

Spring AI&#xff1a;Java开发者的AI集成新利器 在过去&#xff0c;Java开发者在构建AI应用时面临着缺乏统一框架的问题&#xff0c;导致不同AI服务的集成过程复杂且耗时。Spring AI应运而生&#xff0c;旨在为基于Java的应用程序提供一个标准化、高效且易于使用的AI开发平台…

浅说差分算法(下)

我们上节课学了一维的差分&#xff0c;但其实还有二维差分&#xff0c;只是比较难写。 差分 二维差分的定义 二维差分是指对于一个n*m的矩阵a&#xff0c;要求支持操作pro(x1,y1,x2,y2,a)&#xff0c;表示对于以(x1,y1)为左上角&#xff0c;(x2,y2)为右下角的矩形区域&#…

生产车间质量管理有什么用?怎么做?

在生产车间的质量管理中&#xff0c;科学有效的管理方法和严格规范的执行流程是至关重要的&#xff0c;它能够帮助企业提高产品质量、降低次品率、确保生产过程的稳定性和效率。然而&#xff0c;许多企业在生产车间质量管理方面存在诸多问题&#xff0c;常常会面临以下困境&…

多微批量自动加好友

在数字化时代&#xff0c;微信不仅是社交通讯的工具&#xff0c;更是一个拥有庞大用户基础的流量平台。对于企业而言&#xff0c;微信是打造私域流量池的理想选择之一。然而&#xff0c;随着微信号的增多&#xff0c;手动添加好友和备注变得既繁琐又耗时。幸运的是&#xff0c;…

UNI VFX Missiles Explosions for Visual Effect Graph

Unity URP和HDRP的通用视觉效果 使用在视觉效果图中制作的高性能GPU粒子系统。 无需进入视觉效果图编辑器即可轻松自定义VFX。 使用(VFX)事件——一个游戏对象可存储多个效果,这些效果可通过C#或视觉脚本触发。 总共32个事件(不包括“停止”事件)。 ❓ 什么是(VFX)事件?…

Cpp::STL—容器适配器Stack和Queue的讲解和模拟实现(15)

文章目录 前言一、适配器模式概念分类 二、Stack核心作用代码实现 三、Queue核心作用代码实现 四、deque双端队列貌似兼收并蓄&#xff1f;实则也难以兼得~ 总结 前言 适配器也是STL六大组件之一&#xff0c;请跟我一起领悟它的智慧&#xff01;   正文开始&#xff01; 一、…

consumer 角度讲一下i2c外设

往期内容 I2C子系统专栏&#xff1a; I2C&#xff08;IIC&#xff09;协议讲解-CSDN博客SMBus 协议详解-CSDN博客I2C相关结构体讲解:i2c_adapter、i2c_algorithm、i2c_msg-CSDN博客内核提供的通用I2C设备驱动I2c-dev.c分析&#xff1a;注册篇内核提供的通用I2C设备驱动I2C-dev.…

浅析建造者模式

建造者模式 一、基础知识介绍 1. 问题引出 上图面存在的问题&#xff1a;产品和产品创建的过程是封装在一起的。耦合性太强 解决方法: 将二者解耦和 2.建造者模式介绍 将复杂对象的构造过程抽象出来&#xff0c;用户不用知晓里面的构建细节 3.四个角色 建造者模式的四个角…

Java项目-基于springboot框架的财务管理系统项目实战(附源码+文档)

作者&#xff1a;计算机学长阿伟 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、ElementUI等&#xff0c;“文末源码”。 开发运行环境 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBoot、Vue、Mybaits Plus、ELementUI工具&#xff1a;IDEA/…

【element-tiptap】如何修改选中内容时的背景颜色?

前言&#xff1a;element-tiptap 用鼠标选中内容的时候&#xff0c;背景颜色跟系统设置的主题有关&#xff0c;比如的我的就是卡哇伊的pink&#xff0c;默认是淡蓝色 但是我们观察一下语雀&#xff0c;背景颜色是它规定好的颜色 这篇文章来探索一下&#xff0c;怎么自己规定选…

实操上手TinyEngine低代码引擎插件化开发

1.背景介绍 1.1 TinyEngine 低代码引擎简介 低代码开发是近些年非常热门的一种开发方式&#xff0c;用户可以通过可视化的方式&#xff0c;简单拖拽&#xff0c;不写代码或者编写少量代码&#xff0c;类似搭积木一样搭建业务应用。 TinyEngine是一个强大的低代码引擎&#x…

企业博客SEO优化:8个必备工具与资源指南

在当今数字化时代&#xff0c;企业博客已远远超越了传统意义上的信息展示平台。它不仅是企业展示品牌形象、传递品牌价值的重要窗口&#xff0c;更是吸引潜在客户、增强用户粘性、提升网站流量和搜索引擎排名的关键。通过精心策划和高质量的内容创作&#xff0c;企业博客能够建…

ChatGPT4o、o1 谁才是最佳大模型?

如何选择合适的 ChatGPT 模型&#xff1f;OpenAI 更新细节与 GPTs 的深入解析 随着人工智能的发展&#xff0c;ChatGPT 已成为众多用户的强大助手&#xff0c;广泛应用于写作、编程、学习和商业等多个领域。然而&#xff0c;面对 OpenAI 提供的众多模型&#xff08;如 GPT-4、…

idea中,git提交时忽略某些本地修改.将文件从git暂存区移除

我们有时候在本地调试代码时&#xff0c;某些配置文件需要修改成本地环境中。当改完后&#xff0c;需要提交代码时&#xff0c;这些文件又不能推到git上。如下图&#xff1a; 当出现这种情况&#xff0c;我们每次都需要手动去将不需要提交的文件的对号去掉。文件多了后&#x…

[Redis] 在Linux中安装Redis并连接图形化工具详细过程(附下载链接)

前言 安装Redis之前应该在虚拟机中安装Linux系统&#xff0c;这里使用centos7版本 [linux] 在VMware中安装linux、文件下载及详细安装过程&#xff08;附下载链接&#xff09;-CSDN博客 安装Linux后&#xff0c;更换yum源为阿里云并安装gcc依赖 [Linux] CentOS7替换yum源为阿…

Rust 语言持续崛起,即将冲击 TIOBE 指数前十,能否成为编程语言新王者?

Rust 语言持续崛起&#xff0c;即将冲击 TIOBE 指数前十&#xff0c;能否成为编程语言新王者&#xff1f; 2024 年 10 月&#xff0c;全球编程语言 TIOBE 排行榜再次更新&#xff0c;各大编程语言在各自领域中继续发挥着独特的优势。官方的标题是&#xff1a; Rust排名稳步攀升…

【代码随想录Day47】单调栈Part02

42. 接雨水 题目链接/文章讲解&#xff1a;代码随想录 视频讲解&#xff1a;单调栈&#xff0c;经典来袭&#xff01;LeetCode:42.接雨水_哔哩哔哩_bilibili 思路概述 问题理解&#xff1a;我们需要计算在给定柱子高度之间可以接住的雨水总量。雨水的量取决于柱子的高度和它们…

PP-ChatOCRv3—文档场景信息抽取v3产线使用教程

文档场景信息抽取v3产线使用教程 1. 文档场景信息抽取v3产线介绍 文档场景信息抽取v3&#xff08;PP-ChatOCRv3&#xff09;是飞桨特色的文档和图像智能分析解决方案&#xff0c;结合了 LLM 和 OCR 技术&#xff0c;一站式解决版面分析、生僻字、多页 pdf、表格、印章识别等常…

有同学问:拿到大厂JAVA OFFER,但是会不会不稳定,有失业风险?!

昨天在直播里面有一个同学说拿到了大厂的offer&#xff0c;但是最近看了很多很多的报道&#xff0c;说大厂Java会不会也失业&#xff1f; 前两天也有家长私信咨询说孩子去了外企&#xff0c;拿着23K的工资&#xff0c;会不会也不稳定&#xff1f; 现在很多同学看了新闻报道或…