MedSLM — Medical Small Language Model (~381M Parameters)

A 381M parameter transformer language model pre-trained on curated medical text from PubMed abstracts, PMC full-text articles, and clinical guidelines.

Architecture

MedSLM uses a modern GPT-style transformer with several architectural improvements over the standard GPT-2 design:

Component Detail
Normalization RMSNorm (faster than LayerNorm, used in LLaMA/Mistral)
Positional Encoding Rotary Positional Embeddings (RoPE) — better length generalization
Feed-Forward SwiGLU activation (gated FFN, outperforms GELU)
Attention Grouped-Query Attention (GQA) — shared KV heads for efficiency
Layers 24 transformer blocks
Attention Heads 16 query heads, 8 KV heads
Embedding Dim 1024
Context Length 1024 tokens
Vocab Size 50,257 (GPT-2 BPE tokenizer)
Parameters 381,373,440 (~381M)

Training

  • Dataset: Saminx22/medical_data_for_slm (~44M tokens)
  • Sources: PubMed abstracts, PMC Open Access full-text, Clinical Guidelines
  • Tokenizer: GPT-2 BPE tokenizer (50,257 vocab)
  • Optimizer: AdamW (betas=0.9/0.95, weight_decay=0.1)
  • LR Schedule: Linear warmup (1000 steps) + Cosine decay
  • Peak LR: 0.0003
  • Precision: bfloat16
  • Effective Batch Size: 256
  • Max Steps: 20,000
  • Best Val Loss: 3.2198 (at step 19500)

Usage

Loading the Model

import torch
import json
from safetensors.torch import load_file
from transformers import AutoTokenizer

# Load config
with open("config.json") as f:
    config_dict = json.load(f)

# Reconstruct model (requires the MedSLM class definition)
config = MedSLMConfig(**{k: v for k, v in config_dict.items()
                         if k in MedSLMConfig.__dataclass_fields__})
model = MedSLM(config)

# Load weights
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("tokenizer/")

Generating Text

prompt = "The patient presented with acute myocardial infarction"
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

output = model.generate(input_ids, max_new_tokens=200, temperature=0.8, top_k=50, top_p=0.9)
print(tokenizer.decode(output.squeeze().tolist()))

Resuming Training

# Load optimizer state
optimizer_state = torch.load("optimizer.pt")
optimizer.load_state_dict(optimizer_state)

Files

File Description
model.safetensors Model weights (safetensors format)
optimizer.pt Optimizer state dict for resuming training
config.json Model architecture configuration
training_config.json Training hyperparameters and loss history
tokenizer/ GPT-2 tokenizer files
loss_curves.png Training/validation loss plot

Intended Use

This model is intended for research purposes in medical NLP. It can be used as:

  • A foundation model for downstream medical NLP tasks (NER, classification, QA)
  • A starting point for medical instruction tuning
  • A baseline for comparing medical language model architectures

Limitations

  • Not for clinical use: This model should NOT be used for clinical decision-making
  • Small scale: ~381M parameters is relatively small; larger models will perform better
  • Limited data: Trained on ~44M tokens (production models use trillions)
  • No alignment: This is a base model without instruction tuning or RLHF
  • English only: Trained exclusively on English medical text
  • Potential biases: May reflect biases present in the medical literature

License

Apache 2.0

Downloads last month
1,430
Safetensors
Model size
0.3B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train Saminx22/MedSLM

Free AI Image Generator No sign-up. Instant results. Open Now