LLM PreTraining from scratch -- 大模型从头开始预训练指北

最近做了一些大模型训练相关的训练相关的技术储备,在内部平台上完成了多机多卡的llm 预训练的尝试,具体的过程大致如下:

数据准备:

大语言模型的训练依赖于与之匹配的语料数据,在开源社区有一群人在自发的整理高质量的语料数据,可以通过 以下的一些链接获取

liwu/MNBVC at main

Skywork/SkyPile-150B · Datasets at Hugging Face

预训练框架:

选用了百川智能的开源框架

原始版本代码训练准备:

根据README 里面的介绍,需要准备以下几样东西:

  • 训练数据,按照训练的卡的数目分成多个文件,每个文件的每一行为一整句的语料,类似这样的文件

添加图片注释,不超过 140 字(可选)

  • 分词器(tokenizer) ,下载 分词器 到当前目录下。

  • 修改hostfile训练脚本,单机训练情况下,不依赖于多机多卡的hostfile, 修改启动脚本 添加启动项 --num_nodes 即可完成单机多卡的训练

 

#!/bin/bash

deepspeed --hostfile config/hostfile --num_nodes=1 \ train.py \ --deepspeed \ --deepspeed_config config/deepspeed.json

原版的训练中几个要处理的问题

  • deepspeed.zero.Init. 错误

这个错误是发生在 deepspeed 设置优化等级不为3的情况下,调用deepspeed.zero.Init 函数会报错,需要在初始化的时候判断一下优化等级是不是3,因此修改代码如下:

 
//train.py
def prepare_model():ds_config = json.load(open(args.deepspeed_config))# print(type(ds_config["zero_optimization"]['stage']))with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config,enabled=ds_config["zero_optimization"]['stage']==3,mem_efficient_linear=False,mpu=None):model = BaiChuanForCausalLM(smallconfig)
  • 数据不可无限读取

    def get_data(self):# todo 循环读取#data = self.data.pop(0)

在原版本的实现中,数据是从队列中pop 出来的,导致当数据读完了之后会报一个错误导致训练中断

此外原版代码中所有的语料是读取到内存中再进行操作,但是随着语料的量级达到T级别,基本无法全部用内存hold 住所有的语料,另外读取语料的到内存的时间也会很长,基于以上几点考虑,重新选择tfrecord 作为新的数据的存储方式

TFRecord

tfrecord 是一种在tensorflow 中常用的数据格式,数据基于protobuf 完成序列化存储,配合对应的index 可以实现高速的数据读取从而减少数据读取造成的性能瓶颈,原版的训练代码基于的pytorch的框架,可以pip 安装

pip install tfrecord

来使用这个数据结构, 注意这个库可能会遇到protobuf 版本库的问题,通过pip 重新安装 protobuf==3.19 可以解决,

编写对应的代码完成将原来的jsonl 数据转换成tfrecord

//tools/jsonlmutiltfrecord.py
import tfrecord
import os
from tqdm import tqdm
import json
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset, TFRecordDatasetori_path = "/workspace/mnt/storage/zhaozhijian/silk-debug/Baichuan-7B/data_dir_ori"
out_path = "/workspace/mnt/storage/zhaozhijian/silk-debug/Baichuan-7B/data_dir_mutil_test"if not os.path.exists(out_path):os.mkdir(out_path)numgpu = 16 
fidlist = []
for i in range(numgpu):writer = tfrecord.TFRecordWriter(os.path.join(out_path,"data" +str(i)+".tfrecord"))fidlist.append(writer)if 1:files = os.listdir(ori_path)count = 0for file in files:with open(os.path.join(ori_path, file)) as f:for line in tqdm(f.readlines()):dict_ = json.loads(line)fidlist[count%numgpu].write({"text":(dict_["text"].encode('utf-8'), "byte")})count +=1for writer in fidlist:writer.close()os.system("python3 -m tfrecord.tools.tfrecord2idx " + os.path.join(out_path))tfrecord_path = os.path.join(out_path,"data{}.tfrecord")
index_path = os.path.join(out_path,"data{}.tfindex")
splits = {"1": 1,
}
description = {"text": "byte"}
dataset = MultiTFRecordDataset(tfrecord_path, index_path, splits, description, infinite=False)
loader = torch.utils.data.DataLoader(dataset, batch_size=1)for item in loader:print(item['text'].decode('utf-8'))

