For implementation detailes, please refer to DiagRL github repository
Introduction
This repository contains the 7B version of DiagRL Model trained through retrieval-augmented reinforcement learning for medical diagnosis. Our model is post-trained from Qwen2.5-7B-Instruct, keeping general model ability while achieving better retrieval-reasoning abilities in medical domain.
The model can be loaded and directly inferenced as a general-purpose medical LLM as:
Basic Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "QiaoyuZheng/DiagRL-7B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "A patient suffers from fever, pain, feeling very tired, paleness in face, frequent infections, easy infections and bruising, bleeding with no clear cause, such as in the nose or gums and shortness of breath."
messages = [
{"role": "system", "content": "You are DiagRL, created by SJTU. You are a helpful agent on diagnosis."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
Advanced Usage
When assisted by customized retrieval corpus, the model can interact with the environment through LangChain framework
Detailed usage is comming soon ...
Benchmark
We show the main results of DiagRL and compare it to vanilla models and other baselines. For more details, please refer to our paper.
Table: Main diagnosis performance.
We calculate top-1 and top-5 accuracy among common and rare disease diagnosis datasets and compare our \ModelName{} with other representative models. “Env” means we allow the model to use our proposed environment as assistance. All results are shown in percentage.
Model | MIMIC-C Acc@1 | MIMIC-C Acc@5 | PMC-Patient Acc@1 | PMC-Patient Acc@5 | MedDialog Acc@1 | MedDialog Acc@5 | MIMIC-R Acc@1 | MIMIC-R Acc@5 | RareArena Acc@1 | RareArena Acc@5 | RareBench Acc@1 | RareBench Acc@5 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Qwen-2.5-14B | 8.80 | 12.40 | 17.73 | 27.66 | 17.87 | 32.34 | 7.93 | 16.71 | 6.53 | 13.23 | 18.07 | 31.38 |
Baichuan-M1 | 11.8 | 14.48 | 26.95 | 39.84 | 26.81 | 38.85 | 8.35 | 19.25 | 10.69 | 21.63 | 26.93 | 44.79 |
DeepSeek-R1 | 5.65 | 15.32 | 29.62 | 41.52 | 28.34 | 40.96 | 12.05 | 23.90 | 10.98 | 22.56 | 28.22 | 50.83 |
GPT-4o | 6.43 | 9.82 | 23.51 | 36.10 | 22.59 | 36.01 | 7.65 | 15.58 | 12.83 | 23.10 | 24.25 | 43.54 |
Qwen14B (Env) | 13.22 | 15.91 | 24.38 | 35.57 | 24.69 | 36.22 | 16.54 | 24.33 | 10.08 | 15.47 | 34.70 | 59.20 |
GPT-4o (Env) | 15.07 | 21.25 | 28.64 | 38.38 | 25.86 | 39.41 | 20.47 | 29.05 | 11.24 | 19.32 | 40.11 | 63.28 |
Ours (Llama8B) | 21.05 | 27.83 | 34.15 | 45.74 | 35.51 | 46.92 | 42.00 | 55.02 | 22.41 | 29.95 | 64.33 | 73.86 |
Ours (Qwen7B) | 33.09 | 42.87 | 41.41 | 46.80 | 49.28 | 55.34 | 52.44 | 61.53 | 25.97 | 35.32 | 64.47 | 79.51 |
- Downloads last month
- 99