【课程总结】Day18:Seq2Seq的深入了解

前言

在上一章【课程总结】Day17(下):初始Seq2Seq模型中,我们初步了解了Seq2Seq模型的基本情况及代码运行效果,本章内容将深入了解Seq2Seq模型的代码,梳理代码的框架图、各部分组成部分以及运行流程。

框架图

工程目录结构

查看项目目录结构如下:

seq2seq_demo/
├── data.txt                     # 原始数据文件,包含训练或测试数据
├── dataloader.py                # 数据加载器,负责读取和预处理数据
├── decoder.py                   # 解码器实现,用于生成输出序列
├── encoder.py                   # 编码器实现,将输入序列编码为上下文向量
├── main.py                      # 主程序入口,执行模型训练和推理
├── seq2seq.py                   # seq2seq 模型的实现,整合编码器和解码器
└── tokenizer.py                 # 分词器实现,将文本转换为模型可处理的格式

查看各个py文件整理关系图结构如下:

  • main.py 文件是主程序入口,同时其中也定义了 Translation类,用于训练和推理。
  • Translation类在 __init__() 方法中调用 get_tokenizer() 方法实例化tokenizer对象。
  • Translation类在 __init__() 方法中调用 get_model() 实例化seq2seq类对象,进而实例化 EncoderDecoder 对象。
  • Translation类在 train() 方法中调用 get_dataloader() 方法实例化dataloader对象。

核心逻辑

初始化过程

  • 上述流程中较为重要的代码主要是 build_dict() 、encoder实例化、decoder实例化初始化过程:
Build_dict()
def build_dict(self):"""构建字典"""if os.path.exists(self.saved_dict):self.load()print("加载本地字典成功")returninput_words = {"<UNK>", "<PAD>"}output_words = {"<UNK>", "<PAD>", "<SOS>", "<EOS>"}with open(file=self.data_file, mode="r", encoding="utf8") as f:for line in tqdm(f.readlines()):if line:input_sentence, output_sentence = line.strip().split("\t")input_sentence_words = self.split_input(input_sentence)output_sentence_words = self.split_output(output_sentence)input_words = input_words.union(set(input_sentence_words))output_words = output_words.union(set(output_sentence_words))# 输入字典self.input_word2idx = {word: idx for idx, word in enumerate(input_words)}self.input_idx2word = {idx: word for word, idx in self.input_word2idx.items()}self.input_dict_len = len(self.input_word2idx)# 输出字典self.output_word2idx = {word: idx for idx, word in enumerate(output_words)}self.output_idx2word = {idx: word for word, idx in self.output_word2idx.items()}self.output_dict_len = len(self.output_word2idx)# 保存self.save()print("保存字典成功")

代码解析:

  • 首先,判断本地是否有字典,有的话直接加载;
  • 其次,在input_wordsoutput_words 集合中添加特殊符号(special tokens):
    • <UNK>:表示未知单词,用于表示输入序列中未在字典中找到的单词;
    • <PAD>:表示填充符号,用于填充输入序列和输出序列,使它们具有相同的长度;
    • <SOS>:表示序列的开始,用于表示输出序列的起始位置;
    • <EOS>:表示序列的结束,用于表示输出序列的结束位置。
  • 然后,读取data.txt文件,以\t切分数据并切分单词:
    • 输入的英文调用split_input进行预处理,例如:I’m a student.→[‘i’, ‘m’, ‘a’, ‘student’, ‘.’]
    • 输出的中文调用split_output进行切分,例如:我爱北京天安门→[‘我’, ‘爱’, ‘北京’, ‘天安门’]
  • 最后,调用self.save() 方法将字典保存到本地文件 self.saved_dict 中。
