构建医疗Mini DeepSeek R1:用强化学习训练

构建 医疗迷你 DeepSeek R1:用强化学习训练

在当今快速发展的技术时代,大语言模型(LLMs)与医疗的结合带来了无限的机遇和独特的挑战。本文探索如何利用 Group Relative Policy Optimization(GRPO)——由 DeepSeek 团队最近引入的有前景的强化学习技术,来调整阿里巴巴的 Qwen-3B 模型,使其能够进行医疗推理。

为什么这很重要?

  • 患者安全至上:医疗 AI 中的“幻觉”现象可能会带来危险。
  • 领域专业化:通用 LLM 在临床推理方面存在困难。
  • 效率: 3B 参数模型可以在消费级 GPU 上运行。

像 O3 和 DeepSeek R1 这样的推理模型在许多具有挑战性的基准测试中表现出前所未有的改进。它们改变了从监督微调到实际强化学习(RL)的趋势。许多深度学习领域的突破都来自于 RL,例如 AlphaGo,因为模型能够通过与不同真实场景的交互来学习,而这些场景在监督微调中很难提供示例。
在这里插入图片描述

DeepSeek R1 在关键基准测试中的表现

DeepSeek 实现了一个用于 LLM 微调的实用 GRPO 框架。

该算法的直觉是,它使所有导致正确或错误答案的选择更有可能或不太可能。这些选择可以是标记集或推理步骤。

正如下面的图所示:目标是激励模型在正确的 <reasoning><answer> 块中生成响应,以及一个可以轻松验证的最终正确答案(例如数学问题)。
在这里插入图片描述

案例实践

本文中使用的代码可以在 Colab 笔记本中轻松运行,使用 T4 免费套餐即可。

安装 Unsloth 和 TRL

开源技术已经取得了长足的进步。在这个教程中,将使用两个令人惊叹的开源库:

  • Unsloth:一个可以帮助我们从 GPU 中获取尽可能多的内存并提高训练性能的库。
  • TRL:一个来自 Hugging Face 的开源库,将实现 GRPO。

还将使用 Qlora 技术,以更节省内存的方式微调模型

!pip install unsloth vllm  # 内存高效的训练和推理
!pip install trl@git+https://github.com/huggingface/trl  # GRPO 实现from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

下载并初始化模型
并利用 50% 的 GPU 容量,结合 vLLM 推理,使用 Qlora 加速 GRPO 训练。

 from unsloth import is_bfloat16_supported
import torchmax_seq_length = 2048  # 可以增加以支持更长的推理路径
lora_rank = 64  # 较大的秩 = 更智能,但更慢model, tokenizer = FastLanguageModel.from_pretrained(model_name="Qwen/Qwen2.5-3B-Instruct",max_seq_length=max_seq_length,load_in_4bit=True,  # False 用于 16 位 LoRAfast_inference=True,  # 启用 vLLM 快速推理max_lora_rank=lora_rank,gpu_memory_utilization=0.5,  # 如果内存不足,可以减少
)model = FastLanguageModel.get_peft_model(model,r=lora_rank,  # 选择任意大于 0 的数字!建议 8、16、32、64、128target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",],  # 如果内存不足,可以移除 QKVOlora_alpha=lora_rank,use_gradient_checkpointing="unsloth",  # 启用长上下文微调random_state=3407,
)

关键选择
量化:支持在 16/24GB GPU 上进行训练(兼容 T4/A10)。
LoRA Rank 64:平衡性能与内存。
vLLM 集成:在 RL 中将生成速度提高 50%。

数据策略:

将使用 Hugging Face 的 interleave_datasets 混合三个关键数据集:

  • PubMedQA(占总数据的 70%):
    临床问答,答案为“是/否/也许”。
    为提高内存效率,筛选出少于 1024 个标记的数据。
  • GSM8K:
    数学文字问题,用于保持数值推理能力。
  • Health Benchmarks:
    50+ 个医学专业多选题,涵盖从心脏病学到疫苗接种等多个类别。

