|
|
--- |
|
|
license: apache-2.0 |
|
|
language: en |
|
|
datasets: |
|
|
- gretelai/symptom_to_diagnosis |
|
|
metrics: |
|
|
- f1 |
|
|
pipeline_tag: text-classification |
|
|
widget: |
|
|
- text: "I have a sharp pain in my chest, difficulty breathing, and a persistent cough." |
|
|
--- |
|
|
|
|
|
# Symptom-to-Condition Classifier |
|
|
|
|
|
This repository contains the artefacts for a LightGBM classification model that predicts a likely medical condition based on a user's textual description of their symptoms. |
|
|
|
|
|
**This model is a proof-of-concept for a portfolio project and is NOT a medical diagnostic tool.** |
|
|
|
|
|
## Model Details |
|
|
|
|
|
This is not a standard end-to-end transformer model. It is a classical machine learning pipeline that uses pre-trained transformers for feature extraction. |
|
|
|
|
|
- **Feature Extractor:** The model uses embeddings from `emilyalsentzer/Bio_ClinicalBERT`. Specifically, it generates a 768-dimension vector for each symptom description by applying **mean pooling** to the last hidden state of the BERT model. |
|
|
- **Classifier:** The actual classification is performed by a `LightGBM` (Light Gradient Boosting Machine) model trained on the embeddings. |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
This model is intended for educational and demonstrational purposes only. It takes a string of text describing symptoms and outputs a predicted medical condition from a predefined list of 22 classes. |
|
|
|
|
|
### Ethical Considerations & Limitations |
|
|
|
|
|
- **⚠️ Not for Medical Use:** This model should **NEVER** be used to diagnose, treat, or provide medical advice for real-world health issues. It is not a substitute for consultation with a qualified healthcare professional. |
|
|
- **Data Bias:** The model's knowledge is limited to the `gretelai/symptom_to_diagnosis` dataset. It cannot predict any condition outside of its 22-class training data and may perform poorly on symptom descriptions that are stylistically different from the training set. |
|
|
- **Correlation, Not Causation:** The model learns statistical correlations between words and labels. It has no true understanding of biology or medicine. |
|
|
|
|
|
## How to Use |
|
|
|
|
|
To use this model, you must load the feature extractor (`Bio_ClinicalBERT`), the LightGBM classifier, and the label encoder. |
|
|
|
|
|
## Training Data |
|
|
This model was trained on the gretelai/symptom_to_diagnosis dataset, which contains ~1000 symptom descriptions across 22 balanced classes. |
|
|
|
|
|
## Evaluation |
|
|
The model achieves a Macro F1-score of 0.834 and an Accuracy of 0.835 on the test set. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import joblib |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# --- CONFIGURATION --- |
|
|
HF_REPO_ID = "<YourUsername>/Symptom-to-Condition-Classifier" # Replace with your repo ID |
|
|
LGBM_MODEL_FILENAME = "lgbm_disease_classifier.joblib" |
|
|
LABEL_ENCODER_FILENAME = "label_encoder.joblib" |
|
|
BERT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT" |
|
|
|
|
|
# --- LOAD ARTIFACTS --- |
|
|
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME) |
|
|
bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME) |
|
|
lgbm_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LGBM_MODEL_FILENAME) |
|
|
label_encoder_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LABEL_ENCODER_FILENAME) |
|
|
lgbm_model = joblib.load(lgbm_model_path) |
|
|
label_encoder = joblib.load(label_encoder_path) |
|
|
|
|
|
# --- INFERENCE PIPELINE --- |
|
|
def mean_pool(model_output, attention_mask): |
|
|
token_embeddings = model_output.last_hidden_state |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
return sum_embeddings / sum_mask |
|
|
|
|
|
def predict_condition(text): |
|
|
encoded_input = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt') |
|
|
with torch.no_grad(): |
|
|
model_output = bert_model(**encoded_input) |
|
|
embedding = mean_pool(model_output, encoded_input['attention_mask']) |
|
|
|
|
|
prediction_id = lgbm_model.predict(embedding.cpu().numpy()) |
|
|
predicted_condition = label_encoder.inverse_transform(prediction_id)[0] |
|
|
return predicted_condition |
|
|
|
|
|
# --- EXAMPLE --- |
|
|
symptoms = "I have a burning sensation in my stomach that gets worse when I haven't eaten." |
|
|
prediction = predict_condition(symptoms) |
|
|
print(f"Predicted Condition: {prediction}") |