encoder
import torch
from torch import nnclass Encoder(nn.Module):"""定义一个 编码器"""def __init__(self, tokenizer):super(Encoder, self).__init__()self.tokenizer = tokenizer# 嵌入层self.embed = nn.Embedding(num_embeddings=self.tokenizer.input_dict_len,embedding_dim=self.tokenizer.input_embed_dim,padding_idx=self.tokenizer.input_word2idx.get("<PAD>"))# GRU单元self.gru = nn.GRU(input_size=self.tokenizer.input_embed_dim,hidden_size=self.tokenizer.input_hidden_size,batch_first=False)def forward(self, x, x_len):# [seq_len, batch_size] --> [seq_len, batch_size, embed_dim]x = self.embed(x)# 压紧被填充的序列x = nn.utils.rnn.pack_padded_sequence(input=x,lengths=x_len,batch_first=False)out, hn = self.gru(x)# 填充被压紧的序列out, out_len = nn.utils.rnn.pad_packed_sequence(sequence=out,batch_first=False,padding_value=self.tokenizer.input_word2idx.get("<PAD>"))# out: [seq_len, batch_size, hidden_size]# hn: [1, batch_size, hidden_size]return out, hn

代码解析:

  • encoder是一个典型的RNN结构,其定义了embedding层用于词嵌入,以及GRU单元进行序列处理。
  • forward方法中,首先将输入序列进行词嵌入,然后使用pack_padded_sequence将被填充的序列压紧,以便于GRU单元处理。
decoder
import torch
from torch import nn
import randomdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Decoder(nn.Module):def __init__(self, tokenizer):super(Decoder, self).__init__()self.tokenizer = tokenizer# 嵌入self.embed = nn.Embedding(num_embeddings=self.tokenizer.output_dict_len,embedding_dim=self.tokenizer.output_embed_dim,padding_idx=self.tokenizer.output_word2idx.get("<PAD>"),)# 抽取特征self.gru = nn.GRU(input_size=self.tokenizer.output_embed_dim,hidden_size=self.tokenizer.output_hidden_size,batch_first=False,)# 转换维度,做概率输出self.fc = nn.Linear(in_features=self.tokenizer.output_hidden_size,out_features=self.tokenizer.output_dict_len,)def forward_step(self, decoder_input, decoder_hidden):"""单步解码:decoder_input: [1, batch_size]decoder_hidden: [1, batch_size, hidden_size]"""# [1, batch_size] --> [1, batch_size, embedding_dim]decoder_input = self.embed(decoder_input)# 输入:[1, batch_size, embedding_dim] [1, batch_size, hidden_size]# 输出:[1, batch_size, hidden_size] [1, batch_size, hidden_size]# 因为只有1步,所以 out 跟 decoder_hidden是一样的out, decoder_hidden = self.gru(decoder_input, decoder_hidden)# [batch_size, hidden_size]out = out.squeeze(dim=0)# [batch_size, dict_len]out = self.fc(out)# out: [batch_size, dict_len]# decoder_hidden: [1, batch_size, hidden_size]return out, decoder_hiddendef forward(self, encoder_hidden, y, y_len):"""训练时的正向传播- encoder_hidden: [1, batch_size, hidden_size]- y: [seq_len, batch_size]- y_len: [batch_size]"""# 计算输出的最大长度(本批数据的最大长度)output_max_len = max(y_len.tolist()) + 1# 本批数据的批量大小batch_size = encoder_hidden.size(1)# 输入信号 SOS  读取第0步,启动信号# decoder_input: [1, batch_size]# 输入信号 SOS [1, batch_size]decoder_input = torch.LongTensor([[self.tokenizer.output_word2idx.get("<SOS>")] * batch_size]).to(device=device)# 收集所有的预测结果# decoder_outputs: [seq_len, batch_size, dict_len]decoder_outputs = torch.zeros(output_max_len, batch_size, self.tokenizer.output_dict_len)# 隐藏状态 [1, batch_size, hidden_size]decoder_hidden = encoder_hidden# 手动循环for t in range(output_max_len):# 输入:decoder_input: [batch_size, dict_len], decoder_hidden: [1, batch_size, hidden_size]# 返回值:decoder_output_t: [batch_size, dict_len], decoder_hidden: [1, batch_size, hidden_size]decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)# 填充结果张量 [seq_len, batch_size, dict_len]decoder_outputs[t, :, :] = decoder_output_t# teacher forcing 教师强迫机制use_teacher_forcing = random.random() > 0.5# 0.5 概率 实行教师强迫if use_teacher_forcing:# [1, batch_size] 取标签中的下一个词decoder_input = y[t, :].unsqueeze(0)else:# 取出上一步的推理结果 [1, batch_size]decoder_input = decoder_output_t.argmax(dim=-1).unsqueeze(0)# decoder_outputs: [seq_len, batch_size, dict_len]return decoder_outputs# ...(其他函数暂略)

