stable-diffusion-implementation
/
main
/myenv
/lib
/python3.10
/site-packages
/lightning_fabric
/cli.py
# Copyright The Lightning AI team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import os | |
import re | |
from argparse import Namespace | |
from typing import Any, Optional | |
import torch | |
from lightning_utilities.core.imports import RequirementCache | |
from typing_extensions import get_args | |
from lightning_fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator | |
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS | |
from lightning_fabric.strategies import STRATEGY_REGISTRY | |
from lightning_fabric.utilities.consolidate_checkpoint import _process_cli_args | |
from lightning_fabric.utilities.device_parser import _parse_gpu_ids | |
from lightning_fabric.utilities.distributed import _suggested_max_num_threads | |
from lightning_fabric.utilities.load import _load_distributed_checkpoint | |
_log = logging.getLogger(__name__) | |
_CLICK_AVAILABLE = RequirementCache("click") | |
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") | |
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") | |
def _get_supported_strategies() -> list[str]: | |
"""Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the | |
CLI or ones that require further configuration by the user.""" | |
available_strategies = STRATEGY_REGISTRY.available_strategies() | |
excluded = r".*(spawn|fork|notebook|xla|tpu|offload).*" | |
return [strategy for strategy in available_strategies if not re.match(excluded, strategy)] | |
if _CLICK_AVAILABLE: | |
import click | |
def _main() -> None: | |
pass | |
def _run(**kwargs: Any) -> None: | |
"""Run a Lightning Fabric script. | |
SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object. | |
SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed | |
there. | |
""" | |
script_args = list(kwargs.pop("script_args", [])) | |
main(args=Namespace(**kwargs), script_args=script_args) | |
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None: | |
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`. | |
Only supports FSDP sharded checkpoints at the moment. | |
""" | |
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file) | |
config = _process_cli_args(args) | |
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder) | |
torch.save(checkpoint, config.output_file) | |
def _set_env_variables(args: Namespace) -> None: | |
"""Set the environment variables for the new processes. | |
The Fabric connector will parse the arguments set here. | |
""" | |
os.environ["LT_CLI_USED"] = "1" | |
if args.accelerator is not None: | |
os.environ["LT_ACCELERATOR"] = str(args.accelerator) | |
if args.strategy is not None: | |
os.environ["LT_STRATEGY"] = str(args.strategy) | |
os.environ["LT_DEVICES"] = str(args.devices) | |
os.environ["LT_NUM_NODES"] = str(args.num_nodes) | |
if args.precision is not None: | |
os.environ["LT_PRECISION"] = str(args.precision) | |
def _get_num_processes(accelerator: str, devices: str) -> int: | |
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" | |
if accelerator == "gpu": | |
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) | |
elif accelerator == "cuda": | |
parsed_devices = CUDAAccelerator.parse_devices(devices) | |
elif accelerator == "mps": | |
parsed_devices = MPSAccelerator.parse_devices(devices) | |
elif accelerator == "tpu": | |
raise ValueError("Launching processes for TPU through the CLI is not supported.") | |
else: | |
return CPUAccelerator.parse_devices(devices) | |
return len(parsed_devices) if parsed_devices is not None else 0 | |
def _torchrun_launch(args: Namespace, script_args: list[str]) -> None: | |
"""This will invoke `torchrun` programmatically to launch the given script in new processes.""" | |
import torch.distributed.run as torchrun | |
num_processes = 1 if args.strategy == "dp" else _get_num_processes(args.accelerator, args.devices) | |
torchrun_args = [ | |
f"--nproc_per_node={num_processes}", | |
f"--nnodes={args.num_nodes}", | |
f"--node_rank={args.node_rank}", | |
f"--master_addr={args.main_address}", | |
f"--master_port={args.main_port}", | |
args.script, | |
] | |
torchrun_args.extend(script_args) | |
# set a good default number of threads for OMP to avoid warnings being emitted to the user | |
os.environ.setdefault("OMP_NUM_THREADS", str(_suggested_max_num_threads())) | |
torchrun.main(torchrun_args) | |
def main(args: Namespace, script_args: Optional[list[str]] = None) -> None: | |
_set_env_variables(args) | |
_torchrun_launch(args, script_args or []) | |
if __name__ == "__main__": | |
if not _CLICK_AVAILABLE: # pragma: no cover | |
_log.error( | |
"To use the Lightning Fabric CLI, you must have `click` installed." | |
" Install it by running `pip install -U click`." | |
) | |
raise SystemExit(1) | |
_run() | |