基于LSTM实现春联上联对下联

按照阿光的项目做出了学习笔记,pytorch深度学习实战项目100例

基于LSTM实现春联上联对下联

基于LSTM(长短期记忆网络)实现春联上联对下联是一种有趣且具有挑战性的任务,它涉及到自然语言处理(NLP)中的序列到序列(seq2seq)模型。LSTM是处理序列数据的理想选择,因为它能够记住长期的依赖信息,这对于生成符合语境和文化习俗的春联下联至关重要。

数据

https://github.com/wb14123/couplet-dataset

感谢大佬的分享的对联数据集

对数据集的处理

def data_generator(data):# 计算每个对联长度的权重data_probability = [float(len(x)) for wordcount, [x, y] in data.items()]  # [每个字数key对应对联list中上联数据的个数]data_probability = np.array(data_probability) / sum(data_probability)  # 标准化至[0,1],这是每个字数的权重# 随机选择字数,然后随机选择字数对应的上联样本,生成batchfor idx in range(15):# 随机选字数id,概率为上面计算的字数权重idx = idx + 1size = min(batch_size, len(data[idx][0]))  # batch_size=64,len(data[idx][0])随机选择的字数key对应的上联个数# 从上联列表下标list中随机选出大小为size的listidxs = np.random.choice(len(data[idx][0]), size=size)# 返回选出的上联X与下联y, 将原本1-d array维度扩展为(row,col,1)yield data[idx][0][idxs], np.expand_dims(data[idx][1][idxs], axis=2)# 加载文本数据
def load_data(input_path, output_path):# 数据读取与切分def read_data(file_path):txt = codecs.open(file_path, encoding='utf-8').readlines()txt = [line.strip().split(' ') for line in txt]  # 每行按空格切分txt = [line for line in txt if len(line) < 16]  # 过滤掉字数超过maxlen的对联return txt# 产生数据字典def generate_count_dict(result_dict, x, y):for i, idx in enumerate(x):j = len(idx)if j not in result_dict:result_dict[j] = [[], []]  # [样本数据list,类别标记list]result_dict[j][0].append(idx)result_dict[j][1].append(y[i])return result_dict# 将字典数据转为numpydef to_numpy_array(dict):for count, [x, y] in dict.items():dict[count][0] = np.array(x)dict[count][1] = np.array(y)return dictx = read_data(input_path)y = read_data(output_path)# 获取词表vocabulary = x + y# 构造字符级别的特征string = ''for words in vocabulary:for word in words:string += word# 所有的词汇表vocabulary = set(string)word2idx = {word: i for i, word in enumerate(vocabulary)}idx2word = {i: word for i, word in enumerate(vocabulary)}# 训练数据中所有词的个数vocab_size = len(word2idx.keys())  # 词汇表大小# 将x和y转为数值x = [[word2idx[word] for word in sent] for sent in x]y = [[word2idx[word] for word in sent] for sent in y]train_dict = {}train_dict = generate_count_dict(train_dict, x, y)train_dict = to_numpy_array(train_dict)return train_dict, vocab_size, idx2word, word2idx

在这里插入图片描述
基本想法:
这种场景是典型的 Encoder-Decoder 框架应用问题。

在这个框架中:

  • Encoder 负责读取输入序列(上联)并将其转换成一个固定长度的内部表示形式,通常是最后一个时间步的隐藏状态。这个内部表示被视为输入序列的“上下文”或“意义”,包含了生成输出序列所需的所有信息。
  • Decoder 接收这个内部表示并开始生成输出序列(下联),一步一步地生成,直到产生序列结束标记或达到特定长度。

在这里插入图片描述

构建模型

模型架构:使用seq2seq模型,该模型一般包括一个编码器(encoder)和一个解码器(decoder),两者都可以是LSTM网络。编码器负责处理上联,而解码器则生成下联。
嵌入层:通常在模型的第一层使用嵌入层,将每个字符或词转换为固定大小的向量,这有助于模型更好地理解语言中的语义信息。
在这里插入图片描述

# 定义网络结构
class LSTM(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):super(LSTM, self).__init__()self.hidden_dim = hidden_dimself.embeddings = nn.Embedding(vocab_size + 1, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers)self.linear = nn.Linear(hidden_dim, vocab_size)def forward(self, x):time_step, batch_size = x.size()  # 124, 16embeds = self.embeddings(x)output, (h_n, c_n) = self.lstm(embeds)output = self.linear(output.reshape(time_step * batch_size, -1))# 要返回所有时间点的数据,每个时间点对应一个字,也就是vocab_size维度的向量return output

训练模型

