|
--- |
|
license: mit |
|
language: |
|
- en |
|
tags: |
|
- text-classification |
|
- multi-label-classification |
|
- intent-classification |
|
- healthcare |
|
- social-media |
|
- uk |
|
library_name: transformers |
|
pipeline_tag: text-classification |
|
|
|
|
|
--- |
|
|
|
## DocMap Leads Classifier (UK healthcare, social media) |
|
|
|
Multi-head encoder classifier built on `microsoft/deberta-v3-base` to identify: |
|
- Intent (single label) |
|
- Symptoms (multi-label) |
|
- Specialties (multi-label) |
|
|
|
Trained on DocMap leads_v1 JSONL (Supabase), using simple weak labels derived from keywords/zero-shot prompts. See `label_config.json` for the canonical label spaces. |
|
|
|
### What’s in this repo |
|
- `model.safetensors`: classifier heads + encoder weights |
|
- `label_config.json`: lists of `intents`, `symptoms`, `specialties` |
|
- `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json`, `spm.model` |
|
- `README.md`, `.gitattributes` |
|
- (Checkpoints may be removed for smaller repo size) |
|
|
|
### Intended use |
|
- Lead identification from public social media text in a UK healthcare context. |
|
- Outputs: intent label and multi-label sets for symptoms/specialties. |
|
|
|
Not a medical device. Do not use for diagnosis; for triage or marketing pre-filtering only. |
|
|
|
### Training |
|
- Base: `microsoft/deberta-v3-base` |
|
- Epochs: 3, batch size: 16, lr: 2e-5, max_len: 256 |
|
- Split: 10% validation |
|
- Threshold sweep on validation suggested `0.3` as default for multi-label heads. |
|
|
|
### Quick start (Python) |
|
This model uses a lightweight custom head. Load with the snippet below (no HF widget). |
|
|
|
```python |
|
import json, torch |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
repo_id = "YOUR_USER_OR_ORG/docmap-leads-classifier-v1" # change to your repo |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
# Load labels |
|
from huggingface_hub import hf_hub_download |
|
cfg_path = hf_hub_download(repo_id, "label_config.json") |
|
with open(cfg_path, "r") as f: |
|
cfg = json.load(f) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id) |
|
|
|
class LeadsClassifier(nn.Module): |
|
def __init__(self, base_model_name, num_intents, num_symptoms, num_specialties): |
|
super().__init__() |
|
self.encoder = AutoModel.from_pretrained(base_model_name) |
|
hidden = self.encoder.config.hidden_size |
|
self.dropout = nn.Dropout(0.1) |
|
self.intent_head = nn.Linear(hidden, num_intents) |
|
self.sym_head = nn.Linear(hidden, num_symptoms) |
|
self.spec_head = nn.Linear(hidden, num_specialties) |
|
def forward(self, input_ids=None, attention_mask=None): |
|
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
cls = self.dropout(out.last_hidden_state[:, 0, :]) |
|
return { |
|
"intent_logits": self.intent_head(cls), |
|
"sym_logits": self.sym_head(cls), |
|
"spec_logits": self.spec_head(cls), |
|
} |
|
|
|
model = LeadsClassifier( |
|
base_model_name="microsoft/deberta-v3-base", |
|
num_intents=len(cfg["intents"]), |
|
num_symptoms=len(cfg["symptoms"]), |
|
num_specialties=len(cfg["specialties"]), |
|
).to(device) |
|
|
|
# Load weights |
|
sd_path = hf_hub_download(repo_id, "model.safetensors") |
|
model.load_state_dict(torch.load(sd_path, map_location=device, weights_only=True)) |
|
model.eval() |
|
|
|
def predict(texts, thr=0.3): |
|
batch = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
out = model(**batch) |
|
intent = torch.softmax(out["intent_logits"], dim=-1).argmax(dim=-1).tolist() |
|
sym_prob = torch.sigmoid(out["sym_logits"]) |
|
spec_prob = torch.sigmoid(out["spec_logits"]) |
|
intents = [cfg["intents"][i] for i in intent] |
|
symptoms = [[cfg["symptoms"][j] for j, p in enumerate(row) if p >= thr] for row in sym_prob.tolist()] |
|
specialties = [[cfg["specialties"][j] for j, p in enumerate(row) if p >= thr] for row in spec_prob.tolist()] |
|
return [{"intent": i, "symptoms": s, "specialties": sp} for i, s, sp in zip(intents, symptoms, specialties)] |
|
|
|
print(predict(["Any advice on fever? Based in Glasgow, started 3 days ago."], thr=0.3)) |
|
``` |
|
|
|
### Inference thresholds |
|
- Default multi-label threshold: **0.3** (from validation sweep). |
|
- Tune per use-case; 0.5 is stricter, 0.2 more sensitive. |
|
|
|
|
|
### Limitations and risks |
|
- Weakly supervised labels; potential label noise and leakage. |
|
- Social media domain; may not generalize to clinical text. |
|
- Not for medical diagnosis or emergency advice. |
|
|
|
### License |
|
- MIT (inherits base model’s MIT license). |
|
|
|
### Citation |
|
Please cite the base model and this repository if you use it in research or production. |