word2vector训练代码详解

目录

1.代码实现

2.知识点 


 

1.代码实现

#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集  ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5  
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)  
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape   #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])

embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法 

#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):"""embed_v:表示对中心词进行embedding层embed_u:对上下文词进行embedding层 """v = embed_v(center)                 #中心词的词向量表达u = embed_u(contexts_and_negatives) #上下文词的词向量表达#用中心词来预测上下文词#u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法pred = torch.bmm(v, u.permute(0, 2, 1))  #矩阵乘法(bmm三维乘法),不用管batch_size维度return pred
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed)
tensor([[[3.1980, 3.1980, 3.1980, 3.1980]],[[3.1980, 3.1980, 3.1980, 3.1980]]], grad_fn=<BmmBackward0>)
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape

 torch.Size([2, 1, 4])

#带掩码的二元交叉熵损失
class SigmoidBCELoss(nn.Module):def __init__(self):super().__init__()  #直接继承父类的初始化属性和方法def forward(self, inputs, target, mask=None):#nn.functional.binary_cross_entropy_with_logits表示返回的不是转化后的概率,是原始计算的数据结果#weight=mask权重将掩码带上#reduction='none'表示不将计算结果聚合,算损失时(默认聚合)out = nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction='none')return out.mean(dim=1)  #计算结果是二维的,在索引1维度上聚合求平均
loss = SigmoidBCELoss()
[[1.1, -2.2, 3.3, -4.4]] * 2
[[1.1, -2.2, 3.3, -4.4], [1.1, -2.2, 3.3, -4.4]]
torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2).shape

 torch.Size([2, 4])

#假设数据测试
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
#mask每一行都有4个数值,所以* mask.shape[1]=4
#但是mask中的数值0表示权重,是补充步长的,不重要,需要计算有效序列的损失平均值,所以 / mask.sum(axis=1)
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)

 tensor([0.9352, 1.8462])

#初始化模型参数,定义两个嵌入层
#一开始,embed_weights会标准正态分布的数据初始化
#两个embedding层的参数不一样,不能重复使用,需要初始化定义两个
embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size),nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size))

 

#定义训练过程
def train(net, data_iter, lr, num_epochs, device=dltools.try_gpu()):#修改embedding层的初始化方法,使用nn.init.xavier_uniform_初始化embed.weight权重,在NLP中不使用标准正态分布的额数据初始化权重def init_weights(m):if type(m) == nn.Embedding:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)  net = net.to(device)#设置梯度下降的优化器optimizer = torch.optim.Adam(net.parameters(), lr=lr)#设置绘制可视化的动图(epoch——loss)animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs])#设置累加metric = dltools.Accumulator(2)   #2种数据需要累加for epoch in range(num_epochs):  #遍历训练次数#设置计时器, 赋值批次数量timer, num_batches = dltools.Timer(), len(data_iter)    #data_iter是分好批次的数据集,长度就是批次数量num_batchesfor i, batch in enumerate(data_iter):   #i是索引, batch是取出的一批批数据#梯度清零optimizer.zero_grad()#接收中心词, 上下文词_噪声词, 掩码, 标记目标值 center, context_negative, mask, label = [data.to(device) for data in batch]#调用skip_gram模型预测pred = skip_gram(center, context_negative, embed_v=net[0], embed_u=net[1])#计算损失l = loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.shape[1] * mask.sum(dim=1)#用loss反向传播  ,loss先sum()聚合变成标量(合并成一个数值), 只有标量才能反向传播l.sum().backward()#梯度更新optimizer.step()#累加metric.add(l.sum(), l.numel())   #l.sum()数值求和累加, l.numel()数量累加#   %  取余数      #  //  商向下取整#迭代到总数据量的5%的倍数时 或者 处理到最后一批数据时,执行下面操作#  i+1是因为i是从0开始遍历的if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:  #epoch + (i+1) / num_batches当前迭代次数占整个数据集的比例animator.add(epoch + (i+1) / num_batches, (metric[0] / metric[1]))print(f'loss {metric[0] / metric[1]:.3f}', f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')      
lr, num_epochs = 0.002, 50
train(net, data_iter, lr, num_epochs)

