copy is all you need前向绘图 和疑惑标记

疑惑的起因
在这里插入图片描述
简化前向图
在这里插入图片描述
GPT4解释

这段代码实现了一个神经网络模型,包含了BERT、GPT-2和MLP等模块。主要功能是给定一个文本序列和一个查询序列,预测查询序列中的起始和结束位置,使其对应文本序列中的一个短语。具体实现细节如下:

首先,使用BERT将文本序列编码成一个张量作为输入,并从GPT-2模型中提取查询序列的表示。然后,将查询序列的表示拆分为起始和结束两部分,并使用MLP将短语的起始和结束点表示投影到另一个空间中。
接下来,利用生成式模型计算每个查询起始和结束位置之间的所有令牌的概率分布,并根据此生成预测位置。使用交叉熵损失函数来计算这些预测和真实值之间的差距。
最后,计算起始和结束位置的准确性和整个短语的准确性,并返回损失和准确性的值来优化模型的性能。

import torch
from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel
from torch import nn
from torch.nn import functional as Fclass Copyisallyouneed(nn.Module):def __init__(self, **args):super(Copyisallyouneed, self).__init__()self.args = args# bert-encoder modelself.phrase_encoder = AutoModel.from_pretrained(self.args['phrase_encoder_model'][self.args['lang']])self.bert_tokenizer = AutoTokenizer.from_pretrained(self.args['phrase_encoder_tokenizer'][self.args['lang']])self.bert_tokenizer.add_tokens(['<|endoftext|>', '[PREFIX]'])self.prefix_token_id = self.bert_tokenizer.convert_tokens_to_ids('[PREFIX]')self.phrase_encoder.resize_token_embeddings(self.phrase_encoder.config.vocab_size + 2)# model and tokenizerself.tokenizer = AutoTokenizer.from_pretrained(self.args['prefix_encoder_tokenizer'][self.args['lang']])self.vocab_size = len(self.tokenizer)self.pad = self.tokenizer.pad_token_id if self.args['lang'] == 'zh' else self.tokenizer.bos_token_idself.model = GPT2LMHeadModel.from_pretrained(self.args['prefix_encoder_model'][self.args['lang']])self.token_embeddings = nn.Parameter(list(self.model.lm_head.parameters())[0])# MLP: mapping bert phrase start representationsself.s_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))# MLP: mapping bert phrase end representationsself.e_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))self.gen_loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad)@torch.no_grad()def get_query_rep(self, ids):self.eval()output = self.model(input_ids=ids, output_hidden_states=True)['hidden_states'][-1][:, -1, :]return outputdef get_token_loss(self, ids, hs, ids_mask):# no pad tokenlabel = ids[:, 1:]logits = torch.matmul(hs[:, :-1, :],self.token_embeddings.t())# TODO: inner loss function remove the temperature factorlogits /= self.args['temp']loss = self.gen_loss_fct(logits.view(-1, logits.size(-1)), label.reshape(-1))chosen_tokens = torch.max(logits, dim=-1)[1]gen_acc = (chosen_tokens.reshape(-1) == label.reshape(-1)).to(torch.long)valid_mask = (label != self.pad).reshape(-1)valid_tokens = gen_acc & valid_maskgen_acc = valid_tokens.sum().item() / valid_mask.sum().item()return loss, gen_accdef forward(self, batch):## gpt2 query encoderids, ids_mask = batch['gpt2_ids'], batch['gpt2_mask']last_hidden_states = \self.model(input_ids=ids, attention_mask=ids_mask, output_hidden_states=True).hidden_states[-1]# get token lossloss_0, acc_0 = self.get_token_loss(ids, last_hidden_states, ids_mask)## encode the document with the BERT encoder modeldids, dids_mask = batch['bert_ids'], batch['bert_mask']output = self.phrase_encoder(dids, dids_mask, output_hidden_states=True)['hidden_states'][-1]  # [B, S, E]# collect the phrase start representations and phrase end representationss_rep = self.s_proj(output)e_rep = self.e_proj(output)s_rep = s_rep.reshape(-1, s_rep.size(-1))e_rep = e_rep.reshape(-1, e_rep.size(-1))  # [B_doc*S_doc, 768//2]# collect the query representationsquery = last_hidden_states[:, :-1].reshape(-1, last_hidden_states.size(-1))query_start = query[:, :self.model.config.hidden_size // 2]query_end = query[:, self.model.config.hidden_size // 2:]# training the representations of the start tokenscandidate_reps = torch.cat([self.token_embeddings[:, :self.model.config.hidden_size // 2],s_rep], dim=0)logits = torch.matmul(query_start, candidate_reps.t())logits /= self.args['temp']# build the padding mask for query sidequery_padding_mask = ids_mask[:, :-1].reshape(-1).to(torch.bool)# build the padding mask: 1 for valid and 0 for maskattention_mask = (dids_mask.reshape(1, -1).to(torch.bool)).to(torch.long)padding_mask = torch.ones_like(logits).to(torch.long)# Santiy check overpadding_mask[:, self.vocab_size:] = attention_mask# build the position mask: 1 for valid and 0 for maskpos_mask = batch['pos_mask']start_labels, end_labels = batch['start_labels'][:, 1:].reshape(-1), batch['end_labels'][:, 1:].reshape(-1)position_mask = torch.ones_like(logits).to(torch.long)query_pos = start_labels > self.vocab_size# ignore the padding maskposition_mask[query_pos, self.vocab_size:] = pos_maskassert padding_mask.shape == position_mask.shape# overall maskoverall_mask = padding_mask * position_mask## remove the position mask# overall_mask = padding_masknew_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), start_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_1 = (-loss_.sum(dim=-1)).mean()## split the token accuaracy and phrase accuracyphrase_indexes = start_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]phrase_start_acc = phrase_start_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]token_start_acc = token_start_acc.to(torch.float).mean().item()# training the representations of the end tokenscandidate_reps = torch.cat([self.token_embeddings[:, self.model.config.hidden_size // 2:],e_rep], dim=0)logits = torch.matmul(query_end, candidate_reps.t())  # [Q, B*]  logits /= self.args['temp']new_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), end_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_2 = (-loss_.sum(dim=-1)).mean()# split the phrase and token accuracyphrase_indexes = end_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]phrase_end_acc = phrase_end_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]token_end_acc = token_end_acc.to(torch.float).mean().item()return (loss_0,  # token lossloss_1,  # token-head lossloss_2,  # token-tail lossacc_0,  # token accuracyphrase_start_acc,phrase_end_acc,token_start_acc,token_end_acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/105634.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【脚踢数据结构】查找

