stable-diffusion-implementation
/
main
/myenv
/lib
/python3.10
/site-packages
/lightning_fabric
/utilities
/load.py
# Copyright 2023 MathInf GmbH | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this files from this repository except in compliance | |
# with the License reproduced below (also at | |
# http://www.apache.org/licenses/LICENSE-2.0). | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import pickle | |
import warnings | |
from collections import OrderedDict | |
from collections.abc import Sequence | |
from functools import partial | |
from io import BytesIO | |
from pathlib import Path | |
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union | |
import torch | |
from lightning_utilities.core.apply_func import apply_to_collection | |
from torch import Tensor | |
from torch._C import _TensorMeta | |
from torch.nn import Parameter | |
from typing_extensions import override | |
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 | |
from lightning_fabric.utilities.types import _PATH, _Stateful | |
_METADATA_FILENAME = "meta.pt" | |
if TYPE_CHECKING: | |
from torch.storage import TypedStorage | |
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann | |
class _NotYetLoadedTensor: | |
def __init__( | |
self, | |
metatensor: Tensor, | |
archiveinfo: "_LazyLoadingUnpickler", | |
storageinfo: tuple, | |
rebuild_args: tuple, | |
) -> None: | |
self.metatensor = metatensor | |
self.archiveinfo = archiveinfo | |
self.storageinfo = storageinfo | |
self.rebuild_args = rebuild_args | |
def rebuild_from_type_v2( | |
cls, | |
func: Callable, | |
new_type: _TensorMeta, | |
args: tuple, | |
state: dict, | |
*, | |
archiveinfo: Optional["_LazyLoadingUnpickler"] = None, | |
) -> Any: | |
ret = func(*args) | |
if isinstance(ret, _NotYetLoadedTensor): | |
old_lt = ret._load_tensor | |
def _load_tensor() -> Any: | |
t = old_lt() | |
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) | |
ret._load_tensor = _load_tensor # type: ignore[method-assign] | |
return ret | |
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) | |
def rebuild_parameter( | |
cls, | |
data: Any, | |
requires_grad: bool, | |
backward_hooks: OrderedDict, | |
*, | |
archiveinfo: Optional["_LazyLoadingUnpickler"] = None, | |
) -> Union[Tensor, "_NotYetLoadedTensor"]: | |
if isinstance(data, _NotYetLoadedTensor): | |
old_lt = data._load_tensor | |
def _load_tensor() -> Parameter: | |
t = old_lt() | |
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) | |
data._load_tensor = _load_tensor # type: ignore[method-assign] | |
return data | |
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) | |
def rebuild_tensor_v2( | |
cls, | |
storage: "TypedStorage", | |
storage_offset: int, | |
size: tuple, | |
stride: tuple, | |
requires_grad: bool, | |
backward_hooks: OrderedDict, | |
metadata: Optional[Any] = None, | |
*, | |
archiveinfo: "_LazyLoadingUnpickler", | |
) -> "_NotYetLoadedTensor": | |
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) | |
metatensor = torch._utils._rebuild_tensor_v2( | |
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata | |
) | |
storageinfo = storage.archiveinfo | |
return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) | |
def _load_tensor(self) -> Tensor: | |
from torch.storage import TypedStorage, UntypedStorage | |
_, _, fn, _, size = self.storageinfo | |
dtype = self.metatensor.dtype | |
storage = self.archiveinfo.file_reader.get_storage_from_record( | |
f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage | |
) | |
uts = storage._typed_storage()._untyped_storage | |
with warnings.catch_warnings(): | |
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now | |
warnings.simplefilter("ignore") | |
storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True) | |
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) | |
def __torch_function__( | |
cls, | |
func: Callable, | |
types: Sequence, | |
args: Sequence[Any] = (), | |
kwargs: Optional[dict] = None, | |
) -> Any: | |
kwargs = kwargs or {} | |
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] | |
return func(*loaded_args, **kwargs) | |
def device(self) -> torch.device: | |
return torch.device(self.storageinfo[3]) | |
def __getattr__(self, name: str) -> Any: | |
# These properties don't require materialization and can be accessed through the meta tensor directly | |
if name in { | |
"dtype", | |
"grad", | |
"grad_fn", | |
"is_meta", | |
"layout", | |
"names", | |
"ndim", | |
"output_nr", | |
"requires_grad", | |
"retains_grad", | |
"size", | |
"shape", | |
"volatile", | |
}: | |
return getattr(self.metatensor, name) | |
# materializing these is needed for quantization (see lit-gpt) | |
if name in {"contiguous", "cuda", "half", "data", "to"}: | |
return getattr(self._load_tensor(), name) | |
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}({repr(self.metatensor)})" | |
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann | |
class _LazyLoadingUnpickler(pickle.Unpickler): | |
def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None: | |
super().__init__(file) | |
self.file_reader = file_reader | |
def find_class(self, module: str, name: str) -> Any: | |
if module == "torch._utils" and name == "_rebuild_tensor_v2": | |
return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) | |
if module == "torch._tensor" and name == "_rebuild_from_type_v2": | |
return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) | |
if module == "torch._utils" and name == "_rebuild_parameter": | |
return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) | |
return super().find_class(module, name) | |
def persistent_load(self, pid: tuple) -> "TypedStorage": | |
from torch.storage import TypedStorage | |
_, cls, _, _, _ = pid | |
with warnings.catch_warnings(): | |
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now | |
warnings.simplefilter("ignore") | |
storage = TypedStorage(dtype=cls().dtype, device="meta") | |
storage.archiveinfo = pid | |
return storage | |
def _lazy_load(filename: _PATH) -> Any: | |
if not os.path.isfile(filename): | |
raise FileNotFoundError(f"Path {str(filename)!r} does not exist or is not a file.") | |
file_reader = torch.PyTorchFileReader(str(filename)) | |
with BytesIO(file_reader.get_record("data.pkl")) as pkl: | |
mup = _LazyLoadingUnpickler(pkl, file_reader) | |
return mup.load() | |
def _materialize_tensors(collection: Any) -> Any: | |
def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: | |
return t._load_tensor() | |
return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor) | |
def _move_state_into( | |
source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None | |
) -> None: | |
"""Takes the state from the source destination and moves it into the destination dictionary. | |
If an object in the destination follows the stateful protocol, it loads the source state via ``load_state_dict``. | |
""" | |
keys = set(source) if keys is None else keys & set(source) | |
for key in keys: | |
state = source.pop(key) | |
if key in destination and isinstance(destination[key], _Stateful): | |
destination[key].load_state_dict(state) | |
else: | |
destination[key] = state | |
def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]: | |
"""Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict. | |
The current implementation assumes that the entire checkpoint fits in CPU memory. | |
""" | |
if not _TORCH_GREATER_EQUAL_2_3: | |
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.") | |
from torch.distributed.checkpoint import FileSystemReader | |
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner | |
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict | |
checkpoint: dict[str, Any] = {} | |
_load_state_dict( | |
checkpoint, | |
storage_reader=FileSystemReader(checkpoint_folder), | |
planner=_EmptyStateDictLoadPlanner(), | |
no_dist=True, | |
) | |
# This is the extra file saved by Fabric, with user data separate from weights and optimizer states | |
extra_file = checkpoint_folder / _METADATA_FILENAME | |
extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {} | |
checkpoint.update(extra) | |
return checkpoint | |