#如果能够找到词的近义词, 就说明训练的不错
def get_similar_tokens(query_token, k, embed):"""query_token:需要预测的词k:最高相似度的词数量embed:embedding层的哪一层"""#获取词向量权重    (词向量权重*词的one_hot编码,就是词向量)W = embed.weight.dataprint(f'W的shape:{W.shape}')x = W[vocab[query_token]]     #embedding层是按照索引查表查词对应的权重-->优点print(f'x的shape:{x.shape}')#计算余弦相似度#torch.mv两个向量的点乘cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)print(f'cos的shape:{cos.shape}')#排序选择前k个对应的索引topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')for i in topk[1:]:   #排除query_token他本身,自己与自己余弦相似度最高print(f'cosine sim={float(cos[i]):.3f}:{vocab.to_tokens(i)}')
get_similar_tokens('food', 3, net[0])

 

W的shape:torch.Size([6719, 100])
x的shape:torch.Size([100])
cos的shape:torch.Size([6719])
cosine sim=0.430:feed
cosine sim=0.418:precious
cosine sim=0.412:drink

2.知识点 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/435948.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《深度学习》卷积神经网络CNN 实现手写数字识别

目录 一、卷积神经网络CNN 1、什么是CNN 2、核心 3、构造 二、案例实现 1、下载训练集、测试集 代码实现如下&#xff1a; 2、展示部分图片 运行结果&#xff1a; 3、图片打包 运行结果&#xff1a; 4、判断当前使用的CPU还是GPU 5、定义卷积神经网络 运行结果&a…

后端-对表格数据进行添加、删除和修改

一、添加 要求&#xff1a; 按下添加按钮出现一个板块输入添加的数据信息&#xff0c;点击板块的添加按钮&#xff0c;添加&#xff1b;点击取消&#xff0c;板块消失。 实现&#xff1a; 1.首先&#xff0c;设计页面输入框格式&#xff0c;表格首行 2.从数据库里调数据 3.添加…

LPDDR4芯片学习(二)——Functional Description

一、LPDDR4寻址表 以每个die容量为4GB为例&#xff1a; Memory density(per channel) 2Gb&#xff1a;每个通道大小为2Gb&#xff0c;一个die有两个通道Configuration 16Mb 16DQ 8 banks 2 channels &#xff1a;16Mb的寻址空间16位每个channels8个bank*每个die两channels。1…

Java基础(Arrays工具类)(asList()方法)(详细)

目录 一、Arrays工具类 &#xff08;1&#xff09;引言 &#xff08;2&#xff09;基本介绍 &#xff08;3&#xff09;主要功能&#xff08;提供的方法&#xff09; &#xff08;I&#xff09;排序&#xff08;Arrays.sort()&#xff09; &#xff08;II&#xff09;搜索(查找…

ECCV 2024 现场:参会者付高价、跨万里,却无法入场?

ECCV&#xff08;European Conference on Computer Vision&#xff0c;欧洲计算机视觉国际会议&#xff09;是计算机视觉领域的重要国际会议之一&#xff0c;与CVPR和ICCV并称为计算机视觉的三大顶级会议。 ECCV2024是该系列会议的第18届会议&#xff0c;2024年9月29日至10月4…

第3篇:常见的Webshell查杀工具----应急响应篇

当网站服务器被入侵时&#xff0c;我们需要一款Webshell检测工具&#xff0c;来帮助我们发现webshell&#xff0c;进一步排查系统可能存在的安全漏洞。 本文推荐了10款Webshll检测工具&#xff0c;用于网站入侵排查。当然&#xff0c;目前市场上的很多主机安全产品也都提供这种…

引入Scrum激发研发体系活力

引言 在当今快速变化的技术环境中&#xff0c;IT企业面临着持续的市场压力和竞争&#xff0c;传统的瀑布式开发模式已经难以满足现代企业的需要。瀑布模型过于僵化&#xff0c;缺乏灵活性&#xff0c;导致项目经常延期&#xff0c;成本增加&#xff0c;最终可能无法达到预期效果…

这款工具在手,前端开发轻松搞定!

这款工具在手&#xff0c;前端开发轻松搞定&#xff01; 引言 在之前的一篇文章中&#xff0c;已经给大家分享了一款AI助手。尽管该助手能够生成前端代码&#xff0c;但遗憾的是缺少了实时预览的功能。而现在&#xff0c;这一缺憾已经被弥补——你只需要描述你的设计想法&…

递归算法介绍和【题解】——数楼梯

递归算法介绍和【题解】——数楼梯 1.递推算法介绍2.数楼梯题目描述输入格式输出格式输入输出样例输入 #1输出 #1 提示 1.思路解析2.AC代码 1.递推算法介绍 有些目标是宏大的&#xff0c;比如如果你想找到一个好工作&#xff0c;需要先把面试通过。要把面试通过&#xff0c;就需…

