NLP之LSTM与BiLSTM

文章目录

  • 代码展示
  • 代码解读
  • 双向LSTM介绍(BiLSTM)

代码展示

import pandas as pd
import tensorflow as tf
tf.random.set_seed(1)
df = pd.read_csv("../data/Clothing Reviews.csv")
print(df.info())df['Review Text'] = df['Review Text'].astype(str)
x_train = df['Review Text']
y_train = df['Rating']
print(y_train.unique())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 23486 entries, 0 to 23485
Data columns (total 11 columns):#   Column                   Non-Null Count  Dtype 
---  ------                   --------------  ----- 0   Unnamed: 0               23486 non-null  int64 1   Clothing ID              23486 non-null  int64 2   Age                      23486 non-null  int64 3   Title                    19676 non-null  object4   Review Text              22641 non-null  object5   Rating                   23486 non-null  int64 6   Recommended IND          23486 non-null  int64 7   Positive Feedback Count  23486 non-null  int64 8   Division Name            23472 non-null  object9   Department Name          23472 non-null  object10  Class Name               23472 non-null  object
[4 5 3 2 1]
from tensorflow.keras.preprocessing.text import Tokenizerdict_size = 14848
tokenizer = Tokenizer(num_words=dict_size)tokenizer.fit_on_texts(x_train)
print(len(tokenizer.word_index),tokenizer.index_word)x_train_tokenized = tokenizer.texts_to_sequences(x_train)
from tensorflow.keras.preprocessing.sequence import pad_sequences
max_comment_length = 120
x_train = pad_sequences(x_train_tokenized,maxlen=max_comment_length)for v in x_train[:10]:print(v,len(v))
# 构建RNN神经网络
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Embedding,LSTM,Bidirectional
import tensorflow as tfrnn = Sequential()
# 对于rnn来说首先进行词向量的操作
rnn.add(Embedding(input_dim=dict_size,output_dim=60,input_length=max_comment_length))
# RNN:simple_rnn (SimpleRNN)  (None, 100)   16100
# LSTM:simple_rnn (SimpleRNN)  (None, 100)  64400
rnn.add(Bidirectional(LSTM(units=100)))  # 第二层构建了100个RNN神经元
rnn.add(Dense(units=10,activation=tf.nn.relu))
rnn.add(Dense(units=6,activation=tf.nn.softmax))  # 输出分类的结果
rnn.compile(loss='sparse_categorical_crossentropy',optimizer="adam",metrics=['accuracy'])
print(rnn.summary())
result = rnn.fit(x_train,y_train,batch_size=64,validation_split=0.3,epochs=10)
print(result)
print(result.history)

代码解读

首先,我们来总结这段代码的流程:

  1. 导入了必要的TensorFlow Keras模块。
  2. 初始化了一个Sequential模型,这表示我们的模型会按顺序堆叠各层。
  3. 添加了一个Embedding层,用于将整数索引(对应词汇)转换为密集向量。
  4. 添加了一个双向LSTM层,其中包含100个神经元。
  5. 添加了两个Dense全连接层,分别包含10个和6个神经元。
  6. 使用sparse_categorical_crossentropy损失函数编译了模型。
  7. 打印了模型的摘要。
  8. 使用给定的训练数据和验证数据对模型进行了训练。
  9. 打印了训练的结果。

现在,让我们逐行解读代码:

  1. 导入依赖:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Embedding,LSTM,Bidirectional
import tensorflow as tf

你导入了创建和训练RNN模型所需的TensorFlow Keras库。

  1. 初始化模型:
rnn = Sequential()

你选择了一个顺序模型,这意味着你可以简单地按顺序添加层。

  1. 添加Embedding层:
rnn.add(Embedding(input_dim=dict_size,output_dim=60,input_length=max_comment_length))

此层将整数索引转换为固定大小的向量。dict_size是词汇表的大小,max_comment_length是输入评论的最大长度。

  1. 添加LSTM层:
rnn.add(Bidirectional(LSTM(units=100)))

你选择了双向LSTM,这意味着它会考虑过去和未来的信息。它有100个神经元。

  1. 添加全连接层:
