【深度学习】循环神经网络及文本生成模型构建

循环神经网络

词嵌入层

词嵌入层的作用就是将文本转换为向量。
​ 词嵌入层首先会根据输入的词的数量构建一个词向量矩阵,例如: 我们有 100 个词,每个词希望转换成 128 维度的向量,那么构建的矩阵形状即为: 100*128,输入的每个词都对应了一个该矩阵中的一个向量.

在这里插入图片描述

在 PyTorch 中,使用 nn.Embedding 词嵌入层来实现输入词的向量化。

nn.Embedding(num_embeddings=10, embedding_dim=4)
  • nn.Embedding 对象构建时,最主要有两个参数:
    • num_embeddings 表示词的数量
    • embedding_dim 表示用多少维的向量来表示每个词

​ 接下来,我们将会学习如何将词转换为词向量,其步骤如下:
​ 先将语料进行分词,构建词与索引的映射,我们可以把这个映射叫做词表,词表中每个词都对应了一个唯一的索引
​ 然后使用 nn.Embedding 构建词嵌入矩阵,词索引对应的向量即为该词对应的数值化后的向量表示。
​ 例如,我们的文本数据为: “北京冬奥的进度条已经过半,不少外国运动员在完成自己的比赛后踏上归途。”,

import torch
import torch.nn as nn
import jiebaif __name__ == '__main__':# 0.文本数据text = '北京冬奥的进度条已经过半,不少外国运动员在完成自己的比赛后踏上归途。'# 1. 文本分词words = jieba.lcut(text)print('文本分词:', words)# 2.分词去重并保留原来的顺序获取所有的词语unique_words = list(set(words))print("去重后词的个数:\n",len(unique_words))# 3. 构建词嵌入层:num_embeddings: 表示词的总数量;embedding_dim: 表示词嵌入的维度embed = nn.Embedding(num_embeddings=len(unique_words), embedding_dim=4)print("词嵌入的结果:\n",embed)# 4. 词语的词向量表示for i, word in enumerate(unique_words):# 获得词嵌入向量word_vec = embed(torch.tensor(i))print('%3s\t' % word, word_vec)

在这里插入图片描述

RNN网络原理

​ 文本数据是具有序列特性的
​ 例如: “我爱你”, 这串文本就是具有序列关系的,“爱” 需要在 “我” 之后,“你” 需要在 “爱” 之后, 如果颠倒了顺序,那么可能就会表达不同的意思。
​ 为了表示出数据的序列关系,需要使用循环神经网络(Recurrent Nearal Networks, RNN) 来对数据进行建模,RNN 是一个作用于处理带有序列特点的样本数据。
在这里插入图片描述

​ h 表示隐藏状态,
​ 每一次的输入都会包含两个值: 上一个时间步的隐藏状态、当前状态的输入值,输出当前时间步的隐藏状态和当前时间步的预测结果。上图有三个神经元处理’我爱你’这三个字,实际上是一个他们三个字重复输入到同一个神经元

在这里插入图片描述

​ 我们举个例子来理解上图的工作过程,假设我们要实现文本生成,也就是输入 “我爱” 这两个字,来预测出 “你”,其如下图所示:

在这里插入图片描述

将上图展开成不同时间步的形式,如下图所示:

在这里插入图片描述

​ 首先初始化出第一个隐藏状态h0,一般都是全0的一个向量,然后将 “我” 进行词嵌入,转换为向量的表示形式,送入到第一个时间步,然后输出隐藏状态 h1,然后将 h1 和 “爱” 输入到第二个时间步,得到隐藏状态 h2, 将 h2 送入到全连接网络,得到 “你” 的预测概率。

在这里插入图片描述

上述公式中:
Wih 表示输入数据的权重
bih 表示输入数据的偏置
Whh 表示输入隐藏状态的权重
bhh 表示输入隐藏状态的偏置
最后对输出的结果使用 tanh 激活函数进行计算,得到该神经元你的输出。

在这里插入图片描述

Pytorch RNN层的使用

  • 输入数据和输出结果
    将RNN实例化就可以将数据送入其中进行处理,处理的方式如下所示:
output, hn = RNN(x, h0)
  • 输入数据:输入主要包括词嵌入的x 、初始的隐藏层h0
    • x的表示形式为[seq_len, batch, input_size],即[句子的长度,batch的大小,词向量的维度]
    • h0的表示形式为[num_layers, batch, hidden_size],即[隐藏层的层数,batch的大,隐藏层h的维数]\
  • 输出结果:主要包括输出结果output,最后一层的hn
    • output的表示形式与输入x类似,为[seq_len, batch, hidden_size],即[句子的长度,batch的大小,输出向量的维度]
    • hn的表示形式与输入h0一样,为[num_layers, batch, hidden_size],即[隐藏层的层数,batch的大,隐藏层h的维度]
