基于ChatYuan-large-v2 微调训练 医疗问答 任务

一、ChatYuan-large-v2

上篇基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练了广告生成任务,总体生成效果还可以,但上篇文章的训练是微调的模型全部的参数,本篇文章还是以 ChatYuan-large-v2 作为基础模型,继续探索仅训练解码器层参数,并在医疗问答任务上的效果如何。

下面是上篇文章的地址:

基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练 广告生成 任务

二、数据集处理

数据集这里使用 GitHub 上的 Chinese-medical-dialogue-data 中文医疗对话数据集。

GitHub 地址如下:

https://github.com/Toyhom/Chinese-medical-dialogue-data

数据分了 6 个科目类型:

在这里插入图片描述

数据格式如下所示:

在这里插入图片描述

其中 ask 为病症的问题描述,answer 为病症的回答。

整体加起来数据比较多,这里为了演示效果,只训练 内科、肿瘤科、儿科、外科 四个科目的数据,并且每个科目取前 10000 条数据进行训练、2000 条数据进行验证:

import json
import pandas as pddata_path = ["./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000def doHandler():train_f = open(train_json_path, "a", encoding='utf-8')val_f = open(val_json_path, "a", encoding='utf-8')for path in data_path:data = pd.read_csv(path, encoding='ANSI')train_count = 0val_count = 0for index, row in data.iterrows():ask = row["ask"]answer = row["answer"]line = {"content": ask,"summary": answer}line = json.dumps(line, ensure_ascii=False)if train_count < train_size:train_f.write(line + "\n")train_count = train_count + 1elif val_count < val_size:val_f.write(line + "\n")val_count = val_count + 1else:breakprint("数据处理完毕!")train_f.close()val_f.close()if __name__ == '__main__':doHandler()

处理之后可以看到两个生成的文件:

在这里插入图片描述

下面基于上面的数据格式构建 Dataset

from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
import jsonclass SummaryDataSet(Dataset):def __init__(self, json_path: str, tokenizer, max_length=300):self.tokenizer = tokenizerself.max_length = max_lengthself.content_data = []self.summary_data = []with open(json_path, "r", encoding='utf-8') as f:for line in f:if not line or line == "":continuejson_line = json.loads(line)content = json_line["content"]summary = json_line["summary"]self.content_data.append(content)self.summary_data.append(summary)print("data load , size:", len(self.content_data))def __len__(self):return len(self.content_data)def __getitem__(self, index):source_text = str(self.content_data[index])target_text = str(self.summary_data[index])source = self.tokenizer.batch_encode_plus([source_text],max_length=self.max_length,pad_to_max_length=True,truncation=True,padding="max_length",return_tensors="pt",)target = self.tokenizer.batch_encode_plus([target_text],max_length=self.max_length,pad_to_max_length=True,truncation=True,padding="max_length",return_tensors="pt",)source_ids = source["input_ids"].squeeze()source_mask = source["attention_mask"].squeeze()target_ids = target["input_ids"].squeeze()target_mask = target["attention_mask"].squeeze()return {"source_ids": source_ids.to(dtype=torch.long),"source_mask": source_mask.to(dtype=torch.long),"target_ids": target_ids.to(dtype=torch.long)}

三、模型训练

构建训练过程,注意这里只训练解码层参数,因此需要将其他层的参数进行冻结:

# 只训练解码层
for name, param in model.named_parameters():if "decoder" not in name:param.requires_grad = False

整体训练过程如下:

# -*- coding: utf-8 -*-
import pandas as pd
import torch
from torch.utils.data import DataLoader
import os, time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from gen_dataset import SummaryDataSetdef train(epoch, tokenizer, model, device, loader, optimizer):model.train()time1 = time.time()for _, data in enumerate(loader, 0):y = data["target_ids"].to(device, dtype=torch.long)y_ids = y[:, :-1].contiguous()lm_labels = y[:, 1:].clone().detach()lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100ids = data["source_ids"].to(device, dtype=torch.long)mask = data["source_mask"].to(device, dtype=torch.long)outputs = model(input_ids=ids,attention_mask=mask,decoder_input_ids=y_ids,labels=lm_labels,)loss = outputs[0]# 每100步打印日志if _ % 100 == 0 and _ != 0:time2 = time.time()print(_, "epoch:" + str(epoch) + "-loss:" + str(loss) + ";each step's time spent:" + str(float(time2 - time1) / float(_ + 0.0001)))optimizer.zero_grad()loss.backward()optimizer.step()def validate(tokenizer, model, device, loader, max_length):model.eval()predictions = []actuals = []with torch.no_grad():for _, data in enumerate(loader, 0):y = data['target_ids'].to(device, dtype=torch.long)ids = data['source_ids'].to(device, dtype=torch.long)mask = data['source_mask'].to(device, dtype=torch.long)generated_ids = model.generate(input_ids=ids,attention_mask=mask,max_length=max_length,num_beams=2,repetition_penalty=2.5,length_penalty=1.0,early_stopping=True)preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g ingenerated_ids]target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]if _ % 100 == 0:print(f'Completed {_}')predictions.extend(preds)actuals.extend(target)return predictions, actualsdef T5Trainer(train_json_path, val_json_path, model_dir, batch_size, epochs, output_dir, max_length=300):tokenizer = T5Tokenizer.from_pretrained(model_dir)model = T5ForConditionalGeneration.from_pretrained(model_dir)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = model.to(device)# 只训练解码层for name, param in model.named_parameters():if "decoder" not in name:param.requires_grad = Falsetrain_params = {"batch_size": batch_size,"shuffle": True,"num_workers": 0,}training_set = SummaryDataSet(train_json_path, tokenizer, max_length=max_length)training_loader = DataLoader(training_set, **train_params)val_params = {"batch_size": batch_size,"shuffle": False,"num_workers": 0,}val_set = SummaryDataSet(val_json_path, tokenizer, max_length=max_length)val_loader = DataLoader(val_set, **val_params)optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)for epoch in range(epochs):train(epoch, tokenizer, model, device, training_loader, optimizer)print("保存模型")model.save_pretrained(output_dir)tokenizer.save_pretrained(output_dir)# 验证with torch.no_grad():predictions, actuals = validate(tokenizer, model, device, val_loader, max_length)# 验证结果存储final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})final_df.to_csv(os.path.join(output_dir, "predictions.csv"))if __name__ == '__main__':train_json_path = "./data/train.json"val_json_path = "./data/val.json"model_dir = "chatyuan_large_v2"batch_size = 5epochs = 5max_length = 300output_dir = "./model"T5Trainer(train_json_path,val_json_path,model_dir,batch_size,epochs,output_dir,max_length)

运行后可以看到如下日志打印,训练大概占用 21G 的显存,如果显存不够可以调低些 batch_size 的大小:

在这里插入图片描述

等待训练结束后:

在这里插入图片描述

可以在 model 下看到保存的模型:

在这里插入图片描述

这里可以看到 predictions.csv 验证集的效果:

在这里插入图片描述

四、模型测试

# -*- coding: utf-8 -*-
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch# 这里是模型下载的位置
model_dir = './model'tokenizer = T5Tokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)while True:text = input("请输入内容: \n ")if not text or text == "":continueif text == "q":breakencoded_input = tokenizer(text, padding="max_length", truncation=True, max_length=300)input_ids = torch.tensor([encoded_input['input_ids']])attention_mask = torch.tensor([encoded_input['attention_mask']])generated_ids = model.generate(input_ids=input_ids,attention_mask=attention_mask,max_length=300,num_beams=2,repetition_penalty=2.5,length_penalty=1.0,early_stopping=True)reds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g ingenerated_ids]print(reds)

