alessandro trinca tornidor
commited on
Commit
·
eec88db
1
Parent(s):
7ad428c
feat: adding explicit gpu init in get_model()
Browse files
lisa_on_cuda/utils/app_helpers.py
CHANGED
|
@@ -169,6 +169,12 @@ def load_model_for_causal_llm_pretrained(
|
|
| 169 |
return _model
|
| 170 |
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"):
|
| 173 |
"""Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)
|
| 174 |
|
|
@@ -186,6 +192,10 @@ def get_model(args_to_parse, internal_logger: logging = None, inference_decorato
|
|
| 186 |
if internal_logger is None:
|
| 187 |
internal_logger = app_logger
|
| 188 |
internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
try:
|
| 190 |
vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
|
| 191 |
logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
|
|
|
|
| 169 |
return _model
|
| 170 |
|
| 171 |
|
| 172 |
+
def gpu_init_zero(internal_logger: logging = None):
|
| 173 |
+
if internal_logger is None:
|
| 174 |
+
internal_logger = app_logger
|
| 175 |
+
internal_logger.info("GPU init...")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None, device_map="auto", device="cpu", device2="cuda"):
|
| 179 |
"""Load model and inference function with arguments. Compatible with ZeroGPU (spaces 0.30.2)
|
| 180 |
|
|
|
|
| 192 |
if internal_logger is None:
|
| 193 |
internal_logger = app_logger
|
| 194 |
internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
|
| 195 |
+
if inference_decorator:
|
| 196 |
+
internal_logger.info(f"try explicit gpu init with decorator {inference_decorator.__name__}...")
|
| 197 |
+
inference_decorator(gpu_init_zero(internal_logger=internal_logger))
|
| 198 |
+
internal_logger.info(f"gpu explicitly initialized!")
|
| 199 |
try:
|
| 200 |
vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
|
| 201 |
logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
|