rnn.add(Dense(units=10,activation=tf.nn.relu))
rnn.add(Dense(units=6,activation=tf.nn.softmax))

这两个Dense层用于模型的输出,最后一层使用softmax激活函数进行6类的分类。

  1. 编译模型:
rnn.compile(loss='sparse_categorical_crossentropy',optimizer="adam",metrics=['accuracy'])

你选择了一个适合分类问题的损失函数,并选择了adam优化器。

  1. 显示模型摘要:
print(rnn.summary())

这将展示模型的结构和参数数量。

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================embedding (Embedding)       (None, 120, 60)           890880    bidirectional (Bidirectiona  (None, 200)              128800    l)                                                              dense (Dense)               (None, 10)                2010      dense_1 (Dense)             (None, 6)                 66        =================================================================
Total params: 1,021,756
Trainable params: 1,021,756
Non-trainable params: 0
_________________________________________________________________
None
  1. 训练模型:
result = rnn.fit(x_train,y_train,batch_size=64,validation_split=0.3,epochs=10)

你用训练数据集训练了模型,其中30%的数据用作验证,训练了10个周期。

Epoch 1/10
257/257 [==============================] - 74s 258ms/step - loss: 1.2142 - accuracy: 0.5470 - val_loss: 1.0998 - val_accuracy: 0.5521
Epoch 2/10
257/257 [==============================] - 57s 221ms/step - loss: 0.9335 - accuracy: 0.6293 - val_loss: 0.9554 - val_accuracy: 0.6094
Epoch 3/10
257/257 [==============================] - 59s 229ms/step - loss: 0.8363 - accuracy: 0.6616 - val_loss: 0.9321 - val_accuracy: 0.6168
Epoch 4/10
257/257 [==============================] - 61s 236ms/step - loss: 0.7795 - accuracy: 0.6833 - val_loss: 0.9812 - val_accuracy: 0.6089
Epoch 5/10
257/257 [==============================] - 56s 217ms/step - loss: 0.7281 - accuracy: 0.7010 - val_loss: 0.9559 - val_accuracy: 0.6043
Epoch 6/10
257/257 [==============================] - 56s 219ms/step - loss: 0.6934 - accuracy: 0.7156 - val_loss: 1.0197 - val_accuracy: 0.5999
Epoch 7/10
257/257 [==============================] - 57s 220ms/step - loss: 0.6514 - accuracy: 0.7364 - val_loss: 1.1192 - val_accuracy: 0.6080
Epoch 8/10
257/257 [==============================] - 57s 222ms/step - loss: 0.6258 - accuracy: 0.7486 - val_loss: 1.1350 - val_accuracy: 0.6100
Epoch 9/10
257/257 [==============================] - 57s 220ms/step - loss: 0.5839 - accuracy: 0.7749 - val_loss: 1.1537 - val_accuracy: 0.6019
Epoch 10/10
257/257 [==============================] - 57s 222ms/step - loss: 0.5424 - accuracy: 0.7945 - val_loss: 1.1715 - val_accuracy: 0.5744
<keras.callbacks.History object at 0x00000244DCE06D90>
  1. 显示训练结果:
print(result)
<keras.callbacks.History object at 0x0000013AEAAE1A30>
print(result.history)
{'loss': [1.2142471075057983, 0.9334620833396912, 0.8363043069839478, 0.7795010805130005, 0.7280740141868591, 0.693393349647522, 0.6514003872871399, 0.6257606744766235, 0.5839114189147949, 0.5423741340637207], 
'accuracy': [0.5469586253166199, 0.6292579174041748, 0.6616179943084717, 0.6833333373069763, 0.7010340690612793, 0.7156326174736023, 0.7363746762275696, 0.748600959777832, 0.7748783230781555, 0.7944647073745728], 
'val_loss': [1.0997602939605713, 0.9553984999656677, 0.932131290435791, 0.9812102317810059, 0.9558586478233337, 1.019730806350708, 1.11918044090271, 1.1349923610687256, 1.1536787748336792, 1.1715185642242432], 
'val_accuracy': [0.5520862936973572, 0.609423816204071, 0.6168038845062256, 0.6088560819625854, 0.6043145060539246, 0.5999148488044739, 0.6080045700073242, 0.6099914908409119, 0.6019017696380615, 0.574368417263031]
}

