For implementation detailes, please refer to DiagRL github repository

Introduction

This repository contains the 7B version of DiagRL Model trained through retrieval-augmented reinforcement learning for medical diagnosis. Our model is post-trained from Qwen2.5-7B-Instruct, keeping general model ability while achieving better retrieval-reasoning abilities in medical domain.

The model can be loaded and directly inferenced as a general-purpose medical LLM as:

Basic Usage


from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "QiaoyuZheng/DiagRL-7B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "A patient suffers from fever, pain, feeling very tired, paleness in face, frequent infections, easy infections and bruising, bleeding with no clear cause, such as in the nose or gums and shortness of breath."
messages = [
    {"role": "system", "content": "You are DiagRL, created by SJTU. You are a helpful agent on diagnosis."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

Advanced Usage

When assisted by customized retrieval corpus, the model can interact with the environment through LangChain framework

Detailed usage is comming soon ...

Benchmark

We show the main results of DiagRL and compare it to vanilla models and other baselines. For more details, please refer to our paper.

Table: Main diagnosis performance.
We calculate top-1 and top-5 accuracy among common and rare disease diagnosis datasets and compare our \ModelName{} with other representative models. “Env” means we allow the model to use our proposed environment as assistance. All results are shown in percentage.

Model MIMIC-C Acc@1 MIMIC-C Acc@5 PMC-Patient Acc@1 PMC-Patient Acc@5 MedDialog Acc@1 MedDialog Acc@5 MIMIC-R Acc@1 MIMIC-R Acc@5 RareArena Acc@1 RareArena Acc@5 RareBench Acc@1 RareBench Acc@5
Qwen-2.5-14B 8.80 12.40 17.73 27.66 17.87 32.34 7.93 16.71 6.53 13.23 18.07 31.38
Baichuan-M1 11.8 14.48 26.95 39.84 26.81 38.85 8.35 19.25 10.69 21.63 26.93 44.79
DeepSeek-R1 5.65 15.32 29.62 41.52 28.34 40.96 12.05 23.90 10.98 22.56 28.22 50.83
GPT-4o 6.43 9.82 23.51 36.10 22.59 36.01 7.65 15.58 12.83 23.10 24.25 43.54
Qwen14B (Env) 13.22 15.91 24.38 35.57 24.69 36.22 16.54 24.33 10.08 15.47 34.70 59.20
GPT-4o (Env) 15.07 21.25 28.64 38.38 25.86 39.41 20.47 29.05 11.24 19.32 40.11 63.28
Ours (Llama8B) 21.05 27.83 34.15 45.74 35.51 46.92 42.00 55.02 22.41 29.95 64.33 73.86
Ours (Qwen7B) 33.09 42.87 41.41 46.80 49.28 55.34 52.44 61.53 25.97 35.32 64.47 79.51
Downloads last month
99
Safetensors
Model size
7.62B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for QiaoyuZheng/DiagRL-7B

Base model

Qwen/Qwen2.5-7B
Finetuned
(634)
this model
Quantizations
1 model

Datasets used to train QiaoyuZheng/DiagRL-7B

Collection including QiaoyuZheng/DiagRL-7B