1. 我准备过两天去看我叔叔,顺便带些人参,但是他有高血压,您好人参高血压可以吃吗?

模型回答:

人参具有补气的药物,高血压患者是可以吃人参的,但是需要有在医生的指导下口服,不能够盲目的口服,以免引来血压升高,平时要留意低盐低脂饮食,多吃许多新鲜的蔬菜和水果,不要熬夜,确保充足的深度睡眠,防止情绪震荡太大,不要生气或者是发火。调节身心健康,保持自身卫生.避免因抵抗力下降而导致细菌入侵。

在这里插入图片描述

2. 检测报告单五项指标乙肝表面抗原结果为阳性,抗体结果为阴性,e抗原为阳性e抗体为阴性,核心抗体为阳性。要有打预防工作疫苗么?

模型回答:

乙肝表面抗原阳性,抗体结果为阴性,e抗原为阳性,e抗体为阴性,核心抗体为阳性,证明既往传染过乙肝病毒,现在已经康复了,并且产生了保护性抗体。建议你到医院检查一下肝功能和HBV-DNA,如果都正常的话,就用不着打预防工作疫苗了,平时多留意歇息,别加班,别喝啤酒,不定期复检就可以了。患者直面乙肝困恼,我们应当积极采用救治措施,不要避忌就诊,如果乙肝久治不愈,那么就要及时对换治疗方案。

在这里插入图片描述

3. 男孩,11岁,上小学了,这一周,孩子一直扣耳朵,说痒,另外,偶尔还会有黄色的液体流出,另外,伴有一些低烧,请问:孩子中耳炎流黄水怎样治疗。

模型回答:

