izhx commited on
Commit
23e2bf9
·
verified ·
1 Parent(s): 89146e4

Update custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +1 -5
custom_st.py CHANGED
@@ -6,7 +6,6 @@ import torch
6
  from PIL import Image
7
  from sentence_transformers.models import Transformer as BaseTransformer
8
  from transformers import AutoModelForVision2Seq, AutoProcessor
9
- from packaging import version
10
  import transformers
11
 
12
  class MultiModalTransformer(BaseTransformer):
@@ -54,10 +53,7 @@ class MultiModalTransformer(BaseTransformer):
54
  self, features: Dict[str, torch.Tensor], **kwargs
55
  ) -> Dict[str, torch.Tensor]:
56
  if features.get("inputs_embeds", None) is None:
57
- if version.parse(transformers.__version__) >= version.parse("4.52.0"):
58
- features["inputs_embeds"] = self.auto_model.base_model.language_model.embed_tokens(features["input_ids"])
59
- else:
60
- features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
61
  if features.get("pixel_values", None) is not None:
62
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
63
  image_embeds = self.auto_model.visual(
 
6
  from PIL import Image
7
  from sentence_transformers.models import Transformer as BaseTransformer
8
  from transformers import AutoModelForVision2Seq, AutoProcessor
 
9
  import transformers
10
 
11
  class MultiModalTransformer(BaseTransformer):
 
53
  self, features: Dict[str, torch.Tensor], **kwargs
54
  ) -> Dict[str, torch.Tensor]:
55
  if features.get("inputs_embeds", None) is None:
56
+ features["inputs_embeds"] = self.auto_model.base_model.get_input_embeddings()(features["input_ids"])
 
 
 
57
  if features.get("pixel_values", None) is not None:
58
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
59
  image_embeds = self.auto_model.visual(