Spaces:
Running
Running
Update hf_merge.py
Browse files- hf_merge.py +4 -1
hf_merge.py
CHANGED
|
@@ -5,6 +5,7 @@ import requests
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
from huggingface_hub import HfApi, hf_hub_download
|
| 7 |
from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
|
|
|
|
| 8 |
import logging
|
| 9 |
|
| 10 |
# Set up logging
|
|
@@ -92,8 +93,10 @@ class ModelMerger:
|
|
| 92 |
repo_manager.delete_repo(repo_path)
|
| 93 |
repo_manager.download_repo(repo_name, repo_path)
|
| 94 |
|
|
|
|
|
|
|
| 95 |
try:
|
| 96 |
-
self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val)
|
| 97 |
logging.info(f"Merged {repo_name}")
|
| 98 |
except Exception as e:
|
| 99 |
logging.error(f"Error merging {repo_name}: {e}")
|
|
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
from huggingface_hub import HfApi, hf_hub_download
|
| 7 |
from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
|
| 8 |
+
import torch
|
| 9 |
import logging
|
| 10 |
|
| 11 |
# Set up logging
|
|
|
|
| 93 |
repo_manager.delete_repo(repo_path)
|
| 94 |
repo_manager.download_repo(repo_name, repo_path)
|
| 95 |
|
| 96 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 97 |
+
|
| 98 |
try:
|
| 99 |
+
self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val, device)
|
| 100 |
logging.info(f"Merged {repo_name}")
|
| 101 |
except Exception as e:
|
| 102 |
logging.error(f"Error merging {repo_name}: {e}")
|