drbh commited on
Commit
e612007
·
1 Parent(s): 40f2269

feat: add build

Browse files
Files changed (27) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/adam_atan2/__init__.py +133 -0
  2. build/torch26-cxx11-cu118-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  3. build/torch26-cxx11-cu118-x86_64-linux/adam_atan2/_ops.py +9 -0
  4. build/torch26-cxx11-cu124-x86_64-linux/adam_atan2/__init__.py +133 -0
  5. build/torch26-cxx11-cu124-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  6. build/torch26-cxx11-cu124-x86_64-linux/adam_atan2/_ops.py +9 -0
  7. build/torch26-cxx11-cu126-x86_64-linux/adam_atan2/__init__.py +133 -0
  8. build/torch26-cxx11-cu126-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  9. build/torch26-cxx11-cu126-x86_64-linux/adam_atan2/_ops.py +9 -0
  10. build/torch26-cxx98-cu118-x86_64-linux/adam_atan2/__init__.py +133 -0
  11. build/torch26-cxx98-cu118-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  12. build/torch26-cxx98-cu118-x86_64-linux/adam_atan2/_ops.py +9 -0
  13. build/torch26-cxx98-cu124-x86_64-linux/adam_atan2/__init__.py +133 -0
  14. build/torch26-cxx98-cu124-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  15. build/torch26-cxx98-cu124-x86_64-linux/adam_atan2/_ops.py +9 -0
  16. build/torch26-cxx98-cu126-x86_64-linux/adam_atan2/__init__.py +133 -0
  17. build/torch26-cxx98-cu126-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  18. build/torch26-cxx98-cu126-x86_64-linux/adam_atan2/_ops.py +9 -0
  19. build/torch27-cxx11-cu118-x86_64-linux/adam_atan2/__init__.py +133 -0
  20. build/torch27-cxx11-cu118-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  21. build/torch27-cxx11-cu118-x86_64-linux/adam_atan2/_ops.py +9 -0
  22. build/torch27-cxx11-cu126-x86_64-linux/adam_atan2/__init__.py +133 -0
  23. build/torch27-cxx11-cu126-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  24. build/torch27-cxx11-cu126-x86_64-linux/adam_atan2/_ops.py +9 -0
  25. build/torch27-cxx11-cu128-x86_64-linux/adam_atan2/__init__.py +133 -0
  26. build/torch27-cxx11-cu128-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so +3 -0
  27. build/torch27-cxx11-cu128-x86_64-linux/adam_atan2/_ops.py +9 -0
