深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

文章目录

一、前期工作
1. 设置GPU
2. 导入预处理词库类
二、导入预处理词库类
三、参数设定
四、创建模型
五、训练模型函数
六、测试模型函数
七、训练模型与预测

今天给大家带来一个简单的中文新闻分类模型,利用TextCNN模型进行训练,TextCNN的主要流程是:获取文本的局部特征:通过不同的卷积核尺寸来提取文本的N-Gram信息,然后通过最大池化操作来突出各个卷积操作提取的最关键信息,拼接后通过全连接层对特征进行组合,最后通过交叉熵损失函数来训练模型。

                                                                 textCNN的模型架构
注:N-Gram是大词汇连续语音识别中常用的一种语言模型。⼜被称为⼀阶马尔科夫链。它的基本思想是将⽂本⾥⾯的内容按照字节进行大小为 N 的滑动窗⼝操作,形成了长度是 N 的字节⽚段序列。每⼀个字节⽚段称为 gram,对所有的 gram 的出现频度进⾏统计,并且按照事先设定好的阈值进⾏过滤,形成关键 gram 列表,是这个⽂本的向量特征空间。列表中的每⼀种 gram 就是⼀个特征向量维度。

一、前期工作

1. 设置GPU

如果使用的是CPU可以注释掉这部分的代码。

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")#导入库包
import tensorflow.keras as keras
from config import Config
import os
from sklearn import metrics
import numpy as np
from keras.models import Sequential
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding,Dropout,Conv1D,ReLU,GlobalMaxPool1D,InputLayer

2. 导入预处理词库类

trainingSet_path = "cnews.train.txt"
valSet_path = "cnews.val.txt"
model_save_path = "CNN_model.h5"
testingSet_path = "cnews.test.txt"#创建 文本处理类:preprocesser
class preprocesser(object):def __init__(self):self.config = Config()# 读取文本txt 函数def read_txt(self, txt_path):with open(txt_path, "r", encoding='utf-8') as f:data = f.readlines()labels = []contents = []for line in data:label, content = line.strip().split('\t')labels.append(label)contents.append(content)return labels, contents# 读取分词文档def get_vocab_id(self):vocab_path = "cnews.vocab.txt"with open(vocab_path, "r", encoding="utf-8") as f:infile = f.readlines()vocabs = list([word.replace("\n", "") for word in infile])vocabs_dict = dict(zip(vocabs, range(len(vocabs))))return vocabs, vocabs_dict# 获取新闻属性id 函数def get_category_id(self):categories = ["体育", "财经", "房产", "家居", "教育", "科技", "时尚", "时政", "游戏", "娱乐"]cates_dict = dict(zip(categories, range(len(categories))))return cates_dict#将语料中各文本转换成固定max_length后返回各文本的标签与文本tokensdef word2idx(self, txt_path, max_length):# vocabs:分词词汇表# vocabs_dict:各分词的索引vocabs, vocabs_dict = self.get_vocab_id()# cates_dict:各分类的索引cates_dict = self.get_category_id()# 读取语料labels, contents = self.read_txt(txt_path)# labels_idx:用来存放语料中的分类labels_idx = []# contents_idx:用来存放语料中各样本的索引contents_idx = []# 遍历语料for idx in range(len(contents)):# tmp:存放当前语句indextmp = []# 将该idx(样本)的标签加入至labels_idx中labels_idx.append(cates_dict[labels[idx]])# contents[idx]:为该语料中的样本遍历项# 遍历contents中各词并将其转换为索引后加入contents_idx中for word in contents[idx]:if word in vocabs:tmp.append(vocabs_dict[word])else:# 第5000位设置为未知字符tmp.append(5000)# 将该样本index后结果存入contents_idx作为结果等待传回contents_idx.append(tmp)# 将各样本长度pad至max_lengthx_pad = keras.preprocessing.sequence.pad_sequences(contents_idx, max_length)y_pad = keras.utils.to_categorical(labels_idx, num_classes=len(cates_dict))return x_pad, y_paddef word2idx_for_sample(self, sentence, max_length):# vocabs:分词词汇表# vocabs_dict:各分词的索引vocabs, vocabs_dict = self.get_vocab_id()result = []# 遍历语料for word in sentence:# tmp:存放当前语句indexif word in vocabs:result.append(vocabs_dict[word])else:# 第5000位设置为未知字符,实际中为vocabs_dict[5000],使得vocabs_dict长度变成len(vocabs_dict+1)result.append(5000)x_pad = keras.preprocessing.sequence.pad_sequences([result], max_length)return x_padpre = preprocesser() # 实例化preprocesser()类