# 加载数据
train_dict, vocab_size, idx2word, word2idx = load_data(input_path, output_path)# 模型训练
model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim,embedding_dim=embedding_dim, num_layers=num_layers)Configimizer = optim.Adam(model.parameters(), lr=lr)  # 优化器
criterion = nn.CrossEntropyLoss()  # 多分类损失函数model.to(device)
loss_meter = meter.AverageValueMeter()best_loss = 999  # 保存loss
best_model = None  # 保存对应最好准确率的模型参数for epoch in range(epochs):model.train()  # 开启训练模式loss_meter.reset()for x, y in data_generator(train_dict):x = torch.from_numpy(x).long().transpose(1, 0).contiguous()x = x.to(device)y = torch.from_numpy(y).long().transpose(1, 0).contiguous()y = y.to(device)Configimizer.zero_grad()# 形成预测结果output_ = model(x)# 计算损失loss = criterion(output_, y.long().view(-1))loss.backward()Configimizer.step()loss_meter.add(loss.item())# 打印信息print("【EPOCH: 】%s" % str(epoch + 1))print("训练损失为%s" % (str(loss_meter.mean)))# 保存模型及相关信息if loss_meter.mean < best_loss:best_loss = loss_meter.meanbest_model = model.state_dict()# 在训练结束保存最优的模型参数if epoch == epochs - 1:# 保存模型torch.save(best_model, './best_model.pkl')

测试

import codecsimport numpy as np
import torch
from torch import nn
from torch import optim
from torchnet import meter# 模型输入参数,需要自己根据需要调整
input_path = 'C:\\Users\\kaai\\AppData\\Local\\Temp\\BNZ.65e95f542f0fca6f\\train\\in.txt'
output_path = 'C:\\Users\\kaai\\AppData\\Local\\Temp\\BNZ.65e95f542f0fca6f\\train\\out.txt'
num_layers = 1  # LSTM层数
hidden_dim = 100  # LSTM中的隐层大小
epochs = 50  # 迭代次数
batch_size = 128  # 每个批次样本大小
embedding_dim = 15  # 每个字形成的嵌入向量大小
lr = 0.01  # 学习率
device = 'cpu'  # 设备# 用于生成训练数据
def data_generator(data):# 计算每个对联长度的权重data_probability = [float(len(x)) for wordcount, [x, y] in data.items()]  # [每个字数key对应对联list中上联数据的个数]data_probability = np.array(data_probability) / sum(data_probability)  # 标准化至[0,1],这是每个字数的权重# 随机选择字数,然后随机选择字数对应的上联样本,生成batchfor idx in range(15):# 随机选字数id,概率为上面计算的字数权重idx = idx + 1size = min(batch_size, len(data[idx][0]))  # batch_size=64,len(data[idx][0])随机选择的字数key对应的上联个数# 从上联列表下标list中随机选出大小为size的listidxs = np.random.choice(len(data[idx][0]), size=size)# 返回选出的上联X与下联y, 将原本1-d array维度扩展为(row,col,1)yield data[idx][0][idxs], np.expand_dims(data[idx][1][idxs], axis=2)# 加载文本数据
def load_data(input_path, output_path):# 数据读取与切分def read_data(file_path):txt = codecs.open(file_path, encoding='utf-8').readlines()txt = [line.strip().split(' ') for line in txt]  # 每行按空格切分txt = [line for line in txt if len(line) < 16]  # 过滤掉字数超过maxlen的对联return txt# 产生数据字典def generate_count_dict(result_dict, x, y):for i, idx in enumerate(x):j = len(idx)if j not in result_dict:result_dict[j] = [[], []]  # [样本数据list,类别标记list]result_dict[j][0].append(idx)result_dict[j][1].append(y[i])return result_dict# 将字典数据转为numpydef to_numpy_array(dict):for count, [x, y] in dict.items():dict[count][0] = np.array(x)dict[count][1] = np.array(y)return dictx = read_data(input_path)y = read_data(output_path)# 获取词表vocabulary = x + y# 构造字符级别的特征string = ''for words in vocabulary:for word in words:string += word# 所有的词汇表vocabulary = set(string)word2idx = {word: i for i, word in enumerate(vocabulary)}idx2word = {i: word for i, word in enumerate(vocabulary)}# 训练数据中所有词的个数vocab_size = len(word2idx.keys())  # 词汇表大小# 将x和y转为数值x = [[word2idx[word] for word in sent] for sent in x]y = [[word2idx[word] for word in sent] for sent in y]train_dict = {}train_dict = generate_count_dict(train_dict, x, y)train_dict = to_numpy_array(train_dict)return train_dict, vocab_size, idx2word, word2idx# 定义网络结构
class LSTM(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):super(LSTM, self).__init__()self.hidden_dim = hidden_dimself.embeddings = nn.Embedding(vocab_size + 1, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers)self.linear = nn.Linear(hidden_dim, vocab_size)def forward(self, x):time_step, batch_size = x.size()  # 124, 16embeds = self.embeddings(x)output, (h_n, c_n) = self.lstm(embeds)output = self.linear(output.reshape(time_step * batch_size, -1))# 要返回所有时间点的数据,每个时间点对应一个字,也就是vocab_size维度的向量return outputdef couplet_match(s):# 将字符串转为数值x = [word2idx[word] for word in s]# 将数值向量转为tensorx = torch.from_numpy(np.array(x).reshape(-1, 1))# 加载模型model_path = './best_model.pkl'model = LSTM(vocab_size=vocab_size, hidden_dim=hidden_dim,embedding_dim=embedding_dim, num_layers=num_layers)model.load_state_dict(torch.load(model_path, 'cpu'))y = model(x)y = y.argmax(axis=1)r = ''.join([idx2word[idx.item()] for idx in y])print('上联:%s,下联:%s' % (s, r))
# 加载数据
train_dict, vocab_size, idx2word, word2idx = load_data(input_path, output_path)
# 测试
sentence = '恭喜发财'
couplet_match(sentence)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/275276.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Stable diffusion(一)