权重应反映数据集的复杂性——PubMedQA 的权重是其他数据集的 3 倍,以处理其细微差别。我们没有使用显式的权重,而是通过数据集的随机混洗来实现这一点,因为 PubMedQA 的样本数量是其他数据集的 3 倍,因此模型有 3 倍的机会接触到这些样本。

import re
from datasets import load_dataset, Dataset, interleave_datasets, concatenate_datasets# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""def extract_xml_answer(text: str) -> str:answer = text.split("<answer>")[-1]answer = answer.split("</answer>")[0]return answer.strip()def extract_hash_answer(text: str) -> str | None:if "####" not in text:return Nonereturn text.split("####")[1].strip()# uncomment middle messages for 1-shot prompting
def get_datasets(split = "train") -> Dataset:data = load_dataset('openai/gsm8k', 'main')[split] # type: ignoredata = data.map(lambda x: { # type: ignore'prompt': [{'role': 'system', 'content': SYSTEM_PROMPT},{'role': 'user', 'content': x['question']}],'answer': extract_hash_answer(x['answer']),'db_set':'gsm8k'}) # type: ignoredata = data.remove_columns(['question'])data_qa = load_dataset("qiaojin/PubMedQA", "pqa_artificial")[split] # two times more than other datasetsdata_qa = data_qa.filter(lambda x: len("\n".join(x['context']['contexts'])) < 1024) # avoid long tracesdata_qa = data_qa.map(lambda x: { # type: ignore'prompt': [{'role': 'system', 'content': SYSTEM_PROMPT},{"role": "user","content": "Given the scientific context below:\n" + "\n".join(x['context']['contexts']) + "\n\nAnswer the following question:\n" +x['question'] + " with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering."},],'answer': x['final_decision'],'db_set': 'pubmedqa'}) # type: ignoredata_qa = data_qa.remove_columns(['pubid', 'question', 'context', 'long_answer', 'final_decision'])categories =['Lab_Medicine', 'Wearables', 'Dermatology', 'Gastroenterology', 'Internal_Medicine', 'Oncology', 'Orthopedics', 'General_Surgery', 'Ophthalmology', 'Audiology', 'Head_Neck_Surgery', 'Elderly_Care', 'Pediatrics', 'Allergy_Immunology', 'Rheumatology', 'Pharmacy', 'Obstetrics_Gynecology', 'Microbiology', 'Dentistry', 'Physical_Medicine_and_Rehabilitation', 'Neurology', 'Psychiatry', 'Pathology', 'Genetics', 'Rare_Diseases', 'Hematology', 'Emergency', 'Endocrinology', 'Radiology', 'Cardiology', 'Pulmonology', 'Infectious_Diseases', 'Critical_Care', 'Pediatric_Surgery', 'Neuroscience', 'Epidemiology', 'Fitness_Sports', 'Health_Education', 'Health_Economics', 'Health_Entrepreneurship', 'Hospital_Management', 'Mental_Health', 'Nutrition', 'Palliative_Care', 'Preventive_Medicine', 'Public_Health', 'Social_Media_Addiction', 'Sleep', 'Supplements', 'Vaccination', 'Work_Health', 'Wellbeing']data_mc = concatenate_datasets([load_dataset("yesilhealth/Health_Benchmarks",i)[i] for i in categories])data_mc = data_mc.map(lambda x: { # type: ignore'prompt': [{'role': 'system', 'content': SYSTEM_PROMPT},{"role": "user","content": "\n\nAnswer the following question:\n" +x['Questions'] + "\n With 'A', 'B', 'C' or 'D'. You need to carefully review the context and reason before answering."},],'answer': x['Answers'],'db_set': 'med_mc'}) # type: ignoredata_mc = data_mc.remove_columns(['Answers', 'Questions'])dataset = concatenate_datasets([data, data_qa, data_mc])return dataset

奖励工程

多奖励系统既奖励推理结构,也奖励医疗准确度(详细的奖励函数请参阅笔记本):

def correctness_reward(responses, answers):# Gives 2.0 for exact matches, 1.0 for partialreturn [2.0 if match else (1.0 if partial else 0.0)...]def format_reward(completions):# Enforces <reasoning>...</answer> structurereturn [0.5 if re.match(XML_PATTERN) else 0.0...]

