深度学习之生成唐诗案例(Pytorch版)

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

 示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

运行结果:

代码部分:
Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderdef deal_tangshi():with open("tangshis.txt", "r", encoding="utf-8") as fr:lines = fr.read().strip().split("\n")tangshis = []for line in lines:splits = line.split(":")if len(splits) != 2:continuetangshis.append("S" + splits[1] + "E")word2idx = {"S": 0, "E": 1}word2idx_count = 2tangshi_ids = []for tangshi in tangshis:for word in tangshi:if word not in word2idx:word2idx[word] = word2idx_countword2idx_count += 1idx2word = {idx: w for w, idx in word2idx.items()}for tangshi in tangshis:tangshi_ids.extend([word2idx[w] for w in tangshi])return word2idx, idx2word, tangshis, word2idx_count, tangshi_idsword2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()class TangShiDataset(Dataset):def __init__(self, tangshi_ids, num_chars):# 语料数据self.tangshi_ids = tangshi_ids# 语料长度self.num_chars = num_chars# 词的数量self.word_count = len(self.tangshi_ids)# 句子数量self.number = self.word_count // self.num_charsdef __len__(self):return self.numberdef __getitem__(self, idx):# 修正索引值到: [0, self.word_count - 1]start = min(max(idx, 0), self.word_count - self.num_chars - 2)x = self.tangshi_ids[start: start + self.num_chars]y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]return torch.tensor(x), torch.tensor(y)def __test_Dataset():dataset = TangShiDataset(tangshi_ids, 8)x, y = dataset[0]print(x, y)if __name__ == '__main__':# deal_tangshi()__test_Dataset()
TangShiModel.py:唐诗的模型
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as Fclass TangShiRNN(nn.Module):def __init__(self, vocab_size):super().__init__()# 初始化词嵌入层self.ebd = nn.Embedding(vocab_size, 128)# 循环网络层self.rnn = nn.RNN(128, 128, 1)# 输出层self.out = nn.Linear(128, vocab_size)def forward(self, inputs, hidden):embed = self.ebd(inputs)# 正则化层embed = F.dropout(embed, p=0.2)output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 正则化层embed = F.dropout(output, p=0.2)output = self.out(output.squeeze())return output, hiddendef init_hidden(self):return torch.zeros(1, 64, 128)

 main.py:

import timeimport torchfrom Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def train():dataset = TangShiDataset(tangshi_ids, 128)epochs = 100model = TangShiRNN(word2idx_count).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)for idx in range(epochs):dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)start_time = time.time()total_loss = 0total_num = 0total_correct = 0total_correct_num = 0hidden = model.init_hidden()for x, y in tqdm(dataloader):x = x.to(device)y = y.to(device)# 隐藏状态hidden = model.init_hidden()hidden = hidden.to(device)# 模型计算output, hidden = model(x, hidden)# print(output.shape)# print(y.shape)# 计算损失loss = criterion(output.permute(1, 2, 0), y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_loss += loss.sum().item()total_num += len(y)total_correct_num += y.shape[0] * y.shape[1]# print(output.shape)total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %(idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")if __name__ == '__main__':train()

predict.py:

import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def predict():model = TangShiRNN(word2idx_count)model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))model.eval()hidden = torch.zeros(1, 1, 128)start_word = input("输入第一个字:")flag = Nonetangshi_strs = []while True:if not flag:outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)tangshi_strs.append("S")flag = Trueelse:tangshi_strs.append(start_word)outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)top_i = torch.argmax(outputs, dim=-1)if top_i.item() == word2idx["E"]:breakprint(top_i)start_word = idx2word[top_i.item()]print(tangshi_strs)if __name__ == '__main__':predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/201448.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码

基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于材料生成优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神…

矩阵知识补充

正交矩阵 定义: 正交矩阵是一种满足 A T A E A^{T}AE ATAE的方阵 正交矩阵具有以下几个重要性质: A的逆等于A的转置,即 A − 1 A T A^{-1}A^{T} A−1AT**A的行列式的绝对值等于1,即 ∣ d e t ( A ) ∣ 1 |det(A)|1 ∣det(A)∣…

【开源】基于Vue.js的教学过程管理系统

项目编号: S 054 ,文末获取源码。 \color{red}{项目编号:S054,文末获取源码。} 项目编号:S054,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 教师端2.2 学生端2.3 微信小程序端2…

D. Absolute Beauty - 思维

题面 分析 补题。配上题解的图,理解了很长时间,思维还需要提高。 对于每一对 a i a_i ai​和 b i b_i bi​,可以看作一个线段的左右端点,这是关键的一步,那么他们的绝对值就是线段的长度,对于线段相对位…

Microsoft Visual Studio 2019下载及安装流程记录

第一周任务: 1.笔记本上安装vc2019的环境 2.再把OpenCV安装上 3.根据网上的教程,试着写几个opencv的程序 一、安装Visual Studio 2019社区版 首先先完成安装vc2019的环境, 因为: Microsoft Visual C是用于C编程的工具集合&am…

计算机毕业论文内容参考|基于深度学习的交通标识智能识别系统的设计与维护

文章目录 导文摘要前言绪论1课题背景2国内外现状与趋势3课题内容相关技术与方法介绍系统分析总结与展望导文 基于深度学习的交通标识智能识别系统是一种利用深度学习模型对交通标识进行识别和解析的系统。它可以帮助驾驶员更好地理解交通规则和安全提示,同时也可以提高道路交通…

