TensorFlow2实战-系列教程11:RNN文本分类3

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

6、构建训练数据

  • 所有的输入样本必须都是相同shape(文本长度,词向量维度等)
  • tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
  • tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本
def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, y
  1. 定义一个生成器函数,传进来读数据的路径、和一些有限制的参数,params 在是一个字典,它包含了最大序列长度(max_len)、词到索引的映射(word2idx)等关键信息
  2. 打开文件
  3. 打印文件路径
  4. 遍历每行数据
  5. 获取标签和文本
  6. 文本按照空格分离出单词
  7. 获取当前句子的所有词对应的索引,for w in text取出这个句子的每一个单词,[params[‘word2idx’]取出params中对应的word2idx字典,.get(w, len(word2idx))从word2idx字典中取出该单词对应的索引,如果有这个索引则返回这个索引,如果没有则返回len(word2idx)作为索引,这个索引表示unknow
  8. 如果当前句子大于预设的最大句子长度
  9. 进行截断操作
  10. 如果小于
  11. 补充0
  12. 标签从str转换为int类型
  13. yield 关键字:用于从一个函数返回一个生成器(generator)。与 return 不同,yield 不会退出函数,而是将函数暂时挂起,保存当前的状态,当生成器再次被调用时,函数会从上次 yield 的地方继续执行,使用 yield 的函数可以在处理大数据集时节省内存,因为它允许逐个生成和处理数据,而不是一次性加载整个数据集到内存中

也就是说yield 会从上一次取得地方再接着去取数据,而return却不会

def dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds
  1. 定义一个制作数据集的函数,is_training表示是否是训练,这个函数在验证和测试也会使用,训练的时候设置为True,验证和测试为False
  2. 当前shape值
  3. 1
  4. 是否在训练,如果是:
  5. 构建一个Dataset
  6. 传进我们刚刚定义的生成器函数,并且传进实际的路径和配置参数
  7. 输出的shape值
  8. 输出的类型
  9. 指定shuffle
  10. 指定 batch_size
  11. 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速
  12. 如果不是在训练,则:
  13. 验证和测试不同的就是路径不同,以及没有shuffle操作,其他都一样
  14. 最后把做好的Datasets返回回去

7、自定义网络模型

在这里插入图片描述
一条文本变成一组向量/矩阵的基本流程:

  1. 拿到一个英文句子
  2. 通过查语料表将句子变成一组索引
  3. 通过词嵌入表结合索引将每个单词都变成一组向量,一条句子就变成了一个矩阵,这就是特征了

在这里插入图片描述
在这里插入图片描述
BiLSTM即双向LSTM,就是在原本的LSTM增加了一个从后往前走的模块,这样前向和反向两个方向都各自生成了一组特征,把两个特征拼接起来得到一组新的特征,得到翻倍的特征。其他前面和后续的处理操作都是一样的。

class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'), dtype=tf.float32, name='pretrained_embedding', trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)  def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)        x = self.drop1(x, training=training)x = self.rnn1(x)x = self.drop2(x, training=training)x = self.rnn2(x)x = self.drop3(x, training=training)x = self.rnn3(x)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x
  1. 自定义一个模型,继承tf.keras.Model模块
  2. 初始化函数
  3. 初始化
  4. 词嵌入,把之前保存好的词嵌入文件向量读进来
  5. 定义一层dropout1
  6. 定义一层dropout2
  7. 定义一层dropout3
  8. 定义一个rnn1,rnn_units表示得到多少维的特征,return_sequences表示是返回一个序列还是最后一个输出
  9. 定义一个rnn2,最后一层的rnn肯定只需要最后一个输出,前后两个rnn的堆叠肯定需要返回一个序列
  10. 定义一个rnn3 ,tf.keras.layers.LSTM()直接就可以定义一个LSTM,在外面再封装一层API:tf.keras.layers.Bidirectional就实现了双向LSTM
  11. 定义全连接层的dropout
  12. 定义一个全连接层,因为是双向的,这里就需要把参数乘以2
  13. 定义最后输出的全连接层,只需要得到是正例还是负例,所以是2
  14. 定义前向传播函数,传进来一个batch的数据和是否是在训练
  15. 如果输入数据不是tf.int32类型
  16. 转换成tf.int32类型
  17. 取出batch_size
  18. 设置LSTM神经元个数,双向乘以2
  19. 使用 TensorFlow 的 embedding_lookup 函数将输入的整数索引转换为词向量
  20. 数据通过第1个 Dropout 层
  21. 数据通过第1个rnn
  22. 数据通过第2个 Dropout 层
  23. 数据通过第2个rnn
  24. 数据通过第3个 Dropout 层
  25. 数据通过第3个rnn
  26. 经过全连接层对应的Dropout
  27. 数据通过一个全连接层
  28. 最后,数据通过一个输出层
  29. 返回最终的模型输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/248372.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring | Spring的“数据库开发“ (Srping JDBC)

目录: Spring JDBC1.Spring JDBC的核心类 ( JdbcTemplate类 )2.Srping JDBC 的配置3.JdbcTemplate类的“常用方法”execute( ):直接执行“sql语句”,没有返回值update( ) :“增删改”,返回 “影响的行数”query( ) : “…

86.网游逆向分析与插件开发-物品使用-物品丢弃的逆向分析与C++代码的封装

内容参考于:易道云信息技术研究院VIP课 上一个内容:物品使用的逆向分析与C代码的封装-CSDN博客 码云地址(ui显示角色数据 分支):https://gitee.com/dye_your_fingers/sro_-ex.git 码云版本号:7563f86877c…

网络协议与攻击模拟_11DHCP欺骗防护

开启DHCP 监听 ip dhcp snooping 指定监听vlan ip dhcp snooping vlan 1 由于开启监听后,交换机上的接口就全部变成非信任端口, 非信任端口会拒绝DHCP报文,会造成正常的DHCP请求和响应都无法完成。 现在是请求不到IP地址的,…

添加了gateway之后远程调用失败

前端提示500,后端提示[400 ] during [GET] to [http://userservice/user/1] 原因是这个,因为在请求地址写了两个参数,实际上只传了一个参数 解决方案:加上(required false)并重启所有相关服务

【pytorch】nn.linear 中为什么是y=xA^T+b

我记得读教材的时候是yWxb, 左乘矩阵W,这样才能表示线性变化。 但是pytorch中的nn.linear中,计算方式是yxA^Tb,其中A是权重矩阵。 为什么右乘也能表示线性变化操作呢?因为pytorch中,照顾到输入是多个样本一起算的&…

基于JDK8的SpringBoot-2.7.6应用程序的jar包能直接通过java -jar 命令运行的原因

文章目录 前言一、JAR是什么?二、嵌套JAR1.java官方不支持嵌套jar读取和加载2.嵌套“shade” jar 方案3.SpringBoot的解决方案三、SpringBoot的Executable Jars1.核心支持模块(spring-boot-loader)2.运行调试工具(JDWP)2.1.保姆级IDEA添加JDWP远程调试示例3.运行调试(jav…

vue3之echarts3D环柱饼图

vue3之echarts3D环柱饼图 效果&#xff1a; 版本 "echarts": "^5.4.1", "echarts-gl": "^2.0.9" 核心代码&#xff1a; <template><div class"content"><div ref"eCharts" class"chart&…

运动耳机怎么选?运动耳机什么牌子的好?2024运动无线耳机推荐

科学有规律的锻炼对身体有很多好处。很多人选择通过跑步来保持健康&#xff0c;而在跑步时&#xff0c;戴上耳机听音乐已经成为了许多人的习惯。这不仅可以放松心情&#xff0c;还可以提高身心免疫力。在健身房里&#xff0c;也会看到很多人选择佩戴运动耳机来增强他们的锻炼体…

elk之安装和简单配置

写在前面 本文看下elk的安装和简单配置&#xff0c;安装我们会尝试通过不同的方式来完成&#xff0c;也会介绍如何使用docker&#xff0c;docker-compose安装。 1&#xff1a;安装es 1.1&#xff1a;安装单实例 下载es安装包 在这里 下载&#xff0c;下载后解压到某个目录…

Idea设置代理后无法clone git项目

背景 对于我们程序员来说&#xff0c;经常上github找项目、找资料是必不可少的&#xff0c;但是一些原因&#xff0c;我们访问的时候速度特别的慢&#xff0c;需要有个代理&#xff0c;才能正常的访问。 今天碰到个问题&#xff0c;使用idea工具 clone项目&#xff0c;速度特…

Java入门高频考查基础知识8(腾讯18问1.5万字参考答案)

刷题专栏&#xff1a;http://t.csdnimg.cn/gvB6r Java 是一种广泛使用的面向对象编程语言&#xff0c;在软件开发领域有着重要的地位。Java 提供了丰富的库和强大的特性&#xff0c;适用于多种应用场景&#xff0c;包括企业应用、移动应用、嵌入式系统等。 以下是几个面试技巧&…

Django模型(四)

一、数据操作初始化 from django.db import models# Create your models here. class Place(models.Model):"""位置信息"""name = models.CharField(max_length=32,verbose_name=地名)address = models.CharField(max_length=64,null=True,verbo…

【日常总结】如何快速迁移Navicat中的全部连接设置到新安装的Navicat中?

一、场景 二、需求 三、解决方案 Stage 1&#xff1a;“文件”-->“导出连接”。 Stage 2&#xff1a;获取备份文件 connections.ncx Stage 3&#xff1a;导入connections.ncx 四、不足 一、场景 公司电脑换新&#xff0c;所有软件需要重装&#xff0c;包括navicat 1…

深度学习(7)--卷积神经网络项目详解

一.项目介绍&#xff1a; 用Keras工具包搭建训练自己的一个卷积神经网络(Simple_VGGNet&#xff0c;简单版VGGNet)&#xff0c;用来识别猫/狗/羊三种图片。 数据集&#xff1a; 二.卷积神经网络构造 查看API文档 Convolution layers (keras.io)https://keras.io/api/layers/…

【Tomcat与网络1】史前时代—没有Spring该如何写Web服务

在前面我们介绍了网络与Java相关的问题&#xff0c; 最近在调研的时候发现这块内容其实非常复杂&#xff0c;涉及的内容多而且零碎&#xff0c;想短时间梳理出整个体系是不太可能的&#xff0c;所以我们还是继续看Tomcat的问题&#xff0c;后面有网络的内容继续补充吧。 目录 …

酒店|酒店管理小程序|基于微信小程序的酒店管理系统设计与实现(源码+数据库+文档)

酒店管理小程序目录 目录 基于微信小程序的酒店管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员模块的实现 (1) 用户信息管理 (2) 酒店管理员管理 (3) 房间信息管理 2、小程序序会员模块的实现 &#xff08;1&#xff09;系统首页 &#xff0…

Visual Studio 2022 打开“程序包管理器控制台”失败

Visual Studio 2022 打开“程序包管理器控制台”失败 昨天下午&#xff0c;正在用Visual studio 2022写代码&#xff0c;当使用EF core 做数据迁移时&#xff0c;需要用到“程序包管理器控制台”&#xff0c;打开失败&#xff0c;前一秒还好好的&#xff0c;怎么突然就用不了了…

excel给数据库初始化/旧数据处理(自动sql拼装)

思路&#xff1a; 首先导出数据到excel编写单条数据操作的sql利用excel CONCATENATE 函数自动生成&#xff0c;每一行数据的操作sql 小技巧:对于需要套娃的字段值&#xff0c;可以加一个临时列同样使用CONCATENATE函数进行sql拼装 案例&#xff1a; 1.临时列:CONCATENATE(C2, …

【计算机网络】网络的网络

网络的网络 客户 customer 接入ISP提供商 provider 全球承载ISP多个ISP的层级结构 第一层ISP &#xff08;tier-1 ISP &#xff09; 位于顶部 区域ISP &#xff08;reginal ISP&#xff09;Level 3通信 &#xff0c;AT&T&#xff0c;Sprint &#xff0c;NTT存在点&#x…

嘿嘿,vue之输出土味情话

有点好玩&#xff0c;记录一下。通过按钮调用网站接口&#xff0c;然后解构数据输出土味情话。 lovetalk.vue: <!--vue简单框架--> <template> <!-- 这是一个div容器&#xff0c;用于显示土味情话 --> <div class"talk"> <!-- 当点…