tt1225 commited on
Commit
5a1d866
·
verified ·
1 Parent(s): abec836

Update modeling_prismatic.py

Browse files
Files changed (1) hide show
  1. modeling_prismatic.py +5 -19
modeling_prismatic.py CHANGED
@@ -1,12 +1,9 @@
1
  """
2
  modeling_prismatic.py
3
-
4
  Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
5
  from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
6
  logic in `prismatic.models.vlms.prismatic.py`.
7
-
8
  Note =>> for the time being, not adding the custom HF "docstring" formatting.
9
-
10
  References [LLaVa, IDEFICS-2]:
11
  => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
12
  => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
@@ -411,7 +408,7 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
411
  use_cache=use_cache,
412
  output_attentions=True,
413
  output_hidden_states=output_hidden_states,
414
- return_dict=return_dict,
415
  )
416
 
417
  # === Otherwise =>> Assume Invalid! ===
@@ -437,9 +434,6 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
437
 
438
  return language_model_output
439
 
440
- print("Forward")
441
-
442
-
443
  return PrismaticCausalLMOutputWithPast(
444
  loss=language_model_output.loss,
445
  logits=language_model_output.logits,
@@ -485,11 +479,6 @@ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
485
  }
486
  )
487
 
488
- # model_inputs["output_attentions"] = True
489
-
490
- print("Prepare")
491
- print(model_inputs.keys())
492
-
493
  return model_inputs
494
 
495
  # Defer to Language Model (all handle this differently, with different return types)
@@ -523,10 +512,8 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
523
  )
524
 
525
  # Run VLA inference
526
- print("=" * 100)
527
- model_outputs = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), return_dict=True, **kwargs)
528
- print(model_outputs.shape)
529
- print("=" * 100)
530
 
531
  # Extract predicted action tokens and translate into (normalized) continuous actions
532
  predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
@@ -544,7 +531,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
544
  normalized_actions,
545
  )
546
 
547
- return actions, generated_ids
548
 
549
  @staticmethod
550
  def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
@@ -570,6 +557,5 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
570
  def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
571
  """Get all the logged statistics for the given dataset."""
572
  unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
573
- print("keys:", self.norm_stats.keys())
574
- print("items:", self.norm_stats)
575
  return self.norm_stats[unnorm_key]["action"]
 
 
1
  """
2
  modeling_prismatic.py
 
3
  Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
4
  from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
5
  logic in `prismatic.models.vlms.prismatic.py`.
 
6
  Note =>> for the time being, not adding the custom HF "docstring" formatting.
 
7
  References [LLaVa, IDEFICS-2]:
8
  => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
9
  => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
 
408
  use_cache=use_cache,
409
  output_attentions=True,
410
  output_hidden_states=output_hidden_states,
411
+ return_dict=True,
412
  )
413
 
414
  # === Otherwise =>> Assume Invalid! ===
 
434
 
435
  return language_model_output
436
 
 
 
 
437
  return PrismaticCausalLMOutputWithPast(
438
  loss=language_model_output.loss,
439
  logits=language_model_output.logits,
 
479
  }
480
  )
481
 
 
 
 
 
 
482
  return model_inputs
483
 
484
  # Defer to Language Model (all handle this differently, with different return types)
 
512
  )
513
 
514
  # Run VLA inference
515
+ model_outputs = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
516
+ print(model_outputs.keys())
 
 
517
 
518
  # Extract predicted action tokens and translate into (normalized) continuous actions
519
  predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
 
531
  normalized_actions,
532
  )
533
 
534
+ return actions
535
 
536
  @staticmethod
537
  def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
 
557
  def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
558
  """Get all the logged statistics for the given dataset."""
559
  unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
 
 
560
  return self.norm_stats[unnorm_key]["action"]
561
+