rabbitfishai commited on
Commit
f61dfa3
·
verified ·
1 Parent(s): 2b510fc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +115 -15
README.md CHANGED
@@ -1,24 +1,124 @@
1
  ---
2
- license: apache-2.0
3
- pipeline_tag: text-classification
4
  language:
5
  - en
6
  tags:
7
- - leads
 
 
8
  - healthcare
9
- - classifier
10
- - intent
11
- - multi-label
12
- - deberta
13
- base_model:
14
- - microsoft/deberta-large-mnli
15
  ---
16
 
17
- # DocMap Leads Classifier (DeBERTa-v3-base)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- Predicts:
20
- - intent (single label)
21
- - symptoms (multi-label)
22
- - specialties (multi-label)
23
 
24
- Load the model and heads using your service code; apply sigmoid thresholds (e.g., 0.4–0.5) for multi-label outputs.
 
 
1
  ---
2
+ license: mit
 
3
  language:
4
  - en
5
  tags:
6
+ - text-classification
7
+ - multi-label-classification
8
+ - intent-classification
9
  - healthcare
10
+ - social-media
11
+ - uk
12
+ library_name: transformers
13
+ pipeline_tag: text-classification
14
+
15
+
16
  ---
17
 
18
+ ## DocMap Leads Classifier (UK healthcare, social media)
19
+
20
+ Multi-head encoder classifier built on `microsoft/deberta-v3-base` to identify:
21
+ - Intent (single label)
22
+ - Symptoms (multi-label)
23
+ - Specialties (multi-label)
24
+
25
+ Trained on DocMap leads_v1 JSONL (Supabase), using simple weak labels derived from keywords/zero-shot prompts. See `label_config.json` for the canonical label spaces.
26
+
27
+ ### What’s in this repo
28
+ - `model.safetensors`: classifier heads + encoder weights
29
+ - `label_config.json`: lists of `intents`, `symptoms`, `specialties`
30
+ - `tokenizer.json`, `tokenizer_config.json`, `special_tokens_map.json`, `spm.model`
31
+ - `README.md`, `.gitattributes`
32
+ - (Checkpoints may be removed for smaller repo size)
33
+
34
+ ### Intended use
35
+ - Lead identification from public social media text in a UK healthcare context.
36
+ - Outputs: intent label and multi-label sets for symptoms/specialties.
37
+
38
+ Not a medical device. Do not use for diagnosis; for triage or marketing pre-filtering only.
39
+
40
+ ### Training
41
+ - Base: `microsoft/deberta-v3-base`
42
+ - Epochs: 3, batch size: 16, lr: 2e-5, max_len: 256
43
+ - Split: 10% validation
44
+ - Threshold sweep on validation suggested `0.3` as default for multi-label heads.
45
+
46
+ ### Quick start (Python)
47
+ This model uses a lightweight custom head. Load with the snippet below (no HF widget).
48
+
49
+ ```python
50
+ import json, torch
51
+ import torch.nn as nn
52
+ from transformers import AutoTokenizer, AutoModel
53
+
54
+ repo_id = "YOUR_USER_OR_ORG/docmap-leads-classifier-v1" # change to your repo
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+
57
+ # Load labels
58
+ from huggingface_hub import hf_hub_download
59
+ cfg_path = hf_hub_download(repo_id, "label_config.json")
60
+ with open(cfg_path, "r") as f:
61
+ cfg = json.load(f)
62
+
63
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
64
+
65
+ class LeadsClassifier(nn.Module):
66
+ def __init__(self, base_model_name, num_intents, num_symptoms, num_specialties):
67
+ super().__init__()
68
+ self.encoder = AutoModel.from_pretrained(base_model_name)
69
+ hidden = self.encoder.config.hidden_size
70
+ self.dropout = nn.Dropout(0.1)
71
+ self.intent_head = nn.Linear(hidden, num_intents)
72
+ self.sym_head = nn.Linear(hidden, num_symptoms)
73
+ self.spec_head = nn.Linear(hidden, num_specialties)
74
+ def forward(self, input_ids=None, attention_mask=None):
75
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
76
+ cls = self.dropout(out.last_hidden_state[:, 0, :])
77
+ return {
78
+ "intent_logits": self.intent_head(cls),
79
+ "sym_logits": self.sym_head(cls),
80
+ "spec_logits": self.spec_head(cls),
81
+ }
82
+
83
+ model = LeadsClassifier(
84
+ base_model_name="microsoft/deberta-v3-base",
85
+ num_intents=len(cfg["intents"]),
86
+ num_symptoms=len(cfg["symptoms"]),
87
+ num_specialties=len(cfg["specialties"]),
88
+ ).to(device)
89
+
90
+ # Load weights
91
+ sd_path = hf_hub_download(repo_id, "model.safetensors")
92
+ model.load_state_dict(torch.load(sd_path, map_location=device, weights_only=True))
93
+ model.eval()
94
+
95
+ def predict(texts, thr=0.3):
96
+ batch = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
97
+ with torch.no_grad():
98
+ out = model(**batch)
99
+ intent = torch.softmax(out["intent_logits"], dim=-1).argmax(dim=-1).tolist()
100
+ sym_prob = torch.sigmoid(out["sym_logits"])
101
+ spec_prob = torch.sigmoid(out["spec_logits"])
102
+ intents = [cfg["intents"][i] for i in intent]
103
+ symptoms = [[cfg["symptoms"][j] for j, p in enumerate(row) if p >= thr] for row in sym_prob.tolist()]
104
+ specialties = [[cfg["specialties"][j] for j, p in enumerate(row) if p >= thr] for row in spec_prob.tolist()]
105
+ return [{"intent": i, "symptoms": s, "specialties": sp} for i, s, sp in zip(intents, symptoms, specialties)]
106
+
107
+ print(predict(["Any advice on fever? Based in Glasgow, started 3 days ago."], thr=0.3))
108
+ ```
109
+
110
+ ### Inference thresholds
111
+ - Default multi-label threshold: **0.3** (from validation sweep).
112
+ - Tune per use-case; 0.5 is stricter, 0.2 more sensitive.
113
+
114
+
115
+ ### Limitations and risks
116
+ - Weakly supervised labels; potential label noise and leakage.
117
+ - Social media domain; may not generalize to clinical text.
118
+ - Not for medical diagnosis or emergency advice.
119
 
120
+ ### License
121
+ - MIT (inherits base model’s MIT license).
 
 
122
 
123
+ ### Citation
124
+ Please cite the base model and this repository if you use it in research or production.