动手学深度学习——循环神经网络的从零开始实现(原理解释+代码详解)

文章目录

    • 循环神经网络的从零开始实现
      • 1. 独热编码
      • 2. 初始化模型参数
      • 3. 循环神经网络模型
      • 4. 预测
      • 5. 梯度裁剪
      • 6. 训练

循环神经网络的从零开始实现

从头开始基于循环神经网络实现字符级语言模型。

# 读取数据集
%matplotlib inline
import math
import torchfrom torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

1. 独热编码

每个词元都有一个对应的索引,表示为特征向量,即每个索引映射为相互不同的单位向量。

词元表不同词元个数为N,词元索引范围为0到N-1。词元的索引为整数,那么将创建一个长度为N的全0向量,并将第i处元素设置为1。则此向量是原始词元的一个独热编码。

假如有2个词元"cat"和"dog"

  • "cat"对应:[1, 0]
  • "dog"对应:[0, 1]

索引为0和2的独热向量

# 索引为0和2的独热向量
F.one_hot(torch.tensor([0, 2]), len(vocab))

在这里插入图片描述
采样的小批量数据形状为二维张量:(批量大小,时间步数),one_hot函数将其转换为三维张量:(时间步数,批量大小,词表大小)

# 采样的小批量数据形状为二维张量:(批量大小,时间步数)
# one_hot函数将其转换为三维张量:(时间步数,批量大小,词表大小)
# 方便我们通过最外层维度,一步一步更新小批量数据的隐状态
X = torch.arange(10).reshape((2, 5))
print(F.one_hot(X.T, 28).shape)
# 显示第一行
F.one_hot(X.T, 28)[0,:,:]

在这里插入图片描述

2. 初始化模型参数

隐藏单元数num_hiddens是一个可调的超参数

训练语言模型时,输入和输出来自相同的词表,具有相同的维度即词表大小

"""
初始化模型参数:1、隐藏层参数2、输出层参数3、附加梯度
"""
# (词表大小,隐藏层数,设备)
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_size# 定义函数normal(),初始化模型的参数def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))W_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

3. 循环神经网络模型

定义init_rnn_state函数在初始化时返回隐状态,该函数的返回是一个张量,张量全用0填充,形状为(批量大小,隐藏单元数)。

# 定义init_rnn_state函数在初始化时返回隐状态
# 该函数的返回是一个张量,张量全用0填充,形状为(批量大小,隐藏单元数)
def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

在这里插入图片描述
循环神经网络通过最外层的维度实现循环,以便时间步更新小批量数据的隐状态H

# 循环神经网络通过最外层的维度实现循环,以便时间步更新小批量数据的隐状态H
def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:# 激活函数tanh,更新隐状态HH = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

创建一个类来包装这些函数, 并存储从零开始实现的循环神经网络模型的参数

"""
从零开始实现的循环神经网络模型:
1、定义网络模型的参数
2、对词表进行独热编码
3、初始化模型参数并返回隐状态
"""
class RNNModelScratch: #@save"""从零开始实现的循环神经网络模型"""# 定义类的初始化,将传入的参数赋值给对象的属性,以便后续使用def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):# 对输入进行独热编码,返回状态及参数X = F.one_hot(X.T, self.vocab_size).type(torch.float32)return self.forward_fn(X, state, self.params)def begin_state(self, batch_size, device):# 初始化参数return self.init_state(batch_size, self.num_hiddens, device)

检查输出是否具有正确的形状。 例如,隐状态的维数是否保持不变。

num_hiddens = 512
# 网络模型
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
# 获得网络初始状态
state = net.begin_state(X.shape[0], d2l.try_gpu())
# 将X移到GPU上,并且返回输出Y和状态
Y, new_state = net(X.to(d2l.try_gpu()), state)
Y.shape, len(new_state), new_state[0].shape

在这里插入图片描述
可以看到输出形状是(时间步数x批量大小,词表大小), 而隐状态形状保持不变,即(批量大小,隐藏单元数)。

4. 预测

定义预测函数

