Lech-Iyoko commited on
Commit
986a3fc
·
verified ·
1 Parent(s): dde0c9e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +81 -3
README.md CHANGED
@@ -1,3 +1,81 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language: en
4
+ datasets:
5
+ - gretelai/symptom_to_diagnosis
6
+ metrics:
7
+ - f1
8
+ pipeline_tag: text-classification
9
+ widget:
10
+ - text: "I have a sharp pain in my chest, difficulty breathing, and a persistent cough."
11
+ ---
12
+
13
+ # Symptom-to-Condition Classifier
14
+
15
+ This repository contains the artefacts for a LightGBM classification model that predicts a likely medical condition based on a user's textual description of their symptoms.
16
+
17
+ **This model is a proof-of-concept for a portfolio project and is NOT a medical diagnostic tool.**
18
+
19
+ ## Model Details
20
+
21
+ This is not a standard end-to-end transformer model. It is a classical machine learning pipeline that uses pre-trained transformers for feature extraction.
22
+
23
+ - **Feature Extractor:** The model uses embeddings from `emilyalsentzer/Bio_ClinicalBERT`. Specifically, it generates a 768-dimension vector for each symptom description by applying **mean pooling** to the last hidden state of the BERT model.
24
+ - **Classifier:** The actual classification is performed by a `LightGBM` (Light Gradient Boosting Machine) model trained on the embeddings.
25
+
26
+ ## Intended Use
27
+
28
+ This model is intended for educational and demonstrational purposes only. It takes a string of text describing symptoms and outputs a predicted medical condition from a predefined list of 22 classes.
29
+
30
+ ### Ethical Considerations & Limitations
31
+
32
+ - **⚠️ Not for Medical Use:** This model should **NEVER** be used to diagnose, treat, or provide medical advice for real-world health issues. It is not a substitute for consultation with a qualified healthcare professional.
33
+ - **Data Bias:** The model's knowledge is limited to the `gretelai/symptom_to_diagnosis` dataset. It cannot predict any condition outside of its 22-class training data and may perform poorly on symptom descriptions that are stylistically different from the training set.
34
+ - **Correlation, Not Causation:** The model learns statistical correlations between words and labels. It has no true understanding of biology or medicine.
35
+
36
+ ## How to Use
37
+
38
+ To use this model, you must load the feature extractor (`Bio_ClinicalBERT`), the LightGBM classifier, and the label encoder.
39
+
40
+ ```python
41
+ import torch
42
+ import joblib
43
+ from transformers import AutoTokenizer, AutoModel
44
+ from huggingface_hub import hf_hub_download
45
+
46
+ # --- CONFIGURATION ---
47
+ HF_REPO_ID = "<YourUsername>/Symptom-to-Condition-Classifier" # Replace with your repo ID
48
+ LGBM_MODEL_FILENAME = "lgbm_disease_classifier.joblib"
49
+ LABEL_ENCODER_FILENAME = "label_encoder.joblib"
50
+ BERT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
51
+
52
+ # --- LOAD ARTIFACTS ---
53
+ tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
54
+ bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME)
55
+ lgbm_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LGBM_MODEL_FILENAME)
56
+ label_encoder_path = hf_hub_download(repo_id=HF_REPO_ID, filename=LABEL_ENCODER_FILENAME)
57
+ lgbm_model = joblib.load(lgbm_model_path)
58
+ label_encoder = joblib.load(label_encoder_path)
59
+
60
+ # --- INFERENCE PIPELINE ---
61
+ def mean_pool(model_output, attention_mask):
62
+ token_embeddings = model_output.last_hidden_state
63
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
64
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
65
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
66
+ return sum_embeddings / sum_mask
67
+
68
+ def predict_condition(text):
69
+ encoded_input = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
70
+ with torch.no_grad():
71
+ model_output = bert_model(**encoded_input)
72
+ embedding = mean_pool(model_output, encoded_input['attention_mask'])
73
+
74
+ prediction_id = lgbm_model.predict(embedding.cpu().numpy())
75
+ predicted_condition = label_encoder.inverse_transform(prediction_id)[0]
76
+ return predicted_condition
77
+
78
+ # --- EXAMPLE ---
79
+ symptoms = "I have a burning sensation in my stomach that gets worse when I haven't eaten."
80
+ prediction = predict_condition(symptoms)
81
+ print(f"Predicted Condition: {prediction}")