代码解析:

  • decoder定义了三个层:embed(词嵌入)、gru和fc(全链接层)。
  • 全链接层用于输出的是字典长度,即每个位置代表着每个字的概率。
  • decoder的forward_step方法,用于一步一步地执行,属于手动循环;forward方法,把所有步都执行完进行推理,属于自动循环。
  • forward方法中:
    • 首先,计算本批数据的最大长度(用于标签对齐)
    • 其次,使用encoder_hidden.size(1)获取批量大小
    • 然后,增加启动信号,即<SOS>
    • 然后,准备全0的张量 decoder_outputs
    • 然后,开始循环
      • 在循环每一步中,将输入和隐藏状态传给forward_step进行处理,得到输出概率decoder_output_t
      • 将结果概率放在decoder_outputs
      • 启用教师强迫机制(teacher forcing):
        • 即有50%概率,使用标准答案作为下一步的输入;
        • 否则,使用上一步的推理结果中概率最大的词作为下一步的输入。
    • 最后,返回结果概率张量 decoder_outputs

训练过程

  • 上述流程中较为重要的代码主要是 调用collate_fn具体训练过程手动循环进行正向推理
调用collate_fn
def collate_fn(batch, tokenizer):# 根据 x 的长度来 倒序排列batch = sorted(batch, key=lambda ele: ele[1], reverse=True)# 合并整个批量的每一部分input_sentences, input_sentence_lens, output_sentences, output_sentence_lens = zip(*batch)# 转索引【按本批量最大长度来填充】input_sentence_len = input_sentence_lens[0]input_idxes = []for input_sentence in input_sentences:input_idxes.append(tokenizer.encode_input(input_sentence, input_sentence_len))# 转索引【按本批量最大长度来填充】output_sentence_len = max(output_sentence_lens)output_idxes = []for output_sentence in output_sentences:output_idxes.append(tokenizer.encode_output(output_sentence, output_sentence_len))# 转张量 [seq_len, batch_size]input_idxes = torch.LongTensor(input_idxes).t()output_idxes = torch.LongTensor(output_idxes).t()input_sentence_lens = torch.LongTensor(input_sentence_lens)output_sentence_lens = torch.LongTensor(output_sentence_lens)return input_idxes, input_sentence_lens, output_idxes, output_sentence_lens

代码解析:

  • 当文字长度不一样齐的时候,需要进行补充<PAD>,以保持所有序列长度一致

例如:
I’m a student.
I’m OK.
Here is your change.

  • 但是补充<PAD>本身对训练过程会造成干扰,所以我们需要采用一种机制:既保证对齐数据批量化训练,又能消除填充对训练过程的影响。
  • 这种机制原理:在训练时知道实际的数据长度,这样在训练时就可以略过<PAD>。
  • torch提供了相应的API,其大致过程是:
    • 首先,根据 x(上句) 的长度倒序排序
    • 其次,获取本批量最大的长度
    • 然后,将数据填充到本批量最大长度
    • 最后,在返回数据时,不知返回数据,还会带着真实长度
