[oneAPI] 使用字符级 RNN 生成名称

[oneAPI] 使用字符级 RNN 生成名称

  • oneAPI特殊写法
  • 使用字符级 RNN 生成名称
    • Intel® Optimization for PyTorch
    • 数据下载
    • 加载数据并对数据进行处理
    • 创建网络
    • 训练过程
      • 准备训练
      • 训练网络
    • 结果
  • 参考资料

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

oneAPI特殊写法

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')rnn = RNN(n_letters, 128, n_letters)
optim = torch.optim.SGD(rnn.parameters(), lr=0.01)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
rnn, optim = ipex.optimize(rnn, optimizer=optim)criterion = nn.NLLLoss()

使用字符级 RNN 生成名称

为了深入探索语言模型在分类和生成方面的卓越能力,我们特意设计了一个独特的任务。此任务的独特之处在于,它旨在综合学习多种语言的词义特征,以确保生成的内容与各种语言的词组相关性一致。

在任务的具体描述中,我们提供了一个多语言数据集,这个数据集包含多种语言的文本。通过这个数据集,我们的目标是使模型能够在生成名称时融合不同语言的特征。具体来说,我们会提供一个词的开头作为提示,然后模型将能够根据这个开头生成对应语言的名称,从而将不同语言的词意和语法特征进行完美融合。

通过这一任务,我们旨在实现一个在多语言环境中具有卓越生成和分类能力的语言模型。通过学习并融合不同语言的词义和语法特征,我们让使模型具备更广泛的应用潜力,能够在不同语境下生成准确、符合语法规则的名称。

> python sample.py Russian RUS
Rovakov
Uantov
Shavakov> python sample.py German GER
Gerren
Ereng
Rosher> python sample.py Spanish SPA
Salla
Parer
Allan> python sample.py Chinese CHI
Chan
Hang
Iun

Intel® Optimization for PyTorch

在本次实验中,我们利用PyTorch和Intel® Optimization for PyTorch的强大功能,对PyTorch进行了精心的优化和扩展。这些优化举措极大地增强了PyTorch在各种任务中的性能,尤其是在英特尔硬件上的表现更加突出。通过这些优化策略,我们的模型在训练和推断过程中变得更加敏捷和高效,显著地减少了计算时间,提高了整体效能。我们通过深度融合硬件和软件的精巧设计,成功地释放了硬件潜力,使得模型的训练和应用变得更加快速和高效。这一系列优化举措为人工智能应用开辟了新的前景,带来了全新的可能性。
在这里插入图片描述

数据下载

从这里下载数据 并将其解压到当前目录。

加载数据并对数据进行处理

简而言之,有一堆data/names/[Language].txt每行都有一个名称的纯文本文件。我们将行分割成一个数组,将 Unicode 转换为 ASCII,最后得到一个字典。{language: [names …]}

from io import open
import glob
import os
import unicodedata
import stringall_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1  # Plus EOS markerdef findFiles(path): return glob.glob(path)# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn'and c in all_letters)# Read a file and split into lines
def readLines(filename):with open(filename, encoding='utf-8') as some_file:return [unicodeToAscii(line.strip()) for line in some_file]# Build the category_lines dictionary, a list of lines per category
category_lines = {}
all_categories = []
for filename in findFiles('data/names/*.txt'):category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)category_lines[category] = linesn_categories = len(all_categories)if n_categories == 0:raise RuntimeError('Data not found. Make sure that you downloaded data ''from https://download.pytorch.org/tutorial/data.zip and extract it to ''the current directory.')print('# categories:', n_categories, all_categories)
print(unicodeToAscii("O'Néàl"))

Output:

# categories: 18 ['Arabic', 'Chinese', 'Czech', 'Dutch', 'English', 'French', 'German', 'Greek', 'Irish', 'Italian', 'Japanese', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Scottish', 'Spanish', 'Vietnamese']
O'Neal

创建网络

