python-pytorch编写transformer模型实现问答0.5.00--训练和预测

python-pytorch编写transformer模型实现问答0.5.00--训练和预测

    • 背景
    • 代码
    • 训练
    • 预测
    • 效果

背景

代码写不了这么长,接上一篇
https://blog.csdn.net/m0_60688978/article/details/139360270

代码

#  定义解码器类
n_layers = 6  # 设置 Decoder 的层数
class Decoder(nn.Module):def __init__(self, corpus):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(vocab_size, d_embedding) # 词嵌入层self.pos_emb = nn.Embedding.from_pretrained( \get_sin_enc_table(vocab_size+1, d_embedding), freeze=True) # 位置嵌入层        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) # 叠加多层def forward(self, dec_inputs, enc_inputs, enc_outputs): #------------------------- 维度信息 --------------------------------# dec_inputs 的维度是 [batch_size, target_len]# enc_inputs 的维度是 [batch_size, source_len]# enc_outputs 的维度是 [batch_size, source_len, embedding_dim]#-----------------------------------------------------------------   # 创建一个从 1 到 source_len 的位置索引序列pos_indices = torch.arange(1, dec_inputs.size(1) + 1).unsqueeze(0).to(dec_inputs)#------------------------- 维度信息 --------------------------------# pos_indices 的维度是 [1, target_len]#-----------------------------------------------------------------              # 对输入进行词嵌入和位置嵌入相加dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(pos_indices)#------------------------- 维度信息 --------------------------------# dec_outputs 的维度是 [batch_size, target_len, embedding_dim]#-----------------------------------------------------------------        # 生成解码器自注意力掩码和解码器 - 编码器注意力掩码dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # 填充位掩码dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) # 后续位掩码dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask \+ dec_self_attn_subsequent_mask), 0) dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # 解码器 - 编码器掩码#------------------------- 维度信息 --------------------------------        # dec_self_attn_pad_mask 的维度是 [batch_size, target_len, target_len]# dec_self_attn_subsequent_mask 的维度是 [batch_size, target_len, target_len]# dec_self_attn_mask 的维度是 [batch_size, target_len, target_len]# dec_enc_attn_mask 的维度是 [batch_size, target_len, source_len]#-----------------------------------------------------------------       dec_self_attns, dec_enc_attns = [], [] # 初始化 dec_self_attns, dec_enc_attns# 通过解码器层 [batch_size, seq_len, embedding_dim]for layer in self.layers:dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)dec_self_attns.append(dec_self_attn)dec_enc_attns.append(dec_enc_attn)#------------------------- 维度信息 --------------------------------# dec_outputs 的维度是 [batch_size, target_len, embedding_dim]# dec_self_attns 是一个列表,每个元素的维度是 [batch_size, n_heads, target_len, target_len]# dec_enc_attns 是一个列表,每个元素的维度是 [batch_size, n_heads, target_len, source_len]#----------------------------------------------------------------- # 返回解码器输出,解码器自注意力和解码器 - 编码器注意力权重       return dec_outputs, dec_self_attns, dec_enc_attns# 定义 Transformer 模型
class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()        self.encoder = Encoder(encoder_input) # 初始化编码器实例        self.decoder = Decoder(decoder_input) # 初始化解码器实例# 定义线性投影层,将解码器输出转换为目标词汇表大小的概率分布self.projection = nn.Linear(d_embedding, vocab_size, bias=False)def forward(self, enc_inputs, dec_inputs):#------------------------- 维度信息 --------------------------------# enc_inputs 的维度是 [batch_size, source_seq_len]# dec_inputs 的维度是 [batch_size, target_seq_len]#-----------------------------------------------------------------        # 将输入传递给编码器,并获取编码器输出和自注意力权重        enc_outputs, enc_self_attns = self.encoder(enc_inputs)#------------------------- 维度信息 --------------------------------# enc_outputs 的维度是 [batch_size, source_len, embedding_dim]# enc_self_attns 是一个列表,每个元素的维度是 [batch_size, n_heads, src_seq_len, src_seq_len]        #-----------------------------------------------------------------          # 将编码器输出、解码器输入和编码器输入传递给解码器# 获取解码器输出、解码器自注意力权重和编码器 - 解码器注意力权重     dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)#------------------------- 维度信息 --------------------------------# dec_outputs 的维度是 [batch_size, target_len, embedding_dim]# dec_self_attns 是一个列表,每个元素的维度是 [batch_size, n_heads, tgt_seq_len, src_seq_len]# dec_enc_attns 是一个列表,每个元素的维度是 [batch_size, n_heads, tgt_seq_len, src_seq_len]   #-----------------------------------------------------------------                # 将解码器输出传递给投影层,生成目标词汇表大小的概率分布dec_logits = self.projection(dec_outputs)  #------------------------- 维度信息 --------------------------------# dec_logits 的维度是 [batch_size, tgt_seq_len, tgt_vocab_size]#-----------------------------------------------------------------# 返回逻辑值 ( 原始预测结果 ), 编码器自注意力权重,解码器自注意力权重,解 - 编码器注意力权重return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns

