第N6周:使用Word2vec实现文本分类

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
#忽略警告信息
warnings.filterwarnings("ignore")
# win10系统
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
deviceimport pandas as pd
# 加载自定义中文数据
train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
train_data.head()# 构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,y
x = train_data[0].values[:]
#多类标签的one-hot展开
y = train_data[1].values[:]from gensim.models.word2vec import Word2Vec
import numpy as np
#训练word2Vec浅层神经网络模型
w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)# 将文本转化为向量
def average_vec(text):vec =np.zeros(100).reshape((1,100))for word in text:try:vec +=w2v.wv[word].reshape((1,100))except KeyError:continuereturn vec
#将词向量保存为Ndarray
x_vec= np.concatenate([average_vec(z)for z in x])
#保存Word2Vec模型及词向量
w2v.save('data/w2v_model.pk1')train_iter= coustom_data_iter(x_vec,y)
len(x),len(x_vec)label_name =list(set(train_data[1].values[:]))
print(label_name)text_pipeline =lambda x:average_vec(x)
label_pipeline =lambda x:label_name.index(x)text_pipeline("你在干嘛")
label_pipeline("Travel-Query")from torch.utils.data import DataLoader
def collate_batch(batch):label_list,text_list=[],[]for(_text,_label)in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device),label_list.to(device)
# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,
shuffle =False,
collate_fn=collate_batch)from torch import nn
class TextclassificationModel(nn.Module):def __init__(self,num_class):super(TextclassificationModel,self).__init__()self.fc = nn.Linear(100,num_class)def forward(self,text):return self.fc(text)num_class =len(label_name)
vocab_size =100000
em_size=12
model= TextclassificationModel(num_class).to(device)import time
def train(dataloader):model.train()#切换为训练模式total_acc,train_loss,total_count =0,0,0log_interval=50start_time= time.time()for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)# grad属性归零optimizer.zero_grad()loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,labelloss.backward()#反向传播torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪optimizer.step()#每一步自动更新#记录acc与losstotal_acc+=(predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval==0 and idx>0:elapsed =time.time()-start_timeprint('Iepoch {:1d}I{:4d}/{:4d} batches''|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count =0,0,0start_time = time.time()
def evaluate(dataloader):model.eval()#切换为测试模式total_acc,train_loss,total_count =0,0,0with torch.no_grad():for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)loss = criterion(predicted_label,label)# 计算loss值# 记录测试数据total_acc+=(predicted_label.argmax(1)== label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count,train_loss/total_countfrom torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS=10#epoch
LR=5 #学习率
BATCH_SIZE=64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.SGD(model.parameters(),lr=LR)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None
# 构建数据集
train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr =optimizer.state_dict()['param_groups'][0]['1r']if total_accu is not None and total_accu>val_acc:scheduler.step()else:total_accu = val_accprint('-'*69)print('|epoch {:1d}|time:{:4.2f}s |''valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))print('-'*69)# test_acc,test_loss =evaluate(valid_dataloader)
# print('模型准确率为:{:5.4f}'.format(test_acc))
#
#
# def predict(text,text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text),dtype=torch.float32)
#         print(text.shape)
#         output = model(text)
#         return output.argmax(1).item()
# # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
# ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
# model=model.to("cpu")
# print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

[[-0.85472693  0.96605204  1.5058695  -0.06065784 -2.10079319 -0.120211511.41170089  2.00004494  0.90861696 -0.62710127 -0.62408304 -3.805954991.02797993 -0.45584389  0.54715634  1.70490362  2.33389823 -1.996075184.34822938 -0.76296186  2.73265275 -1.15046433  0.82106878 -0.32701646-0.50515595 -0.37742117 -2.02331601 -1.365334    1.48786476 -1.63949711.59438308  2.23569647 -0.00500725 -0.65070192  0.07377997  0.01777986-1.35580809  3.82080549 -2.19764423  1.06595343  0.99296588  0.58972518-0.33535255  2.15471306 -0.52244038  1.00874437  1.28869729 -0.72208139-2.81094289  2.2614549   0.20799019 -2.36187895 -0.94019454  0.49448857-0.68613767 -0.79071895  0.47535057 -0.78339124 -0.71336574 -0.279315671.0514895  -1.76352624  1.93158554 -0.85853558 -0.65540617  1.3612217-1.39405773  1.18187538  1.31730198 -0.02322496  0.14652854  0.222498812.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.277228682.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767  2.12093259-1.2326943  -1.89571355  2.3295732  -0.53244872 -0.67313893 -0.808146040.86987564 -1.31373079  1.33797717  1.02223087  0.5817025  -0.835356470.97088164  2.09045361 -2.57758138  0.07126901]]
6

输出结果并非为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/295852.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Golang面试系列3-内存管理

3.1 内存分配机制 Go内存管理本质上是一个经过内部优化的内存池:自动伸缩内存池大小,合理的切割内存块。 分配逻辑:针对不同大小对象有不同的分配逻辑 (0,16B)且不含指针的对象:Tiny分配(0,16B)且含指针的对象:正常…

推荐多样性 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 200分 题解: Java / Python / C++ 题目描述 推荐多样性需要从多个列表中选择元素,一次性要返回N屏数据(窗口数量),每屏展示K个元素(窗口大小),选择策略: 各个列表元素需要做穿插处理,即先从第一个列表中为每屏选择一个元素,再从第二个列表…

k8s 基础入门

1.namespace k8s中的namespace和docker中namespace是两码事,可以理解为k8s中的namespace是为了多租户,dockers中的namespace是为了网络、资源等隔离 2.deployment kubectl create #新建 kubectl aply #新建 更新 升级: 滚动升级&#x…

pygame--坦克大战(一)

项目搭建 本游戏主要分为两个对象,分别是我方坦克和敌方坦克。用户可以通过控制我方的坦克来摧毁敌方的坦克保护自己的“家”,把所有的敌方坦克消灭完达到胜利。敌方的坦克在初始的时候是默认5个的(这可以自己设置),当然,如果我方坦克被敌方坦克的子弹打中,游戏结束。从…

蓝色wordpress外贸建站模板

蓝色wordpress外贸建站模板 https://www.mymoban.com/wordpress/7.html

Ansys Zemax | 如何将光栅数据从Lumerical导入至OpticStudio(上)

附件下载 联系工作人员获取附件 本文介绍了一种使用Ansys Zemax OpticStudio和Lumerical RCWA在整个光学系统中精确仿真1D/2D光栅的静态工作流程。将首先简要介绍方法。然后解释有关如何建立系统的详细信息。 本篇内容将分为上下两部分,上部将首先简要介绍方法工…

C++——异常机制

目录 一,背景 1.1 C语言处理错误的方式 1.2 C异常概念 二,异常的使用 2.1 异常的简单使用 2.2 异常的匹配原则 2.3 异常抛对象 2.4 异常的重新抛出 2.5 异常安全 三,自定义异常体系 四,异常优缺点 4.1 优点 4.2 缺点 …

声音文件格式有哪几种?常见的声音格式和转换方法分享

随着数字技术的飞速发展,声音文件已经成为了我们日常生活和工作中不可或缺的一部分。无论是音乐、电影、游戏还是各类应用程序,声音文件都扮演着重要的角色。本文将为大家介绍常见的声音文件格式以及如何进行格式转换。 一、常见的声音文件格式 &#x…

shopee虾皮业绩一直没办法提升?不同时期要有不同的运营思路

店铺运营“开荒期”需要根据自身店铺数据调整运营策略,“运营期”就需要更多分析竞品的运营数据,分析接近上架时间段的出单同款/相似款,有效找到影响起量的因素;在出单缓慢,接近瓶颈期时找同行的策略方案,抓…

filetype: python中判断图像格式库imghdr替代库

引言 imghdr库是python中的一个内置库,用来判断图像原本格式的。自己一直有在用,不过近来看到这个库在python 3.13中会被移除。 自己感觉一直被python版本赶着走。这不找了好久,才找到一个替代库–filetype Python各个版本将要移除和可替代…

后台返回数据需要自己匹配图标,图标命名与后台返回的变量保持一致

testItemId为后台返回匹配图标的变量名 sportsTargetsData:{suggestSportTargetId: "2",unlocks: [{ testItemId: vo2max_high_knee, sportTargetName: 心肺能力, indexName: 心肺能力, sportTargetId: 1 },{ testItemId: grip_strength, sportTargetName: 基础力量…

第12章 集合框架

一 集合框架概述 1.1 生活中的容器 1.2 数组的特点与弊端 一方面,面向对象语言对事物的体现都是以对象的形式,为了方便对多个对象的操作,就要对对象进行存储。另一方面,使用数组存储对象方面具有一些弊端,而Java 集合…

Linux学习笔记————C 语言版 LED 灯实验

这里写目录标题 一、实验程序编写二、 汇编部分实验程序编写三、C 语言部分实验程序编写四、编译下载验证 汇编 LED 灯实验中,我们讲解了如何使用汇编来编写 LED 灯驱动,实际工作中是很少用到汇编去写嵌入式驱动的,毕竟汇编太难,而…

Java中的可变字符串

Java中的可变字符串 一、什么是可变字符串二、可变字符串的使用场景以及使用步骤1.新建一个可变字符串2.可变字符串的一系列方法 一、什么是可变字符串 可变字符串是Java.lang包下的 在我们学习到JDBC的时候需要将原有的sql语句根据不同的差异添加一段新的关键字或者单词&…

Spring Boot 学习(2)——HelloWorld

HelloWorld!全宇宙码农的第一个(行)程序(代码)。 1、创建项目 打开idea,新建一个maven项目。 1)选择项目sdk(本例是1.8) 2)输入GroupId(co…

DDL ---- 数据库的操作

1.查询所有数据库 show databases; 上图除了自创的,其他的四个都是mysql自带的数据库 。(不区分大小写) 2.查询当前数据库 select database(); 最开始没有使用数据库,那么查找结果为NULL 所以我们就需要先使用数据库&#xff…

Spring Boot 配置文件

1. 配置文件的作用 配置文件主要是为了解决硬件编码带来的问题,把可能会发生改变的信息,放在一个集中的地方,当我们启动某个程序时,程序从配置文件中读取一些数据,并加载运行。 硬编码是将数据直接放在源代码中&…

【HTML】标签学习(下.2)

(大家好哇,今天我们将继续来学习HTML(下.2)的相关知识,大家可以在评论区进行互动答疑哦~加油!💕) 目录 二.列表标签 2.1 无序列表(重点) 2.2有序列表(理解) 2.3 自定义列表(重点…

数字信号处理实验操作教程:3-3 mp3音频编码实验(AD7606采集)

一、实验目的 学习AD7606采集音频数据的方法并掌握MP3音频编码的原理,并实现AD7606采集音频数据进行MP3编码并保存到SD卡。 二、实验原理 AD7606原理图 硬件原理图,找到AD采集,可查看相关控制引脚,同时可看到ADC输入的V1V8通道…

稀碎从零算法笔记Day37-LeetCode:所有可能的真二叉树

今天的每日一题,感觉理解的还不够深,有待加深理解 题型:树、分治、递归 链接:894. 所有可能的真二叉树 - 力扣(LeetCode) 来源:LeetCode 题目描述 给你一个整数 n ,请你找出所有…