Pytorch实现RNN实验

一、实验要求

        用 Pytorch 模块的 RNN 实现生成唐诗。要求给定一个字能够生成一首唐诗。

二、实验目的

  1. 理解循环神经网络(RNN)的基本原理:通过构建一个基于RNN的诗歌生成模型,学会RNN是如何处理序列数据的,以及如何在PyTorch中实现它。
  2. 掌握文本数据的预处理技巧,并学会构建一个文本生成模型
  3. 加深对循环神经网络的了解

三、实验过程

1.搭建模型

整体思路:

        先进行文本预处理,读取诗歌数据,清理文本数据,并构建词汇表,记录每个字符的出现频率。然后将清理过的文本数据转换成模型可用的数字表示形式,并将整数表示的文本数据切分为多个序列,构建训练数据集。随后,定义一个基于RNN的模型,通过训练数据集迭代训练模型来优化模型参数。模型训练完成后可利用模型生成一定长度的新诗歌文本。

1)导入库和检查GPU可用性

导入Pytorch库并检查GPU是否可用。如果GPU可用,返回“True”

0d7c31528e784a4385fd612ac838b4fa.png

导入进行数据预处理和标记所需的库

12e03ed704584408a311bca179a0cd75.png

2)定义超参数

定义了学习率、最大训练轮次、批处理大小以及是否使用GPU的标志。

a95a519588c046c89d24463dd6a09781.png

3)数据处理

引入诗歌文件,形成诗歌数据集,并通过替换换行符和中文标点符号来清理文本

f02607369439424aafdb5fd640bdea33.png

 ‘TextConverter’类负责对文本数据进行预处理和转换

 e903211fa6bb432cb08a5a6e925d0d49.png

c3a2bdc80da948c5a2fd6dcd7224606c.png 600f40d27203467299d38899cdc0337c.png

字符到整数和整数到字符的转换方法:

  1. word_to_int方法接受一个字符作为参数,返回字符在词汇表中的整数索引。如果字符不在词汇表中,则返回词汇表大小。
  2. int_to_word方法接受一个整数索引作为参数,返回该索引对应的字符。如果索引等于词汇表大小,返回中文逗号",";如果索引小于词汇表大小,则返回对应的字符;否则,抛出异常。

 a8b8ac010c4d4a49bfb093d784705c7e.png

 

文本到数组和数组到文本的转换方法:

  1. text_to_arr方法接受一个文本字符串作为参数,返回一个由文本中每个字符对应整数索引组成的NumPy数组。
  2. arr_to_text方法接受一个整数索引数组作为参数,返回由数组中每个索引对应字符组成的字符串

 57a537561c9d4856a75704602c46afaf.png

准备数据集 

d12a2aa6f726424882c41d14fc07587e.png 

定义数据集 

87e91f85726749719c0510fbb6cde7ef.png 

4)定义RNN模型

        使用PyTorch的nn.Module定义了RNN模型的结构

        通过嵌入层将字符索引映射为密集向量,然后通过RNN层处理这些向量序列。最后,通过线性层将RNN输出映射为词汇表大小的向量。

3f2a1ecae5914cdf98c83f87c55f53c9.png

 

5)模型初始化、损失和优化器

使用交叉熵损失函数(nn.CrossEntropyLoss())来度量模型输出与实际标签之间的差异。

使用Adam优化器(torch.optim.Adam)来更新模型参数,其中学习率为Learning_rate。

f6aaf40adb784438b06c8fec3635b5d4.png

6)训练循环

通过反复迭代,模型在每个Epoch中根据训练数据调整参数,逐渐提高对中文诗歌模式的学习,使得生成的文本更符合训练数据的特征

8de168c7ffe547ed989c6af34a001d46.png

 

2.对模型进行优化、改进

1)运行程序

823bae72b8d549df8d0a91bcf0583b0d.png

根据提供的训练输出结果来看,Perplexity的数值较大,而Loss较高,说明模型在训练数据上的拟合效果相对较差。通常情况下,Perplexity较低且Loss较小的模型效果更好。

分析可能导致模型效果一般的原因:

  1. 增加模型复杂性:添加更多层或增加现有层中的隐藏单元数
  2. 使用LSTM或者GRU:捕捉序列中的长期依赖关系
  1. 调整嵌入维度:尝试不同的myRNN类中的embed_dim参数值
  2. 调整学习率
  3. 增加训练次数
  4. 实现验证集:将数据集拆分为训练集和验证集。使用验证集来监控训练过程中模型的性能。在验证损失不再下降或开始上升时停止训练。

 2)修改模型结构,使用LSTM结构

 f124531b85e64938ba4987fa516502c0.png

 

并且将训练次数增加到50

输出结果为:

f9e9d5bd2f5541289e350e87388adf0f.png