"""
定义预测函数:
1、prefix是用户提供的字符串;
2、循环遍历prefix的开始字符时不输出,不断将隐状态传递给下一个时间步;
3、在此期间模型进行自我更新(隐状态),不进行预测;
4、2和3步骤称为预热期,预热期过后隐状态的值更适合预测,从而预测字符并输出。
"""
# prefix:前缀字符串
def predict_ch8(prefix, num_preds, net, vocab, device):  #@save"""在prefix后面生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]# 匿名函数:改变输出的形状get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))# 预热期:不进行输出for y in prefix[1:]:  # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])# 预热期过了之后,进行预测for _ in range(num_preds):  # 预测num_preds步y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])

测试predict_ch8函数。 我们将前缀指定为time traveller, 并基于这个前缀生成10个后续字符

# 测试predict_ch8函数。 我们将前缀指定为time traveller, 并基于这个前缀生成10个后续字符。
# 未训练模型,输出预测结果没有联系
predict_ch8('time traveller ', 10, net, vocab, d2l.try_gpu())

在这里插入图片描述

5. 梯度裁剪

为什么要梯度裁剪:
1、对于长度为T的序列,我们在迭代中计算T个时间步上的梯度,在反向传播过程中产生长度为T的矩阵乘法链;
2、T较大时,会导致数值不稳定,例如梯度消失或者梯度爆炸。

一个流行的替代方案是通过将梯度g投影回给定半径 (例如θ)的球来裁剪梯度g。
在这里插入图片描述

def grad_clipping(net, theta):  #@save"""裁剪梯度"""if isinstance(net, nn.Module):# 附加梯度的参数params = [p for p in net.parameters() if p.requires_grad]else:# 梯度的范数:对应图里作为分母的"||g||"params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))# 如果梯度过大,将其限制到θif norm > theta:for param in params:param.grad[:] *= theta / norm

6. 训练

在一个迭代周期内训练模型:
1、序列数据的不同采样方法(随机采样和顺序分区)将导致状态初始化的差异;
2、在更新模型参数之前裁剪梯度,这样可以保证训练过程中如果某点发生梯度爆炸,模型也不会发散;
3、用困惑度评价模型,使得不同长度的序列也有了可比性。

  • 顺序分区:只在每个迭代周期的开始位置初始化隐状态。
  • 随机抽样:每个样本都是在一个随机位置抽样的,因此需要在每个迭代周期重新初始化隐状态。
#@save
"""
训练网络一个迭代周期:
1、初始化状态,将数据传到GPU上
2、计算损失,进行梯度裁剪并更新模型参数
"""
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练网络一个迭代周期(定义见第8章)"""# 状态,时间state, timer = None, d2l.Timer()metric = d2l.Accumulator(2)  # 训练损失之和,词元数量for X, Y in train_iter:if state is None or use_random_iter:# 在第一次迭代或使用随机抽样时初始化statestate = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(net, nn.Module) and not isinstance(state, tuple):# state对于nn.GRU是个张量# detach_()将张量从计算图中分离出来,不会影响到原始张量state.detach_()else:# state对于nn.LSTM或对于我们从零开始实现的模型是个张量for s in state:s.detach_()# 将Y 进行转置并展平成一维向量y = Y.T.reshape(-1)# 将X,y移动到设备上,并且输入到模型中X, y = X.to(device), y.to(device)y_hat, state = net(X, state)l = loss(y_hat, y.long()).mean()# 如果更新器 updater 是 torch.optim.Optimizer 类型,则调用 updater.step() 方法进行参数更新;# 否则调用 updater(batch_size=1) 进行参数更新。if isinstance(updater, torch.optim.Optimizer):updater.zero_grad() # 梯度置零l.backward() # 反向传播,知道如何调整参数以最小化损失函数grad_clipping(net, 1) # 梯度裁剪updater.step() # 使用优化器来更新参数else:l.backward()grad_clipping(net, 1)# 因为已经调用了mean函数updater(batch_size=1)# y.numel()计算y中元素数量metric.add(l * y.numel(), y.numel())# 使用指数损失函数计算累积平均困惑度 math.exp(metric[0] / metric[1]) 和训练速度 metric[1] / timer.stop()。return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()
  • updater.zero_grad(): 这一行代码将模型参数的梯度置零,以便在每次迭代中计算新的梯度。
  • l.backward(): 这一行代码使用反向传播算法计算损失函数对模型参数的梯度。通过计算梯度,我们可以知道如何调整模型参数以最小化损失函数。
  • grad_clipping(net, 1): 这一行代码对模型的梯度进行裁剪,以防止梯度爆炸的问题。梯度爆炸可能会导致训练不稳定,裁剪梯度可以限制梯度的范围。
  • updater.step(): 这一行代码使用优化器(如SGD、Adam等)来更新模型的参数。优化器根据计算得到的梯度和预定义的学习率来更新模型参数,以使模型更好地拟合训练数据。

循环神经网络的训练函数也支持高级API实现

# 循环神经网络的训练函数也支持高级API实现
#@save
def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型(定义见第8章)"""loss = nn.CrossEntropyLoss()# 动画窗口:窗口显示一个图例,图例名称为 "train",x 轴的范围从 10 到 num_epochsanimator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(), lr)else:updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 训练和预测for epoch in range(num_epochs):ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter)# 每10个epoch,对输入字符串进行预测,并将预测结果添加到动画中if (epoch + 1) % 10 == 0:print(predict('time traveller'))animator.add(epoch + 1, [ppl])print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')print(predict('time traveller'))print(predict('traveller'))

