【mT5多语言翻译】之五——训练:中央日志、训练可视化、PEFT微调

·请参考本系列目录:【mT5多语言翻译】之一——实战项目总览

[1] 模型训练与验证

  在上一篇实战博客中,我们讲解了访问数据集中每个batch数据的方法。接下来我们介绍如何训练mT5模型进行多语言翻译微调。

  首先加载模型,并把模型设置为训练状态,然后定义优化器、学习率衰减,并设置一些初始状态值。

model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)
model.train()optimizer = AdamW(model.parameters(), lr=config.learning_rate)
# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_iter) * config.num_epochs)
total_batch = 0  # 记录进行到多少batch
losses = []

【注】设置了model.train()后,会激活dropout。而设置model.eval()后,会关闭dropout,防止影响模型的推理结果。

  接着,我们就可以访问数据加载器中的数据进行模型的微调了:

    for epoch in range(config.num_epochs):for i, (input_batch, label_batch) in enumerate(train_iter):optimizer.zero_grad()model_out = model.forward(input_ids=input_batch, labels=label_batch)loss = model_out.losslosses.append(loss.item())loss.backward()optimizer.step()scheduler.step()# 打印步if (total_batch + 1) % config.print_step == 0:avg_loss = np.mean(losses[-config.print_step:])logger.info('Epoch: {} | Step: {} | Train Avg. loss: {:.3f} | lr: {} | Time: {}'.format(epoch + 1,total_batch + 1, avg_loss, scheduler.get_last_lr()[0], get_time_dif(start_time)))# 验证步if (total_batch + 1) % config.checkpoint_step == 0:test_loss = evaluate(model, dev_iter)torch.save(model.state_dict(), config.save_path)time_dif = get_time_dif(start_time)logger.info('Test Avg. loss: {:.3f} | Time: {} '.format(test_loss, time_dif))model.train()total_batch += 1torch.save(model.state_dict(), config.save_path)

  如上,训练的代码还是很少的,里面有一些注意事项需要详细说明。

  1、由于我们不是每一个batch都要输出一下损失到控制台。所有我们需要设置打印步,来控制每隔多少个batch输出一次模型的loss。因此需要一个数组将每次的训练损失记录下来,然后打印时求下平均值。

  2、同理,在验证集上测试性能也是这个道理。但是由于我们的数据集很大,即使模型在验证集上的loss不再下降,也不应该主动把模型停止。因为大模型的训练可以抽象为“压缩数据”的概念,它没见过的数据就是不会产生相应的知识,所以最好还是让模型一直训练下去,直到把数据集训练完。

  模型的验证代码属于训练的阉割版,比较简单,如下:

def evaluate(model, data_iter):model.eval()eval_losses = []with torch.no_grad():for input_batch, label_batch in data_iter:model_out = model.forward(input_ids=input_batch, labels=label_batch)eval_losses.append(model_out.loss.item())return np.mean(eval_losses)

[2] 中央日志

  由于模型训练的时间较长,一旦断联可能就无法再继续观测模型在控制台打印的输出。因此,我们需要设计一个日志功能,让模型即可以实时的打印输出,又能同时记录输出到文件中,以便于我们后期查看。

  还有一个待解决的问题是,项目中不同文件的输出日志,我们需要将其定位到同一个log日志中。

  所以需要在项目中配置根日志器。配置代码如下:

