RNN经典案例——构建人名分类器

RNN经典案例——人名分类器

  • 一、数据处理
    • 1.1 去掉语言中的重音标记
    • 1.2 读取数据
    • 1.3 构建人名类别与人名对应关系字典
    • 1.4 将人名转换为对应的onehot张量
  • 二、构建RNN模型
    • 2.1 构建传统RNN模型
    • 2.2 构建LSTM模型
    • 2.3 构建GRU模型
  • 三、构建训练函数并进行训练
    • 3.1 从输出结果中获得指定类别函数
    • 3.2 随机生成训练数据
    • 3.3 构建传统的RNN训练函数
    • 3.4 构建LSTM训练函数
    • 3.5 构建GRU训练函数
    • 3.6 构建时间计算函数
    • 3.7 构建训练过程的日志打印函数
    • 3.8 调用train函数, 进行模型的训练
  • 四、构建评估模型并预测
    • 4.1 构建传统RNN评估函数
    • 4.2 构建LSTM评估函数
    • 4.3 构建GRU评估函数
    • 4.4 构建预测函数

一、数据处理

from io import open
import glob
import os
import string 
import unicodedata
import random
import time
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

1.1 去掉语言中的重音标记

# 获取常用字符数量和常用标点
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
print("all_letters:",all_letters)
print("n_letters:",n_letters)

在这里插入图片描述