数据集样式:

 二、参数设定

num_classes = 10     # 类别数
vocab_size = 5000    #语料词大小
seq_length = 600     #词长度conv1_num_filters = 128   # 第一层输入卷积维数
conv1_kernel_size = 1     # 卷积核数conv2_num_filters = 64   # 第二层输入卷维数
conv2_kernel_size = 1    # 卷积核数hidden_dim = 128         # 隐藏层维度
dropout_keep_prob = 0.5  # dropout层丢弃0.5batch_size = 64     # 每次训练批次数  

四、创建模型

def TextCNN():#创建模型序列model = Sequential()model.add(InputLayer((seq_length,)))model.add(Embedding(vocab_size+1, 256, input_length=seq_length))model.add(Conv1D(conv1_num_filters, conv1_kernel_size, padding="SAME"))model.add(Conv1D(conv2_num_filters, conv2_kernel_size, padding="SAME"))model.add(GlobalMaxPool1D())model.add(Dense(hidden_dim))model.add(Dropout(dropout_keep_prob))model.add(ReLU())model.add(Dense(num_classes, activation="softmax"))model.compile(loss="categorical_crossentropy",optimizer="adam",metrics=["acc"])print(model.summary())return model

五、训练模型函数

def train(epochs):model = TextCNN()model.summary()x_train, y_train = pre.word2idx(trainingSet_path, max_length=seq_length)x_val, y_val = pre.word2idx(valSet_path, max_length=seq_length)model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_val, y_val))model.save(model_save_path, overwrite=True)

 六、测试模型函数

def test():if os.path.exists(model_save_path):model = keras.models.load_model(model_save_path)print("-----model loaded-----")model.summary()x_test, y_test = pre.word2idx(testingSet_path, max_length=seq_length)print(x_test.shape)print(type(x_test))print(y_test.shape)# print(type(y_test))pre_test = model.predict(x_test)# print(pre_test.shape)# metrics.classification_report(np.argmax(pre_test, axis=1), np.argmax(y_test, axis=1), digits=4, output_dict=True)print(metrics.classification_report(np.argmax(pre_test, axis=1), np.argmax(y_test, axis=1)))

七、训练模型与预测

