nlp培训重点-5

1. LoRA微调

loader:

# -*- coding: utf-8 -*-import json
import re
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
"""
数据加载
"""class DataGenerator:def __init__(self, data_path, config):self.config = configself.path = data_pathself.index_to_label = {0: '家居', 1: '房产', 2: '股票', 3: '社会', 4: '文化',5: '国际', 6: '教育', 7: '军事', 8: '彩票', 9: '旅游',10: '体育', 11: '科技', 12: '汽车', 13: '健康',14: '娱乐', 15: '财经', 16: '时尚', 17: '游戏'}self.label_to_index = dict((y, x) for x, y in self.index_to_label.items())self.config["class_num"] = len(self.index_to_label)if self.config["model_type"] == "bert":self.tokenizer = BertTokenizer.from_pretrained(config["pretrain_model_path"])self.vocab = load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:for line in f:line = json.loads(line)tag = line["tag"]label = self.label_to_index[tag]title = line["title"]if self.config["model_type"] == "bert":input_id = self.tokenizer.encode(title, max_length=self.config["max_length"], pad_to_max_length=True)else:input_id = self.encode_sentence(title)input_id = torch.LongTensor(input_id)label_index = torch.LongTensor([label])self.data.append([input_id, label_index])returndef encode_sentence(self, text):input_id = []for char in text:input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))input_id = self.padding(input_id)return input_id#补齐或截断输入的序列,使其可以在一个batch内运算def padding(self, input_id):input_id = input_id[:self.config["max_length"]]input_id += [0] * (self.config["max_length"] - len(input_id))return input_iddef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]def load_vocab(vocab_path):token_dict = {}with open(vocab_path, encoding="utf8") as f:for index, line in enumerate(f):token = line.strip()token_dict[token] = index + 1  #0留给padding位置,所以从1开始return token_dict#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == "__main__":from config import Configdg = DataGenerator("valid_tag_news.json", Config)print(dg[1])

model:

import torch.nn as nn
from config import Config
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from torch.optim import Adam, SGDTorchModel = AutoModelForSequenceClassification.from_pretrained(Config["pretrain_model_path"])def choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)

evaluate:

# -*- coding: utf-8 -*-
import torch
from loader import load_data"""
模型效果测试
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)self.stats_dict = {"correct":0, "wrong":0}  #用于存储测试结果def eval(self, epoch):self.logger.info("开始测试第%d轮模型效果:" % epoch)self.model.eval()self.stats_dict = {"correct": 0, "wrong": 0}  # 清空上一轮结果for index, batch_data in enumerate(self.valid_data):if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_ids, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况with torch.no_grad():pred_results = self.model(input_ids)[0]self.write_stats(labels, pred_results)acc = self.show_stats()return accdef write_stats(self, labels, pred_results):# assert len(labels) == len(pred_results)for true_label, pred_label in zip(labels, pred_results):pred_label = torch.argmax(pred_label)# print(true_label, pred_label)if int(true_label) == int(pred_label):self.stats_dict["correct"] += 1else:self.stats_dict["wrong"] += 1returndef show_stats(self):correct = self.stats_dict["correct"]wrong = self.stats_dict["wrong"]self.logger.info("预测集合条目总量:%d" % (correct +wrong))self.logger.info("预测正确条目:%d,预测错误条目:%d" % (correct, wrong))self.logger.info("预测准确率:%f" % (correct / (correct + wrong)))self.logger.info("--------------------")return correct / (correct + wrong)

 main:

# -*- coding: utf-8 -*-import torch
import os
import random
import os
import numpy as np
import torch.nn as nn
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
from peft import get_peft_model, LoraConfig, \PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfig #[DEBUG, INFO, WARNING, ERROR, CRITICAL]
logging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型训练主程序
"""seed = Config["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)def main(config):#创建保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加载训练数据train_data = load_data(config["train_data_path"], config)#加载模型model = TorchModel#大模型微调策略tuning_tactics = config["tuning_tactics"]if tuning_tactics == "lora_tuning":peft_config = LoraConfig(r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query", "key", "value"])elif tuning_tactics == "p_tuning":peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)elif tuning_tactics == "prompt_tuning":peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)elif tuning_tactics == "prefix_tuning":peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)model = get_peft_model(model, peft_config)# print(model.state_dict().keys())if tuning_tactics == "lora_tuning":# lora配置会冻结原始模型中的所有层的权重,不允许其反传梯度# 但是事实上我们希望最后一个线性层照常训练,只是bert部分被冻结,所以需要手动设置for param in model.get_submodule("model").get_submodule("classifier").parameters():param.requires_grad = True# 标识是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,迁移模型至gpu")model = model.cuda()#加载优化器optimizer = choose_optimizer(config, model)#加载效果测试类evaluator = Evaluator(config, model, logger)#训练for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):if cuda_flag:batch_data = [d.cuda() for d in batch_data]optimizer.zero_grad()input_ids, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况output = model(input_ids)[0]loss = nn.CrossEntropyLoss()(output, labels.view(-1))loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))acc = evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "%s.pth" % tuning_tactics)save_tunable_parameters(model, model_path)  #保存模型权重return accdef save_tunable_parameters(model, path):saved_params = {k: v.to("cpu")for k, v in model.named_parameters()if v.requires_grad}torch.save(saved_params, path)if __name__ == "__main__":main(Config)

