File size: 5,379 Bytes
8535e80 febdf5b 8535e80 febdf5b 8535e80 febdf5b 8535e80 febdf5b 8535e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
\documentclass{article}
\usepackage{graphicx}
\usepackage{hyperref}
\usepackage{amsmath}
\usepackage{caption}
\usepackage{tgtermes}
\usepackage{float}
\usepackage[a4paper, margin=1in]{geometry}
\usepackage{booktabs}
\usepackage{algorithm}
\usepackage{algorithmicx}
\usepackage{algpseudocode}
\date{}
\begin{document}
{\LARGE \bfseries Parallelize Muon with FSDP2 \par}
\vspace{1em} % 제목 아래 간격 조정
\section*{Motivation}
\begin{figure}[H]
\centering
\includegraphics[width=0.8\textwidth]{distributed_muon.png}
\caption*{Distributed Muon by Moonlight}
\end{figure}
While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs.
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{distributed_muon_execution.png}
\caption*{Execution timeline of Distributed Muon}
\end{figure}
\begin{itemize}
\item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient
\item \texttt{AG[i]} : AllGather i-th gradient
\item \texttt{G[i]} : Gather i-th gradient
\item \texttt{SC[i]} : Scatter i-th gradient
\end{itemize}
\clearpage
\section*{Algorithm}
\subsection*{Parallel Muon}
\begin{algorithm}
\caption{Parallel Muon}
\textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$
\begin{algorithmic}[1]
\State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$}
\State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$
\State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$}
\State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$
\State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$}
\State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$
\State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$}
\If{$\mathbf{r}$ == $\mathbf{R}$}
\State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$
\Else
\State $\mathbf{u} \gets None$
\EndIf
\State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP}
\State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$
\State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$}
\State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$
\State \textbf{return $\mathbf{p'}$}
\end{algorithmic}
\end{algorithm}
We eliminate redundant computation by assigning each parameter to a specific GPU.
However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding.
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{naive_execution.png}
\caption*{Execution timeline of Parallel Muon}
\end{figure}
\subsection*{Scheduling Sub-Operations}
We can schedule the whole sub-operations as follows, due to the following reasons:
\begin{itemize}
\item There are no dependencies between parameters.
\item GPUs can execute computation and communication concurrently.
\end{itemize}
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{pipelined.png}
\caption*{Execution timeline of re-scheduled Parallel Muon}
\end{figure}
We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete.
\textbf{[Algorithm]} (To be written)
\clearpage
\subsection*{Load Balancing}
If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\
To mitigate this, we apply load balancing based on per-parameter FLOPs.
\vspace{1em}
\textbf{Imbalanced (Round Robin)}
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{imbalance.png}
\end{figure}
\textbf{After Load Balancing}
\begin{figure}[H]
\centering
\includegraphics[width=1.0\textwidth]{balanced.png}
\end{figure}
\section*{Implementation}
The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}.
To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations.
Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity.
\section*{Evaluation}
We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step.
\begin{table}[H]
\centering
\begin{tabular}{@{}lllll@{}}
\toprule
Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\
\midrule
10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\
\bottomrule
\end{tabular}
\end{table}
Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead.
\end{document} |