license: apache-2.0
language: en
datasets:
- gretelai/symptom_to_diagnosis
metrics:
- f1
pipeline_tag: text-classification
widget:
- text: >-
I have a sharp pain in my chest, difficulty breathing, and a persistent
cough.
Symptom-to-Condition Classifier
This repository contains the artefacts for a LightGBM classification model that predicts a likely medical condition based on a user's textual description of their symptoms.
This model is a proof-of-concept for a portfolio project and is NOT a medical diagnostic tool.
Model Details
This is not a standard end-to-end transformer model. It is a classical machine learning pipeline that uses pre-trained transformers for feature extraction.
- Feature Extractor: The model uses embeddings from
emilyalsentzer/Bio_ClinicalBERT
. Specifically, it generates a 768-dimension vector for each symptom description by applying mean pooling to the last hidden state of the BERT model. - Classifier: The actual classification is performed by a
LightGBM
(Light Gradient Boosting Machine) model trained on the embeddings.
Intended Use
This model is intended for educational and demonstrational purposes only. It takes a string of text describing symptoms and outputs a predicted medical condition from a predefined list of 22 classes.
Ethical Considerations & Limitations
- ⚠️ Not for Medical Use: This model should NEVER be used to diagnose, treat, or provide medical advice for real-world health issues. It is not a substitute for consultation with a qualified healthcare professional.
- Data Bias: The model's knowledge is limited to the
gretelai/symptom_to_diagnosis
dataset. It cannot predict any condition outside of its 22-class training data and may perform poorly on symptom descriptions that are stylistically different from the training set. - Correlation, Not Causation: The model learns statistical correlations between words and labels. It has no true understanding of biology or medicine.
How to Use
To use this model, you must load the feature extractor (Bio_ClinicalBERT
), the LightGBM classifier, and the label encoder.
Training Data
This model was trained on the gretelai/symptom_to_diagnosis dataset, which contains ~1000 symptom descriptions across 22 balanced classes.
Evaluation
The model achieves a Macro F1-score of 0.834 and an Accuracy of 0.835 on the test set.
import torch
import joblib
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
# --- CONFIGURATION ---
HF_REPO_ID = "<YourUsername>/Symptom-to-Condition-Classifier" # Replace with your repo ID
LGBM_MODEL_FILENAME = "lgbm_disease_classifier.joblib"
LABEL_ENCODER_FILENAME = "label_encoder.joblib"
BERT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
# --- LOAD ARTIFACTS ---
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME)
lgbm_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LGBM_MODEL_FILENAME)
label_encoder_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LABEL_ENCODER_FILENAME)
lgbm_model = joblib.load(lgbm_model_path)
label_encoder = joblib.load(label_encoder_path)
# --- INFERENCE PIPELINE ---
def mean_pool(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def predict_condition(text):
encoded_input = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
with torch.no_grad():
model_output = bert_model(**encoded_input)
embedding = mean_pool(model_output, encoded_input['attention_mask'])
prediction_id = lgbm_model.predict(embedding.cpu().numpy())
predicted_condition = label_encoder.inverse_transform(prediction_id)[0]
return predicted_condition
# --- EXAMPLE ---
symptoms = "I have a burning sensation in my stomach that gets worse when I haven't eaten."
prediction = predict_condition(symptoms)
print(f"Predicted Condition: {prediction}")