import logging.config# 定义日志配置
LOGGING_CONFIG = {'version': 1,'formatters': {'default': {'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',},},'handlers': {'console': {'class': 'logging.StreamHandler',  # 输出到控制台'level': 'INFO','formatter': 'default',},'file': {'class': 'logging.FileHandler',  # 输出到文件'filename': 'train.log','level': 'DEBUG','formatter': 'default',},},'loggers': {'': {  # root logger'handlers': ['console', 'file'],'level': 'DEBUG',},},
}# 在根日志器(没有名称的日志器)上应用日志配置,那么所有子日志器都会继承这些设置
logging.config.dictConfig(LOGGING_CONFIG)
# 获取日志器
logger = logging.getLogger(__name__)

  我们在项目的任意一个py文件中写下这样的配置。然后在其他py文件中只需要2行代码即可让输出流即展示在控制台也能保存在同一个log日志文件中:

# 获取日志器
logger = logging.getLogger(__name__)
# 记录日志
logger.info("xxxxxx")

[3] 训练可视化

  我们在训练过程中,会有实时观测损失波动、评价指标波动的需求。因此项目集成了tensorBoardX来进行训练可视化。

  先上效果图:
在这里插入图片描述

  首先安装tensorboard,输入命令:

pip install tensorboard 

  然后在代码中使用write即可,代码demo:

import numpy as np
from torch.utils.tensorboard import SummaryWriter  # 也可以使用 tensorboardX
# from tensorboardX import SummaryWriter  # 也可以使用 pytorch 集成的 tensorboardwriter = SummaryWriter('log') # 配置生成的数据保存的地址
for epoch in range(100):writer.add_scalar('test/squared', np.square(epoch), epoch)writer.close()

  执行上述代码后在本文件更目录下生成一个logs文件,且包含了一个事件文件。

  在pycharm中terminal终端输入:

tensorboard --logdir=logs

  一定要注意起初配置的生成文件保存地址,你在terminal终端中命令的地址要能够访问的到!!!

  输入命令后,会生成一个地址,访问即可。

在这里插入图片描述

  在本实战项目的实际训练中,通过is_write变量来控制是否要进行训练可视化,因此实际的代码如下:

	model.train()...if config.is_write:writer = SummaryWriter(log_dir="{0}/{1}".format(config.tb_log_path, time.strftime('%m-%d_%H.%M', time.localtime())))for epoch in range(config.num_epochs):for i, (input_batch, label_batch) in enumerate(train_iter):...# 验证步...if config.is_write:writer.add_scalar('loss/train', loss.item(), total_batch)...if config.is_write:writer.close()

  上述的代码判断is_write变量如果为True,那么会创建一个writer对象,并且每个batch都会记录训练loss。

[4] PEFT微调

  PEFT微调可能是大家最需要的方法。

  官方已经提供了peft库,直接安装即可:

pip install peft 

  不过在使用之前,我们需要明确2点:1)使用PEFT如何修改模型结构?2)使用PEFT如何保存模型?

  其实非常简单!

  使用PEFT如何修改模型结构?

  常规加载模型的代码为:

model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)

  使用PEFT方法修改加载模型后的代码为:

from peft import get_peft_model, LoraConfigmodel = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)
peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
model = get_peft_model(model, peft_config)

【注】朋友们,没看错!使用PEFT只需要在原来的基础上加两行代码即可,其他模型训练阶段的代码完全不需要改变。我这里演示的是使用LoRA进行微调,peft里目前还集成了Prefix Tuning等其他方法。

  使用PEFT如何保存模型?

  我们知道,使用PEFT微调模型是不会修改大模型本身的参数的,因此我们保存只要保存PEFT方法添加的那部分参数即可,PEFT模型保存方法如下:

model.save_pretrained(config.peft_save)

  而传统保存全量参数的代码如下:

torch.save(model.state_dict(), config.save_path)

【注】peft使用起来真的太方便了。

[5] 集成了所有以上功能的模型训练与验证代码