在数据集中只使用了10000个词元, 所以模型需要更多的迭代周期来更好地收敛

# 在数据集中只使用了10000个词元, 所以模型需要更多的迭代周期来更好地收敛
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())

在这里插入图片描述
检查一下随机抽样方法的结果

# 检查一下随机抽样方法的结果
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(),use_random_iter=True)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/198048.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sqli-labs关卡20(基于http头部报错盲注)通关思路

文章目录 前言一、回顾上一关知识点二、靶场第二十关通关思路1、判断注入点2、爆数据库名3、爆数据库表4、爆数据库列5、爆数据库关键信息 总结 前言 此文章只用于学习和反思巩固sql注入知识,禁止用于做非法攻击。注意靶场是可以练习的平台,不能随意去尚…

【Linux】安全审计-audit

文章目录 一、audit简介二、开启auditd服务三、相关文件四、审计规则五、审计日志查询及分析附录1:auditctl -h附录2:systemcall 类型 参考文章: 1、安全-linux audit审计使用入门 2、audit详细使用配置 3、Linux-有哪些常见的System Call&a…

golang学习笔记——接口interfaces

文章目录 Go 语言接口例子空接口空接口的定义空接口的应用空接口作为函数的参数空接口作为map的值 类型断言接口值 类型断言例子001类型断言例子002类型断言例子003巩固练习 Go 语言接口 接口(interface)定义了一个对象的行为规范,只定义规范…

Java面向对象(高级)-- 类的成员之四:代码块

文章目录 一、回顾(1)三条主线(2)类中可以声明的结构及作用1.结构2.作用 二、代码块(1)代码块的修饰与分类1. 代码块的修饰2. 代码块的分类3. 举例 (2) 静态代码块1. 语法格式2. 静态…

【数据结构】栈与队列面试题(C语言)

我们再用C语言做题时,是比较不方便的,因此我们在用到数据结构中的某些时只能手搓或者Ctrlcv 我们这里用到的栈或队列来自栈与队列的实现 目录 有效的括号解题思路:代码实现: 用队列实现栈解题思路:代码实现&#xff1a…

4月2日-3日·上海 | 3DCC 第二届3D细胞培养与类器官研发峰会携手CGT Asia 重磅来袭

类器官(Organoids)作为干细胞研究领域最重要的成果之一,在基础医学研究、转化医学及药物研发领域展现出巨大的应用潜力,特别是在精准医疗以及药物安全性和有效性评价等方向凭借其先天优势引起了极大的市场关注,成为各大…

LabVIEW进行MQTT通信及数据解析

需求:一般通过串口的方式进行数据的解析,但有时候硬件的限制,没法预留串口,那么如何通过网络的方式特别是MQTT数据的通信及解析 解决方式: 1.MQTT通信控件: 参考开源的mqtt-LabVIEW https://github.com…

