Update modeling_prismatic.py
Browse files- modeling_prismatic.py +3 -1
modeling_prismatic.py
CHANGED
@@ -397,6 +397,8 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
397 |
)
|
398 |
multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
|
399 |
|
|
|
|
|
400 |
# Dispatch to Language Model
|
401 |
language_model_output = self.language_model(
|
402 |
input_ids=None,
|
@@ -408,7 +410,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
408 |
use_cache=use_cache,
|
409 |
output_attentions=True,
|
410 |
output_hidden_states=output_hidden_states,
|
411 |
-
return_dict=
|
412 |
)
|
413 |
|
414 |
# === Otherwise =>> Assume Invalid! ===
|
|
|
397 |
)
|
398 |
multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
|
399 |
|
400 |
+
print("Must run this")
|
401 |
+
|
402 |
# Dispatch to Language Model
|
403 |
language_model_output = self.language_model(
|
404 |
input_ids=None,
|
|
|
410 |
use_cache=use_cache,
|
411 |
output_attentions=True,
|
412 |
output_hidden_states=output_hidden_states,
|
413 |
+
return_dict=True,
|
414 |
)
|
415 |
|
416 |
# === Otherwise =>> Assume Invalid! ===
|