LLM优化:开源星火13B显卡及内存占用优化

1. 背景

本qiang~这两天接了一个任务,部署几个开源的模型,并且将本地经过全量微调的模型与开源模型做一个效果对比。

部署的开源模型包括:星火13B,Baichuan2-13B, ChatGLM6B等

其他两个模型基于transformers架构封装,因此推理服务启动还是十分丝滑,但星火13B是基于Megatron-DeepSpeed框架实现,地址是:https://gitee.com/iflytekopensource/iFlytekSpark-13B,启动推理服务的过程中发现启动13B的显卡占用71G-78G,有些反直觉。

此文就是整理开源星火13B的显存及内存排查并优化的整理过程,至于哪家开源模型效果好,不在此文的讨论范围内。

2. 原因分析

直观上来说,13B的模型,数据类型为bf16,显卡占用大概在26G左右,但星火13B直接占用70G+,不可思议,怪不得网上关于星火开源模型的讨论少之又少,原因显而易见,这么大的显存占用只能用多卡或者A800等80G显卡才能适配。穷人家的孩子,哪有这么多余粮。

排查原因的过程中,少不了源码的调试与分析。在排查的过程中,启动推理服务的文件run_iFlytekSpark_text_generation.py中,model_provider方法是初始化模型并加载模型文件的方法。

def model_provider(pre_process=True, post_process=True):"""Build the model."""print_rank_0('building iFlytekSpark model ...')args = get_args()config = core_transformer_config_from_args(args)### 初始化星火模型model = iFlytekSparkModel(config,num_tokentypes=0,parallel_output=False,pre_process=pre_process,post_process=post_process,return_moe_loss=False)if args.from_pretrained is not None:assert os.path.exists(args.from_pretrained)ckpt_path = get_checkpoint_name(args.from_pretrained)print_rank_0('Loading from {} '.format(args.from_pretrained))# 模型加载权重文件state_dict = torch.load(ckpt_path, map_location=f"cuda:{torch.cuda.current_device()}")if 'module' in state_dict:state_dict = state_dict['module']model.load_state_dict(state_dict)return model

其中,加载权重文件可以看到,加载state_dict时,直接将权重文件加载到显卡中,而非加载至CPU,然后再执行to方法,转移到GPU。因此该处是一个潜在的优化点

再打入iFlytekSparkModel内部,词表Embedding层,线性转换层,等初始化weight时,也是直接将weight分配在GPU上运行。例如下例:

class RowParallelLinear(torch.nn.Module):def __init__(self, input_size: int, output_size: int, *,config: ModelParallelConfig,init_method: Callable,bias: bool = True,input_is_parallel: bool = False,stride: int = 1,keep_master_weight_for_test: bool = False,skip_bias_add: bool = False,moe=False, enable_expert_tensor_parallelism=False):super(RowParallelLinear, self).__init__()# .........if config.use_cpu_initialization:self.weight = Parameter(torch.empty(self.output_size,self.input_size_per_partition,dtype=config.params_dtype))if config.perform_initialization:self.master_weight = _initialize_affine_weight_cpu(self.weight, self.output_size, self.input_size,self.input_size_per_partition, 1, init_method,stride=stride, return_master_weight=keep_master_weight_for_test,params_dtype=config.params_dtype)else:# 默认按照启动sh命令,会走该分支self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition,device=get_accelerator().current_device_name(), dtype=config.params_dtype))if config.perform_initialization:_initialize_affine_weight_gpu(self.weight, init_method,partition_dim=1, stride=stride)if bias:if config.use_cpu_initialization:self.bias = Parameter(torch.empty(self.output_size,dtype=config.params_dtype))else:# 默认按照启动sh命令,会走该分支self.bias = Parameter(torch.empty(self.output_size, device=get_accelerator().current_device_name(),dtype=config.params_dtype))setattr(self.bias, 'sequence_parallel', self.sequence_parallel)if config.perform_initialization:# Always initialize bias to zero.with torch.no_grad():self.bias.zero_()else:self.register_parameter('bias', None)

3. 优化方案

1. 模型初始化时,模型的Embedding,线性层的权重weight均直接加载至GPU,因此可以优化为先将这些weight加载至CPU

改进的方式也很简单,从上面的源码层面,可以看到,当增加参数” use_cpu_initialization”,将使用CPU进行初始化权重,因此只需要在启动推理服务的脚本中增加” --use-cpu-initialization”参数即可。

2. 加载模型文件时,直接加载至GPU,然后run_iFlytekSpark_text_generation.py中的get_model方法中,当模型加载完成后,会进行分配至GPU以及FP16的转换的操作。如下代码所示。

