Update modeling_prismatic.py
Browse files- modeling_prismatic.py +1 -1
modeling_prismatic.py
CHANGED
@@ -524,7 +524,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
524 |
|
525 |
# Run VLA inference
|
526 |
print("=" * 100)
|
527 |
-
model_outputs = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key),
|
528 |
print(model_outputs.shape)
|
529 |
print("=" * 100)
|
530 |
|
|
|
524 |
|
525 |
# Run VLA inference
|
526 |
print("=" * 100)
|
527 |
+
model_outputs = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), return_dict=True, **kwargs)
|
528 |
print(model_outputs.shape)
|
529 |
print("=" * 100)
|
530 |
|