Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,81 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language: en
|
4 |
+
datasets:
|
5 |
+
- gretelai/symptom_to_diagnosis
|
6 |
+
metrics:
|
7 |
+
- f1
|
8 |
+
pipeline_tag: text-classification
|
9 |
+
widget:
|
10 |
+
- text: "I have a sharp pain in my chest, difficulty breathing, and a persistent cough."
|
11 |
+
---
|
12 |
+
|
13 |
+
# Symptom-to-Condition Classifier
|
14 |
+
|
15 |
+
This repository contains the artefacts for a LightGBM classification model that predicts a likely medical condition based on a user's textual description of their symptoms.
|
16 |
+
|
17 |
+
**This model is a proof-of-concept for a portfolio project and is NOT a medical diagnostic tool.**
|
18 |
+
|
19 |
+
## Model Details
|
20 |
+
|
21 |
+
This is not a standard end-to-end transformer model. It is a classical machine learning pipeline that uses pre-trained transformers for feature extraction.
|
22 |
+
|
23 |
+
- **Feature Extractor:** The model uses embeddings from `emilyalsentzer/Bio_ClinicalBERT`. Specifically, it generates a 768-dimension vector for each symptom description by applying **mean pooling** to the last hidden state of the BERT model.
|
24 |
+
- **Classifier:** The actual classification is performed by a `LightGBM` (Light Gradient Boosting Machine) model trained on the embeddings.
|
25 |
+
|
26 |
+
## Intended Use
|
27 |
+
|
28 |
+
This model is intended for educational and demonstrational purposes only. It takes a string of text describing symptoms and outputs a predicted medical condition from a predefined list of 22 classes.
|
29 |
+
|
30 |
+
### Ethical Considerations & Limitations
|
31 |
+
|
32 |
+
- **⚠️ Not for Medical Use:** This model should **NEVER** be used to diagnose, treat, or provide medical advice for real-world health issues. It is not a substitute for consultation with a qualified healthcare professional.
|
33 |
+
- **Data Bias:** The model's knowledge is limited to the `gretelai/symptom_to_diagnosis` dataset. It cannot predict any condition outside of its 22-class training data and may perform poorly on symptom descriptions that are stylistically different from the training set.
|
34 |
+
- **Correlation, Not Causation:** The model learns statistical correlations between words and labels. It has no true understanding of biology or medicine.
|
35 |
+
|
36 |
+
## How to Use
|
37 |
+
|
38 |
+
To use this model, you must load the feature extractor (`Bio_ClinicalBERT`), the LightGBM classifier, and the label encoder.
|
39 |
+
|
40 |
+
```python
|
41 |
+
import torch
|
42 |
+
import joblib
|
43 |
+
from transformers import AutoTokenizer, AutoModel
|
44 |
+
from huggingface_hub import hf_hub_download
|
45 |
+
|
46 |
+
# --- CONFIGURATION ---
|
47 |
+
HF_REPO_ID = "<YourUsername>/Symptom-to-Condition-Classifier" # Replace with your repo ID
|
48 |
+
LGBM_MODEL_FILENAME = "lgbm_disease_classifier.joblib"
|
49 |
+
LABEL_ENCODER_FILENAME = "label_encoder.joblib"
|
50 |
+
BERT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
|
51 |
+
|
52 |
+
# --- LOAD ARTIFACTS ---
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
|
54 |
+
bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME)
|
55 |
+
lgbm_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LGBM_MODEL_FILENAME)
|
56 |
+
label_encoder_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LABEL_ENCODER_FILENAME)
|
57 |
+
lgbm_model = joblib.load(lgbm_model_path)
|
58 |
+
label_encoder = joblib.load(label_encoder_path)
|
59 |
+
|
60 |
+
# --- INFERENCE PIPELINE ---
|
61 |
+
def mean_pool(model_output, attention_mask):
|
62 |
+
token_embeddings = model_output.last_hidden_state
|
63 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
64 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
65 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
66 |
+
return sum_embeddings / sum_mask
|
67 |
+
|
68 |
+
def predict_condition(text):
|
69 |
+
encoded_input = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
|
70 |
+
with torch.no_grad():
|
71 |
+
model_output = bert_model(**encoded_input)
|
72 |
+
embedding = mean_pool(model_output, encoded_input['attention_mask'])
|
73 |
+
|
74 |
+
prediction_id = lgbm_model.predict(embedding.cpu().numpy())
|
75 |
+
predicted_condition = label_encoder.inverse_transform(prediction_id)[0]
|
76 |
+
return predicted_condition
|
77 |
+
|
78 |
+
# --- EXAMPLE ---
|
79 |
+
symptoms = "I have a burning sensation in my stomach that gets worse when I haven't eaten."
|
80 |
+
prediction = predict_condition(symptoms)
|
81 |
+
print(f"Predicted Condition: {prediction}")
|