基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练 广告生成 任务

一、ChatYuan-large-v2

ChatYuan-large-v2是一个开源的支持中英双语的功能型对话语言大模型,与其他 LLM 不同的是模型十分轻量化,并且在轻量化的同时效果相对还不错,仅仅通过0.7B参数量就可以实现10B模型的基础效果,正是其如此的轻量级,使其可以在普通显卡、 CPU、甚至手机上进行推理,而且 INT4 量化后的最低只需 400M

v2 版本相对于以前的 v1 版本,是使用了相同的技术方案,但在指令微调、人类反馈强化学习、思维链等方面进行了优化。

在本专栏前面文章介绍了 ChatYuan-large-v2langchain 相结合的使用,地址如下:

LangChain 本地化方案 - 使用 ChatYuan-large-v2 作为 LLM 大语言模型

本篇文章以 ChatYuan-large-v2 模型为基础 Fine-tuning 广告生成 任务。

二、数据集处理

数据集这里使用 ChatGLM 官方在 Fine-tuning 中使用到的 广告生成 数据集。

下载地址如下:

https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view

数据已 JSON 的形式存放,分为了 traindev 两种类型:

在这里插入图片描述
数据格式如下所示:

{"content":"类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤","summary":"宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"
}
{"content":"类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤","summary":"宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"
}
{"content":"类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖","summary":"圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆,深邃富有诗意。收腰包臀,修饰女性身体曲线,结合别出心裁的鱼尾裙摆设计,勾勒出自然流畅的身体轮廓,展现了婀娜多姿的迷人姿态。"
}
{"content":"类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款式#不规则","summary":"宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不规则剪裁设计,彰显出时尚前卫的形态。被剪裁过的样式呈现出布条状自然地垂坠下来,别具有一番设计感。线条分明的字母样式有着花式的外观,棱角分明加上具有少女元气的枣红色十分有年轻活力感。粉红色的衣身把肌肤衬托得很白嫩又健康。"
}
{"content":"类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙","summary":"踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄托在随风摇曳的雪纺连衣裙上,吐露出<UNK>微妙而又浪漫的清新之意。宽松的a字版型除了能够带来足够的空间,也能以上窄下宽的方式强化立体层次,携带出自然优雅的曼妙体验。"
}
{"content":"类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短袖*衣款式#拼接","summary":"想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔软纯棉面料,让您紧跟时尚潮流。再配合上潮流的蓝色拼接设计,使您的风格更加出众。就算单从选料上来说,这款polo衫的颜色沉稳经典,是这个季度十分受大众喜爱的风格了,而且兼具舒适感和时尚感。"
}

其中任务的方式为根据输入(content)生成一段广告词(summary)。

train.json 共有 114599 条记录,这里为了演示效果取前 50000 条数据进行训练、5000 条数据进行验证:

import os# 将训练集进行提取
def doHandle(json_path, train_size, val_size, out_json_path):train_count = 0val_count = 0train_f = open(os.path.join(out_json_path, "train.json"), "a", encoding='utf-8')val_f = open(os.path.join(out_json_path, "val.json"), "a", encoding='utf-8')with open(json_path, "r", encoding='utf-8') as f:for line in f:if train_count < train_size:train_f.writelines(line)train_count = train_count + 1elif val_count < val_size:val_f.writelines(line)val_count = val_count + 1else:breakprint("数据处理完毕!")train_f.close()val_f.close()if __name__ == '__main__':json_path = "./data/AdvertiseGen/train.json"out_json_path = "./data/"train_size = 50000val_size = 5000doHandle(json_path, train_size, val_size, out_json_path)

处理之后可以看到两个生成的文件:

在这里插入图片描述

下面基于上面的数据格式构建 Dataset

from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
import jsonclass SummaryDataSet(Dataset):def __init__(self, json_path: str, tokenizer, max_length=300):self.tokenizer = tokenizerself.max_length = max_lengthself.content_data = []self.summary_data = []with open(json_path, "r", encoding='utf-8') as f:for line in f:if not line or line == "":continuejson_line = json.loads(line)content = json_line["content"]summary = json_line["summary"]self.content_data.append(content)self.summary_data.append(summary)print("data load , size:", len(self.content_data))def __len__(self):return len(self.content_data)def __getitem__(self, index):source_text = str(self.content_data[index])target_text = str(self.summary_data[index])source = self.tokenizer.batch_encode_plus([source_text],max_length=self.max_length,pad_to_max_length=True,truncation=True,padding="max_length",return_tensors="pt",)target = self.tokenizer.batch_encode_plus([target_text],max_length=self.max_length,pad_to_max_length=True,truncation=True,padding="max_length",return_tensors="pt",)source_ids = source["input_ids"].squeeze()source_mask = source["attention_mask"].squeeze()target_ids = target["input_ids"].squeeze()target_mask = target["attention_mask"].squeeze()return {"source_ids": source_ids.to(dtype=torch.long),"source_mask": source_mask.to(dtype=torch.long),"target_ids": target_ids.to(dtype=torch.long),"target_ids_y": target_ids.to(dtype=torch.long),}