def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):"""Build the model."""args = get_args()args.model_type = model_type# ..........# GPU allocation.for model_module in model:model_module.to(get_accelerator().current_device_name())# Fp16 conversion.if args.fp16 or args.bf16:model = [Float16Module(model_module, args) for model_module in model]# .......return model

因此,优化的方式也很简单,可以优化为先加载至CPU,再运行get_model中的默认分配至GPU,加载完后,再使用垃圾回收机制清除CPU占用的内存即可

话不多说,优化后的代码如下:

def model_provider(pre_process=True, post_process=True):"""Build the model."""print_rank_0('building iFlytekSpark model ...')args = get_args()config = core_transformer_config_from_args(args)model = iFlytekSparkModel(config,num_tokentypes=0,parallel_output=False,pre_process=pre_process,post_process=post_process,return_moe_loss=False)if args.from_pretrained is not None:print(args.from_pretrained)assert os.path.exists(args.from_pretrained)ckpt_path = get_checkpoint_name(args.from_pretrained)print_rank_0('Loading from {} '.format(args.from_pretrained))# state_dict = torch.load(ckpt_path, map_location=f"cuda:{torch.cuda.current_device()}")# CPU进行加载state_dict = torch.load(ckpt_path, map_location=f"cpu")if 'module' in state_dict:state_dict = state_dict['module']model.load_state_dict(state_dict)# 加载完成,删除state_dict,并垃圾回收del state_dictgc.collect()torch.cuda.empty_cache()return model

4. 效果对比

(1) 优化前的显卡占用: 71.5G

(2) 优化前的内存占用: 虚拟内存占用94.5G

(3) 优化后的显卡占用: 26G

(4) 优化后的内存占用: 43.1G

5. 总结

一句话足矣~

本文主要是针对开源星火13B的显存及内存占用过大的一个代码优化。核心思想是使用CPU预加载模型,再转换至GPU。

后期如有遇到此类问题,可以借鉴之~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/317761.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java-stream流案例

需求 代码 Vote类 // 1. 定义一个投票类 public class Vote {private String name;private ArrayList<String> voteList;public Vote(String name, ArrayList<String> voteList) {this.name name;this.voteList voteList;}public String getName() {return nam…

微信小程序:5.数据绑定

在Data中定义数据早wxml中进行数据使用 在data中定义数据 在页面对应的js对象中找到data&#xff0c;然后把数据进行定义即可 Page({data: {motto: Hello World,userInfo: {avatarUrl: defaultAvatarUrl,nickName: ,},hasUserInfo: false,canIUseGetUserProfile: wx.canIUse…

Ollamallama

Olllama 直接下载ollama程序&#xff0c;安装后可在cmd里直接运行大模型&#xff1b; llama 3 meta 开源的最新llama大模型&#xff1b; 下载运行 1 ollama ollama run llama3 2 github 下载仓库&#xff0c;需要linux环境&#xff0c;windows可使用wsl&#xff1b; 接…

linux 光驱(光盘)安装

文章目录 选择光盘自带 YUM 库创建 repo创建文件夹挂载光驱开机自启动挂载安装软件YUM 安装RPM 安装源码包安装 选择光盘 vmware 选择光盘 自带 YUM 库 ls /etc/yum.repos.d创建 repo vim /etc/yum.repo.d/demo.repo // 编写 repo 相关配置 [demo] namedemo baseurlfile://…

Java进阶-File、递归、IO流

目录 前言 File类概述 File类创建对象 绝对路径和相对路径 File类常用API 判断文件类型、获取文件信息 创建文件、删除文件功能 遍历文件夹 方法递归 递归的形式和特点 递归的算法流程、核心要素 常见案例 经典问题-猴子吃桃问题 非规律化递归案例-文件搜索 非规…

【网络原理】UDP协议 | UDP报文格式 | 校验和 | UDP的特点 | 应用层的自定义格式

文章目录 一、UDP协议1.UDP的传输流程发送方接收方 2.UDP协议报文格式&#xff1a;长度受限校验和如何校验&#xff1a;CRC算法&#xff1a;循环冗余算法md5算法&#xff1a; 2.UDP的特点 二、开发中常见的自定义格式1.xml&#xff08;古老&#xff09;2.json&#xff08;最流行…

java案例-读取xml文件

需求 导入依赖 <dependencies><!-- dom4j --><dependency><groupId>dom4j</groupId><artifactId>dom4j</artifactId><version>1.6.1</version></dependency> </dependencies>代码 SAXReader saxReade…

Ansible自动化

Ansible自动化 自动化的需求&#xff1a; 1. 在什么样的场景下需要自动化&#xff1f; 批量化的工作&#xff1a; 装软件包、配置服务、升级、下发文件… 2. 为什么在自动化工具中选择ansible&#xff1f; 对比shell脚本&#xff1a; 相对于用shell的脚本来实现自动化&#x…

