使用PyTorch实现LSTM生成ai诗

最近学习torch的一个小demo。

什么是LSTM?

长短时记忆网络(Long Short-Term Memory,LSTM)是一种循环神经网络(RNN)的变体,旨在解决传统RNN在处理长序列时的梯度消失和梯度爆炸问题。LSTM引入了一种特殊的存储单元和门控机制,以更有效地捕捉和处理序列数据中的长期依赖关系。

通俗点说就是:LSTM是一种改进版的递归神经网络(RNN)。它的主要特点是可以记住更长时间的信息,这使得它在处理序列数据(如文本、时间序列、语音等)时非常有效。

步骤如下

数据准备

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string
import os# 数据加载和预处理
def load_data(filepath):with open(filepath, 'r', encoding='utf-8') as file:text = file.read()return textdef preprocess_text(text):text = text.lower()text = text.translate(str.maketrans('', '', string.punctuation))return textdata_path = 'poetry.txt'  # 替换为实际的诗歌数据文件路径
text = load_data(data_path)
text = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")

模型构建

定义LSTM模型:

class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, x, hidden):lstm_out, hidden = self.lstm(x, hidden)output = self.fc(lstm_out[:, -1, :])output = self.softmax(output)return output, hiddendef init_hidden(self, batch_size):weight = next(self.parameters()).datahidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),weight.new(self.num_layers, batch_size, self.hidden_size).zero_())return hidden

训练模型

将数据转换成LSTM需要的格式:

def prepare_data(text, seq_length):inputs = []targets = []for i in range(0, len(text) - seq_length, 1):seq_in = text[i:i + seq_length]seq_out = text[i + seq_length]inputs.append([char_to_idx[char] for char in seq_in])targets.append(char_to_idx[seq_out])return inputs, targetsseq_length = 100
inputs, targets = prepare_data(text, seq_length)# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 20
learning_rate = 0.001model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Training loop
for epoch in range(num_epochs):h = model.init_hidden(batch_size)total_loss = 0for i in range(0, len(inputs), batch_size):x = inputs[i:i + batch_size]y = targets[i:i + batch_size]x = nn.functional.one_hot(x, num_classes=vocab_size).float()output, h = model(x, h)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(inputs):.4f}")

生成

def generate_text(model, start_str, length=100):model.eval()with torch.no_grad():input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()h = model.init_hidden(1)predicted_text = start_strfor _ in range(length):output, h = model(input_eval, h)prob = torch.softmax(output, dim=1).datapredicted_idx = torch.multinomial(prob, num_samples=1).item()predicted_char = idx_to_char[predicted_idx]predicted_text += predicted_charinput_eval = torch.tensor([[predicted_idx]], dtype=torch.long)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()return predicted_textstart_string = "春眠不觉晓"
generated_text = generate_text(model, start_string)
print(generated_text)

运行结果如下:

运行的肯定不好,但至少出结果了。诗歌我这边只放了几句,可以自己通过外部文件放入更多素材。

