alperenunlu commited on
Commit
94e16fc
·
verified ·
1 Parent(s): 49fb85d

Push model

Browse files
Files changed (28) hide show
  1. .gitattributes +11 -0
  2. README.md +63 -0
  3. dqn_atari.py +561 -0
  4. dqn_atari_1000000.safetensors +3 -0
  5. dqn_atari_2000000.safetensors +3 -0
  6. dqn_atari_3000000.safetensors +3 -0
  7. dqn_atari_4000000.safetensors +3 -0
  8. dqn_atari_5000000.safetensors +3 -0
  9. dqn_atari_6000000.safetensors +3 -0
  10. dqn_atari_7000000.safetensors +3 -0
  11. dqn_atari_8000000.safetensors +3 -0
  12. dqn_atari_9000000.safetensors +3 -0
  13. dqn_atari_final.safetensors +3 -0
  14. events.out.tfevents.1755876604.dc81b9906985.2178.0 +3 -0
  15. hyperparameters.json +25 -0
  16. pyproject.toml +16 -0
  17. replay.mp4 +3 -0
  18. uv.lock +0 -0
  19. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-0.mp4 +3 -0
  20. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-1.mp4 +3 -0
  21. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-2.mp4 +3 -0
  22. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-3.mp4 +3 -0
  23. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-4.mp4 +3 -0
  24. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-5.mp4 +3 -0
  25. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-6.mp4 +3 -0
  26. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-7.mp4 +3 -0
  27. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-8.mp4 +3 -0
  28. videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-9.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-8.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-2.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-7.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-4.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-0.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-3.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-6.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-5.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-1.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-9.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - reinforcement-learning