具体训练过程
    # (其他部分代码略)# 训练过程is_complete = Falsefor epoch in range(self.epochs):self.model.train()for batch_idx, (x, x_len, y, y_len) in enumerate(train_dataloader):x = x.to(device=self.device)y = y.to(device=self.device)results = self.model(x, x_len, y, y_len)loss = self.get_loss(decoder_outputs=results, y=y)# 简单判定一下,如果损失小于0.5,则训练提前完成if loss.item() < 0.3:is_complete = Trueprint(f"训练提前完成, 本批次损失为:{loss.item()}")breakloss.backward()self.optimizer.step()self.optimizer.zero_grad()# 过程监控with torch.no_grad():if batch_idx % 100 == 0:print(f"第 {epoch + 1}{batch_idx + 1} 批, 当前批次损失: {loss.item()}")x_true = self.get_real_input(x)y_pred = self.model.batch_infer(x, x_len)y_true = self.get_real_output(y)samples = random.sample(population=range(x.size(1)), k=2)for idx in samples:print("\t真实输入:", x_true[idx])print("\t真实结果:", y_true[idx])print("\t预测结果:", y_pred[idx])print("\t----------------------------------------------------------")# 外层提前退出if is_complete:# print("训练提前完成")break# 保存模型torch.save(obj=self.model.state_dict(), f="./model.pt")
手动循环进行正向推理
    #(其他部分略)def batch_infer(self, encoder_hidden):"""推理时的正向传播- encoder_hidden: [1, batch_size, hidden_size]"""# 推理时,设定一个最大的固定长度output_max_len = self.tokenizer.output_max_len# 获取批量大小batch_size = encoder_hidden.size(1)# 输入信号 SOS [1, batch_size]decoder_input = torch.LongTensor([[self.tokenizer.output_word2idx.get("<SOS>")] * batch_size]).to(device=device)# print(decoder_input)results = []# 隐藏状态# encoder_hidden: [1, batch_size, hidden_size]decoder_hidden = encoder_hiddenwith torch.no_grad():# 手动循环for t in range(output_max_len):# decoder_input: [1, batch_size]# decoder_hidden: [1, batch_size, hidden_size]decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)# 取出结果 [1, batch_size]decoder_input = decoder_output_t.argmax(dim=-1).unsqueeze(0)results.append(decoder_input)# [seq_len, batch_size]results = torch.cat(tensors=results, dim=0)return results

代码解析:

  • 相比训练的时候,推理的时候函数入参没有y标准答案。
  • 推理的过程:
    • (与训练类似)获取最大长度、获取批量大小、构建启动信号。
    • (与训练不同)在无梯度环境里,调用forward_step函数,进行循环推理。
    • (与训练不同)因为推理时不需要teacher forcing机制,所以直接使用贪心思想获得概率最大的词。
    • 循环结束后,将结果拼接起来,返回。

补充知识

tqdm

定义

tqdm 是一个用于在 Python 中显示进度条的库,非常适合在长时间运行的循环中使用。

安装方法
pip install tqdm
使用方法
from tqdm import tqdm
import time# 示例:在一个简单的循环中使用 tqdm
for i in tqdm(range(10)):time.sleep(1)  # 模拟某个耗时操作

运行结果:

OpenCC

定义

OpenCC(Open Chinese Convert)是一个用于简体中文和繁体中文之间转换的工具

安装方法
pip install OpenCC
使用方法
import opencc# 创建转换器,使用简体到繁体的配置
converter = opencc.OpenCC('s2t')  # s2t: 简体到繁体# 输入简体中文
simplified_text = "我爱编程"# 进行转换
traditional_text = converter.convert(simplified_text)print(traditional_text)  
# 输出结果:我愛編程

内容小结

  • Seq2Seq项目整体组成由tokenizer(分词器)、dataloader(数据加载)、encoder(编码器)、decoder(解码器)、seq2seq和main六个部分组成
  • 在分词器中重点工作是构建自定义字典,并添加特殊符号(special tokens)
    • <UNK>:表示未知单词,用于表示输入序列中未在字典中找到的单词;
    • <PAD>:表示填充符号,用于填充输入序列和输出序列,使它们具有相同的长度;
    • <SOS>:表示序列的开始,用于表示输出序列的起始位置;上文不会增加。
    • <EOS>:表示序列的结束,用于表示输出序列的结束位置,上文不会增加。
  • 在decoder的forward函数中,增加了一个teacher_forcing_ratio参数,用于控制是否使用教师强迫机制。
    • 有50%概率,使用标准答案作为下一步的输入;
    • 有50%概率,使用上一步的推理结果中概率最大的词作为下一步的输入。
    • 该机制用于提升训练速度。
  • 在训练过程中会使用collate_fn用于数据对齐时消除PAD的影响。