整体代码直接运行即可:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import string# 预定义一些中文诗歌数据
text = """
春眠不觉晓,处处闻啼鸟。
夜来风雨声,花落知多少。
床前明月光,疑是地上霜。
举头望明月,低头思故乡。
红豆生南国,春来发几枝。
愿君多采撷,此物最相思。
"""# 数据预处理
def preprocess_text(text):text = text.replace('\n', '')return texttext = preprocess_text(text)
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, x, hidden):lstm_out, hidden = self.lstm(x, hidden)output = self.fc(lstm_out[:, -1, :])output = self.softmax(output)return output, hiddendef init_hidden(self, batch_size):weight = next(self.parameters()).datahidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),weight.new(self.num_layers, batch_size, self.hidden_size).zero_())return hiddendef prepare_data(text, seq_length):inputs = []targets = []for i in range(0, len(text) - seq_length, 1):seq_in = text[i:i + seq_length]seq_out = text[i + seq_length]inputs.append([char_to_idx[char] for char in seq_in])targets.append(char_to_idx[seq_out])return inputs, targetsseq_length = 10
inputs, targets = prepare_data(text, seq_length)# Convert to tensors
inputs = torch.tensor(inputs, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)batch_size = 64
input_size = vocab_size
hidden_size = 256
output_size = vocab_size
num_epochs = 50
learning_rate = 0.003model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Training loop
for epoch in range(num_epochs):h = model.init_hidden(batch_size)total_loss = 0for i in range(0, len(inputs), batch_size):x = inputs[i:i + batch_size]y = targets[i:i + batch_size]if x.size(0) != batch_size:continuex = nn.functional.one_hot(x, num_classes=vocab_size).float()output, h = model(x, h)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(inputs):.4f}")def generate_text(model, start_str, length=100):model.eval()with torch.no_grad():input_eval = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()h = model.init_hidden(1)predicted_text = start_strfor _ in range(length):output, h = model(input_eval, h)prob = torch.softmax(output, dim=1).datapredicted_idx = torch.multinomial(prob, num_samples=1).item()predicted_char = idx_to_char[predicted_idx]predicted_text += predicted_charinput_eval = torch.tensor([[predicted_idx]], dtype=torch.long)input_eval = nn.functional.one_hot(input_eval, num_classes=vocab_size).float()return predicted_textstart_string = "春眠不觉晓"
generated_text = generate_text(model, start_string, length=100)
print(generated_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/351010.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

初识PHP

一、格式 每行以分号结尾 <?phpecho hello; ?>二、echo函数和print函数 作用&#xff1a;两个函数都是输出内容到页面中&#xff0c;多用于代码调试。 <?php echo "<h1 styletext-align: center;>test</h1>"; print "<h1 stylet…

笔记 | 用go写个docker

仅作为自己学习过程的记录&#xff0c;不具备参考价值 前言 看到一段非常有意思的话&#xff1a; 很多人刚接触docker的时候就会感觉非常神奇&#xff0c;感觉这个技术非常新颖&#xff0c;其实并不然&#xff0c;docker使用到的技术都是之前已经存在过的&#xff0c;只不过旧…

如何在Spring Boot中实现图片上传至本地和阿里云OSS

在开发Web应用时&#xff0c;处理文件上传是常见的需求之一&#xff0c;尤其是在涉及到图片、视频等多媒体数据时。本文将详细介绍如何使用Spring Boot实现图片上传至本地服务器以及阿里云OSS存储服务&#xff0c;并提供完整的代码示例。 一、上传图片至本地 首先&#xff0c…

CMU最新论文:机器人智慧流畅的躲避障碍物论文详细讲解

CMU华人博士生Tairan He最新论文&#xff1a;Agile But Safe: Learning Collision-Free High-Speed Legged Locomotion 代码开源&#xff1a;Code: https://github.com/LeCAR-Lab/ABS B站实际效果展示视频地址&#xff1a;bilibili效果地址 我会详细解读论文的内容,让我们开始吧…

这个网站有点意思,可做SPRINGBOOT的启动图

在 SpringBoot 项目的 resources 目录下新建一个 banner.txt 文本文件&#xff0c;然后将启动 Banner 粘贴到此文本文件中&#xff0c;启动项目&#xff0c;即可在控制台展示对应的内容信息。 下面这个工具很好用&#xff0c;收藏精哦

C/C++:指针用法详解

C/C&#xff1a;指针 指针概念 指针变量也是一个变量 指针存放的内容是一个地址&#xff0c;该地址指向一块内存空间 指针是一种数据类型 指针变量定义 内存最小单位&#xff1a;BYTE字节&#xff08;比特&#xff09; 对于内存&#xff0c;每个BYTE都有一个唯一不同的编号…

积木搭建游戏-第13届蓝桥杯省赛Python真题精选

[导读]&#xff1a;超平老师的Scratch蓝桥杯真题解读系列在推出之后&#xff0c;受到了广大老师和家长的好评&#xff0c;非常感谢各位的认可和厚爱。作为回馈&#xff0c;超平老师计划推出《Python蓝桥杯真题解析100讲》&#xff0c;这是解读系列的第83讲。 积木搭建游戏&…

QT属性系统,简单属性功能快速实现 QT属性的简单理解 属性学习如此简单 一文就能读懂QT属性 QT属性最简单的学习

4.4 属性系统 Qt 元对象系统最主要的功能是实现信号和槽机制&#xff0c;当然也有其他功能&#xff0c;就是支持属性系统。有些高级语言通过编译器的 __property 或者 [property] 等关键字实现属性系统&#xff0c;用于提供对成员变量的访问权限&#xff0c;Qt 则通过自己的元对…

回归预测 | Matlab实现GWO-ESN基于灰狼算法优化回声状态网络的多输入单输出回归预测

回归预测 | Matlab实现GWO-ESN基于灰狼算法优化回声状态网络的多输入单输出回归预测 目录 回归预测 | Matlab实现GWO-ESN基于灰狼算法优化回声状态网络的多输入单输出回归预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现GWO-ESN基于灰狼算法优化回声状态…

软件下载网站源码附手机版和图文教程

PHP游戏应用市场APP软件下载平台网站源码手机版 可自行打包APP&#xff0c;带下载统计&#xff0c;带多套模板&#xff0c;带图文教程&#xff0c;可以做软件库&#xff0c;也可以做推广app下载等等&#xff0c;需要的朋友可以下载 源码下载 软件下载网站源码附手机版和图文…

Guava-EventBus 源码解析

EventBus 采用发布订阅者模式的实现方式&#xff0c;它实现了泛化的注册方法以及泛化的方法调用,另外还考虑到了多线程的问题,对多线程使用时做了一些优化&#xff0c;观察者模式都比较熟悉&#xff0c;这里会简单介绍一下&#xff0c;重点介绍的是如何泛化的进行方法的注册以及…

FineReport简单介绍

一、介绍 官网 &#xff1a;FineReport产品简介- FineReport帮助文档 - 全面的报表使用教程和学习资料 报表是以表格、图表的形式来动态展示数据&#xff0c;企业通过报表进行数据分析&#xff0c;进而用于辅助经营管理决策。 FineReport 是一款用于报表制作&#xff0c;分析和…

uniapp中unicloud接入支付宝订阅消息完整教程

经过无数次的尝试,终于还是让我做出来了 准备工作 设置接口加签方式 使用支付宝小程序订阅消息,首先要设置接口加签方式,需要下载支付宝开放平台密钥工具,按照步骤生成秘钥,然后按照支付宝设置密钥加签方式添加接口加签方式。 有一点需要注意的,因为要在云函数中使用,…

Mac M3 Pro安装Hadoop-3.3.6

1、下载Hadoop安装包 可以到官方网站下载&#xff0c;也可以使用网盘下载 官网下载地址&#xff1a;Hadoop官网下载地址 网盘地址&#xff1a;https://pan.baidu.com/s/1p4BXq2mvby2B76lmpiEjnA?pwdr62r提取码: r62r 2、解压并添加环境变量 # 将安装包移动到指定目录 mv …

基于flask的网站如何使用https加密通信-问题记录

文章目录 项目场景&#xff1a;问题1问题描述原因分析解决步骤解决方案 问题2问题描述原因分析解决方案 参考文章 项目场景&#xff1a; 项目场景&#xff1a;基于flask的网站使用https加密通信一文中遇到的问题记录 问题1 问题描述 使用下面的命令生成自签名的SSL/TLS证书和…

大模型基础——从零实现一个Transformer(3)

大模型基础——从零实现一个Transformer(1)-CSDN博客 大模型基础——从零实现一个Transformer(2)-CSDN博客 一、前言 之前两篇文章已经讲了Transformer的Embedding,Tokenizer,Attention,Position Encoding, 本文我们继续了解Transformer中剩下的其他组件. 二、归一化 2.1 L…

红队攻防渗透技术实战流程:中间件安全:JettyJenkinsWeblogicWPS

红队攻防渗透实战 1. 中间件安全1.1 中间件-Jetty-CVE&信息泄漏1.2 中间件-Jenkins-CVE&RCE执行1.2.1 cve_2017_1000353 JDK-1.8.0_291 其他版本失效1.2.2 CVE-2018-10008611.2.3 cve_2019_100300 需要用户帐号密码1.3 中间件-Weblogic-CVE&反序列化&RCE1.4 应…

使用python绘制三维曲线图

使用python绘制三维曲线图 三维曲线图定义特点 效果代码 三维曲线图 三维曲线图&#xff08;3D曲线图&#xff09;是一种用于可视化三维数据的图表&#xff0c;它展示了数据在三个维度&#xff08;X、Y、Z&#xff09;上的变化。 定义 三维曲线图通过在三维坐标系中绘制曲线…

数据结构之线性表(4)

前面我们了解到线性表中的顺序表、链表等结构&#xff0c;今天我们探讨新的一种线性表——栈。 那么我们开始栈的探讨之旅吧。 1.栈的基本概念 1.1栈&#xff08;Stack&#xff09;&#xff1a; 是只允许在一端进行插入或删除的线性表。首先栈是一种线性表&#xff0c;但限定…

sudo 用户切换

切换到centos 用户 sudo -i -u centos 解决centos sudo执行仍旧显示Permission denied 方法一&#xff08;建议&#xff09; 暂时切换到root用户 sudo -i然后执行命令即可 方法二 赋给当前用户权限&#xff1a; sudo chmod -R 777 目录路径 sudo chmod 777 文件路径.txt…