使用deepspeed小记

1. 减少显存占用的历程忠告

医学图像经常很大,所以训练模型有时候会有难度,但是现在找到了很多减少显存的方法。
不知道为什么,使用transformers的trainer库确确实实会减少显存的占用,即使没有使用deepspeed,占用的显存也会减少。

别自己造轮子
我之前也使用过 LoRA,自己也设计过,非常非常建议千万不要自己去写LoRA,很浪费时间,设计很费时间,同时检验模型LoRA的有效性也很浪费时间,权重的融合也很浪费时间,尽量使用其他已经写好的LoRA。

我推荐使用transformers集成模型和训练集,只需要写一个dataset和collate_fn,最多再多写一个Trainer的computer_loss,模型就可以自然而然的搞定。效率最高最有效。

2. Deepspeed方便快捷

在这里插入图片描述
使用 deepspeed 的流程是最短的

2.1 如果warning,需要加载一些库

moudle ava
moudle load compiler/gcc/7.3.1
moudle load cuda/7/11.8

由于deepspeed进行编译实际上是将GPU的一些指令重新编译,让CPU执行,同时还要符合CUDA的计算结构,能和GPU交互,所以GCC编译,CUDA编译都要符合版本要求

2.2 编写Trainer的python文件

建议使用transformers的trainer函数,这样很多json文件可以直接设置auto,同时还方便指定json配置文件。
同时要注意,这里可能会要求你加入 args,设置一个 local_rank 全局管控。
TrainingArguments 指定 ds_config.json 文件

import argparse
import sys
def parse_agrs():parser = argparse.ArgumentParser()parser.add_argument("--local_rank", type=int, default=-1, help="Local rank. Necessary for using the torch.distributed.launch utility.")return argsargs = parse_agrs()training_args = TrainingArguments(output_dir='./checkpoint/Eff_R2GenCMN_base',num_train_epochs=1000,per_device_train_batch_size=10,per_device_eval_batch_size=10,warmup_steps=500,weight_decay=0.01,logging_dir='./checkpoint/Eff_R2GenCMN_base/output_logs',logging_steps=10,save_strategy='steps',  # 添加保存策略为每一定步骤保存一次save_steps=100,  # 每100步保存一次模型save_total_limit=5,  # 最多保存5个模型report_to="none",fp16=True,  # 启用混合精度训练deepspeed='./ds_config.json',
)tokenizer = Tokenizer()
args = parse_agrs()
model = R2GenCMN(args, tokenizer)
dataset_train = Dataset(xlsx_file="./dataset/train_dataset.xlsx")
dataset_test = Dataset(xlsx_file="./dataset/test_dataset.xlsx")
trainer = MyTrainer(model=model,  # 使用的模型args=training_args,  # 训练参数train_dataset=dataset_train,  # 训练数据集eval_dataset=dataset_test,  # 验证数据集data_collator=collate_fn,# 可能需要定义compute_metrics函数来计算评估指标
)

2.3 编写ds_config文件

编写ds_config文件的目的就是简介python文件,同时更改参数方便,减少大脑记忆负担,便于使用。
ds_config.json 文件脚本通常是 通用的, batch如果写auto,deepspeed会根据显卡给你 自动设置batch 大小
这里只是设置了

stage2的

{"bfloat16": {"enabled": false},"fp16": {"enabled": "auto","loss_scale": 0,"loss_scale_window": 1000,"initial_scale_power": 16,"hysteresis": 2,"min_loss_scale": 1},"optimizer": {"type": "AdamW","params": {"lr": "auto","betas": "auto","eps": "auto","weight_decay": "auto"}},"scheduler": {"type": "WarmupLR","params": {"warmup_min_lr": "auto","warmup_max_lr": "auto","warmup_num_steps": "auto"}},"zero_optimization": {"stage": 2,"offload_optimizer": {"device": "cpu","pin_memory": true},"allgather_partitions": true,"allgather_bucket_size": 2e8,"overlap_comm": true,"reduce_scatter": true,"reduce_bucket_size": 2e8,"contiguous_gradients": true},"gradient_accumulation_steps": "auto","gradient_clipping": "auto","train_batch_size": "auto","train_micro_batch_size_per_gpu": "auto","steps_per_print": 1e5
}