序列到序列网络,或 seq2seq 网络,或编码器解码器网络,是由两个称为编码器和解码器的 RNN 组成的模型。编码器读取输入序列并输出单个向量,解码器读取该向量以产生输出序列。

我添加了第二个线性层o2o(在组合隐藏层和输出层之后)以赋予其更多的功能。还有一个 dropout 层,它以给定的概率(此处为 0.1)随机将部分输入归零,通常用于模糊输入以防止过度拟合。在这里,我们在网络末端使用它来故意添加一些混乱并增加采样多样性。
在这里插入图片描述

######################################################################
# Creating the Network
# ====================
import torch
import torch.nn as nnimport intel_extension_for_pytorch as ipexclass RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)self.o2o = nn.Linear(hidden_size + output_size, output_size)self.dropout = nn.Dropout(0.1)self.softmax = nn.LogSoftmax(dim=1)def forward(self, category, input, hidden):input_combined = torch.cat((category, input, hidden), 1)hidden = self.i2h(input_combined)output = self.i2o(input_combined)output_combined = torch.cat((hidden, output), 1)output = self.o2o(output_combined)output = self.dropout(output)output = self.softmax(output)return output, hiddendef initHidden(self):return torch.zeros(1, self.hidden_size)

训练过程

准备训练

import random# Random item from a list
def randomChoice(l):return l[random.randint(0, len(l) - 1)]# Get a random category and random line from that category
def randomTrainingPair():category = randomChoice(all_categories)line = randomChoice(category_lines[category])return category, line
# One-hot vector for category
def categoryTensor(category):li = all_categories.index(category)tensor = torch.zeros(1, n_categories)tensor[0][li] = 1return tensor# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):tensor = torch.zeros(len(line), 1, n_letters)for li in range(len(line)):letter = line[li]tensor[li][0][all_letters.find(letter)] = 1return tensor# ``LongTensor`` of second letter to end (EOS) for target
def targetTensor(line):letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]letter_indexes.append(n_letters - 1) # EOSreturn torch.LongTensor(letter_indexes)

为了训练过程中的方便,我们将创建一个randomTrainingExample 函数来获取随机(类别、线)对并将它们转换为所需的(类别、输入、目标)张量。

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():category, line = randomTrainingPair()category_tensor = categoryTensor(category)input_line_tensor = inputTensor(line)target_line_tensor = targetTensor(line)return category_tensor, input_line_tensor, target_line_tensor

训练网络

criterion = nn.NLLLoss()learning_rate = 0.0005def train(category_tensor, input_line_tensor, target_line_tensor):target_line_tensor.unsqueeze_(-1)hidden = rnn.initHidden()rnn.zero_grad()loss = torch.Tensor([0]) # you can also just simply use ``loss = 0``for i in range(input_line_tensor.size(0)):output, hidden = rnn(category_tensor, input_line_tensor[i], hidden)l = criterion(output, target_line_tensor[i])loss += lloss.backward()for p in rnn.parameters():p.data.add_(p.grad.data, alpha=-learning_rate)return output, loss.item() / input_line_tensor.size(0)

为了跟踪训练需要多长时间,我添加了一个 timeSince(timestamp)返回人类可读字符串的函数:

import time
import mathdef timeSince(since):now = time.time()s = now - sincem = math.floor(s / 60)s -= m * 60return '%dm %ds' % (m, s)

训练就像平常一样 - 多次调用训练并等待几分钟,打印当前时间和每个print_every 示例的损失,并存储每个plot_every示例的平均损失all_losses以供稍后绘制。

rnn = RNN(n_letters, 128, n_letters)n_iters = 100000
print_every = 5000
plot_every = 500
all_losses = []
total_loss = 0 # Reset every ``plot_every`` ``iters``start = time.time()for iter in range(1, n_iters + 1):output, loss = train(*randomTrainingExample())total_loss += lossif iter % print_every == 0:print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))if iter % plot_every == 0:all_losses.append(total_loss / plot_every)total_loss = 0

结果

在这里插入图片描述