# 获取日志器
logger = logging.getLogger(__name__)def train(config, model, train_iter, dev_iter):start_time = time.time()model.train()optimizer = AdamW(model.parameters(), lr=config.learning_rate)# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0,num_training_steps=len(train_iter) * config.num_epochs)total_batch = 0  # 记录进行到多少batchlosses = []if config.is_write:writer = SummaryWriter(log_dir="{0}/{1}".format(config.tb_log_path, time.strftime('%m-%d_%H.%M', time.localtime())))for epoch in range(config.num_epochs):for i, (input_batch, label_batch) in enumerate(train_iter):optimizer.zero_grad()model_out = model.forward(input_ids=input_batch, labels=label_batch)loss = model_out.losslosses.append(loss.item())loss.backward()optimizer.step()scheduler.step()# 打印步if (total_batch + 1) % config.print_step == 0:avg_loss = np.mean(losses[-config.print_step:])logger.info('Epoch: {} | Step: {} | Train Avg. loss: {:.3f} | lr: {} | Time: {}'.format(epoch + 1,total_batch + 1, avg_loss, scheduler.get_last_lr()[0], get_time_dif(start_time)))# 验证步if (total_batch + 1) % config.checkpoint_step == 0:test_loss = evaluate(model, dev_iter)if config.is_peft:model.save_pretrained(config.peft_save)else:torch.save(model.state_dict(), config.save_path)time_dif = get_time_dif(start_time)logger.info('Test Avg. loss: {:.3f} | Time: {} '.format(test_loss, time_dif))if config.is_write:writer.add_scalar('loss/train', loss.item(), total_batch)model.train()total_batch += 1if config.is_peft:model.save_pretrained(config.peft_save)else:torch.save(model.state_dict(), config.save_path)if config.is_write:writer.close()def evaluate(model, data_iter):model.eval()eval_losses = []with torch.no_grad():for input_batch, label_batch in data_iter:model_out = model.forward(input_ids=input_batch, labels=label_batch)eval_losses.append(model_out.loss.item())return np.mean(eval_losses)

[6] 进行下一篇实战

  【mT5多语言翻译】之六——推理:多语言翻译与第三方接口设计

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/304271.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STL库 —— list 的编写

目录 一、成员变量 ​编辑 二、push_back 函数 三、迭代器 iterator 3.1 iterator 结构体 3.2 begin() 与 end() 函数 3.3 iterator 运算符重载 3.4 -> 的重载 3.5 const_iterator 四、测试代码 五、修饰符成员 5.1 insert 函数 5.2 erase 函数 5.3 push 函数…

机器学习和深度学习-- 李宏毅(笔记与个人理解)Day10

Day 10 Genaral GUidance training Loss 不够的case Loss on Testing data over fitting 为什么over fitting 留到下下周哦~~ 期待 solve CNN卷积神经网络 Bias-Conplexiy Trade off cross Validation how to split? N-fold Cross Validation mismatch 这节课总体听下来比较…

lovesql 手工sql注入

1.页面 2.万能密码登录成功 我还傻乎乎的以为密码就是flag 但不是 3. 继续注入 判断列数 确定了只有三列 开始尝试联合注入 4.使用联合注入之前先判断显示位 5.之后一步一步的构造,先得到当前数据库名 利用database() 再得到库里有哪些表 …

Vue+el-table 修改表格 单元格横线边框颜色及表格空数据时边框颜色