if __name__ == '__main__':train(20)  # 训练模型Epoch 1/20
782/782 [==============================] - 119s 152ms/step - loss: 0.7380 - accuracy: 0.7696 - val_loss: 0.5568 - val_accuracy: 0.8334
Epoch 2/20
782/782 [==============================] - 122s 156ms/step - loss: 0.3898 - accuracy: 0.8823 - val_loss: 0.4342 - val_accuracy: 0.8588
Epoch 3/20
782/782 [==============================] - 121s 154ms/step - loss: 0.3382 - accuracy: 0.8979 - val_loss: 0.4154 - val_accuracy: 0.8648
Epoch 4/20
782/782 [==============================] - 116s 148ms/step - loss: 0.3091 - accuracy: 0.9055 - val_loss: 0.4408 - val_accuracy: 0.8688
Epoch 5/20
782/782 [==============================] - 117s 150ms/step - loss: 0.2904 - accuracy: 0.9116 - val_loss: 0.3880 - val_accuracy: 0.8844
Epoch 6/20
782/782 [==============================] - 119s 153ms/step - loss: 0.2724 - accuracy: 0.9153 - val_loss: 0.4412 - val_accuracy: 0.8664
Epoch 7/20
782/782 [==============================] - 117s 149ms/step - loss: 0.2601 - accuracy: 0.9206 - val_loss: 0.4217 - val_accuracy: 0.8726
Epoch 8/20
782/782 [==============================] - 116s 149ms/step - loss: 0.2423 - accuracy: 0.9243 - val_loss: 0.4205 - val_accuracy: 0.8760
Epoch 9/20
782/782 [==============================] - 117s 150ms/step - loss: 0.2346 - accuracy: 0.9275 - val_loss: 0.4022 - val_accuracy: 0.8808
Epoch 10/20
782/782 [==============================] - 116s 148ms/step - loss: 0.2249 - accuracy: 0.9301 - val_loss: 0.4297 - val_accuracy: 0.8726
....model = keras.models.load_model(model_save_path)print("-----model loaded-----")model.summary()test = preprocesser()# 测试文本x_test = '5月6日,上海莘庄基地田径特许赛在第二体育运动学校鸣枪开赛。男子110米栏决赛,19岁崇明小囡秦伟搏以13.35秒的成绩夺冠,创造本赛季亚洲最佳。谢文骏迎来赛季首秀,以13.38秒获得亚军'x_test = test.word2idx_for_sample(x_test, 600)categories = ["体育", "财经", "房产", "家居", "教育", "科技", "时尚", "时政", "游戏", "娱乐"]pre_test = model.predict(x_test)index = int(np.argmax(pre_test, axis=1)[0])print('该新闻为:', categories[index])

训练20次后,训练集损失函数loss: 0.1635 ,训练集准确率:accuracy: 0.9462

验证集函数:val_loss: 0.4554 验证集准确率 val_accuracy: 0.8820

运行结果:该新闻为: 体育

 往期作品:

深度学习实战项目

1.深度学习实战1-(keras框架)企业数据分析与预测

2.深度学习实战2-(keras框架)企业信用评级与预测

3.深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

4.深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别

5.深度学习实战5-卷积神经网络(CNN)中文OCR识别项目

6.深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测

7.深度学习实战7-电商产品评论的情感分析

8.深度学习实战8-生活照片转化漫画照片应用

9.深度学习实战9-文本生成图像-本地电脑实现text2img

10.深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)

11.深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例

12.深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正

13.深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星

14.深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了

15.深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问

16.深度学习实战16(进阶版)-虚拟截图识别文字-可以做纸质合同和表格识别

17.深度学习实战17(进阶版)-智能辅助编辑平台系统的搭建与开发案例

18.深度学习实战18(进阶版)-NLP的15项任务大融合系统,可实现市面上你能想到的NLP任务

19.深度学习实战19(进阶版)-ChatGPT的本地实现部署测试,自己的平台就可以实现ChatGPT

...(待更新)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/42423.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

各大佬抨击ICML审稿太随意:LeCun三篇全没中,马毅说以后再也不投了

梦晨 发自 凹非寺量子位 | 公众号 QbitAI 顶会ICML结果一出,掀起一片混乱。 LeCun分享自己的“战果”:三篇全没中,推特上就没看到有谁中了的。 UC伯克利教授马毅也表示今后不再投ICML了,评审非常随意、不透明,最大的担…

生物信息经典书籍一揽子推荐!

每次遇到有人问怎么学生信时,总会碰到一个尴尬的问题,“有没有什么书可以推荐下”。骤然之下,也不知道该怎么回答。这个问题有点大,想不到一本书可以囊括。而且对提问人的基础和学习倾向没有了解,与其指条错路&#xf…

AI 建模师 素养手冊(4)

从深度神经网络 (Deep Neural Network) 认识隐藏空间(Latent spaces) By 高焕堂 / AI 建模师俱乐部会长 文章目录 前言一、隐藏空间(Latent spaces)的特质二、从<单层 NN 模型>说起 三、多层的 NN 模型 四、隐藏空间在 AIGC 领域的角总结 前言 *** 本文摘自 高焕堂 的下…

