|
\documentclass{article} |
|
\usepackage{graphicx} |
|
\usepackage{hyperref} |
|
\usepackage{amsmath} |
|
\usepackage{caption} |
|
\usepackage{tgtermes} |
|
\usepackage{float} |
|
\usepackage[a4paper, margin=1in]{geometry} |
|
\usepackage{booktabs} |
|
\usepackage{algorithm} |
|
\usepackage{algorithmicx} |
|
\usepackage{algpseudocode} |
|
\date{} |
|
|
|
\begin{document} |
|
|
|
{\LARGE \bfseries Parallelize Muon with FSDP2 \par} |
|
\vspace{1em} |
|
|
|
\section*{Motivation} |
|
|
|
\begin{figure}[H] |
|
\centering |
|
\includegraphics[width=0.8\textwidth]{distributed_muon.png} |
|
\caption*{Distributed Muon by Moonlight} |
|
\end{figure} |
|
|
|
While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs. |
|
|
|
\begin{figure}[H] |
|
\centering |
|
\includegraphics[width=1.0\textwidth]{distributed_muon_execution.png} |
|
\caption*{Execution timeline of Distributed Muon} |
|
\end{figure} |
|
|
|
\begin{itemize} |
|
\item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient |
|
\item \texttt{AG[i]} : AllGather i-th gradient |
|
\item \texttt{G[i]} : Gather i-th gradient |
|
\item \texttt{SC[i]} : Scatter i-th gradient |
|
\end{itemize} |
|
\clearpage |
|
\section*{Algorithm} |
|
|
|
\subsection*{Parallel Muon} |
|
|
|
\begin{algorithm} |
|
\caption{Parallel Muon} |
|
\textbf{Require:} DP partitioned gradient $\mathbf{g}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameter $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$ |
|
\begin{algorithmic}[1] |
|
\State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$} |
|
\State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$ |
|
\State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$} |
|
\State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$ |
|
\State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$} |
|
\State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$ |
|
\State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$} |
|
\If{$\mathbf{r}$ == $\mathbf{R}$} |
|
\State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$ |
|
\Else |
|
\State $\mathbf{u} \gets None$ |
|
\EndIf |
|
|
|
\State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP} |
|
\State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$ |
|
\State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$} |
|
\State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$ |
|
\State \textbf{return $\mathbf{p'}$} |
|
\end{algorithmic} |
|
\end{algorithm} |
|
|
|
We eliminate redundant computation by assigning each parameter to a specific GPU. |
|
|
|
However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding. |
|
|
|
\begin{figure}[H] |
|
\centering |
|
\includegraphics[width=1.0\textwidth]{naive_execution.png} |
|
\caption*{Execution timeline of Parallel Muon} |
|
\end{figure} |
|
|
|
\subsection*{Scheduling Sub-Operations} |
|
|
|
We can schedule the whole sub-operations as follows, due to the following reasons: |
|
\begin{itemize} |
|
\item There are no dependencies between parameters. |
|
\item GPUs can execute computation and communication concurrently. |
|
\end{itemize} |
|
|
|
\begin{figure}[H] |
|
\centering |
|
\includegraphics[width=1.0\textwidth]{pipelined.png} |
|
\caption*{Execution timeline of re-scheduled Parallel Muon} |
|
\end{figure} |
|
|
|
We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete. |
|
|
|
\textbf{[Algorithm]} (To be written) |
|
\clearpage |
|
\subsection*{Load Balancing} |
|
|
|
If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\ |
|
To mitigate this, we apply load balancing based on per-parameter FLOPs. |
|
|
|
\vspace{1em} |
|
\textbf{Imbalanced (Round Robin)} |
|
|
|
\begin{figure}[H] |
|
\centering |
|
\includegraphics[width=1.0\textwidth]{imbalance.png} |
|
\end{figure} |
|
|
|
\textbf{After Load Balancing} |
|
|
|
\begin{figure}[H] |
|
\centering |
|
\includegraphics[width=1.0\textwidth]{balanced.png} |
|
\end{figure} |
|
|
|
\section*{Implementation} |
|
|
|
The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}. |
|
To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations. |
|
|
|
Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity. |
|
|
|
\section*{Evaluation} |
|
We evaluated the performance using 10B model currently in development, achieving 151 TFLOPS per GPU during the optimizer step. |
|
|
|
\begin{table}[H] |
|
\centering |
|
\begin{tabular}{@{}lllll@{}} |
|
\toprule |
|
Model Size & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\ |
|
\midrule |
|
10B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\ |
|
\bottomrule |
|
\end{tabular} |
|
\end{table} |
|
Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead. |
|
|
|
\end{document} |