iamwyldecat's picture
chore(muon): clean build and update doc
febdf5b
\documentclass{article}
\usepackage{graphicx}
\usepackage{hyperref}
\usepackage{amsmath}
\usepackage{caption}
\usepackage{tgtermes}
\usepackage{float}
\usepackage[a4paper, margin=1in]{geometry}
\usepackage{booktabs}
\usepackage{algorithm}
\usepackage{algorithmicx}
\usepackage{algpseudocode}
\date{}
\begin{document}
{\LARGE \bfseries Parallelize Muon with FSDP2 \par}
\vspace{1em} % 제목 아래 간격 조정
\section*{Motivation}
\begin{figure}[H]
\centering
\includegraphics[width=0.8\textwidth]{distributed_muon.png}
\caption*{Distributed Muon by Moonlight}
\end{figure}
While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs.
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{distributed_muon_execution.png}
\caption*{Execution timeline of Distributed Muon}
\end{figure}
\begin{itemize}
\item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient
\item \texttt{AG[i]} : AllGather i-th gradient
\item \texttt{G[i]} : Gather i-th gradient
\item \texttt{SC[i]} : Scatter i-th gradient
\end{itemize}
\clearpage
\section*{Algorithm}
\subsection*{Parallel Muon}
\begin{algorithm}
\caption{Parallel Muon}
\textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$
\begin{algorithmic}[1]
\State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$}
\State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$
\State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$}
\State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$
\State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$}
\State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$
\State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$}
\If{$\mathbf{r}$ == $\mathbf{R}$}
\State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$
\Else
\State $\mathbf{u} \gets None$
\EndIf
\State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP}
\State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$
\State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$}
\State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$
\State \textbf{return $\mathbf{p'}$}
\end{algorithmic}
\end{algorithm}
We eliminate redundant computation by assigning each parameter to a specific GPU.
However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding.
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{naive_execution.png}
\caption*{Execution timeline of Parallel Muon}
\end{figure}
\subsection*{Scheduling Sub-Operations}
We can schedule the whole sub-operations as follows, due to the following reasons:
\begin{itemize}
\item There are no dependencies between parameters.
\item GPUs can execute computation and communication concurrently.
\end{itemize}
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{pipelined.png}
\caption*{Execution timeline of re-scheduled Parallel Muon}
\end{figure}
We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete.
\textbf{[Algorithm]} (To be written)
\clearpage
\subsection*{Load Balancing}
If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\
To mitigate this, we apply load balancing based on per-parameter FLOPs.
\vspace{1em}
\textbf{Imbalanced (Round Robin)}
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{imbalance.png}
\end{figure}
\textbf{After Load Balancing}
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{balanced.png}
\end{figure}
\section*{Implementation}
The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}.
To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations.
Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity.
\section*{Evaluation}
We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step.
\begin{table}[H]
\centering
\begin{tabular}{@{}lllll@{}}
\toprule
Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\
\midrule
10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\
\bottomrule
\end{tabular}
\end{table}
Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead.
\end{document}