File size: 4,554 Bytes
2854174 f61dfa3 2854174 f61dfa3 2854174 f61dfa3 2854174 f61dfa3 2854174 f61dfa3 2854174 f61dfa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
---
license: mit
language:
- en
tags:
- text-classification
- multi-label-classification
- intent-classification
- healthcare
- social-media
- uk
library_name: transformers
pipeline_tag: text-classification
---
## DocMap Leads Classifier (UK healthcare, social media)
Multi-head encoder classifier built on `microsoft/deberta-v3-base` to identify:
- Intent (single label)
- Symptoms (multi-label)
- Specialties (multi-label)
Trained on DocMap leads_v1 JSONL (Supabase), using simple weak labels derived from keywords/zero-shot prompts. See `label_config.json` for the canonical label spaces.
### What’s in this repo
- `model.safetensors`: classifier heads + encoder weights
- `label_config.json`: lists of `intents`, `symptoms`, `specialties`
- `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json`, `spm.model`
- `README.md`, `.gitattributes`
- (Checkpoints may be removed for smaller repo size)
### Intended use
- Lead identification from public social media text in a UK healthcare context.
- Outputs: intent label and multi-label sets for symptoms/specialties.
Not a medical device. Do not use for diagnosis; for triage or marketing pre-filtering only.
### Training
- Base: `microsoft/deberta-v3-base`
- Epochs: 3, batch size: 16, lr: 2e-5, max_len: 256
- Split: 10% validation
- Threshold sweep on validation suggested `0.3` as default for multi-label heads.
### Quick start (Python)
This model uses a lightweight custom head. Load with the snippet below (no HF widget).
```python
import json, torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
repo_id = "YOUR_USER_OR_ORG/docmap-leads-classifier-v1" # change to your repo
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load labels
from huggingface_hub import hf_hub_download
cfg_path = hf_hub_download(repo_id, "label_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
class LeadsClassifier(nn.Module):
def __init__(self, base_model_name, num_intents, num_symptoms, num_specialties):
super().__init__()
self.encoder = AutoModel.from_pretrained(base_model_name)
hidden = self.encoder.config.hidden_size
self.dropout = nn.Dropout(0.1)
self.intent_head = nn.Linear(hidden, num_intents)
self.sym_head = nn.Linear(hidden, num_symptoms)
self.spec_head = nn.Linear(hidden, num_specialties)
def forward(self, input_ids=None, attention_mask=None):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls = self.dropout(out.last_hidden_state[:, 0, :])
return {
"intent_logits": self.intent_head(cls),
"sym_logits": self.sym_head(cls),
"spec_logits": self.spec_head(cls),
}
model = LeadsClassifier(
base_model_name="microsoft/deberta-v3-base",
num_intents=len(cfg["intents"]),
num_symptoms=len(cfg["symptoms"]),
num_specialties=len(cfg["specialties"]),
).to(device)
# Load weights
sd_path = hf_hub_download(repo_id, "model.safetensors")
model.load_state_dict(torch.load(sd_path, map_location=device, weights_only=True))
model.eval()
def predict(texts, thr=0.3):
batch = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
with torch.no_grad():
out = model(**batch)
intent = torch.softmax(out["intent_logits"], dim=-1).argmax(dim=-1).tolist()
sym_prob = torch.sigmoid(out["sym_logits"])
spec_prob = torch.sigmoid(out["spec_logits"])
intents = [cfg["intents"][i] for i in intent]
symptoms = [[cfg["symptoms"][j] for j, p in enumerate(row) if p >= thr] for row in sym_prob.tolist()]
specialties = [[cfg["specialties"][j] for j, p in enumerate(row) if p >= thr] for row in spec_prob.tolist()]
return [{"intent": i, "symptoms": s, "specialties": sp} for i, s, sp in zip(intents, symptoms, specialties)]
print(predict(["Any advice on fever? Based in Glasgow, started 3 days ago."], thr=0.3))
```
### Inference thresholds
- Default multi-label threshold: **0.3** (from validation sweep).
- Tune per use-case; 0.5 is stricter, 0.2 more sensitive.
### Limitations and risks
- Weakly supervised labels; potential label noise and leakage.
- Social media domain; may not generalize to clinical text.
- Not for medical diagnosis or emergency advice.
### License
- MIT (inherits base model’s MIT license).
### Citation
Please cite the base model and this repository if you use it in research or production. |