【iOS】——知乎日报第五周总结

文章目录 一、评论区展开与收缩二、FMDB库实现本地持久化FMDB常用类:FMDB的简单使用: 三、点赞和收藏的持久化 一、评论区展开与收缩 有的评论没有被回复评论或者被回复评论过短,这时就不需要展开全文的按钮,所以首先计算被回复评…

量化交易:借助talib使用技术分析指标

什么是技术分析? 所谓股票的技术分析,是相对于基本面分析而言的。基本分析法着重于对一般经济情况以及各个公司的经营管理状况、行业动态等因素进行分析,以此来研究股票的价值,衡量股价的高低。而技术分析则是透过图表或技术指标…

vulhub redis-4-unacc

环境搭建 cd vulhub/redis/4-unacc docker-compose up -d 漏洞复现 检测 redis-cli -h ip 使用redis工具 工具地址:https://github.com/vulhub/redis-rogue-getshell 下载完成后,先进入RedisModulesSDK/exp/ 目录进行make操作 获得exp.so后可以进行…

Jenkinsfile+Dockerfile前端vue自动化部署

前言 本篇主要介绍如何自动化部署前端vue项目 其中,有两种方案: 第一种是利用nginx进行静态资源转发;第二种方案是利用nodejs进行启动访问; 各个组件版本如下: Docker 最新版本;Jenkins 2.387.3nginx …

物联网AI MicroPython学习之语法 I2C总线

学物联网,来万物简单IoT物联网!! I2C 介绍 模块功能: I2C Master设备驱动 接口说明 I2C - 构建硬件I2C对象 函数原型:I2C(id, scl, sda, freq)参数说明: 参数类型必选参数?说明idintYI2C外设&#xff…

带你快速掌握Linux最常用的命令(图文详解)- 最新版(面试笔试常考)

最常用的Linux指令(图文详解)- 最新版 ls:列出目录中的文件和子目录(重点)cd:改变当前工作目录绝对路径:相对路径 pwd:显示当前工作目录的路径mkdir:创建一个新的目录tou…

【开源】基于Vue.js的音乐偏好度推荐系统的设计和实现

项目编号: S 012 ,文末获取源码。 \color{red}{项目编号:S012,文末获取源码。} 项目编号:S012,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、系统设计2.1 功能模块设计2.1.1 音乐档案模块2.1…

spring中的DI

【知识要点】 控制反转(IOC)将对象的创建权限交给第三方模块完成,第三方模块需要将创建好的对象,以某种合适的方式交给引用对象去使用,这个过程称为依赖注入(DI)。如:A对象如果需要…

代码随想录算法训练营Day 54 || 392.判断子序列、115.不同的子序列

392.判断子序列 力扣题目链接(opens new window) 给定字符串 s 和 t ,判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些(也可以不删除)字符而不改变剩余字符相对位置形成的新字符串。(例如,&quo…

多svn仓库一键更新脚本分享

之前分享过多git仓库一键更新脚本,本期就分享下svn仓库的一键更新脚本 1、首先需要设置svn为可执行命令行 打开SVN安装程序,选择modify,然后点击 command client tools,安装命令行工具 2、update脚本 echo 开始更新SVN目录&…

计算机视觉:使用opencv实现车牌识别

1 引言 汽车车牌识别(License Plate Recognition)是一个日常生活中的普遍应用,特别是在智能交通系统中,汽车牌照识别发挥了巨大的作用。汽车牌照的自动识别技术是把处理图像的方法与计算机的软件技术相连接在一起,以准…

Linux管道的工作过程

常用的匿名管道(Anonymous Pipes),也即将多个命令串起来的竖线。管道的创建,需要通过下面这个系统调用。 int pipe(int fd[2]) 我们创建了一个管道 pipe,返回了两个文件描述符,这表示管道的两端&#xff…

jvm 内存结构 ^_^

1. 程序计数器 2. 虚拟机栈 3. 本地方法栈 4. 堆 5. 方法区 程序计数器 定义: Program Counter Register 程序计数器(寄存器) 作用,是记住下一条jvm指令的执行地址 特点: 是线程私有的 不会存在内存溢出 虚拟机栈…