stable-diffusion-implementation
/
main
/myenv
/lib
/python3.10
/site-packages
/huggingface_hub
/dataclasses.py
import inspect | |
from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields | |
from functools import wraps | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
List, | |
Literal, | |
Optional, | |
Tuple, | |
Type, | |
TypeVar, | |
Union, | |
get_args, | |
get_origin, | |
overload, | |
) | |
from .errors import ( | |
StrictDataclassClassValidationError, | |
StrictDataclassDefinitionError, | |
StrictDataclassFieldValidationError, | |
) | |
Validator_T = Callable[[Any], None] | |
T = TypeVar("T") | |
# The overload decorator helps type checkers understand the different return types | |
def strict(cls: Type[T]) -> Type[T]: ... | |
def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ... | |
def strict( | |
cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False | |
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: | |
""" | |
Decorator to add strict validation to a dataclass. | |
This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools | |
recognize the class as a dataclass. | |
Can be used with or without arguments: | |
- `@strict` | |
- `@strict(accept_kwargs=True)` | |
Args: | |
cls: | |
The class to convert to a strict dataclass. | |
accept_kwargs (`bool`, *optional*): | |
If True, allows arbitrary keyword arguments in `__init__`. Defaults to False. | |
Returns: | |
The enhanced dataclass with strict validation on field assignment. | |
Example: | |
```py | |
>>> from dataclasses import dataclass | |
>>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field | |
>>> @as_validated_field | |
>>> def positive_int(value: int): | |
... if not value >= 0: | |
... raise ValueError(f"Value must be positive, got {value}") | |
>>> @strict(accept_kwargs=True) | |
... @dataclass | |
... class User: | |
... name: str | |
... age: int = positive_int(default=10) | |
# Initialize | |
>>> User(name="John") | |
User(name='John', age=10) | |
# Extra kwargs are accepted | |
>>> User(name="John", age=30, lastname="Doe") | |
User(name='John', age=30, *lastname='Doe') | |
# Invalid type => raises | |
>>> User(name="John", age="30") | |
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': | |
TypeError: Field 'age' expected int, got str (value: '30') | |
# Invalid value => raises | |
>>> User(name="John", age=-1) | |
huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': | |
ValueError: Value must be positive, got -1 | |
``` | |
""" | |
def wrap(cls: Type[T]) -> Type[T]: | |
if not hasattr(cls, "__dataclass_fields__"): | |
raise StrictDataclassDefinitionError( | |
f"Class '{cls.__name__}' must be a dataclass before applying @strict." | |
) | |
# List and store validators | |
field_validators: Dict[str, List[Validator_T]] = {} | |
for f in fields(cls): # type: ignore [arg-type] | |
validators = [] | |
validators.append(_create_type_validator(f)) | |
custom_validator = f.metadata.get("validator") | |
if custom_validator is not None: | |
if not isinstance(custom_validator, list): | |
custom_validator = [custom_validator] | |
for validator in custom_validator: | |
if not _is_validator(validator): | |
raise StrictDataclassDefinitionError( | |
f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument." | |
) | |
validators.extend(custom_validator) | |
field_validators[f.name] = validators | |
cls.__validators__ = field_validators # type: ignore | |
# Override __setattr__ to validate fields on assignment | |
original_setattr = cls.__setattr__ | |
def __strict_setattr__(self: Any, name: str, value: Any) -> None: | |
"""Custom __setattr__ method for strict dataclasses.""" | |
# Run all validators | |
for validator in self.__validators__.get(name, []): | |
try: | |
validator(value) | |
except (ValueError, TypeError) as e: | |
raise StrictDataclassFieldValidationError(field=name, cause=e) from e | |
# If validation passed, set the attribute | |
original_setattr(self, name, value) | |
cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign] | |
if accept_kwargs: | |
# (optional) Override __init__ to accept arbitrary keyword arguments | |
original_init = cls.__init__ | |
def __init__(self, **kwargs: Any) -> None: | |
# Extract only the fields that are part of the dataclass | |
dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type] | |
standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields} | |
# Call the original __init__ with standard fields | |
original_init(self, **standard_kwargs) | |
# Add any additional kwargs as attributes | |
for name, value in kwargs.items(): | |
if name not in dataclass_fields: | |
self.__setattr__(name, value) | |
cls.__init__ = __init__ # type: ignore[method-assign] | |
# (optional) Override __repr__ to include additional kwargs | |
original_repr = cls.__repr__ | |
def __repr__(self) -> str: | |
# Call the original __repr__ to get the standard fields | |
standard_repr = original_repr(self) | |
# Get additional kwargs | |
additional_kwargs = [ | |
# add a '*' in front of additional kwargs to let the user know they are not part of the dataclass | |
f"*{k}={v!r}" | |
for k, v in self.__dict__.items() | |
if k not in cls.__dataclass_fields__ # type: ignore [attr-defined] | |
] | |
additional_repr = ", ".join(additional_kwargs) | |
# Combine both representations | |
return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr | |
cls.__repr__ = __repr__ # type: ignore [method-assign] | |
# List all public methods starting with `validate_` => class validators. | |
class_validators = [] | |
for name in dir(cls): | |
if not name.startswith("validate_"): | |
continue | |
method = getattr(cls, name) | |
if not callable(method): | |
continue | |
if len(inspect.signature(method).parameters) != 1: | |
raise StrictDataclassDefinitionError( | |
f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument." | |
" Class validators must take only 'self' as an argument. Methods starting with 'validate_'" | |
" are considered to be class validators." | |
) | |
class_validators.append(method) | |
cls.__class_validators__ = class_validators # type: ignore [attr-defined] | |
# Add `validate` method to the class, but first check if it already exists | |
def validate(self: T) -> None: | |
"""Run class validators on the instance.""" | |
for validator in cls.__class_validators__: # type: ignore [attr-defined] | |
try: | |
validator(self) | |
except (ValueError, TypeError) as e: | |
raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e | |
# Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class | |
# (in which case we just override it) | |
validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined] | |
if hasattr(cls, "validate"): | |
if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined] | |
raise StrictDataclassDefinitionError( | |
f"Class '{cls.__name__}' already implements a method called 'validate'." | |
" This method name is reserved when using the @strict decorator on a dataclass." | |
" If you want to keep your own method, please rename it." | |
) | |
cls.validate = validate # type: ignore | |
# Run class validators after initialization | |
initial_init = cls.__init__ | |
def init_with_validate(self, *args, **kwargs) -> None: | |
"""Run class validators after initialization.""" | |
initial_init(self, *args, **kwargs) # type: ignore [call-arg] | |
cls.validate(self) # type: ignore [attr-defined] | |
setattr(cls, "__init__", init_with_validate) | |
return cls | |
# Return wrapped class or the decorator itself | |
return wrap(cls) if cls is not None else wrap | |
def validated_field( | |
validator: Union[List[Validator_T], Validator_T], | |
default: Union[Any, _MISSING_TYPE] = MISSING, | |
default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, | |
init: bool = True, | |
repr: bool = True, | |
hash: Optional[bool] = None, | |
compare: bool = True, | |
metadata: Optional[Dict] = None, | |
**kwargs: Any, | |
) -> Any: | |
""" | |
Create a dataclass field with a custom validator. | |
Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator. | |
Args: | |
validator (`Callable` or `List[Callable]`): | |
A method that takes a value as input and raises ValueError/TypeError if the value is invalid. | |
Can be a list of validators to apply multiple checks. | |
**kwargs: | |
Additional arguments to pass to `dataclasses.field()`. | |
Returns: | |
A field with the validator attached in metadata | |
""" | |
if not isinstance(validator, list): | |
validator = [validator] | |
if metadata is None: | |
metadata = {} | |
metadata["validator"] = validator | |
return field( # type: ignore | |
default=default, # type: ignore [arg-type] | |
default_factory=default_factory, # type: ignore [arg-type] | |
init=init, | |
repr=repr, | |
hash=hash, | |
compare=compare, | |
metadata=metadata, | |
**kwargs, | |
) | |
def as_validated_field(validator: Validator_T): | |
""" | |
Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator). | |
Args: | |
validator (`Callable`): | |
A method that takes a value as input and raises ValueError/TypeError if the value is invalid. | |
""" | |
def _inner( | |
default: Union[Any, _MISSING_TYPE] = MISSING, | |
default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, | |
init: bool = True, | |
repr: bool = True, | |
hash: Optional[bool] = None, | |
compare: bool = True, | |
metadata: Optional[Dict] = None, | |
**kwargs: Any, | |
): | |
return validated_field( | |
validator, | |
default=default, | |
default_factory=default_factory, | |
init=init, | |
repr=repr, | |
hash=hash, | |
compare=compare, | |
metadata=metadata, | |
**kwargs, | |
) | |
return _inner | |
def type_validator(name: str, value: Any, expected_type: Any) -> None: | |
"""Validate that 'value' matches 'expected_type'.""" | |
origin = get_origin(expected_type) | |
args = get_args(expected_type) | |
if expected_type is Any: | |
return | |
elif validator := _BASIC_TYPE_VALIDATORS.get(origin): | |
validator(name, value, args) | |
elif isinstance(expected_type, type): # simple types | |
_validate_simple_type(name, value, expected_type) | |
else: | |
raise TypeError(f"Unsupported type for field '{name}': {expected_type}") | |
def _validate_union(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
"""Validate that value matches one of the types in a Union.""" | |
errors = [] | |
for t in args: | |
try: | |
type_validator(name, value, t) | |
return # Valid if any type matches | |
except TypeError as e: | |
errors.append(str(e)) | |
raise TypeError( | |
f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}" | |
) | |
def _validate_literal(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
"""Validate Literal type.""" | |
if value not in args: | |
raise TypeError(f"Field '{name}' expected one of {args}, got {value}") | |
def _validate_list(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
"""Validate List[T] type.""" | |
if not isinstance(value, list): | |
raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}") | |
# Validate each item in the list | |
item_type = args[0] | |
for i, item in enumerate(value): | |
try: | |
type_validator(f"{name}[{i}]", item, item_type) | |
except TypeError as e: | |
raise TypeError(f"Invalid item at index {i} in list '{name}'") from e | |
def _validate_dict(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
"""Validate Dict[K, V] type.""" | |
if not isinstance(value, dict): | |
raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}") | |
# Validate keys and values | |
key_type, value_type = args | |
for k, v in value.items(): | |
try: | |
type_validator(f"{name}.key", k, key_type) | |
type_validator(f"{name}[{k!r}]", v, value_type) | |
except TypeError as e: | |
raise TypeError(f"Invalid key or value in dict '{name}'") from e | |
def _validate_tuple(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
"""Validate Tuple type.""" | |
if not isinstance(value, tuple): | |
raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}") | |
# Handle variable-length tuples: Tuple[T, ...] | |
if len(args) == 2 and args[1] is Ellipsis: | |
for i, item in enumerate(value): | |
try: | |
type_validator(f"{name}[{i}]", item, args[0]) | |
except TypeError as e: | |
raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e | |
# Handle fixed-length tuples: Tuple[T1, T2, ...] | |
elif len(args) != len(value): | |
raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}") | |
else: | |
for i, (item, expected) in enumerate(zip(value, args)): | |
try: | |
type_validator(f"{name}[{i}]", item, expected) | |
except TypeError as e: | |
raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e | |
def _validate_set(name: str, value: Any, args: Tuple[Any, ...]) -> None: | |
"""Validate Set[T] type.""" | |
if not isinstance(value, set): | |
raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}") | |
# Validate each item in the set | |
item_type = args[0] | |
for i, item in enumerate(value): | |
try: | |
type_validator(f"{name} item", item, item_type) | |
except TypeError as e: | |
raise TypeError(f"Invalid item in set '{name}'") from e | |
def _validate_simple_type(name: str, value: Any, expected_type: type) -> None: | |
"""Validate simple type (int, str, etc.).""" | |
if not isinstance(value, expected_type): | |
raise TypeError( | |
f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})" | |
) | |
def _create_type_validator(field: Field) -> Validator_T: | |
"""Create a type validator function for a field.""" | |
# Hacky: we cannot use a lambda here because of reference issues | |
def validator(value: Any) -> None: | |
type_validator(field.name, value, field.type) | |
return validator | |
def _is_validator(validator: Any) -> bool: | |
"""Check if a function is a validator. | |
A validator is a Callable that can be called with a single positional argument. | |
The validator can have more arguments with default values. | |
Basically, returns True if `validator(value)` is possible. | |
""" | |
if not callable(validator): | |
return False | |
signature = inspect.signature(validator) | |
parameters = list(signature.parameters.values()) | |
if len(parameters) == 0: | |
return False | |
if parameters[0].kind not in ( | |
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
inspect.Parameter.POSITIONAL_ONLY, | |
inspect.Parameter.VAR_POSITIONAL, | |
): | |
return False | |
for parameter in parameters[1:]: | |
if parameter.default == inspect.Parameter.empty: | |
return False | |
return True | |
_BASIC_TYPE_VALIDATORS = { | |
Union: _validate_union, | |
Literal: _validate_literal, | |
list: _validate_list, | |
dict: _validate_dict, | |
tuple: _validate_tuple, | |
set: _validate_set, | |
} | |
__all__ = [ | |
"strict", | |
"validated_field", | |
"Validator_T", | |
"StrictDataclassClassValidationError", | |
"StrictDataclassDefinitionError", | |
"StrictDataclassFieldValidationError", | |
] | |