build/torch26-cxx11-cu118-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch26-cxx11-cu118-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bf69bbdc95e26ae0b98ecd42e238a8fa67c503348ba062c8af18e681b758db3
3
+ size 2900352
build/torch26-cxx11-cu118-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch26-cxx11-cu124-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50887573fa0599bcc94b948faaa38d9c6e06f8a654c066d5f49460b86a109c1b
3
+ size 2929048
build/torch26-cxx11-cu124-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch26-cxx11-cu126-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eec0b7d37568f183dfbcb419f359f9b046fd893e075525083f635c2f936c89e0
3
+ size 2933584
build/torch26-cxx11-cu126-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch26-cxx98-cu118-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13a70e896d05f084a9a736b17604e859a6b384fb9dc93b5c28486a3a84d2bc93
3
+ size 2897504
build/torch26-cxx98-cu118-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch26-cxx98-cu124-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f11e153608ac8a325a312278994b01278dd47c41173ee06241eff26f69637a48
3
+ size 2922152
build/torch26-cxx98-cu124-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch26-cxx98-cu126-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:993bdb2c8dd5dc103bdb1c2d632e0f41ace2caf6665e3211b2954bc191eb5bf9
3
+ size 2926688
build/torch26-cxx98-cu126-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch27-cxx11-cu118-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47610d48450a101312d6e6335153e925a90efd357e1b252f3eec07b1459cd58f
3
+ size 2900448
build/torch27-cxx11-cu118-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch27-cxx11-cu126-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d202c842a41a288e7c56b9063ad1ab1962cde09a647072344f07c76687555f7
3
+ size 2933616
build/torch27-cxx11-cu126-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/adam_atan2/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+ from ._ops import ops
5
+
6
+ from typing import List, Tuple, Union
7
+ from torch import Tensor
8
+ from torch.optim.optimizer import Optimizer, ParamsT
9
+
10
+
11
+ class AdamATan2(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params: ParamsT,
15
+ lr: Union[float, Tensor] = 1e-3,
16
+ betas: Tuple[float, float] = (0.9, 0.999),
17
+ weight_decay: float = 1e-2,
18
+ ):
19
+ if not 0.0 <= lr:
20
+ raise ValueError(f"Invalid learning rate: {lr}")
21
+ if not 0.0 <= betas[0] < 1.0:
22
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
23
+ if not 0.0 <= betas[1] < 1.0:
24
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
25
+ if not 0.0 <= weight_decay:
26
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
27
+
28
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
29
+ super().__init__(params, defaults)
30
+
31
+ def _init_group(
32
+ self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
33
+ ):
34
+ for p in group["params"]:
35
+ if p.grad is None:
36
+ continue
37
+
38
+ params_with_grad.append(p)
39
+ if p.grad.is_sparse:
40
+ raise RuntimeError("AdamW does not support sparse gradients")
41
+ grads.append(p.grad)
42
+
43
+ state = self.state[p]
44
+
45
+ # State initialization
46
+ if len(state) == 0:
47
+ # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
48
+ # This is because kernel launches are costly on CUDA and XLA.
49
+ state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(
52
+ p, memory_format=torch.preserve_format
53
+ )
54
+ # Exponential moving average of squared gradient values
55
+ state["exp_avg_sq"] = torch.zeros_like(
56
+ p, memory_format=torch.preserve_format
57
+ )
58
+
59
+ exp_avgs.append(state["exp_avg"])
60
+ exp_avg_sqs.append(state["exp_avg_sq"])
61
+ state_steps.append(state["step"])
62
+
63
+ def step(self):
64
+ """Perform a single optimization step."""
65
+ self._cuda_graph_capture_health_check()
66
+
67
+ for group in self.param_groups:
68
+ params_with_grad = []
69
+ grads = []
70
+ exp_avgs = []
71
+ exp_avg_sqs = []
72
+ state_steps = []
73
+ beta1, beta2 = group["betas"]
74
+
75
+ self._init_group(
76
+ group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
77
+ )
78
+
79
+ _adam_atan2(
80
+ params_with_grad,
81
+ grads,
82
+ exp_avgs,
83
+ exp_avg_sqs,
84
+ state_steps,
85
+ beta1=beta1,
86
+ beta2=beta2,
87
+ lr=group["lr"],
88
+ weight_decay=group["weight_decay"],
89
+ )
90
+
91
+
92
+ def _adam_atan2(
93
+ params: List[Tensor],
94
+ grads: List[Tensor],
95
+ exp_avgs: List[Tensor],
96
+ exp_avg_sqs: List[Tensor],
97
+ state_steps: List[Tensor],
98
+ beta1: float,
99
+ beta2: float,
100
+ lr: float,
101
+ weight_decay: float,
102
+ ) -> None:
103
+ if not params:
104
+ return
105
+
106
+ # We only support scalar lr.
107
+ assert not isinstance(lr, Tensor)
108
+
109
+ grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
110
+ [params, grads, exp_avgs, exp_avg_sqs, state_steps]
111
+ )
112
+ for (device, _), (
113
+ (
114
+ device_params,
115
+ device_grads,
116
+ device_exp_avgs,
117
+ device_exp_avg_sqs,
118
+ device_state_steps,
119
+ ),
120
+ _,
121
+ ) in grouped_tensors.items():
122
+ torch._foreach_add_(device_state_steps, 1)
123
+ ops.adam_atan2_cuda_impl_(
124
+ device_params,
125
+ device_grads,
126
+ device_exp_avgs,
127
+ device_exp_avg_sqs,
128
+ device_state_steps,
129
+ lr,
130
+ beta1,
131
+ beta2,
132
+ weight_decay,
133
+ )
build/torch27-cxx11-cu128-x86_64-linux/adam_atan2/_adam_atan2_40f2269.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37cf6b90c834d55ed802734bb91a8a601f6eab6a88e0e8eed7bd4cb449c563fd
3
+ size 3688960
build/torch27-cxx11-cu128-x86_64-linux/adam_atan2/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _adam_atan2_40f2269
3
+ ops = torch.ops._adam_atan2_40f2269
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_adam_atan2_40f2269::{op_name}"