调整学习率为1e-5 ,输出

fa6f66caef28446693f81094510c70fd.png

3)实现测试集:将数据集拆分为训练集和测试集

8a0154d0a87c4bde85bbeea0d7b98bc5.png

 发现多次调参,调整Embedding层,调模型结构都没调出合适的模型,输出的诗句有很多重复的字。

4)选择将原模型增加测试集进行尝试

f390c840ac774e8893bcb837067e15ee.png

af88e781be7c4073a06608730b264df6.png

8b12628f2b6649e1be2ba53b9610f459.png 

输出结果为 

21f44249761949babbafc8e56bfd2498.png

考虑到古诗上下文之间有一定的关联性

将n_step设置为30

输出结果为

0eca6277c86d41b09676fd805cdb1c72.png

 

将n_step设置为40

输出结果为

0a2224a9ede34865a74bd8369848a1a4.png

 发现这种情况是所有实验中Loss最小的一种

四、实验结果

        经过多次调参,优化模型,发现使用RNN结构,学习率为1e-4,epochs为50,n_setp为40时,得出的Loss最小。

五、实验总结

        在修改深度学习代码,特别是从RNN迁移到LSTM的过程中,我遇到了一些挑战。首先,了解LSTM与RNN的区别和工作原理对于成功修改代码至关重要。其次,我注意到LSTM层的输入格式要求与RNN不同,需要将batch_first设置为True。在调试过程中,还遇到了一些GPU不可用的问题,通过检查CUDA是否可用、GPU驱动程序和PyTorch版本等方面找到解决方案。总的来说,通过修改代码将RNN替换为LSTM,我更深入地理解了这两者之间的差异。但是,由于自己的能力有限,在修改为LSTM后并没有成功优化模型。所以,最后还是将RNN结构模型增加测试集,得出一个相对较好的结果。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/439406.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序使用picker,数组怎么设置默认值

默认先显示请选择XXX。然后点击弹出选择列表。如果默认value是0的话&#xff0c;他就直接默认显示数组的第一个了。<picker mode"selector" :value"planIndex" :range"planStatus" range-key"label" change"bindPlanChange&qu…

使用Conda管理python环境的指南

1. 准备 .yml 文件 确保你有一个定义了 Conda 环境的 .yml 文件。这个文件通常包括环境的依赖和配置设置。文件内容可能如下所示&#xff1a; name: myenv channels:- defaults dependencies:- python3.8- numpy- pandas- scipy- pip- pip:- torch- torchvision- torchaudio2…

OpenCV马赛克

#马赛克 import cv2 import numpy as np import matplotlib.pyplot as pltimg cv2.imread(coins.jpg,1) imgInfo img.shape height imgInfo[0] width imgInfo[1]for m in range(200,400): #m,n表示打马赛克区域for n in range(200,400):# pixel ->10*10if m%10 0 and …

Hive数仓操作(十七)

一、Hive的存储 一、Hive 四种存储格式 在 Hive 中&#xff0c;支持四种主要的数据存储格式&#xff0c;每种格式有其特点和适用场景&#xff0c;不过一般只会使用Text 和 ORC &#xff1a; 1. Text 说明&#xff1a;Hive 的默认存储格式。存储方式&#xff1a;行存储。优点…

华硕天选笔记本外接音箱没有声音

系列文章目录 文章目录 系列文章目录一.前言二.解决方法第一种方法第二种方法 一.前言 华硕天选笔记本外接音箱没有声音&#xff0c;在插上外接音箱时&#xff0c;系统会自动弹出下图窗口 二.解决方法 第一种方法 在我的电脑上选择 Headphone Speaker Out Headset 这三个选项…

Custom C++ and CUDA Extensions - PyTorch

0. Abstract 经历了一波 pybind11 和 CUDA 编程 的学习, 接下来看一看 PyTorch 官方给的 C/CUDA 扩展的教程. 发现极其简单, 就是直接用 setuptools 导出 PyTorch C 版代码的 Python 接口就可以了. 所以, 本博客包含以下内容: LibTorch 初步;C Extension 例子; 1. LibTorch …

国庆刷题(day4)

C语言&#xff1a; C&#xff1a;

gdb 调试 linux 应用程序的技巧介绍

使用 gdb 来调试 Linux 应用程序时&#xff0c;可以显著提高开发和调试的效率。gdb&#xff08;GNU 调试器&#xff09;是一款功能强大的调试工具&#xff0c;适用于调试各类 C、C 程序。它允许我们在运行程序时检查其状态&#xff0c;设置断点&#xff0c;跟踪变量值的变化&am…

电脑手机下载小米xiaomi redmi刷机包太慢 解决办法