pred:

import torch
import logging
from model import TorchModel
from peft import get_peft_model, LoraConfig, PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfigfrom evaluate import Evaluator
from config import Configlogging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)#大模型微调策略
tuning_tactics = Config["tuning_tactics"]print("正在使用 %s"%tuning_tactics)if tuning_tactics == "lora_tuning":peft_config = LoraConfig(r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query", "key", "value"])
elif tuning_tactics == "p_tuning":peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prompt_tuning":peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prefix_tuning":peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)#重建模型
model = TorchModel
# print(model.state_dict().keys())
# print("====================")model = get_peft_model(model, peft_config)
# print(model.state_dict().keys())
# print("====================")state_dict = model.state_dict()#将微调部分权重加载
if tuning_tactics == "lora_tuning":loaded_weight = torch.load('output/lora_tuning.pth')
elif tuning_tactics == "p_tuning":loaded_weight = torch.load('output/p_tuning.pth')
elif tuning_tactics == "prompt_tuning":loaded_weight = torch.load('output/prompt_tuning.pth')
elif tuning_tactics == "prefix_tuning":loaded_weight = torch.load('output/prefix_tuning.pth')print(loaded_weight.keys())
state_dict.update(loaded_weight)#权重更新后重新加载到模型
model.load_state_dict(state_dict)#进行一次测试
model = model.cuda()
evaluator = Evaluator(Config, model, logger)
evaluator.eval(0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/30779.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql5.7-yum安装和更改mysql数据存放路径-2020年记录

记录下官网里用yum rpm源安装mysql, 1 官网下载rpm https://dev.mysql.com/downloads/repo/yum/ https://dev.mysql.com/doc/refman/5.7/en/linux-installation-yum-repo.html(附官网操作手册) wget https://repo.mysql.com//mysql80-community-release…

nvm list available为空

nvm list available为空 该问题主要是因为nvm 获取不到node导致,排查网络问题外,可能就是由于nvm环境变量配置问题导致,本次我这个问题就是由于环境变量配置缺少导致的。 第一步:排查并排除了网络问题。 第二步:排查环…

mosfet的驱动设计-栅极电阻

栅极电阻在MOSFET驱动电路中具有关键作用,其阻值直接影响器件开关速度、功率损耗及电磁干扰水平。本文将从物理原理出发,推导典型栅极电阻计算公式,并详细说明各参数选取依据。 本人查阅了很多资料,不同的资料介绍的计算方法也不尽…

Unity Dots从入门到精通之 Prefab引用 转 实体引用

文章目录 前言安装 DOTS 包实体引用Authoring 前言 DOTS(面向数据的技术堆栈)是一套由 Unity 提供支持的技术,用于提供高性能游戏开发解决方案,特别适合需要处理大量数据的游戏,例如大型开放世界游戏。 本文讲解我在…

并查集模板

注意理解路径压缩 static class UnionFind {int[] fa;public UnionFind(int n) {fa new int[n];for (int i 0; i < n; i) {fa[i] i;}}public int find(int i) {if (fa[i] ! i) {fa[i] find(fa[i]);}return fa[i];}public void union(int i, int j) {int fai find(i);in…

深入了解Linux —— 调试程序

前言 我们已经学习了linux下许多的工具&#xff0c;vim、gcc、make/makefile等&#xff1b; 已经能够在linux写代码&#xff0c;并且进行编译运行&#xff0c;让程序在linux下跑起来。 但是&#xff0c;如果我们在写代码的时候遇见了错误&#xff1b;但是我们并不知道错误在哪&…

Python接口自动化之断言封装!

该框架支持两种断言方式&#xff0c;相等和包含。 先看一下断言的yaml文件编写规范&#xff1a; validate: - equals: {status_code: 200} - contains: $ddt{assert_str} 其中assert_str和之前用例一样&#xff0c;作为变量&#xff0c;放在对应的data yaml文件中 # D…

基于Rye的Django项目通过Pyinstaller用Github工作流简单打包

前言 Rye的介绍和安装 Ryehttps://rye.astral.sh/Rye 完整使用教程_安装rye-CSDN博客https://blog.csdn.net/zhenndbc/article/details/144544692 正文 项目建立 配置好环境后 新建文件夹 新建文件夹&#xff0c;进入项目 初始化 rye init下载依赖 rye syncpycharm 打…

Pycharm 取消拼写错误检查(Typo:in word xxx)

现象 Pycharm显示单词存在错误&#xff0c;下面看着有下划波浪线&#xff0c;看着很不舒服。 快捷键AltEnter&#xff0c;查看提示错误。 Typo是啥? "Typo" 这个词通常用于描述打字或排印过程中的小错误&#xff0c;尤其是拼写错误。它指的是在文本中由于打字或印刷…

K8S学习之基础十七:k8s的蓝绿部署

蓝绿部署概述 ​ 蓝绿部署中&#xff0c;一共有两套系统&#xff0c;一套是正在提供服务的系统&#xff0c;一套是准备发布的系统。两套系统都是功能完善、正在运行的系统&#xff0c;只是版本和对外服务情况不同。 ​ 开发新版本&#xff0c;要用新版本替换线上的旧版本&…

三、0-1搭建springboot+vue3前后端分离-idea新建springboot项目

一、ideal新建项目1 ideal新建项目2 至此父项目就创建好了&#xff0c;下面创建多模块&#xff1a; 填好之后点击create 不删了&#xff0c;直接改包名&#xff0c;看自己喜欢 修改包名和启动类名&#xff1a; 打开ServiceApplication启动类&#xff0c;修改如下&#xff1a; …

任天堂Switch拉美游戏价涨,传Switch 2全球或提价

易采游戏网3月9日独家消息&#xff1a;近日据相关资讯显示&#xff0c;在拉丁美洲地区&#xff0c;任天堂Switch的游戏价格出现了上扬态势。这一变化引发了玩家与市场的关注&#xff0c;不过就目前而言&#xff0c;其并未波及全球游戏市场的整体定价格局。但值得注意的是&#…

10.2 继承与多态

文章目录 继承多态 继承 继承的作用是代码复用。派生类自动获得基类的除私有成员外的一切。基类描述一般特性&#xff0c;派生类提供更丰富的属性和行为。在构造派生类时&#xff0c;其基类构造函数先被调用&#xff0c;然后是派生类构造函数。在析构时顺序刚好相反。 // 基类…

如何在需求分析阶段考虑未来扩展性

在需求分析阶段考虑未来扩展性的关键在于 前瞻规划、灵活架构、标准设计。其中&#xff0c;前瞻规划尤为重要&#xff0c;因为通过全面分析业务发展趋势与技术演进&#xff0c;能够在初期设计阶段预留足够扩展空间&#xff0c;降低后期改造成本&#xff0c;为企业长期发展奠定坚…

PawSQL for MSSQL:PawSQL 支持 SQL Server 的SQL优化、SQL审核、性能巡检

0. 概述 在PawSQL的最新版本中&#xff0c;PawSQL 为 SQL Server 数据库提供了全方位的SQL优化、SQL审核、性能巡检支持&#xff0c;覆盖SQL开发、测试、运维的整个生命周期&#xff0c;助力用户充分发挥 SQL Server 数据库的性能潜力。 1. 纳管SQL Server 实例 工作空间是SQ…

【Java代码审计 | 第六篇】XSS防范

文章目录 XSS防范使用HTML转义使用Content Security Policy (CSP)输入验证使用安全的库和框架避免直接使用用户输入构建JavaScript代码 XSS防范 使用HTML转义 在输出用户输入时&#xff0c;对特殊字符进行转义&#xff0c;防止它们被解释为HTML或JavaScript代码。 例如&…

NO.26十六届蓝桥杯备战|字符数组七道练习|islower|isupper|tolower|toupper|strstr(C++)

P5733 【深基6.例1】自动修正 - 洛谷 小写字母 - 32 大写字母 大写字母 32 小写字母 #include <bits/stdc.h> using namespace std;const int N 110; char a[N] { 0 };int main() {ios::sync_with_stdio(false);cin.tie(nullptr);cin >> a;int i 0;while (a…

langChainv0.3学习笔记(初级篇)

LangChain自0.1版本发布以来&#xff0c;已经历了显著的进化&#xff0c;特别是向AI时代的适应性提升。在0.1版本中&#xff0c;LangChain主要聚焦于提供基本的链式操作和工具集成&#xff0c;帮助开发者构建简单的语言模型应用。该版本适用于处理简单任务&#xff0c;但在应对…

qt 播放pcm音频

一、获取PCM音频 ffmpeg -i input.mp3 -acodec pcm_s16le -ar 44100 -ac 2 -f s16le output.pcm -acodec pcm_s16le&#xff1a;指定16位小端PCM编码格式&#xff08;兼容性最佳&#xff09;-ar 44100&#xff1a;设置采样率为CD标准44.1kHz&#xff08;可替换为16000/8000等&a…

Windsuf 连接失败问题:[unavailable] unavailable: dial tcp...

问题描述 3月6日&#xff0c;在使用Windsuf 时&#xff0c;遇到以下网络连接错误&#xff1a; [unavailable] unavailable: dial tcp 35.223.238.178:443: connectex: A connection attempt failed because the connected party did not properly respond after a period of…