Upload openvla-7b+example_dataset+b16+lr-0.0005+lora-r32+dropout-0.0--image_aug+example_dataset+b16+lr-0.0005+lora-r32+dropout-0.0--image_aug/modeling_prismatic.py
ba2683d
verified
| """ | |
| modeling_prismatic.py | |
| Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting | |
| from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the | |
| logic in `prismatic.models.vlms.prismatic.py`. | |
| Note =>> for the time being, not adding the custom HF "docstring" formatting. | |
| References [LLaVa, IDEFICS-2]: | |
| => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py | |
| => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py | |
| """ | |
| import logging | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import timm | |
| import tokenizers | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| from timm.models.vision_transformer import LayerScale | |
| from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel | |
| from transformers.modeling_outputs import ModelOutput | |
| from .configuration_prismatic import OpenVLAConfig, PrismaticConfig | |
| # Get Logger | |
| logger = logging.getLogger(__name__) | |
| # === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels) | |
| IGNORE_INDEX = -100 | |
| # === Utility Functions for Monkey-Patching === | |
| def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: | |
| def wrapper(*args: Any, **kwargs: Any) -> Any: | |
| result = fn(*args, **kwargs) | |
| return result[0] if isinstance(result, tuple) else result | |
| return wrapper | |
| # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. | |
| # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 | |
| # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 | |
| def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor | |
| def ls_apply_patch(ls_module: LayerScale): | |
| ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) | |
| ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) | |
| del ls_module.gamma | |
| # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) === | |
| class PrismaticVisionBackbone(nn.Module): | |
| def __init__( | |
| self, | |
| use_fused_vision_backbone: bool, | |
| image_sizes: List[int], | |
| timm_model_ids: List[str], | |
| timm_override_act_layers: List[Optional[str]], | |
| ) -> None: | |
| super().__init__() | |
| self.use_fused_vision_backbone = use_fused_vision_backbone | |
| # [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate | |
| # =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility | |
| # Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches! | |
| assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!" | |
| self.featurizer = timm.create_model( | |
| timm_model_ids[0], | |
| pretrained=False, | |
| num_classes=0, | |
| img_size=image_sizes[0], | |
| act_layer=timm_override_act_layers[0], | |
| ) | |
| self.featurizer.forward = unpack_tuple( | |
| partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) | |
| ) | |
| self.embed_dim = self.featurizer.embed_dim | |
| # If `use_fused_vision_backbone` =>> create "beta" featurizer | |
| if self.use_fused_vision_backbone: | |
| self.fused_featurizer = timm.create_model( | |
| timm_model_ids[1], | |
| pretrained=False, | |
| num_classes=0, | |
| img_size=image_sizes[1], | |
| act_layer=timm_override_act_layers[1], | |
| ) | |
| self.fused_featurizer.forward = unpack_tuple( | |
| partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2}) | |
| ) | |
| self.embed_dim += self.fused_featurizer.embed_dim | |
| # Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale | |
| for module in self.featurizer.modules(): | |
| if isinstance(module, LayerScale): | |
| ls_apply_patch(module) | |
| if self.use_fused_vision_backbone: | |
| for module in self.fused_featurizer.modules(): | |
| if isinstance(module, LayerScale): | |
| ls_apply_patch(module) | |
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack.""" | |
| if not self.use_fused_vision_backbone: | |
| return self.featurizer(pixel_values) | |
| # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack | |
| img, img_fused = torch.split(pixel_values, [3, 3], dim=1) | |
| patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused) | |
| return torch.cat([patches, patches_fused], dim=2) | |
| # === Prismatic Projector (nn.Module) Definitions === | |
| class PrismaticProjector(nn.Module): | |
| def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None: | |
| super().__init__() | |
| self.use_fused_vision_backbone = use_fused_vision_backbone | |
| self.vision_dim, self.llm_dim = vision_dim, llm_dim | |
| # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors! | |
| if not self.use_fused_vision_backbone: | |
| self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True) | |
| self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) | |
| self.act_fn1 = nn.GELU() | |
| else: | |
| initial_projection_dim = 4 * vision_dim | |
| self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True) | |
| self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True) | |
| self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) | |
| self.act_fn1 = nn.GELU() | |
| self.act_fn2 = nn.GELU() | |
| def forward(self, img_patches: torch.Tensor) -> torch.Tensor: | |
| if not self.use_fused_vision_backbone: | |
| projected_features = self.fc1(img_patches) | |
| projected_features = self.act_fn1(projected_features) | |
| projected_features = self.fc2(projected_features) | |
| else: | |
| projected_features = self.fc1(img_patches) | |
| projected_features = self.act_fn1(projected_features) | |
| projected_features = self.fc2(projected_features) | |
| projected_features = self.act_fn2(projected_features) | |
| projected_features = self.fc3(projected_features) | |
| return projected_features | |
| # === Main HF Class Definitions === | |
| class PrismaticCausalLMOutputWithPast(ModelOutput): | |
| """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| # Additions for VLMs | |
| projector_features: Optional[torch.FloatTensor] = None | |
| class PrismaticPreTrainedModel(PreTrainedModel): | |
| config_class: PretrainedConfig = PrismaticConfig | |
| base_model_prefix: str = "model" | |
| supports_gradient_checkpointing: bool = True | |
| _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] | |
| _skip_keys_device_placement: str = "past_key_values" | |
| _supports_flash_attn_2: bool = True | |
| def _init_weights(self, module: nn.Module) -> None: | |
| # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning! | |
| # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at | |
| # https://github.com/TRI-ML/prismatic-vlms | |
| std = ( | |
| self.config.initializer_range | |
| if hasattr(self.config, "initializer_range") | |
| else self.config.text_config.initializer_range | |
| ) | |
| if hasattr(module, "class_embedding"): | |
| module.class_embedding.data.normal_(mean=0.0, std=std) | |
| if isinstance(module, (nn.Linear, nn.Conv2d)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| def _supports_sdpa(self) -> bool: | |
| """Check LLM supports SDPA Attention""" | |
| return self.language_model._supports_sdpa | |
| class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): | |
| def __init__(self, config: PrismaticConfig) -> None: | |
| super().__init__(config) | |
| # [Validation] Lightweight Validate on `config` Fields + Dependency Versions | |
| if config.use_fused_vision_backbone is None: | |
| raise ValueError("Missing config field `use_fused_vision_backbone`") | |
| if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}: | |
| raise NotImplementedError( | |
| "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue " | |
| "if you urgently need support for latest TIMM versions." | |
| ) | |
| if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"): | |
| logger.warning( | |
| f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got " | |
| f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; " | |
| f"there might be inference-time regressions due to dependency changes. If in doubt, please" | |
| f"use the above versions." | |
| ) | |
| # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone) | |
| self.vision_backbone = PrismaticVisionBackbone( | |
| config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers | |
| ) | |
| # Create Multimodal Projector | |
| self.projector = PrismaticProjector( | |
| config.use_fused_vision_backbone, | |
| vision_dim=self.vision_backbone.embed_dim, | |
| llm_dim=config.text_config.hidden_size, | |
| ) | |
| # Instantiate LLM Backbone | |
| self.language_model = AutoModelForCausalLM.from_config( | |
| config.text_config, attn_implementation=config._attn_implementation | |
| ) | |
| self.vocab_size = config.text_config.vocab_size | |
| self.pad_token_id = config.pad_token_id | |
| # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing | |
| self.post_init() | |
| # === `PreTrainedModel` Boilerplate === | |
| def get_input_embeddings(self) -> nn.Module: | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value: nn.Module) -> None: | |
| self.language_model.set_input_embeddings(value) | |
| def get_output_embeddings(self) -> nn.Module: | |
| return self.language_model.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: | |
| self.language_model.set_output_embeddings(new_embeddings) | |
| def get_decoder(self) -> nn.Module: | |
| return self.language_model.get_decoder() | |
| def set_decoder(self, decoder: nn.Module) -> None: | |
| self.language_model.set_decoder(decoder) | |
| def tie_weights(self) -> None: | |
| self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op) | |
| def resize_token_embeddings( | |
| self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None | |
| ) -> nn.Embedding: | |
| updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | |
| # Update config/instance variables | |
| self.config.text_config.vocab_size = updated_embeddings.num_embeddings | |
| self.vocab_size = updated_embeddings.num_embeddings | |
| return updated_embeddings | |
| # === Core Prismatic VLM `forward()` Logic === | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_projector_features: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: | |
| """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| output_projector_features = output_projector_features if output_projector_features is not None else False | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off) | |
| use_cache = use_cache and not self.training | |
| # Instantiate Placeholder for Projector Features | |
| projected_patch_embeddings = None | |
| # Note :: We only support forward passes with the following cases: | |
| # => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None) | |
| # => Unimodal Forward :: (pixel_values is None) | |
| # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0]) | |
| # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` === | |
| if input_ids.shape[1] == 1: | |
| assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" | |
| assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" | |
| assert labels is None, "Unexpected key `labels` provided during cached generation!" | |
| language_model_output = self.language_model( | |
| input_ids=input_ids, | |
| attention_mask=None, | |
| position_ids=None, | |
| past_key_values=past_key_values, | |
| inputs_embeds=None, | |
| labels=None, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # === Handle Unimodal Forward === | |
| elif pixel_values is None: | |
| assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" | |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" | |
| language_model_output = self.language_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=None, | |
| past_key_values=None, | |
| inputs_embeds=None, | |
| labels=labels, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # === Handle Multimodal Forward === | |
| elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): | |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" | |
| # Visual Feature Extraction | |
| patch_features = self.vision_backbone(pixel_values) | |
| # Projection Logic =>> Update Attention Mask | |
| projected_patch_embeddings = self.projector(patch_features) | |
| projected_patch_attention_mask = None | |
| if attention_mask is not None: | |
| projected_patch_attention_mask = torch.full( | |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), | |
| fill_value=True, | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device, | |
| ) | |
| # Get Input Embeddings (from Language Model Embeddings) | |
| input_embeddings = self.get_input_embeddings()(input_ids) | |
| # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:) | |
| multimodal_embeddings = torch.cat( | |
| [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 | |
| ) | |
| multimodal_attention_mask = None | |
| if attention_mask is not None: | |
| multimodal_attention_mask = torch.cat( | |
| [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 | |
| ) | |
| # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings | |
| multimodal_labels = None | |
| if labels is not None: | |
| projected_patch_labels = torch.full( | |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), | |
| fill_value=IGNORE_INDEX, | |
| dtype=labels.dtype, | |
| device=labels.device, | |
| ) | |
| multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) | |
| # Dispatch to Language Model | |
| language_model_output = self.language_model( | |
| input_ids=None, | |
| attention_mask=multimodal_attention_mask, | |
| position_ids=None, | |
| past_key_values=None, | |
| inputs_embeds=multimodal_embeddings, | |
| labels=multimodal_labels, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # === Otherwise =>> Assume Invalid! === | |
| elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): | |
| raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") | |
| else: | |
| raise ValueError( | |
| "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" | |
| f"=> `input_ids` = {input_ids is not None}\n" | |
| f"=> `attention_mask` = {attention_mask is not None}\n" | |
| f"=> `pixel_values` = {pixel_values is not None}\n" | |
| f"=> `labels` = {labels is not None}\n" | |
| f"=> `input_embeds` = {inputs_embeds is not None}\n" | |
| f"=> `past_key_values` = {past_key_values is not None}\n" | |
| f"=> `use_cache` = {use_cache}" | |
| ) | |
| # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`) | |
| if not return_dict: | |
| if output_projector_features and (projected_patch_embeddings is not None): | |
| return *language_model_output, projected_patch_embeddings | |
| return language_model_output | |
| return PrismaticCausalLMOutputWithPast( | |
| loss=language_model_output.loss, | |
| logits=language_model_output.logits, | |
| past_key_values=language_model_output.past_key_values, | |
| hidden_states=language_model_output.hidden_states, | |
| attentions=language_model_output.attentions, | |
| projector_features=projected_patch_embeddings, | |
| ) | |
| # === GenerationMixin Methods === | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| **kwargs: str, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" | |
| if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( | |
| (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) | |
| ): | |
| raise ValueError("Generation with batch size > 1 is not currently supported!") | |
| # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens | |
| if past_key_values is not None: | |
| input_ids = input_ids[:, -1:] | |
| # If `input_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and past_key_values is None: | |
| model_inputs = {"input_embeds": inputs_embeds} | |
| else: | |
| model_inputs = {"input_ids": input_ids} | |
| # Make sure `pixel_values` are preserved in `model_inputs` | |
| model_inputs.update( | |
| { | |
| "attention_mask": attention_mask, | |
| "pixel_values": pixel_values, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| } | |
| ) | |
| return model_inputs | |
| # Defer to Language Model (all handle this differently, with different return types) | |
| def _reorder_cache(self, *args, **kwargs) -> Any: | |
| return self.language_model._reorder_cache(*args, **kwargs) | |
| class OpenVLAForActionPrediction(PrismaticForConditionalGeneration): | |
| config_class: PretrainedConfig = OpenVLAConfig | |
| def __init__(self, config: OpenVLAConfig) -> None: | |
| super().__init__(config) | |
| self.norm_stats = config.norm_stats | |
| # Compute action bins | |
| self.bins = np.linspace(-1, 1, config.n_action_bins) | |
| self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 | |
| # Compute vocab size for de-tokenization -- revert added "multiple of" | |
| self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of | |
| def predict_action( | |
| self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str | |
| ) -> np.ndarray: | |
| """Thin wrapper around .generate() that decodes predicted actions and unnormalizes them.""" | |
| # If the special empty token ('') does not already appear after the colon (':') token in the prompt | |
| # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time | |
| if not torch.all(input_ids[:, -1] == 29871): | |
| input_ids = torch.cat( | |
| (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 | |
| ) | |
| # Run VLA inference | |
| generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs) | |
| # Extract predicted action tokens and translate into (normalized) continuous actions | |
| predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy() | |
| discretized_actions = self.vocab_size - predicted_action_token_ids | |
| discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) | |
| normalized_actions = self.bin_centers[discretized_actions] | |
| # Unnormalize actions | |
| action_norm_stats = self.get_action_stats(unnorm_key) | |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) | |
| action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) | |
| actions = np.where( | |
| mask, | |
| 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, | |
| normalized_actions, | |
| ) | |
| return actions | |
| def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: | |
| if unnorm_key is None: | |
| assert len(norm_stats) == 1, ( | |
| f"Your model was trained on more than one dataset, " | |
| f"please pass a `unnorm_key` from the following options to choose the statistics " | |
| f"used for un-normalizing actions: {norm_stats.keys()}" | |
| ) | |
| unnorm_key = next(iter(norm_stats.keys())) | |
| assert unnorm_key in norm_stats, ( | |
| f"The `unnorm_key` you chose is not in the set of available dataset statistics, " | |
| f"please choose from: {norm_stats.keys()}" | |
| ) | |
| return unnorm_key | |
| def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: | |
| """Get the dimensionality of the policy's action space.""" | |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) | |
| return len(self.norm_stats[unnorm_key]["action"]["q01"]) | |
| def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: | |
| """Get all the logged statistics for the given dataset.""" | |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) | |
| return self.norm_stats[unnorm_key]["action"] | |