基于BILSTM及其他RNN序列模型的人名分类器

数据集Kaggle链接

NameNationalLanguage | Kaggle

数据集分布:

第一列为人名,第二列为国家标签

代码开源地址

Kaggle代码链接

https://www.kaggle.com/code/houjijin/name-nationality-classification

Gitee码云链接

人名国籍分类 Name Nation classification: using BILSTM to predict individual's nationality by their name

github链接

GitHub - Foxbabe1q/Name-Nation-classification: Use BILSTM to do the classification of individuals by their names

RNN序列模型类编写

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.input_size = input_sizeself.num_layers = num_layersself.output_size = 18self.rnn = nn.RNN(input_size, hidden_size, num_layers = num_layers, batch_first=True)self.fc = nn.Linear(self.hidden_size, self.output_size)def forward(self, x, hidden):output, hidden = self.rnn(x, hidden)output = output[:, -1, :]output = self.fc(output)return output, hiddendef init_hidden(self, batch_size):hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)return hiddenclass SimpleLSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers):super(SimpleLSTM, self).__init__()self.hidden_size = hidden_sizeself.input_size = input_sizeself.num_layers = num_layersself.output_size = 18self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)self.fc = nn.Linear(self.hidden_size, self.output_size)def forward(self, x, hidden, c):output, (hidden, c) = self.rnn(x, (hidden, c))output = output[:, -1, :]output = self.fc(output)return output, hidden, cdef init_hidden(self, batch_size):hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)return hidden, c0class SimpleBILSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers):super(SimpleBILSTM, self).__init__()self.hidden_size = hidden_sizeself.input_size = input_sizeself.num_layers = num_layersself.output_size = 18self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)self.fc = nn.Linear(self.hidden_size*2, self.output_size)def forward(self, x, hidden, c):output, (hidden, c) = self.rnn(x, (hidden, c))output = output[:, -1, :]output = self.fc(output)return output, hidden, cdef init_hidden(self, batch_size):hidden = torch.zeros(self.num_layers*2, batch_size, self.hidden_size, device=device)c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, device=device)return hidden, c0class SimpleGRU(nn.Module):def __init__(self, input_size, hidden_size, num_layers):super(SimpleGRU, self).__init__()self.hidden_size = hidden_sizeself.input_size = input_sizeself.num_layers = num_layersself.output_size = 18self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True)self.fc = nn.Linear(self.hidden_size, self.output_size)def forward(self, x, hidden):output, hidden = self.rnn(x, hidden)output = output[:, -1, :]output = self.fc(output)return output, hiddendef init_hidden(self, batch_size):hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)return hidden

注意这里BILSTM类中,由于双向lstm会使用两个lstm模型分别处理前向序列和反向序列,所以在初始化隐藏层和记忆细胞层的时候要设置num_layers为2.

导包

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from RNN_Series1 import SimpleRNN, SimpleLSTM, SimpleGRU, SimpleBILSTM
from torch.utils.data import Dataset, DataLoader
import string
from sklearn.preprocessing import LabelEncoder
import time

字符序列及device定义

letters = string.ascii_letters + " .,;'"
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

数据读取及标签列编码

def load_data():data = pd.read_csv('name_classfication.txt', sep='\t', names = ['name', 'country'])X = data[['name']]lb = LabelEncoder()y = data['country']y = lb.fit_transform(y)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)return X_train, X_test, y_train, y_test

数据集定义

class create_dataset(Dataset):def __init__(self, X, y):self.X = Xself.y = yself.length = len(self.X)def __len__(self):return self.lengthdef __getitem__(self, idx):data = torch.zeros(10, len(letters), dtype = torch.float, device=device)for i, letter in enumerate(self.X.iloc[idx,0]):if i==10:breakdata[i,letters.index(letter)] = 1label = torch.tensor(self.y[idx], dtype = torch.long, device=device)return data, label

这里使用字符序列进行独热编码,并且由于名字长度不一,所以经过序列长度分布,选取了10作为截断长度.

使用RNN训练