或者使用stage3

{"bfloat16": {"enabled": false},"fp16": {"enabled": "auto","loss_scale": 0,"loss_scale_window": 1000,"initial_scale_power": 16,"hysteresis": 2,"min_loss_scale": 1},"optimizer": {"type": "AdamW","params": {"lr": "auto","betas": "auto","eps": "auto","weight_decay": "auto"}},"scheduler": {"type": "WarmupLR","params": {"warmup_min_lr": "auto","warmup_max_lr": "auto","warmup_num_steps": "auto"}},"zero_optimization": {"stage": 3,"offload_optimizer": {"device": "cpu","pin_memory": true},"offload_param": {"device": "cpu","pin_memory": true},"overlap_comm": true,"contiguous_gradients": true,"sub_group_size": 1e9,"reduce_bucket_size": "auto","stage3_prefetch_bucket_size": "auto","stage3_param_persistence_threshold": "auto","stage3_max_live_parameters": 1e9,"stage3_max_reuse_distance": 1e9,"stage3_gather_fp16_weights_on_model_save": true},"gradient_accumulation_steps": "auto","gradient_clipping": "auto","steps_per_print": 1e5,"train_batch_size": "auto","train_micro_batch_size_per_gpu": "auto","wall_clock_breakdown": false
}

2.4 运行程序

最终deepspeed运行就可以了
这里的warning实际上没有影响模型的运行,是重新编译。

deepspeed train.py
[2024-04-02 12:04:43,112] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-02 12:05:48,493] [WARNING] [runner.py:196:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2024-04-02 12:05:48,493] [INFO] [runner.py:555:main] cmd = /public/home/v-yumy/anaconda3/envs/llava2/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None transformer_train.py
[2024-04-02 12:05:51,627] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-02 12:05:55,944] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0]}
[2024-04-02 12:05:55,944] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
[2024-04-02 12:05:55,944] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2024-04-02 12:05:55,944] [INFO] [launch.py:163:main] dist_world_size=1
[2024-04-02 12:05:55,944] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0
[2024-04-02 12:06:29,136] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-02 12:06:31,519] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-04-02 12:06:31,519] [INFO] [comm.py:594:init_distributed] cdb=None
[2024-04-02 12:06:31,519] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.742 seconds.
Prefix dict has been built successfully.
EfficientNet: replace first conv
EncoderDecoder 的Transformer 是 base
EncoderDecoder 是 base
视觉特征,不进行预训练[WARNING]  cpu_adam cuda is missing or is incompatible with installed torch, only cpu ops can be compiled!
Using /public/home/v-yumy/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Emitting ninja build file /public/home/v-yumy/.cache/torch_extensions/py310_cu117/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 0.7046074867248535 seconds
Rank: 0 partition count [1] and sizes[(42770360, False)] 
{'loss': 6.7285, 'learning_rate': 1.6730270909663467e-05, 'epoch': 0.02}                                                                                                                                   
{'loss': 6.0535, 'learning_rate': 2.3254658315702903e-05, 'epoch': 0.05}                                                                                                                                   
{'loss': 5.598, 'learning_rate': 2.6809450068309278e-05, 'epoch': 0.07}                                                                                                                                    
{'loss': 5.2824, 'learning_rate': 2.9266416338062584e-05, 'epoch': 0.1}                                                                                                                                    
{'loss': 5.0738, 'learning_rate': 3.114597855245884e-05, 'epoch': 0.12}                                                                                                                                    
{'loss': 4.8191, 'learning_rate': 3.266853634404809e-05, 'epoch': 0.15}                                                                                                                                    
{'loss': 4.5336, 'learning_rate': 3.3948300828875964e-05, 'epoch': 0.17}       

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/295586.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL 8.0.13安装配置教程