【tomcat】java.lang.Exception: Socket bind failed: [730048

项目中一些旧工程运行情况处理 问题 1、启动端口占用 2、打印编码乱码 ʮһ�� 13, 2023 9:33:26 ���� org.apache.coyote.AbstractProtocol init ����: Fa…

Active Directory 和域名系统(DNS)的相互关系

什么是域名系统(DNS) 域名系统(DNS),从一般意义上讲是一种将主机名或域名解析为相应IP地址的手段。 在 AD 的中,DNS 服务维护 DNS 域和子域的工作命名空间,这些域和子域主要有助于查找过程&am…

echarts 几千条分钟级别在小时级别图标上展示

需求背景解决效果ISQQW代码地址strategyChart.vue 需求背景 需要实现 秒级数据几千条在图表上显示&#xff0c;(以下是 设计图表上是按小时界别显示数据&#xff0c;后端接口为分钟级别数据) 解决效果 ISQQW代码地址 链接 strategyChart.vue <!--/** * author: liuk *…

2023 年 亚太赛 APMCM ABC题 国际大学生数学建模挑战赛 |数学建模完整代码+建模过程全解全析

当大家面临着复杂的数学建模问题时&#xff0c;你是否曾经感到茫然无措&#xff1f;作为2022年美国大学生数学建模比赛的O奖得主&#xff0c;我为大家提供了一套优秀的解题思路&#xff0c;让你轻松应对各种难题。 以五一杯 A题为例子&#xff0c;以下是咱们做的一些想法呀&am…

电子眼与无人机在城市安防中的协同应用研究

随着城市化进程的快速推进&#xff0c;城市安全问题成为了人们关注的焦点。传统的安防手段已经无法满足现代城市复杂多变的安全需求。因此&#xff0c;结合电子眼与无人机技术&#xff0c;实现二者之间的协同应用&#xff0c;成为提升城市安防能力的重要途径。 一、电子眼与无人…

Unity开发之C#基础-File文件读取

前言 今天我们将要讲解到c#中 对于文件的读写是怎样的 那么没接触过特别系统编程小伙伴们应该会有一个疑问 这跟文件有什么关系呢&#xff1f; 我们这样来理解 首先 大家对电脑或多或少都应该有不少的了解吧 那么我们这些软件 都是通过变成一个一个文件保存在电脑中 我们才可以…

SecureCRT -- 使用说明

【概念解释】什么是SSH&#xff1f; SSH的英文全称是Secure Shell 传统的网络服务程序&#xff0c;如&#xff1a;ftp和telnet在本质上都是不安全的&#xff0c;因为它们在网络上用明文传送口令和数据&#xff0c;别有用心的人非常容易就可以截获这些口令和数据。而通过使用SS…

6.基于蜻蜓优化算法 (DA)优化的VMD参数(DA-VMD)

代码原理 基于蜻蜓优化算法 (Dragonfly Algorithm, DA) 优化的 VMD 参数&#xff08;DA-VMD&#xff09;是指使用蜻蜓优化算法对 VMD 方法中的参数进行自动调优和优化。 VMD&#xff08;Variational Mode Decomposition&#xff09;是一种信号分解方法&#xff0c;用于将复杂…

ASUS华硕ROG幻13笔记本电脑GV301QE原厂Windows10系统

链接&#xff1a;https://pan.baidu.com/s/1aPW0ctRXRNAhE75mzVPdTg?pwdds78 提取码&#xff1a;ds78 华硕玩家国度幻13笔记本电脑锐龙版Ryzen 7 5800HS,显卡3050 3050Ti,3060,3060Ti,3070,3070Ti 原厂W10系统自带所有驱动、出厂主题壁纸、系统属性专属LOGO标志、Office办…

leetcode:环形链表

题目描述 题目链接&#xff1a;141. 环形链表 - 力扣&#xff08;LeetCode&#xff09; 题目分析 我们先了解一个知识&#xff1a;循环链表 尾结点不指向NULL&#xff0c;指向头就是循环链表 那么带环链表就意味着尾结点的next可以指向链表的任意一个结点&#xff0c;甚至可…

系列二、Lock接口

一、多线程编程模板 线程 操作 资源类 高内聚 低耦合 二、实现步骤 1、创建资源类 2、资源类里创建同步方法、同步代码块 三、12306卖票程序 3.1、synchronized实现 3.1.1、Ticket /*** Author : 一叶浮萍归大海* Date: 2023/11/20 8:54* …

LangChain库简介

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…

关于 Docker

关于 Docker 1. 术语Docker Enginedockerd&#xff08;Docker daemon&#xff09;containerdOCI (Open Container Initiative)runcDocker shimCRI (Container Runtime Interface)CRI-O 2. 容器启动过程在 Linux 中的实现daemon 的作用 Docker 是个划时代的开源项目&#xff0c;…

【React-Router】路由快速上手

1. 创建路由开发环境 # 使用CRA创建项目 npm create-react-app react-router-pro# 安装最新的ReactRouter包 npm i react-router-dom2. 快速开始 // index.jsimport React from react; import ReactDOM from react-dom/client; import ./index.css; import App from ./App; i…