(꒪ꇴ꒪ )&#xff0c;Hello我是祐言QAQ我的博客主页&#xff1a;C/C语言&#xff0c;Linux基础&#xff0c;ARM开发板&#xff0c;软件配置等领域博主&#x1f30d;快上&#x1f698;&#xff0c;一起学习&#xff0c;让我们成为一个强大的攻城狮&#xff01;送给自己和读者的…

数据结构双向链表

Hello&#xff0c;好久不见&#xff0c;今天我们讲链表的双向链表&#xff0c;这是一个很厉害的链表&#xff0c;带头双向且循环&#xff0c;学了这个链表&#xff0c;你会发现顺序表的头插头删不再是一个麻烦问题&#xff0c;单链表的尾插尾删也变得简单起来了&#xff0c;那废…

智慧化工地SaaS平台源码,PC端+APP端+智慧数据可视化大屏端,源码完全开源不封装,自主研发,支持二开,项目使用,微服务+Java++vue+mysql

智慧工地管理平台充分运用数字化技术&#xff0c;聚焦施工现场岗位一线&#xff0c;依托物联网、互联网、AI等技术&#xff0c;围绕施工现场管理的人、机、料、法、环五大维度&#xff0c;以及施工过程管理的进度、质量、安全三大体系为基础应用&#xff0c;实现全面高效的工程…

Windows下 MySql通过拷贝data目录迁移数据库的方法

MySQL数据库的文件目录下图所示&#xff0c; 现举例说明通过COPY文件夹data下数据库文件&#xff0c;进行数据拷贝的步骤&#xff1b;源数据库运行在A服务器上&#xff0c;拷贝到B服务器&#xff0c;假定B服务器上MySQL数据库已经安装完成&#xff0c;为空数据库。 首先进入A服…

[Stable Diffusion教程] 第一课 原理解析+配置需求+应用安装+基本步骤

第一课 原理解析配置需求应用安装基本步骤 本次内容记录来源于B站的一个视频 以下是自己安装过程中整理的问题及解决方法&#xff1a; 问题&#xff1a;stable-diffusion-webui启动No Python at ‘C:\xxx\xxx\python.exe‘ 解答&#xff1a;打开webui.bat 把 if not de…

React Diff算法

文章目录 React Diff算法一、它的作用是什么&#xff1f;二、React的Diff算法1.了解一下什么是调和&#xff1f;2.react的diff算法3.React Diff的三大策略4.tree diff&#xff1a;1、如果DOM节点出现了跨层级操作&#xff0c;Diff会怎么办? 5. component diff&#xff1a;6. e…

SEMIDRIVE X9U 插入 USB 不识别调试要点

一、前言 客户用芯驰 X9U 平台做的智能座舱产品&#xff0c;在烧写固件时发现&#xff0c;通过 USB 连接到 SSA 的 USB 接口&#xff0c;Windows 上无法识别出 USB 设备&#xff0c;一直处在 Ready 状态。 二、SEMIDRIVE X9U 插入 USB 不识别调试要点 ① 建议客户测量 SoC 的…

macOS M1使用TensorFlow GPU加速