写个博客记录一下&#xff0c;省得下次换设备换系统还要到处翻教程&#xff0c;直接匹配自己常用的8.0.13版本 1.MySQL包解压到某个路径 2.将bin的路径加到系统环境变量Path下 3.在安装根目录下新建my.ini配置文件&#xff0c;并用编辑器写入如下数据 [mysqld] [client] port…

30. UE5 RPG GamplayAbility的配置项

在上一篇文章&#xff0c;我们介绍了如何将GA应用到角色身上的&#xff0c;接下来这篇文章&#xff0c;将主要介绍一下GA的相关配置项。 在这之前&#xff0c;再多一嘴&#xff0c;你要能激活技能&#xff0c;首先要先应用到ASC上面&#xff0c;才能够被激活。 标签 之前介绍…

【SpringBoot整合系列】SpirngBoot整合EasyExcel

目录 背景需求发展 EasyExcel官网介绍优势常用注解 SpringBoot整合EaxyExcel1.引入依赖2.实体类定义实体类代码示例注解解释 3.自定义转换器转换器代码示例涉及的枚举类型 4.Excel工具类5.简单导出接口SQL 6.简单导入接口SQL 7.复杂的导出&#xff08;合并行、合并列&#xff0…

python Flask扩展:如何查找高效开发的第三方模块(库/插件)

如何找到扩展以及使用扩展的文档 一、背景二、如何寻找框架的扩展&#xff1f;三、找到想要的扩展四、找到使用扩展的文档五、项目中实战扩展 一、背景 刚入门python的flask的框架&#xff0c;跟着文档学习了一些以后&#xff0c;想着其实在项目开发中&#xff0c;经常会用到发…

每日面经分享(Spring Boot: part3 Service层)

SpringBoot Service层的作用 a. 封装业务逻辑&#xff1a;Service层负责封装应用程序的业务逻辑。Service层是控制器&#xff08;Controller&#xff09;和数据访问对象&#xff08;DAO&#xff09;之间的中间层&#xff0c;负责处理业务规则和业务流程。通过将业务逻辑封装在S…

当面试官问你插入排序算法,你敢说自己会吗?

算法学习的重要性 在程序员的世界里&#xff0c;算法就如同一座桥梁&#xff0c;连接着问题与解决方案&#xff0c;是实现优秀程序的关键。 掌握算法&#xff0c;就能够在面对各种问题时&#xff0c;找到最合适的解决方法&#xff0c;以最少的时间和空间&#xff0c;实现最优的…

基于FPGA的SPI_FLASH程序设计

SPI_FLASH简介 spi_flash是一种通用存储器&#xff0c;也称为SPI NOR Flash或SPI Flash。它使用SPI&#xff08;Serial Peripheral Interface&#xff09;接口进行通信&#xff0c;可以通过串行方式读写数据。spi_flash的特点是工作电压低&#xff0c;体积小&#xff0c;读写速…

梨花带雨网页音乐播放器二开优化修复美化版全开源版本源码

源码简介 最新梨花带雨网页音乐播放器二开优化修复美化版全开源版本源码下载 梨花带雨播放器基于thinkphp6开发的XPlayerHTML5网页播放器前台控制面板,支持多音乐平台音乐解析。二开内容&#xff1a;修复播放器接口问题&#xff0c;把接口本地化&#xff0c;但是集成外链播放器…

C++的并发世界(三)——线程对象生命周期

0.案例代码 先看下面一个例子&#xff1a; #include <iostream> #include <thread>void ThreadMain() {std::cout << "begin sub thread:" << std::this_thread::get_id()<<std::endl;for (int i 0; i < 10; i){std::cout <&…

矩阵间关系的建立

参考文献 2-D Compressive Sensing-Based Visually Secure Multilevel Image Encryption Scheme 加密整体流程如下: 我们关注左上角这一部分: 如何在两个图像之间构建关系,当然是借助第3个矩阵。 A. Establish Relationships Between Different Images 简单说明如下: …

Android的图片加载框架