需求 目前 找到对应的css样式进行修改 修改后 css样式 >>>.el-table th.el-table__cell.is-leaf {border-bottom: 1px solid #444B5F !important;}>>>.el-table td.el-table__cell,.el-table th.el-table__cell.is-leaf {border-bottom: 1px solid #444B5F …

目标检测——3D车道数据集

一、重要性及意义 3D车道检测在自动驾驶和智能交通领域具有极其重要的地位,其重要性和意义主要体现在以下几个方面: 首先,3D车道检测可以精确判断车辆在道路上的位置、方向和速度,从而预测潜在的危险情况并及时采取措施。这种能…

如何在极狐GitLab 启用依赖代理功能

本文作者:徐晓伟 GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署极狐GitLab。 本文主要讲述了如何在[极狐GitLab…

UDP实现Mini版在线聊天室

实现原理 只有当客户端先对服务器发送online消息的时候,服务器才会把客户端加入到在线列表。当在线列表的用户发消息的时候,服务器会把消息广播给在线列表中的所有用户。而当用户输入offline时,表明自己要下线了,此时服务器把该用…

【软件测试】个人博客系统测试

个人博客系统测试 一、项目背景1.1 技术背景1.2 功能背景 二、自动化测试2.1 什么是自动化测试2.2 通过使用selenium进行自动化测试的编写(Java实现)2.3 编写测试用例,执行自动化测试2.3.1 输入用户名:test,密码:123,登录成功2.3.…

云服务器上Docker启动的MySQL会自动删除数据库的问题

一、问题说明 除了常见的情况,例如没有实现数据挂载,导致数据丢失外,还需要考虑数据库是否被攻击,下图 REVOVER_YOUR_DATA 就代表被勒索了,这种情况通常是数据库端口使用了默认端口(3306)且密码…

jmeter实验 模拟:从CSV数据到加密请求到解密返回数据再到跨越线程组访问解密后的数据

注意,本实验所说的加密只是模拟加密解密,您需要届时写自己的加解密算法或者引用含有加密算法的相关jar包才行. 思路: 线程组1: 1.从CSV文件读取原始数据 2.将读取到的数据用BeanShell预习处理器进行加密 3.HTTP提取器使用加密后的数据发起请求 4.使用BeanShell后置处理器…

vite+react+ts+scss 创建项目

npm create vitelatest输入项目名称选择react选择typescript swc WC 通过利用 Rust 编写的编译器,使用了更先进的优化技术,使得它在处理 TypeScript 代码时能够更快地进行转换和编译。特别是在大型项目中,SWC 相对于传统的 TypeScript 编译器…

使用QT 开发不规则窗体

使用QT 开发不规则窗体 不规则窗体贴图法的不规则窗体创建UI模板创建一个父类创建业务窗体main函数直接调用user_dialog创建QSS文件 完整的QT工程 不规则窗体 QT中开发不规则窗体有两种方法:(1)第一种方法,使用QWidget::setMask函…

跨域问题一文解决

📝个人主页:五敷有你 🔥系列专栏:Vue ⛺️稳中求进,晒太阳 一、为什么会出现跨域的问题? 是浏览器的同源策略,跨域也是因为浏览器这个机制引起的,这个机制的存在还是在于安全…

FFmpeg: 简易ijkplayer播放器实现--01项目简介

文章目录 项目介绍流程图播放器实现过程界面展示 项目介绍 此项目基于FFmeg中 ffplay.c进行二次开发,实现基本的功能,开发软件为Qt 项目优势: 参考ijkplayer播放器,实现UI界面和播放器核心进行解耦,容易添加其他功能…

机器学习和深度学习 -- 李宏毅(笔记与个人理解1-6)

机器学习和深度学习教程 – 李宏毅(笔记与个人理解) day1 课程内容 什么是机器学习 找函数关键技术(深度学习) 函数 – 类神经网络来表示 ;输入输出可以是 向量或者矩阵等如何找到函数: supervised Lear…

SSRF靶场

SSRF概述 ​ 强制服务器发送一个攻击者的请求 ​ 互联网上的很多web应用提供了从其他服务器(也可以是本地)获取数据的功能。使用用户指定的URL,web应用可以获取图片(载入图片)、文件资源(下载或读取)。如下图所示&…

C语言面试题之返回倒数第 k 个节点

返回倒数第 k 个节点 实例要求 1、实现一种算法,找出单向链表中倒数第 k 个节点;2、返回该节点的值; 示例:输入: 1->2->3->4->5 和 k 2 输出: 4 说明:给定的 k 保证是有效的。实…

uniapp开发h5端使用video播放mp4格式视频黑屏,但有音频播放解决方案

mp4格式视频有一些谷歌播放视频黑屏,搜狗浏览器可以正常播放 可能和视频的编码格式有关,谷歌只支持h.264编码格式的视频播放 将mp4编码格式修改为h.264即可 转换方法: 如果是自己手动上传文件可以手动转换 如果是后端接口调取的地址就需…

效果炸裂!文生图再升级,支持多对象个性化图片生成!开源!

大家好,今天和大家分享一篇最新的文生图工作,代码已开源 标题:Identity Decoupling for Multi-Subject Personalization of Text-to-Image Models 单位:KAIST 论文:https://arxiv.org/pdf/2404.04243.pdf 代码&#xf…

【linux】yum 和 vim

yum 和 vim 1. Linux 软件包管理器 yum1.1 什么是软件包1.2 查看软件包1.3 如何安装软件1.4 如何卸载软件1.5 关于 rzsz 2. Linux编辑器-vim使用2.1 vim的基本概念2.2 vim的基本操作2.3 vim命令模式命令集2.4 vim底行模式命令集2.5 vim操作总结补充:vim下批量化注释…