三、模型训练

下载 ChatYuan-large-v2 模型:

https://huggingface.co/ClueAI/ChatYuan-large-v2/tree/main

在这里插入图片描述
在这里插入图片描述

下面基于 ChatYuan-large-v2 进行训练:

import pandas as pd
import torch
from torch.utils.data import DataLoader
import os, time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from gen_dataset import SummaryDataSetdef train(epoch, tokenizer, model, device, loader, optimizer):model.train()time1 = time.time()for _, data in enumerate(loader, 0):y = data["target_ids"].to(device, dtype=torch.long)y_ids = y[:, :-1].contiguous()lm_labels = y[:, 1:].clone().detach()lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100ids = data["source_ids"].to(device, dtype=torch.long)mask = data["source_mask"].to(device, dtype=torch.long)outputs = model(input_ids=ids,attention_mask=mask,decoder_input_ids=y_ids,labels=lm_labels,)loss = outputs[0]# 每100步打印日志if _ % 100 == 0 and _ != 0:time2 = time.time()print(_, "epoch:" + str(epoch) + "-loss:" + str(loss) + ";each step's time spent:" + str(float(time2 - time1) / float(_ + 0.0001)))optimizer.zero_grad()loss.backward()optimizer.step()def validate(epoch, tokenizer, model, device, loader, max_length):model.eval()predictions = []actuals = []with torch.no_grad():for _, data in enumerate(loader, 0):y = data['target_ids'].to(device, dtype=torch.long)ids = data['source_ids'].to(device, dtype=torch.long)mask = data['source_mask'].to(device, dtype=torch.long)generated_ids = model.generate(input_ids=ids,attention_mask=mask,max_length=max_length,num_beams=2,repetition_penalty=2.5,length_penalty=1.0,early_stopping=True)preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g ingenerated_ids]target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]if _ % 1000 == 0:print(f'Completed {_}')predictions.extend(preds)actuals.extend(target)return predictions, actualsdef T5Trainer(train_json_path, val_json_path, model_dir, batch_size, epochs, output_dir, max_length=300):tokenizer = T5Tokenizer.from_pretrained(model_dir)model = T5ForConditionalGeneration.from_pretrained(model_dir)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = model.to(device)train_params = {"batch_size": batch_size,"shuffle": True,"num_workers": 0,}training_set = SummaryDataSet(train_json_path, tokenizer, max_length=max_length)training_loader = DataLoader(training_set, **train_params)val_params = {"batch_size": batch_size,"shuffle": False,"num_workers": 0,}val_set = SummaryDataSet(val_json_path, tokenizer, max_length=max_length)val_loader = DataLoader(val_set, **val_params)optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)for epoch in range(epochs):train(epoch, tokenizer, model, device, training_loader, optimizer)print("保存模型")model.save_pretrained(output_dir)tokenizer.save_pretrained(output_dir)# 验证with torch.no_grad():predictions, actuals = validate(epoch, tokenizer, model, device, val_loader, max_length)# 验证结果存储final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})final_df.to_csv(os.path.join(output_dir, "predictions.csv"))if __name__ == '__main__':train_json_path = "./data/train.json"val_json_path = "./data/val.json"# 下载模型目录位置model_dir = "chatyuan_large_v2"batch_size = 5epochs = 1max_length = 300output_dir = "./model"# 开始训练T5Trainer(train_json_path,val_json_path,model_dir,batch_size,epochs,output_dir,max_length)

运行后可以看到如下日志打印,训练大概占用 33G 的显存,如果显存不够可以调低些 batch_size 的大小:

在这里插入图片描述

等待训练结束后:

在这里插入图片描述

可以在 model 下看到保存的模型:

在这里插入图片描述

这里可以先看下 predictions.csv 验证集的效果:

在这里插入图片描述

