LLaMA-Adapter源码解析

LLaMA-Adapter源码解析

伪代码

def transformer_block_with_llama_adapter(x, gating_factor, soft_prompt):residual =xy= zero_init_attention(soft_prompt, x) # llama-adapter: prepend prefixx= self_attention(x)x = x+ gating_factor * y  # llama-adapter: apply zero_init_attentionx = LayerNorm(x+residual)residual = xx = FullyConnectedLayers(x)x = AdapterLayers(x)x = LayerNorm(x + residual)return x

源码

class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()self.head_dim = args.dim // args.n_headsself.wq = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,bias=False,gather_output=False,init_method=lambda x: x,)self.wk = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,bias=False,gather_output=False,init_method=lambda x: x,)self.wv = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,bias=False,gather_output=False,init_method=lambda x: x,)self.wo = RowParallelLinear(args.n_heads * self.head_dim,args.dim,bias=False,input_is_parallel=True,init_method=lambda x: x,)self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()self.gate = torch.nn.Parameter(torch.zeros(1))def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):bsz, seqlen, _ = x.shapexq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)self.cache_k = self.cache_k.to(xq)self.cache_v = self.cache_v.to(xq)self.cache_k[:bsz, start_pos : start_pos + seqlen] = xkself.cache_v[:bsz, start_pos : start_pos + seqlen] = xvkeys = self.cache_k[:bsz, : start_pos + seqlen]values = self.cache_v[:bsz, : start_pos + seqlen]if adapter is not None:adapter_len = adapter.shape[1]adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)adapter_k = adapter_k.transpose(1, 2)adapter_v = adapter_v.transpose(1, 2)xq = xq.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)if mask is not None:scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)scores = F.softmax(scores.float(), dim=-1).type_as(xq)output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)if adapter is not None:adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)output = output + torch.matmul(adapter_scores, adapter_v)output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)return self.wo(output)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/179399.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PY32F071单片机,主频最高72 M,带一路DAC,USB

PY32F071 系列微控制器采用高性能的 32 位 ARM Cortex-M0内核,宽电压工作范围的 MCU。嵌入高达128 Kbytes flash 和 16 Kbytes SRAM 存储器,最高工作频率 72 MHz。包含多种不同封装类型多款产品。芯片集成多路 I2C、SPI、USART 等通讯外设,1 …

配置git并把本地项目连接github

一.配置git 1.下载git(Git),但推荐使用国内镜像下载(CNPM Binaries Mirror) 选好64和版本号下载,全部点下一步 下载完成后打开终端,输入 git --version 出现版本号则说明安装成功 然后继续…

C# OpenCvSharp DNN 部署L2CS-Net人脸朝向估计

效果 项目 代码 using OpenCvSharp; using OpenCvSharp.Dnn; using System; using System.Collections.Generic; using System.Drawing; using System.Drawing.Drawing2D; using System.Linq; using System.Text; using System.Windows.Forms;namespace OpenCvSharp_DNN_Demo …

MySQL中表的增删改查

目录 一、CRUD 二、新增(Create) (1)语法 (2)单行数据全列插入 (3)多行数据指定列插入 三、查询(Retrieve) (1)语法 …

swift语言下SurfGen库做的爬虫是什么样的 ?

Swift语言并没有内置的爬虫库,但是你可以使用第三方库来实现爬虫功能。其中比较常用的是Alamofire和SwiftyJSON。Alamofire是一个基于Swift语言的HTTP网络库,可以用来发送HTTP请求和接收HTTP响应。而SwiftyJSON则是一个用于处理JSON数据的Swift库&#x…

vue使用JsBarcode生成条形码

在工作中&#xff0c;有一个需求是接口返回的订单号生成条形码&#xff0c;如图&#xff1a; 1.安装依赖 yarn add jsbarcode2.引入 在script标签中引入 import JsBarcode from jsbarcode 3.使用 this.$refs.a.src的值为条形码的地址。 <template><div><img…

自定义类型结构体(下)

目录 结构体传参结构体实现位段什么是位段位段的内存分配位段的跨平台问题总结&#xff1a; 位段的应用位段使用的注意事项** 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 &#x1f412;&#x1f412;&#x1f412; 个人主页 &#x1f978;&#x1f978;&a…

应聘遇到性格测试,怎么做才能通过?

本人xxx专业&#xff0c;最近找工作&#xff0c;总会做到性格测试题&#xff0c;好几次了&#xff0c;做完性格测试就没消息了&#xff0c;已经疯了&#xff0c;直接给跪了&#xff0c;完全找不到套路&#xff0c;都快怀疑人生了。所以想问一下&#xff0c;像这样公司的性格测试…

什么是TCY油封?