import torch
import torch.nn as nn#  RNN层送入批量数据
def test():# 词向量维度 128, 隐藏向量维度 256rnn = nn.RNN(input_size=128, hidden_size=256)# 第一个数字: 表示句子长度,也就是词语个数# 第二个数字: 批量个数,也就是句子的个数# 第三个数字: 词向量维度inputs = torch.randn(5, 32, 128)hn = torch.zeros(1, 32, 256)# 获取输出结果output, hn = rnn(inputs, hn)print("输出向量的维度:\n",output.shape)print("隐含层输出的维度:\n",hn.shape)if __name__ == '__main__':test()

RNN及其变体LSTM、GRU

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

文本生成模型构建

项目需求

​ 文本生成任务是一种常见的自然语言处理任务,输入一个开始词能够预测出后面的词序列。本案例将会使用循环神经网络来实现周杰伦歌词生成任务。

import torch
import re
import jieba
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import timedef build_dataset():"""获取数据集:return: unique_words, word_to_idx, count, corpus_idx"""# 数据集位置file_name = '../data/jaychou_lyrics.txt'# 分词结果存储位置unique_words = []all_words = []# 遍历数据集中的每一行文本for line in open(file_name, 'r', encoding='utf-8'):# 使用jieba分词,分割结果是一个列表words = jieba.lcut(line)# print(words)# 所有的分词结果存储到all_sentences,其中包含重复的词组all_words.append(words)# 遍历分词结果,去重后存储到unique_wordsfor word in words:if word not in unique_words:unique_words.append(word)# 语料中词的数量word_count = len(unique_words)# 词到索引映射word_to_index = {word: idx for idx, word in enumerate(unique_words)}# 词表索引表示corpus_idx = []# 遍历每一行的分词结果for words in all_words:temp = []# 获取每一行的词,并获取相应的索引for word in words:temp.append(word_to_index[word])# 在每行词之间添加空格隔开temp.append(word_to_index[' '])# 获取当前文档中每个词对应的索引corpus_idx.extend(temp)return unique_words, word_to_index, word_count, corpus_idxclass LyricsDataset(torch.utils.data.Dataset):def __init__(self, corpus_idx, num_chars):# 文档数据中词的索引self.corpus_idx = corpus_idx# 每个句子中词的个数self.num_chars = num_chars# 词的数量self.word_count = len(self.corpus_idx)# 句子数量self.number = self.word_count // self.num_charsdef __len__(self):# 返回句子数量return self.numberdef __getitem__(self, idx):# idx指词的索引,并将其修正索引值到文档的范围里面start = min(max(idx, 0), self.word_count - self.num_chars - 2)# 输入值x = self.corpus_idx[start: start + self.num_chars]# 网络预测结果(目标值)y = self.corpus_idx[start + 1: start + 1 + self.num_chars]# 返回结果return torch.tensor(x), torch.tensor(y)# 模型构建
class TextGenerator(nn.Module):def __init__(self, word_count):super(TextGenerator, self).__init__()# 初始化词嵌入层: 词向量的维度为128self.ebd = nn.Embedding(word_count, 128)# 循环网络层: 词向量维度 128, 隐藏向量维度 128, 网络层数1self.rnn = nn.RNN(128, 128, 1)# 输出层: 特征向量维度128与隐藏向量维度相同,词表中词的个数self.out = nn.Linear(128, word_count)def forward(self, inputs, hidden):# 输出维度: (batch, seq_len,词向量维度 128)embed = self.ebd(inputs)# 修改维度: (seq_len, batch,词向量维度 128)output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 输入维度: (seq_len*batch,词向量维度 ) 输出维度: (seq_len*batch, 128)output = self.out(output.reshape((-1, output.shape[-1])))# 网络输出结果return output, hiddendef init_hidden(self, bs):# 隐藏层的初始化:[网络层数, batch, 隐藏层向量维度]return torch.zeros(1, bs, 128)# 模型训练
def train():# 构建词典index_to_word, word_to_index, word_count, corpus_idx = build_dataset()# 数据集dataset = LyricsDataset(corpus_idx, 32)# 初始化模型model = TextGenerator(word_count)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练轮数epoch = 20for epoch_idx in range(epoch):# 数据加载器lyrics_dataloader = DataLoader(dataset, shuffle=True, batch_size=2)# 训练时间start = time.time()iter_num = 0  # 迭代次数# 训练损失total_loss = 0.0# 遍历数据集for x, y in lyrics_dataloader:# 隐藏状态的初始化hidden = model.init_hidden(x.size(0))# 模型计算output, hidden = model(x, hidden)# 计算损失# y:[batch,seq_len]->[seq_len,batch]->[seq_len*batch]y = torch.transpose(y, 0, 1).contiguous().view(-1)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()iter_num += 1  # 迭代次数加1total_loss += loss.item()# 打印训练信息print('epoch %3s loss: %.5f time %.2f' % (epoch_idx + 1, total_loss / iter_num, time.time() - start))# 模型存储torch.save(model.state_dict(), '../model/lyrics_model_%d.pth' % epoch)def predict(start_word, sentence_length):# 构建词典index_to_word, word_to_index, word_count, _ = build_dataset()# 构建模型model = TextGenerator(word_count)# 加载参数model.load_state_dict(torch.load('../model/lyrics_model_10.pth'))# 隐藏状态hidden = model.init_hidden(bs=1)# 将起始词转换为索引word_idx = word_to_index[start_word]# 产生的词的索引存放位置generate_sentence = [word_idx]# 遍历到句子长度,获取每一个词for _ in range(sentence_length):# 模型预测output, hidden = model(torch.tensor([[word_idx]]), hidden)# 获取预测结果word_idx = torch.argmax(output)generate_sentence.append(word_idx)# 根据产生的索引获取对应的词,并进行打印for idx in generate_sentence:print(index_to_word[idx], end='')if __name__ == '__main__':# train()predict('回忆', 100)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/475523.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文阅读:Uni-ISP Unifying the Learning of ISPs from Multiple Cameras