抗生素和消炎药是目前临床治疗小儿中耳炎的方法和药物,家长可对症进行选择,也可局部用药,比如给宝宝采取使用消炎类型的滴耳剂,孩子如果是以耳痛的症状为主的话,建议家长是可以适当的给点镇痛药缓解的,要是伴随发高烧的情况,那么根据孩子的症状使用药物,严重的情况请尽快去医院进行救治,以上方法都是目前治疗中耳炎比较常用的,此外要是出现耳膜穿孔的情况是需要及时的进行手术治疗的,与此同时也要注意给孩子安排合理的饮食和作息习惯,帮助疾病尽早康复。

在这里插入图片描述

4. 我家宝宝最近不知晓为什么脸上再次出现了湿疹的情况,而且好象还蛮严重的,宝宝看上去也不舒服,请问婴儿全身长湿疹怎么办

模型回答:

婴儿全身长湿疹,建议你先带孩子去正规的医院仔细检查下,根据医生指导恰当对症救治。平时要留意护理好孩子,始终保持皮肤的清洁干燥,防止搔抓,以免严重病情。患者日常要养成良好的生活习惯,规律作息,禁烟酒,多增加日常锻炼以此来提高患者免疫力,避免病情复发。

在这里插入图片描述

5. 我对我的身体情况一直都是比较重视的,所以隔段时间就会抽时间去了一医院检查一下,近来的时候患有了小肠气,小肠气是什么样的一种疾病,小肠气到哪治?

模型回答:

你好,您的情况是检查确诊疝气。需要手术治疗。疝气是临床上比较常见的疾病。多是由于腹壁存在薄弱环节,腹腔内压力增高引起。主要症状是腹部可复性肿物。直立或者腹内压增加肿物突出,平卧可以消失或者回纳。这种情况需要积极手术治疗。开腹手术或者腹腔镜微创手术都可以。同时避免腹内压增加因素,比如咳嗽,便秘,弯腰提重物,剧烈运动等,以免发生嵌顿或者术后复发。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/98208.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Centos 8 网卡connect: Network is unreachable错误解决办法