文章目录 修改前下载速度修改后下载速度修改方法&#xff08;修改host&#xff09; 修改前下载速度 一开始笔者以为是迅雷没开会员的问题&#xff0c;在淘宝上买了一个临时会员后下载速度依然最高才100KB/s 修改后下载速度 修改方法&#xff08;修改host&#xff09; host文…

Python编码规范与常见问题纠正

Python编码规范与常见问题纠正 Python 是一种以简洁和易读性著称的编程语言&#xff0c;因此&#xff0c;遵循良好的编码规范不仅能使代码易于维护&#xff0c;还能提升代码的可读性和可扩展性。编写规范的 Python 代码也是开发者职业素养的一部分&#xff0c;本文将从 Python…

TryHackMe 第6天 | Web Fundamentals (一)

这一部分我们要简要介绍以下 Web Hacking 的基本内容&#xff0c;预计分三次博客。 在访问 Web 应用时&#xff0c;浏览器提供了若干个工具来帮助我们发现一些潜在问题和有用的信息。 比如可以查看网站源代码。查看源代码可以 右键 网页&#xff0c;然后选择 查看网站源代码&…

【复习】CSS中的选择器

文章目录 东西有点多 以实战为主选择器盒子模型 东西有点多 以实战为主 选择器 CSS选择器&#xff08;CSS Selectors&#xff09;是用于在HTML或XML文档中查找和选择元素&#xff0c;以便应用CSS样式的一种方式。 元素选择器&#xff08;Type Selector&#xff09; 选择所有…

探索 aMQTT:Python中的AI驱动MQTT库

文章目录 探索 aMQTT&#xff1a;Python中的AI驱动MQTT库背景介绍aMQTT是什么&#xff1f;如何安装aMQTT&#xff1f;简单库函数使用方法场景应用常见问题及解决方案总结 探索 aMQTT&#xff1a;Python中的AI驱动MQTT库 背景介绍 在物联网和微服务架构的浪潮中&#xff0c;MQ…

CSS3练习--电商web

免责声明&#xff1a;本文仅做分享&#xff01; 目录 小练--小兔鲜儿 目录构建 SEO 三大标签 Favicon 图标 布局网页 版心 快捷导航&#xff08;shortcut&#xff09; 头部&#xff08;header&#xff09; logo 导航 搜索 购物车 底部&#xff08;footer&#xff0…

2024年计算机视觉与艺术研讨会(CVA 2024)

目录 基本信息 大会简介 征稿主题 会议议程 参会方式 基本信息 大会官网&#xff1a;www.icadi.net&#xff08;点击了解参会投稿等信息&#xff09; 大会时间&#xff1a;2024年11月29-12月1日 大会地点&#xff1a;中国-天津 大会简介 2024年计算机视觉与艺术国际学术…

Redis --- 第三讲 --- 通用命令

一、get和set命令 Redis中最核心的两个命令 get 根据key来取value set 把key和value存储进去 redis是按照键值对的方式存储数据的。必须要先进入到redis客户端。 语法 set key value &#xff1a; key和value都是字符串。 对于上述这里的key value 不需要加上引号&#…

数据库概述(1)

课程主页&#xff1a;Guoliang Li Tsinghua 数据库在计算机系统中的位置 首先&#xff0c;数据库是在设计有大量数据存储需求的软件时必不可少可的基础。 最常见的是&#xff1a;我们通过app或者是浏览器来实现一些特定需求——比如转账、订车票。即引出背后的CS和BS两种网…

重学SpringBoot3-集成Redis(三)

更多SpringBoot3内容请关注我的专栏&#xff1a;《SpringBoot3》 期待您的点赞&#x1f44d;收藏⭐评论✍ 重学SpringBoot3-集成Redis&#xff08;三&#xff09; 1. 引入 Redis 依赖2. 配置 RedisCacheManager 及自定义过期策略2.1 示例代码&#xff1a;自定义过期策略 3. 配置…

如何使用ssm实现民族大学创新学分管理系统分析与设计+vue

TOC ssm763民族大学创新学分管理系统分析与设计vue 第1章 绪论 1.1 课题背景 二十一世纪互联网的出现&#xff0c;改变了几千年以来人们的生活&#xff0c;不仅仅是生活物资的丰富&#xff0c;还有精神层次的丰富。在互联网诞生之前&#xff0c;地域位置往往是人们思想上不…

【rCore OS 开源操作系统】Rust 字符串(可变字符串String与字符串切片str)

【rCore OS 开源操作系统】Rust 语法详解: Strings 前言 这次涉及到的题目相对来说比较有深度&#xff0c;涉及到 Rust 新手们容易困惑的点。 这一次在直接开始做题之前&#xff0c;先来学习下字符串相关的知识。 Rust 的字符串 Rust中“字符串”这个概念涉及多种类型&…