奖励层级:
正确(权重 50%):与实际答案对齐。
格式化(权重 30%):XML 风格的推理路径。
中间检查(权重 20%):有效的答案类型。

就像教导一名医学实习生——既表扬诊断的准确性,也表扬正确的文档记录。

GRPO 训练配置

可以根据自己的需求进行调整和实验。

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(use_vllm = True, # use vLLM for fast inference!learning_rate = 5e-6,adam_beta1 = 0.9,adam_beta2 = 0.99,weight_decay = 0.1,warmup_ratio = 0.1,lr_scheduler_type = "cosine",optim = "adamw_8bit",logging_steps = 1,bf16 = is_bfloat16_supported(),fp16 = not is_bfloat16_supported(),per_device_train_batch_size = 1,gradient_accumulation_steps = 1, # Increase to 4 for smoother trainingnum_generations = 6, # Decrease if out of memorymax_prompt_length = 1024,max_completion_length = 1024,#num_train_epochs = 1, # Set to 1 for a full training runmax_steps = 750,save_steps = 100,max_grad_norm = 0.1,report_to = "none", # Can use Weights & Biasesoutput_dir = "outputs",
)trainer = GRPOTrainer(model = model,processing_class = tokenizer,reward_funcs = [xmlcount_reward_func,soft_format_reward_func,strict_format_reward_func,int_reward_func,correctness_reward_func,],args = training_args,train_dataset = train_dataset,eval_dataset=test_dataset,
)
trainer.train()

由于 奖励函数质量很高—— 很快就能看到奖励函数的值迅速上升

在这里插入图片描述

模型开始具备推理能力

在这里插入图片描述

让我们来看一些模型在训练过程中生成的示例答案:

Prompt:Given the scientific context below:We have previously shown the detrimental effects of 19 sub-erythemal exposures to daily ultraviolet radiation (DUVR, which mimics non-extreme exposure conditions), delivered over 4 weeks to volunteers. This source had UVA (320-400 nm) to UVB (290-320 nm) irradiance ratio of 25, instead of that close to 10 that is typically the case with solar-simulated radiation (SSR) that represents summer global sunlight with a clear sky and quasi-zenith solar irradiance.
Here, we report on an extension of this previous study, in which we evaluated the photoprotection afforded by a broad-spectrum daily-care product with a low-sun protection factor (SPF 8, UVA-PF 7 and 3* rated UVA protection). We assessed cellular and molecular markers of photodamage that are relevant to skin cancer and photoageing.
This study shows that biological effects of repeated exposure to DUVR can be prevented by a broad-spectrum daily-care product and that the level of protection afforded varies with the studied endpoint.Answer the following question:
Does a broad-spectrum sunscreen prevent cumulative damage from repeated exposure to sub-erythemal solar ultraviolet radiation representative of temperate latitudes? with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering. Model Response:<reasoning>
The context provided discusses a study where repeated exposures to sub-erythemal UV radiation (DUVR) can cause detrimental effects over a period of 4 weeks. It also mentions that a broad-spectrum daily-care product with an SPF 8 and UVA protection factor with a 7 and 3* rating was used. The study aims to evaluate whether this broad-spectrum sunscreen can prevent cumulative damage from such repeated exposures. The mention of repeated exposure and the presence of a broad-spectrum product strongly suggests that the sunscreen can prevent cumulative damage. However, the actual effectiveness of the sunscreen in preventing cumulative damage is not explicitly stated in the provided context. Given the experimental setup and the corrective action (broad-spectrum product) taken, it can be inferred that the answer is likely 'yes'.The sunscreen contains SPF 8, which means it provides protection to around 92% of the UVB radiation, and it has a UVA protection factor (UVA-PF) of 7, which likely provides varying degrees of protection to different wavelengths, including those of UVA. Such broad-spectrum protection would effectively prevent the cumulative damage associated with prolonged exposure to sub-erythemal UV radiation.</reasoning><answer>
yes
</answer>

https://colab.research.google.com/github/hooman650/MedQwenReasoner/blob/main/MedQwen3B_Reasoner.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/21897.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在mfc中使用自定义三维向量类和计算多个三维向量的平均值

