sep_cache / custom_generate_split_4_backup /monkey_patching_utils.py
Gausson's picture
Upload 5 files
01f4d5b verified
import torch
import inspect
import importlib
import transformers
import types
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from typing import Callable, Optional, Union, Any, List
from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
def get_full_class_import_path(obj):
"""Get the complete class import path of an object"""
# Get the class of the object
cls = obj.__class__
# Get the module name where the class is defined
module = cls.__module__
# Get the qualified name of the class (including outer classes)
qualname = cls.__qualname__
# Handle nested classes (e.g., ClassA.ClassB)
if '.' in qualname:
# Replace nested class separators
class_path = f"{module}.{qualname.replace('.', '_')}"
else:
class_path = f"{module}.{qualname}"
return class_path
def get_importable_class_path(obj):
"""Get the directly importable class path (handling special cases and dynamic classes)"""
cls = obj.__class__
module = cls.__module__
qualname = cls.__qualname__
# Handle built-in types
if module == 'builtins':
return qualname
# Handle dynamically generated classes (e.g., functools.partial)
if not hasattr(cls, '__module__') or module is None:
return f"<dynamic class {qualname}>"
# Handle nested classes
if '.' in qualname:
# Try to import the parent module to validate the path
try:
import importlib
parent_module = importlib.import_module(module)
# Follow the qualified name path
parts = qualname.split('.')
current = parent_module
for part in parts:
current = getattr(current, part)
# If successful access, return the original path
return f"{module}.{qualname}"
except (ImportError, AttributeError):
# Fallback: use underscore connection
return f"{module}.{qualname.replace('.', '_')}"
return f"{module}.{qualname}"
def monkey_patch_by_class_path(model, new_forward):
"""Perform monkey patching through class path"""
# Get the complete class path
class_path = get_importable_class_path(model)
# Dynamically import the class
try:
import importlib
module_path, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_path)
target_class = getattr(module, class_name)
# Save the original method
if not hasattr(target_class, '_original_forward'):
target_class._original_forward = target_class.forward
# Apply the patch
target_class.forward = new_forward
# Update the method binding of the current instance
model.forward = types.MethodType(target_class.forward, model)
return f"Successful Monkey Patch: {class_path}.forward"
except (ImportError, AttributeError, ValueError) as e:
return f"Patch Failed: {str(e)}"
def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ):
# try to find the attribute of the name in `attr_name_list`.
for target_attr_name in attr_name_list:
if hasattr(obj, target_attr_name):
return getattr(obj, target_attr_name)
# else: try to find the attribute of the type `default_type`
for attr_name in dir(obj):
attr_value = getattr(obj, attr_name)
if isinstance(attr_value, default_type):
return attr_value
raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.")
def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module):
for attr_name in dir(obj):
attr_value = getattr(obj, attr_name)
for pattern in name_pattern_list:
for ex_pattern in exclude_pattern_list:
if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ):
return attr_value
elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ):
return attr_value
raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.")
def monkey_patching(model_obj, model_atten_forward , verbose = True):
transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
## get inner model
possible_inner_model_names = ["model", "transformer", "gpt_neox"]
inner_model_type = PreTrainedModel
inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
possible_layers_names = ["layers", "h" ]
layers_type = nn.ModuleList
model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
atten_attr_name_pattern_list = ["attention", "self_attn"]
atten_attr_name_pattern_exclude = ["norm", "layer"]
for i, decoder_layer in enumerate(model_layers):
self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
if verbose:
decoder_class_name = get_importable_class_path(decoder_layer)
print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
return model_layers