File size: 5,878 Bytes
01f4d5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import inspect
import importlib
import transformers
import types
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from typing import Callable, Optional, Union, Any, List
from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
def get_full_class_import_path(obj):
"""Get the complete class import path of an object"""
# Get the class of the object
cls = obj.__class__
# Get the module name where the class is defined
module = cls.__module__
# Get the qualified name of the class (including outer classes)
qualname = cls.__qualname__
# Handle nested classes (e.g., ClassA.ClassB)
if '.' in qualname:
# Replace nested class separators
class_path = f"{module}.{qualname.replace('.', '_')}"
else:
class_path = f"{module}.{qualname}"
return class_path
def get_importable_class_path(obj):
"""Get the directly importable class path (handling special cases and dynamic classes)"""
cls = obj.__class__
module = cls.__module__
qualname = cls.__qualname__
# Handle built-in types
if module == 'builtins':
return qualname
# Handle dynamically generated classes (e.g., functools.partial)
if not hasattr(cls, '__module__') or module is None:
return f"<dynamic class {qualname}>"
# Handle nested classes
if '.' in qualname:
# Try to import the parent module to validate the path
try:
import importlib
parent_module = importlib.import_module(module)
# Follow the qualified name path
parts = qualname.split('.')
current = parent_module
for part in parts:
current = getattr(current, part)
# If successful access, return the original path
return f"{module}.{qualname}"
except (ImportError, AttributeError):
# Fallback: use underscore connection
return f"{module}.{qualname.replace('.', '_')}"
return f"{module}.{qualname}"
def monkey_patch_by_class_path(model, new_forward):
"""Perform monkey patching through class path"""
# Get the complete class path
class_path = get_importable_class_path(model)
# Dynamically import the class
try:
import importlib
module_path, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_path)
target_class = getattr(module, class_name)
# Save the original method
if not hasattr(target_class, '_original_forward'):
target_class._original_forward = target_class.forward
# Apply the patch
target_class.forward = new_forward
# Update the method binding of the current instance
model.forward = types.MethodType(target_class.forward, model)
return f"Successful Monkey Patch: {class_path}.forward"
except (ImportError, AttributeError, ValueError) as e:
return f"Patch Failed: {str(e)}"
def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ):
# try to find the attribute of the name in `attr_name_list`.
for target_attr_name in attr_name_list:
if hasattr(obj, target_attr_name):
return getattr(obj, target_attr_name)
# else: try to find the attribute of the type `default_type`
for attr_name in dir(obj):
attr_value = getattr(obj, attr_name)
if isinstance(attr_value, default_type):
return attr_value
raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.")
def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module):
for attr_name in dir(obj):
attr_value = getattr(obj, attr_name)
for pattern in name_pattern_list:
for ex_pattern in exclude_pattern_list:
if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ):
return attr_value
elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ):
return attr_value
raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.")
def monkey_patching(model_obj, model_atten_forward , verbose = True):
transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
## get inner model
possible_inner_model_names = ["model", "transformer", "gpt_neox"]
inner_model_type = PreTrainedModel
inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
possible_layers_names = ["layers", "h" ]
layers_type = nn.ModuleList
model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
atten_attr_name_pattern_list = ["attention", "self_attn"]
atten_attr_name_pattern_exclude = ["norm", "layer"]
for i, decoder_layer in enumerate(model_layers):
self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
if verbose:
decoder_class_name = get_importable_class_path(decoder_layer)
print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
return model_layers
|