Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tensorflow as tf | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import neuralgym as ng | |
| from preprocess_image import preprocess_image | |
| from inpaint_model import InpaintCAModel | |
| # ===== Inpainting function ===== / ===== 画像修復処理関数 ===== | |
| def inpaint_image(input_image, watermark_type, checkpoint_dir): | |
| # Convert image from Gradio (PIL format) / Gradioから受け取る画像をPIL形式で処理 | |
| image = input_image.convert("RGB") | |
| # Preprocessing / 前処理 | |
| input_image = preprocess_image(image, watermark_type) | |
| if input_image.shape == (0,): | |
| return None | |
| # Load configuration file / 設定ファイルの読み込み | |
| FLAGS = ng.Config('inpaint.yml') | |
| # Reset TensorFlow graph / TensorFlowグラフをリセット | |
| tf.reset_default_graph() | |
| model = InpaintCAModel() | |
| # GPU configuration / GPU設定 | |
| sess_config = tf.ConfigProto() | |
| sess_config.gpu_options.allow_growth = True | |
| # Start TensorFlow session / TensorFlowセッション開始 | |
| with tf.Session(config=sess_config) as sess: | |
| # Create tensor from image / 画像をテンソルに変換 | |
| input_image_tensor = tf.constant(input_image, dtype=tf.float32) | |
| # Build the model graph / モデルグラフを構築 | |
| output = model.build_server_graph(FLAGS, input_image_tensor) | |
| output = (output + 1.) * 127.5 | |
| output = tf.reverse(output, [-1]) | |
| output = tf.saturate_cast(output, tf.uint8) | |
| # Load model variables / モデル変数を読み込み | |
| vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) | |
| assign_ops = [] | |
| for var in vars_list: | |
| from_name = var.name | |
| var_value = tf.contrib.framework.load_variable(checkpoint_dir, from_name) | |
| assign_ops.append(tf.assign(var, var_value)) | |
| sess.run(assign_ops) | |
| print('Model loaded.') # モデルの読み込み完了 | |
| result = sess.run(output) | |
| result_img = result[0][:, :, ::-1] # Convert BGR to RGB / BGRからRGBに変換 | |
| # Convert numpy array to PIL image / numpy配列をPIL画像に変換 | |
| return Image.fromarray(result_img) | |
| # ===== Gradio User Interface ===== / ===== Gradioユーザーインターフェース ===== | |
| iface = gr.Interface( | |
| fn=inpaint_image, | |
| inputs=[ | |
| gr.Image(label="Input Image / 入力画像", type="pil"), | |
| gr.Radio(["istock", "other"], label="Watermark Type / ウォーターマークタイプ", value="istock"), | |
| gr.Textbox(label="Checkpoint Directory / チェックポイントディレクトリ", value="model/") | |
| ], | |
| outputs=gr.Image(label="Inpainted Image / 修復済み画像"), | |
| title="Watermark Inpainting Model / ウォーターマーク除去モデル", | |
| description="Upload an image to remove the watermark using a TensorFlow model. / TensorFlowモデルを使用してウォーターマークを除去します。", | |
| ) | |
| # Run the app / アプリを起動 | |
| if __name__ == "__main__": | |
| iface.launch() | |