def train_rnn():X_train, X_test, y_train, y_test = load_data()criterion = nn.CrossEntropyLoss(reduction='sum')loss_list = []acc_list = []val_acc_list = []val_loss_list = []epochs = 10my_dataset = create_dataset(X_train, y_train)val_dataset = create_dataset(X_test, y_test)my_dataloader = DataLoader(my_dataset, batch_size=64, shuffle=True)val_dataloader = DataLoader(val_dataset, batch_size=len(y_test), shuffle=True)my_rnn = SimpleRNN(len(letters), 128,2)my_rnn.to(device)optimizer = torch.optim.Adam(my_rnn.parameters(), lr=0.001)start_time = time.time()for epoch in range(epochs):my_rnn.train()total_loss = 0total_acc = 0total_sample = 0for i, (X,y) in enumerate(my_dataloader):output, hidden = my_rnn(X, my_rnn.init_hidden(batch_size=len(y)))total_sample += len(y)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()prediction = output.argmax(dim=1)acc_num = torch.sum(prediction == y).item()total_acc += acc_numloss_list.append(total_loss/total_sample)acc_list.append(total_acc/total_sample)my_rnn.eval()with torch.no_grad():for i, (X_val, y_val) in enumerate(val_dataloader):output, hidden = my_rnn(X_val, my_rnn.init_hidden(batch_size=len(y_test)))loss = criterion(output, y_val)prediction = output.argmax(dim=1)acc_num = torch.sum(prediction == y_val).item()val_acc_list.append(acc_num/len(y_val))val_loss_list.append(loss.item()/len(y_val))print(f'epoch: {epoch+1}, train_loss: {total_loss/total_sample:.2f}, train_acc: {total_acc/total_sample:.2f}, val_loss: {loss.item()/len(y_val):.2f}, val_acc: {acc_num/len(y_val):.2f}, time: {time.time() - start_time : .2f}')torch.save(my_rnn.state_dict(), 'rnn.pt')plt.plot(np.arange(1,11),loss_list,label = 'Training Loss')plt.plot(np.arange(1,11),val_loss_list,label = 'Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.xticks(np.arange(1,11))plt.title('Loss')plt.legend()plt.savefig('logg.png')plt.show()plt.plot(np.arange(1,11),acc_list,label = 'Training Accuracy')plt.plot(np.arange(1,11),val_acc_list,label = 'Validation Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.xticks(np.arange(1,11))plt.title('Accuracy')plt.legend()plt.savefig('accuracy.png')plt.show()

使用BILSTM训练

def train_bilstm():X_train, X_test, y_train, y_test = load_data()criterion = nn.CrossEntropyLoss(reduction='sum')loss_list = []acc_list = []val_acc_list = []val_loss_list = []epochs = 10my_dataset = create_dataset(X_train, y_train)val_dataset = create_dataset(X_test, y_test)my_dataloader = DataLoader(my_dataset, batch_size=64, shuffle=True)val_dataloader = DataLoader(val_dataset, batch_size=len(y_test), shuffle=True)my_rnn = SimpleBILSTM(len(letters), 128,2)my_rnn.to(device)optimizer = torch.optim.Adam(my_rnn.parameters(), lr=0.001)start_time = time.time()for epoch in range(epochs):my_rnn.train()total_loss = 0total_acc = 0total_sample = 0for i, (X,y) in enumerate(my_dataloader):hidden,c0 = my_rnn.init_hidden(batch_size=len(y))output, hidden,c = my_rnn(X, hidden,c0)total_sample += len(y)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()prediction = output.argmax(dim=1)acc_num = torch.sum(prediction == y).item()total_acc += acc_numloss_list.append(total_loss/total_sample)acc_list.append(total_acc/total_sample)my_rnn.eval()with torch.no_grad():for i, (X_val, y_val) in enumerate(val_dataloader):hidden, c0 = my_rnn.init_hidden(batch_size=len(y_val))output, hidden ,c= my_rnn(X_val, hidden,c0)loss = criterion(output, y_val)prediction = output.argmax(dim=1)acc_num = torch.sum(prediction == y_val).item()val_acc_list.append(acc_num/len(y_val))val_loss_list.append(loss.item()/len(y_val))print(f'epoch: {epoch+1}, train_loss: {total_loss/total_sample:.2f}, train_acc: {total_acc/total_sample:.2f}, val_loss: {loss.item()/len(y_val):.2f}, val_acc: {acc_num/len(y_val):.2f}, time: {time.time() - start_time : .2f}')torch.save(my_rnn.state_dict(), 'bilstm.pt')plt.plot(np.arange(1,11),loss_list,label = 'Training Loss')plt.plot(np.arange(1,11),val_loss_list,label = 'Validation Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.xticks(np.arange(1,11))plt.title('Loss')plt.legend()plt.savefig('loss.png')plt.show()plt.plot(np.arange(1,11),acc_list,label = 'Training Accuracy')plt.plot(np.arange(1,11),val_acc_list,label = 'Validation Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.xticks(np.arange(1,11))plt.title('Accuracy')plt.legend()plt.savefig('accuracy.png')plt.show()

这里超参数设置为epochs:10,lr:1e-3,Adam优化器

epoch: 1, train_loss: 1.70, train_acc: 0.51, val_loss: 1.50, val_acc: 0.56, time:  11.83
epoch: 2, train_loss: 1.36, train_acc: 0.60, val_loss: 1.25, val_acc: 0.64, time:  22.84
epoch: 3, train_loss: 1.19, train_acc: 0.65, val_loss: 1.10, val_acc: 0.69, time:  33.76
epoch: 4, train_loss: 1.05, train_acc: 0.69, val_loss: 0.97, val_acc: 0.72, time:  44.63
epoch: 5, train_loss: 0.93, train_acc: 0.73, val_loss: 0.91, val_acc: 0.74, time:  55.49
epoch: 6, train_loss: 0.85, train_acc: 0.75, val_loss: 0.85, val_acc: 0.75, time:  66.38
epoch: 7, train_loss: 0.78, train_acc: 0.77, val_loss: 0.78, val_acc: 0.77, time:  77.38
epoch: 8, train_loss: 0.73, train_acc: 0.78, val_loss: 0.75, val_acc: 0.77, time:  88.27
epoch: 9, train_loss: 0.68, train_acc: 0.79, val_loss: 0.71, val_acc: 0.78, time:  99.44
epoch: 10, train_loss: 0.64, train_acc: 0.80, val_loss: 0.72, val_acc: 0.78, time:  110.43

完整代码的开源链接可以查询kaggle,gitee,github链接,其中gitee和github仓库中有训练好的模型权重,有需要可以在模型实例化后直接使用.

如需使用其他rnn序列模型如lstm和gru也可以直接实例化这里对应的模型类进行训练即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/469936.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

期权懂|期权新手入门教学:期权合约有哪些要素?

期权小懂每日分享期权知识,帮助期权新手及时有效地掌握即市趋势与新资讯! 期权新手入门教学:期权合约有哪些要素? 期权合约:是指约定买方有权在将来某一时间以特定价格买入或卖出约定标的物的标准化或非标准化合约。期…

Oracle OCP认证考试考点详解082系列16

题记: 本系列主要讲解Oracle OCP认证考试考点(题目),适用于19C/21C,跟着学OCP考试必过。 76. 第76题: 题目 解析及答案: 以下哪三项活动会被记录在数据库的警报日志中? A. 块损坏错误 数据库…

【Linux篇】面试——用户和组、文件类型、权限、进程

目录 一、权限管理 1. 用户和组 (1)相关概念 (2)用户命令 ① useradd(添加新的用户账号) ② userdel(删除帐号) ③ usermod(修改帐号) ④ passwd&…

论文阅读《机器人状态估计中的李群》

目录 摘要1 介绍2 微李理论2.1 李群2.2 group actions2.3 正切空间和李代数 摘要 李群是一个古老的数学抽象对象,可以追溯到19世纪,当时数学家 Sophus Lie奠定了连续变换群理论的基础。多年后,它的影响已经蔓延到科学和技术的各个领域。在机…

智能零售柜商品识别

项目源码获取方式见文章末尾! 600多个深度学习项目资料,快来加入社群一起学习吧。 《------往期经典推荐------》 项目名称 1.【基于CNN-RNN的影像报告生成】 2.【卫星图像道路检测DeepLabV3Plus模型】 3.【GAN模型实现二次元头像生成】 4.【CNN模型实现…

综合案例铁锅炖(CSS项目大杂烩)

小工具:snipaste 全世界最好用的截图工具来了 Snipaste是一个强大简单的截图工具,方便的点就在于可以把截图贴回屏幕上 常用快捷方式有这些: 1.F1截图,同时测量大小,设置箭头,文字书写 2.F3在桌面置顶显示…

稀疏视角CBCT重建的几何感知衰减学习|文献速递-基于深度学习的病灶分割与数据超分辨率

Title 题目 Geometry-Aware Attenuation Learning forSparse-View CBCT Reconstruction 稀疏视角CBCT重建的几何感知衰减学习 01 文献速递介绍 稀疏视角锥形束计算机断层扫描(CBCT)重建的几何感知学习方法 锥形束计算机断层扫描(CBCT&a…

河南省的一级科技查新机构有哪些?

科技查新,简称查新,是指权威机构对查新项目的新颖性作出文献评价的情报咨询服务。这一服务在科研立项、成果鉴定、项目申报等方面发挥着至关重要的作用。河南省作为中国的重要科技和教育基地,拥有多个一级科技查新机构,为本省及全…

https网站 请求http图片报错:net::ERR_SSL_PROTOCOL_ERROR

问题描述 场景: https网站,请求http图片资源报错:net::ERR_SSL_PROTOCOL_ERROR 原因: Chrome 81 中,对混合内容资源加载策略进行了改变,会自动升级到 https:// ,如果无法通过 https:// 加载&am…

【机器学习】机器学习中用到的高等数学知识-3.微积分 (Calculus)

3. 微积分 (Calculus) 导数和梯度:用于优化算法(如梯度下降)中计算损失函数的最小值。偏导数:在多变量函数中优化目标函数。链式法则:在反向传播算法中用于计算神经网络的梯度。 导数和梯度:用于优化算法…

华为大咖说 | 浅谈智能运维技术

本文分享自华为云社区:华为大咖说 | 浅谈智能运维技术-云社区-华为云 本文作者:李文轩 ( 华为智能运维专家 ) 全文约2695字,阅读约需8分钟 在大数据、人工智能等新兴技术的加持下,智能运维(AI…

WebStorm 如何调试 Vue 项目

前言 在日常开发和各种教程中,最常见的 debug 方式就是在代码中插入 console.log 语句,然后在 Chrome 控制台中查看日志。显而易见,插入console.log 的效率不高,那是否有更高效的 debug 方式呢?断点调试允许开发者在代…

【循环神经网络】

循环神经网络(Recurrent Neural Network, RNN)是一类用于处理序列数据的神经网络,擅长处理具有时间依赖或顺序结构的数据。RNN通过循环连接的结构,使得当前时刻的输出可以受之前时刻信息的影响,因此被广泛应用于自然语…

sqoop import将Oracle数据加载至hive,数据量变少,只能导入一个mapper的数据量

sqoop脚本如下: sqoop import -D mapred.job.queue.namehighway \ -D mapreduce.map.memory.mb4096 \ -D mapreduce.map.java.opts-Xmx3072m \ --connect "jdbc:oracle:thin://1.2.3.4.5:61521/LZY2" \ --username root \ --password 123456 \ --query &…

低功耗WTK6900P语音ic方案助力电子烟技术革新 打造个性化吸烟体验

在这个科技日新月异的时代,每一个细节的创新都是对传统的一次超越。今天,我们自豪地宣布一项革命性的融合——将先进的频谱技术与电子烟相结合,通过WTK6900P芯片的卓越性能,为您开启前所未有的个性化吸烟体验。这不仅是一次技术的…

《基于深度学习的车辆行驶三维环境双目感知方法研究》

复原论文思路: 《基于深度学习的车辆行驶三维环境双目感知方法研究》 1、双目测距的原理 按照上述公式算的话,求d的话,只和xl-xr有关系,这样一来,是不是只要两张图像上一个测试点的像素位置确定,对应的深…

Chromium 中sqlite数据库操作演示c++

本文主要演示sqlite数据库 增删改查创建数据库以及数据库表的基本操作,仅供学习参考。 一、sqlite数据库操作类封装: sql\database.h sql\database.cc // Copyright 2012 The Chromium Authors // Use of this source code is governed by a BSD-sty…

Qt初识简单使用Qt

使用C代码实现hello world 之前介绍过用图形化界面的方式创建hello world&#xff0c;这里我们使用C代码的方式再来实现一次hello world。 如上&#xff0c;首先要先包含一个头文件。 在QT这里&#xff0c;每一个类都有一个对应的同名头文件。比如这里我就包含了 <QLabel&…

高效运维:构建全面监控与自动化管理体系

在当今数字化时代&#xff0c;企业IT系统的稳定运行直接关系到业务的连续性和竞争力。运维团队作为保障系统稳定运行的中坚力量&#xff0c;面临着前所未有的挑战。随着云计算、大数据、物联网等技术的快速发展&#xff0c;系统架构日益复杂&#xff0c;运维工作也从传统的被动…

Docker网络和overlay的基础讲解

本人发现了两篇写的不错的文章&#xff1a;Docker网络 - docker network详解-CSDN博客&#xff0c;Docker 容器跨主机通信 overlay_docker overlay 网络-CSDN博客 因为这两篇文章中含有大量的例子&#xff0c;新手看起来毫不费力。于是我偷了个小懒&#xff0c;在本篇文章中没有…