feat: only apply select_task_for_layer if task has changed
Browse files- modeling_lora.py +5 -4
modeling_lora.py
CHANGED
|
@@ -265,10 +265,11 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 265 |
@current_task.setter
|
| 266 |
def current_task(self, task_idx: Union[None, int]):
|
| 267 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 268 |
-
self._task_idx
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
|
| 273 |
def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
|
| 274 |
if current_task is None or current_task >= 0:
|
|
|
|
| 265 |
@current_task.setter
|
| 266 |
def current_task(self, task_idx: Union[None, int]):
|
| 267 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 268 |
+
if self._task_idx != task_idx
|
| 269 |
+
self._task_idx = task_idx
|
| 270 |
+
self.apply(
|
| 271 |
+
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
| 272 |
+
)
|
| 273 |
|
| 274 |
def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
|
| 275 |
if current_task is None or current_task >= 0:
|