| import torch.nn as nn | |
| from typing import List | |
| from mmengine.logging import MMLogger | |
| first_set_requires_grad = True | |
| first_set_train = True | |
| def set_requires_grad(model: nn.Module, keywords: List[str]): | |
| """ | |
| notice:key in name! | |
| """ | |
| requires_grad_names = [] | |
| num_params = 0 | |
| num_trainable = 0 | |
| for name, param in model.named_parameters(): | |
| num_params += param.numel() | |
| if any(key in name for key in keywords): | |
| param.requires_grad = True | |
| requires_grad_names.append(name) | |
| num_trainable += param.numel() | |
| else: | |
| param.requires_grad = False | |
| global first_set_requires_grad | |
| if first_set_requires_grad: | |
| logger = MMLogger.get_current_instance() | |
| for name in requires_grad_names: | |
| logger.info(f"set_requires_grad----{name}") | |
| logger.info( | |
| f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%" | |
| ) | |
| first_set_requires_grad = False | |
| def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""): | |
| train_names = [] | |
| for name, child in model.named_children(): | |
| fullname = ".".join([prefix, name]) | |
| if any(name.startswith(key) for key in keywords): | |
| train_names.append(fullname) | |
| child.train() | |
| else: | |
| train_names += _set_train(child, keywords, prefix=fullname) | |
| return train_names | |
| def set_train(model: nn.Module, keywords: List[str]): | |
| """ | |
| notice:sub name startwith key! | |
| """ | |
| model.train(False) | |
| train_names = _set_train(model, keywords) | |
| global first_set_train | |
| if first_set_train: | |
| logger = MMLogger.get_current_instance() | |
| for train_name in train_names: | |
| logger.info(f"set_train----{train_name}") | |
| first_set_train = False |