跨平台桌面客户端开发框架

跨平台桌面客户端开发框架允许开发者创建能够在多个操作系统上运行的桌面应用程序。以下是一些流行的跨平台桌面客户端开发框架。这些框架各有优势&#xff0c;选择哪个框架取决于项目需求、团队的技术栈以及对特定特性的偏好。 1.Electron &#xff1a; 使用JavaScript, HTML…

uniapp视频播放器(h5+app)

关于uniapp视频播放器遇到的一些问题&#xff0c;mark下。 中途遇到了很多问题&#xff0c;如果有相同的伙伴遇到了类似的&#xff0c;欢迎交流 官方的video播放器在app上不友好&#xff0c;有以下功能不支持。 loadedmetadata、controlstoggle不支持导致只能手写控制层。 不…

python学习之词云图片生成

代码实现 import jieba import wordcloudf open("D:/Pythonstudy/data/平凡的世界.txt", "r", encoding"utf-8") t f.read() print(t) f.close() ls jieba.lcut(t) txt " ".join(ls)w wordcloud.WordCloud(font_path"D:/cc…

【webrtc】MessageHandler 8: 基于线程的消息处理:处理音频输入输出断开

m98代码,看起来m114 去掉了MessageHandler :音频的录制和播放 都使用了on message,但只是用来通知并处理流的断开的。AAudioRecorder AAudioRecorder 处理流断开 OnErrorCallback :有可能 错误回调是别处来的,是其他线程, 但是这个错误的处理要再自己的线程执行: 音频播…

Docker:centos7安装docker

官网&#xff1a;https://www.docker.com/官网 文档地址 - 确认centos7及其以上的版本 查看当前系统版本 cat /etc/redhat-release- 卸载旧版本 依照官网执行 - yum安装gcc相关 yum -y install gccyum -y install gcc-c- 安装需要的软件包 yum install -y yum-utils- 设置s…

Flask简介

Flask简介 安装概述使用PyCharm创建一个Flask程序 Flask程序的基本结构初始化路由和视图函数启动服务器请求-响应循环 安装 概述 Flask算是小型框架&#xff0c;小到可以称为“微框架”。Flask 非常小&#xff0c;因此你一旦能够熟练使用它&#xff0c;很可能就能读懂它所有的…

小米笔记本文件夹里是空白怎么办?分享原因及解决方案

随着科技的不断发展&#xff0c;笔记本电脑已成为我们日常生活和工作中不可或缺的一部分。而小米&#xff0c;作为知名的科技品牌&#xff0c;其笔记本产品凭借其出色的性能和合理的价格&#xff0c;受到了广大用户的喜爱。然而&#xff0c;在使用过程中&#xff0c;有时我们可…

Android AOSP探索之Ubantu下Toolbox的安装

文章目录 概述安装Toolbox解决运行的问题 概述 由于最近需要进军android的framework,所以需要工具的支持&#xff0c;之前听说江湖上都流传source insight,我去弄了一个破解版&#xff0c;功能确实强大&#xff0c;但是作为多年android开发的我习惯使用android studio。虽然使…

Delta lake with Java--利用spark sql操作数据2

上一篇文章尝试了建库&#xff0c;建表&#xff0c;插入数据&#xff0c;还差删除和更新&#xff0c;所以在这篇文章补充一下&#xff0c;代码很简单&#xff0c;具体如下&#xff1a; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession;publi…

网盘——分享文件——界面设计

本文主要讲解网盘中文件操作的分享文件部分&#xff0c;主要包含两方面的设计&#xff1a;分享文件界面设计和逻辑设计。 1、界面设计 1.1、添加一个类 1.2、引入头文件 #include <QPushButton> #include <QHBoxLayout> #include <QVBoxLayout> #include …

C++ 矩阵

目录 了解矩阵的数学原理&#xff08;大学线性代数&#xff09; 矩阵及转置矩阵 矩阵乘法 矩阵快速幂 相伴矩阵模板 [相伴矩阵,快速矩阵幂]CSES1722 Fibonacci Numbers 了解矩阵的数学原理&#xff08;大学线性代数&#xff09; 矩阵及转置矩阵 这里A就是一个矩阵&…

C++中的异常

目录 1.C语言传统的处理错误的方式 2. C异常概念 3. 异常的使用 3.1 异常的抛出和捕获 3.2 异常的重新抛出 3.3异常安全 3.4 异常规范 4.自定义异常体系 5.C标准库的异常体系 6.异常的优缺点 7.func&#xff08;&#xff09; throw();的方式规范化 1.C语言传统的处理…