ChatGLM 微调实战

在之前的文章中&#xff0c;我们已经讲过了 ChatGPT 的三个主要流程&#xff1a; SFT&#xff1a;通过 Instruction Tuning 来微调一个监督学习模型。Reward Model&#xff1a;通过排序序列来训练一个打分模型。Reinforcement Learning&#xff1a;通过强化学习来进一步优化模…

【Instruction Tuning】ChatGLM 微调实战(附源码)

在之前的文章中&#xff0c;我们已经讲过了 ChatGPT 的三个主要流程&#xff1a; SFT&#xff1a;通过 Instruction Tuning 来微调一个监督学习模型。Reward Model&#xff1a;通过排序序列来训练一个打分模型。Reinforcement Learning&#xff1a;通过强化学习来进一步优化模…

王者荣耀 业务分析

王者荣耀 业务分析 王者荣耀是一款组队竞技游戏。王者荣耀有三个基本玩家财富字段:金币、钻石、点券。 英雄的获取方式有:限免的英雄可用金币或钻石购买&#xff0c;有一些非限免的可用点券购买&#xff0c;一些英雄可通过特定的游戏活动获得。并结合了游戏奖励机制。 它的主要…

王者荣耀游戏时间计算机制,王者荣耀:荣耀战力计算机制解析,如何提高到金牌乃至国服...

原标题&#xff1a;王者荣耀&#xff1a;荣耀战力计算机制解析&#xff0c;如何提高到金牌乃至国服 每个英雄都对应有一个荣耀战力&#xff0c;战力分越高一定程度上代表着你对该英雄的理解越好&#xff0c;当分数高到一定程度上时&#xff0c;会获得荣耀称号&#xff0c;由低到…

c语言王者荣耀制作,易语言制作王者荣耀刷金币脚本的代码

打开黑夜模拟器&#xff0c;按下F10&#xff0c;王者荣耀进入挑战-魔女回忆&#xff0c;开始即可。 王者荣耀刷金币脚本 此功能需要加载精易模块5.6 .版本 2 .支持库 shellEx .支持库 EThread .支持库 eAPI .程序集 窗口程序集_启动窗口 .程序集变量 热键F10, 整数型 .程序集变…

王者荣耀各服务器位置,盘点王者荣耀各位置国服战力排名,辅助榜表示只有富婆才玩的懂...

原标题&#xff1a;盘点王者荣耀各位置国服战力排名&#xff0c;辅助榜表示只有富婆才玩的懂 七月份也是终于过去了&#xff0c;而国服最新的战力榜也是新鲜出炉&#xff0c;当然小伙伴们可能觉得这些跟自己没什么关系&#xff0c;正所谓内行看门道&#xff0c;外行看热闹。今天…

王者营地登录服务器维护,王者营地怎么查看登录记录

王者营地怎么升级到5级 王者营地怎么升级到5级?具体的升级方法是什么呢?很多小伙伴还不是很了解,那么接下来,就跟随玩游戏网的小编一起继续往下看,感兴趣的小伙伴一定不要错过哦!信誉快速升到5级方法介绍:1、多参与游戏,这个是基础,不然的话信誉积分肯定是无法增加的;…

王者荣耀服务器什么时候维护好19赛季,王者荣耀:S19新赛季开启时间确定,国服战力排名会提前锁定...

S19新赛季的开启时间已经确定&#xff0c;正是之前无意间透露出来的3月31日&#xff0c;正好是周二&#xff0c;符合惯例更新的时间点&#xff0c;但并不是以往大版本更新的周四&#xff0c;为什么会出现这样的情况呢&#xff1f;或许是因为从这一次的苹果的审核比较快吧。 在S…

王者荣耀前端模仿