参考资料

https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/101780.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Temu闯关日韩受挫?跨境电商卖家如何打磨好营销链路

海外版拼多多 Temu 先后在日本和韩国上线,然而效果不似预期,日韩市场对这套“低价补贴”策略并不买账。作为一个尚未被日韩消费者熟悉的网站,其价格之便宜无法让消费者信任。除此之外更大的问题是,在日本卷不过线下零售与百元店&a…

生信学院|08月25日《SOLIDWORKS PDM帮助企业对设计数据版本的管理应用》

课程主题:SOLIDWORKS PDM帮助企业对设计数据版本的管理应用 课程时间:2023年08月25日 14:00-14:30 主讲人:车立洋 生信科技 PDM专家 1、图纸&文档的版本管理对于企业的重要性 2、SolidWorks PDM对图纸&文档版本的管理 3、SolidW…

Android6:片段和导航

创建项目Secret Message strings.xml <resources><string name"app_name">Secret Message</string><string name"welcome_text">Welcome to the Secret Message app!Use this app to encrypt a secret message.Click on the Star…

【深度学习 | 数据可视化】 视觉展示分类边界: Perceptron模型可视化iris数据集的决策边界

&#x1f935;‍♂️ 个人主页: AI_magician &#x1f4e1;主页地址&#xff1a; 作者简介&#xff1a;CSDN内容合伙人&#xff0c;全栈领域优质创作者。 &#x1f468;‍&#x1f4bb;景愿&#xff1a;旨在于能和更多的热爱计算机的伙伴一起成长&#xff01;&#xff01;&…

JavaWeb+JSP+SQL server学生学籍管理系统设计与实现(源代码+论文+开题报告+外文翻译+答辩PPT)

需求分析 本系统主要是针对各个高校的学生学籍进行管理&#xff0c;系统满足以下几点要求&#xff1a; 系统安全性。由于此系统中的操作都是由用户操作的&#xff0c;所以对于用户的权限设置比较严格。对于数据库&#xff0c;设置了不同用户的权限&#xff0c;不同权限进入不…

Spring Boot实践八--用户管理系统

一&#xff0c;技术介绍 技术选型功能说明springboot是一种基于 Spring 框架的快速开发应用程序的框架&#xff0c;它的主要作用是简化 Spring 应用程序的配置和开发&#xff0c;同时提供一系列开箱即用的功能和组件&#xff0c;如内置服务器、数据访问、安全、监控等&#xf…

win11安装ubuntu 子系统安装过程及注意事项

第一步 &#xff1a;安装系统必须组件 由于子系统是系统自带组件&#xff0c;需要安装软件支持 第二步&#xff1a;应用商店安装 ubuntu 编辑 编辑 这个时候打开会报错 第三步&#xff0c;运行linux子系统 选择Windows PowerShell 以管理员身份运行&#xff09; 输入&#…

简单计算器的实现(含转移表实现)

文章目录 计算器的一般实现使⽤函数指针数组的实现&#xff08;转移表&#xff09; 计算器的一般实现 通过函数的调用&#xff0c;实现加减乘除 # define _CRT_SECURE_NO_WARNINGS#include<stdio.h>int Add(int x, int y) {return x y; }int Sub(int x, int y) {retur…

Ros noetic Move_base 监听Move状态 实战使用教程

前言: 承接上一篇文章,在上一文中我们了解到move_base有几种监听的状态,我一文章中我将开源全部监听代码,本文将从0开始建立监听包,并覆上全部的工程代码,和仿真实操结果。 本文,还将解决当临时障碍物与机身相交时,机器人回人为自己被“卡住”,局部规划器规划的速度为…

Linux socket网络编程

一、主机字节序列和网络字节序列 主机字节序列分为大端字节序列和小端字节序列&#xff0c;不同的主机采用的字节序列可能不同。大端字节序列是指一个整数的高位字节存储在内存的低地址处&#xff0c;低位字节存储在内存的高地址处。小端字节序列是指整数的高位字节存储在内存…