这里需要注意,tfrecord 的写入的数据只有int,float, byte 3种形式,因此string 格式的数据数据需要通过utf-8的编码写入到tfrecord 中,再读取的时候通过utf-8的解码才能还原为写入的string数据,对应修改train.py 文件,

from tfrecord.torch.dataset import TFRecordDataset, MultiTFRecordDataset...
class DataEngine():...def load_tfrecode_data_mutil(self):splits = {}for file_path in self.local_input_paths:   splits[file_path.replace('.tfrecord', '')] = 1.0/len(self.local_input_paths)tfrecord_path = "{}.tfrecord"index_path = "{}.tfindex"description = {"text": "byte"}dataset = MultiTFRecordDataset(tfrecord_path, index_path, splits, description, infinite=False)self.loader = torch.utils.data.DataLoader(dataset, batch_size=1)return
...
def prepare_data():data_dir = args.data_dir....#    data_engine.load_data()data_engine.load_tfrecode_data_mutil()return data_engine
...
def train(data_engine, model_engine):model_engine.train()step = 0data =[]for item in data_engine.loader:while 1:line = item['text'].decode('utf-8')cc = data_engine.sp.EncodeAsIds(line.strip()) + [data_engine.EOS_TOKEN_ID]if len(cc) < data_engine.MIN_TEXT_LEN:continuedata.extend(cc)if len(data) >= data_engine.micro_batch_size * (data_engine.max_length + 1):index = data_engine.micro_batch_size * (data_engine.max_length + 1)data = data[:index]breakseq = np.asarray(data).reshape(data_engine.micro_batch_size, data_engine.max_length + 1)data = torch.LongTensor(seq)data = data.cuda(non_blocking=True)loss = model_engine(data, labels=data).lossmodel_engine.backward(loss)model_engine.step()step += 1data =[]return
成功解决数据加载中内存和读取速度的问题

多机多卡训练

原版本使用的hostfile 做为启动器,这个有一个前提条件需要各个机器之间可以通过ssh协议互相通信,但是在我们的内部ATOM的环境中无法做到这个,所以启动多机多卡的训练的时候会出现启动两个单机训练和无法启动训练两种情况,这些和我们的多机多卡训练不符

经过摸索后,我们采用了torchrun的启动方式,利用master_addr 等环境变量,用torchstyle 的方式启动多机多卡训练,解决了deepspeed 启动器对于ssh 通信的依赖

NUM_GPUS=8
torchrun --nnodes=$WORLD_SIZE --nproc-per-node=$NUM_GPUS --master-addr=$MASTER_ADDR \--master-port=$MASTER_PORT --node-rank=$RANK \train.py \--deepspeed \--deepspeed_config config/deepspeed.json >log$RANK.txt

对应的修改train.py 中的一些内容:

//train.py
###
deepspeed.init_distributed()
args.local_rank=int(os.environ['LOCAL_RANK'])
###
def prepare_data():...model = BaiChuanForCausalLM(smallconfig)torch.cuda.set_device(args.local_rank)...def train(data_engine, model_engine):model_engine.train()local_rank = int(os.environ['LOCAL_RANK'])...data = data.cuda(non_blocking=True).to(local_rank)...

一些遗留的BUG:

启动训练会卡住: 原因特别傻,就是现在在数据目录下会有tfrecord 和 index 两种后缀的文件,在按照radnk分的时候由于不够随机,会有loader 读取不到文件,导致计算loss 时候卡住,修改 DataEngine

files = [x for _, x in enumerate(self.global_input_paths)if x.find('.tfrecord') != -1]self.local_input_paths = [x for i, x inenumerate(files)if i % dist.get_world_size() == dist.get_rank()]

