File size: 5,878 Bytes
01f4d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import inspect
import importlib
import transformers
import types

import torch.nn as nn
from transformers.modeling_utils import  PreTrainedModel
from typing import Callable, Optional, Union, Any, List

from .functions_2_patch import _validate_model_kwargs, llama_atten_forward


def get_full_class_import_path(obj):
    """Get the complete class import path of an object"""
    # Get the class of the object
    cls = obj.__class__
    
    # Get the module name where the class is defined
    module = cls.__module__
    
    # Get the qualified name of the class (including outer classes)
    qualname = cls.__qualname__
    
    # Handle nested classes (e.g., ClassA.ClassB)
    if '.' in qualname:
        # Replace nested class separators
        class_path = f"{module}.{qualname.replace('.', '_')}"
    else:
        class_path = f"{module}.{qualname}"
    
    return class_path


def get_importable_class_path(obj):
    """Get the directly importable class path (handling special cases and dynamic classes)"""
    cls = obj.__class__
    module = cls.__module__
    qualname = cls.__qualname__
    
    # Handle built-in types
    if module == 'builtins':
        return qualname
    
    # Handle dynamically generated classes (e.g., functools.partial)
    if not hasattr(cls, '__module__') or module is None:
        return f"<dynamic class {qualname}>"
    
    # Handle nested classes
    if '.' in qualname:
        # Try to import the parent module to validate the path
        try:
            import importlib
            parent_module = importlib.import_module(module)
            
            # Follow the qualified name path
            parts = qualname.split('.')
            current = parent_module
            for part in parts:
                current = getattr(current, part)
            
            # If successful access, return the original path
            return f"{module}.{qualname}"
        except (ImportError, AttributeError):
            # Fallback: use underscore connection
            return f"{module}.{qualname.replace('.', '_')}"
    
    return f"{module}.{qualname}"



def monkey_patch_by_class_path(model, new_forward):
    """Perform monkey patching through class path"""
    # Get the complete class path
    class_path = get_importable_class_path(model)
    
    # Dynamically import the class
    try:
        import importlib
        module_path, class_name = class_path.rsplit('.', 1)
        module = importlib.import_module(module_path)
        target_class = getattr(module, class_name)
        
        # Save the original method
        if not hasattr(target_class, '_original_forward'):
            target_class._original_forward = target_class.forward
        
        # Apply the patch
        target_class.forward = new_forward
        
        # Update the method binding of the current instance
        model.forward = types.MethodType(target_class.forward, model)
        
        return f"Successful Monkey Patch: {class_path}.forward"
    
    except (ImportError, AttributeError, ValueError) as e:
        return f"Patch Failed: {str(e)}"




def find_inner_attribute(obj,  attr_name_list: List[str], default_type = PreTrainedModel ):
    # try to find the attribute of the name in `attr_name_list`.
    for target_attr_name in attr_name_list:
        if hasattr(obj, target_attr_name):
            return getattr(obj, target_attr_name)
        
    # else: try to find the attribute of the type `default_type`
    for attr_name in dir(obj): 
        attr_value = getattr(obj, attr_name)   
        if isinstance(attr_value, default_type):
            return attr_value

    raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.")


def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module):
    for attr_name in dir(obj): 
        attr_value = getattr(obj, attr_name)   
        for pattern in name_pattern_list:
            for ex_pattern in exclude_pattern_list:
                if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ):
                    return attr_value
                elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ):
                    return attr_value

    raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.")



def monkey_patching(model_obj, model_atten_forward , verbose = True):
    transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs    

    ## get inner model
    possible_inner_model_names = ["model", "transformer", "gpt_neox"]
    inner_model_type = PreTrainedModel
    inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)


    possible_layers_names = ["layers", "h" ]
    layers_type = nn.ModuleList
    model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
    
    atten_attr_name_pattern_list = ["attention", "self_attn"]
    atten_attr_name_pattern_exclude = ["norm", "layer"]

    for i, decoder_layer in enumerate(model_layers):
        self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
        result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
        if verbose:
            decoder_class_name = get_importable_class_path(decoder_layer)
            print(f"For Layer {i}'s `{decoder_class_name}`: {result}")

    return model_layers