Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import tensorflow as tf | |
| if tf.__version__ >= '2.0': | |
| tf = tf.compat.v1 | |
| class ImageMattingPipeline: | |
| def __init__(self, model_dir: str, input_name: str = 'input_image:0', output_name: str = 'output_png:0'): | |
| model_path = os.path.join(model_dir, 'tf_graph.pb') | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError("Model file not found at {}".format(model_path)) | |
| config = tf.ConfigProto(allow_soft_placement=True) | |
| config.gpu_options.allow_growth = True | |
| self.graph = tf.Graph() | |
| with self.graph.as_default(): | |
| self._session = tf.Session(config=config) | |
| with tf.gfile.FastGFile(model_path, 'rb') as f: | |
| graph_def = tf.GraphDef() | |
| graph_def.ParseFromString(f.read()) | |
| tf.import_graph_def(graph_def, name='') | |
| self.output = self._session.graph.get_tensor_by_name(output_name) | |
| self.input_name = input_name | |
| def preprocess(self, input_image): | |
| img = np.array(input_image) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| img = img.astype(float) | |
| return {'img': img} | |
| def forward(self, input, output_mask=False, alpha_threshold=128): | |
| with self.graph.as_default(), self._session.as_default(): | |
| feed_dict = {self.input_name: input['img']} | |
| output_img = self._session.run(self.output, feed_dict=feed_dict) | |
| result = {'output_img': output_img} | |
| if output_mask: | |
| alpha_channel = output_img[:, :, 3] | |
| mask = np.zeros(alpha_channel.shape, dtype=np.uint8) | |
| mask[alpha_channel >= alpha_threshold] = 255 | |
| output_img[mask == 0, 3] = 0 | |
| result['mask'] = mask | |
| return result | |
| def apply_filters(mask: np.array, closing_kernel: tuple = (5, 5), opening_kernel: tuple = (5, 5), | |
| blur_kernel: tuple = (3, 3), bilateral_params: tuple = (9, 75, 75), | |
| min_area: int = 2000) -> np.array: | |
| mask = mask.astype(np.uint8) | |
| closing_element = np.ones(closing_kernel, np.uint8) | |
| opening_element = np.ones(opening_kernel, np.uint8) | |
| closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, closing_element) | |
| opened_mask = cv2.morphologyEx(closed_mask, cv2.MORPH_OPEN, opening_element) | |
| smoothed_mask = cv2.GaussianBlur(opened_mask, blur_kernel, 0) | |
| edge_smoothed_mask = cv2.bilateralFilter(smoothed_mask, *bilateral_params) | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edge_smoothed_mask, connectivity=8) | |
| large_component_mask = np.zeros_like(edge_smoothed_mask) | |
| for i in range(1, num_labels): | |
| if stats[i, cv2.CC_STAT_AREA] >= min_area: | |
| large_component_mask[labels == i] = 255 | |
| return large_component_mask | |
| def matting_interface(input_image, apply_morphology): | |
| input_image = np.array(input_image) | |
| input_image = input_image[:, :, ::-1] | |
| pipeline = ImageMattingPipeline(model_dir='cv_unet_universal-matting') | |
| preprocessed = pipeline.preprocess(input_image) | |
| result = pipeline.forward(preprocessed, output_mask=True) | |
| if apply_morphology: | |
| mask = apply_filters(result['mask']) | |
| else: | |
| mask = result.get('mask', None) | |
| output_img_pil = Image.fromarray(result['output_img'].astype(np.uint8)) | |
| mask_pil = Image.fromarray(mask) if mask is not None else None | |
| return output_img_pil, mask_pil | |
| iface = gr.Interface( | |
| fn=matting_interface, | |
| inputs=[ | |
| gr.components.Image(type="pil", image_mode="RGB"), | |
| gr.components.Checkbox(label="Apply Morphological Processing for Mask") | |
| ], | |
| outputs=[ | |
| gr.components.Image(type="pil", label="Matting Result"), | |
| gr.components.Image(type="pil", label="Mask"), | |
| ], | |
| title="Image Matting and Mask", | |
| description="Upload an image to get the matting result and mask. " | |
| "Use the checkbox to enable or disable morphological processing on the mask." | |
| ) | |
| iface.launch() |