训练

transfomer=Transformer()
import torch # 导入 torch
import torch.optim as optim # 导入优化器
model = Transformer() # 创建模型实例
print(model)
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.00001) # 优化器
epochs = 40 # 训练轮次
for epoch in range(epochs): # 训练 100 轮optimizer.zero_grad() # 梯度清零
#     enc_inputs, dec_inputs, target_batch = corpus.make_batch(batch_size) # 创建训练数据    
#     print(enc_inputs, dec_inputs, target_batch)outputs, _, _, _ = model(encoder_input, decoder_input) # 获取模型输出 loss = criterion(outputs.view(-1, vocab_size), decoder_target.view(-1)) # 计算损失if (epoch + 1) % 1 == 0: # 打印损失print(f"Epoch: {epoch + 1:04d} cost = {loss:.6f}")loss.backward()# 反向传播        optimizer.step()# 更新参数

预测

model.eval()
question_text = '张学友是哪里人'
question_cut = list(jieba.cut(question_text))
encoder_x = make_data([question_cut])
decoder_x = [[word2index['SOS']]]
encoder_x,  decoder_x = torch.LongTensor(encoder_x), torch.LongTensor(decoder_x)# decoder_x=torch.tensor([[1, 0, 0,  0,  0]])decoder_x=torch.zeros(1,seq_length,dtype=torch.long)outt=1
for i in range(seq_length):decoder_x[0][i]=outtpredict, enc_self_attns, dec_self_attns, dec_enc_attns = model(encoder_x, decoder_x) # 用模型进行翻译predict = predict.view(-1,vocab_size) # 将预测结果维度重塑predict = predict.data.max(1, keepdim=True)[1] # 找到每个位置概率最大的词汇的索引outt=predict[i].item()predict, enc_self_attns, dec_self_attns, dec_enc_attns = model(encoder_x, decoder_x) # 用模型进行翻译
predictWords=predict.data.max(-1)answer = ''
for i in predictWords[1][0]:if i.item() in [2,0]:breakanswer += index2word[i.item()]
print('问题:', question_text)
print('回答:', answer)

效果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/337469.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

向量数据库引领 AI 创新——Zilliz 亮相 2024 亚马逊云科技中国峰会

2024年5月29日,亚马逊云科技中国峰会在上海召开,此次峰会聚集了来自全球各地的科技领袖、行业专家和创新企业,探讨云计算、大数据、人工智能等前沿技术的发展趋势和应用场景。作为领先的向量数据库技术公司,Zilliz 在本次峰会上展…

SpringBoot+layui实现Excel导入操作

excel导入步骤 第三方插件引入插件 效果图 (方法1)代码实现(方法1)Html代码( 公共)下载导入模板 js实现 (方法1)上传文件实现 效果图(方法2)代码实现&#xf…

mimkatz获取windows10明文密码

目录 mimkatz获取windows10明文密码原理 lsass.exe进程的作用 mimikatz的工作机制 Windows 10的特殊情况 实验 实验环境 实验工具 实验步骤 首先根据版本选择相应的mimikatz 使用管理员身份运行cmd 修改注册表 ​编辑 重启 重启电脑后打开mimikatz 在cmd切换到mi…

Matlab|基于粒子群算法优化Kmeans聚类的居民用电行为分析

目录 主要内容 部分代码 结果一览 下载链接 主要内容 在我们研究电力系统优化调度模型的过程中,由于每天负荷和分布式电源出力随机性和不确定性,可能会优化出很多的结果,但是经济调度模型试图做到通用策略,同样的策…

HarmonyOS鸿蒙学习笔记(25)相对布局 RelativeContainer详细说明

RelativeContainer 简介 前言核心概念官方实例官方实例改造蓝色方块改造center 属性说明参考资料 前言 RelativeContainer是鸿蒙的相对布局组件,它的布局很灵活,可以很方便的控制各个子UI 组件的相对位置,其布局理念有点类似于android的约束…

如何看待时间序列与机器学习?

GPT-4o 时间序列与机器学习的关联在于,时间序列数据是一种重要的结构化数据形式,而机器学习则是一种强大的工具,用于从数据中提取有用的模式和信息。在很多实际应用中,时间序列与机器学习可以结合起来,发挥重要作用。…

基于 Apache Doris 的实时/离线一体化架构,赋能中国联通 5G 全连接工厂解决方案

作者:田向阳,联通西部创新研究院 大数据专家 共创:SelectDB 技术团队 导读: 数据是 5G 全连接工厂的核心要素,为支持全方位的数据收集、存储、分析等工作的高效进行,联通 5G 全连接工厂从典型的 Lambda 架…

利用ArcGIS Python批量拼接遥感影像(arcpy batch processing)