参考资料

(暂无)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/391499.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【大模型系列】LanguageBind(ICLR2024.01)

Paper&#xff1a;https://arxiv.org/abs/2310.01852Github&#xff1a;https://github.com/PKU-YuanGroup/LanguageBindHuggingface&#xff1a;https://huggingface.co/spaces/LanguageBind/LanguageBindAuthor&#xff1a;Bin Zhu et al. 北大袁粒团队 文章目录 1 LanguageB…

入门mem0.NET

入门mem0.NET 安装包 如果你的项目使用了EntityFrameworkCore,那么你可以跟随这个教程走 <ItemGroup><PackageReference Include"mem0.NET" Version"0.1.7" /><PackageReference Include"mem0.NET.Qdrant" Version"0.1.7…

软件测试需要具备的基础知识【功能测试】---前端知识(一)

​ ​ 您好&#xff0c;我是程序员小羊&#xff01; 前言 为了更好的学习软件测试的相关技能&#xff0c;需要具备一定的基础知识。需要学习的基础知识包括&#xff1a; 1、计算机基础 2、前端知识 3、后端知识 4、软件测试理论 后期分四篇文章进行编写&#xff0c;这是第二篇 …

【精选】通信与感知(ISAC)必读好文

微信公众号&#xff1a;EW Frontier 个人博客&#xff1a;106.54.201.174 QQ交流群&#xff1a;949444104 简介 通信与感知&#xff08;ISAC&#xff09;也被称为联合雷达通信 (JRC) / 联合通信和雷达传感 (JCAS) / 双功能雷达通信 (DFRC) 定义&#xff1a;将传感和通信系统集…

记录一次学习过程(msf、cs的使用、横向渗透等等)

目录 用python搭建一个简单的web服务器 代码解释 MSF msfvenom 功能 用途 查看payloads列表 msfconsole 功能 用途 msfvenom和msfconsole配合使用 来个例子 msf会话中用到的一些命令 在windows中net user用法 列出所有用户账户 显示单个用户账户信息 创建用户账…

学python的第一天:PyCharm创建项目

创建项目 打开工具 PyCharm 点击“新建项目” 点击“创建” 环境 系统会创建虚拟环境&#xff0c;稍等 初始设置 创建完成后会进入main.py文件 性能 可以看到 右下角提示我们增强性能&#xff0c;点“自动” 会获取到管理员权限 完成后会提示完成

【数据结构】栈和队列(c语言实现)(附源码)

&#x1f31f;&#x1f31f;作者主页&#xff1a;ephemerals__ &#x1f31f;&#x1f31f;所属专栏&#xff1a;数据结构 目录 一、栈 1.栈的概念与结构 2.栈的实现 2.1 栈的结构定义 2.2 方法的声明 2.3 方法的实现 2.3.1 初始化 2.3.2 销毁 2.3.3 判空 2.3.4 压…

常见CMS漏洞(WordPress、DeDeCMS、ASPCMS、PHPMyadmin、Pageadmin)

目录 一&#xff1a;WordPress 步骤一:进入Vulhub靶场并执行以下命令开启靶场;在浏览器中访问并安装好子... 步骤二:思路是修改其WP的模板写入一句话木马后门并访问其文件即可GetShel;登陆WP后点击【外观】--》【编辑】 --》 404.php 步骤三:访问以下连接即可获取WebShel...…

用VBA在Word中随机打乱单词表,进行分列