可以看到模型生成的结果有点不太好,这里仅对前 50000 条进行了训练,并且就训练了一个 epoch ,后面可以增加数据集大小和增加 epoch 应该能达到更好的效果,下面通过调用模型测试一下生成的文本效果。

四、模型测试

# -*- coding: utf-8 -*-
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch# 这里是模型下载的位置
model_dir = './model'tokenizer = T5Tokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)while True:text = input("请输入内容: \n ")if not text or text == "":continueif text == "q":breakencoded_input = tokenizer(text, padding="max_length", truncation=True, max_length=300)input_ids = torch.tensor([encoded_input['input_ids']])attention_mask = torch.tensor([encoded_input['attention_mask']])generated_ids = model.generate(input_ids=input_ids,attention_mask=attention_mask,max_length=300,num_beams=2,repetition_penalty=2.5,length_penalty=1.0,early_stopping=True)reds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g ingenerated_ids]print(reds)

效果测试:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/78912.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VBA技术资料MF40:VBA_计数筛选状态的数据行数

【分享成果&#xff0c;随喜正能量】人唯有与喜欢的事物发展关系&#xff0c;不管是人或者是物还是事&#xff0c;包括喜欢自己外表、个性的部分&#xff0c;喜欢自己做的事&#xff0c;喜欢自己的创造&#xff0c;喜欢的风景……才给人带来对自己的认同。在与喜欢的事物互动关…

【Linux命令详解 | cd命令】Linux系统中用于更改当前工作目录的命令

文章标题 简介一&#xff0c;参数列表二&#xff0c;使用介绍1. 使用cd命令切换到特定目录2. 使用cd命令与路径相关的特殊字符3. 使用cd命令切换到包含空格的目录4. 使用cd命令切换到前一个和后一个目录5. 使用cd命令切换到用户的主目录6. 使用cd命令与绝对路径和相对路径 总结…

Vue day01

Vue 1.简介&#xff1a; ​ Vue是一套用于构建用户界面的渐进式框架。与其他大型框架不同的是&#xff0c;Vue被设计为可以自底向上逐层应用。Vue的核心库只关注视图层&#xff0c;不仅容易上手&#xff0c;还便于与第三方库或既有项目整合。另一方面&#xff0c;当与现代化的工…

【Docker】DockerFile

目录 一、镜像原理 二、如何制作镜像 1、容器转镜像 2、DockerFile 三、DockerFile关键字​编辑 四、案例&#xff1a;部署SpringBoot项目 一、镜像原理 docker镜像是由一个特殊的文件系统叠加而成的&#xff0c;他的最低端是bootfs&#xff0c;并使用宿主机的bootfs&…

Maven设置阿里云路径(防止加载过慢)

<?xml version"1.0" encoding"UTF-8"?><!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding …

【web逆向】全报文加密流量的去加密测试方案

aHR0cHM6Ly90ZGx6LmNjYi5jb20vIy9sb2dpbg 国密混合 WEB JS逆向篇 先看报文&#xff1a;请求和响应都是全加密&#xff0c;这种情况就不像参数加密可以方便全文搜索定位加密代码&#xff0c;但因为前端必须解密响应的密文&#xff0c;因此万能的方法就是搜索拦截器&#xff0c…

Golang之路---04 并发编程——WaitGroup

WaitGroup 为了保证 main goroutine 在所有的 goroutine 都执行完毕后再退出&#xff0c;前面使用了 time.Sleep 这种简单的方式。 由于写的 demo 都是比较简单的&#xff0c; sleep 个 1 秒&#xff0c;我们主观上认为是够用的。 但在实际开发中&#xff0c;开发人员是无法…

Swish - Mac 触控板手势窗口管理工具[macOS]

Swish for Mac是一款Mac触控板增强工具&#xff0c;借助直观的两指轻扫&#xff0c;捏合&#xff0c;轻击和按住手势&#xff0c;就可以从触控板上控制窗口和应用程序。 Swish for Mac又不仅仅只是一个窗口管理器&#xff0c;Swish具有28个易于使用的标题栏&#xff0c;停靠栏…

设计模式之策略模式(Strategy)

一、概述 定义一系列的算法&#xff0c;把它们一个个封装起来,并且使它们可相互替换。本模式使得算法可独立于使用它的类而变化。 二、适用性 1.许多相关的类仅仅是行为有异。“策略”提供了一种用多个行为中的一个行为来配置一个类的方法。 2.需要使用一个算法的不同变体。…

【DFS】17. 电话号码的字母组合

