目录
- 一、Seq2seq 简介
- 二、编码器
- 三、解码器
- 四、编码器-解码器的训练
遇到看不明白的地方,欢迎在评论中留言呐,一起讨论,一起进步!
需掌握的前提知识: LSTM、词嵌入
本文参考:【官方双语】编码、解码神经网络,一个视频讲清楚,seq2seq模型
一、Seq2seq 简介
Seq2seq(Sequence to Sequence)的作用是将一种序列转换为另一种序列,比如将英文句子翻译为中文句子,或者将一篇文章进行概括。
一种解决 Seq2seq 问题的方法是编码器-解码器模型。下面我们以英语句子翻译为西班牙句子为例来进行介绍。
二、编码器
首先先摆出编码器的作用:将输入的句子编码为上下文向量
我们知道不同的英语句子长度可能不同,我们需要让不同长度的句子作为输入;另外不同的中文句子长度也可能不同,因此我们还需要生成不同长度的句子作为输出。说到这里,我们会想到 LSTM,这个模型便可以处理具有可变长度的输入和输出。
现在,我们还需要把单词通过嵌入层转换为数字再塞入 LSTM 中。因为词汇中包含了单词和符号,我们将词汇表中的各个元素称作 tokens。各个 token 通过已经训练好的嵌入层便可以转换为数字词汇表。
我们将这个嵌入层放在 LSTM 的输入前面
从理论上来讲,这就是对输入的句子进行编码所需要的全部工作。然后在实践中,为了有更多的权重和偏差来让模型更加对数据适用,人们经常在输入中添加额外的 LSTM 单元。
简单起见,我们在这里只添加了一个额外的 LSTM 单元。这意味着单词的嵌入值会用作两个不同 LSTM 单元的输入值,而这两个不同的 LSTM 单元有单独的权重和偏差集。
为了再添加更多的权重和偏差使模型更好,人们会添加额外的 LSTM 层。
我们在这里再添加一层 LSTM 层。这意味着第一层 LSTM 单元的输出会用作第二层 LSTM 单元的输入。
最后我们初始化长期记忆和短期记忆之后,我们就完成了创建编码器的部分。本质上,编码器对输入的句子进行编码,形成长期记忆(细胞状态)和短期记忆(隐藏状态)的集合,即上下文向量(Context Vector)。
三、解码器
现在我们介绍解码器,它的作用是:解码上下文向量为输出句子
首先我们需要将长期记忆和短期记忆连接起来形成上下文向量,再切换到解码器中一组新的 LSTM,这个上下文向量用来初始化这个解码器的长期记忆和短期记忆。
像编码器一样,这组新的 LSTM 也有两层,每层有两个 LSTM 单元,但是这组 LSTM 有着自己独立的权重和偏差。
解码器需要将上下文向量进行解码,从而输出句子。像编码器一样,解码器第一层 LSTM 单元的输入来自嵌入层。现在,嵌入层创建了西班牙单词的嵌入值。
对比编码器和解码器的嵌入层,它们具有不同 tokens 作为输入,还具有不同的权重使的每个 token 的嵌入向量也有所不同。
解码器顶层的 LSTM 单元的长期记忆和短期记忆之后会作为全连接层(Full Connected Later)的输入,这个全连接层就是一个是基本的普通神经网络。全连接层的输出对应着西班牙词汇表中的每个 token,输入和输出之间通过权重和偏差进行联系。之后我们将输出通过 Softmax 函数,从而选出最后我们要输出的词。
解码器可以从最后一个 token 的嵌入值开始,也可以从第一个 token 的嵌入值开始。我们这里选择从 <EOS>
开始。
我们看到 Softmax 函数的输出是 Vamos
,即 let's go
的西班牙翻译。目前为止翻译是正确的,但是解码器直到输出结尾 token 即 <EOS>
才会停止。
所以我们将 Vamos
插入到解码器的嵌入层进行下一轮操作,最后我们得到的输出是 <EOS>
,到此为止我们就将英文句子 let's go
翻译为了正确的西班牙句子。
总结一下解码器:
- 由编码两层 LSTM 单元创建的上下文向量用于初始化解码器中的 LSTM 单元
- 解码器 LSTM 的输入来自以
<EOS>
开始的词嵌入层输出 - 解码器的输入又将作为下一轮的输入,直到输出为
<EOS>
或者达到输出的最大长度 - 通过将编码器和解码器解耦,输入和输出句子的长度可以不同(例如,上面我们将长度为 2 的英文句子翻译为了长度为 1 的西班牙句子)
四、编码器-解码器的训练
就像所有的神经网络一样,编码器-解码器中所有的权重和偏差都是通过反向传播进行训练的。但是在训练的时候有两个特别的地方:
- 在上面的解码器例子中,我们使用第一轮输出
Vamos
作为第二轮输入,但是在训练权重和偏差的时候,第二轮的输入不使用第一轮的预测值,而是使用已知的正确的 token。也就是说如果在训练的时候第一轮的输出是y
,这是一个错误的翻译,在第二轮我们仍然以正确的Vamos
作为输入,而不是以第一轮预测输出y
作为输入。 - 在上面的解码器例子中,我们需要等到输出为
EOS
才停止。但是在训练过程中,我们以正确翻译的长度为标准,比如训练时第二轮的输出是y
,但是正确的翻译已经到了EOS
,因此我们会在这时停止下一轮的展开,尽管预测输出还没有到了EOS
。这个现象被称为教师强迫(Teacher Forcing)