llama源码学习·model.py[7]Transformer类

一、源码展示

class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.params = paramsself.vocab_size = params.vocab_sizeself.n_layers = params.n_layersself.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)self.layers = torch.nn.ModuleList()for layer_id in range(params.n_layers):self.layers.append(TransformerBlock(layer_id, params))self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads,params.max_seq_len * 2,params.rope_theta,)@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int):_bsz, seqlen = tokens.shapeh = self.tok_embeddings(tokens)self.freqs_cis = self.freqs_cis.to(h.device)freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]mask = Noneif seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)mask = torch.triu(mask, diagonal=1)# When performing key-value caching, we compute the attention scores# only for the new sequence. Thus, the matrix of scores is of size# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for# j > cache_len + i, since row i corresponds to token cache_len + i.mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)h = self.norm(h)output = self.output(h).float()return output

二、原理图

在这里插入图片描述

三、代码注释

class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()# 基本参数self.params = params# 词汇表大小self.vocab_size = params.vocab_size# 模型的层数self.n_layers = params.n_layers# 这个嵌入层会把每个单词映射到一个高维向量,这个高维向量就是这个单词的嵌入。self.tok_embeddings = ParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)# 创建了一个空的模块列表self.layers = torch.nn.ModuleList()# 添加了n_layers个TransformerBlock到列表中for layer_id in range(params.n_layers):self.layers.append(TransformerBlock(layer_id, params))# 创建了一个RMSNorm层,它用于对输入数据进行归一化处理。self.norm = RMSNorm(params.dim, eps=params.norm_eps)# ColumnParallelLinear层是一个线性层,用于将输入数据的特征从params.dim维映射到params.vocab_size维。# 这种映射是通过学习一组权重来实现的,权重矩阵的大小为 params.dim x params.vocab_size。# 简言之,将输入转化为params.vocab_size维的输出,这个输出可以看作是预测每个词汇的概率分布。self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)# 计算了freqs_cis,这是一个预计算的张量,用于后面的旋转位置嵌入(Rotary Position Embedding)self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_headers, self.params.max_seq_len * 2,)# 通过torch.inference_mode()装饰器来指示这个方法将用于模型推理,# 这可以帮助PyTorch优化计算,并在可能的情况下减少内存使用。@torch.inference_mode()def forward(self, tokens: torch.Tensor, start_pos: int):# 批量大小(_bsz)和序列长度(seqlen)_bsz, seqlen = tokens.shape# 词嵌入向量h = self.tok_embeddings(tokens)# 根据输入的序列起始位置start_pos和序列长度seqlen,从self.freqs_cis中取出对应的旋转嵌入。# 这些旋转嵌入将用于后续的Transformer层中,对输入的词嵌入进行旋转操作,以编码位置信息。freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]mask = Noneif seqlen > 1:# 模型首先生成了一个掩码(mask),这个掩码被用于transformer层以防止在自注意力机制中考虑到未来的词汇。mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)# 这是通过填充一个全为负无穷的矩阵,然后使用torch.triu(取上三角)函数,来创建一个遮罩,# 该遮罩对应的位置上的元素,# 如果它们代表的词在序列中是在当前词之后的词,则值为负无穷,否则为0。mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)# 对每个transformer层,依次将当前的嵌入向量(或者前一层的输出)作为输入,# 执行该层的前向传播,计算结果将用于下一层的输入。for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)# 将最后一层transformer层的输出通过一个规范化(norm)层,然后通过一个全连接层(self.output),# 转换为最后的模型输出。这个输出的尺寸应该与词汇表的大小相同,因此每个词都有一个对应的分数,# 这个分数代表模型认为该词是下一个词的可能性。h = self.norm(h)output = self.output(h).float()return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/39244.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MD2Card(markdown)

MD2Card 介绍: 1.小红书爆款神器,Markdown笔记秒转高颜值卡片 2.实时预览15种主题,自动拆长文,图片/SVG导出即用 3.零门槛不登录,免费无限生成,专治排版废和设计手残党 网站地址: https://md2…

第二节第一部分:String字符串

