| """ |
| Example usage of the Pi-0 Bolt Nut Sort model |
| """ |
|
|
| from openpi.policies import policy_config |
| from openpi.training import config |
| import numpy as np |
|
|
| def load_model(checkpoint_path: str): |
| """Load the Pi-0 bolt nut sort model.""" |
| train_config = config.get_config("pi0_bns") |
| |
| policy = policy_config.create_trained_policy( |
| train_config, |
| checkpoint_path, |
| default_prompt="sort the bolts and the nuts into separate baskets" |
| ) |
| |
| return policy |
|
|
| def create_observation(images, joint_positions): |
| """Create observation dict for the model.""" |
| return { |
| "images": { |
| "cam_high": images["high"], |
| "cam_left_wrist": images["left_wrist"], |
| "cam_right_wrist": images["right_wrist"], |
| }, |
| "state": joint_positions, |
| "prompt": "sort the bolts and the nuts into separate baskets" |
| } |
|
|
| |
| if __name__ == "__main__": |
| |
| policy = load_model("./checkpoint") |
| |
| |
| images = { |
| "high": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), |
| "left_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), |
| "right_wrist": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), |
| } |
| joint_positions = np.random.randn(14).astype(np.float32) |
| |
| obs = create_observation(images, joint_positions) |
| |
| |
| result = policy.infer(obs) |
| actions = result["actions"] |
| |
| print(f"Generated actions shape: {actions.shape}") |
|
|