【深度学习】NLP中的对抗训练

        在NLP中,对抗训练往往都是针对嵌入层(包括词嵌入,位置嵌入,segment嵌入等等)开展的,思想很简单,即针对嵌入层添加干扰,从而提高模型的鲁棒性和泛化能力,下面结合具体代码讲解一些NLP中常见对抗训练算法。

1.Fast Gradient Method(FGM)

        FGM的思想是针对词嵌入加入梯度方向的干扰,至于干扰的大小是我们可以调节的,增加干扰后的样本可以作为额外的对抗样本进行训练,以此提高模型的效果。由于我们在训练时会针对每个样本都进行一次额外的增加干扰后的训练,所以使用FGM后训练时间理论上也会大概增加一倍。

        FGM在原训练代码的基础上,主要增加了以下几个额外的操作:针对嵌入层添加干扰并备份参数,计算添加干扰后的损失,梯度回传从而累积添加干扰后的梯度,恢复原来的嵌入层参数。

1.1 算法流程

对于每个x:1.计算x的前向loss、反向传播得到梯度2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上4.将embedding恢复为(1)时的值5.根据(3)的梯度对参数进行更新

  1.2 具体代码

import torch
class FGM():def __init__(self, model):self.model = modelself.backup = {}def attack(self, epsilon=1., emb_name='word_embeddings'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:#print('增加扰动的对象是', name)#print(type(param.grad))self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = epsilon * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='word_embeddings'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.backupparam.data = self.backup[name]self.backup = {}

1.3 具体用法

fgm = FGM(model) # (#1)初始化
for batch_input, batch_label in data:loss = model(batch_input, batch_label) # 正常训练loss.backward() # 反向传播,得到正常的grad# 对抗训练fgm.attack() # (#2)在embedding上添加对抗扰动loss_adv = model(batch_input, batch_label) # (#3)计算含有扰动的对抗样本的lossloss_adv.backward() # (#4)反向传播,并在正常的grad基础上,累加对抗训练的梯度fgm.restore() # (#5)恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

2.Projected Gradient Descent (PGD

        Project Gradient Descent(PGD)是一种迭代攻击算法,相比于普通的FGM 仅做一次迭代,PGD是做多次迭代,每次走一小步,每次迭代都会将扰动投射到规定范围内。其中r为扰动约束空间(一个半径为r的球体),原始的输入样本对应的初识点为球心,避免扰动超过球面。迭代多次后,保证扰动在一定范围内,如下图所示:

 2.1 算法流程

对于每个x:1.计算x的前向loss、反向传播得到梯度并备份对于每步t:2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r(超出范围则投影回epsilon内)3.t不是最后一步: 将梯度归0,根据1的x+r计算前后向并得到梯度4.t是最后一步: 恢复(1)的梯度,计算最后的x+r并将梯度累加到(1)上5.将embedding恢复为(1)时的值6.根据(4)的梯度对参数进行更新

 2.2  具体代码

import torch
class PGD():def __init__(self, model):self.model = modelself.emb_backup = {}self.grad_backup = {}def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings', is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data, epsilon)def restore(self, emb_name='word_embeddings'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data, epsilon):r = param_data - self.emb_backup[param_name]if torch.norm(r) > epsilon:r = epsilon * r / torch.norm(r)return self.emb_backup[param_name] + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:self.grad_backup[name] = param.grad.clone()def restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:param.grad = self.grad_backup[name]

2.3 具体用法

pgd = PGD(model)
K = 3
for batch_input, batch_label in data:# 正常训练loss = model(batch_input, batch_label)loss.backward() # 反向传播,得到正常的gradpgd.backup_grad()# 累积多次对抗训练——每次生成对抗样本后,进行一次对抗训练,并不断累积梯度for t in range(K):pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.dataif t != K-1:model.zero_grad()else:pgd.restore_grad()loss_adv = model(batch_input, batch_label)loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度pgd.restore() # 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

Reference:

1.NLP中的对抗训练_colourmind的博客-CSDN博客

2.【NLP】NLP中的对抗训练_风度78的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/90580.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[NepCTF 2023] crypto 复现

这个赛很不理想,啥都不会。 拿了WP看了几个题,记录一下 random_RSA 这题不会是正常情况,我认为。对于论文题,不知道就是不知道,基本没有可能自己去完成论文。 题目不长,只有两个菜单,共可交…

win10中Docker安装、构建镜像、创建容器、Vscode连接实例

Docker方便一键构建项目所需的运行环境:首先构建镜像(Image)。然后镜像实例化成为容器(Container),构成项目的运行环境。最后Vscode连接容器,方便我们在本地进行开发。下面以一个简单的例子介绍在win10中实现:Docker安装、构建镜像…

九州未来参与编制的开源领域3项团体标准获批发布

日前,中电标2023年第21号团体标准公告正式发布,其中由九州未来参与编制的3项开源领域团体标准正式获批发布,于2023年8月1日正式实施。 具体内容如下: 《T/CESA 1269-2023 信息技术 开源 术语与综述》,本文件界定了信息…

“万恶”之源的KieServices,获取代码就一行,表面代码越少里面东西就越多,本以为就是个简单的工厂方法,没想到里面弯弯绕绕这么多东西

Drools用户手册看了得有一段时间了,现在开始看源码了,因为每次使用drools都会看见这么一段代码: 代码段1 起手代码 KieServices ks KieServices.Factory.get(); 那我就从这段代码出发开始研究drools的源码吧,这么一小段代码起初…

LangChain手记 Models,Prompts and Parsers

整理并翻译自DeepLearning.AILangChain的官方课程:Models,Prompts and Parsers 模型,提示词和解析器(Models, Prompts and Parsers) 模型:大语言模型提示词:构建传递给模型的输入的方式解析器:…

大语言模型:LLM的概念是个啥?

一、说明 大语言模型(维基:LLM- large language model)是以大尺寸为特征的语言模型。它们的规模是由人工智能加速器实现的,人工智能加速器能够处理大量文本数据,这些数据大部分是从互联网上抓取的。 [1]所构建的人工神…

Qt应用开发(基础篇)——工具箱 QToolBox

一、前言 QToolBox类继承于QFrame,QFrame继承于QWidget,是Qt常用的基础工具部件。 框架类QFrame介绍 QToolBox工具箱类提供了一列选项卡窗口,当前项显示在当前选项卡下面,适用于分类浏览、内容展示、操作指引这一类的使用场景。 二…

基于熵权法对Topsis模型的修正

由于层次分析法的最大缺点为:主观性太强,影响判断,对结果有很大影响,所以提出了熵权法修正。 变异程度方差/标准差。 如何度量信息量的大小: 把不可能的事情变成可能,这里面就有很多信息量。 概率越大&…

KCC@广州开源读书会广州开源建设讨论会

亲爱的开源读书会朋友们, 在下个周末我们将举办一场令人激动的线下读书会,探讨两本引人入胜的新书《只是为了好玩》和《开源之迷》。作为一个致力于推广开源精神和技术创新的社区,这次我们还邀请了圈内大咖前来参与,会给大家提供一…

瑞数信息《2023 API安全趋势报告》重磅发布: API攻击持续走高,Bots武器更聪明

如今API作为连接服务和传输数据的重要通道,已成为数字时代的新型基础设施,但随之而来的安全问题也日益凸显。为了让各个行业更好地应对API安全威胁挑战,瑞数信息作为国内首批具备“云原生API安全能力”认证的专业厂商,近年来持续输…

观察者模式实战

场景 假设创建订单后需要发短信、发邮件等其它的操作,放在业务逻辑会使代码非常臃肿,可以使用观察者模式优化代码 代码实现 自定义一个事件 发送邮件 发送短信 最后再创建订单的业务逻辑进行监听,创建订单 假设后面还需要做其它的…

【12】Git工具 协同工作平台使用教程 Gitee使用指南 腾讯工蜂使用指南【Gitee】【腾讯工蜂】【Git】

tips:少量的git安装和使用教程,更多讲快速使用上手Gitee和工蜂平台 一、准备工作 1、下载git Git - Downloads (git-scm.com) 找到对应操作系统,对应版本,对应的位数 下载后根据需求自己安装,然后用git --version验…

自动化更新导致的各种问题解决办法

由于最近自动化频频更新导致出现各种问题,因此在创建驱动对象代码时改成这种方式 我最近就遇到了由于更新而导致的代码报错,报错信息如下: 复制内容如下: Exception in thread “main” org.openqa.selenium.remote.http.Connecti…

【C++】多态的概念和简单介绍、虚函数、虚函数重写、多态构成的条件、重载、重写、重定义

文章目录 多态1.多态的概念和介绍2.虚函数2.1final2.2override 3.虚函数的重写3.1协变3.2析构函数的重写 4.多态构成的条件5.重载、重写、重定义...... 多态 1.多态的概念和介绍 C中的多态是一种面向对象编程的特性,它允许不同的对象对同一个消息做出不同的响应。 …

Hazel 引擎学习笔记

目录 Hazel 引擎学习笔记学习方法思考引擎结构创建工程程序入口点日志系统Premake\MD没有 cpp 文件的项目会出错include 到某个库就要包含这个库的路径,注意头文件展开 事件系统 获取和利用派生类信息预编译头文件抽象窗口类和 GLFWgit submodule addpremake 脚本禁…

【JVM】对String::intern()方法深入详解(JDK7及以上)

文章目录 1、什么是intern?2、经典例题解释例1例2例3 1、什么是intern? String::intern()是一个本地方法,它的作用是如果字符串常量池中已经包含一个等于此String对象的字符串,则返回代表池中这个字符串的String对象的引用&#…

7-15 然后是几点

有时候人们用四位数字表示一个时间,比如 1106 表示 11 点零 6 分。现在,你的程序要根据起始时间和流逝的时间计算出终止时间。 读入两个数字,第一个数字以这样的四位数字表示当前时间,第二个数字表示分钟数,计算当前时…

【Vue-Router】嵌套路由

footer.vue <template><div><router-view></router-view><hr><h1>我是父路由</h1><div><router-link to"/user">Login</router-link><router-link to"/user/reg" style"margin-left…

代码随想录算法训练营(二叉树总结篇)

一.二叉树的种类 1.满二叉树&#xff1a;就是说每一个非叶子节点的节点都有两个子节点。 2.完全二叉树&#xff1a;此二叉树只有最后一层可能没填满&#xff0c;并且存在的叶子节点都集中在左侧&#xff01;&#xff01;&#xff01; &#xff08;满二叉树也是完全二叉树&…

【Flutter】【基础】CustomPaint 绘画功能(一)

功能&#xff1a;CustomPaint 相当于在一个画布上面画画&#xff0c;可以自己绘制不同的颜色形状等 在各种widget 或者是插件不能满足到需求的时候&#xff0c;可以自己定义一些形状 使用实例和代码&#xff1a; CustomPaint&#xff1a; 能使你绘制的东西显示在你的ui 上面&a…