feat: use property in LoRA parametrization
Browse files- modeling_lora.py +8 -3
modeling_lora.py
CHANGED
|
@@ -116,8 +116,13 @@ class LoRAParametrization(nn.Module):
|
|
| 116 |
def forward(self, X):
|
| 117 |
return self.forward_fn(X)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
if task is None:
|
| 122 |
self.forward_fn = lambda x: x
|
| 123 |
else:
|
|
@@ -192,7 +197,7 @@ class LoRAParametrization(nn.Module):
|
|
| 192 |
@classmethod
|
| 193 |
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
|
| 194 |
if isinstance(layer, LoRAParametrization):
|
| 195 |
-
layer.
|
| 196 |
|
| 197 |
|
| 198 |
class BertLoRA(BertPreTrainedModel):
|
|
|
|
| 116 |
def forward(self, X):
|
| 117 |
return self.forward_fn(X)
|
| 118 |
|
| 119 |
+
@property
|
| 120 |
+
def current_task(self):
|
| 121 |
+
return self._current_task
|
| 122 |
+
|
| 123 |
+
@current_task.setter
|
| 124 |
+
def current_task(self, task: Union[None, int]):
|
| 125 |
+
self._current_task = task
|
| 126 |
if task is None:
|
| 127 |
self.forward_fn = lambda x: x
|
| 128 |
else:
|
|
|
|
| 197 |
@classmethod
|
| 198 |
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
|
| 199 |
if isinstance(layer, LoRAParametrization):
|
| 200 |
+
layer.current_task = task_idx
|
| 201 |
|
| 202 |
|
| 203 |
class BertLoRA(BertPreTrainedModel):
|