作品初衷 因为想着要写答辩&#xff0c;网页这东西展现的会更直观一点&#xff0c;刚好也复习一下自己基础我牢固的前端知识&#xff0c;也想着提升一下自己前端网页排版能力和逻辑能力 作品简介 编辑器&#xff1a;sublime 语言&#xff1a;htmlcssjs 代码量&#xff1a;…

鸿蒙os版王者荣耀,王者荣耀鸿蒙版

王者荣耀鸿蒙版下载&#xff0c;快猴网为大家带来的王者荣耀鸿蒙版是为了适配华为的鸿蒙系统而特别设立的版本&#xff0c;玩家可以体验远超一般系统的流畅度&#xff0c;让你的手速能轻松跟上你的意识&#xff0c;享受成为王者荣耀鸿蒙版最强王者的快乐吧! 王者荣耀鸿蒙版游戏…

王者荣耀技术指南

者荣耀技术指南&#xff1a; 本人是在2018年9月和10月都获得国服达摩称号&#xff0c;其实就打了一个月国服达摩&#xff0c;由于9月底正好赛季更新&#xff0c;我月底就打了一把10月份国服达摩了。 先讲一下达摩基本连招&#xff1a; 1闪现a3上墙2破甲3第二段贴近敌人&#xf…

国服最强王者之最良心王者

国服最强王者之最良心王者 何为lol最强王者?想成为最强王者&#xff0c;你必须是所在服务器的前50名&#xff0c;这是绝对实力的体现。并且还会面临着紧随其后的玩家的挑战&#xff0c;若有第51个人挑战”最强王者“成功&#xff0c;则原属于”最强王者“最后一名次的玩家或战…

王者荣耀服务器维护S19,王者荣耀:S19官宣31日更新,国服玩家集体声讨天美:1个月白打了...

前言&#xff1a;就在2020年3月28日深夜&#xff0c;王者荣耀官方官宣了一则重大消息&#xff0c;那就是S19赛季即将在2020年3月31日更新。也就是说今年新赛季的到来真是让人猝不及防。但是新赛季的到来&#xff0c;也意味着众多小伙伴们又有一波福利可以领了。不过几家欢喜几家…

王者荣耀各服务器位置,王者荣耀全国排行功能新上线,位置战力系统介绍[多图]...

王者荣耀全国排行功能新上线。在最近的体验服当中王者荣耀官方又尝试进行了多项改动&#xff0c;在原本的省级和国服最强之间增添了全国排名这一级别的荣耀称号&#xff0c;新添加了根据分路了来判断的位置战力和位置段位系统&#xff0c;同时还改版了国服最强的展示页面。 1.荣…

王者荣耀微信哪个服务器怎么选,王者荣耀:国服战力对比!手Q和微信哪个大区的战力更胜一筹?...

原标题&#xff1a;王者荣耀&#xff1a;国服战力对比&#xff01;手Q和微信哪个大区的战力更胜一筹&#xff1f; 每个玩家都希望自己能够获得国服称号&#xff0c;想要获得的难度却是非常大的&#xff0c;但这也不妨玩家们对国服最高战力的讨论。 在王者荣耀里将国服战力分成了…

如何获取宝宝取名软件注册码

如何获取宝宝取名软件注册码 宝宝的名字不仅要伴随宝宝的一生&#xff0c;还具有改变命运的重大意义。所以现在家长都十分重视给孩子起名&#xff0c;现在很多父母给孩子取名上网搜索、翻阅字典、参考各种资料&#xff0c;取得名字很好听&#xff0c;孩子不一定能用。主要是怕名…

宝宝智能起名,免费起名实现方案,带源码

宝宝智能起名&#xff0c;免费起名实现方案&#xff01;&#xff08;带源码&#xff09; 作为一个coder&#xff0c;没有什么问题是代码解决不了的&#xff0c;哈哈哈&#xff01;调皮一下&#xff01; 一个好的名字可以潜移默化的改变人的一生&#xff0c;很多人从来没有考虑到…