Update README.md
Browse files
README.md
CHANGED
@@ -1,24 +1,124 @@
|
|
1 |
---
|
2 |
-
license:
|
3 |
-
pipeline_tag: text-classification
|
4 |
language:
|
5 |
- en
|
6 |
tags:
|
7 |
-
-
|
|
|
|
|
8 |
- healthcare
|
9 |
-
-
|
10 |
-
-
|
11 |
-
|
12 |
-
-
|
13 |
-
|
14 |
-
|
15 |
---
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
-
|
21 |
-
- symptoms (multi-label)
|
22 |
-
- specialties (multi-label)
|
23 |
|
24 |
-
|
|
|
|
1 |
---
|
2 |
+
license: mit
|
|
|
3 |
language:
|
4 |
- en
|
5 |
tags:
|
6 |
+
- text-classification
|
7 |
+
- multi-label-classification
|
8 |
+
- intent-classification
|
9 |
- healthcare
|
10 |
+
- social-media
|
11 |
+
- uk
|
12 |
+
library_name: transformers
|
13 |
+
pipeline_tag: text-classification
|
14 |
+
|
15 |
+
|
16 |
---
|
17 |
|
18 |
+
## DocMap Leads Classifier (UK healthcare, social media)
|
19 |
+
|
20 |
+
Multi-head encoder classifier built on `microsoft/deberta-v3-base` to identify:
|
21 |
+
- Intent (single label)
|
22 |
+
- Symptoms (multi-label)
|
23 |
+
- Specialties (multi-label)
|
24 |
+
|
25 |
+
Trained on DocMap leads_v1 JSONL (Supabase), using simple weak labels derived from keywords/zero-shot prompts. See `label_config.json` for the canonical label spaces.
|
26 |
+
|
27 |
+
### What’s in this repo
|
28 |
+
- `model.safetensors`: classifier heads + encoder weights
|
29 |
+
- `label_config.json`: lists of `intents`, `symptoms`, `specialties`
|
30 |
+
- `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json`, `spm.model`
|
31 |
+
- `README.md`, `.gitattributes`
|
32 |
+
- (Checkpoints may be removed for smaller repo size)
|
33 |
+
|
34 |
+
### Intended use
|
35 |
+
- Lead identification from public social media text in a UK healthcare context.
|
36 |
+
- Outputs: intent label and multi-label sets for symptoms/specialties.
|
37 |
+
|
38 |
+
Not a medical device. Do not use for diagnosis; for triage or marketing pre-filtering only.
|
39 |
+
|
40 |
+
### Training
|
41 |
+
- Base: `microsoft/deberta-v3-base`
|
42 |
+
- Epochs: 3, batch size: 16, lr: 2e-5, max_len: 256
|
43 |
+
- Split: 10% validation
|
44 |
+
- Threshold sweep on validation suggested `0.3` as default for multi-label heads.
|
45 |
+
|
46 |
+
### Quick start (Python)
|
47 |
+
This model uses a lightweight custom head. Load with the snippet below (no HF widget).
|
48 |
+
|
49 |
+
```python
|
50 |
+
import json, torch
|
51 |
+
import torch.nn as nn
|
52 |
+
from transformers import AutoTokenizer, AutoModel
|
53 |
+
|
54 |
+
repo_id = "YOUR_USER_OR_ORG/docmap-leads-classifier-v1" # change to your repo
|
55 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
56 |
+
|
57 |
+
# Load labels
|
58 |
+
from huggingface_hub import hf_hub_download
|
59 |
+
cfg_path = hf_hub_download(repo_id, "label_config.json")
|
60 |
+
with open(cfg_path, "r") as f:
|
61 |
+
cfg = json.load(f)
|
62 |
+
|
63 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
64 |
+
|
65 |
+
class LeadsClassifier(nn.Module):
|
66 |
+
def __init__(self, base_model_name, num_intents, num_symptoms, num_specialties):
|
67 |
+
super().__init__()
|
68 |
+
self.encoder = AutoModel.from_pretrained(base_model_name)
|
69 |
+
hidden = self.encoder.config.hidden_size
|
70 |
+
self.dropout = nn.Dropout(0.1)
|
71 |
+
self.intent_head = nn.Linear(hidden, num_intents)
|
72 |
+
self.sym_head = nn.Linear(hidden, num_symptoms)
|
73 |
+
self.spec_head = nn.Linear(hidden, num_specialties)
|
74 |
+
def forward(self, input_ids=None, attention_mask=None):
|
75 |
+
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
76 |
+
cls = self.dropout(out.last_hidden_state[:, 0, :])
|
77 |
+
return {
|
78 |
+
"intent_logits": self.intent_head(cls),
|
79 |
+
"sym_logits": self.sym_head(cls),
|
80 |
+
"spec_logits": self.spec_head(cls),
|
81 |
+
}
|
82 |
+
|
83 |
+
model = LeadsClassifier(
|
84 |
+
base_model_name="microsoft/deberta-v3-base",
|
85 |
+
num_intents=len(cfg["intents"]),
|
86 |
+
num_symptoms=len(cfg["symptoms"]),
|
87 |
+
num_specialties=len(cfg["specialties"]),
|
88 |
+
).to(device)
|
89 |
+
|
90 |
+
# Load weights
|
91 |
+
sd_path = hf_hub_download(repo_id, "model.safetensors")
|
92 |
+
model.load_state_dict(torch.load(sd_path, map_location=device, weights_only=True))
|
93 |
+
model.eval()
|
94 |
+
|
95 |
+
def predict(texts, thr=0.3):
|
96 |
+
batch = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
|
97 |
+
with torch.no_grad():
|
98 |
+
out = model(**batch)
|
99 |
+
intent = torch.softmax(out["intent_logits"], dim=-1).argmax(dim=-1).tolist()
|
100 |
+
sym_prob = torch.sigmoid(out["sym_logits"])
|
101 |
+
spec_prob = torch.sigmoid(out["spec_logits"])
|
102 |
+
intents = [cfg["intents"][i] for i in intent]
|
103 |
+
symptoms = [[cfg["symptoms"][j] for j, p in enumerate(row) if p >= thr] for row in sym_prob.tolist()]
|
104 |
+
specialties = [[cfg["specialties"][j] for j, p in enumerate(row) if p >= thr] for row in spec_prob.tolist()]
|
105 |
+
return [{"intent": i, "symptoms": s, "specialties": sp} for i, s, sp in zip(intents, symptoms, specialties)]
|
106 |
+
|
107 |
+
print(predict(["Any advice on fever? Based in Glasgow, started 3 days ago."], thr=0.3))
|
108 |
+
```
|
109 |
+
|
110 |
+
### Inference thresholds
|
111 |
+
- Default multi-label threshold: **0.3** (from validation sweep).
|
112 |
+
- Tune per use-case; 0.5 is stricter, 0.2 more sensitive.
|
113 |
+
|
114 |
+
|
115 |
+
### Limitations and risks
|
116 |
+
- Weakly supervised labels; potential label noise and leakage.
|
117 |
+
- Social media domain; may not generalize to clinical text.
|
118 |
+
- Not for medical diagnosis or emergency advice.
|
119 |
|
120 |
+
### License
|
121 |
+
- MIT (inherits base model’s MIT license).
|
|
|
|
|
122 |
|
123 |
+
### Citation
|
124 |
+
Please cite the base model and this repository if you use it in research or production.
|