import torch import inspect import importlib import transformers import types import torch.nn as nn from transformers.modeling_utils import PreTrainedModel from typing import Callable, Optional, Union, Any, List from .functions_2_patch import _validate_model_kwargs, llama_atten_forward def get_full_class_import_path(obj): """Get the complete class import path of an object""" # Get the class of the object cls = obj.__class__ # Get the module name where the class is defined module = cls.__module__ # Get the qualified name of the class (including outer classes) qualname = cls.__qualname__ # Handle nested classes (e.g., ClassA.ClassB) if '.' in qualname: # Replace nested class separators class_path = f"{module}.{qualname.replace('.', '_')}" else: class_path = f"{module}.{qualname}" return class_path def get_importable_class_path(obj): """Get the directly importable class path (handling special cases and dynamic classes)""" cls = obj.__class__ module = cls.__module__ qualname = cls.__qualname__ # Handle built-in types if module == 'builtins': return qualname # Handle dynamically generated classes (e.g., functools.partial) if not hasattr(cls, '__module__') or module is None: return f"" # Handle nested classes if '.' in qualname: # Try to import the parent module to validate the path try: import importlib parent_module = importlib.import_module(module) # Follow the qualified name path parts = qualname.split('.') current = parent_module for part in parts: current = getattr(current, part) # If successful access, return the original path return f"{module}.{qualname}" except (ImportError, AttributeError): # Fallback: use underscore connection return f"{module}.{qualname.replace('.', '_')}" return f"{module}.{qualname}" def monkey_patch_by_class_path(model, new_forward): """Perform monkey patching through class path""" # Get the complete class path class_path = get_importable_class_path(model) # Dynamically import the class try: import importlib module_path, class_name = class_path.rsplit('.', 1) module = importlib.import_module(module_path) target_class = getattr(module, class_name) # Save the original method if not hasattr(target_class, '_original_forward'): target_class._original_forward = target_class.forward # Apply the patch target_class.forward = new_forward # Update the method binding of the current instance model.forward = types.MethodType(target_class.forward, model) return f"Successful Monkey Patch: {class_path}.forward" except (ImportError, AttributeError, ValueError) as e: return f"Patch Failed: {str(e)}" def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ): # try to find the attribute of the name in `attr_name_list`. for target_attr_name in attr_name_list: if hasattr(obj, target_attr_name): return getattr(obj, target_attr_name) # else: try to find the attribute of the type `default_type` for attr_name in dir(obj): attr_value = getattr(obj, attr_name) if isinstance(attr_value, default_type): return attr_value raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.") def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module): for attr_name in dir(obj): attr_value = getattr(obj, attr_name) for pattern in name_pattern_list: for ex_pattern in exclude_pattern_list: if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ): return attr_value elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ): return attr_value raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.") def monkey_patching(model_obj, model_atten_forward , verbose = True): transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs ## get inner model possible_inner_model_names = ["model", "transformer", "gpt_neox"] inner_model_type = PreTrainedModel inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type) possible_layers_names = ["layers", "h" ] layers_type = nn.ModuleList model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type) atten_attr_name_pattern_list = ["attention", "self_attn"] atten_attr_name_pattern_exclude = ["norm", "layer"] for i, decoder_layer in enumerate(model_layers): self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module) result = monkey_patch_by_class_path(self_attn_module, model_atten_forward) if verbose: decoder_class_name = get_importable_class_path(decoder_layer) print(f"For Layer {i}'s `{decoder_class_name}`: {result}") return model_layers