Update app.py
Browse files
app.py
CHANGED
|
@@ -174,14 +174,16 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.path.join(model
|
|
| 174 |
# model_sketch = create_model_sketch('default').to('cuda') # create a model given opt.model and other options
|
| 175 |
# model_sketch.eval()
|
| 176 |
|
| 177 |
-
|
| 178 |
global pipeline
|
| 179 |
global MultiResNetModel
|
| 180 |
global cur_style
|
| 181 |
-
cur_style = 'line + shadow'
|
| 182 |
|
| 183 |
@spaces.GPU
|
| 184 |
def load_ckpt():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
weight_dtype = torch.float16
|
| 186 |
|
| 187 |
block_out_channels = [128, 128, 256, 512, 512]
|
|
@@ -291,10 +293,8 @@ def load_ckpt():
|
|
| 291 |
|
| 292 |
print('loaded pipeline')
|
| 293 |
|
| 294 |
-
return pipeline, MultiResNetModel
|
| 295 |
-
|
| 296 |
|
| 297 |
-
|
| 298 |
|
| 299 |
@spaces.GPU
|
| 300 |
def change_ckpt(style):
|
|
@@ -311,6 +311,10 @@ def change_ckpt(style):
|
|
| 311 |
else:
|
| 312 |
raise ValueError("Invalid style: {}".format(style))
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
cur_style = style
|
| 315 |
|
| 316 |
MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
|
|
@@ -349,6 +353,7 @@ def process_multi_images(files):
|
|
| 349 |
|
| 350 |
@spaces.GPU
|
| 351 |
def extract_lines(image):
|
|
|
|
| 352 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
| 353 |
|
| 354 |
rows = int(np.ceil(src.shape[0] / 16)) * 16
|
|
@@ -384,6 +389,7 @@ def extract_line_image(query_image_, resolution):
|
|
| 384 |
|
| 385 |
@spaces.GPU
|
| 386 |
def extract_sketch_line_image(query_image_, input_style):
|
|
|
|
| 387 |
if input_style != cur_style:
|
| 388 |
change_ckpt(input_style)
|
| 389 |
|
|
@@ -425,6 +431,10 @@ def colorize_image(extracted_line, reference_images, resolution, seed, num_infer
|
|
| 425 |
reference_images = process_multi_images(reference_images)
|
| 426 |
fix_random_seeds(seed)
|
| 427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
tar_width, tar_height = resolution
|
| 429 |
|
| 430 |
gr.Info("Image retrieval in progress...")
|
|
|
|
| 174 |
# model_sketch = create_model_sketch('default').to('cuda') # create a model given opt.model and other options
|
| 175 |
# model_sketch.eval()
|
| 176 |
|
|
|
|
| 177 |
global pipeline
|
| 178 |
global MultiResNetModel
|
| 179 |
global cur_style
|
|
|
|
| 180 |
|
| 181 |
@spaces.GPU
|
| 182 |
def load_ckpt():
|
| 183 |
+
global pipeline
|
| 184 |
+
global MultiResNetModel
|
| 185 |
+
global cur_style
|
| 186 |
+
cur_style = 'line + shadow'
|
| 187 |
weight_dtype = torch.float16
|
| 188 |
|
| 189 |
block_out_channels = [128, 128, 256, 512, 512]
|
|
|
|
| 293 |
|
| 294 |
print('loaded pipeline')
|
| 295 |
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
load_ckpt()
|
| 298 |
|
| 299 |
@spaces.GPU
|
| 300 |
def change_ckpt(style):
|
|
|
|
| 311 |
else:
|
| 312 |
raise ValueError("Invalid style: {}".format(style))
|
| 313 |
|
| 314 |
+
global pipeline
|
| 315 |
+
global MultiResNetModel
|
| 316 |
+
global cur_style
|
| 317 |
+
|
| 318 |
cur_style = style
|
| 319 |
|
| 320 |
MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
|
|
|
|
| 353 |
|
| 354 |
@spaces.GPU
|
| 355 |
def extract_lines(image):
|
| 356 |
+
global line_model
|
| 357 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
| 358 |
|
| 359 |
rows = int(np.ceil(src.shape[0] / 16)) * 16
|
|
|
|
| 389 |
|
| 390 |
@spaces.GPU
|
| 391 |
def extract_sketch_line_image(query_image_, input_style):
|
| 392 |
+
global cur_style
|
| 393 |
if input_style != cur_style:
|
| 394 |
change_ckpt(input_style)
|
| 395 |
|
|
|
|
| 431 |
reference_images = process_multi_images(reference_images)
|
| 432 |
fix_random_seeds(seed)
|
| 433 |
|
| 434 |
+
global pipeline
|
| 435 |
+
global MultiResNetModel
|
| 436 |
+
global cur_style
|
| 437 |
+
|
| 438 |
tar_width, tar_height = resolution
|
| 439 |
|
| 440 |
gr.Info("Image retrieval in progress...")
|