构建 医疗迷你 DeepSeek R1:用强化学习训练
在当今快速发展的技术时代,大语言模型(LLMs)与医疗的结合带来了无限的机遇和独特的挑战。本文探索如何利用 Group Relative Policy Optimization(GRPO)——由 DeepSeek 团队最近引入的有前景的强化学习技术,来调整阿里巴巴的 Qwen-3B 模型,使其能够进行医疗推理。
为什么这很重要?
- 患者安全至上:医疗 AI 中的“幻觉”现象可能会带来危险。
- 领域专业化:通用 LLM 在临床推理方面存在困难。
- 效率: 3B 参数模型可以在消费级 GPU 上运行。
像 O3 和 DeepSeek R1 这样的推理模型在许多具有挑战性的基准测试中表现出前所未有的改进。它们改变了从监督微调到实际强化学习(RL)的趋势。许多深度学习领域的突破都来自于 RL,例如 AlphaGo,因为模型能够通过与不同真实场景的交互来学习,而这些场景在监督微调中很难提供示例。
DeepSeek R1 在关键基准测试中的表现
DeepSeek 实现了一个用于 LLM 微调的实用 GRPO 框架。
该算法的直觉是,它使所有导致正确或错误答案的选择更有可能或不太可能。这些选择可以是标记集或推理步骤。
正如下面的图所示:目标是激励模型在正确的 <reasoning>
和 <answer>
块中生成响应,以及一个可以轻松验证的最终正确答案(例如数学问题)。
案例实践
本文中使用的代码可以在 Colab 笔记本中轻松运行,使用 T4 免费套餐即可。
安装 Unsloth 和 TRL
开源技术已经取得了长足的进步。在这个教程中,将使用两个令人惊叹的开源库:
- Unsloth:一个可以帮助我们从 GPU 中获取尽可能多的内存并提高训练性能的库。
- TRL:一个来自 Hugging Face 的开源库,将实现 GRPO。
还将使用 Qlora 技术,以更节省内存的方式微调模型
!pip install unsloth vllm # 内存高效的训练和推理
!pip install trl@git+https://github.com/huggingface/trl # GRPO 实现from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
下载并初始化模型
并利用 50% 的 GPU 容量,结合 vLLM 推理,使用 Qlora 加速 GRPO 训练。
from unsloth import is_bfloat16_supported
import torchmax_seq_length = 2048 # 可以增加以支持更长的推理路径
lora_rank = 64 # 较大的秩 = 更智能,但更慢model, tokenizer = FastLanguageModel.from_pretrained(model_name="Qwen/Qwen2.5-3B-Instruct",max_seq_length=max_seq_length,load_in_4bit=True, # False 用于 16 位 LoRAfast_inference=True, # 启用 vLLM 快速推理max_lora_rank=lora_rank,gpu_memory_utilization=0.5, # 如果内存不足,可以减少
)model = FastLanguageModel.get_peft_model(model,r=lora_rank, # 选择任意大于 0 的数字!建议 8、16、32、64、128target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",], # 如果内存不足,可以移除 QKVOlora_alpha=lora_rank,use_gradient_checkpointing="unsloth", # 启用长上下文微调random_state=3407,
)
关键选择
量化:支持在 16/24GB GPU 上进行训练(兼容 T4/A10)。
LoRA Rank 64:平衡性能与内存。
vLLM 集成:在 RL 中将生成速度提高 50%。
数据策略:
将使用 Hugging Face 的 interleave_datasets 混合三个关键数据集:
- PubMedQA(占总数据的 70%):
临床问答,答案为“是/否/也许”。
为提高内存效率,筛选出少于 1024 个标记的数据。 - GSM8K:
数学文字问题,用于保持数值推理能力。 - Health Benchmarks:
50+ 个医学专业多选题,涵盖从心脏病学到疫苗接种等多个类别。
权重应反映数据集的复杂性——PubMedQA 的权重是其他数据集的 3 倍,以处理其细微差别。我们没有使用显式的权重,而是通过数据集的随机混洗来实现这一点,因为 PubMedQA 的样本数量是其他数据集的 3 倍,因此模型有 3 倍的机会接触到这些样本。
import re
from datasets import load_dataset, Dataset, interleave_datasets, concatenate_datasets# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""def extract_xml_answer(text: str) -> str:answer = text.split("<answer>")[-1]answer = answer.split("</answer>")[0]return answer.strip()def extract_hash_answer(text: str) -> str | None:if "####" not in text:return Nonereturn text.split("####")[1].strip()# uncomment middle messages for 1-shot prompting
def get_datasets(split = "train") -> Dataset:data = load_dataset('openai/gsm8k', 'main')[split] # type: ignoredata = data.map(lambda x: { # type: ignore'prompt': [{'role': 'system', 'content': SYSTEM_PROMPT},{'role': 'user', 'content': x['question']}],'answer': extract_hash_answer(x['answer']),'db_set':'gsm8k'}) # type: ignoredata = data.remove_columns(['question'])data_qa = load_dataset("qiaojin/PubMedQA", "pqa_artificial")[split] # two times more than other datasetsdata_qa = data_qa.filter(lambda x: len("\n".join(x['context']['contexts'])) < 1024) # avoid long tracesdata_qa = data_qa.map(lambda x: { # type: ignore'prompt': [{'role': 'system', 'content': SYSTEM_PROMPT},{"role": "user","content": "Given the scientific context below:\n" + "\n".join(x['context']['contexts']) + "\n\nAnswer the following question:\n" +x['question'] + " with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering."},],'answer': x['final_decision'],'db_set': 'pubmedqa'}) # type: ignoredata_qa = data_qa.remove_columns(['pubid', 'question', 'context', 'long_answer', 'final_decision'])categories =['Lab_Medicine', 'Wearables', 'Dermatology', 'Gastroenterology', 'Internal_Medicine', 'Oncology', 'Orthopedics', 'General_Surgery', 'Ophthalmology', 'Audiology', 'Head_Neck_Surgery', 'Elderly_Care', 'Pediatrics', 'Allergy_Immunology', 'Rheumatology', 'Pharmacy', 'Obstetrics_Gynecology', 'Microbiology', 'Dentistry', 'Physical_Medicine_and_Rehabilitation', 'Neurology', 'Psychiatry', 'Pathology', 'Genetics', 'Rare_Diseases', 'Hematology', 'Emergency', 'Endocrinology', 'Radiology', 'Cardiology', 'Pulmonology', 'Infectious_Diseases', 'Critical_Care', 'Pediatric_Surgery', 'Neuroscience', 'Epidemiology', 'Fitness_Sports', 'Health_Education', 'Health_Economics', 'Health_Entrepreneurship', 'Hospital_Management', 'Mental_Health', 'Nutrition', 'Palliative_Care', 'Preventive_Medicine', 'Public_Health', 'Social_Media_Addiction', 'Sleep', 'Supplements', 'Vaccination', 'Work_Health', 'Wellbeing']data_mc = concatenate_datasets([load_dataset("yesilhealth/Health_Benchmarks",i)[i] for i in categories])data_mc = data_mc.map(lambda x: { # type: ignore'prompt': [{'role': 'system', 'content': SYSTEM_PROMPT},{"role": "user","content": "\n\nAnswer the following question:\n" +x['Questions'] + "\n With 'A', 'B', 'C' or 'D'. You need to carefully review the context and reason before answering."},],'answer': x['Answers'],'db_set': 'med_mc'}) # type: ignoredata_mc = data_mc.remove_columns(['Answers', 'Questions'])dataset = concatenate_datasets([data, data_qa, data_mc])return dataset
奖励工程
多奖励系统既奖励推理结构,也奖励医疗准确度(详细的奖励函数请参阅笔记本):
def correctness_reward(responses, answers):# Gives 2.0 for exact matches, 1.0 for partialreturn [2.0 if match else (1.0 if partial else 0.0)...]def format_reward(completions):# Enforces <reasoning>...</answer> structurereturn [0.5 if re.match(XML_PATTERN) else 0.0...]
奖励层级:
正确(权重 50%):与实际答案对齐。
格式化(权重 30%):XML 风格的推理路径。
中间检查(权重 20%):有效的答案类型。
就像教导一名医学实习生——既表扬诊断的准确性,也表扬正确的文档记录。
GRPO 训练配置
可以根据自己的需求进行调整和实验。
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(use_vllm = True, # use vLLM for fast inference!learning_rate = 5e-6,adam_beta1 = 0.9,adam_beta2 = 0.99,weight_decay = 0.1,warmup_ratio = 0.1,lr_scheduler_type = "cosine",optim = "adamw_8bit",logging_steps = 1,bf16 = is_bfloat16_supported(),fp16 = not is_bfloat16_supported(),per_device_train_batch_size = 1,gradient_accumulation_steps = 1, # Increase to 4 for smoother trainingnum_generations = 6, # Decrease if out of memorymax_prompt_length = 1024,max_completion_length = 1024,#num_train_epochs = 1, # Set to 1 for a full training runmax_steps = 750,save_steps = 100,max_grad_norm = 0.1,report_to = "none", # Can use Weights & Biasesoutput_dir = "outputs",
)trainer = GRPOTrainer(model = model,processing_class = tokenizer,reward_funcs = [xmlcount_reward_func,soft_format_reward_func,strict_format_reward_func,int_reward_func,correctness_reward_func,],args = training_args,train_dataset = train_dataset,eval_dataset=test_dataset,
)
trainer.train()
由于 奖励函数质量很高—— 很快就能看到奖励函数的值迅速上升
模型开始具备推理能力
让我们来看一些模型在训练过程中生成的示例答案:
Prompt:Given the scientific context below:We have previously shown the detrimental effects of 19 sub-erythemal exposures to daily ultraviolet radiation (DUVR, which mimics non-extreme exposure conditions), delivered over 4 weeks to volunteers. This source had UVA (320-400 nm) to UVB (290-320 nm) irradiance ratio of 25, instead of that close to 10 that is typically the case with solar-simulated radiation (SSR) that represents summer global sunlight with a clear sky and quasi-zenith solar irradiance.
Here, we report on an extension of this previous study, in which we evaluated the photoprotection afforded by a broad-spectrum daily-care product with a low-sun protection factor (SPF 8, UVA-PF 7 and 3* rated UVA protection). We assessed cellular and molecular markers of photodamage that are relevant to skin cancer and photoageing.
This study shows that biological effects of repeated exposure to DUVR can be prevented by a broad-spectrum daily-care product and that the level of protection afforded varies with the studied endpoint.Answer the following question:
Does a broad-spectrum sunscreen prevent cumulative damage from repeated exposure to sub-erythemal solar ultraviolet radiation representative of temperate latitudes? with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering. Model Response:<reasoning>
The context provided discusses a study where repeated exposures to sub-erythemal UV radiation (DUVR) can cause detrimental effects over a period of 4 weeks. It also mentions that a broad-spectrum daily-care product with an SPF 8 and UVA protection factor with a 7 and 3* rating was used. The study aims to evaluate whether this broad-spectrum sunscreen can prevent cumulative damage from such repeated exposures. The mention of repeated exposure and the presence of a broad-spectrum product strongly suggests that the sunscreen can prevent cumulative damage. However, the actual effectiveness of the sunscreen in preventing cumulative damage is not explicitly stated in the provided context. Given the experimental setup and the corrective action (broad-spectrum product) taken, it can be inferred that the answer is likely 'yes'.The sunscreen contains SPF 8, which means it provides protection to around 92% of the UVB radiation, and it has a UVA protection factor (UVA-PF) of 7, which likely provides varying degrees of protection to different wavelengths, including those of UVA. Such broad-spectrum protection would effectively prevent the cumulative damage associated with prolonged exposure to sub-erythemal UV radiation.</reasoning><answer>
yes
</answer>
https://colab.research.google.com/github/hooman650/MedQwenReasoner/blob/main/MedQwen3B_Reasoner.ipynb