这将展示训练过程中的损失和准确性等信息。

双向LSTM介绍(BiLSTM)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
例子:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/180528.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【计算机网络实验/wireshark】tcp建立和释放

wireshark开始捕获后&#xff0c;浏览器打开xg.swjtu.edu.cn&#xff0c;网页传输完成后&#xff0c;关闭浏览器&#xff0c;然后停止报文捕获。 若捕获不到dns报文&#xff0c;先运行ipconfig/flushdns命令清空dns缓存 DNS报文 设置了筛选条件&#xff1a;dns 查询报文目的…

17、Flink 之Table API: Table API 支持的操作(1)

Flink 系列文章 1、Flink 部署、概念介绍、source、transformation、sink使用示例、四大基石介绍和示例等系列综合文章链接 13、Flink 的table api与sql的基本概念、通用api介绍及入门示例 14、Flink 的table api与sql之数据类型: 内置数据类型以及它们的属性 15、Flink 的ta…

代码随想录训练营第60天 | 503.下一个更大元素II ● 42. 接雨水● 84.柱状图中的最大矩形

503.下一个更大元素II 题目链接&#xff1a;https://leetcode.com/problems/next-greater-element-ii/ 解法&#xff1a; 由于是循环数组&#xff0c;可以直接把两个数组拼接在一起&#xff0c;然后使用单调栈求下一个最大值。 写法上&#xff0c;可以巧妙一些&#xff0c…

【马蹄集】—— 百度之星 2023

百度之星 2023 目录 BD202301 公园⭐BD202302 蛋糕划分⭐⭐⭐BD202303 第五维度⭐⭐ BD202301 公园⭐ 难度&#xff1a;钻石    时间限制&#xff1a;1秒    占用内存&#xff1a;64M 题目描述 今天是六一节&#xff0c;小度去公园玩&#xff0c;公园一共 N N N 个景点&am…

快速灵敏的 Flink1

一、flink单机安装 1、解压 tar -zxvf ./flink-1.13.2-bin-scala_2.12.tgz -C /opt/soft/ 2、改名字 mv ./flink-1.13.2/ ./flink1132 3、profile配置 #FLINK export FLINK_HOME/opt/soft/flink1132 export PATH$FLINK_HOME/bin:$PATH 4、查看版本 flink --version 5、…