这是 ECCV 2024 的一篇文章,文章作者想建立一个统一的 ISP 模型,以实现在不同手机之间的自由切换。文章作者是香港中文大学的 xue tianfan 和 Gu jinwei 老师。 Abstract 现代端到端图像信号处理器(ISPs)能够学习从 RAW/XYZ 数据…

ROS2指令总结(跟随古月居教程学习)

​ 博主跟随古月居博客进行ROS2学习,对ROS2相关指令进行了总结,方便学习和回顾。 古月居ROS2博文链接:https://book.guyuehome.com/ 本文会持续进行更新,觉得有帮助的朋友可以点赞收藏。 1. ROS2安装命令 $ sudo apt update &am…

Qt不同类之间参数的传递

一、信号槽方式 1: 在需要发送信号的子类增加一个信号函数 void set_send(double lonx, double laty);sub.h sub.cpp emit set_send(lonx,laty);2: 在需要接收信号的类增加一个槽函数 main.h void set_rece(double lonx, double laty);main.cpp 1)引入子类头文…

labview记录系统所用月数和天数

在做项目时会遇到采集系统的记录,比如一个项目测试要跑很久这个时候就需要在软件系统上显示项目运行了多少天,从开始测试开始一共用了多少年多少月。 年的话还好计算只需要把年份减掉就可以了,相比之下月份和天数就比较难确定,一…

机器翻译基础与模型 之一: 基于RNN的模型

一、机器翻译发展历程 基于规则的-->基于实例的-->基于统计方法的-->基于神经网络的 传统统计机器翻译把词序列看作离散空间里的由多个特征函数描述的点,类似 于 n-gram 语言模型,这类模型对数据稀疏问题非常敏感。神经机器翻译把文字序列表示…

WPF Prism框架

一、关于Prism框架 Prism.Core:【Prism.dll】实现MVVM的核心功能,属于一个与平台无关的项目 Prism.Wpf:【Prism.Wpf】包含了DialogService,Region,Module,Navigation,其他的一些WPF的功能 Prism.Unity:【Prism.Unity.Wpf】,IOC容器 Prism.Unity>Pr…

STM32F103系统时钟配置

时钟是单片机运行的基础,时钟信号推动单片机内各个部分执行相应的指令。时钟系统就是CPU的脉搏,决定CPU速率,像人的心跳一样 只有有了心跳,人才能做其他的事情,而单片机有了时钟,才能够运行执行指令&#x…

2024年 Web3开发学习路线全指南

Web3是一个包含了很多领域的概念,不讨论币圈和链圈的划分,Web3包括有Defi、NFT、Game等基于区块链的Dapp应用的开发;也有VR、AR等追求视觉沉浸感的XR相关领域的开发;还有基于区块链底层架构或者协议的开发。 这篇文章给出的学习路…