即可。

最终的训练loss 如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/271614.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

读《文明之光》第1册总结

人类几千年的文明史和地球的历史相比&#xff0c;实在是太短暂了&#xff0c;大约相当于几分钟和一年的关系。人类已经走过的路&#xff0c;相比今后要走的漫漫长路&#xff0c;只能算是刚刚起步。如果跳出一个个具体事件&#xff0c;站在历史的高度去看&#xff0c;我们会发现…

前端实现一个绕圆心转动的功能

得知了转换关系&#xff0c;我们就可以定义一个变量 angle 来表示我们这个 div 做圆周运动时绕圆心转过的角度&#xff0c;则弧度&#xff08;radian&#xff09; 为 radian &#xff08;angle*π&#xff09;/180 我们先在草稿纸上演练一遍我们的逻辑是否可行。让我们先准备一…

货运物流小程序开发功能 发货运输更简单

随着互联网的快速发展&#xff0c;线上接单已经成为物流行业的主流趋势。货运物流接单小程序作为物流企业的得力助手&#xff0c;能够提高运输效率、降低成本、提升服务质量&#xff0c;成为物流行业的发展新方向。 1. 用户注册与登录功能&#xff1a;用户可以通过手机号、邮箱…

光谱下的养殖业:数据可视化的现代变革

在数字化时代&#xff0c;数据可视化在养殖业中崭露头角&#xff0c;为这一传统行业注入了新的活力。无论是家禽养殖还是水产养殖&#xff0c;数据可视化都以其直观、高效的特点&#xff0c;为养殖业带来了全新的发展机遇。下面我就以可视化从业者的角度&#xff0c;简单聊聊这…

华为od机试C卷-开源项目热度榜单

1、题目描述 某个开源社区希望将最近热度比较高的开源项目出一个榜单&#xff0c;推荐给社区里面的开发者。 对于每个开源项目&#xff0c;开发者可以进行关注(watch)、收藏(star)、fork、提issue、提交合并请求(MR)等。 数据库里面统计了每个开源项目关注、收藏、fork、issue…

【自然语言处理六-最重要的模型-transformer-上】

自然语言处理六-最重要的模型-transformer-上 什么是transformer模型transformer 模型在自然语言处理领域的应用transformer 架构encoderinput处理部分&#xff08;词嵌入和postional encoding&#xff09;attention部分addNorm Feedforward & add && NormFeedforw…

数睿通2.0数据接入升级——支持增量字段同步,表单独映射

引言 上次数睿通 2.0 更新是在 23 年12 月 底&#xff0c;已经过去了接近三个月的时间&#xff0c;中间由于过年加上年前年后实在是工作繁忙&#xff0c;所以一直没有腾出空来更新代码&#xff0c;希望大家可以理解&#xff0c;平台的发展离不开你们的支持&#xff0c;在此表示…

2021年PAT--春

