tt1225 commited on
Commit
38474ea
·
verified ·
1 Parent(s): a58aca6

Update modeling_prismatic.py

Browse files
Files changed (1) hide show
  1. modeling_prismatic.py +3 -1
modeling_prismatic.py CHANGED
@@ -397,6 +397,8 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
397
  )
398
  multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
399
 
 
 
400
  # Dispatch to Language Model
401
  language_model_output = self.language_model(
402
  input_ids=None,
@@ -408,7 +410,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
408
  use_cache=use_cache,
409
  output_attentions=True,
410
  output_hidden_states=output_hidden_states,
411
- return_dict=return_dict,
412
  )
413
 
414
  # === Otherwise =>> Assume Invalid! ===
 
397
  )
398
  multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
399
 
400
+ print("Must run this")
401
+
402
  # Dispatch to Language Model
403
  language_model_output = self.language_model(
404
  input_ids=None,
 
410
  use_cache=use_cache,
411
  output_attentions=True,
412
  output_hidden_states=output_hidden_states,
413
+ return_dict=True,
414
  )
415
 
416
  # === Otherwise =>> Assume Invalid! ===