【深度学习】RNN的简单实现

目录

1.RNNCell

2.RNN

3.RNN_Embedding


1.RNNCell

import torchinput_size = 4
hidden_size = 4
batch_size = 1idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]  # 输入:hello
y_data = [3, 1, 2, 3, 2]  # 期待:ohlol# 独热向量
one_hot_lookup = [[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)  # (seqLen,batchSize,inputSize)
labels = torch.LongTensor(y_data).view(-1, 1)  # (seqLen,1)class Model(torch.nn.Module):def __init__(self, input_size, hidden_size, batch_size):super(Model, self).__init__()self.batch_size = batch_sizeself.input_size = input_sizeself.hidden_size = hidden_sizeself.rnncell = torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)def forward(self, input, hidden):hidden = self.rnncell(input, hidden)  # input:(batch, input_size) hidden:(batch, hidden_size)return hiddendef init_hidden(self):return torch.zeros(self.batch_size, self.hidden_size)net = Model(input_size, hidden_size, batch_size)criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)for epoch in range(15):loss = 0optimizer.zero_grad()hidden = net.init_hidden()print('Predicted string: ', end='')for input, label in zip(inputs, labels):hidden = net(input, hidden)loss += criterion(hidden, label)idx = torch.argmax(hidden, dim=1)print(idx2char[idx.item()], end='')loss.backward()optimizer.step()print(', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))

2.RNN

import torchinput_size = 4  # 输入的维度,例如hello为四个字母表示,其维度为四
hidden_size = 4  # 隐藏层维度
num_layers = 1  # number of layers
batch_size = 1
seq_len = 5idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]  # 输入:hello
y_data = [3, 1, 2, 3, 2]  # 期待:ohlol# 独热向量
one_hot_lookup = [[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]inputs = torch.Tensor(x_one_hot).view(seq_len, batch_size, input_size)
labels = torch.LongTensor(y_data)  # (seqSize*batchSize, 1)class Model(torch.nn.Module):def __init__(self, input_size, hidden_size, batch_size, num_layers=1):super(Model, self).__init__()self.num_layers = num_layersself.batch_size = batch_sizeself.input_size = input_sizeself.hidden_size = hidden_sizeself.rnn = torch.nn.RNN(input_size=self.input_size,hidden_size=self.hidden_size,num_layers=num_layers)def forward(self, input):hidden = torch.zeros(self.num_layers,self.batch_size,self.hidden_size)  # (numLayers, batchSize, hiddenSize)out, hidden_last = self.rnn(input, hidden)  # out:(seqLen, batchSize, hiddenSize), hidden_last:最后一个hiddenreturn out.view(-1, self.hidden_size)  # (seqLen×batchSize, hiddenSize)net = Model(input_size, hidden_size, batch_size, num_layers)criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)for epoch in range(15):optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()idx = torch.argmax(outputs, dim=1)print('Predicted string: ', ''.join([idx2char[i.item()] for i in idx]), end='')print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

3.RNN_Embedding

import torch# parameters
num_class = 4  # 引入线性层,不用不必要求一个输入就有一个输出,可以多个
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]]  # (batch:1, seq_len:5)
y_data = [3, 1, 2, 3, 2]  # (batch * seq_len)
inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = torch.nn.Embedding(input_size, embedding_size)self.rnn = torch.nn.RNN(input_size=embedding_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)# batchSize在第一位: (batchSize:一共几个句子, seqLen:每个句子有几个单词, inputSize:每个单词有多少特征)self.fc = torch.nn.Linear(hidden_size, num_class)def forward(self, x):hidden = torch.zeros(num_layers, x.size(0), hidden_size)x = self.emb(x)  # (batch, seqLen, embeddingSize), 输入数据x首先经过嵌入层,将字符索引转换为向量x, _ = self.rnn(x, hidden)x = self.fc(x)return x.view(-1, num_class)net = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)for epoch in range(15):optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()idx = torch.argmax(outputs, dim=1)print('Predicted string: ', ''.join([idx2char[i.item()] for i in idx]), end='')print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/452672.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用Python合并Excel文件中的多个Sheet

在日常工作中,我们经常会遇到需要处理多个Excel工作表(Sheet)的情况。比如,一个Excel文件中包含了一个月内每天的数据,每个工作表代表一天。有时候,为了方便分析,我们需要将这些分散的数据合并到…

数据结构单向链表的插入和删除(一)

链表 一、链表结构: (物理存储结构上不连续,逻辑上连续;大小不固定)二、单链表:三、单项链表的代码实现:四、开发可用的链表:四、单链表的效率分析: 一、链表结构&#x…

Javaweb基础-vue

Vue.js Vue是一套用于构建用户界面的渐进式框架。 起步 引入vue <head><script src"static/js/vue2.6.12.min.js"></script> </head> 创建vue应用 <body> <div id"index"><p>{{message}}</p> </div>…

Vulnhub:Me-and-My-Girlfriend-1

一.环境启动/信息收集 &#xff08;1&#xff09;根据物理地址用nmap的主机发现功能得出IP地址 nmap -P 192.168.138.0/24 //同网段下主机发现得到IP为192.168.138.180&#xff08;2&#xff09;做nmap的目录扫描和端口扫描来发现其他站带以及信息 nmap -p- 192.168.138.180 …

CTF(五)

导言&#xff1a; 本文主要讲述在CTF竞赛中&#xff0c;web类题目easyphp。 靶场链接&#xff1a;攻防世界 (xctf.org.cn) 参考文章原文链接&#xff1a;Web安全攻防世界05 easyphp&#xff08;江苏工匠杯&#xff09;_攻防世界 easyphp-CSDN博客 一&#xff0c;观察页面。…