Android的图片加载框架 为什么要使用图片加载框架&#xff1f;图片加载框架1. Universal Image Loader [https://github.com/nostra13/Android-Universal-Image-Loader](https://github.com/nostra13/Android-Universal-Image-Loader)2. Glide [https://muyangmin.github.io/gl…

美摄科技AI智能图像矫正解决方案

图像已经成为了企业传播信息、展示产品的重要媒介&#xff0c;在日常拍摄过程中&#xff0c;由于摄影技巧的限制和拍摄环境的复杂多变&#xff0c;许多企业面临着图像内容倾斜、构图效果不佳等挑战&#xff0c;这无疑给企业的形象展示和信息传递带来了不小的困扰。 美摄科技深…

CentOS7安装flink1.17完全分布式

前提条件 准备三台CenOS7机器&#xff0c;主机名称&#xff0c;例如&#xff1a;node2&#xff0c;node3&#xff0c;node4 三台机器安装好jdk8&#xff0c;通常情况下&#xff0c;flink需要结合hadoop处理大数据问题&#xff0c;建议先安装hadoop&#xff0c;可参考 hadoop安…

顶顶通呼叫中心中间件-话术编辑器机器人转人工坐席配置(mod_cti基于FreeSWITCH)

顶顶通呼叫中心中间件-话术编辑器机器人转人工座席配置(mod_cti基于FreeSWITCH) 配置方法 一、ACD排队转接 二、伴随转接 比如你设置的通知规则是任意满足一个就通知那么通话时间设置为10 秒那样他只要通话时间到10秒他就会转坐席。 如果要转人工的时侯转手机可以这样配置 把…

用于HUD平视显示器的控制芯片:S2D13V40

一款利用汽车抬头显示技术用于HUD平视显示器的控制芯片:S2D13V40。HUD的全称是Head Up Display&#xff0c;即平视显示器&#xff0c;以前应用于军用飞机上&#xff0c;旨在降低飞行员需要低头查看仪表的频率。起初&#xff0c;HUD通过光学原理&#xff0c;将驾驶相关的信息投射…

53 v-bind 和 v-model 的实现和区别

前言 这个主要的来源是 偶尔的情况下 出现的问题 就比如是 el-select 中选择组件之后, 视图不回显, 然后 model 不更新等等 这个 其实就是 vue 中 视图 -> 模型 的数据同步, 我们通常意义上的处理一般是通过 模型 -> 数据 的数据同步, 比如 我们代码里面更新了 model.…

pygame--坦克大战(二)

加载敌方坦克 敌方坦克的方向是随机的&#xff0c;使用随机数生成。 初始化敌方坦克。 class EnemyTank(Tank):def __init__(self,left,top,speed):self.images {U: pygame.image.load(img/enemy1U.gif),D: pygame.image.load(img/enemy1D.gif),L: pygame.image.load(img/e…

10_MVC

文章目录 JSON常用的JSON解析Jackson的常规使用指定日期格式 MVC设计模式MVC介绍前后端分离案例&#xff08;开发与Json相关接口&#xff09; 三层架构三层架构介绍 JSON JSON&#xff08;JavaScript Object Notation&#xff09; 是一种轻量级的数据交换格式&#xff0c;是存…

python 爱心代码

效果图&#xff1a; 代码&#xff1a; import random from math import sin, cos, pi, log from tkinter import *CANVAS_WIDTH 640 CANVAS_HEIGHT 480 CANVAS_CENTER_X CANVAS_WIDTH / 2 CANVAS_CENTER_Y CANVAS_HEIGHT / 2 IMAGE_ENLARGE 11 # 设置颜色 HEART_COLOR &…

QT中的文件操作QFile、QDataStream、QTextStream、QBuffer

文件操作概述 1、Qt中IO操作的处理方式 &#xff08;1&#xff09;、Qt通过统一的接口简化了文件与外部设备的操作方式 &#xff08;2&#xff09;、Qt中的文件被看做是一种特殊的外部设备 &#xff08;3&#xff09;、Qt中的文件操作与外部设备操作相同 2、IO操作中的关键…