Arithmetic Progression of Primes In mathematics, an arithmetic progression (AP&#xff0c;等差数列) is a sequence of numbers such that the difference between the consecutive terms is constant. In 2004, Terence Tao (陶哲轩) and Ben Green proved that for an…

sql server使用逗号,分隔保存多个id的一些查询保存

方案一&#xff0c;前后不附加逗号&#xff1a; 方案二&#xff0c;前后附加逗号&#xff1a; 其他保存方案&#xff1a; &#xff08;这里是我做一个程序的商家日期规则搞得&#xff0c;后面再补具体操作&#xff09;&#xff1a; 1,2,3 | 1,2,3 | 1,2,3; 1,2,3 &#xff1…

奖励建模(Reward Modeling)实现人类对智能体的反馈

奖励建模&#xff08;Reward Modeling&#xff09;是强化学习中的一个重要概念和技术&#xff0c;它主要用于训练智能体&#xff08;如AI机器人或大型语言模型&#xff09;如何更有效地学习和遵循人类期望的行为。在强化学习环境中&#xff0c;智能体通过尝试不同的行为获得环境…

S4---FPGA-K7板级原理图硬件实战

视频链接 FPGA-K7板级系统硬件实战01_哔哩哔哩_bilibili FPGA-K7板级原理图硬件实战 基于XC7K325TFFG900的FPGA硬件实战框图 基于XILINX 的KINTEX-7 芯片XC7K325FPGA的硬件平台&#xff0c;FPGA 开发板挂载了4 片512MB 的高速DDR3 SDRAM 芯片&#xff0c;另外板上带有一个SODIM…

【新版Hi3521DV200处理器性能】

新版Hi3521DV200处理器性能 Hi3521DV200是针对多路高清/超高清&#xff08;1080p/4M/5M/4K&#xff09;DVR产品应用开发的新一代专业SoC芯片。Hi3521DV200集成了ARM Cortex-A7四核处理器和性能强大的神经网络推理引擎&#xff0c;支持多种智能算法应用。同时&#xff0c;Hi352…

UE4升级UE5 蓝图节点变更汇总(4.26/27-5.2/5.3)

一、删除部分 Ploygon Editing删除 Polygon Editing这个在4.26、4.27中的插件&#xff0c;在5.1后彻底失效。 相关的蓝图&#xff0c;如编辑器蓝图 Generate mapping UVs等&#xff0c;均失效。 如需相关功能&#xff0c;请改成Dynamic Mesh下的方法。 GetSupportedClass删…

【c语言】算法1.1:二分查找

目录 题目 算法步骤&#xff08;没带数位板&#xff0c;希望没有丑到您的眼睛&#xff09; 代码 题目 算法步骤&#xff08;没带数位板&#xff0c;希望没有丑到您的眼睛&#xff09; 代码 #include <stdio.h> int main() {int num[4]{1,3,5,6};int t;scanf("%d&…

FPGA FIFO 读取模式

FPGA FIFO 读取模式分两种&#xff1a; Normal Mode: In normal mode, the “rdreq” signal serves as the read request or read enable. When this signal goes high, the data output provides the first data from the FIFO.Essentially, in normal mode, data is availa…

【Spring面试题】

目录 前言 1.Spring框架中的单例bean是线程安全的吗? 2.什么是AOP? 3.你们项目中有没有使用到AOP&#xff1f; 4.Spring中的事务是如何实现的&#xff1f; 5.Spring中事务失效的场景有哪些&#xff1f; 6.Spring的bean的生命周期。 7.Spring中的循环引用 8.构造方法…

ArcGIS筛选工具:19段SQL示例代码,所有需求一网打尽

一、使用方法 筛选工具(Select_analysis)主要用于从输入要素类或输入要素图层中提取要素&#xff08;通常使用选择或结构化查询语言 (SQL) 表达式&#xff09;&#xff0c;并将其存储于输出要素类中。 以三调图斑为例&#xff0c;图斑中有一个【DLMC】字段&#xff0c;该字段…

Facebook的社交未来:元宇宙时代的数字共融

引言&#xff1a; 随着科技的不断进步和社会的快速发展&#xff0c;人们对于社交网络的需求和期待也在不断演变。在这个数字化时代&#xff0c;元宇宙的概念逐渐引发了人们对社交体验的重新思考。作为全球最大的社交网络之一&#xff0c;Facebook正在积极探索元宇宙时代的社交…

知识管理系统:初创企业的智慧助手

一、什么是知识管理系统 用通俗易懂的语言来解释&#xff0c;知识管理系统就像一个超级大脑&#xff0c;帮助企业和团队更好地记住、分享和使用他们学到的东西。无论是工作中的经验、方案还是项目成果&#xff0c;这个系统都能帮大家保存下来&#xff0c;并方便以后查找和使用。…

Redis与 Memcache区别

Redis与 Memcache区别 1 , Redis 和 Memcache 都是将数据存放在内存中&#xff0c;都是内存数据库。不过 Memcache 还可用于缓存 其他东西&#xff0c;例如图片、视频等等。 2 , Memcache 仅支持key-value结构的数据类型&#xff0c;Redis不仅仅支持简单的key-value类型的数据&…