CTF--php伪协议结合Base64绕过

Base64绕过 在ctf中,base64是比较常见的编码方式,在做题的时候发现自己对于base64的编码和解码规则不是很了解,并且恰好碰到了类似的题目,在翻阅了大佬的文章后记录一下,对于base64编码的学习和一个工具 base64编码是…

Linux 命令之 tar

文章目录 1 tar 命令介绍2 压缩与解压缩2.1 压缩2.2 解压 4 高级用法4.1 排除目录4.2 显示进度4.2.1 脚本解压缩4.2.2 命令解压缩4.2.3 压缩进度 1 tar 命令介绍 常见的压缩包有 .tar.gz、.tar.xz、.tar.bz2,以及 .rar、.zip、.7z 等压缩包。 常见的 tar 选项&#…

Jenkins修改LOGO

重启看的LOGO和登录页面左上角的LOGO 进入LOGO存在的目录 [roottest-server01 svgs]# pwd /opt/jenkins_data/war/images/svgs [roottest-server01 svgs]# ll logo.svg -rw-r--r-- 1 jenkins jenkins 29819 Oct 21 10:58 logo.svg #jenkins_data目录是我挂载到了/opt目录&…

【大模型】LLaMA: Open and Efficient Foundation Language Models

链接:https://arxiv.org/pdf/2302.13971 论文:LLaMA: Open and Efficient Foundation Language Models Introduction 规模和效果 7B to 65B,LLaMA-13B 超过 GPT-3 (175B)Motivation 如何最好地缩放特定训练计算预算的数据集和模型大小&…

vue添加LCD字体(液晶字体)数字美化,前端如何引用LCD字体液晶字体,如何转换?@font-face 如何使用?

文章目录 一、效果二、下载字体格式【[https://www.dafont.com/theme.php?cat302&text0123456789](https://www.dafont.com/theme.php?cat302&text0123456789)】三、下载后,解压后都是.ttf文件,在【[https://www.fontsquirrel.com/tools/webfo…

【大数据学习 | Spark】关于distinct算子

只有shuffle类的算子能够修改分区数量,这些算子不仅仅存在自己的功能,比如分组算子groupBy,它的功能是分组但是却可以修改分区。 而这里我们要讲的distinct算子也是一个shuffle类的算子。即可以修改分区。 scala> val arr Array(1,1,2,…

Qt桌面应用开发 第五天(常用控件 自定义控件)

目录 1.QPushButton和ToolButton 1.1QPushButton 1.2ToolButton 2.RadioButton和CheckBox 2.1RadioButton单选按钮 2.2CheckBox多选按钮 3.ListWidget 4.TreeWidget控件 5.TableWidget控件 6.Containers控件 6.1QScrollArea 6.2QToolBox 6.3QTabWidget 6.4QStacke…

Excel - VLOOKUP函数将指定列替换为字典值

背景:在根据各种复杂的口径导出报表数据时,因为关联的表较多、数据量较大,一行数据往往会存在三个以上的字典数据。 为了保证导出数据的效率,博主选择了导出字典code值后,在Excel中处理匹配字典值。在查询百度之后&am…

ctfshow-web入门-SSRF(web351-web360)

目录 1、web351 2、web352 3、web353 4、web354 5、web355 6、web356 7、web357 8、web358 9、web359 10、web360 1、web351 看到 curl_exec 函数,很典型的 SSRF 尝试使用 file 协议读文件: urlfile:///etc/passwd 成功读取到 /etc/passwd 同…

【vmware+ubuntu16.04】ROS学习_博物馆仿真克隆ROS-Academy-for-Beginners软件包处理依赖报错问题

首先安装git 进入终端,输入sudo apt-get install git 安装后,创建一个工作空间名为tutorial_ws, 输入 mkdir tutorial_ws#创建工作空间 cd tutorial_ws#进入 mkdir src cd src git clone https://github.com/DroidAITech/ROS-Academy-for-Be…

AI数字人视频小程序:引领未来互动新潮流

当下,随着人工智能技术的不断创新发展,各类AI系统已经成为了创新市场发展的重要力量,AI文案、AI数字人、AI视频等,为大众带来更加便捷的创作方式,AI成为了一个全新的风口,各种AI红利持续释放,市…

leetcode400第N位数字

代码 class Solution {public int findNthDigit(int n) {int base 1;//位数int weight 9;//权重while(n>(long)base*weight){//300n-base*weight;base;weight*10;}//n111 base3 weight900;n--;int res (int)Math.pow(10,base-1)n/base;int index n%base;return String…