|
import os |
|
import random |
|
import time |
|
import warnings |
|
from dataclasses import dataclass |
|
from datetime import datetime |
|
from typing import Any, Callable, SupportsFloat |
|
|
|
import ale_py |
|
import gymnasium as gym |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import tyro |
|
from safetensors.torch import save_model |
|
from torch.utils.tensorboard import SummaryWriter |
|
from tqdm.auto import tqdm |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
device = torch.device( |
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "mps" |
|
if torch.backends.mps.is_available() |
|
else "cpu" |
|
) |
|
|
|
|
|
@dataclass |
|
class HyperParams: |
|
env_id: str = "SpaceInvadersNoFrameskip-v4" |
|
"""The ID of the environment to train on.""" |
|
exp_name: str = os.path.basename(__file__)[: -len(".py")] |
|
"""The name of the experiment, used for saving models and logs.""" |
|
n_envs: int = 8 |
|
"""The number of parallel environments to run.""" |
|
seed: int = 0 |
|
"""The random seed for reproducibility.""" |
|
total_timesteps: int = 10_000_000 |
|
"""The total number of timesteps to train the agent.""" |
|
buffer_size: int = 1_000_000 |
|
"""The size of the replay buffer to store transitions.""" |
|
video_capture_frequency: int = 5 |
|
"""The interval (in episodes) to record videos of the agent's performance.""" |
|
|
|
initial_exploration: float = 1 |
|
"""The initial exploration rate for the epsilon-greedy policy.""" |
|
final_exploration: float = 0.01 |
|
"""The final exploration rate after annealing.""" |
|
exploration_fraction: float = 0.1 |
|
"""The fraction of total timesteps over which to anneal the exploration rate.""" |
|
|
|
learning_start: int = 80_000 |
|
"""The number of timesteps before starting to learn.""" |
|
train_frequency: int = 4 |
|
"""The frequency (in timesteps) to update the Q-network.""" |
|
batch_size: int = 32 |
|
"""The batch size for sampling from the replay buffer.""" |
|
gamma: float = 0.99 |
|
"""The discount factor (gamma) for future rewards.""" |
|
learning_rate: float = 1e-4 |
|
"""The learning rate for the optimizer.""" |
|
target_network_update_frequency: int = 1_000 |
|
"""The frequency (in timesteps) to update the target network.""" |
|
tau: float = 1.0 |
|
"""The rate at which to update the target network towards the Q-network.""" |
|
|
|
log_interval: int = 100 |
|
"""The interval (in timesteps) to log training statistics.""" |
|
save_times: int = 10 |
|
"""The number of times to save the model during training.""" |
|
|
|
evaluate: bool = True |
|
"""Whether to evaluate the agent after training.""" |
|
eval_episodes: int = 10 |
|
"""The number of episodes to run for evaluation.""" |
|
|
|
push_model: bool = True |
|
hf_entity: str = "alperenunlu" |
|
|
|
|
|
class ReplayBuffer: |
|
def __init__( |
|
self, |
|
buffer_size: int, |
|
observation_space: gym.Space, |
|
action_space: gym.Space, |
|
device: torch.device | str = "auto", |
|
n_envs: int = 1, |
|
optimize_memory_usage: bool = True, |
|
) -> None: |
|
self.buffer_size = max(buffer_size // n_envs, 1) |
|
self.n_envs = n_envs |
|
self.optimize_memory_usage = optimize_memory_usage |
|
|
|
self.obs_shape = self.get_obs_shape(observation_space) |
|
self.action_dim = self.get_action_dim(action_space) |
|
|
|
self.device = self.get_device(device) |
|
|
|
self.observations = np.zeros( |
|
(self.buffer_size, self.n_envs, *self.obs_shape), |
|
dtype=observation_space.dtype, |
|
) |
|
if not self.optimize_memory_usage: |
|
self.next_observations = np.zeros( |
|
(self.buffer_size, self.n_envs, *self.obs_shape), |
|
dtype=observation_space.dtype, |
|
) |
|
else: |
|
self.next_observations = None |
|
|
|
self.actions = np.zeros( |
|
(self.buffer_size, self.n_envs, self.action_dim), |
|
dtype=action_space.dtype, |
|
) |
|
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
|
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
|
|
|
self.index = 0 |
|
self.size = 0 |
|
|
|
def add( |
|
self, |
|
obs: np.ndarray, |
|
next_obs: np.ndarray, |
|
actions: np.ndarray, |
|
rewards: np.ndarray, |
|
dones: np.ndarray, |
|
) -> None: |
|
"""Add a new transition to the replay buffer.""" |
|
self.observations[self.index] = np.asarray(obs) |
|
|
|
if self.optimize_memory_usage: |
|
|
|
self.observations[(self.index + 1) % self.buffer_size] = np.asarray( |
|
next_obs |
|
) |
|
else: |
|
self.next_observations[self.index] = np.asarray(next_obs) |
|
|
|
self.actions[self.index] = np.asarray(actions).reshape( |
|
self.n_envs, self.action_dim |
|
) |
|
self.rewards[self.index] = np.asarray(rewards) |
|
self.dones[self.index] = np.asarray(dones) |
|
|
|
self.index = (self.index + 1) % self.buffer_size |
|
self.size = min(self.size + 1, self.buffer_size) |
|
|
|
def sample(self, batch_size: int) -> dict[str, torch.Tensor]: |
|
"""Sample a batch of data from the replay buffer.""" |
|
if not self.optimize_memory_usage: |
|
batch_idx = np.random.randint(0, self.size, size=batch_size) |
|
else: |
|
|
|
if self.size == self.buffer_size: |
|
batch_idx = ( |
|
np.random.randint(1, self.buffer_size, size=batch_size) + self.index |
|
) % self.buffer_size |
|
else: |
|
batch_idx = np.random.randint(0, self.index, size=batch_size) |
|
|
|
env_idx = np.random.randint(0, self.n_envs, size=batch_size) |
|
|
|
if self.optimize_memory_usage: |
|
next_obs = self.observations[ |
|
(batch_idx + 1) % self.buffer_size, env_idx, ... |
|
] |
|
else: |
|
next_obs = self.next_observations[batch_idx, env_idx, ...] |
|
|
|
data = dict( |
|
obs=self.observations[batch_idx, env_idx, ...], |
|
next_obs=next_obs, |
|
actions=self.actions[batch_idx, env_idx, ...], |
|
rewards=self.rewards[batch_idx, env_idx, ...], |
|
dones=self.dones[batch_idx, env_idx, ...], |
|
) |
|
return self.to_torch(data) |
|
|
|
def to_torch(self, data: dict[str, np.ndarray]) -> dict[str, torch.Tensor]: |
|
"""Convert numpy arrays to torch tensors and move them to the specified device.""" |
|
tensor_data = dict() |
|
for k, v in data.items(): |
|
tensor_data[k] = torch.from_numpy(v).to(device=self.device) |
|
return tensor_data |
|
|
|
@staticmethod |
|
def get_device(device: torch.device | str = "auto") -> torch.device: |
|
"""Get the device to use for computations.""" |
|
if device == "auto": |
|
return torch.device( |
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "mps" |
|
if torch.backends.mps.is_available() |
|
else "cpu" |
|
) |
|
else: |
|
return torch.device(device) |
|
|
|
@staticmethod |
|
def get_obs_shape( |
|
observation_space: gym.Space, |
|
) -> tuple[int, ...]: |
|
"""Get the shape of the observation space.""" |
|
if isinstance(observation_space, gym.spaces.Box): |
|
return observation_space.shape |
|
elif isinstance(observation_space, gym.spaces.Discrete): |
|
return (1,) |
|
elif isinstance(observation_space, gym.spaces.MultiDiscrete): |
|
return (int(len(observation_space.nvec)),) |
|
elif isinstance(observation_space, gym.spaces.MultiBinary): |
|
return observation_space.shape |
|
else: |
|
raise NotImplementedError( |
|
f"{observation_space} observation space is not supported" |
|
) |
|
|
|
@staticmethod |
|
def get_action_dim(action_space: gym.spaces.Space) -> int: |
|
"""Get the dimension of the action space.""" |
|
if isinstance(action_space, gym.spaces.Box): |
|
return int(np.prod(action_space.shape)) |
|
elif isinstance(action_space, gym.spaces.Discrete): |
|
return 1 |
|
elif isinstance(action_space, gym.spaces.MultiDiscrete): |
|
return int(len(action_space.nvec)) |
|
else: |
|
raise NotImplementedError(f"{action_space} action space is not supported") |
|
|
|
|
|
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): |
|
""" |
|
Take action on reset for environments that are fixed until firing. |
|
|
|
:param env: Environment to wrap |
|
""" |
|
|
|
def __init__(self, env: gym.Env) -> None: |
|
super().__init__(env) |
|
assert env.unwrapped.get_action_meanings()[1] == "FIRE" |
|
assert len(env.unwrapped.get_action_meanings()) >= 3 |
|
|
|
def reset(self, **kwargs) -> tuple[np.ndarray, dict[str, Any]]: |
|
self.env.reset(**kwargs) |
|
obs, _, terminated, truncated, info = self.env.step(1) |
|
if terminated or truncated: |
|
self.env.reset(**kwargs) |
|
obs, _, terminated, truncated, info = self.env.step(2) |
|
if terminated or truncated: |
|
self.env.reset(**kwargs) |
|
return obs, info |
|
|
|
|
|
class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): |
|
""" |
|
Make end-of-life == end-of-episode, but only reset on true game over. |
|
Done by DeepMind for the DQN and co. since it helps value estimation. |
|
|
|
:param env: Environment to wrap |
|
""" |
|
|
|
def __init__(self, env: gym.Env) -> None: |
|
super().__init__(env) |
|
self.lives = 0 |
|
self.was_real_done = True |
|
|
|
def step( |
|
self, action: int |
|
) -> tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]]: |
|
obs, reward, terminated, truncated, info = self.env.step(action) |
|
self.was_real_done = terminated or truncated |
|
lives = self.env.unwrapped.ale.lives() |
|
if 0 < lives < self.lives: |
|
terminated = True |
|
self.lives = lives |
|
return obs, reward, terminated, truncated, info |
|
|
|
def reset(self, **kwargs) -> tuple[np.ndarray, dict[str, Any]]: |
|
""" |
|
Calls the Gym environment reset, only when lives are exhausted. |
|
This way all states are still reachable even though lives are episodic, |
|
and the learner need not know about any of this behind-the-scenes. |
|
|
|
:param kwargs: Extra keywords passed to env.reset() call |
|
:return: the first observation of the environment |
|
""" |
|
if self.was_real_done: |
|
obs, info = self.env.reset(**kwargs) |
|
else: |
|
obs, _, terminated, truncated, info = self.env.step(0) |
|
|
|
if terminated or truncated: |
|
obs, info = self.env.reset(**kwargs) |
|
self.lives = self.env.unwrapped.ale.lives() |
|
return obs, info |
|
|
|
|
|
def make_env( |
|
env_id: str, seed: int, idx: int, video_capture_frequency: int, run_name: str |
|
) -> Callable[[], gym.Env]: |
|
"""Create a gym environment with specific configurations. |
|
|
|
Args: |
|
env_id (str): The ID of the environment to create. |
|
seed (int): The seed for random number generation. |
|
idx (int): The index of the environment (for vectorized environments). |
|
video_freq (int): Frequency of recording videos (0 to disable). |
|
run_name (str): The name of the run for saving videos. |
|
|
|
Returns: |
|
Callable[[], gym.Env]: A function that returns the configured environment. |
|
""" |
|
|
|
def _thunk() -> gym.Env: |
|
if video_capture_frequency > 0 and idx == 0: |
|
env = gym.make(env_id, render_mode="rgb_array") |
|
env = gym.wrappers.RecordVideo( |
|
env, |
|
video_folder=f"videos/{run_name}", |
|
episode_trigger=lambda x: x % video_capture_frequency == 0, |
|
name_prefix=env_id, |
|
) |
|
else: |
|
env = gym.make(env_id) |
|
env = gym.wrappers.RecordEpisodeStatistics(env) |
|
env = gym.wrappers.AtariPreprocessing( |
|
env, |
|
noop_max=30, |
|
frame_skip=4, |
|
screen_size=(84, 84), |
|
terminal_on_life_loss=False, |
|
grayscale_obs=True, |
|
grayscale_newaxis=False, |
|
scale_obs=False, |
|
) |
|
env = EpisodicLifeEnv(env) |
|
if "FIRE" in env.unwrapped.get_action_meanings(): |
|
env = FireResetEnv(env) |
|
env = gym.wrappers.ClipReward(env, -1, +1) |
|
env = gym.wrappers.FrameStackObservation(env, stack_size=4) |
|
env.action_space.seed(seed) |
|
return env |
|
|
|
return _thunk |
|
|
|
|
|
class QNetwork(nn.Module): |
|
def __init__(self, env) -> None: |
|
super().__init__() |
|
self.network = nn.Sequential( |
|
nn.Conv2d(4, 32, 8, stride=4), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(32, 64, 4, stride=2), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(64, 64, 3, stride=1), |
|
nn.ReLU(inplace=True), |
|
nn.Flatten(), |
|
nn.Linear(7 * 7 * 64, 512), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(512, env.single_action_space.n), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.network(x / 255.0) |
|
|
|
|
|
def linear_schedule(t: int, start_e: float, end_e: float, duration: int) -> float: |
|
"""Linear annealing from start_e to end_e over duration steps. |
|
|
|
Args: |
|
t (int): Current step. |
|
start_e (float): Initial exploration rate. |
|
end_e (float): Final exploration rate. |
|
duration (float): Duration of the annealing in steps. |
|
|
|
Returns: |
|
float: The exploration rate at step t. |
|
""" |
|
slope = (end_e - start_e) / duration |
|
return max(slope * t + start_e, end_e) |
|
|
|
|
|
def main() -> None: |
|
args = tyro.cli(HyperParams) |
|
run_name = f"{args.env_id}_{args.exp_name.replace('/', '_')}_{args.seed}_{datetime.now().strftime('%y%m%d_%H%M%S')}" |
|
print(run_name) |
|
writer = SummaryWriter(f"runs/{run_name}") |
|
keyval_str = "\n".join([f"|{key}|{value}|" for key, value in vars(args).items()]) |
|
writer.add_text( |
|
"hyperparameters", |
|
f"|param|value|\n|-|-|\n{keyval_str}", |
|
) |
|
|
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
|
|
envs = gym.vector.AsyncVectorEnv( |
|
[ |
|
make_env( |
|
args.env_id, args.seed + i, i, args.video_capture_frequency, run_name |
|
) |
|
for i in range(int(args.n_envs)) |
|
], |
|
autoreset_mode=gym.vector.AutoresetMode.DISABLED, |
|
) |
|
assert isinstance(envs.single_action_space, gym.spaces.Discrete) |
|
envs.action_space.seed(args.seed) |
|
|
|
q_network = QNetwork(envs).to(device) |
|
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) |
|
target_network = QNetwork(envs).to(device) |
|
target_network.load_state_dict(q_network.state_dict()) |
|
target_network.eval() |
|
|
|
rb = ReplayBuffer( |
|
args.buffer_size, |
|
envs.single_observation_space, |
|
envs.single_action_space, |
|
device=device, |
|
n_envs=envs.num_envs, |
|
) |
|
|
|
start_time = time.time() |
|
|
|
obs, _ = envs.reset(seed=args.seed) |
|
step_pbar = tqdm(total=args.total_timesteps) |
|
postfix_dict = dict() |
|
for step in range(0, args.total_timesteps, envs.num_envs): |
|
epsilon = linear_schedule( |
|
step, |
|
args.initial_exploration, |
|
args.final_exploration, |
|
int(args.total_timesteps * args.exploration_fraction), |
|
) |
|
writer.add_scalar("charts/epsilon", epsilon, step) |
|
postfix_dict.update(e=epsilon) |
|
|
|
with torch.no_grad(): |
|
q_values = q_network(torch.from_numpy(obs).to(device)) |
|
greedy_actions = torch.argmax(q_values, dim=1).cpu().numpy() |
|
|
|
randn_actions = envs.action_space.sample() |
|
mask_actions = np.random.rand(envs.num_envs) < epsilon |
|
actions = np.where(mask_actions, randn_actions, greedy_actions) |
|
|
|
next_obs, rewards, terminations, truncations, infos = envs.step(actions) |
|
|
|
if "episode" in infos: |
|
mask = infos["_episode"] |
|
r_mean = infos["episode"]["r"][mask].mean() |
|
l_mean = infos["episode"]["l"][mask].mean() |
|
|
|
writer.add_scalar("charts/episodic_return", r_mean, step) |
|
writer.add_scalar("charts/episodic_length", l_mean, step) |
|
postfix_dict.update( |
|
r=r_mean, |
|
l=l_mean, |
|
) |
|
|
|
rb.add(obs, next_obs, actions, rewards, terminations) |
|
|
|
dones = np.logical_or(terminations, truncations) |
|
if dones.any(): |
|
obs, _ = envs.reset(options={"reset_mask": dones}) |
|
else: |
|
obs = next_obs |
|
|
|
for parallel_step in range( |
|
step, min(step + envs.num_envs, args.total_timesteps) |
|
): |
|
if parallel_step > args.learning_start: |
|
if parallel_step % args.train_frequency == 0: |
|
data = rb.sample(args.batch_size) |
|
with torch.no_grad(): |
|
target_max, _ = target_network(data["next_obs"]).max(dim=1) |
|
td_target = data[ |
|
"rewards" |
|
].flatten() + args.gamma * target_max * ( |
|
1 - data["dones"].flatten() |
|
) |
|
q_val = q_network(data["obs"]).gather(1, data["actions"]).squeeze() |
|
loss = F.smooth_l1_loss(q_val, td_target) |
|
|
|
if step % args.log_interval < envs.num_envs: |
|
writer.add_scalar("losses/td_loss", loss, parallel_step) |
|
writer.add_scalar( |
|
"losses/q_values", q_val.mean().item(), parallel_step |
|
) |
|
writer.add_scalar( |
|
"charts/SPS", |
|
step // (time.time() - start_time), |
|
parallel_step, |
|
) |
|
postfix_dict.update( |
|
td_loss=loss.item(), |
|
q_val_mean=q_val.mean().item(), |
|
sps=step // (time.time() - start_time), |
|
) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if parallel_step % args.target_network_update_frequency == 0: |
|
for target_network_param, q_network_param in zip( |
|
target_network.parameters(), q_network.parameters() |
|
): |
|
target_network_param.data.lerp_(q_network_param.data, args.tau) |
|
|
|
if parallel_step % (args.total_timesteps // args.save_times) == 0: |
|
save_model( |
|
model=q_network, |
|
filename=f"runs/{run_name}/{args.exp_name}_{step}.safetensors", |
|
) |
|
step_pbar.set_postfix(postfix_dict) |
|
step_pbar.update(envs.num_envs) |
|
envs.close() |
|
step_pbar.close() |
|
|
|
if args.evaluate: |
|
run_name_eval = f"{run_name}_eval" |
|
final_model_path = f"runs/{run_name}/{args.exp_name}_final.safetensors" |
|
save_model(model=q_network, filename=final_model_path) |
|
from evals.dqn_eval import evaluate |
|
|
|
episode_rewards = evaluate( |
|
final_model_path=final_model_path, |
|
make_env=make_env, |
|
env_id=args.env_id, |
|
run_name_eval=run_name_eval, |
|
QNetwork=QNetwork, |
|
device=device, |
|
eval_episodes=args.eval_episodes, |
|
epsilon=args.final_exploration, |
|
) |
|
for i, r in enumerate(episode_rewards): |
|
writer.add_scalar("eval/episodic_return_eval", r, i) |
|
|
|
if args.push_model: |
|
from utils import push_model |
|
|
|
push_model( |
|
args=args, |
|
episode_rewards=episode_rewards, |
|
algo_name="DQN", |
|
run_path=f"runs/{run_name}", |
|
video_folder_path=f"videos/{run_name_eval}", |
|
) |
|
|
|
writer.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|