Stable diffusion 原理解读 名词解释 正向扩散&#xff08;Fixed Forward Diffusion Process&#xff09;&#xff1a;反向扩散&#xff08;Generative Reverse Denoising Process&#xff09; VAE&#xff08;Variational AutoEncoder&#xff09;&#xff1a;一个用于压缩图…

Mysql/Redis缓存一致性

如何保证MySQL和Redis的缓存一致。从理论到实战。总结6种来感受一下。 理论知识 不好的方案 1.先写MySQL&#xff0c;再写Redis 图解说明: 这是一幅时序图&#xff0c;描述请求的先后调用顺序&#xff1b; 黄色的线是请求A&#xff0c;黑色的线是请求B&#xff1b; 黄色的…

智慧城市与绿色出行:共同迈向低碳未来

随着城市化进程的加速&#xff0c;交通拥堵、空气污染、能源消耗等问题日益凸显&#xff0c;智慧城市与绿色出行成为了解决这些问题的关键途径。智慧城市利用信息技术手段&#xff0c;实现城市各领域的智能化管理和服务&#xff0c;而绿色出行则强调低碳、环保的出行方式&#…

LInux系统架构----Nginx模块rewrite的规则与应用场景

LInux系统架构----Nginx模块rewrite的规则与应用场景 一.rewrite跳转实现 Nginx实现跳转通过ngx_http_rewrite_module模块支持URL重写、支持if条件判断&#xff0c;但是不支持else跳转时&#xff0c;循环最多可以执行10次&#xff0c;超过后nginx将返回500错误注&#xff1a;…

STM32 | STM32F407ZE中断、按键、灯(续第三天)

上节回顾 STM32 | 库函数与寄存器开发区别及LED等和按键源码(第三天)一、 中断 中断概念 中断是指计算机运行过程中,出现某些意外情况需主机干预时,机器能自动停止正在运行的程序并转入处理新情况的程序,处理完毕后又返回原被暂停的程序继续运行(面试题)。 STM32外部中断…

【设计模式】设计原则和常见的23种经典设计模式

设计模式 1. 设计原则&#xff08;记忆口诀&#xff1a;SOLID&#xff09;【记忆口诀&#xff1a;单开里依接迪合&#xff08;单开礼仪接地和&#xff09;】 &#xff08;1&#xff09;单一职责原则&#xff08;Single Responsibility Principle, SRP&#xff09; &#xff…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的常见手势识别系统(深度学习模型+UI界面代码+训练数据集)

摘要&#xff1a;开发手势识别系统对于增强人机交互和智能家居控制领域的体验非常关键。本博客详尽阐述了通过深度学习技术构建手势识别系统的过程&#xff0c;并附上了全套实施代码。系统采用了先进的YOLOv8算法&#xff0c;并通过与YOLOv7、YOLOv6、YOLOv5的性能对比&#xf…

Kafka 面试题及答案整理,最新面试题

Kafka中的Producer API是如何工作的&#xff1f; Kafka中的Producer API允许应用程序发布一流的数据到一个或多个Kafka主题。它的工作原理包括&#xff1a; 1、创建Producer实例&#xff1a; 通过配置Producer的各种属性&#xff08;如服务器地址、序列化方式等&#xff09;来…

个人博客系统(测试报告)

一、项目背景 一个Web网站程序&#xff0c;你可以观看到其他用户博客也可以登录自己的账号发布博客&#xff0c;通过使用Selenium定位web元素、操作测试对象等方法来对个人博客系统的进行测试&#xff0c;测试的核心内容有用户登录、博客列表及博客数量的展示、查看全文、写博客…