机械由无数组件协同工作以确保平稳运行&#xff0c;其中一种不可或缺的部件是油封&#xff0c;特别是TCY油封。本文旨在阐明TCY油封的应用、其重要性以及它们如何提高机械的整体效率。 TCY油封主要用于轴密封。轴是一种旋转机器元件&#xff0c;横截面通常为圆形&#xff0c;用…

Microsoft 365 管理自动化

Microsoft 365 服务被大多数组织广泛使用&#xff0c;每天生成的数据量巨大。解决 Microsoft 365 中的问题可能非常困难&#xff0c;并且使用多个管理中心来保护组织变得复杂。本机控制台还缺少某些批量管理任务、全面的审计报告和基于角色的精细访问控制。 Microsoft 360 管理…

c++获取和设置环境变量

这个功能非常常用&#xff0c;但是容易忘记&#xff0c;这里做个记录。 注意&#xff0c;设置的环境变量只在当前进程中生效&#xff0c;所以在电脑中的环境变量设置区域看不到。 std::string env getenv("PATH");env "X:\\envtest";std::string newEnv…

【机器学习】几种常用的机器学习调参方法

在机器学习中&#xff0c;模型的性能往往受到模型的超参数、数据的质量、特征选择等因素影响。其中&#xff0c;模型的超参数调整是模型优化中最重要的环节之一。超参数&#xff08;Hyperparameters&#xff09;在机器学习算法中需要人为设定&#xff0c;它们不能直接从训练数据…

0基础学习PyFlink——事件时间和运行时间的窗口

大纲 定制策略运行策略Reduce完整代码滑动窗口案例参考资料 在 《0基础学习PyFlink——时间滚动窗口(Tumbling Time Windows)》一文中&#xff0c;我们使用的是运行时间(Tumbling ProcessingTimeWindows)作为窗口的参考时间&#xff1a; reducedkeyed.window(TumblingProcess…

国际测试委员会BenchCouncil首发“开源系统杰出成果榜” 百度飞桨上榜

&#x1f4d5;作者简介&#xff1a;热爱跑步的恒川&#xff0c;致力于C/C、Java、Python等多编程语言&#xff0c;热爱跑步&#xff0c;喜爱音乐的一位博主。 &#x1f4d7;本文收录于恒川的日常汇报系列&#xff0c;大家有兴趣的可以看一看 &#x1f4d8;相关专栏C语言初阶、C…

关于pytorch张量维度转换及张量运算

关于pytorch张量维度转换大全 1 tensor.view()2 tensor.reshape()3 tensor.squeeze()和tensor.unsqueeze()3.1 tensor.squeeze() 降维3.2 tensor.unsqueeze(idx)升维 4 tensor.permute()5 torch.cat([a,b],dim)6 torch.stack()7 torch.chunk()和torch.split()8 与tensor相乘运算…

RESTful接口实现与测试

目录标题 是什么&#xff1f;设计风格HTTP协议四种传参方式常用注解RequestBody与ResponseBodyRequestMapping注解RestController与ControllerPathVariable 与RequestParam 接受复杂嵌套对象参数Http数据转换的原理自定义HttpMessageConverter统一规划接口响应的数据格式实战&a…

为什么重写 redisTemplate

为什么重写 redisTemplate 1.安装 redis 上传 redis 的安装包tar -xvf redis-5.0.7.tar.gzyum -y install gcc-cmakemake PREFIX/soft/redis installcd /soft/redis/bin./redis-server redis.conf 2. 集成 redisTemplate maven 依赖 <dependency><groupId>org…

详解Java经典数据结构——HashMap

Java 的 HashMap 是一个常用的基于哈希表的数据结构&#xff0c;它实现了 Map 接口&#xff0c;可以存储键值对。下面我们进行详细介绍&#xff1a; 基本结构&#xff1a;HashMap 底层是基于哈希表来实现的&#xff0c;每次插入一个键值对时&#xff0c;会先对该键进行 Hash 运…

Locust:可能是一款最被低估的压测工具

01、Locust介绍 开源性能测试工具https://www.locust.io/&#xff0c;基于Python的性能压测工具&#xff0c;使用Python代码来定义用户行为&#xff0c;模拟百万计的并发用户访问。每个测试用户的行为由您定义&#xff0c;并且通过Web UI实时监控聚集过程。 压力发生器作为性…

本地部署Jellyfin影音服务器并实现远程访问影音库

文章目录 1. 前言2. Jellyfin服务网站搭建2.1. Jellyfin下载和安装2.2. Jellyfin网页测试 3.本地网页发布3.1 cpolar的安装和注册3.2 Cpolar云端设置3.3 Cpolar本地设置 4.公网访问测试5. 结语 1. 前言 随着移动智能设备的普及&#xff0c;各种各样的使用需求也被开发出来&…