alperenunlu's picture
Push model
94e16fc verified
import os
import random
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, SupportsFloat
import ale_py # noqa: F401
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from safetensors.torch import save_model
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
@dataclass
class HyperParams:
env_id: str = "SpaceInvadersNoFrameskip-v4"
"""The ID of the environment to train on."""
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of the experiment, used for saving models and logs."""
n_envs: int = 8
"""The number of parallel environments to run."""
seed: int = 0
"""The random seed for reproducibility."""
total_timesteps: int = 10_000_000
"""The total number of timesteps to train the agent."""
buffer_size: int = 1_000_000
"""The size of the replay buffer to store transitions."""
video_capture_frequency: int = 5
"""The interval (in episodes) to record videos of the agent's performance."""
initial_exploration: float = 1
"""The initial exploration rate for the epsilon-greedy policy."""
final_exploration: float = 0.01
"""The final exploration rate after annealing."""
exploration_fraction: float = 0.1
"""The fraction of total timesteps over which to anneal the exploration rate."""
learning_start: int = 80_000
"""The number of timesteps before starting to learn."""
train_frequency: int = 4
"""The frequency (in timesteps) to update the Q-network."""
batch_size: int = 32
"""The batch size for sampling from the replay buffer."""
gamma: float = 0.99
"""The discount factor (gamma) for future rewards."""
learning_rate: float = 1e-4
"""The learning rate for the optimizer."""
target_network_update_frequency: int = 1_000
"""The frequency (in timesteps) to update the target network."""
tau: float = 1.0
"""The rate at which to update the target network towards the Q-network."""
log_interval: int = 100
"""The interval (in timesteps) to log training statistics."""
save_times: int = 10
"""The number of times to save the model during training."""
evaluate: bool = True
"""Whether to evaluate the agent after training."""
eval_episodes: int = 10
"""The number of episodes to run for evaluation."""
push_model: bool = True
hf_entity: str = "alperenunlu"
class ReplayBuffer:
def __init__(
self,
buffer_size: int,
observation_space: gym.Space,
action_space: gym.Space,
device: torch.device | str = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = True,
) -> None:
self.buffer_size = max(buffer_size // n_envs, 1)
self.n_envs = n_envs
self.optimize_memory_usage = optimize_memory_usage
self.obs_shape = self.get_obs_shape(observation_space)
self.action_dim = self.get_action_dim(action_space)
self.device = self.get_device(device)
self.observations = np.zeros(
(self.buffer_size, self.n_envs, *self.obs_shape),
dtype=observation_space.dtype,
)
if not self.optimize_memory_usage:
self.next_observations = np.zeros(
(self.buffer_size, self.n_envs, *self.obs_shape),
dtype=observation_space.dtype,
)
else:
self.next_observations = None
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim),
dtype=action_space.dtype,
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.index = 0
self.size = 0
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
) -> None:
"""Add a new transition to the replay buffer."""
self.observations[self.index] = np.asarray(obs)
if self.optimize_memory_usage:
# Store next_obs at the *next* slot to save memory
self.observations[(self.index + 1) % self.buffer_size] = np.asarray(
next_obs
)
else:
self.next_observations[self.index] = np.asarray(next_obs)
self.actions[self.index] = np.asarray(actions).reshape(
self.n_envs, self.action_dim
)
self.rewards[self.index] = np.asarray(rewards)
self.dones[self.index] = np.asarray(dones)
self.index = (self.index + 1) % self.buffer_size
self.size = min(self.size + 1, self.buffer_size)
def sample(self, batch_size: int) -> dict[str, torch.Tensor]:
"""Sample a batch of data from the replay buffer."""
if not self.optimize_memory_usage:
batch_idx = np.random.randint(0, self.size, size=batch_size)
else:
# Do not sample the write index because its (obs,next_obs) pair is invalid
if self.size == self.buffer_size:
batch_idx = (
np.random.randint(1, self.buffer_size, size=batch_size) + self.index
) % self.buffer_size
else:
batch_idx = np.random.randint(0, self.index, size=batch_size)
env_idx = np.random.randint(0, self.n_envs, size=batch_size)
if self.optimize_memory_usage:
next_obs = self.observations[
(batch_idx + 1) % self.buffer_size, env_idx, ...
]
else:
next_obs = self.next_observations[batch_idx, env_idx, ...]
data = dict(
obs=self.observations[batch_idx, env_idx, ...],
next_obs=next_obs,
actions=self.actions[batch_idx, env_idx, ...],
rewards=self.rewards[batch_idx, env_idx, ...],
dones=self.dones[batch_idx, env_idx, ...],
)
return self.to_torch(data)
def to_torch(self, data: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
"""Convert numpy arrays to torch tensors and move them to the specified device."""
tensor_data = dict()
for k, v in data.items():
tensor_data[k] = torch.from_numpy(v).to(device=self.device)
return tensor_data
@staticmethod
def get_device(device: torch.device | str = "auto") -> torch.device:
"""Get the device to use for computations."""
if device == "auto":
return torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
else:
return torch.device(device)
@staticmethod
def get_obs_shape(
observation_space: gym.Space,
) -> tuple[int, ...]:
"""Get the shape of the observation space."""
if isinstance(observation_space, gym.spaces.Box):
return observation_space.shape
elif isinstance(observation_space, gym.spaces.Discrete):
return (1,)
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, gym.spaces.MultiBinary):
return observation_space.shape
else:
raise NotImplementedError(
f"{observation_space} observation space is not supported"
)
@staticmethod
def get_action_dim(action_space: gym.spaces.Space) -> int:
"""Get the dimension of the action space."""
if isinstance(action_space, gym.spaces.Box):
return int(np.prod(action_space.shape))
elif isinstance(action_space, gym.spaces.Discrete):
return 1
elif isinstance(action_space, gym.spaces.MultiDiscrete):
return int(len(action_space.nvec))
else:
raise NotImplementedError(f"{action_space} action space is not supported")
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Take action on reset for environments that are fixed until firing.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
def reset(self, **kwargs) -> tuple[np.ndarray, dict[str, Any]]:
self.env.reset(**kwargs)
obs, _, terminated, truncated, info = self.env.step(1)
if terminated or truncated:
self.env.reset(**kwargs)
obs, _, terminated, truncated, info = self.env.step(2)
if terminated or truncated:
self.env.reset(**kwargs)
return obs, info
class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.lives = 0
self.was_real_done = True
def step(
self, action: int
) -> tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]]:
obs, reward, terminated, truncated, info = self.env.step(action)
self.was_real_done = terminated or truncated
lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
if 0 < lives < self.lives:
terminated = True
self.lives = lives
return obs, reward, terminated, truncated, info
def reset(self, **kwargs) -> tuple[np.ndarray, dict[str, Any]]:
"""
Calls the Gym environment reset, only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
:param kwargs: Extra keywords passed to env.reset() call
:return: the first observation of the environment
"""
if self.was_real_done:
obs, info = self.env.reset(**kwargs)
else:
obs, _, terminated, truncated, info = self.env.step(0)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
return obs, info
def make_env(
env_id: str, seed: int, idx: int, video_capture_frequency: int, run_name: str
) -> Callable[[], gym.Env]:
"""Create a gym environment with specific configurations.
Args:
env_id (str): The ID of the environment to create.
seed (int): The seed for random number generation.
idx (int): The index of the environment (for vectorized environments).
video_freq (int): Frequency of recording videos (0 to disable).
run_name (str): The name of the run for saving videos.
Returns:
Callable[[], gym.Env]: A function that returns the configured environment.
"""
def _thunk() -> gym.Env:
if video_capture_frequency > 0 and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(
env,
video_folder=f"videos/{run_name}",
episode_trigger=lambda x: x % video_capture_frequency == 0,
name_prefix=env_id,
)
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.AtariPreprocessing(
env,
noop_max=30,
frame_skip=4,
screen_size=(84, 84),
terminal_on_life_loss=False,
grayscale_obs=True,
grayscale_newaxis=False,
scale_obs=False,
)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined]
env = FireResetEnv(env)
env = gym.wrappers.ClipReward(env, -1, +1)
env = gym.wrappers.FrameStackObservation(env, stack_size=4)
env.action_space.seed(seed)
return env
return _thunk
class QNetwork(nn.Module):
def __init__(self, env) -> None:
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(7 * 7 * 64, 512),
nn.ReLU(inplace=True),
nn.Linear(512, env.single_action_space.n),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x / 255.0)
def linear_schedule(t: int, start_e: float, end_e: float, duration: int) -> float:
"""Linear annealing from start_e to end_e over duration steps.
Args:
t (int): Current step.
start_e (float): Initial exploration rate.
end_e (float): Final exploration rate.
duration (float): Duration of the annealing in steps.
Returns:
float: The exploration rate at step t.
"""
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)
def main() -> None:
args = tyro.cli(HyperParams)
run_name = f"{args.env_id}_{args.exp_name.replace('/', '_')}_{args.seed}_{datetime.now().strftime('%y%m%d_%H%M%S')}"
print(run_name)
writer = SummaryWriter(f"runs/{run_name}")
keyval_str = "\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])
writer.add_text(
"hyperparameters",
f"|param|value|\n|-|-|\n{keyval_str}",
)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
envs = gym.vector.AsyncVectorEnv(
[
make_env(
args.env_id, args.seed + i, i, args.video_capture_frequency, run_name
)
for i in range(int(args.n_envs))
],
autoreset_mode=gym.vector.AutoresetMode.DISABLED,
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete)
envs.action_space.seed(args.seed)
q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
target_network = QNetwork(envs).to(device)
target_network.load_state_dict(q_network.state_dict())
target_network.eval()
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device=device,
n_envs=envs.num_envs,
)
start_time = time.time()
obs, _ = envs.reset(seed=args.seed)
step_pbar = tqdm(total=args.total_timesteps)
postfix_dict = dict()
for step in range(0, args.total_timesteps, envs.num_envs):
epsilon = linear_schedule(
step,
args.initial_exploration,
args.final_exploration,
int(args.total_timesteps * args.exploration_fraction),
)
writer.add_scalar("charts/epsilon", epsilon, step)
postfix_dict.update(e=epsilon)
with torch.no_grad():
q_values = q_network(torch.from_numpy(obs).to(device))
greedy_actions = torch.argmax(q_values, dim=1).cpu().numpy()
randn_actions = envs.action_space.sample()
mask_actions = np.random.rand(envs.num_envs) < epsilon
actions = np.where(mask_actions, randn_actions, greedy_actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
if "episode" in infos:
mask = infos["_episode"]
r_mean = infos["episode"]["r"][mask].mean()
l_mean = infos["episode"]["l"][mask].mean()
writer.add_scalar("charts/episodic_return", r_mean, step)
writer.add_scalar("charts/episodic_length", l_mean, step)
postfix_dict.update(
r=r_mean,
l=l_mean,
)
rb.add(obs, next_obs, actions, rewards, terminations)
dones = np.logical_or(terminations, truncations)
if dones.any():
obs, _ = envs.reset(options={"reset_mask": dones})
else:
obs = next_obs
for parallel_step in range(
step, min(step + envs.num_envs, args.total_timesteps)
):
if parallel_step > args.learning_start:
if parallel_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
with torch.no_grad():
target_max, _ = target_network(data["next_obs"]).max(dim=1)
td_target = data[
"rewards"
].flatten() + args.gamma * target_max * (
1 - data["dones"].flatten()
)
q_val = q_network(data["obs"]).gather(1, data["actions"]).squeeze()
loss = F.smooth_l1_loss(q_val, td_target)
if step % args.log_interval < envs.num_envs:
writer.add_scalar("losses/td_loss", loss, parallel_step)
writer.add_scalar(
"losses/q_values", q_val.mean().item(), parallel_step
)
writer.add_scalar(
"charts/SPS",
step // (time.time() - start_time),
parallel_step,
)
postfix_dict.update(
td_loss=loss.item(),
q_val_mean=q_val.mean().item(),
sps=step // (time.time() - start_time),
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if parallel_step % args.target_network_update_frequency == 0:
for target_network_param, q_network_param in zip(
target_network.parameters(), q_network.parameters()
):
target_network_param.data.lerp_(q_network_param.data, args.tau)
if parallel_step % (args.total_timesteps // args.save_times) == 0:
save_model(
model=q_network,
filename=f"runs/{run_name}/{args.exp_name}_{step}.safetensors",
)
step_pbar.set_postfix(postfix_dict)
step_pbar.update(envs.num_envs)
envs.close()
step_pbar.close()
if args.evaluate:
run_name_eval = f"{run_name}_eval"
final_model_path = f"runs/{run_name}/{args.exp_name}_final.safetensors"
save_model(model=q_network, filename=final_model_path)
from evals.dqn_eval import evaluate
episode_rewards = evaluate(
final_model_path=final_model_path,
make_env=make_env,
env_id=args.env_id,
run_name_eval=run_name_eval,
QNetwork=QNetwork,
device=device,
eval_episodes=args.eval_episodes,
epsilon=args.final_exploration,
)
for i, r in enumerate(episode_rewards):
writer.add_scalar("eval/episodic_return_eval", r, i)
if args.push_model:
from utils import push_model
push_model(
args=args,
episode_rewards=episode_rewards,
algo_name="DQN",
run_path=f"runs/{run_name}",
video_folder_path=f"videos/{run_name_eval}",
)
writer.close()
if __name__ == "__main__":
main()