Spaces:
Runtime error
Runtime error
Commit
·
8523150
1
Parent(s):
f9f482a
precomputed all rotations
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -2
- __pycache__/io_utils.cpython-311.pyc +0 -0
- app.py +20 -116
- data/cached_outputs/xr_1_(-10, -10, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, 10).png +0 -0
.gitignore
CHANGED
|
@@ -1,2 +1 @@
|
|
| 1 |
-
|
| 2 |
-
*.ckpt
|
|
|
|
| 1 |
+
app_copy.py
|
|
|
__pycache__/io_utils.cpython-311.pyc
ADDED
|
Binary file (6.62 kB). View file
|
|
|
app.py
CHANGED
|
@@ -3,95 +3,12 @@ import os
|
|
| 3 |
import gradio as gr
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import numpy as np
|
| 6 |
-
import pandas as pd
|
| 7 |
import skimage
|
| 8 |
-
from mediffusion import DiffusionModule
|
| 9 |
import monai as mn
|
| 10 |
import torch
|
| 11 |
|
| 12 |
from io_utils import LoadImageD
|
| 13 |
|
| 14 |
-
# Loading the model for inference
|
| 15 |
-
|
| 16 |
-
model = DiffusionModule("./diffusion_configs.yaml")
|
| 17 |
-
model.load_ckpt("./data/model.ckpt")
|
| 18 |
-
model.eval();
|
| 19 |
-
|
| 20 |
-
# Loading a baseline noise for making predictions
|
| 21 |
-
|
| 22 |
-
seed = 3407
|
| 23 |
-
np.random.seed(seed)
|
| 24 |
-
torch.random.manual_seed(seed)
|
| 25 |
-
torch.backends.cudnn.deterministic = True
|
| 26 |
-
BASELINE_NOISE = torch.randn(1, 1, 256, 256).half()
|
| 27 |
-
|
| 28 |
-
# Model helper functions
|
| 29 |
-
|
| 30 |
-
def create_ds(img_paths):
|
| 31 |
-
if type(img_paths) == str:
|
| 32 |
-
img_paths = [img_paths]
|
| 33 |
-
data_list = [{"img": img_path} for img_path in img_paths]
|
| 34 |
-
|
| 35 |
-
# Get the transforms
|
| 36 |
-
Ts_list = [
|
| 37 |
-
LoadImageD(keys=["img"], transpose=True, normalize=True),
|
| 38 |
-
mn.transforms.EnsureChannelFirstD(
|
| 39 |
-
keys=["img"], channel_dim="no_channel"
|
| 40 |
-
),
|
| 41 |
-
mn.transforms.ResizeD(
|
| 42 |
-
keys=["img"],
|
| 43 |
-
spatial_size=(256, 256),
|
| 44 |
-
mode=["bicubic"],
|
| 45 |
-
),
|
| 46 |
-
mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1),
|
| 47 |
-
mn.transforms.ToTensorD(keys=["img"], track_meta=None),
|
| 48 |
-
mn.transforms.SelectItemsD(keys=["img"]),
|
| 49 |
-
]
|
| 50 |
-
return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list))
|
| 51 |
-
|
| 52 |
-
def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM5"):
|
| 53 |
-
|
| 54 |
-
global model
|
| 55 |
-
global BASELINE_NOISE
|
| 56 |
-
|
| 57 |
-
# Create the image dataset
|
| 58 |
-
if cls_batch is not None:
|
| 59 |
-
ds = create_ds([img_path]*len(cls_batch))
|
| 60 |
-
else:
|
| 61 |
-
ds = create_ds(img_path)
|
| 62 |
-
dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False)
|
| 63 |
-
input_batch = next(iter(dl))
|
| 64 |
-
original_imgs = input_batch["img"].detach().cpu().numpy()
|
| 65 |
-
|
| 66 |
-
# Create the classifier condition if not provided
|
| 67 |
-
if cls_batch is None:
|
| 68 |
-
fp = torch.zeros(768)
|
| 69 |
-
if rotate_to_standard or angles is None:
|
| 70 |
-
angles = [1000, 1000, 1000]
|
| 71 |
-
cls_value = torch.tensor([2, *angles, *fp])
|
| 72 |
-
else:
|
| 73 |
-
cls_value = torch.tensor([1, *angles, *fp])
|
| 74 |
-
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
|
| 75 |
-
|
| 76 |
-
# Generate noise
|
| 77 |
-
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
|
| 78 |
-
model_kwargs = {
|
| 79 |
-
"cls": cls_batch,
|
| 80 |
-
"concat": input_batch["img"]
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
# Make predictions
|
| 84 |
-
preds = model.predict(
|
| 85 |
-
noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler
|
| 86 |
-
)
|
| 87 |
-
adjusted_preds = list()
|
| 88 |
-
for pred, original_img in zip(preds, original_imgs):
|
| 89 |
-
adjusted_pred = pred.detach().cpu().numpy().squeeze()
|
| 90 |
-
original_img = original_img.squeeze()
|
| 91 |
-
adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img)
|
| 92 |
-
adjusted_preds.append(adjusted_pred)
|
| 93 |
-
return adjusted_preds
|
| 94 |
-
|
| 95 |
# Gradio helper functions
|
| 96 |
|
| 97 |
current_img = None
|
|
@@ -101,65 +18,52 @@ def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
|
|
| 101 |
|
| 102 |
global current_img
|
| 103 |
|
| 104 |
-
angles =
|
| 105 |
-
|
|
|
|
| 106 |
if not add_bone_cmap:
|
| 107 |
-
print(out_img.shape)
|
| 108 |
return out_img
|
| 109 |
cmap = plt.get_cmap('bone')
|
| 110 |
out_img = cmap(out_img)
|
| 111 |
out_img = (out_img[..., :3] * 255).astype(np.uint8)
|
| 112 |
current_img = out_img
|
| 113 |
return out_img
|
| 114 |
-
|
| 115 |
-
def use_current_btn_fn(input_img):
|
| 116 |
-
return input_img
|
| 117 |
-
|
| 118 |
-
def retrieve_examples(examples, inputs):
|
| 119 |
-
global current_img
|
| 120 |
-
if current_img is not None:
|
| 121 |
-
return current_img
|
| 122 |
-
return examples[0]
|
| 123 |
|
| 124 |
css_style = "./style.css"
|
| 125 |
callback = gr.CSVLogger()
|
| 126 |
with gr.Blocks(css=css_style) as app:
|
| 127 |
gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
|
| 128 |
-
gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="
|
| 129 |
-
gr.HTML("Note: This is a proof-of-concept demo
|
| 130 |
|
| 131 |
-
with gr.TabItem("
|
| 132 |
with gr.Row():
|
| 133 |
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
|
| 134 |
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
|
| 135 |
with gr.Row():
|
| 136 |
-
gr.
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
)
|
| 148 |
with gr.Row():
|
| 149 |
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
|
| 150 |
with gr.Row():
|
| 151 |
with gr.Column(scale=1):
|
| 152 |
-
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-
|
| 153 |
with gr.Column(scale=1):
|
| 154 |
-
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-
|
| 155 |
with gr.Column(scale=1):
|
| 156 |
-
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-
|
| 157 |
with gr.Row():
|
| 158 |
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
|
| 159 |
-
with gr.Row():
|
| 160 |
-
use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
|
| 161 |
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
|
| 162 |
-
use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
|
| 163 |
|
| 164 |
try:
|
| 165 |
app.close()
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
import skimage
|
|
|
|
| 7 |
import monai as mn
|
| 8 |
import torch
|
| 9 |
|
| 10 |
from io_utils import LoadImageD
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Gradio helper functions
|
| 13 |
|
| 14 |
current_img = None
|
|
|
|
| 18 |
|
| 19 |
global current_img
|
| 20 |
|
| 21 |
+
angles = (xt, yt, zt)
|
| 22 |
+
out_img_path = f'data/cached_outputs/{os.path.basename(img_path)[:-4]}_{angles}.png'
|
| 23 |
+
out_img = skimage.io.imread(out_img_path)
|
| 24 |
if not add_bone_cmap:
|
|
|
|
| 25 |
return out_img
|
| 26 |
cmap = plt.get_cmap('bone')
|
| 27 |
out_img = cmap(out_img)
|
| 28 |
out_img = (out_img[..., :3] * 255).astype(np.uint8)
|
| 29 |
current_img = out_img
|
| 30 |
return out_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
css_style = "./style.css"
|
| 33 |
callback = gr.CSVLogger()
|
| 34 |
with gr.Blocks(css=css_style) as app:
|
| 35 |
gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
|
| 36 |
+
gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="subtitle")
|
| 37 |
+
gr.HTML("Note: This is a proof-of-concept demo running on CPU. All predictions are pre-computed.", elem_classes="note")
|
| 38 |
|
| 39 |
+
with gr.TabItem("Demo"):
|
| 40 |
with gr.Row():
|
| 41 |
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
|
| 42 |
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
|
| 43 |
with gr.Row():
|
| 44 |
+
with gr.Column(scale=0.25):
|
| 45 |
+
pass
|
| 46 |
+
with gr.Column(scale=1):
|
| 47 |
+
gr.Examples(
|
| 48 |
+
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
|
| 49 |
+
inputs = [input_img],
|
| 50 |
+
label = "Xray Examples",
|
| 51 |
+
elem_id='examples',
|
| 52 |
+
)
|
| 53 |
+
with gr.Column(scale=0.25):
|
| 54 |
+
pass
|
|
|
|
| 55 |
with gr.Row():
|
| 56 |
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
|
| 57 |
with gr.Row():
|
| 58 |
with gr.Column(scale=1):
|
| 59 |
+
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
|
| 60 |
with gr.Column(scale=1):
|
| 61 |
+
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
|
| 62 |
with gr.Column(scale=1):
|
| 63 |
+
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
|
| 64 |
with gr.Row():
|
| 65 |
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
|
|
|
|
|
|
|
| 66 |
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
|
|
|
|
| 67 |
|
| 68 |
try:
|
| 69 |
app.close()
|
data/cached_outputs/xr_1_(-10, -10, -10).png
ADDED
|
data/cached_outputs/xr_1_(-10, -10, -15).png
ADDED
|
data/cached_outputs/xr_1_(-10, -10, -5).png
ADDED
|
data/cached_outputs/xr_1_(-10, -10, 0).png
ADDED
|
data/cached_outputs/xr_1_(-10, -10, 10).png
ADDED
|
data/cached_outputs/xr_1_(-10, -10, 15).png
ADDED
|
data/cached_outputs/xr_1_(-10, -10, 5).png
ADDED
|
data/cached_outputs/xr_1_(-10, -15, -10).png
ADDED
|
data/cached_outputs/xr_1_(-10, -15, -15).png
ADDED
|
data/cached_outputs/xr_1_(-10, -15, -5).png
ADDED
|
data/cached_outputs/xr_1_(-10, -15, 0).png
ADDED
|
data/cached_outputs/xr_1_(-10, -15, 10).png
ADDED
|
data/cached_outputs/xr_1_(-10, -15, 15).png
ADDED
|
data/cached_outputs/xr_1_(-10, -15, 5).png
ADDED
|
data/cached_outputs/xr_1_(-10, -5, -10).png
ADDED
|
data/cached_outputs/xr_1_(-10, -5, -15).png
ADDED
|
data/cached_outputs/xr_1_(-10, -5, -5).png
ADDED
|
data/cached_outputs/xr_1_(-10, -5, 0).png
ADDED
|
data/cached_outputs/xr_1_(-10, -5, 10).png
ADDED
|
data/cached_outputs/xr_1_(-10, -5, 15).png
ADDED
|
data/cached_outputs/xr_1_(-10, -5, 5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 0, -10).png
ADDED
|
data/cached_outputs/xr_1_(-10, 0, -15).png
ADDED
|
data/cached_outputs/xr_1_(-10, 0, -5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 0, 0).png
ADDED
|
data/cached_outputs/xr_1_(-10, 0, 10).png
ADDED
|
data/cached_outputs/xr_1_(-10, 0, 15).png
ADDED
|
data/cached_outputs/xr_1_(-10, 0, 5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 10, -10).png
ADDED
|
data/cached_outputs/xr_1_(-10, 10, -15).png
ADDED
|
data/cached_outputs/xr_1_(-10, 10, -5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 10, 0).png
ADDED
|
data/cached_outputs/xr_1_(-10, 10, 10).png
ADDED
|
data/cached_outputs/xr_1_(-10, 10, 15).png
ADDED
|
data/cached_outputs/xr_1_(-10, 10, 5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 15, -10).png
ADDED
|
data/cached_outputs/xr_1_(-10, 15, -15).png
ADDED
|
data/cached_outputs/xr_1_(-10, 15, -5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 15, 0).png
ADDED
|
data/cached_outputs/xr_1_(-10, 15, 10).png
ADDED
|
data/cached_outputs/xr_1_(-10, 15, 15).png
ADDED
|
data/cached_outputs/xr_1_(-10, 15, 5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 5, -10).png
ADDED
|
data/cached_outputs/xr_1_(-10, 5, -15).png
ADDED
|
data/cached_outputs/xr_1_(-10, 5, -5).png
ADDED
|
data/cached_outputs/xr_1_(-10, 5, 0).png
ADDED
|
data/cached_outputs/xr_1_(-10, 5, 10).png
ADDED
|