# 去掉一些语言中的重音标记
# 如: Ślusàrski ---> Slusarski
def unicodeToAscii(s):ascii = ''.join(# NFD会将每个字符分解为其基本字符和组合标记,Ś会拆分为音掉和S#'Mn'这类字符通常用于表示重音符号、音调符号等c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c) != 'Mn' and c in all_letters)return ascii

1.2 读取数据

# 读取数据
data_path = "./data/names/"def readLines(filename):lines = open(filename,encoding='utf-8').read().strip().split('\n')return [unicodeToAscii(line) for line in lines]
# 调试
filename = data_path + "Chinese.txt"
lines = readLines(filename)
print(lines)

在这里插入图片描述

1.3 构建人名类别与人名对应关系字典

# 类别名字列表
category_lines = {}
# 类别名称
all_category = []for filename in glob.glob(data_path + '*.txt'):category = os.path.splitext(os.path.basename(filename))[0]all_category.append(category)lines = readLines(filename)category_lines[category] = lines# 查看类别总数
n_categories = len(all_category)
print("n_categories:",n_categories)

在这里插入图片描述

1.4 将人名转换为对应的onehot张量

def lineToTensor(line):tensor = torch.zeros(len(line),1,n_letters)for i,letter in enumerate(line):tensor[i][0][all_letters.find(letter)] = 1return tensor  
# 调试
line = 'cui'
line_tensor = lineToTensor(line)
line_tensor

在这里插入图片描述

二、构建RNN模型

2.1 构建传统RNN模型

class RNN(nn.Module):def __init__(self,input_size,hidden_size,output_size,num_layers=1):super(RNN,self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers# 实例化RNNself.rnn = nn.RNN(input_size,hidden_size,num_layers)# RNN 层的输出转换为最终的输出特征self.linear = nn.Linear(hidden_size,output_size)# 将全连接层的输出特征转换为概率分布self.softmax = nn.LogSoftmax(dim=-1)def forward(self,input,hidden):# input 形状为1*n_letters需要变换为三维张量input = input.unsqueeze(0)rr,hn = self.rnn(input,hidden)return self.softmax(self.linear(rr)),hn# 定义初始化隐藏状态def initHidden(self):return torch.zeros(self.num_layers,1,self.hidden_size)

2.2 构建LSTM模型

class LSTM(nn.Module):def __init__(self,input_size,hidden_size,output_size,num_layers=1):super(LSTM,self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size,hidden_size,num_layers)self.linear = nn.Linear(hidden_size,output_size)self.softmax = nn.LogSoftmax(dim=-1)def forward(self,input,hidden,c):input = input.unsqueeze(0)rr,(hn,c) = self.lstm(input,(hidden,c))return self.softmax(self.linear(rr)),hn,cdef initHidden(self):hidden = c = torch.zeros(self.num_layers,1,self.hidden_size)return hidden,c

2.3 构建GRU模型

class GRU(nn.Module):def __init__(self,input_size,hidden_size,output_size,num_layers=1):super(GRU,self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size,hidden_size,num_layers)self.linear = nn.Linear(hidden_size,output_size)self.softmax = nn.LogSoftmax(dim=-1)def forward(self,input,hidden):input = input.unsqueeze(0)rr,hn = self.gru(input,hidden)return self.softmax(self.linear(rr)),hndef initHidden(self):return torch.zeros(self.num_layers,1,self.hidden_size)
# 调用
# 实例化参数
input_size = n_letters
n_hidden = 128
output_size = n_categoriesinput = lineToTensor('B').squeeze(0)
hidden = c = torch.zeros(1,1,n_hidden)rnn = RNN(n_letters,n_hidden,n_categories)
lstm = LSTM(n_letters,n_hidden,n_categories)
gru = GRU(n_letters,n_hidden,n_categories)rnn_output, next_hidden = rnn(input, hidden)
print("rnn:", rnn_output)
lstm_output, next_hidden, c = lstm(input, hidden, c)
print("lstm:", lstm_output)
gru_output, next_hidden = gru(input, hidden)
print("gru:", gru_output) 

在这里插入图片描述

三、构建训练函数并进行训练

3.1 从输出结果中获得指定类别函数

def categoryFromOutput(output):# 从输出张量中返回最大的值和索引top_n,top_i = output.topk(1)category_i = top_i[0].item()# 获取对应语言类别, 返回语⾔类别和索引值return all_category[category_i],category_i
# 调试
category, category_i = categoryFromOutput(gru_output)
print("category:", category)
print("category_i:", category_i)

在这里插入图片描述

3.2 随机生成训练数据

def randomTrainingExample():# 随机获取一个类别category = random.choice(all_category)# 随机获取该类别中的名字line = random.choice(category_lines[category])# 将类别索引转换为tensor张量category_tensor = torch.tensor([all_category.index(category)],dtype=torch.long)# 对名字进行onehot编码line_tensor = lineToTensor(line)return category,line,category_tensor,line_tensor
# 调试
for i in range(10):category,line,category_tensor,line_tensor = randomTrainingExample()print('category =',category,'/ line =',line,'/ category_tensor =',category_tensor,'/ line_tensor =',line_tensor)

在这里插入图片描述

3.3 构建传统的RNN训练函数

# 定义损失函数
criterion = nn.NLLLoss()
# 设置学习率为0.005
learning_rate = 0.005
import torch.optim as optim
def trainRNN(category_tensor,line_tensor):# 实例化对象rnn初始化隐层张量hidden = rnn.initHidden()# 梯度清零optimizer = optim.SGD(rnn.parameters(),lr=0.01,momentum=0.9)optimizer.zero_grad()# 前向传播for i in range(line_tensor.size()[0]):# output 是 RNN 在每个时间步的输出。每个时间步的输出是一个隐藏状态,这些隐藏状态可以用于后续的处理,例如分类、回归等任务。# hidden是 RNN 在最后一个时间步的隐藏状态。这些隐藏状态可以用于捕获整个序列的信息,通常用于后续的处理,例如作为下一个层的输入。output,hidden = rnn(line_tensor[i],hidden)# 计算损失loss = criterion(output.squeeze(0),category_tensor)# 反向传播loss.backward()optimizer.step()# 更新模型中的参数#for p in rnn.parameters():#p.data.add_(-learning_rate,p.grad.data)return output,loss.item()

3.4 构建LSTM训练函数

def trainLSTM(category_tensor,line_tensor):hidden,c = lstm.initHidden()lstm.zero_grad()for i in range(line_tensor.size()[0]):output,hidden,c = lstm(line_tensor[i],hidden,c)loss = criterion(output.squeeze(0),category_tensor)loss.backward()for p in lstm.parameters():p.data.add_(-learning_rate,p.grad.data)return output,loss.item()

3.5 构建GRU训练函数

def trainGRU(category_tensor,line_tensor):hidden = gru.initHidden()gru.zero_grad()for i in range(line_tensor.size()[0]):output,hidden = gru(line_tensor[i],hidden)loss = criterion(output.squeeze(0),category_tensor)loss.backward()for p in gru.parameters():p.data.add_(-learning_rate,p.grad.data)return output,loss.item()

3.6 构建时间计算函数

# 获取每次打印的训练耗时
def timeSince(since):# 获得当前时间now = time.time()# 获取时间差s = now - since# 将秒转换为分m = s // 60# 计算不够1分钟的秒数s -= m * 60return '%dm %ds' % (m,s)

3.7 构建训练过程的日志打印函数

# 设置训练迭代次数
n_iters= 1000
# 设置结果的打印间隔
print_every = 50
# 设置绘制损失曲线上的打印间隔
plot_every = 10
def train(train_typr_fn):# 保存每个间隔的损失函数all_losses = []# 获得训练开始的时间戳start = time.time()# 设置当前间隔损失为0current_loss = 0# 循环训练for iter in range(1,n_iters+1):category,line,category_tensor,line_tensor = randomTrainingExample()output,loss = train_typr_fn(category_tensor,line_tensor)# 计算打印间隔的总损失current_loss += lossif iter % print_every == 0:# 获得预测的类别和索引guess,guess_i = categoryFromOutput(output)if guess == category:correct = '✓'else:correct = '✗(%s)' % categoryprint('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters *100, timeSince(start), loss, line, guess, correct))if iter % plot_every == 0:all_losses.append(current_loss / plot_every)current_loss = 0 return all_losses,int(time.time()-start)       

3.8 调用train函数, 进行模型的训练

# 调⽤train函数, 分别进⾏RNN, LSTM, GRU模型的训练
all_losses1, period1 = train(trainRNN)
all_losses2, period2 = train(trainLSTM)
all_losses3, period3 = train(trainGRU)

在这里插入图片描述

# 创建画布0
plt.figure(0)
# 绘制损失对⽐曲线
plt.plot(all_losses1, label="RNN")
plt.plot(all_losses2, color="red", label="LSTM")
plt.plot(all_losses3, color="orange", label="GRU")
plt.legend(loc='upper left')# 创建画布1
plt.figure(1)
x_data=["RNN", "LSTM", "GRU"]
y_data = [period1, period2, period3]
# 绘制训练耗时对⽐柱状图
plt.bar(range(len(x_data)), y_data, tick_label=x_data)

在这里插入图片描述
在这里插入图片描述

四、构建评估模型并预测

4.1 构建传统RNN评估函数

# 构建传统RNN评估函数
def evaluateRNN(line_tensor):hidden = rnn.initHidden()for i in range(line_tensor.size()[0]):output,hidden = rnn(line_tensor[i],hidden)return output.squeeze(0)

4.2 构建LSTM评估函数

# 构建LSTM评估函数
def evaluateLSTM(line_tensor):hidden,c = lstm.initHidden()for i in range(line_tensor.size()[0]):output,hidden,c = lstm(line_tensor[i],hidden,c)return output.squeeze(0)

4.3 构建GRU评估函数

# 构建GRU评估函数
def evaluateGRU(line_tensor):hidden = gru.initHidden()for i in range(line_tensor.size()[0]):output,hidden = gru(line_tensor[i],hidden)return output.squeeze(0)
# 调试
line = "Bai"
line_tensor = lineToTensor(line)rnn_output = evaluateRNN(line_tensor)
lstm_output = evaluateLSTM(line_tensor)
gru_output = evaluateGRU(line_tensor)
print("rnn_output:", rnn_output)
print("gru_output:", lstm_output)
print("gru_output:", gru_output)

在这里插入图片描述

4.4 构建预测函数

def predict(input_line,evaluate,n_predictions=3):print('\n> %s' % input_line)with torch.no_grad():output = evaluate(lineToTensor(input_line))topv,topi = output.topk(n_predictions,1,True)predictions = []for i in range(n_predictions):# 从topv中取出的output值value = topv[0][i].item()# 取出索引并找到对应的类别category_index = topi[0][i].item()# 打印ouput的值, 和对应的类别print('(%.2f) %s' % (value, all_category[category_index]))# 将结果装进predictions中predictions.append([value, all_category[category_index]])
for evaluate_fn in [evaluateRNN, evaluateLSTM, evaluateGRU]:print("-"*18)predict('Dovesky', evaluate_fn)predict('Jackson', evaluate_fn)predict('Satoshi', evaluate_fn)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/444823.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TON生态小游戏开发:推广、经济模型与UI设计的建设指南

随着区块链技术的快速发展,基于区块链的Web3游戏正引领行业变革。而TON生态小游戏,借助Telegram庞大的用户基础和TON(The Open Network)链上技术,已成为这一领域的明星之一。国内外开发者正迅速涌入,开发和…

基于SpringBoot+Vue的船舶监造系统(带1w+文档)

基于SpringBootVue的船舶监造系统(带1w文档) 基于SpringBootVue的船舶监造系统(带1w文档) 大概在20世纪90年代,我国才开始研发船舶监造系统,与一些发达国家相比,系统研发起步比较晚。当时的计算机技术刚开始发展起来,国家经济力量…

SEO(搜索引擎优化)指南

SEO(Search Engine Optimization)是通过优化网站内容、结构和外部链接,提升网页在搜索引擎结果中的排名,从而增加网站流量的过程。SEO 涉及多个层面,包括技术 SEO、内容优化、外部链接建设等。以下是 SEO 的核心优化策…

京东零售数据湖应用与实践

作者:陈洪健:京东零售大数据架构师,深耕大数据 10 年,2019 年加入京东,主要负责 OLAP 优化、大数据传输工具生态、流批一体、SRE 建设。 当前企业数据处理广泛采用 Lambda 架构。Lambda 架构的优点是保证了数据的完整性…

【论文阅读】Learning a Few-shot Embedding Model with Contrastive Learning

使用对比学习来学习小样本嵌入模型 引用:Liu, Chen, et al. “Learning a few-shot embedding model with contrastive learning.” Proceedings of the AAAI conference on artificial intelligence. Vol. 35. No. 10. 2021. 论文地址:下载地址 论文代码…

强化学习笔记之【SAC算法】

强化学习笔记之【SAC算法】 前言: 本文为强化学习笔记第三篇,第一篇讲的是Q-learning和DQN,第二篇DDPG,第三篇TD3 TD3比DDPG少了一个target_actor网络,其它地方有点小改动 CSDN主页:https://blog.csdn.n…

思迈特:在AI时代韧性增长的流量密码

作者 | 曾响铃 文 | 响铃说 “超级人工智能将在‘几千天内’降临。” 最近,OpenAI 公司 CEO 山姆奥特曼在社交媒体罕见发表长文,预言了这一点。之前,很多专家预测超级人工智能将在五年内到来,奥特曼的预期,可能让这…

图论day57|建造最大岛屿(卡码网)【截至目前,图论的最高难度】

图论day57|建造最大岛屿(卡码网)【截至目前所做的题中,图论的最高难度】 思维导图分析 104.建造最大岛屿(卡码网)【截至目前所做的题中,图论的最高难度】 思维导图分析 104.建造最大岛屿(卡码网…

i18n多语言项目批量翻译工具(支持84种语言)

这里写自定义目录标题 打开‘i18n翻译助手’小程序快捷访问 打开‘i18n翻译助手’小程序 1.将需要翻译的json文件复制到输入框(建议一次不要翻译过多,测试1000条以内没什么问题) 2.等待翻译 3.翻译完成,复制结果 快捷访问

从容应对DDoS攻击:小网站的防守之战

前几天收到云服务商短信,服务器正在遭受DDoS攻击 说实话,我的网站只是一个小型站点,平时访问量并不高,没想到会成为攻击的目标。当我看到这次DDoS攻击的通知时,我其实既惊讶又有点小小的“荣幸”,毕竟我的小…

火山引擎边缘智能×扣子,拓展AI Agent物理边界

9月21日, 火山引擎边缘智能扣子技术沙龙在上海圆满落地,沙龙以“探索端智能,加速大模型应用”为主题,边缘智能、扣子、地瓜机器人以及上海交大等多位重磅嘉宾出席,分享 AI 最新趋势及端侧大模型最新探索与应用实践。 …

Java项目-----图形验证码登陆实现

原理: 验证码在前端显示,但是是在后端生成, 将生成的验证码存入redis,待登录时,前端提交验证码,与后端生成的验证码比较. 详细解释: 图形验证码的原理(如下图代码).前端发起获取验证码的请求后, 1 后端接收请求,生成一个键key(随机的键) 然后生成一个验证码作为map的valu…

JAVA接入GPT开发

Spring AI Alibaba:Java开发者的GPT集成新标准 目前,像OpenAI等GPT服务提供商主要提供HTTP接口,这导致大部分Java开发者在接入GPT时缺乏标准化的方法。为解决这一问题,Spring团队推出了Spring AI Alibaba,它作为一套标…

基于Java的可携宠物酒店管理系统的设计与实现(论文+源码)_kaic

摘 要 随着社会经济的不断发‎‏展,现如今出行并住酒店的人越来越多,与之而来的是酒店行业的工作量日益增加,酒店的管理效率亟待提升。此外很多人出门旅游时会有携带宠物的情况,但是现如今酒店对宠物的限制,导致许多…

Java学习-JVM

目录 1. 基本常识 1.1 JVM是什么 1.2 JVM架构图 1.3 Java技术体系 1.4 Java与JVM的关系 2. 类加载系统 2.1 类加载器种类 2.2 执行顺序 2.3 类加载四个时机 2.4 生命周期 2.5 类加载途径 2.6 双亲委派模型 3. 运行时数据区 3.1 运行时数据区构成 3.2 堆 3.3 栈…

【RabbitMQ高级——过期时间TTL+死信队列】

1. 过期时间TTL概述 过期时间TTL表示可以对消息设置预期的时间,在这个时间内都可以被消费者接收获取;过了之后消息将自动被删除。RabbitMQ可以对消息和队列设置TTL。 目前有两种方法可以设置。 第一种方法是通过队列属性设置,队列中所有消…

基于Springboot的宠物咖啡馆平台的设计与实现(源码+定制+参考)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

【操作系统】四、文件管理:1.文件系统基础(文件属性、文件逻辑结构、文件物理结构、文件存储管理、文件目录、基本操作、文件共享、文件保护)

文件管理 文章目录 文件管理八、文件系统基础1.文件的属性2.文件的逻辑结构2.1顺序文件2.2索引文件2.3索引顺序文件2.4多级索引顺序文件 3.目录文件❗3.1文件控制块FCB3.1.1对目录进行的操作 3.2目录结构3.2.1单级目录结构3.2.2两级目录结构3.2.3多级目录结构(树形目…

【大模型部署】本地运行自己的大模型--ollama

ollama简介 ollama是一款开源的、轻量级的框架,它可以快速在本地构建及运行大模型,尤其是一些目前最新开源的模型,如 Llama 3, Mistral, Gemma等。 官网上有大量已经开源的模型,部分针对性微调过的模型也可以选择到,…

Qt源码-Qt多媒体音频框架

Qt 多媒体音频框架 一、概述二、音频设计1. ALSA 基础2. Qt 音频类1. 接口实现2. alsa 插件实现 一、概述 环境详细Qt版本Qt 5.15操作系统Deepin v23代码工具Visual Code源码https://github.com/qt/qtmultimedia/tree/5.15 这里记录一下在Linux下Qt 的 Qt Multimedia 模块的设…