File size: 5,379 Bytes
8535e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
febdf5b
8535e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
febdf5b
8535e80
 
 
 
 
febdf5b
8535e80
febdf5b
8535e80
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
\documentclass{article}
\usepackage{graphicx}
\usepackage{hyperref}
\usepackage{amsmath}
\usepackage{caption}
\usepackage{tgtermes}
\usepackage{float}
\usepackage[a4paper, margin=1in]{geometry}
\usepackage{booktabs}
\usepackage{algorithm}
\usepackage{algorithmicx}
\usepackage{algpseudocode}
\date{}

\begin{document}

{\LARGE \bfseries Parallelize Muon with FSDP2 \par}
\vspace{1em} % 제목 아래 간격 조정

\section*{Motivation}

\begin{figure}[H]
    \centering
    \includegraphics[width=0.8\textwidth]{distributed_muon.png}
    \caption*{Distributed Muon by Moonlight}
\end{figure}

While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs.

\begin{figure}[H]
    \centering
    \includegraphics[width=1.0\textwidth]{distributed_muon_execution.png}
    \caption*{Execution timeline of Distributed Muon}
\end{figure}

\begin{itemize}
    \item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient
    \item \texttt{AG[i]} : AllGather i-th gradient
    \item \texttt{G[i]} : Gather i-th gradient
    \item \texttt{SC[i]} : Scatter i-th gradient
\end{itemize}
\clearpage
\section*{Algorithm}

\subsection*{Parallel Muon}

\begin{algorithm}
\caption{Parallel Muon}
\textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$
\begin{algorithmic}[1]
\State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$}
\State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$
\State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$}
\State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$
\State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$}
\State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$
\State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$}
\If{$\mathbf{r}$ == $\mathbf{R}$}
  \State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$
\Else
  \State $\mathbf{u} \gets None$
\EndIf

\State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP}
\State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$
\State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$}
\State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$
\State \textbf{return $\mathbf{p'}$}
\end{algorithmic}
\end{algorithm}

We eliminate redundant computation by assigning each parameter to a specific GPU.

However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding.

\begin{figure}[H]
    \centering
    \includegraphics[width=1.0\textwidth]{naive_execution.png}
    \caption*{Execution timeline of Parallel Muon}
\end{figure}

\subsection*{Scheduling Sub-Operations}

We can schedule the whole sub-operations as follows, due to the following reasons:
\begin{itemize}
    \item There are no dependencies between parameters.
    \item GPUs can execute computation and communication concurrently.
\end{itemize}

\begin{figure}[H]
    \centering
    \includegraphics[width=1.0\textwidth]{pipelined.png}
    \caption*{Execution timeline of re-scheduled Parallel Muon}
\end{figure}

We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete.

\textbf{[Algorithm]} (To be written)
\clearpage
\subsection*{Load Balancing}

If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\
To mitigate this, we apply load balancing based on per-parameter FLOPs.

\vspace{1em}
\textbf{Imbalanced (Round Robin)}

\begin{figure}[H]
    \centering
    \includegraphics[width=1.0\textwidth]{imbalance.png}
\end{figure}

\textbf{After Load Balancing}

\begin{figure}[H]
    \centering
    \includegraphics[width=1.0\textwidth]{balanced.png}
\end{figure}

\section*{Implementation}

The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}.
To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations.

Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity.

\section*{Evaluation}
We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step.

\begin{table}[H]
    \centering
    \begin{tabular}{@{}lllll@{}}
        \toprule
        Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\
        \midrule
        10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\
        \bottomrule
    \end{tabular}
\end{table}
Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead.

\end{document}