一、导包 二、String字符串 三、String注意事项 四、字符串的比较 五、面试例题 六、String案例一 需求分析: 代码: package com.StringTest;import java.util.Scanner;public class StingTest {public static void main(String[] args) {//1.开发一个…

动态规划(01背包恰好装满型详解):和为目标值的最长子序列长度

0-1背包:有n个物品,第i个物品的体积为w[i],价值为v[i],每个物品至多选择一个,求体积和不超过capacity的最大价值和。 对于第i个物品,我们只有两种选择:选,或者不选。如果选&#xf…

Spring漏洞再现

一、CVE-2017-8046 1、开环境 2、访问目录 /customers/1 3、在当前页抓包,并修改数据包 PATCH /customers/1 HTTP/1.1 Host: 150.158.199.164:8080 Accept-Encoding: gzip, deflate Accept: */* User-Agent: Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1;…

Ftrans飞驰云联受邀参加“2025汽车零部件CIO年会“并荣获智象奖

2025年3月6日,由栖观汽车、栖观资讯和飞羽商务主办的“2025第二届中国汽车&零部件CIO年会暨智象奖颁奖盛典”于上海盛大召开,Ftrans飞驰云联作为国内领先的企业文件传输与数据交换解决方案提供商,受邀出席了年会,并凭借卓越的…

西门子 CPU 1513-1 PN TCP Server 接收字符串前多了一个问号

TIA V17编程环境中(CPU 1513-1 PN),调用TSEND_C以TCP协议向TCP Server发送字符串:abded1234,TCP Server接收到的字符串多了一个问号:?avded1234. TSEND_C 指令的 DATA DB为非优化string类型数据 截图如下: 字符串前面两个字节不是起始字符,第一个是字节是字符串最大长度…

Matlab2024a免费版下载教程

Matlab是一个高性能的数学计算与仿真软件,广泛应用于科学计算、数据分析、算法开发以及工程绘图等多个领域。它提供了强大的矩阵运算能力、丰富的内置函数库以及灵活的编程环境,使得用户能够高效地解决复杂的数学问题。本文,我将为大家详细介…

SpringCould微服务架构之Docker(1)

项目中微服务比较多的时候,一个一个手动的部署太麻烦了,所以就需要用到Docker。 项目部署中的问题: Docker是一种快速交付应用、运行应用的技术。

软件公司高新技术企业代办:机遇与陷阱并存-优雅草卓伊凡

软件公司高新技术企业代办:机遇与陷阱并存-优雅草卓伊凡 在科技飞速发展的当下,软件公司如雨后春笋般涌现,众多企业渴望通过申请高新技术企业来获得政策支持与发展助力。随之而来的,是高新技术企业代办业务的兴起。然而&#xff…

动捕技术革新虚拟直播:解码虚拟主播的“拟真感“破局之路

在元宇宙技术加速落地的今天,虚拟直播已从早期的卡通形象展示,进化为具备情感交互的沉浸式体验,用户对"高拟真度互动"的需求也逐渐增加,这场行业变革的核心驱动力,离不开动捕技术的持续迭代。 虚拟直播的&q…

python字节码文件.pyc反编译成.py文件

一、前言 在 Python 开发过程中,.pyc 文件(Python 字节码文件)是 Python 解释器运行程序时生成的一种中间文件。它通常用于提高程序的运行效率,避免每次运行时都重新编译源代码。然而,由于各种原因,我们可…

C++友元:跨墙访问的三种姿势

目录 友元 友元之普通函数形式 友元之成员函数形式 友元类 友元的特点 友元 什么叫友元? 一般来说,类的私有成员只能在类的内部访问,类之外是不能访问它们的。但如果将其他类/函数设置为类的友元,那么友元类/函数就可以在前…

Typora安装使用教程 简单易用的Markdown编辑器

Typora markdown 编辑器下,最后一个免费版本 0.11.18,但可能会提示过期无法使用, 建议大家可以使用 0.9.96 Windows 版,下载 Windows X64 版。 Typora简介 Typora 是一款由 Abner Lee 开发的轻量级 Markdown 编辑器,与其他 Mark…

图解AUTOSAR_SWS_WatchdogInterface

AUTOSAR Watchdog Interface (WdgIf) 详解 AUTOSAR经典平台看门狗接口模块技术详解 目录 1. 概述 1.1 WdgIf模块的作用1.2 WdgIf在AUTOSAR中的位置2. 架构设计 2.1 WdgIf架构概览2.2 接口设计2.3 序列设计3. 配置详解 3.1 配置参数3.2 配置结构3.3 配置类型4. 总结 4.1 主要特点…

(Arxiv-2025)Magic 1-For-1:在一分钟内生成一分钟视频剪辑

Magic 1-For-1:在一分钟内生成一分钟视频剪辑 paper是PKU发布在Arxiv 2025的工作 paper title:Magic 1-For-1: Generating One Minute Video Clips within One Minute Code:地址 Abstract 在本技术报告中,我们提出了 Magic 1-For-1&#xff…

谷歌大型推理模型曝光!击败Claude-3.7-Thinking

哎!最近推特上的网友在LMSYS Arena 发现了个泄漏的大模型 Nebula,效果据说特别好,打败了o1、o3-mini、Claude 3.7 Thinking等模型: 网友们通过询问和分析 API,发现这似乎是谷歌正在秘密测试的新推理模型!推…

css-grid布局

文章目录 1、布局2、网格轨道3、间距Gap4、网格线5、网格别名 当一个 HTML 元素将 display 属性设置为 grid 或 inline-grid 后,它就变成了一个网格容器,这个元素的所有直系子元素将成为网格元素。 1、布局 启用grid布局类似与flex布局,不过g…

菱形虚拟继承的原理

一 :菱形继承的问题 普通的菱形继承存在数据冗余和二义性的问题 ,如下代码: class Person { public:string _name; //姓名 };class Student : public Person { protected:int _num; //学号 };class Teacher : public Person { protected:int…

<数据集>轨道异物识别数据集<目标检测>

数据集下载链接:https://download.csdn.net/download/qq_53332949/90527370 数据集格式:VOCYOLO格式 图片数量:1659张 标注数量(xml文件个数):1659 标注数量(txt文件个数):1659 标注类别数:6 标注类别…

高效PDF翻译解决方案:多引擎支持+格式零丢失

软件介绍 在AI翻译工具大行其道的今天,传统翻译软件市场逐渐饱和,但专业领域的深度需求依然存在。本文推荐的PDF翻译工具凭借20余种专业翻译接口,为学术文献、技术文档等复杂内容提供更精准的翻译服务,在保留文档原始排版的同时…