【DFS】17. 电话号码的字母组合 Halo&#xff0c;这里是Ppeua。平时主要更新C语言&#xff0c;C&#xff0c;数据结构算法…感兴趣就关注我bua&#xff01; TOC 题目: [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rikjhVFD-1691333891079)(C:\Us…

高成本帖子——RDK X3模组快速体验载板辅助选型

转眼&#xff0c;又是一年了&#xff0c;还记得22年6月自己第一次体验X3开发板 [首发] 多方位玩转“地平线新发布AIoT开发板——旭日X3派(Sunrise x3 Pi)” 。至今还记得拿到手的新奇与兴奋&#xff0c;还记得第一次玩BPU就遇到了大坑&#xff0c;第一次发帖就把论坛搞崩。如今…

34.利用matlab解 多变量多目标规划问题(matlab程序)

1.简述 学习目标&#xff1a;适合解 多变量多目标规划问题&#xff0c;例如 收益最大&#xff0c;风险最小 主要目标法&#xff0c;线性加权法&#xff0c;权值我们可以自己设定。 收益函数是 70*x(1)66*x(2) &#xff1b; 风险函数是 0.02*x(1)^20.01*x(2)^20.04*(x…

Tcp的粘包和半包问题及解决方案

目录 粘包&#xff1a; 半包&#xff1a; 应用进程如何解读字节流&#xff1f;如何解决粘包和半包问题&#xff1f; ①&#xff1a;固定长度 ②&#xff1a;分隔符 ③&#xff1a;固定长度字段存储内容的长度信息 粘包&#xff1a; 一次接收到多个消息&#xff0c;粘包 应…

第4章 变量、作用域与内存

引言 由于js是一门只有在声明变量后才能明确类型的语言&#xff0c;并且在任意时刻都可以改变数据类型。这也引起了一些问题 原始值与引用值 原始值就是基本数据类型&#xff0c;引言值就是复杂数据类型 变量在赋值的时候。js会判断如果是原始值&#xff0c;访问时就是按值访问…

1、sparkStreaming概述

1、sparkStreaming概述 1.1 SparkStreaming是什么 它是一个可扩展&#xff0c;高吞吐具有容错性的流式计算框架 吞吐量&#xff1a;单位时间内成功传输数据的数量 之前我们接触的spark-core和spark-sql都是处理属于离线批处理任务&#xff0c;数据一般都是在固定位置上&…

python简单知识点大全

python简单知识点大全 一、变量二、字符串三、比较运算符四、随机数4.1、随机整数4.2、随机浮点数4.3、随机数重现 五、数字类型5.1、整数5.2、浮点数5.3、复数 六、数字运算七、内置函数八、布尔类型八、逻辑运算符九、短路逻辑和运算符优先级十、分支和循环10.1、if语句10.2、…

相机传感器格式与镜头光圈参数

相机靶面大小 CCD/CMOS图像传感器尺寸&#xff08;sensor format&#xff09;1/2’‘、1/3’‘、1/4’实际是多大 1英寸——靶面尺寸为宽12.7mm*高9.6mm&#xff0c;对角线16mm。 2/3英寸——靶面尺寸为宽8.8mm*高6.6mm&#xff0c;对角线11mm。 1/2英寸——靶面尺寸为宽6.…

maven安装(windows)

环境 maven&#xff1a;Apache Maven 3.5.2 jdk环境&#xff1a;jdk 1.8.0_192 系统版本&#xff1a;win10 一、安装 apache官网下载需要的版本&#xff0c;然后解压缩&#xff0c;解压路径尽量不要有空格和中文 官网下载地址 https://maven.apache.org/download.cgihttps:…

微信小程序animation动画,微信小程序animation动画无限循环播放

需求是酱紫的&#xff1a; 页面顶部的喇叭通知&#xff0c;内容不固定&#xff0c;宽度不固定&#xff0c;就是做走马灯&#xff08;轮播&#xff09;效果&#xff0c;从左到右的走马灯&#xff08;轮播&#xff09;&#xff0c;每播放一遍暂停 1500ms &#xff5e; 2000ms 刚…

ARCGIS地理配准出现的问题

第一种。已有省级行政区矢量数据&#xff0c;在网上随便找一个相同省级行政区图片&#xff0c;利用地理配准工具给图片添加坐标信息。 依次添加省级行政区选择矢量数据、浙江省图片。 此时&#xff0c;图层默认的坐标系与第一个加载进来的省级行政区选择矢量数据的坐标系一致…