玄机平台-应急响应-webshell查杀

首先xshell连接 然后进入/var/www/html目录中&#xff0c;将文件变成压缩包 cd /var/www/html tar -czvf web.tar.gz ./* 开启一个http.server服务&#xff0c;将文件下载到本地 python3 -m http.server 放在D盾中检测 基本可以确认木马文件就是这四个 /var/www/html/shell.p…

动态规划-子数组系列——乘积最大子数组

1.题目解析 题目来源&#xff1a;152.乘积最大子数组——力扣 测试用例 2.算法原理 1.状态表示 由于题目给的数组中可以包含负数&#xff0c;因此求最大乘积有两种情况&#xff1a; a.负数乘以最小数得出最大乘积 b.整数乘以最大数得出最大乘积 所以需要两个表分别求出最大最…

Leetcode 柱状图中最大的矩形

h 是右边界&#xff0c;连续多个高度递增的柱子&#xff0c;如果遇到下一个 h < 栈顶元素(是最大的元素&#xff0c;单调递增栈)&#xff0c;那么会不断出栈来更新计算最大面积。 并非是一次性计算出最大面积的&#xff0c;很重要的一点是while (!stack.isEmpty()这一部分的…

【Linux】多线程安全之道:互斥、加锁技术与底层原理

目录 1.线程的互斥 1.1.进程线程间的互斥相关背景概念 1.2.互斥量mutex的基本概念 所以多线程之间为什么要有互斥&#xff1f; 为什么抢票会抢到负数&#xff0c;无法获得正确结果&#xff1f; 为什么--操作不是原子性的呢&#xff1f; 解决方式&#xff1a; 2.三种加锁…

项目实战:构建 effet.js 人脸识别交互系统的实战之路

&#x1f4dd;个人主页&#x1f339;&#xff1a;Eternity._ &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; ❀构建 effet.js &#x1f4d2;1. 什么是effet.js&#x1f4dc;2. 为什么需要使用effet.js&#x1f4dd;3. effet.js的功能&#x1f4da;4. 使用…

HarmonyOS NEXT 应用开发实战(五、页面的生命周期及使用介绍)

HarmonyOS NEXT是华为推出的最新操作系统&#xff0c;arkUI是其提供的用户界面框架。arkUI的页面生命周期管理对于开发者来说非常重要&#xff0c;因为它涉及到页面的创建、显示、隐藏、销毁等各个阶段。以下是arkUI页面生命周期的介绍及使用举例。 页面的生命周期的作用 页面…

聊聊Go语言的异常处理机制

背景 最近因为遇到了一个panic问题&#xff0c;加上之前零零散散看了些关于程序异常处理相关的东西&#xff0c;对这块有点兴趣&#xff0c;于是整理了一下golang对于异常处理的机制。 名词介绍 Painc golang的内置方法&#xff0c;能够改变程序的控制流。 当函数调用了pan…

T113 内核中 adbd相关配置1

准备工作 1. 配置 系统&#xff1a;ubuntu24.04docker&#xff08;ubuntu18.04&#xff09; 软件vscode, sdk:Tina-linux&#xff08;BingPi-M2&#xff09; 2. 构建环境直接使用自带的 source ./build/envsetup.sh lunch 选择 6 编译开启16线程 make -j16boot编译 mboot 打包…

设计模式——装饰者模式(8)

一、定义 指在不改变现有对象结构的情况下&#xff0c;动态地给该对象增加一些职责&#xff08;即增加其额外功能&#xff09;的模式。我们先来看一个快餐店的例子。快餐店有炒面、炒饭这些快餐&#xff0c;可以额外附加鸡蛋、火腿、培根这些配菜&#xff0c;当然加配菜需要额…

【网络安全】简单P1:通过开发者工具解锁专业版和企业版功能

未经许可,不得转载。 文章目录 前言发现过程前言 在探索一个SaaS平台的过程中,我发现了一个漏洞,使得我能够在无需订阅的情况下解锁高级(专业/企业)功能。 发现过程 我使用一个没有任何高级功能的基本用户账户进行常规登录。在浏览平台时,我注意到某些按钮和功能上带有…

基于微信小程序的购物系统【附源码、文档】

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…

web前端--html 5---qq注册

<!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <title>qq注册</title> <link rel"impo…

[Linux#66][TCP->IP] 面向字节流 | TCP异常 | filesocket | 网络层IP

目录 1. 面向字节流 思考&#xff1a;对于UDP协议来说&#xff0c;是否也存在“粘包问题”呢&#xff1f; 2.TCP 异常情况 3.知识 1.UDP实现可靠传输(经典面试题) 2. 网络抓包 | 爬虫 3.打通文件和 socket 的关系 4.网络层&#xff1a;IP 前置知识 1. 面向字节流 udp…

java逻辑运算符 C语言结构体定义

1. public static void main(String[] args) {System.out.println(true&true);//&两者均为true才trueSystem.out.println(false|false);// | 两边都是false才是falseSystem.out.println(true^false);//^ 相同为false&#xff0c;不同为trueSystem.out.println(!false)…

git 安装

文章目录 一、ubuntu 安装 git二、centos 安装 git三、检查安装 git 一、ubuntu 安装 git sudo apt-get install get -y二、centos 安装 git sudo s install git -y三、检查安装 git git --version出现此标志git版本号&#xff0c;表示git安装完成。