stable-diffusion-implementation
/
main
/myenv
/lib
/python3.10
/site-packages
/lightning_fabric
/utilities
/init.py
# Copyright The Lightning AI team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import itertools | |
from collections.abc import Sequence | |
from typing import Any, Callable, Optional, Union | |
import torch | |
from torch.nn import Module, Parameter | |
from torch.optim import Optimizer | |
from torch.overrides import TorchFunctionMode | |
from typing_extensions import override | |
from lightning_fabric.utilities.rank_zero import rank_zero_warn | |
from lightning_fabric.utilities.types import _DEVICE | |
# From https://lernapparat.de/faster-model-init by Thomas Viehmann | |
class _EmptyInit(TorchFunctionMode): | |
"""Initialize `nn.Module` with empty tensors, i.e., uninitialized memory. | |
Example:: | |
with _EmptyInit(): | |
model = BigModel() | |
model.load_state_dict(torch.load("checkpoint.pt")) | |
""" | |
def __init__(self, enabled: bool = True) -> None: | |
super().__init__() | |
self.enabled = enabled | |
def __torch_function__( | |
self, | |
func: Callable, | |
types: Sequence, | |
args: Sequence[Any] = (), | |
kwargs: Optional[dict] = None, | |
) -> Any: | |
kwargs = kwargs or {} | |
if not self.enabled: | |
return func(*args, **kwargs) | |
if getattr(func, "__module__", None) == "torch.nn.init": | |
if "tensor" in kwargs: | |
return kwargs["tensor"] | |
return args[0] | |
return func(*args, **kwargs) | |
def _materialize(module: Module, device: _DEVICE) -> None: | |
"""Materialize a module.""" | |
module.to_empty(device=device, recurse=False) | |
if not hasattr(module, "reset_parameters"): | |
raise TypeError( | |
f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented." | |
" This method is used to initialize any children parameters or buffers in this module." | |
) | |
if callable(module.reset_parameters): | |
module.reset_parameters() | |
def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None: | |
"""Materialize all tensors in a given module.""" | |
for module in module.modules(): | |
if _has_meta_device_parameters_or_buffers(module, recurse=False): | |
_materialize(module, device) | |
def _materialize_distributed_module(module: Module, device: torch.device) -> None: | |
# Reference: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md#meta-device-initialization | |
# TODO: Introduce `Fabric.materialize(module)` to give user control when materialization should happen | |
# TODO: Make `torchmetrics.Metric` compatible with the `to_empty()` + `reset_parameters()` semantics | |
if not _has_meta_device_parameters_or_buffers(module): | |
return | |
module.to_empty(device=device) # has to be called on the root module | |
uninitialized_modules = set() | |
for submodule in module.modules(): | |
if all(False for _ in itertools.chain(submodule.parameters(recurse=False), submodule.buffers(recurse=False))): | |
# module has no parameters or buffers | |
continue | |
if callable(reset_method := getattr(submodule, "reset_parameters", None)): | |
reset_method() | |
else: | |
uninitialized_modules.add(type(submodule).__name__) | |
if uninitialized_modules: | |
rank_zero_warn( | |
"Parameter initialization incomplete. The following modules have parameters or buffers with uninitialized" | |
" memory because they don't define a `reset_parameters()` method for re-initialization:" | |
f" {', '.join(uninitialized_modules)}" | |
) | |
def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool: | |
if isinstance(obj, Optimizer): | |
return any( | |
t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter) | |
) | |
if isinstance(obj, Module): | |
return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse))) | |
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") | |