soiz1's picture
Update main.py
fb02422 verified
raw
history blame
3.05 kB
import gradio as gr
import tensorflow as tf
import cv2
import numpy as np
from PIL import Image
import neuralgym as ng
from preprocess_image import preprocess_image
from inpaint_model import InpaintCAModel
# ===== Inpainting function ===== / ===== 画像修復処理関数 =====
def inpaint_image(input_image, watermark_type, checkpoint_dir):
# Convert image from Gradio (PIL format) / Gradioから受け取る画像をPIL形式で処理
image = input_image.convert("RGB")
# Preprocessing / 前処理
input_image = preprocess_image(image, watermark_type)
if input_image.shape == (0,):
return None
# Load configuration file / 設定ファイルの読み込み
FLAGS = ng.Config('inpaint.yml')
# Reset TensorFlow graph / TensorFlowグラフをリセット
tf.reset_default_graph()
model = InpaintCAModel()
# GPU configuration / GPU設定
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
# Start TensorFlow session / TensorFlowセッション開始
with tf.Session(config=sess_config) as sess:
# Create tensor from image / 画像をテンソルに変換
input_image_tensor = tf.constant(input_image, dtype=tf.float32)
# Build the model graph / モデルグラフを構築
output = model.build_server_graph(FLAGS, input_image_tensor)
output = (output + 1.) * 127.5
output = tf.reverse(output, [-1])
output = tf.saturate_cast(output, tf.uint8)
# Load model variables / モデル変数を読み込み
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = []
for var in vars_list:
from_name = var.name
var_value = tf.contrib.framework.load_variable(checkpoint_dir, from_name)
assign_ops.append(tf.assign(var, var_value))
sess.run(assign_ops)
print('Model loaded.') # モデルの読み込み完了
result = sess.run(output)
result_img = result[0][:, :, ::-1] # Convert BGR to RGB / BGRからRGBに変換
# Convert numpy array to PIL image / numpy配列をPIL画像に変換
return Image.fromarray(result_img)
# ===== Gradio User Interface ===== / ===== Gradioユーザーインターフェース =====
iface = gr.Interface(
fn=inpaint_image,
inputs=[
gr.Image(label="Input Image / 入力画像", type="pil"),
gr.Radio(["istock", "other"], label="Watermark Type / ウォーターマークタイプ", value="istock"),
gr.Textbox(label="Checkpoint Directory / チェックポイントディレクトリ", value="model/")
],
outputs=gr.Image(label="Inpainted Image / 修復済み画像"),
title="Watermark Inpainting Model / ウォーターマーク除去モデル",
description="Upload an image to remove the watermark using a TensorFlow model. / TensorFlowモデルを使用してウォーターマークを除去します。",
)
# Run the app / アプリを起動
if __name__ == "__main__":
iface.launch()