anicolson commited on
Commit
1988ce1
·
verified ·
1 Parent(s): c3491fd

Upload model

Browse files
Files changed (1) hide show
  1. modelling_cxrmate_ed.py +15 -0
modelling_cxrmate_ed.py CHANGED
@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
7
  import datasets
8
  import torch
9
  import transformers
 
10
  from torch.nn import CrossEntropyLoss
11
  from torch.utils.data import Subset
12
  from torchvision.io import decode_image
@@ -182,6 +183,20 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
182
 
183
  self.post_init()
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  # @classmethod
187
  # def from_encoder_decoder_pretrained(
 
7
  import datasets
8
  import torch
9
  import transformers
10
+ from huggingface_hub import hf_hub_download
11
  from torch.nn import CrossEntropyLoss
12
  from torch.utils.data import Subset
13
  from torchvision.io import decode_image
 
183
 
184
  self.post_init()
185
 
186
+ @classmethod
187
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
188
+
189
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='tables.json')
190
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='token_type_ids.json')
191
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='lookup_tables.json')
192
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_train_study_ids.json')
193
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_validate_study_ids.json')
194
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_test_study_ids.json')
195
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_train_study_ids.json')
196
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_validate_study_ids.json')
197
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_test_study_ids.json')
198
+
199
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
200
 
201
  # @classmethod
202
  # def from_encoder_decoder_pretrained(