Update modeling_prismatic.py
Browse files- modeling_prismatic.py +5 -19
modeling_prismatic.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
"""
|
2 |
modeling_prismatic.py
|
3 |
-
|
4 |
Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
|
5 |
from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
|
6 |
logic in `prismatic.models.vlms.prismatic.py`.
|
7 |
-
|
8 |
Note =>> for the time being, not adding the custom HF "docstring" formatting.
|
9 |
-
|
10 |
References [LLaVa, IDEFICS-2]:
|
11 |
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
|
12 |
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
|
@@ -411,7 +408,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
411 |
use_cache=use_cache,
|
412 |
output_attentions=True,
|
413 |
output_hidden_states=output_hidden_states,
|
414 |
-
return_dict=
|
415 |
)
|
416 |
|
417 |
# === Otherwise =>> Assume Invalid! ===
|
@@ -437,9 +434,6 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
437 |
|
438 |
return language_model_output
|
439 |
|
440 |
-
print("Forward")
|
441 |
-
|
442 |
-
|
443 |
return PrismaticCausalLMOutputWithPast(
|
444 |
loss=language_model_output.loss,
|
445 |
logits=language_model_output.logits,
|
@@ -485,11 +479,6 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
|
485 |
}
|
486 |
)
|
487 |
|
488 |
-
# model_inputs["output_attentions"] = True
|
489 |
-
|
490 |
-
print("Prepare")
|
491 |
-
print(model_inputs.keys())
|
492 |
-
|
493 |
return model_inputs
|
494 |
|
495 |
# Defer to Language Model (all handle this differently, with different return types)
|
@@ -523,10 +512,8 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
523 |
)
|
524 |
|
525 |
# Run VLA inference
|
526 |
-
|
527 |
-
model_outputs
|
528 |
-
print(model_outputs.shape)
|
529 |
-
print("=" * 100)
|
530 |
|
531 |
# Extract predicted action tokens and translate into (normalized) continuous actions
|
532 |
predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
|
@@ -544,7 +531,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
544 |
normalized_actions,
|
545 |
)
|
546 |
|
547 |
-
return actions
|
548 |
|
549 |
@staticmethod
|
550 |
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
@@ -570,6 +557,5 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
570 |
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
|
571 |
"""Get all the logged statistics for the given dataset."""
|
572 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
573 |
-
print("keys:", self.norm_stats.keys())
|
574 |
-
print("items:", self.norm_stats)
|
575 |
return self.norm_stats[unnorm_key]["action"]
|
|
|
|
1 |
"""
|
2 |
modeling_prismatic.py
|
|
|
3 |
Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
|
4 |
from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
|
5 |
logic in `prismatic.models.vlms.prismatic.py`.
|
|
|
6 |
Note =>> for the time being, not adding the custom HF "docstring" formatting.
|
|
|
7 |
References [LLaVa, IDEFICS-2]:
|
8 |
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
|
9 |
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
|
|
|
408 |
use_cache=use_cache,
|
409 |
output_attentions=True,
|
410 |
output_hidden_states=output_hidden_states,
|
411 |
+
return_dict=True,
|
412 |
)
|
413 |
|
414 |
# === Otherwise =>> Assume Invalid! ===
|
|
|
434 |
|
435 |
return language_model_output
|
436 |
|
|
|
|
|
|
|
437 |
return PrismaticCausalLMOutputWithPast(
|
438 |
loss=language_model_output.loss,
|
439 |
logits=language_model_output.logits,
|
|
|
479 |
}
|
480 |
)
|
481 |
|
|
|
|
|
|
|
|
|
|
|
482 |
return model_inputs
|
483 |
|
484 |
# Defer to Language Model (all handle this differently, with different return types)
|
|
|
512 |
)
|
513 |
|
514 |
# Run VLA inference
|
515 |
+
model_outputs = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
|
516 |
+
print(model_outputs.keys())
|
|
|
|
|
517 |
|
518 |
# Extract predicted action tokens and translate into (normalized) continuous actions
|
519 |
predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
|
|
|
531 |
normalized_actions,
|
532 |
)
|
533 |
|
534 |
+
return actions
|
535 |
|
536 |
@staticmethod
|
537 |
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
|
|
557 |
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
|
558 |
"""Get all the logged statistics for the given dataset."""
|
559 |
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
|
|
|
|
560 |
return self.norm_stats[unnorm_key]["action"]
|
561 |
+
|