4
+ - deep-reinforcement-learning
5
+ - DQN
6
+ - SpaceInvadersNoFrameskip-v4
7
+ library_name: rlhell
8
+ model-index:
9
+ - name: DQN
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: SpaceInvadersNoFrameskip-v4
16
+ type: SpaceInvadersNoFrameskip-v4
17
+ metrics:
18
+ - type: mean_reward
19
+ value: 1813.50 +/- 579.77
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # (RLHell) **DQN** Agent Playing **SpaceInvadersNoFrameskip-v4**
25
+
26
+ The model was trained by using [rlhell](https://github.com/alperenunlu/rlhell).
27
+
28
+ ## Command to reproduce the training
29
+
30
+ ```bash
31
+ curl -OL https://huggingface.co/alperenunlu/SpaceInvadersNoFrameskip-v4-dqn_atari/raw/main/dqn_atari.py
32
+ curl -OL https://huggingface.co/alperenunlu/SpaceInvadersNoFrameskip-v4-dqn_atari/raw/main/pyproject.toml
33
+ curl -OL https://huggingface.co/alperenunlu/SpaceInvadersNoFrameskip-v4-dqn_atari/raw/main/uv.lock
34
+ uv run dqn_atari.py
35
+ ```
36
+
37
+ # Hyperparameters
38
+ ```python
39
+ {'batch_size': 32,
40
+ 'buffer_size': 1_000_000,
41
+ 'env_id': 'SpaceInvadersNoFrameskip-v4',
42
+ 'eval_episodes': 10,
43
+ 'evaluate': True,
44
+ 'exp_name': 'dqn_atari',
45
+ 'exploration_fraction': 0.1,
46
+ 'final_exploration': 0.01,
47
+ 'gamma': 0.99,
48
+ 'hf_entity': 'alperenunlu',
49
+ 'initial_exploration': 1,
50
+ 'learning_rate': 0.0001,
51
+ 'learning_start': 80_000,
52
+ 'log_interval': 100,
53
+ 'n_envs': 8,
54
+ 'push_model': True,
55
+ 'save_times': 10,
56
+ 'seed': 0,
57
+ 'target_network_update_frequency': 1_000,
58
+ 'tau': 1.0,
59
+ 'total_timesteps': 10_000_000,
60
+ 'train_frequency': 4,
61
+ 'video_capture_frequency': 5}
62
+ ```
63
+
dqn_atari.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from typing import Any, Callable, SupportsFloat
8
+
9
+ import ale_py # noqa: F401
10
+ import gymnasium as gym
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.optim as optim
16
+ import tyro
17
+ from safetensors.torch import save_model
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from tqdm.auto import tqdm
20
+
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+
23
+ device = torch.device(
24
+ "cuda"
25
+ if torch.cuda.is_available()
26
+ else "mps"
27
+ if torch.backends.mps.is_available()
28
+ else "cpu"
29
+ )
30
+
31
+
32
+ @dataclass
33
+ class HyperParams:
34
+ env_id: str = "SpaceInvadersNoFrameskip-v4"
35
+ """The ID of the environment to train on."""
36
+ exp_name: str = os.path.basename(__file__)[: -len(".py")]
37
+ """The name of the experiment, used for saving models and logs."""
38
+ n_envs: int = 8
39
+ """The number of parallel environments to run."""
40
+ seed: int = 0
41
+ """The random seed for reproducibility."""
42
+ total_timesteps: int = 10_000_000
43
+ """The total number of timesteps to train the agent."""
44
+ buffer_size: int = 1_000_000
45
+ """The size of the replay buffer to store transitions."""
46
+ video_capture_frequency: int = 5
47
+ """The interval (in episodes) to record videos of the agent's performance."""
48
+
49
+ initial_exploration: float = 1
50
+ """The initial exploration rate for the epsilon-greedy policy."""
51
+ final_exploration: float = 0.01
52
+ """The final exploration rate after annealing."""
53
+ exploration_fraction: float = 0.1
54
+ """The fraction of total timesteps over which to anneal the exploration rate."""
55
+
56
+ learning_start: int = 80_000
57
+ """The number of timesteps before starting to learn."""
58
+ train_frequency: int = 4
59
+ """The frequency (in timesteps) to update the Q-network."""
60
+ batch_size: int = 32
61
+ """The batch size for sampling from the replay buffer."""
62
+ gamma: float = 0.99
63
+ """The discount factor (gamma) for future rewards."""
64
+ learning_rate: float = 1e-4
65
+ """The learning rate for the optimizer."""
66
+ target_network_update_frequency: int = 1_000
67
+ """The frequency (in timesteps) to update the target network."""
68
+ tau: float = 1.0
69
+ """The rate at which to update the target network towards the Q-network."""
70
+
71
+ log_interval: int = 100
72
+ """The interval (in timesteps) to log training statistics."""
73
+ save_times: int = 10
74
+ """The number of times to save the model during training."""
75
+
76
+ evaluate: bool = True
77
+ """Whether to evaluate the agent after training."""
78
+ eval_episodes: int = 10
79
+ """The number of episodes to run for evaluation."""
80
+
81
+ push_model: bool = True
82
+ hf_entity: str = "alperenunlu"
83
+
84
+
85
+ class ReplayBuffer:
86
+ def __init__(
87
+ self,
88
+ buffer_size: int,
89
+ observation_space: gym.Space,
90
+ action_space: gym.Space,
91
+ device: torch.device | str = "auto",
92
+ n_envs: int = 1,
93
+ optimize_memory_usage: bool = True,
94
+ ) -> None:
95
+ self.buffer_size = max(buffer_size // n_envs, 1)
96
+ self.n_envs = n_envs
97
+ self.optimize_memory_usage = optimize_memory_usage
98
+
99
+ self.obs_shape = self.get_obs_shape(observation_space)
100
+ self.action_dim = self.get_action_dim(action_space)
101
+
102
+ self.device = self.get_device(device)
103
+
104
+ self.observations = np.zeros(
105
+ (self.buffer_size, self.n_envs, *self.obs_shape),
106
+ dtype=observation_space.dtype,
107
+ )
108
+ if not self.optimize_memory_usage:
109
+ self.next_observations = np.zeros(
110
+ (self.buffer_size, self.n_envs, *self.obs_shape),
111
+ dtype=observation_space.dtype,
112
+ )
113
+ else:
114
+ self.next_observations = None
115
+
116
+ self.actions = np.zeros(
117
+ (self.buffer_size, self.n_envs, self.action_dim),
118
+ dtype=action_space.dtype,
119
+ )
120
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
121
+ self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
122
+
123
+ self.index = 0
124
+ self.size = 0
125
+
126
+ def add(
127
+ self,
128
+ obs: np.ndarray,
129
+ next_obs: np.ndarray,
130
+ actions: np.ndarray,
131
+ rewards: np.ndarray,
132
+ dones: np.ndarray,
133
+ ) -> None:
134
+ """Add a new transition to the replay buffer."""
135
+ self.observations[self.index] = np.asarray(obs)
136
+
137
+ if self.optimize_memory_usage:
138
+ # Store next_obs at the *next* slot to save memory
139
+ self.observations[(self.index + 1) % self.buffer_size] = np.asarray(
140
+ next_obs
141
+ )
142
+ else:
143
+ self.next_observations[self.index] = np.asarray(next_obs)
144
+
145
+ self.actions[self.index] = np.asarray(actions).reshape(
146
+ self.n_envs, self.action_dim
147
+ )
148
+ self.rewards[self.index] = np.asarray(rewards)
149
+ self.dones[self.index] = np.asarray(dones)
150
+
151
+ self.index = (self.index + 1) % self.buffer_size
152
+ self.size = min(self.size + 1, self.buffer_size)
153
+
154
+ def sample(self, batch_size: int) -> dict[str, torch.Tensor]:
155
+ """Sample a batch of data from the replay buffer."""
156
+ if not self.optimize_memory_usage:
157
+ batch_idx = np.random.randint(0, self.size, size=batch_size)
158
+ else:
159
+ # Do not sample the write index because its (obs,next_obs) pair is invalid
160
+ if self.size == self.buffer_size:
161
+ batch_idx = (
162
+ np.random.randint(1, self.buffer_size, size=batch_size) + self.index
163
+ ) % self.buffer_size
164
+ else:
165
+ batch_idx = np.random.randint(0, self.index, size=batch_size)
166
+
167
+ env_idx = np.random.randint(0, self.n_envs, size=batch_size)
168
+
169
+ if self.optimize_memory_usage:
170
+ next_obs = self.observations[
171
+ (batch_idx + 1) % self.buffer_size, env_idx, ...
172
+ ]
173
+ else:
174
+ next_obs = self.next_observations[batch_idx, env_idx, ...]
175
+
176
+ data = dict(
177
+ obs=self.observations[batch_idx, env_idx, ...],
178
+ next_obs=next_obs,
179
+ actions=self.actions[batch_idx, env_idx, ...],
180
+ rewards=self.rewards[batch_idx, env_idx, ...],
181
+ dones=self.dones[batch_idx, env_idx, ...],
182
+ )
183
+ return self.to_torch(data)
184
+
185
+ def to_torch(self, data: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
186
+ """Convert numpy arrays to torch tensors and move them to the specified device."""
187
+ tensor_data = dict()
188
+ for k, v in data.items():
189
+ tensor_data[k] = torch.from_numpy(v).to(device=self.device)
190
+ return tensor_data
191
+
192
+ @staticmethod
193
+ def get_device(device: torch.device | str = "auto") -> torch.device:
194
+ """Get the device to use for computations."""
195
+ if device == "auto":
196
+ return torch.device(
197
+ "cuda"
198
+ if torch.cuda.is_available()
199
+ else "mps"
200
+ if torch.backends.mps.is_available()
201
+ else "cpu"
202
+ )
203
+ else:
204
+ return torch.device(device)
205
+
206
+ @staticmethod
207
+ def get_obs_shape(
208
+ observation_space: gym.Space,
209
+ ) -> tuple[int, ...]:
210
+ """Get the shape of the observation space."""
211
+ if isinstance(observation_space, gym.spaces.Box):
212
+ return observation_space.shape
213
+ elif isinstance(observation_space, gym.spaces.Discrete):
214
+ return (1,)
215
+ elif isinstance(observation_space, gym.spaces.MultiDiscrete):
216
+ return (int(len(observation_space.nvec)),)
217
+ elif isinstance(observation_space, gym.spaces.MultiBinary):
218
+ return observation_space.shape
219
+ else:
220
+ raise NotImplementedError(
221
+ f"{observation_space} observation space is not supported"
222
+ )
223
+
224
+ @staticmethod
225
+ def get_action_dim(action_space: gym.spaces.Space) -> int:
226
+ """Get the dimension of the action space."""
227
+ if isinstance(action_space, gym.spaces.Box):
228
+ return int(np.prod(action_space.shape))
229
+ elif isinstance(action_space, gym.spaces.Discrete):
230
+ return 1
231
+ elif isinstance(action_space, gym.spaces.MultiDiscrete):
232
+ return int(len(action_space.nvec))
233
+ else:
234
+ raise NotImplementedError(f"{action_space} action space is not supported")
235
+
236
+
237
+ class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
238
+ """
239
+ Take action on reset for environments that are fixed until firing.
240
+
241
+ :param env: Environment to wrap
242
+ """
243
+
244
+ def __init__(self, env: gym.Env) -> None:
245
+ super().__init__(env)
246
+ assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
247
+ assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
248
+
249
+ def reset(self, **kwargs) -> tuple[np.ndarray, dict[str, Any]]:
250
+ self.env.reset(**kwargs)
251
+ obs, _, terminated, truncated, info = self.env.step(1)
252
+ if terminated or truncated:
253
+ self.env.reset(**kwargs)
254
+ obs, _, terminated, truncated, info = self.env.step(2)
255
+ if terminated or truncated:
256
+ self.env.reset(**kwargs)
257
+ return obs, info
258
+
259
+
260
+ class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
261
+ """
262
+ Make end-of-life == end-of-episode, but only reset on true game over.
263
+ Done by DeepMind for the DQN and co. since it helps value estimation.
264
+
265
+ :param env: Environment to wrap
266
+ """
267
+
268
+ def __init__(self, env: gym.Env) -> None:
269
+ super().__init__(env)
270
+ self.lives = 0
271
+ self.was_real_done = True
272
+
273
+ def step(
274
+ self, action: int
275
+ ) -> tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]]:
276
+ obs, reward, terminated, truncated, info = self.env.step(action)
277
+ self.was_real_done = terminated or truncated
278
+ lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
279
+ if 0 < lives < self.lives:
280
+ terminated = True
281
+ self.lives = lives
282
+ return obs, reward, terminated, truncated, info
283
+
284
+ def reset(self, **kwargs) -> tuple[np.ndarray, dict[str, Any]]:
285
+ """
286
+ Calls the Gym environment reset, only when lives are exhausted.
287
+ This way all states are still reachable even though lives are episodic,
288
+ and the learner need not know about any of this behind-the-scenes.
289
+
290
+ :param kwargs: Extra keywords passed to env.reset() call
291
+ :return: the first observation of the environment
292
+ """
293
+ if self.was_real_done:
294
+ obs, info = self.env.reset(**kwargs)
295
+ else:
296
+ obs, _, terminated, truncated, info = self.env.step(0)
297
+
298
+ if terminated or truncated:
299
+ obs, info = self.env.reset(**kwargs)
300
+ self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
301
+ return obs, info
302
+
303
+
304
+ def make_env(
305
+ env_id: str, seed: int, idx: int, video_capture_frequency: int, run_name: str
306
+ ) -> Callable[[], gym.Env]:
307
+ """Create a gym environment with specific configurations.
308
+
309
+ Args:
310
+ env_id (str): The ID of the environment to create.
311
+ seed (int): The seed for random number generation.
312
+ idx (int): The index of the environment (for vectorized environments).
313
+ video_freq (int): Frequency of recording videos (0 to disable).
314
+ run_name (str): The name of the run for saving videos.
315
+
316
+ Returns:
317
+ Callable[[], gym.Env]: A function that returns the configured environment.
318
+ """
319
+
320
+ def _thunk() -> gym.Env:
321
+ if video_capture_frequency > 0 and idx == 0:
322
+ env = gym.make(env_id, render_mode="rgb_array")
323
+ env = gym.wrappers.RecordVideo(
324
+ env,
325
+ video_folder=f"videos/{run_name}",
326
+ episode_trigger=lambda x: x % video_capture_frequency == 0,
327
+ name_prefix=env_id,
328
+ )
329
+ else:
330
+ env = gym.make(env_id)
331
+ env = gym.wrappers.RecordEpisodeStatistics(env)
332
+ env = gym.wrappers.AtariPreprocessing(
333
+ env,
334
+ noop_max=30,
335
+ frame_skip=4,
336
+ screen_size=(84, 84),
337
+ terminal_on_life_loss=False,
338
+ grayscale_obs=True,
339
+ grayscale_newaxis=False,
340
+ scale_obs=False,
341
+ )
342
+ env = EpisodicLifeEnv(env)
343
+ if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined]
344
+ env = FireResetEnv(env)
345
+ env = gym.wrappers.ClipReward(env, -1, +1)
346
+ env = gym.wrappers.FrameStackObservation(env, stack_size=4)
347
+ env.action_space.seed(seed)
348
+ return env
349
+
350
+ return _thunk
351
+
352
+
353
+ class QNetwork(nn.Module):
354
+ def __init__(self, env) -> None:
355
+ super().__init__()
356
+ self.network = nn.Sequential(
357
+ nn.Conv2d(4, 32, 8, stride=4),
358
+ nn.ReLU(inplace=True),
359
+ nn.Conv2d(32, 64, 4, stride=2),
360
+ nn.ReLU(inplace=True),
361
+ nn.Conv2d(64, 64, 3, stride=1),
362
+ nn.ReLU(inplace=True),
363
+ nn.Flatten(),
364
+ nn.Linear(7 * 7 * 64, 512),
365
+ nn.ReLU(inplace=True),
366
+ nn.Linear(512, env.single_action_space.n),
367
+ )
368
+
369
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
370
+ return self.network(x / 255.0)
371
+
372
+
373
+ def linear_schedule(t: int, start_e: float, end_e: float, duration: int) -> float:
374
+ """Linear annealing from start_e to end_e over duration steps.
375
+
376
+ Args:
377
+ t (int): Current step.
378
+ start_e (float): Initial exploration rate.
379
+ end_e (float): Final exploration rate.
380
+ duration (float): Duration of the annealing in steps.
381
+
382
+ Returns:
383
+ float: The exploration rate at step t.
384
+ """
385
+ slope = (end_e - start_e) / duration
386
+ return max(slope * t + start_e, end_e)
387
+
388
+
389
+ def main() -> None:
390
+ args = tyro.cli(HyperParams)
391
+ run_name = f"{args.env_id}_{args.exp_name.replace('/', '_')}_{args.seed}_{datetime.now().strftime('%y%m%d_%H%M%S')}"
392
+ print(run_name)
393
+ writer = SummaryWriter(f"runs/{run_name}")
394
+ keyval_str = "\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])
395
+ writer.add_text(
396
+ "hyperparameters",
397
+ f"|param|value|\n|-|-|\n{keyval_str}",
398
+ )
399
+
400
+ random.seed(args.seed)
401
+ np.random.seed(args.seed)
402
+ torch.manual_seed(args.seed)
403
+
404
+ envs = gym.vector.AsyncVectorEnv(
405
+ [
406
+ make_env(
407
+ args.env_id, args.seed + i, i, args.video_capture_frequency, run_name
408
+ )
409
+ for i in range(int(args.n_envs))
410
+ ],
411
+ autoreset_mode=gym.vector.AutoresetMode.DISABLED,
412
+ )
413
+ assert isinstance(envs.single_action_space, gym.spaces.Discrete)
414
+ envs.action_space.seed(args.seed)
415
+
416
+ q_network = QNetwork(envs).to(device)
417
+ optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
418
+ target_network = QNetwork(envs).to(device)
419
+ target_network.load_state_dict(q_network.state_dict())
420
+ target_network.eval()
421
+
422
+ rb = ReplayBuffer(
423
+ args.buffer_size,
424
+ envs.single_observation_space,
425
+ envs.single_action_space,
426
+ device=device,
427
+ n_envs=envs.num_envs,
428
+ )
429
+
430
+ start_time = time.time()
431
+
432
+ obs, _ = envs.reset(seed=args.seed)
433
+ step_pbar = tqdm(total=args.total_timesteps)
434
+ postfix_dict = dict()
435
+ for step in range(0, args.total_timesteps, envs.num_envs):
436
+ epsilon = linear_schedule(
437
+ step,
438
+ args.initial_exploration,
439
+ args.final_exploration,
440
+ int(args.total_timesteps * args.exploration_fraction),
441
+ )
442
+ writer.add_scalar("charts/epsilon", epsilon, step)
443
+ postfix_dict.update(e=epsilon)
444
+
445
+ with torch.no_grad():
446
+ q_values = q_network(torch.from_numpy(obs).to(device))
447
+ greedy_actions = torch.argmax(q_values, dim=1).cpu().numpy()
448
+
449
+ randn_actions = envs.action_space.sample()
450
+ mask_actions = np.random.rand(envs.num_envs) < epsilon
451
+ actions = np.where(mask_actions, randn_actions, greedy_actions)
452
+
453
+ next_obs, rewards, terminations, truncations, infos = envs.step(actions)
454
+
455
+ if "episode" in infos:
456
+ mask = infos["_episode"]
457
+ r_mean = infos["episode"]["r"][mask].mean()
458
+ l_mean = infos["episode"]["l"][mask].mean()
459
+
460
+ writer.add_scalar("charts/episodic_return", r_mean, step)
461
+ writer.add_scalar("charts/episodic_length", l_mean, step)
462
+ postfix_dict.update(
463
+ r=r_mean,
464
+ l=l_mean,
465
+ )
466
+
467
+ rb.add(obs, next_obs, actions, rewards, terminations)
468
+
469
+ dones = np.logical_or(terminations, truncations)
470
+ if dones.any():
471
+ obs, _ = envs.reset(options={"reset_mask": dones})
472
+ else:
473
+ obs = next_obs
474
+
475
+ for parallel_step in range(
476
+ step, min(step + envs.num_envs, args.total_timesteps)
477
+ ):
478
+ if parallel_step > args.learning_start:
479
+ if parallel_step % args.train_frequency == 0:
480
+ data = rb.sample(args.batch_size)
481
+ with torch.no_grad():
482
+ target_max, _ = target_network(data["next_obs"]).max(dim=1)
483
+ td_target = data[
484
+ "rewards"
485
+ ].flatten() + args.gamma * target_max * (
486
+ 1 - data["dones"].flatten()
487
+ )
488
+ q_val = q_network(data["obs"]).gather(1, data["actions"]).squeeze()
489
+ loss = F.smooth_l1_loss(q_val, td_target)
490
+
491
+ if step % args.log_interval < envs.num_envs:
492
+ writer.add_scalar("losses/td_loss", loss, parallel_step)
493
+ writer.add_scalar(
494
+ "losses/q_values", q_val.mean().item(), parallel_step
495
+ )
496
+ writer.add_scalar(
497
+ "charts/SPS",
498
+ step // (time.time() - start_time),
499
+ parallel_step,
500
+ )
501
+ postfix_dict.update(
502
+ td_loss=loss.item(),
503
+ q_val_mean=q_val.mean().item(),
504
+ sps=step // (time.time() - start_time),
505
+ )
506
+
507
+ optimizer.zero_grad()
508
+ loss.backward()
509
+ optimizer.step()
510
+
511
+ if parallel_step % args.target_network_update_frequency == 0:
512
+ for target_network_param, q_network_param in zip(
513
+ target_network.parameters(), q_network.parameters()
514
+ ):
515
+ target_network_param.data.lerp_(q_network_param.data, args.tau)
516
+
517
+ if parallel_step % (args.total_timesteps // args.save_times) == 0:
518
+ save_model(
519
+ model=q_network,
520
+ filename=f"runs/{run_name}/{args.exp_name}_{step}.safetensors",
521
+ )
522
+ step_pbar.set_postfix(postfix_dict)
523
+ step_pbar.update(envs.num_envs)
524
+ envs.close()
525
+ step_pbar.close()
526
+
527
+ if args.evaluate:
528
+ run_name_eval = f"{run_name}_eval"
529
+ final_model_path = f"runs/{run_name}/{args.exp_name}_final.safetensors"
530
+ save_model(model=q_network, filename=final_model_path)
531
+ from evals.dqn_eval import evaluate
532
+
533
+ episode_rewards = evaluate(
534
+ final_model_path=final_model_path,
535
+ make_env=make_env,
536
+ env_id=args.env_id,
537
+ run_name_eval=run_name_eval,
538
+ QNetwork=QNetwork,
539
+ device=device,
540
+ eval_episodes=args.eval_episodes,
541
+ epsilon=args.final_exploration,
542
+ )
543
+ for i, r in enumerate(episode_rewards):
544
+ writer.add_scalar("eval/episodic_return_eval", r, i)
545
+
546
+ if args.push_model:
547
+ from utils import push_model
548
+
549
+ push_model(
550
+ args=args,
551
+ episode_rewards=episode_rewards,
552
+ algo_name="DQN",
553
+ run_path=f"runs/{run_name}",
554
+ video_folder_path=f"videos/{run_name_eval}",
555
+ )
556
+
557
+ writer.close()
558
+
559
+
560
+ if __name__ == "__main__":
561
+ main()
dqn_atari_1000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7dd068380a83b737c2f6ba15ce7b54742a97e328733448bb6052bcc76d1abac
3
+ size 6749632
dqn_atari_2000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4ed5df14bf26a840e1090978e037112bbb02a0f56d3ef229d1c8885082abc1e
3
+ size 6749632
dqn_atari_3000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f51ada39ff7512a9180606dc367d7e5fe7a7c02692f645e3f5f91d1e937b3e4
3
+ size 6749632
dqn_atari_4000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bfc97d59160b805d1dd56a0372d8634dbcb7fb17b7ba1ba7239c533a5f7a001
3
+ size 6749632
dqn_atari_5000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:852b6388c16dec1cece86432c9cb1b159263b6a7d33ee84a212de7685cf3148c
3
+ size 6749632
dqn_atari_6000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33978695e05d872451f0262f3a3381da1e36fa721de54e98bd3850740612e936
3
+ size 6749632
dqn_atari_7000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6175016bb27fa7fd4b65268a1e9dd3791590ef11525f8805c01a608ba1e1a389
3
+ size 6749632
dqn_atari_8000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e74417b7e7fc1025fb6ab60bafda4b26eacb59f5207893f9930f07a8488f1977
3
+ size 6749632
dqn_atari_9000000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a416598792b7b6968310f07ad932efdb8e12ba5eee63ce13e0df5f40d06703f
3
+ size 6749632
dqn_atari_final.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6caadb9e62f0806a480f84d4c90ec8f3e7b96dd1f1fb51b7c75223cb936222a2
3
+ size 6749632
events.out.tfevents.1755876604.dc81b9906985.2178.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b4ac76ec60c8469e0d0183cefc32741464287ced7fec20e2daca8bbd816b6f8
3
+ size 101547150
hyperparameters.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "env_id": "SpaceInvadersNoFrameskip-v4",
3
+ "exp_name": "dqn_atari",
4
+ "n_envs": 8,
5
+ "seed": 0,
6
+ "total_timesteps": 10000000,
7
+ "buffer_size": 1000000,
8
+ "video_capture_frequency": 5,
9
+ "initial_exploration": 1,
10
+ "final_exploration": 0.01,
11
+ "exploration_fraction": 0.1,
12
+ "learning_start": 80000,
13
+ "train_frequency": 4,
14
+ "batch_size": 32,
15
+ "gamma": 0.99,
16
+ "learning_rate": 0.0001,
17
+ "target_network_update_frequency": 1000,
18
+ "tau": 1.0,
19
+ "log_interval": 100,
20
+ "save_times": 10,
21
+ "evaluate": true,
22
+ "eval_episodes": 10,
23
+ "push_model": true,
24
+ "hf_entity": "alperenunlu"
25
+ }
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rlhell"
3
+ version = "0.1.0"
4
+ description = "RL Implementations Without Dependency Hell Cutting Edge Versions and Vectorized Training"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "gymnasium[all]>=1.2.0",
9
+ "huggingface-hub>=0.34.4",
10
+ "safetensors>=0.6.2",
11
+ "swig>=4.3.1",
12
+ "tensorboard>=2.20.0",
13
+ "torch>=2.8.0",
14
+ "tqdm>=4.67.1",
15
+ "tyro>=0.9.28",
16
+ ]
replay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0aa46fb3860c1865495e256b64def56e973ca1029ae0aa5b29d0eb1b3cc9fcf4
3
+ size 1390433
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0aa46fb3860c1865495e256b64def56e973ca1029ae0aa5b29d0eb1b3cc9fcf4
3
+ size 1390433
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81058e6ce5baa8ba1699e95f7f88c12356292f47f8744e5f248ccd425712ab56
3
+ size 622516
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dd804e02102171a0ebf380b45108fee1fc2c631d2e97cf7e58c7a0ab6420059
3
+ size 670602
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b3332e807ee00821010ebfca57343a8bba4f04f85449d1f0bd672212b130f5d
3
+ size 1259808
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5accd7fb013d92d5ec6ffb0d200693cc316f21f54a1ca7bae593087303f6baa
3
+ size 1182288
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-5.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7822ac3dc85aff825689f7022cd6c64ebe64101a4a7e7f5be702c3ebf68731a
3
+ size 728035
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-6.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a2c826830c9e0bea8736b213fb57792676fe16191580d50727f65e40ec1871f
3
+ size 1118185
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-7.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0660e438dfacf48485940b7388833d7261cfe9b261bf33333aa1db39f8bfae08
3
+ size 805407
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-8.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1c041e5c76f0bc0c8ed6fd433557af20613544e4455e8ae9e5c1ef9a00675b6
3
+ size 593165
videos/SpaceInvadersNoFrameskip-v4_dqn_atari_0_250822_153004_eval/SpaceInvadersNoFrameskip-v4-episode-9.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99016ba1754bbf677baf13297f08aa533f18072271769d5aa8112da36e7d78e1
3
+ size 903444