stable-diffusion-implementation
/
main
/myenv
/lib
/python3.10
/site-packages
/lightning_fabric
/strategies
/ddp.py
# Copyright The Lightning AI team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from contextlib import AbstractContextManager, nullcontext | |
from datetime import timedelta | |
from typing import Any, Literal, Optional, Union | |
import torch | |
import torch.distributed | |
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only | |
from torch import Tensor | |
from torch.nn import Module | |
from torch.nn.parallel.distributed import DistributedDataParallel | |
from typing_extensions import override | |
from lightning_fabric.accelerators.accelerator import Accelerator | |
from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout | |
from lightning_fabric.plugins.environments.cluster_environment import ClusterEnvironment | |
from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO | |
from lightning_fabric.plugins.precision import Precision | |
from lightning_fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher | |
from lightning_fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher | |
from lightning_fabric.strategies.parallel import ParallelStrategy | |
from lightning_fabric.strategies.registry import _StrategyRegistry | |
from lightning_fabric.strategies.strategy import TBroadcast, _BackwardSyncControl | |
from lightning_fabric.utilities.distributed import ( | |
ReduceOp, | |
_distributed_is_initialized, | |
_get_default_process_group_backend_for_device, | |
_init_dist_connection, | |
_sync_ddp_if_available, | |
) | |
from lightning_fabric.utilities.distributed import group as _group | |
from lightning_fabric.utilities.rank_zero import rank_zero_only | |
_DDP_FORK_ALIASES = ( | |
"ddp_fork", | |
"ddp_notebook", | |
) | |
class DDPStrategy(ParallelStrategy): | |
"""Strategy for multi-process single-device training on one or multiple nodes.""" | |
def __init__( | |
self, | |
accelerator: Optional[Accelerator] = None, | |
parallel_devices: Optional[list[torch.device]] = None, | |
cluster_environment: Optional[ClusterEnvironment] = None, | |
checkpoint_io: Optional[CheckpointIO] = None, | |
precision: Optional[Precision] = None, | |
process_group_backend: Optional[str] = None, | |
timeout: Optional[timedelta] = default_pg_timeout, | |
start_method: Literal["popen", "spawn", "fork", "forkserver"] = "popen", | |
**kwargs: Any, | |
) -> None: | |
super().__init__( | |
accelerator=accelerator, | |
parallel_devices=parallel_devices, | |
cluster_environment=cluster_environment, | |
checkpoint_io=checkpoint_io, | |
precision=precision, | |
) | |
self._num_nodes = 1 | |
self._process_group_backend: Optional[str] = process_group_backend | |
self._timeout: Optional[timedelta] = timeout | |
self._start_method = start_method | |
self._backward_sync_control = _DDPBackwardSyncControl() | |
self._ddp_kwargs = kwargs | |
def root_device(self) -> torch.device: | |
assert self.parallel_devices is not None | |
return self.parallel_devices[self.local_rank] | |
def num_nodes(self) -> int: | |
return self._num_nodes | |
def num_nodes(self, num_nodes: int) -> None: | |
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks | |
self._num_nodes = num_nodes | |
def num_processes(self) -> int: | |
return len(self.parallel_devices) if self.parallel_devices is not None else 0 | |
def distributed_sampler_kwargs(self) -> dict[str, Any]: | |
return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} | |
def process_group_backend(self) -> Optional[str]: | |
return self._process_group_backend | |
def _configure_launcher(self) -> None: | |
assert self.cluster_environment is not None | |
if self._start_method == "popen": | |
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) | |
else: | |
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method) | |
def setup_environment(self) -> None: | |
super().setup_environment() | |
self._setup_distributed() | |
def setup_module(self, module: Module) -> DistributedDataParallel: | |
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" | |
device_ids = self._determine_ddp_device_ids() | |
# https://pytorch.org/docs/stable/notes/cuda.html#id5 | |
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() | |
with ctx: | |
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) | |
def module_to_device(self, module: Module) -> None: | |
module.to(self.root_device) | |
def all_reduce( | |
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" | |
) -> Tensor: | |
"""Reduces a tensor from several distributed processes to one aggregated tensor. | |
Args: | |
tensor: the tensor to sync and reduce | |
group: the process group to gather results from. Defaults to all processes (world) | |
reduce_op: the reduction operation. Defaults to 'mean'/'avg'. | |
Can also be a string 'sum' to calculate the sum during reduction. | |
Return: | |
reduced value, except when the input was not a tensor the output remains is unchanged | |
""" | |
if isinstance(tensor, Tensor): | |
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) | |
return tensor | |
def barrier(self, *args: Any, **kwargs: Any) -> None: | |
if not _distributed_is_initialized(): | |
return | |
if torch.distributed.get_backend() == "nccl": | |
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) | |
else: | |
torch.distributed.barrier() | |
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: | |
if not _distributed_is_initialized(): | |
return obj | |
obj = [obj] | |
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) | |
return obj[0] | |
def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: | |
if isinstance(module, DistributedDataParallel): | |
module = module.module | |
return super().get_module_state_dict(module) | |
def load_module_state_dict( | |
self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True | |
) -> None: | |
if isinstance(module, DistributedDataParallel): | |
module = module.module | |
super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict) | |
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: | |
entries = ( | |
("ddp", "popen"), | |
("ddp_spawn", "spawn"), | |
("ddp_fork", "fork"), | |
("ddp_notebook", "fork"), | |
) | |
for name, start_method in entries: | |
strategy_registry.register( | |
name, | |
cls, | |
description=f"DDP strategy with `start_method={start_method!r}`", | |
start_method=start_method, | |
) | |
strategy_registry.register( | |
"ddp_find_unused_parameters_true", | |
cls, | |
description="Alias for `find_unused_parameters_true` and `start_method='popen'`", | |
find_unused_parameters=True, | |
start_method="popen", | |
) | |
def _setup_distributed(self) -> None: | |
self._set_world_ranks() | |
self._process_group_backend = self._get_process_group_backend() | |
assert self.cluster_environment is not None | |
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) | |
def _get_process_group_backend(self) -> str: | |
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) | |
def _set_world_ranks(self) -> None: | |
if self.cluster_environment is not None: | |
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) | |
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) | |
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail | |
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter | |
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank | |
def _determine_ddp_device_ids(self) -> Optional[list[int]]: | |
return None if self.root_device.type == "cpu" else [self.root_device.index] | |
class _DDPBackwardSyncControl(_BackwardSyncControl): | |
def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: | |
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel` | |
wrapper.""" | |
if not enabled: | |
return nullcontext() | |
if not isinstance(module, DistributedDataParallel): | |
raise TypeError( | |
"Blocking backward sync is only possible if the module passed to" | |
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `DistributedDataParallel`." | |
f" Got: {module.__class__.__name__}." | |
) | |
return module.no_sync() | |