力扣(leetcode)每日一题 1014 最佳观光组合

题干 1014. 最佳观光组合 给你一个正整数数组 values&#xff0c;其中 values[i] 表示第 i 个观光景点的评分&#xff0c;并且两个景点 i 和 j 之间的 距离 为 j - i。 一对景点&#xff08;i < j&#xff09;组成的观光组合的得分为 values[i] values[j] i - j &#…

总结C/C++中内存区域划分

目录 1.C/C程序内存分配主要的几个区域&#xff1a; 2.内存分布图 1.C/C程序内存分配主要的几个区域&#xff1a; 1、栈区 2、堆区 3、数据段&#xff08;静态区&#xff09; 4.代码段 2.内存分布图 如图&#xff1a; static修饰静态变量成员——放在静态区 int globalVar 是…

uniapp在线打包的ios后调用摄像头失败的解决方法

uniapp在线打包的ios后调用摄像头失败的解决方法 解决方法&#xff1a; 由于未选中打包模块的配置 当你在测试时发现能够正常的开启摄像头&#xff0c;但是当你对其进行在线打包后&#xff0c;发现当你点击启用摄像头时&#xff0c;没有反应&#xff0c;或者是打开是黑屏状态…

《情书》你的名字,是最美的情书

《情书》你的名字&#xff0c;是最美的情书 岩井俊二&#xff0c;日本电影导演&#xff0c;作家及记录片导演。被誉为日本最有潜质的新近“映像作家”&#xff0c;也有中国影迷称他为“日本王家卫”。影像清新独特、感情细腻丰富。&#xff08;来自豆瓣&#xff09; 穆晓芳 译 …

网页WebRTC电话和软电话哪个好用?

关于WebRTC电话与软件电话哪个更好用&#xff0c;这实际上取决于多个因素&#xff0c;并没有一个绝对的答案。不过&#xff0c;我可以根据WebRTC技术的一些特点&#xff0c;以及与传统软件电话相比的优劣势&#xff0c;为你提供一个清晰的对比。 首先&#xff0c;让我们了解一下…

无监督算法目标识别-工业异常检测模型Padim+PatchCore的C++_libtorch实现

基于anomalib的python代码完美复现 示例&#xff1a; 使用无监督算法识别缺陷&#xff1a;图像复杂不能太高&#xff0c;尽量是简单背景的图片&#xff0c;如果太复杂了还是直接上有监督算法识别泛化能力强 代码实现详见&#xff1a;****Gitee

11.全面学习面向对象技术

面向对象开发 相关概念 对象&#xff1a;由数据及其操作所构成的封装体&#xff0c;是系统中用来描述客观事务的一个实体&#xff0c;是构成系统的一个基本单位。一个对象通常可以由对象名、属性和方法3个部分组成。类&#xff1a;现实世界中实体的形式化描述&#xff0c;类…

Chainlit集成LlamaIndex实现知识库高级检索(组合对象检索)

检索原理 对象组合索引的原理 是利用IndexNode索引节点&#xff0c;将两个不同类型的检索器作为节点对象&#xff0c;使用 SummaryIndex &#xff08;它可以用来构建一个包含多个索引节点的索引结构。这种索引通常用于从多个不同的数据源或索引方法中汇总信息&#xff0c;并能…

第18章 中断和异常的处理与抢占式多任务

第18章 中断和异常的处理与抢占式多任务 中断和异常 中断和异常概述 中断&#xff08;Interrupt&#xff09;&#xff1a; 硬件中断是由外围硬件设备发出的中断信号引发的&#xff0c;以请求处理器提供服务。软中断是由int n指令引发的中断处理&#xff0c;n是中断号或者叫…

【Python】数据可视化之分布图

分布图主要用来展示某些现象或数据在地理空间、时间或其他维度上的分布情况。它可以清晰地反映出数据的空间位置、数量、密度等特征&#xff0c;帮助人们更好地理解数据的内在规律和相互关系。 目录 单变量分布 变量关系组图 双变量关系 核密度估计 山脊分布图 单变量分布…

5.数据结构与算法-类C语言的有关操作

元素类型说明 数组定义 C语言的动态内存分配 C动态存储分配 C的参数传递 传值方式 传地址方式 形参变化影响实参 形参变化不影响实参 数组名做参数 引用类型做参数