先添加一个普通类, Vector3.h, // Vector3.h: interface for the Vector3 class. // //#if !defined(AFX_VECTOR3_H__53D34D26_95FF_4377_BD54_57F4271918A4__INCLUDED_) #define AFX_VECTOR3_H__53D34D26_95FF_4377_BD54_57F4271918A4__INCLUDED_#if _MSC_VER > 1000 #p…

DeepSeek、微信、硅基流动、纳米搜索、秘塔搜索……十种不同方法实现DeepSeek使用自由

为了让大家实现 DeepSeek 使用自由&#xff0c;今天分享 10 个畅用 DeepSeek 的平台。 一、官方满血版&#xff1a;DeepSeek官网与APP 首推&#xff0c;肯定是 DeepSeek 的官网和 APP&#xff0c;可以使用满血版 R1 和 V3 模型&#xff0c;以及联网功能。 网址&#xff1a; htt…

Solon Cloud —— 介绍

说明 前面的章节&#xff0c;我们讲解了 Solon 的开发应用&#xff0c;接下来准备讲解 Solon Cloud 的的开发。Solon Cloud 是为微服务和云原生准备的分布式开发套件。 微服务 就像 MVC 一样&#xff0c;对于微服务的理解也是有不同的。微服务是一组协调工作的小而自治的服务…

python中的异常-模块-包

文章目录 异常异常的定义异常捕获语法捕获常规异常捕获指定异常捕获多个异常捕获所有异常异常else异常finally 异常传递总结 模块概念导入自定义模块及导入main方法all变量 总结 包自定义包定义pycharm中建包的基本步骤导入方式 第三方包 异常 异常的定义 当检测到一个错误时…

跟着柳叶刀数字健康,学习如何通过病理切片预测分子分类对预后的影响|项目复现

小罗碎碎念 项目复现 今天和大家分享一个非常具有参考价值的项目,手把手带着大家复现一篇发表在柳叶刀数字健康的文章。 花了六个小时才完成的这篇推送,信息量非常大,遇到了很多报错问题,但是解决以后的感觉是非常爽的,先给大家展示一下最终的成果——在同一张切片上,通…

Android Http-server 本地 web 服务

时间&#xff1a;2025年2月16日 地点&#xff1a;深圳.前海湾 需求 我们都知道 webview 可加载 URI&#xff0c;他有自己的协议 scheme&#xff1a; content:// 标识数据由 Content Provider 管理file:// 本地文件 http:// 网络资源 特别的&#xff0c;如果你想直接…

PyTorch 源码学习:阅读经验 代码结构

分享自己在学习 PyTorch 源码时阅读过的资料。本文重点关注阅读 PyTorch 源码的经验和 PyTorch 的代码结构。因为 PyTorch 不同版本的源码实现有所不同&#xff0c;所以笔者在整理资料时尽可能按版本号升序&#xff0c;版本号见标题前[]。最新版本的源码实现还请查看 PyTorch 仓…

Flowmix/Docx 多模态文档编辑器:新增【操作留痕】功能,让文档编辑有迹可循!...

hi, 大家好, 我是徐小夕. 最近 flowmix/docx 多模态文档编辑器新增了【操作留痕】功能&#xff1a; 体验地址: https://orange.turntip.cn/docx-react 在和大家分享更新功能之前&#xff0c;我简单介绍一下flowmix/docx 的【操作留痕】功能。 操作留痕功能就像是一位忠实的助手…

[每周一更]-(第135期):AI融合本地知识库(RAG),谁才是最强王者!

文章目录 简单看下DeepSeek满血版配置RAG是什么&#xff1f;**RAG 的核心原理**RAG的局限性**RAG 技术栈**技术实现 **RAG 的应用场景****RAG vs 传统 LLM****方法 1&#xff1a;配合 LangChain 加载知识库****方法 2&#xff1a;使用 Ollama****方法 3&#xff1a;结合 Anythi…

go orm GORM

