Spaces:
Runtime error
Runtime error
| import os | |
| import uuid | |
| import numpy as np | |
| import torch | |
| from diffusers import ( | |
| EulerAncestralDiscreteScheduler, | |
| StableDiffusionInpaintPipeline, | |
| StableDiffusionInstructPix2PixPipeline, | |
| StableDiffusionPipeline, | |
| ) | |
| from PIL import Image | |
| from transformers import ( | |
| BlipForConditionalGeneration, | |
| BlipForQuestionAnswering, | |
| BlipProcessor, | |
| CLIPSegForImageSegmentation, | |
| CLIPSegProcessor, | |
| ) | |
| from swarms.models.prompts.prebuild.multi_modal_prompts import IMAGE_PROMPT | |
| from swarms.tools.base import tool | |
| from swarms.tools.main import BaseToolSet | |
| from swarms.utils.logger import logger | |
| from swarms.utils.main import BaseHandler, get_new_image_name | |
| class MaskFormer(BaseToolSet): | |
| def __init__(self, device): | |
| print("Initializing MaskFormer to %s" % device) | |
| self.device = device | |
| self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| self.model = CLIPSegForImageSegmentation.from_pretrained( | |
| "CIDAS/clipseg-rd64-refined" | |
| ).to(device) | |
| def inference(self, image_path, text): | |
| threshold = 0.5 | |
| min_area = 0.02 | |
| padding = 20 | |
| original_image = Image.open(image_path) | |
| image = original_image.resize((512, 512)) | |
| inputs = self.processor( | |
| text=text, images=image, padding="max_length", return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold | |
| area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1]) | |
| if area_ratio < min_area: | |
| return None | |
| true_indices = np.argwhere(mask) | |
| mask_array = np.zeros_like(mask, dtype=bool) | |
| for idx in true_indices: | |
| padded_slice = tuple( | |
| slice(max(0, i - padding), i + padding + 1) for i in idx | |
| ) | |
| mask_array[padded_slice] = True | |
| visual_mask = (mask_array * 255).astype(np.uint8) | |
| image_mask = Image.fromarray(visual_mask) | |
| return image_mask.resize(original_image.size) | |
| class ImageEditing(BaseToolSet): | |
| def __init__(self, device): | |
| print("Initializing ImageEditing to %s" % device) | |
| self.device = device | |
| self.mask_former = MaskFormer(device=self.device) | |
| self.revision = "fp16" if "cuda" in device else None | |
| self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
| self.inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| revision=self.revision, | |
| torch_dtype=self.torch_dtype, | |
| ).to(device) | |
| def inference_remove(self, inputs): | |
| image_path, to_be_removed_txt = inputs.split(",") | |
| return self.inference_replace(f"{image_path},{to_be_removed_txt},background") | |
| def inference_replace(self, inputs): | |
| image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") | |
| original_image = Image.open(image_path) | |
| original_size = original_image.size | |
| mask_image = self.mask_former.inference(image_path, to_be_replaced_txt) | |
| updated_image = self.inpaint( | |
| prompt=replace_with_txt, | |
| image=original_image.resize((512, 512)), | |
| mask_image=mask_image.resize((512, 512)), | |
| ).images[0] | |
| updated_image_path = get_new_image_name( | |
| image_path, func_name="replace-something" | |
| ) | |
| updated_image = updated_image.resize(original_size) | |
| updated_image.save(updated_image_path) | |
| logger.debug( | |
| f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, " | |
| f"Output Image: {updated_image_path}" | |
| ) | |
| return updated_image_path | |
| class InstructPix2Pix(BaseToolSet): | |
| def __init__(self, device): | |
| print("Initializing InstructPix2Pix to %s" % device) | |
| self.device = device | |
| self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
| self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
| "timbrooks/instruct-pix2pix", | |
| safety_checker=None, | |
| torch_dtype=self.torch_dtype, | |
| ).to(device) | |
| self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| self.pipe.scheduler.config | |
| ) | |
| def inference(self, inputs): | |
| """Change style of image.""" | |
| logger.debug("===> Starting InstructPix2Pix Inference") | |
| image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) | |
| original_image = Image.open(image_path) | |
| image = self.pipe( | |
| text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 | |
| ).images[0] | |
| updated_image_path = get_new_image_name(image_path, func_name="pix2pix") | |
| image.save(updated_image_path) | |
| logger.debug( | |
| f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, " | |
| f"Output Image: {updated_image_path}" | |
| ) | |
| return updated_image_path | |
| class Text2Image(BaseToolSet): | |
| def __init__(self, device): | |
| print("Initializing Text2Image to %s" % device) | |
| self.device = device | |
| self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
| self.pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype | |
| ) | |
| self.pipe.to(device) | |
| self.a_prompt = "best quality, extremely detailed" | |
| self.n_prompt = ( | |
| "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " | |
| "fewer digits, cropped, worst quality, low quality" | |
| ) | |
| def inference(self, text): | |
| image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png") | |
| prompt = text + ", " + self.a_prompt | |
| image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0] | |
| image.save(image_filename) | |
| logger.debug( | |
| f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}" | |
| ) | |
| return image_filename | |
| class VisualQuestionAnswering(BaseToolSet): | |
| def __init__(self, device): | |
| print("Initializing VisualQuestionAnswering to %s" % device) | |
| self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
| self.device = device | |
| self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| self.model = BlipForQuestionAnswering.from_pretrained( | |
| "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype | |
| ).to(self.device) | |
| def inference(self, inputs): | |
| image_path, question = inputs.split(",") | |
| raw_image = Image.open(image_path).convert("RGB") | |
| inputs = self.processor(raw_image, question, return_tensors="pt").to( | |
| self.device, self.torch_dtype | |
| ) | |
| out = self.model.generate(**inputs) | |
| answer = self.processor.decode(out[0], skip_special_tokens=True) | |
| logger.debug( | |
| f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, " | |
| f"Output Answer: {answer}" | |
| ) | |
| return answer | |
| class ImageCaptioning(BaseHandler): | |
| def __init__(self, device): | |
| print("Initializing ImageCaptioning to %s" % device) | |
| self.device = device | |
| self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
| self.processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| self.model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype | |
| ).to(self.device) | |
| def handle(self, filename: str): | |
| img = Image.open(filename) | |
| width, height = img.size | |
| ratio = min(512 / width, 512 / height) | |
| width_new, height_new = (round(width * ratio), round(height * ratio)) | |
| img = img.resize((width_new, height_new)) | |
| img = img.convert("RGB") | |
| img.save(filename, "PNG") | |
| print(f"Resize image form {width}x{height} to {width_new}x{height_new}") | |
| inputs = self.processor(Image.open(filename), return_tensors="pt").to( | |
| self.device, self.torch_dtype | |
| ) | |
| out = self.model.generate(**inputs) | |
| description = self.processor.decode(out[0], skip_special_tokens=True) | |
| print( | |
| f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text: {description}" | |
| ) | |
| return IMAGE_PROMPT.format(filename=filename, description=description) | |