一、效果展示&#xff08;以下是三次随机打乱的结果&#xff09; 二、代码 Sub 随机分单词到后面的单元格()Dim C1 As CellDim str, str1, aDim shuffledArray() As VariantSet C1 Selection.Range.Tables(1).Cell(1, 1)str C1.Range.textstr mid(str, 3, Len(str) - 4)str…

ADC的介绍和工作原理

一&#xff0c;什么是ADC&#xff1f; Analog-to-Digital Converter&#xff0c;指模拟/数字转换器 什么是ADC&#xff1a; ADC可以将引脚上连续变化的模拟电压转换为内存中存储的数字变量&#xff0c;建立模拟电路到数字电路的桥梁 SUCH AS: 12 位 ADC 是一种逐次逼近…

C# Solidworks二次开发------设置按键打开模型查询

一、代码 public void Open_File(string FileNmae) {Process.Start("explorer.exe", FileNmae); }Open_File("路径"); 二、内容 这个代码很简单&#xff0c;我使用其主要的作用是设置一个按键&#xff0c;可以快速的查看我们已生成的三维模型&#xff0…

JS使用 navigator.clipboard 操作剪切板

注意&#xff1a;需要在安全域下才能够使用&#xff0c;比如&#xff1a;https 协议的地址、127.0.0.1、localhost safari浏览器需要打开配置&#xff0c;在地址栏输入 about:config&#xff0c;搜索 clipboard&#xff0c;将 asyncClipboard 由 false 改为 true&#xff0c;然…

C语言初阶(11)

1.结构体定义 结构体就是一群数据类型的集合体。这些数据类型被称为成员变量。结构的成员可以是标量、数组、指针&#xff0c;甚至是其他结构体。 2.结构体的声明和结构体变量命名与初始化 结构体声明由以下结构组成 struct stu {char name[12];int age; }; 结构体命名有两…

算法通关:017_2:二叉树及三种顺序的非递归遍历

文章目录 题目思路运行结果 题目 二叉树及三种顺序的非递归遍历 思路 import java.util.Stack;/*** Author: ggdpzhk* CreateTime: 2024-08-04* 二叉树非递归版本*/ public class _017_Tree2 {public static void main(String[] args) {TreeNode head new TreeNode(1);head.…

keil编程中#pragma NOAREGS的作用和优点

参考 功能 不直接操作内存地址 #pragma NOAREGS在Keil中的使用含义是禁用自动分配寄存器&#xff0c;开发人员指定控制的寄存器。‌例如中断的执行使用的寄存器需要人为的指定&#xff0c;避免分配同样的寄存器导致数据错误。对寄存器R0到R7不直接操作寄存器地址&#xff0c…

学习笔记-Cookie、Session、JWT

目录 一、验证码的生成与校验 1. 创建生成验证码的工具类 2. 写一个 Controller 3. 实现验证码验证 1. 获取验证码 2. 验证码请求过程 3. 验证码的校验 4. 原理说明 5. 验证 6. 总结 二、JWT登录鉴权 1. 为什么要做登录鉴权&#xff1f; 2. 什么是 JWT 3. JWT相比…

Open Interpreter - 开放解释器

文章目录 一、关于演示它是如何工作的&#xff1f;与 ChatGPT 的代码解释器比较 二、快速开始三、更多操作1、互动聊天2、程序化聊天3、开始新的聊天4、保存和恢复聊天5、自定义系统消息6、更改模型7、在本地运行 Open Interpreter终端Python上下文窗口&#xff0c;最大令牌 8、…

【Golang 面试 - 进阶题】每日 3 题(十四)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…

python pip怎么安装包

按WinR键打开运行窗口&#xff0c;输入“cmd”&#xff0c;再按回车键&#xff0c;打开命令行窗口。 找到pip安装路径。 Python2/Python3安装路径是相同的&#xff0c;都在x:\Python xx\Scripts路径下。 拖动pip主应用程序到命令行窗口。 输入“install 模块/包名”&#xff…

【Golang 面试 - 进阶题】每日 3 题(十)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…