现象1、ifconfig没有ens160配置 [testlocalhost ~]$ ifconfig lo: flags73<UP,LOOPBACK,RUNNING> mtu 65536 inet 127.0.0.1 netmask 255.0.0.0 inet6 ::1 prefixlen 128 scopeid 0x10<host> loop txqueuelen 1000 (Local Loopba…

基于HTML+CSS+Echarts大屏数据可视化集合共99套

基于HTMLCSSEcharts大屏数据可视化集合共99套 一、介绍二、展示1.大数据展示系统2.物流订单系统3.物流信息系统4.办税渠道监控平台5.车辆综合管控平台 三、其他系统实现四、获取源码 一、介绍 基于HTML/CSS/Echarts的会议展览、业务监控、风险预警、数据分析展示等多种展示需求…

Python绘制爱心代码(七夕限定版)

写在前面&#xff1a; 又到了一年一度的七夕节啦&#xff01;你还在发愁送女朋友什么礼物&#xff0c;不知道怎样表达你满满的爱意吗&#xff1f;别担心&#xff0c;我来帮你&#xff01;今天&#xff0c;我将教你使用Python绘制一个跳动的爱心&#xff0c;用创意和幽默为这个…

【面试专题】Java核心基础篇②

&#x1f4c3;个人主页&#xff1a;个人主页 &#x1f525;系列专栏&#xff1a;Java面试专题 目录 1.接口和抽象类有什么区别&#xff1f; 2.两个对象的 hashCode() 相同&#xff0c;则 equals()也一定为 true&#xff0c;对吗&#xff1f; 3.说一说hashCode()和equals()的…

机器学习基础之《分类算法(3)—模型选择与调优》

作用是如何选择出最好的K值 一、什么是交叉验证&#xff08;cross validation&#xff09; 1、定义 交叉验证&#xff1a;将拿到的训练数据&#xff0c;分为训练和验证集。以下图为例&#xff1a;将数据分成5份&#xff0c;其中一份作为验证集。然后经过5次(组)的测试&#x…

【简单认识Docker基本管理】

文章目录 一、Docker概述1、定义2.容器化流行的原因3.Docker和虚拟机的区别4.Docker核心概念 二、安装docker三、镜像管理1.搜索镜像2.下载&#xff08;拉取&#xff09;镜像3.查看已下载镜像4.查看镜像详细信息5.修改镜像标签6.删除镜像7.导出镜像文件和拉取本地镜像文件8.上传…

cloud 问题

eureka 自我保护机制 eureka 是定时去拿 统计最近15分钟的eureka实例正常的心跳占比&#xff0c;如果低于85%&#xff0c;那么就会触发自我保护机制。 触发了自我保护机制&#xff0c;eureka 暂时会把失效的服务保护起来&#xff0c;不让其过期&#xff0c;但是这些服务也不是…

用讯飞星火认知大模型构建一个带界面的聊天机器人

在本篇博客中&#xff0c;我们将使用 最新的星火认知大模型和 Tkinter 构建一个简单的聊天机器人界面。通过这个界面&#xff0c;用户可以与聊天机器人进行交互&#xff0c;并查看机器人的回答。 C:\pythoncode\blog\static_16919367030710594_SparkApi_Python\testui.py 准备…

Flutter:简单搞一个内容高亮

内容高亮并不陌生&#xff0c;特别是在搜索内容页面&#xff0c;可以说四处可见&#xff0c;就拿掘金这个应用而言&#xff0c;针对某一个关键字&#xff0c;我们搜索之后&#xff0c;与关键字相同的内容&#xff0c;则会高亮展示&#xff0c;如下图所示&#xff1a; 如上的效果…

安装Vue_dev_tools

Vue控制台出现Download the Vue Devtools extension for a better development experience: 下载Vue_dev_tools,这里给出网盘链接&#xff0c;有Vue2和Vue3的&#xff0c;dev_tools 以Google浏览器为例 点击设置&#xff08;就是那三个点&#xff09;->扩展程序->管理扩…

pycharm上传项目到github,版本管理

前提&#xff1a;下载git 设置Git路径 登录Github 此时自动打开浏览器&#xff0c;并打开连接页面&#xff0c;点击 Authorize GitHub。登录&#xff1a; 创建本地仓库 提交到Github 填写初始提交相关信息 origin&#xff0c;它们只是远程服务器的一个别名&#xff0c;否则你就…

ATTCK覆盖度97.1%!360终端安全管理系统获赛可达认证

近日&#xff0c;国际知名第三方网络安全检测服务机构——赛可达实验室&#xff08;SKD Labs&#xff09;发布最新测试报告&#xff0c;360终端安全管理系统以ATT&CK V12框架攻击技术覆盖面377个、覆盖度97.1%&#xff0c;勒索病毒、挖矿病毒检出率100%&#xff0c;误报率0…

MFC140.dll缺失的修复方法,安装MFC140.dll文件

大家好&#xff0c;今天我要和大家分享的是如何正确安装和使用MFC140.dll。MFC140.dll是一种常见的动态链接库文件&#xff0c;它是Microsoft Foundation Classes(MFC)的一部分&#xff0c;被广泛应用于Windows操作系统中的各种应用程序中。在本文中&#xff0c;我们将详细介绍…

【C++入门到精通】C++入门 —— 容器适配器、stack和queue(STL)

阅读导航 前言stack1. stack概念2. stack特点3. stack使用 queue1. queue概念2. queue特点3. queue使用 容器适配器1. 什么是适配器2. STL标准库中stack和queue的底层结构3. STL标准库中对于stack和queue的模拟实现⭕stack的模拟实现⭕stack的模拟实现 总结温馨提示 前言 文章…

微信小程序:模板使用

目录 模板的优点&#xff1a; 一、静态模板创建 二、静态模板使用 1.*.wxml引入模板 2.模板使用 3.*.wxss引入模板的样式 三、动态模板创建 四、动态模板使用 1.*.wxml引入模板 2.模板使用 3.*.js定义动态数据 五、结果展示 总结 模板的优点&#xff1a; 有利于保持网…

记录--webpack和vite原理

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 前言 每次用vite创建项目秒建好&#xff0c;前几天用vue-cli创建了一个项目&#xff0c;足足等了我一分钟&#xff0c;那为什么用 vite 比 webpack 要快呢&#xff0c;这篇文章带你梳理清楚它们的原理…

【算法——双指针】LeetCode 11 盛最多水的容器

题目描述&#xff1a; 解题思路&#xff1a; 如图所示&#xff1a; 1、我们考虑相距最远的两个柱子所能容纳水的面积。宽度是两根柱子之间的距离8&#xff1b;高度取决于两根柱子之间较短的那个&#xff0c;即左边柱子的高度3。水的面积就是3824。 2、如果选择固定一根柱子&…

在线HmacMd5加密工具--在线获取哈希值又称摘要

具体请前往&#xff1a;在线计算HMacMd5工具

Flink-----Yarn应用模式作业提交流程

Yarn应用模式作业提交流程 在Yarn当中又分为Session&#xff0c;PerJob&#xff0c;Application&#xff0c;建议和推荐使用独立集群的&#xff0c;其中就包含PerJob 和Application&#xff0c;但是1.17版本的Flink已将PerJob标记为过时&#xff0c;并且Application可以解决Pe…

地球IT

地球是我们生活的家园&#xff0c;也是人类发展的基地。地球不仅仅是一个行星&#xff0c;更是一个复杂而有机的生态系统。 地球直径约为12,742公里&#xff0c;被称为“蓝色星球”&#xff0c;因为它的表面约70%被水覆盖。海洋是地球上最大的生态系统之一&#xff0c;它们扮演…