轻量封装WebGPU渲染系统示例<14>- 多线程模型载入(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/main/src/voxgpu/sample/ModelLoadTest.ts 此示例渲染系统实现的特性: 1. 用户态与系统态隔离。 细节请见&#xff1a;引擎系统设计思路 - 用户态与系统态隔离-CSDN博客 2. 高频调用与低频调用隔离。 …

C语言--判断一个年份是否是闰年(详解)

一.闰年的定义 闰年是指在公历&#xff08;格里高利历&#xff09;中&#xff0c;年份可以被4整除但不能被100整除的年份&#xff0c;或者可以被400整除的年份。简单来说&#xff0c;闰年是一个比平年多出一天的年份&#xff0c;即2月有29天。闰年的目的是校准公历与地球公转周…

CH10_简化条件逻辑

分解条件表达式&#xff08;Decompose Conditional&#xff09; if (!aDate.isBefore(plan.summerStart) && !aDate.isAfter(plan.summerEnd))charge quantity * plan.summerRate; elsecharge quantity * plan.regularRate plan.regularServiceCharge;if (summer())…

【蓝桥杯省赛真题42】Scratch舞台特效 蓝桥杯少儿编程scratch图形化编程 蓝桥杯省赛真题讲解

目录 scratch舞台特效 一、题目要求 编程实现 二、案例分析 1、角色分析

【移远QuecPython】EC800M物联网开发板的内置GNSS定位的恶性BUG(目前没有完全的解决方案)

【移远QuecPython】EC800M物联网开发板的内置GNSS定位的恶性BUG&#xff08;目前没有完全的解决方案&#xff09; GNSS配置如下&#xff1a; 【移远QuecPython】EC800M物联网开发板的内置GNSS定位获取&#xff08;北斗、GPS和GNSS&#xff09; 测试视频&#xff08;包括BUG复…

Iceberg教程

目录 教程来源于尚硅谷1. 简介1.1 概述1.2 特性 2. 存储结构2.1 数据文件(data files)2.2 表快照(Snapshot)2.3 清单列表(Manifest list)2.4 清单文件(Manifest file)2.5 查询流程分析 3. 与Flink集成3.1 环境准备3.1.1 安装Flink3.1.2 启动Sql-Client 3.2 语法 教程来源于尚硅…

【RabbitMQ】RabbitMQ 消息的可靠性 —— 生产者和消费者消息的确认,消息的持久化以及消费失败的重试机制

文章目录 前言&#xff1a;消息的可靠性问题一、生产者消息的确认1.1 生产者确认机制1.2 实现生产者消息的确认1.3 验证生产者消息的确认 二、消息的持久化2.1 演示消息的丢失2.2 声明持久化的交换机和队列2.3 发送持久化的消息 三、消费者消息的确认3.1 配置消费者消息确认3.2…

Git从基础到实践

1.Git是用来做什么的&#xff1f; git就是一款版本控制软件&#xff0c;主要面向代码的管理。你可以理解为Git是一个代码的备份器&#xff0c;给你的每一次修改后的代码做个备份&#xff0c;防止丢失&#xff0c;这个是git最基本的功能。 其次,git不止备份,当你需要比对多…

NEFU数字图像处理(5)图像压缩编码

一、概述 1.1简介 图像压缩编码的过程是在图像存储或传输之前进行&#xff0c;然后再由压缩后的图像数据&#xff08;编码数据&#xff09;恢复出原始图像或者是原始图像的近似图像 无损压缩&#xff1a;在压缩过程中没有信息损失&#xff0c;可由编码数据完全恢复出原始图像有…

iOS App Store上传项目报错 缺少隐私政策网址(URL)解决方法

​ 一、问题如下图所示&#xff1a; ​ 二、解决办法&#xff1a;使用Google浏览器&#xff08;翻译成中文&#xff09;直接打开该网址 https://www.freeprivacypolicy.com/free-privacy-policy-generator.php 按照要求填写APP信息&#xff0c;最后将生成的网址复制粘贴到隐…

【SOC基础】单片机学习案例汇总 Part2:蜂鸣器、数码管显示

&#x1f4e2;&#xff1a;如果你也对机器人、人工智能感兴趣&#xff0c;看来我们志同道合✨ &#x1f4e2;&#xff1a;不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 &#x1f4e2;&#xff1a;文章若有幸对你有帮助&#xff0c;可点赞 &#x1f44d;…

xilinx fpga ddr mig axi

硬件 参考&#xff1a; https://zhuanlan.zhihu.com/p/97491454 https://blog.csdn.net/qq_22222449/article/details/106492469 https://zhuanlan.zhihu.com/p/26327347 https://zhuanlan.zhihu.com/p/582524766 包括野火、正点原子的资料 一片内存是 1Gbit 128MByte 16bit …

【wp】2023鹏城杯初赛 Web web1(反序列化漏洞)

考点&#xff1a; 常规的PHP反序列化漏洞双写绕过waf 签到题 源码&#xff1a; <?php show_source(__FILE__); error_reporting(0); class Hacker{private $exp;private $cmd;public function __toString(){call_user_func(system, "cat /flag");} }class A {p…

Spring基础

文章目录 Spring基础IoC容器基础IoC理论第一个Spring程序Bean注册与配置依赖注入自动装配生命周期与继承工厂模式和工厂Bean注解开发 AOP面向切片配置实现AOP接口实现AOP注解实现AOP Spring基础 Spring是为了简化开发而生&#xff0c;它是轻量级的IoC和AOP的容器框架&#xff…

I/O多路转接之select

承接上文&#xff1a;I/O模型之非阻塞IO-CSDN博客 简介 select函数原型介绍使用 一个select简单的服务器的代码书写 select的缺点 初识select 系统提供select函数来实现多路复用输入/输出模型 select系统调用是用来让我们的程序监视多个文件描述符的状态变化的; 程序会停在s…