官网&#xff1a;https://gorm.io/ docs&#xff1a;https://gorm.io/docs/ 博客&#xff1a;https://www.tizi365.com/archives/6.html import ("fmt""gorm.io/driver/mysql""gorm.io/gorm" )type Product struct {gorm.ModelCode stringPr…

毕业项目推荐:基于yolov8/yolo11的100种中药材检测识别系统(python+卷积神经网络)

文章目录 概要一、整体资源介绍技术要点功能展示&#xff1a;功能1 支持单张图片识别功能2 支持遍历文件夹识别功能3 支持识别视频文件功能4 支持摄像头识别功能5 支持结果文件导出&#xff08;xls格式&#xff09;功能6 支持切换检测到的目标查看 二、数据集三、算法介绍1. YO…

基于Python CNN和词向量的句子相似性度量

毕业设计&#xff1a;基于CNN和词向量的句子相似性度量 注意&#xff1a;因为要计算WMD距离所以需要安装依赖库pyemd 开发环境 Anaconda Pycharm 项目说明 按照老师要求复现论文(论文提出了一个新概念相似元&#xff0c;通过相似元来计算两个句子的相似度‘)&#xff0c;同…

CPU安装pytorch(别点进来)

终于&#xff01; 深度学习环境配置5——windows下的torch-cpu1.2.0环境配置_requirement怎么写torch cu-CSDN博客

Django-Vue 学习-VUE

主组件中有多个Vue组件 是指在Vue.js框架中&#xff0c;主组件是一个父组件&#xff0c;它包含了多个子组件&#xff08;Vue组件&#xff09;。这种组件嵌套的方式可以用于构建复杂的前端应用程序&#xff0c;通过拆分功能和视图&#xff0c;使代码更加模块化、可复用和易于维…

MATLAB基础学习相关知识

MATLAB安装参考&#xff1a;抖音-记录美好生活 MATLAB基础知识学习参考&#xff1a;【1小时Matlab速成教程-哔哩哔哩】 https://b23.tv/CnvHtO3 第1部分&#xff1a;变量定义和基本运算 生成矩阵&#xff1a; % 生成矩阵% 直接法% ,表示行 ;表示列 a [1,2,3;4,5,6;7,8,9];%…

TypeScript - 数据类型 - 声明变量

TypeScript 是一种强类型的 JavaScript 超集&#xff0c;它引入了静态类型检查和类型注解。基础类型是 TypeScript 中最基本的类型&#xff0c;用于定义变量的类型。 一、数据类型 常用基本类型&#xff1a;boolean 、number 、string 常用&#xff0c;都是小写 1.布尔类型&…

有序任务规划的局限性

有序任务规划的局限性&#xff08;Limitation of Ordered-Task Planning&#xff09; 1. 任务前向分解&#xff08;TFD&#xff09;的限制 TFD&#xff08;Task Forward Decomposition&#xff09;是一种 基于完全有序方法&#xff08;Totally Ordered Methods&#xff09;的任…

MATLAB学习之旅:数据插值与曲线拟合

在MATLAB的奇妙世界里,我们已经走过了一段又一段的学习旅程。从基础的语法和数据处理,到如今,我们即将踏入数据插值与曲线拟合这片充满魅力的领域。这个领域就像是魔法中的艺术创作,能够让我们根据现有的数据点,构建出更加丰富的曲线和曲面,从而更好地理解和描述数据背后…

ASP.NET Core 下载文件

本文使用 ASP .NET Core&#xff0c;适用于 .NET Core 3.1、.NET 5、.NET 6和.NET 8。 另请参阅&#xff1a; 如何在将文件发送到浏览器后自动删除该文件。 如何将文件从浏览器上传到服务器。 如何在 ASP.NET Core 应用程序中从 URL/URI 下载文件。 如果使用.NET Framework&am…

Part 3 第十二章 单元测试 Unit Testing

概述 第十二章围绕单元测试展开&#xff0c;阐述了单元测试的实践与重要性&#xff0c;通过对比其他测试类型&#xff0c;突出其特点&#xff0c;还介绍了单元测试的最佳实践、避免的反模式以及与测试替身相关的内容&#xff0c;为编写高质量单元测试提供指导。 章节概要 1…