本人是在pycharm运行代码&#xff0c;安装了tensorflow版本2.13.0 先运行代码查看有没有使用GPU加速&#xff1a; import tensorflow as tf# Press the green button in the gutter to run the script. if __name__ __main__:physical_devices tf.config.list_physical_dev…

【Vue】vue2项目使用swiper轮播图2023年8月21日实战保姆级教程

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、npm 下载swiper二、使用步骤1.引入库声明变量2.编写页面3.执行js 总结 前言 swiper轮播图官网 参考文章&#xff0c;最好先看完他的介绍&#xff0c;再看…

动手学深度学习—深度卷积神经网络AlexNet(代码详解)

AlexNet 1. 学习表征1.1 缺少的成分&#xff1a;数据1.2 缺少的成分&#xff1a;硬件 2. AlexNet2.1 模型设计2.2 激活函数2.3 容量控制和预处理 3. 读取数据集4. 训练AlexNet ImageNet classification with deep convolutional neural networks 原文链接&#xff1a;https://d…

视频云存储/安防监控视频智能分析网关V3:占道经营功能详解

违规占道经营者经常会在人流量大、车辆集中的道路两旁摆摊&#xff0c;导致公路交通堵塞&#xff0c;给居民出行的造成不便&#xff0c;而且违规占路密集的地方都是交通事故频频发生的区域。 TSINGSEE青犀视频云存储/安防监控视频/AI智能分析网关V3运用视频AI智能分析技术&…

YOLOv5算法改进(5)— 添加ECA注意力机制

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。ECA注意力机制是一种用于图像处理中的注意力机制&#xff0c;是在通道注意力机制的基础上做了进一步的改进。通道注意力机制主要是通过提取权重&#xff0c;作用在原特征图的通道维度上&#xff0c;而ECA注意力机制则使用了…

git常用操作命令(不定时更新)

git常用操作命令 将某个分支的某次提交迁移到另外一个分支查询这次提交的ID号方法一方法二 切换到目标分支执行commitID合并指令 将某个分支的某次提交迁移到另外一个分支 查询这次提交的ID号 方法一 方法二 切换到目标分支 git checkout 目标分支名 执行commitID合并指令 gi…

MySQL视图

一、视图-介绍及基本语法 视图&#xff08;View&#xff09;是一种虚拟存在的表。视图中的数据并不在数据库中实际存在&#xff0c;行和列数据来自定义视图的查询中使用的表&#xff0c;并且是在使用视图时动态生成的。 通俗的讲&#xff0c;视图只保存了查询的SQL逻辑&#xf…

mysql 、sql server trigger 触发器

sql server mySQL create trigger 触发器名称 { before | after } [ insert | update | delete ] on 表名 for each row 触发器执行的语句块## 表名&#xff1a; 表示触发器监控的对象 ## before | after : 表示触发的时间&#xff0c;before : 表示在事件之前触发&am…

mysql基础——认识索引

一、介绍 “索引”是为了能够更快地查询数据。比如一本书的目录&#xff0c;就是这本书的内容的索引&#xff0c;读者可以通过在目录中快速查找自己想要的内容&#xff0c;然后根据页码去找到具体的章节。 二、优缺点 优势&#xff1a;以快速检索&#xff0c;减少I/O次数&am…

低代码赋能| 智慧园区项目开发痛点及解决方案

智慧园区是一个综合体&#xff0c;集技术开发、产业发展和学术研究于一体。作为未来智慧城市建设的核心&#xff0c;智慧园区充当着“产业大脑”和“指挥中心”的角色。它通过整合园区内的制造资源和第三方服务能力&#xff0c;实现园区各组成部分的协调运作、良性循环和相互促…

课程项目设计--项目建立--宿舍管理系统--springboot后端

前要 项目设计–宿舍管理系统 文章目录 项目建立导入依赖配置文件配置目录结构config配置mybatis-plusswagger 生成实体、mapper和servicebaseEntity统一响应实例响应码接口响应码接口实现统一响应result统一分页响应 项目建立 太长了&#xff0c;修改一下 导入依赖 暂时先加…

yyyy-MM-dd‘T‘HH:mm时间格式探索

yyyy-MM-ddTHH:mm:ss 一直以后这个T是为了避免yyyy-MM-dd HH:mm:ss空格出现解析报错 但是这个T实际是一个标识符&#xff0c;作为小时元素的开始。 T代表后面跟着是时间&#xff0c;Z代表0时区&#xff08;相差北京时间8小时&#xff09; T 即代表 UTC&#xff08;Coodinated U…

【面试】一文讲清组合逻辑中的竞争与冒险

竞争的定义&#xff1a;组合逻辑电路中&#xff0c;输入信号的变化传输到电路的各级逻辑门&#xff0c;到达的时间有先后&#xff0c;也就是存在时差&#xff0c;称为竞争。 冒险的定义&#xff1a;当输入信号变化时&#xff0c;由于存在时差&#xff0c;在输出端产生错误&…