本篇文章将说明如何利用ArcGIS 10.1自带的Python IDLE进行遥感影像的批量拼接与裁剪。 1.运行环境:ArcGIS10.1 (安装传送门)、Python IDLE 2.数据来源:地理空间数据云 GDEMV2 30M分辨率数字高程数据 3.解决问题:制作山西省的DEM影像 如下…

CMake的使用

文章目录 一、CMake概述二、CMake的使用1.注释2.简单编译程序3.定义变量4.指定使用的C标准5.指定输出的路径6.搜索文件7.包含头文件 三、通过CMake 制作库文件1.制作动静态库2.指定动静态库生成的路径3.在程序中链接静态库4.在程序中链接动态库 四、日志五、变量操作1.追加2.字…

521源码网-免费网络教程-Cloudflare使用加速解析-优化大陆访问速度

Cloudfalre 加速解析是由 心有网络 向中国大陆用户提供的公共优化服务 接入服务节点: cf.13d7s.sit 接入使用方式类似于其它CDN的CNAME接入,可以为中国大陆用户访问Cloudflare网络节点大幅度加速,累计节点130 如何接入使用 Cloudflare 加速解析&#…

python-模块-网络编程-多任务

一、模块 1-1 Python 自带模块 Json模块 处理json数据 {"key":"value"} json不是字典 本质是一个有引号的字符串数据 json注意点 {} 中的数据是字符串引号必须是双引号 使用json模块可以实现将json转为字典,使用字典的方法操作数据 。 或者将…

c基础 - 输入输出

目录 一.scanf() 和 printf() 函数 1.printf 2.scanf 二 . getchar() & putchar() 函数 1.int getchar(void) 2.int putchar(int c) 三. gets() & puts() 函数 一.scanf() 和 printf() 函数 #include <stdio.h> 需要引入头文件,stdio.h 1.printf print…

23种软件设计模式——工厂模式

工厂模式 工厂模式&#xff08;Factory Pattern&#xff09;是 Java 中最常用的设计模式之一&#xff0c;它提供了一种创建对象的方式&#xff0c;使得创建对象的过程与使用对象的过程分离。 工厂模式提供了一种创建对象的方式&#xff0c;而无需指定要创建的具体类。 通过使…

高级Web Lab2

高级Web Lab2 12 1 按照“Lab 2 基础学习文档”文档完成实验步骤 实验截图&#xff1a; 2 添加了Web3D场景选择按钮&#xff0c;可以选择目标课程或者学习房间。

[数据集][目标检测][数据集][目标检测]智能手机检测数据集VOC格式5447张

数据集格式&#xff1a;Pascal VOC格式(不包含分割的txt文件&#xff0c;仅仅包含jpg图片和对应的xml) 图片数量(jpg文件个数)&#xff1a;5447 标注数量(xml文件个数)&#xff1a;5447 标注类别数&#xff1a;1 标注类别名称:["phone"] 每个类别标注的框数&#xff…

详解生成式人工智能的开发过程

回到机器学习的“古老”时代&#xff0c;在您可以使用大型语言模型&#xff08;LLM&#xff09;作为调优模型的基础之前&#xff0c;您基本上必须在所有数据上训练每个可能的机器学习模型&#xff0c;以找到最佳&#xff08;或最不糟糕&#xff09;的拟合。 开发生成式人工智能…

爬虫在金融领域的应用:股票数据收集

介绍 在金融领域&#xff0c;准确及时的数据收集对于市场分析和投资决策至关重要。股票价格作为金融市场的重要指标之一&#xff0c;通过网络爬虫技术可以高效地从多个网站获取实时股票价格信息。本文将介绍网络爬虫在金融领域中的应用&#xff0c;重点讨论如何利用Scrapy框架…

【JVM精通之路】垃圾回收-三色标记算法

首先预期你已经基本了解垃圾回收的相关知识&#xff0c;包括新生代垃圾回收器&#xff0c;老年代垃圾回收器&#xff0c;以及他们的算法&#xff0c;可达性分析等等。 先想象一个场景 最开始黑色节点是GC-Roots的根节点&#xff0c;这些对象有这样的特点因此被选为垃圾回收的根…

Python3位运算符

前言 本文介绍的是位运算符&#xff0c;位运算可以理解成对二进制数字上的每一个位进行操作的运算&#xff0c;位运算分为 布尔位运算符 和 移位位运算符。 文章目录 前言一、位运算概览1、布尔位运算符1&#xff09;按位与运算符 ( & )2&#xff09;按位或运算符 ( | )3…

【设计模式深度剖析】【5】【结构型】【桥接模式】| 以电视和遥控器为例加深理解

&#x1f448;️上一篇:组合模式 | 下一篇:外观模式&#x1f449;️ 设计模式-专栏&#x1f448;️ 目 录 桥接模式(Bridge Pattern)定义英文原话是&#xff1a;直译理解 4个角色UML类图代码示例 应用优点缺点使用场景 示例解析&#xff1a;电视和遥控器UML类图 桥接模式…