【Windows 常用工具系列 10 -- linux ssh登录脚本输入密码】

文章目录 脚本输入SSH登录密码scp 脚本免密传输 脚本输入SSH登录密码 sshpass 是一个用于运行时非交互式ssh密码提供的工具&#xff0c;它允许你直接在命令行中为ssh命令提供密码。这就意味着你可以在脚本中使用ssh命令&#xff0c;而不需要用户交互地输入密码。 一般来说&am…

技术分享| WebRTC之SDP详解

一&#xff0c;什么是SDP WebRTC 是 Web Real-Time Communication&#xff0c;即网页实时通信的缩写&#xff0c;是 RTC 协议的一种Web实现&#xff0c;项目由 Google 开源&#xff0c;并和 IETF 和 W3C 制定了行业标准。 WebRTC是点对点通讯&#xff0c;他的通话建立需要交换…

【数据结构练习】链表面试题锦集一

目录 前言&#xff1a; 1. 删除链表中所有值为key的节点 方法一&#xff1a;正常删除&#xff0c;头结点另外讨论 方法二:虚拟头结点法 方法三&#xff1a;递归 2.反转链表 方法一&#xff1a;双指针迭代 方法二&#xff1a;递归法解析&#xff1a; 3.链表的中间结点 方法…

AWVS安装~Windows~激活

目录 1.下载安装包 2.双击acunetix_15.1.221109177.exe进行安装 3.配置C:\Windows\System32\drivers\etc\hosts 4.复制wvsc.exe到C:\Program Files (x86)\Acunetix\15.1.221109177下 5.复制license_info.json与wa_data.dat到C:\ProgramData\Acunetix\shared\license下&…

CentOS中Oracle11g进程有哪些

最近遇到Oracle数据库运行过程实例进程由于某种原因导致中止的问题&#xff0c;专门看了下正常Oracle数据库启动后的进程有哪些&#xff0c;查阅资料了解了下各进程的作用&#xff0c;记录如下。 oracle 3032 1 0 07:36 ? 00:00:00 ora_pmon_orcl oracle …

解决 node-gyp 错误问题

gyp verb check python checking for Python executable “python2.7” in the PATH gyp verb which failed Error: not found: python2.7 安装老项目老是报错Python找不到&#xff0c;以为是自己node版本高过了node-sass导致的&#xff0c;把node版本降下来还是不行。然后找到…

ONLYOFFICE协作空间服务器如何一键安装自托管私有化部署

ONLYOFFICE协作空间服务器如何一键安装自托管私有化部署 如何在 Ubuntu 上部署 ONLYOFFICE 协作空间社区版&#xff1f;https://blog.csdn.net/m0_68274698/article/details/132069372?ops_request_misc&request_id&biz_id102&utm_termonlyoffice%20%E5%8D%8F%E4…

vue3+element下拉多选框组件

<!-- 下拉多选 --> <template><div class"select-checked"><el-select v-model"selected" :class"{ all: optionsAll, hidden: selectedOptions.data.length < 2 }" multipleplaceholder"请选择" :popper-app…

ui设计师简历自我评价(合集)

UI设计最新面试题及答案 1、说说你是怎么理解UI的? UI是最直观的把产品展示展现在用户面前的东西&#xff0c;是一个产品的脸面。人开始往往是先会先喜欢上美好的事物后&#xff0c;在去深究内在的东西的。 那么也就意味着一个产品的UI首先要做的好看&#xff0c;无论风格是…

智慧班牌云平台源码 (人脸识别、信息发布、校园风采、家校互通、教务管理、考勤管理)

电子班牌是一款智慧校园的管理工具&#xff0c;也是校园的多媒体展示平台&#xff0c;智慧电子班牌系统是专为学校智慧教育设计的一款智慧校园的管理工具&#xff0c;融合了多媒体信息发布、校园风采、家校互通、教务管理、考勤管理、日常办公等一系列应用。具备智慧教育功能和…