liteIDE 解决go root报错 go: cannot find GOROOT directory: c:\go

liteIDE环境配置 我使用的liteIDE为 x36 5.9.5版本 。在查看–>选项 中可以看到 LiteEnv&#xff0c;双击LiteEnv &#xff0c;在右侧选择对应系统的env文件&#xff0c;我的是win64系统&#xff0c;所以文件名为win64.env 再双击 win64.env &#xff0c;关闭当前窗口&…

Linux内核编译(版本6.0以及版本v0.01)并用qemu驱动

系统环境&#xff1a; ubuntu-22.04.1-desktop-amd64 目标平台: x86 i386 内核版本: linux-6.0.1 linux-0.0.1 环境配置 修改root密码 sudo passwd 修改软件源&#xff08;非必要&#xff09; vmtools安装&#xff08;实现win-linux软件互传&#xff09; 安装一些必须的软件&…

DevOps本地搭建笔记(个人开发适用)

需求和背景 win11 wsl2 armbian(玩客云矿渣&#xff09;&#xff0c;构建个人cicd流水线&#xff0c;提高迭代效率。 具体步骤 基础设施准备 硬件准备&#xff1a;一台笔记本&#xff0c;用于开发和构建部署&#xff0c;一台服务器&#xff0c;用于日常服务运行。 笔记本…

Celery知识

celery介绍 # celery 的概念&#xff1a; * 翻译过来是芹菜 * 官网&#xff1a;https://docs.celeryq.dev/en/stable/ # 是分布式的异步任务框架&#xff1a; 分布式&#xff1a;一个任务&#xff0c;拆成多个任务在不同机器上做 异步任务&#xff1a;后台执行…

【Greenhills】MULTIIDE集成第三方的编辑器进行源文件编辑工作

【更多软件使用问题请点击亿道电子官方网站查询】 1、 文档目标 在使用GHS进行工作的时候&#xff0c;可以集成第三方的编辑器进行源文件编辑工作 2、 问题场景 用于解决在GHS中进行项目开发时&#xff0c;对于GHS的编辑器使用不习惯&#xff0c;想要切换到其他第三方的编辑…

漏洞发现-漏扫项目篇武装BURP浏览器插件信息收集分析辅助

知识点 1、插件类-武装BurpSuite-漏洞检测&分析辅助 2、插件类-武装谷歌浏览器-信息收集&情报辅助 章节点&#xff1a; 漏洞发现-Web&框架组件&中间件&APP&小程序&系统 扫描项目-综合漏扫&特征漏扫&被动漏扫&联动漏扫 Poc开发-Ymal语…

Qt QDateTime类使用

一.Qt datetime 介绍 Qt中的QDateTime类是用于处理日期和时间的组合的类&#xff0c;它提供了丰富的功能来操作和格式化日期时间数据。以下是其主要特点和用法&#xff1a; 构造函数&#xff1a;QDateTime可以通过组合QDate&#xff08;日期&#xff09;和QTime&#xff08;时…

TypeScript编译选项

编译单个文件&#xff1a;终端 tsc 文件名 自动编译单个文件&#xff1a;终端 tsc 文件名 -w 编译整个项目&#xff1a;tsc 前提是得有ts的配置文件tsconfig.json 自动编译整个项目&#xff1a;tsc --w tsconfig.json默认文件内容&#xff1a; tsconfig.json是ts编译器的配…

阿里云服务器Ngnix配置SSL证书开启HTTPS访问

文章目录 前言一、SSL证书是什么&#xff1f;二、如何获取免费SSL证书三、Ngnix配置SSL证书总结 前言 很多童鞋的网站默认访问都是通过80端口的Http服务进行访问&#xff0c;往往都会提示不安全&#xff0c;很多人以为Https有多么高大上&#xff0c;实际不然&#xff0c;他只是…

C库函数-getopt函数总结学习

1、简介 getopt函数是命令行参数解析函数 1、1命令行组成 Command name 程序文件名 operands 操作对象 option 选项 option argument 选项参数 getopt()函数将传递给mian()函数的argc,argv作为参数&#xff0c;同时接受字符串参数optstring – optstring是由选项Option字母组…

前端Vue列表组件 list组件:实现高效数据展示与交互

前端Vue列表组件 list组件&#xff1a;实现高效数据展示与交互 摘要&#xff1a;在前端开发中&#xff0c;列表组件是展示数据的重要手段。本文将介绍如何使用Vue.js构建一个高效、可复用的列表组件&#xff0c;并探讨其在实际项目中的应用。 效果图如下&#xff1a; 一、引言…