Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,81 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language: en
|
| 4 |
+
datasets:
|
| 5 |
+
- gretelai/symptom_to_diagnosis
|
| 6 |
+
metrics:
|
| 7 |
+
- f1
|
| 8 |
+
pipeline_tag: text-classification
|
| 9 |
+
widget:
|
| 10 |
+
- text: "I have a sharp pain in my chest, difficulty breathing, and a persistent cough."
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Symptom-to-Condition Classifier
|
| 14 |
+
|
| 15 |
+
This repository contains the artefacts for a LightGBM classification model that predicts a likely medical condition based on a user's textual description of their symptoms.
|
| 16 |
+
|
| 17 |
+
**This model is a proof-of-concept for a portfolio project and is NOT a medical diagnostic tool.**
|
| 18 |
+
|
| 19 |
+
## Model Details
|
| 20 |
+
|
| 21 |
+
This is not a standard end-to-end transformer model. It is a classical machine learning pipeline that uses pre-trained transformers for feature extraction.
|
| 22 |
+
|
| 23 |
+
- **Feature Extractor:** The model uses embeddings from `emilyalsentzer/Bio_ClinicalBERT`. Specifically, it generates a 768-dimension vector for each symptom description by applying **mean pooling** to the last hidden state of the BERT model.
|
| 24 |
+
- **Classifier:** The actual classification is performed by a `LightGBM` (Light Gradient Boosting Machine) model trained on the embeddings.
|
| 25 |
+
|
| 26 |
+
## Intended Use
|
| 27 |
+
|
| 28 |
+
This model is intended for educational and demonstrational purposes only. It takes a string of text describing symptoms and outputs a predicted medical condition from a predefined list of 22 classes.
|
| 29 |
+
|
| 30 |
+
### Ethical Considerations & Limitations
|
| 31 |
+
|
| 32 |
+
- **⚠️ Not for Medical Use:** This model should **NEVER** be used to diagnose, treat, or provide medical advice for real-world health issues. It is not a substitute for consultation with a qualified healthcare professional.
|
| 33 |
+
- **Data Bias:** The model's knowledge is limited to the `gretelai/symptom_to_diagnosis` dataset. It cannot predict any condition outside of its 22-class training data and may perform poorly on symptom descriptions that are stylistically different from the training set.
|
| 34 |
+
- **Correlation, Not Causation:** The model learns statistical correlations between words and labels. It has no true understanding of biology or medicine.
|
| 35 |
+
|
| 36 |
+
## How to Use
|
| 37 |
+
|
| 38 |
+
To use this model, you must load the feature extractor (`Bio_ClinicalBERT`), the LightGBM classifier, and the label encoder.
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
import torch
|
| 42 |
+
import joblib
|
| 43 |
+
from transformers import AutoTokenizer, AutoModel
|
| 44 |
+
from huggingface_hub import hf_hub_download
|
| 45 |
+
|
| 46 |
+
# --- CONFIGURATION ---
|
| 47 |
+
HF_REPO_ID = "<YourUsername>/Symptom-to-Condition-Classifier" # Replace with your repo ID
|
| 48 |
+
LGBM_MODEL_FILENAME = "lgbm_disease_classifier.joblib"
|
| 49 |
+
LABEL_ENCODER_FILENAME = "label_encoder.joblib"
|
| 50 |
+
BERT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
|
| 51 |
+
|
| 52 |
+
# --- LOAD ARTIFACTS ---
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
|
| 54 |
+
bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME)
|
| 55 |
+
lgbm_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LGBM_MODEL_FILENAME)
|
| 56 |
+
label_encoder_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LABEL_ENCODER_FILENAME)
|
| 57 |
+
lgbm_model = joblib.load(lgbm_model_path)
|
| 58 |
+
label_encoder = joblib.load(label_encoder_path)
|
| 59 |
+
|
| 60 |
+
# --- INFERENCE PIPELINE ---
|
| 61 |
+
def mean_pool(model_output, attention_mask):
|
| 62 |
+
token_embeddings = model_output.last_hidden_state
|
| 63 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 64 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
| 65 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 66 |
+
return sum_embeddings / sum_mask
|
| 67 |
+
|
| 68 |
+
def predict_condition(text):
|
| 69 |
+
encoded_input = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
model_output = bert_model(**encoded_input)
|
| 72 |
+
embedding = mean_pool(model_output, encoded_input['attention_mask'])
|
| 73 |
+
|
| 74 |
+
prediction_id = lgbm_model.predict(embedding.cpu().numpy())
|
| 75 |
+
predicted_condition = label_encoder.inverse_transform(prediction_id)[0]
|
| 76 |
+
return predicted_condition
|
| 77 |
+
|
| 78 |
+
# --- EXAMPLE ---
|
| 79 |
+
symptoms = "I have a burning sensation in my stomach that gets worse when I haven't eaten."
|
| 80 |
+
prediction = predict_condition(symptoms)
|
| 81 |
+
print(f"Predicted Condition: {prediction}")
|