Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
fc44d4b
1
Parent(s):
dfeea18
remove_type_annotator
Browse files- app.py +2 -3
- example.py +2 -2
- setup.sh +2 -11
- triplaneturbo_executable/extern/sd_dual_triplane_modules.py +10 -11
- triplaneturbo_executable/models/networks.py +2 -3
- triplaneturbo_executable/pipelines/triplaneturbo_text_to_3d.py +38 -43
- triplaneturbo_executable/utils/general_utils.py +29 -12
- triplaneturbo_executable/utils/mesh.py +75 -63
- triplaneturbo_executable/utils/mesh_exporter.py +26 -30
- triplaneturbo_executable/utils/saving.py +34 -53
app.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import subprocess
|
| 3 |
import sys
|
| 4 |
try:
|
| 5 |
import spaces
|
| 6 |
except:
|
| 7 |
pass
|
| 8 |
-
os.environ["PYDANTIC_STRICT_TYPE_CHECKING"] = "0"
|
| 9 |
|
| 10 |
# Check if setup has been run
|
| 11 |
setup_marker = ".setup_complete"
|
|
@@ -23,7 +23,6 @@ if not os.path.exists(setup_marker):
|
|
| 23 |
|
| 24 |
import torch
|
| 25 |
import gradio as gr
|
| 26 |
-
from typing import Tuple, List, Dict, Any, Optional
|
| 27 |
from collections import deque
|
| 28 |
from diffusers import StableDiffusionPipeline
|
| 29 |
|
|
@@ -58,7 +57,7 @@ def initialize_pipeline():
|
|
| 58 |
return PIPELINE
|
| 59 |
|
| 60 |
@spaces.GPU
|
| 61 |
-
def generate_3d_mesh(prompt
|
| 62 |
"""Generate 3D mesh from text prompt"""
|
| 63 |
global PIPELINE, OBJ_FILE_QUEUE
|
| 64 |
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
import subprocess
|
| 4 |
import sys
|
| 5 |
try:
|
| 6 |
import spaces
|
| 7 |
except:
|
| 8 |
pass
|
|
|
|
| 9 |
|
| 10 |
# Check if setup has been run
|
| 11 |
setup_marker = ".setup_complete"
|
|
|
|
| 23 |
|
| 24 |
import torch
|
| 25 |
import gradio as gr
|
|
|
|
| 26 |
from collections import deque
|
| 27 |
from diffusers import StableDiffusionPipeline
|
| 28 |
|
|
|
|
| 57 |
return PIPELINE
|
| 58 |
|
| 59 |
@spaces.GPU
|
| 60 |
+
def generate_3d_mesh(prompt):
|
| 61 |
"""Generate 3D mesh from text prompt"""
|
| 62 |
global PIPELINE, OBJ_FILE_QUEUE
|
| 63 |
|
example.py
CHANGED
|
@@ -17,8 +17,8 @@ from triplaneturbo_executable import TriplaneTurboTextTo3DPipeline, TriplaneTurb
|
|
| 17 |
|
| 18 |
# Initialize configuration and parameters
|
| 19 |
prompt = "a beautiful girl"
|
| 20 |
-
output_dir = "
|
| 21 |
-
adapter_name_or_path = "
|
| 22 |
num_results_per_prompt = 1
|
| 23 |
seed = 42
|
| 24 |
device = "cuda"
|
|
|
|
| 17 |
|
| 18 |
# Initialize configuration and parameters
|
| 19 |
prompt = "a beautiful girl"
|
| 20 |
+
output_dir = "output"
|
| 21 |
+
adapter_name_or_path = "pretrained/triplane_turbo_sd_v1.pth"
|
| 22 |
num_results_per_prompt = 1
|
| 23 |
seed = 42
|
| 24 |
device = "cuda"
|
setup.sh
CHANGED
|
@@ -17,12 +17,12 @@ pip install --force-reinstall -v "numpy==1.25.2"
|
|
| 17 |
# cd ..
|
| 18 |
# cd ..
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
echo "Installing other requirements..."
|
| 23 |
pip install -r requirements.txt
|
| 24 |
|
| 25 |
-
|
| 26 |
echo "Installing pre-compiled DISO wheel package..."
|
| 27 |
huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
|
| 28 |
--include "diso-0.1.4-*.whl" \
|
|
@@ -30,12 +30,3 @@ huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
|
|
| 30 |
--local-dir-use-symlinks False
|
| 31 |
|
| 32 |
pip install ./diso_package/diso-0.1.4-*.whl
|
| 33 |
-
echo "Setup completed successfully!"
|
| 34 |
-
|
| 35 |
-
echo "Installing compatible dependency versions..."
|
| 36 |
-
pip uninstall -y pydantic
|
| 37 |
-
pip install pydantic==1.10.8 # Install compatible older version
|
| 38 |
-
|
| 39 |
-
# Ensure Gradio and other dependencies are installed correctly
|
| 40 |
-
pip install "gradio>=4.0.0,<5.0.0"
|
| 41 |
-
pip install "fastapi<0.103.0" # Ensure compatible FastAPI version
|
|
|
|
| 17 |
# cd ..
|
| 18 |
# cd ..
|
| 19 |
|
| 20 |
+
echo "Setup completed successfully!"
|
| 21 |
|
| 22 |
echo "Installing other requirements..."
|
| 23 |
pip install -r requirements.txt
|
| 24 |
|
| 25 |
+
# 从您的Hugging Face仓库下载并安装预编译的DISO wheel
|
| 26 |
echo "Installing pre-compiled DISO wheel package..."
|
| 27 |
huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
|
| 28 |
--include "diso-0.1.4-*.whl" \
|
|
|
|
| 30 |
--local-dir-use-symlinks False
|
| 31 |
|
| 32 |
pip install ./diso_package/diso-0.1.4-*.whl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
triplaneturbo_executable/extern/sd_dual_triplane_modules.py
CHANGED
|
@@ -2,7 +2,6 @@ import re
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from dataclasses import dataclass
|
| 5 |
-
from typing import Optional, Union, Tuple
|
| 6 |
|
| 7 |
from diffusers.models.attention_processor import Attention
|
| 8 |
from diffusers import (
|
|
@@ -39,9 +38,9 @@ class LoRALinearLayerwBias(nn.Module):
|
|
| 39 |
in_features: int,
|
| 40 |
out_features: int,
|
| 41 |
rank: int = 4,
|
| 42 |
-
network_alpha
|
| 43 |
-
device
|
| 44 |
-
dtype
|
| 45 |
with_bias: bool = False
|
| 46 |
):
|
| 47 |
super().__init__()
|
|
@@ -105,10 +104,10 @@ class TriplaneLoRAConv2dLayer(nn.Module):
|
|
| 105 |
in_features: int,
|
| 106 |
out_features: int,
|
| 107 |
rank: int = 4,
|
| 108 |
-
kernel_size
|
| 109 |
-
stride
|
| 110 |
-
padding
|
| 111 |
-
network_alpha
|
| 112 |
with_bias: bool = False,
|
| 113 |
locon_type: str = "hexa_v1", #hexa_v2, vanilla_v1, vanilla_v2
|
| 114 |
):
|
|
@@ -220,7 +219,7 @@ class TriplaneSelfAttentionLoRAAttnProcessor(nn.Module):
|
|
| 220 |
self,
|
| 221 |
hidden_size: int,
|
| 222 |
rank: int = 4,
|
| 223 |
-
network_alpha
|
| 224 |
with_bias: bool = False,
|
| 225 |
lora_type: str = "hexa_v1", # vanilla,
|
| 226 |
):
|
|
@@ -492,7 +491,7 @@ class TriplaneCrossAttentionLoRAAttnProcessor(nn.Module):
|
|
| 492 |
hidden_size: int,
|
| 493 |
cross_attention_dim: int,
|
| 494 |
rank: int = 4,
|
| 495 |
-
network_alpha
|
| 496 |
with_bias: bool = False,
|
| 497 |
lora_type: str = "hexa_v1", # vanilla,
|
| 498 |
):
|
|
@@ -713,7 +712,7 @@ class OneStepTriplaneDualStableDiffusion(nn.Module):
|
|
| 713 |
"""
|
| 714 |
def __init__(
|
| 715 |
self,
|
| 716 |
-
config
|
| 717 |
vae: AutoencoderKL,
|
| 718 |
unet: UNet2DConditionModel,
|
| 719 |
):
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
from diffusers.models.attention_processor import Attention
|
| 7 |
from diffusers import (
|
|
|
|
| 38 |
in_features: int,
|
| 39 |
out_features: int,
|
| 40 |
rank: int = 4,
|
| 41 |
+
network_alpha=None,
|
| 42 |
+
device=None,
|
| 43 |
+
dtype=None,
|
| 44 |
with_bias: bool = False
|
| 45 |
):
|
| 46 |
super().__init__()
|
|
|
|
| 104 |
in_features: int,
|
| 105 |
out_features: int,
|
| 106 |
rank: int = 4,
|
| 107 |
+
kernel_size = (1, 1),
|
| 108 |
+
stride = (1, 1),
|
| 109 |
+
padding = 0,
|
| 110 |
+
network_alpha = None,
|
| 111 |
with_bias: bool = False,
|
| 112 |
locon_type: str = "hexa_v1", #hexa_v2, vanilla_v1, vanilla_v2
|
| 113 |
):
|
|
|
|
| 219 |
self,
|
| 220 |
hidden_size: int,
|
| 221 |
rank: int = 4,
|
| 222 |
+
network_alpha=None,
|
| 223 |
with_bias: bool = False,
|
| 224 |
lora_type: str = "hexa_v1", # vanilla,
|
| 225 |
):
|
|
|
|
| 491 |
hidden_size: int,
|
| 492 |
cross_attention_dim: int,
|
| 493 |
rank: int = 4,
|
| 494 |
+
network_alpha = None,
|
| 495 |
with_bias: bool = False,
|
| 496 |
lora_type: str = "hexa_v1", # vanilla,
|
| 497 |
):
|
|
|
|
| 712 |
"""
|
| 713 |
def __init__(
|
| 714 |
self,
|
| 715 |
+
config,
|
| 716 |
vae: AutoencoderKL,
|
| 717 |
unet: UNet2DConditionModel,
|
| 718 |
):
|
triplaneturbo_executable/models/networks.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch.nn as nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from ..utils.general_utils import config_to_primitive
|
| 5 |
from dataclasses import dataclass
|
| 6 |
-
from typing import Optional, Literal
|
| 7 |
|
| 8 |
def get_activation(name):
|
| 9 |
if name is None:
|
|
@@ -21,7 +20,7 @@ def get_activation(name):
|
|
| 21 |
|
| 22 |
|
| 23 |
class VanillaMLP(nn.Module):
|
| 24 |
-
def __init__(self, dim_in
|
| 25 |
super().__init__()
|
| 26 |
# Convert dict to MLPConfig if needed
|
| 27 |
if isinstance(config, dict):
|
|
@@ -70,7 +69,7 @@ class MLPConfig:
|
|
| 70 |
n_neurons: int = 64
|
| 71 |
n_hidden_layers: int = 2
|
| 72 |
|
| 73 |
-
def get_mlp(input_dim
|
| 74 |
"""Create MLP network based on config"""
|
| 75 |
# Convert dict to MLPConfig
|
| 76 |
if isinstance(config, dict):
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from ..utils.general_utils import config_to_primitive
|
| 5 |
from dataclasses import dataclass
|
|
|
|
| 6 |
|
| 7 |
def get_activation(name):
|
| 8 |
if name is None:
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class VanillaMLP(nn.Module):
|
| 23 |
+
def __init__(self, dim_in, dim_out, config):
|
| 24 |
super().__init__()
|
| 25 |
# Convert dict to MLPConfig if needed
|
| 26 |
if isinstance(config, dict):
|
|
|
|
| 69 |
n_neurons: int = 64
|
| 70 |
n_hidden_layers: int = 2
|
| 71 |
|
| 72 |
+
def get_mlp(input_dim, output_dim, config):
|
| 73 |
"""Create MLP network based on config"""
|
| 74 |
# Convert dict to MLPConfig
|
| 75 |
if isinstance(config, dict):
|
triplaneturbo_executable/pipelines/triplaneturbo_text_to_3d.py
CHANGED
|
@@ -4,7 +4,6 @@ import json
|
|
| 4 |
from tqdm import tqdm
|
| 5 |
|
| 6 |
import torch
|
| 7 |
-
from typing import *
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
from diffusers import StableDiffusionPipeline
|
| 10 |
|
|
@@ -21,11 +20,6 @@ class TriplaneTurboTextTo3DPipelineConfig:
|
|
| 21 |
# Basic pipeline settings
|
| 22 |
base_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base"
|
| 23 |
|
| 24 |
-
num_inference_steps: int = 4
|
| 25 |
-
num_results_per_prompt: int = 1
|
| 26 |
-
latent_channels: int = 4
|
| 27 |
-
latent_height: int = 64
|
| 28 |
-
latent_width: int = 64
|
| 29 |
|
| 30 |
# Training/sampling settings
|
| 31 |
num_steps_sampling: int = 4
|
|
@@ -72,7 +66,7 @@ class TriplaneTurboTextTo3DPipelineConfig:
|
|
| 72 |
color_activation: str = "sigmoid-mipnerf"
|
| 73 |
|
| 74 |
@classmethod
|
| 75 |
-
def from_pretrained(cls, pretrained_path
|
| 76 |
"""Load config from pretrained path"""
|
| 77 |
config_path = os.path.join(pretrained_path, "config.json")
|
| 78 |
if os.path.exists(config_path):
|
|
@@ -91,11 +85,11 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
|
|
| 91 |
|
| 92 |
def __init__(
|
| 93 |
self,
|
| 94 |
-
geometry
|
| 95 |
-
material
|
| 96 |
-
base_pipeline
|
| 97 |
-
sample_scheduler
|
| 98 |
-
isosurface_helper
|
| 99 |
**kwargs,
|
| 100 |
):
|
| 101 |
super().__init__()
|
|
@@ -116,7 +110,7 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
|
|
| 116 |
@classmethod
|
| 117 |
def from_pretrained(
|
| 118 |
cls,
|
| 119 |
-
pretrained_model_name_or_path
|
| 120 |
**kwargs,
|
| 121 |
):
|
| 122 |
"""
|
|
@@ -197,10 +191,10 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
|
|
| 197 |
|
| 198 |
def encode_prompt(
|
| 199 |
self,
|
| 200 |
-
prompt
|
| 201 |
-
device
|
| 202 |
-
num_results_per_prompt
|
| 203 |
-
)
|
| 204 |
"""
|
| 205 |
Encodes the prompt into text encoder hidden states.
|
| 206 |
|
|
@@ -227,14 +221,13 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
|
|
| 227 |
@torch.no_grad()
|
| 228 |
def __call__(
|
| 229 |
self,
|
| 230 |
-
prompt
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
colorize
|
| 237 |
-
**kwargs,
|
| 238 |
):
|
| 239 |
# Implementation similar to Zero123Pipeline
|
| 240 |
# Reference code from: https://github.com/zero123/zero123-diffusers
|
|
@@ -251,15 +244,18 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
|
|
| 251 |
# Get the device from the first available module
|
| 252 |
|
| 253 |
# Generate latents if not provided
|
| 254 |
-
if
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
)
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
# Process text prompt through geometry module
|
| 262 |
-
text_embed, _ = self.encode_prompt(prompt,
|
| 263 |
|
| 264 |
# Run diffusion process
|
| 265 |
# Set up timesteps for sampling
|
|
@@ -282,7 +278,7 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
|
|
| 282 |
pred = self.geometry.denoise(
|
| 283 |
noisy_input=noisy_latent_input,
|
| 284 |
text_embed=text_embed,
|
| 285 |
-
timestep=t.to(
|
| 286 |
)
|
| 287 |
|
| 288 |
# Update latents
|
|
@@ -311,20 +307,19 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
|
|
| 311 |
activation=self.material,
|
| 312 |
)
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
return mesh_list
|
| 323 |
|
| 324 |
def _set_timesteps(
|
| 325 |
self,
|
| 326 |
scheduler,
|
| 327 |
-
num_steps
|
| 328 |
):
|
| 329 |
"""Set up timesteps for sampling.
|
| 330 |
|
|
|
|
| 4 |
from tqdm import tqdm
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 7 |
from dataclasses import dataclass, field
|
| 8 |
from diffusers import StableDiffusionPipeline
|
| 9 |
|
|
|
|
| 20 |
# Basic pipeline settings
|
| 21 |
base_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base"
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Training/sampling settings
|
| 25 |
num_steps_sampling: int = 4
|
|
|
|
| 66 |
color_activation: str = "sigmoid-mipnerf"
|
| 67 |
|
| 68 |
@classmethod
|
| 69 |
+
def from_pretrained(cls, pretrained_path):
|
| 70 |
"""Load config from pretrained path"""
|
| 71 |
config_path = os.path.join(pretrained_path, "config.json")
|
| 72 |
if os.path.exists(config_path):
|
|
|
|
| 85 |
|
| 86 |
def __init__(
|
| 87 |
self,
|
| 88 |
+
geometry,
|
| 89 |
+
material,
|
| 90 |
+
base_pipeline,
|
| 91 |
+
sample_scheduler,
|
| 92 |
+
isosurface_helper,
|
| 93 |
**kwargs,
|
| 94 |
):
|
| 95 |
super().__init__()
|
|
|
|
| 110 |
@classmethod
|
| 111 |
def from_pretrained(
|
| 112 |
cls,
|
| 113 |
+
pretrained_model_name_or_path,
|
| 114 |
**kwargs,
|
| 115 |
):
|
| 116 |
"""
|
|
|
|
| 191 |
|
| 192 |
def encode_prompt(
|
| 193 |
self,
|
| 194 |
+
prompt,
|
| 195 |
+
device,
|
| 196 |
+
num_results_per_prompt = 1,
|
| 197 |
+
):
|
| 198 |
"""
|
| 199 |
Encodes the prompt into text encoder hidden states.
|
| 200 |
|
|
|
|
| 221 |
@torch.no_grad()
|
| 222 |
def __call__(
|
| 223 |
self,
|
| 224 |
+
prompt,
|
| 225 |
+
num_results_per_prompt=1,
|
| 226 |
+
generator=None,
|
| 227 |
+
device=None,
|
| 228 |
+
return_dict=True,
|
| 229 |
+
num_inference_steps=4,
|
| 230 |
+
colorize = True,
|
|
|
|
| 231 |
):
|
| 232 |
# Implementation similar to Zero123Pipeline
|
| 233 |
# Reference code from: https://github.com/zero123/zero123-diffusers
|
|
|
|
| 244 |
# Get the device from the first available module
|
| 245 |
|
| 246 |
# Generate latents if not provided
|
| 247 |
+
if device is None:
|
| 248 |
+
device = self.device
|
| 249 |
+
if generator is None:
|
| 250 |
+
generator = torch.Generator(device=device)
|
| 251 |
+
latents = torch.randn(
|
| 252 |
+
(batch_size * 6, 4, 32, 32), # hard-coded for now
|
| 253 |
+
generator=generator,
|
| 254 |
+
device=device,
|
| 255 |
+
)
|
| 256 |
|
| 257 |
# Process text prompt through geometry module
|
| 258 |
+
text_embed, _ = self.encode_prompt(prompt, device, num_results_per_prompt)
|
| 259 |
|
| 260 |
# Run diffusion process
|
| 261 |
# Set up timesteps for sampling
|
|
|
|
| 278 |
pred = self.geometry.denoise(
|
| 279 |
noisy_input=noisy_latent_input,
|
| 280 |
text_embed=text_embed,
|
| 281 |
+
timestep=t.to(device),
|
| 282 |
)
|
| 283 |
|
| 284 |
# Update latents
|
|
|
|
| 307 |
activation=self.material,
|
| 308 |
)
|
| 309 |
|
| 310 |
+
if return_dict:
|
| 311 |
+
return {
|
| 312 |
+
"space_cache": space_cache,
|
| 313 |
+
"latents": latents,
|
| 314 |
+
"mesh": mesh_list,
|
| 315 |
+
}
|
| 316 |
+
else:
|
| 317 |
+
return mesh_list
|
|
|
|
| 318 |
|
| 319 |
def _set_timesteps(
|
| 320 |
self,
|
| 321 |
scheduler,
|
| 322 |
+
num_steps,
|
| 323 |
):
|
| 324 |
"""Set up timesteps for sampling.
|
| 325 |
|
triplaneturbo_executable/utils/general_utils.py
CHANGED
|
@@ -2,17 +2,28 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
-
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
|
| 9 |
-
def config_to_primitive(config
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def scale_tensor(
|
| 13 |
-
dat
|
| 14 |
-
inp_scale
|
| 15 |
-
tgt_scale
|
| 16 |
):
|
| 17 |
if inp_scale is None:
|
| 18 |
inp_scale = (0, 1)
|
|
@@ -25,8 +36,8 @@ def scale_tensor(
|
|
| 25 |
return dat
|
| 26 |
|
| 27 |
def contract_to_unisphere_custom(
|
| 28 |
-
x
|
| 29 |
-
)
|
| 30 |
if unbounded:
|
| 31 |
x = scale_tensor(x, bbox, (-1, 1))
|
| 32 |
x = x * 2 - 1 # aabb is at [-1, 1]
|
|
@@ -81,7 +92,7 @@ def project_onto_planes(planes, coordinates):
|
|
| 81 |
projections = torch.bmm(coordinates, inv_planes)
|
| 82 |
return projections[..., :2]
|
| 83 |
|
| 84 |
-
def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat
|
| 85 |
assert padding_mode == 'zeros'
|
| 86 |
N, n_planes, C, H, W = plane_features.shape
|
| 87 |
_, M, _ = coordinates.shape
|
|
@@ -101,4 +112,10 @@ def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mod
|
|
| 101 |
output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
| 102 |
output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C)
|
| 103 |
|
| 104 |
-
return output_features.contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
+
import numpy as np
|
| 6 |
+
from dataclasses import asdict, is_dataclass
|
| 7 |
+
import gc
|
| 8 |
|
| 9 |
+
def config_to_primitive(config):
|
| 10 |
+
"""Convert a dataclass config to a dictionary recursively."""
|
| 11 |
+
if is_dataclass(config):
|
| 12 |
+
config_dict = asdict(config)
|
| 13 |
+
return {k: config_to_primitive(v) for k, v in config_dict.items()}
|
| 14 |
+
elif isinstance(config, dict):
|
| 15 |
+
return {k: config_to_primitive(v) for k, v in config.items()}
|
| 16 |
+
elif isinstance(config, list):
|
| 17 |
+
return [config_to_primitive(v) for v in config]
|
| 18 |
+
elif isinstance(config, tuple):
|
| 19 |
+
return tuple(config_to_primitive(v) for v in config)
|
| 20 |
+
else:
|
| 21 |
+
return config
|
| 22 |
|
| 23 |
def scale_tensor(
|
| 24 |
+
dat,
|
| 25 |
+
inp_scale,
|
| 26 |
+
tgt_scale
|
| 27 |
):
|
| 28 |
if inp_scale is None:
|
| 29 |
inp_scale = (0, 1)
|
|
|
|
| 36 |
return dat
|
| 37 |
|
| 38 |
def contract_to_unisphere_custom(
|
| 39 |
+
x, bbox, unbounded = False
|
| 40 |
+
):
|
| 41 |
if unbounded:
|
| 42 |
x = scale_tensor(x, bbox, (-1, 1))
|
| 43 |
x = x * 2 - 1 # aabb is at [-1, 1]
|
|
|
|
| 92 |
projections = torch.bmm(coordinates, inv_planes)
|
| 93 |
return projections[..., :2]
|
| 94 |
|
| 95 |
+
def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat = None):
|
| 96 |
assert padding_mode == 'zeros'
|
| 97 |
N, n_planes, C, H, W = plane_features.shape
|
| 98 |
_, M, _ = coordinates.shape
|
|
|
|
| 112 |
output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
| 113 |
output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C)
|
| 114 |
|
| 115 |
+
return output_features.contiguous()
|
| 116 |
+
|
| 117 |
+
def cleanup():
|
| 118 |
+
"""Cleanup torch memory."""
|
| 119 |
+
gc.collect()
|
| 120 |
+
torch.cuda.empty_cache()
|
| 121 |
+
torch.cuda.ipc_collect()
|
triplaneturbo_executable/utils/mesh.py
CHANGED
|
@@ -1,77 +1,54 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
-
from typing import Any, Dict, Optional, Union
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
from jaxtyping import Float, Integer
|
| 11 |
-
from torch import Tensor
|
| 12 |
|
| 13 |
def dot(x, y):
|
| 14 |
return torch.sum(x * y, -1, keepdim=True)
|
| 15 |
|
| 16 |
class Mesh:
|
| 17 |
def __init__(
|
| 18 |
-
self, v_pos
|
| 19 |
-
)
|
| 20 |
-
self.v_pos
|
| 21 |
-
self.t_pos_idx
|
| 22 |
-
self.
|
| 23 |
-
self.
|
| 24 |
-
self.
|
| 25 |
-
self.
|
| 26 |
-
self.
|
| 27 |
-
self.
|
| 28 |
-
self.
|
| 29 |
-
|
| 30 |
-
self.add_extra(k, v)
|
| 31 |
|
| 32 |
def add_extra(self, k, v) -> None:
|
| 33 |
self.extras[k] = v
|
| 34 |
|
| 35 |
-
def remove_outlier(self,
|
| 36 |
-
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
components = mesh.split(only_watertight=False)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
n_faces_threshold: int
|
| 52 |
-
if isinstance(outlier_n_faces_threshold, float):
|
| 53 |
-
# set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
|
| 54 |
-
n_faces_threshold = int(
|
| 55 |
-
max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
|
| 56 |
-
)
|
| 57 |
-
else:
|
| 58 |
-
# set the threshold directly to outlier_n_faces_threshold
|
| 59 |
-
n_faces_threshold = outlier_n_faces_threshold
|
| 60 |
-
|
| 61 |
-
# remove the components with less than n_face_threshold faces
|
| 62 |
-
components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
|
| 63 |
-
|
| 64 |
-
# merge the components
|
| 65 |
-
mesh = trimesh.util.concatenate(components)
|
| 66 |
-
|
| 67 |
-
# convert back to our mesh format
|
| 68 |
-
v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
|
| 69 |
-
t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
|
| 70 |
-
|
| 71 |
-
clean_mesh = Mesh(v_pos, t_pos_idx)
|
| 72 |
-
# keep the extras unchanged
|
| 73 |
-
|
| 74 |
-
return clean_mesh
|
| 75 |
|
| 76 |
@property
|
| 77 |
def requires_grad(self):
|
|
@@ -245,8 +222,8 @@ class Mesh:
|
|
| 245 |
edges = torch.unique(edges, dim=0)
|
| 246 |
return edges
|
| 247 |
|
| 248 |
-
def normal_consistency(self)
|
| 249 |
-
edge_nrm
|
| 250 |
nc = (
|
| 251 |
1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
|
| 252 |
).mean()
|
|
@@ -279,10 +256,45 @@ class Mesh:
|
|
| 279 |
# correct diagonal
|
| 280 |
return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
|
| 281 |
|
| 282 |
-
def laplacian(self)
|
| 283 |
with torch.no_grad():
|
| 284 |
L = self._laplacian_uniform()
|
| 285 |
loss = L.mm(self.v_pos)
|
| 286 |
loss = loss.norm(dim=1)
|
| 287 |
loss = loss.mean()
|
| 288 |
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
import trimesh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def dot(x, y):
|
| 7 |
return torch.sum(x * y, -1, keepdim=True)
|
| 8 |
|
| 9 |
class Mesh:
|
| 10 |
def __init__(
|
| 11 |
+
self, v_pos, t_pos_idx, material=None
|
| 12 |
+
):
|
| 13 |
+
self.v_pos = v_pos
|
| 14 |
+
self.t_pos_idx = t_pos_idx
|
| 15 |
+
self.material = material
|
| 16 |
+
self._v_nrm = None
|
| 17 |
+
self._v_tng = None
|
| 18 |
+
self._v_tex = None
|
| 19 |
+
self._t_tex_idx = None
|
| 20 |
+
self._v_rgb = None
|
| 21 |
+
self._edges = None
|
| 22 |
+
self.extras = {}
|
|
|
|
| 23 |
|
| 24 |
def add_extra(self, k, v) -> None:
|
| 25 |
self.extras[k] = v
|
| 26 |
|
| 27 |
+
def remove_outlier(self, n_face_threshold=5):
|
| 28 |
+
"""Remove outlier components with fewer faces than threshold."""
|
| 29 |
+
# Convert to trimesh
|
| 30 |
+
trimesh_mesh = self.as_trimesh()
|
| 31 |
+
|
| 32 |
+
# Split into connected components
|
| 33 |
+
components = trimesh_mesh.split(only_watertight=False)
|
| 34 |
+
|
| 35 |
+
# Filter components with few faces
|
| 36 |
+
valid_components = [c for c in components if len(c.faces) > n_face_threshold]
|
| 37 |
+
|
| 38 |
+
if len(valid_components) == 0:
|
| 39 |
+
# If no valid components, return the original mesh
|
| 40 |
+
return self
|
| 41 |
+
|
| 42 |
+
# Combine valid components
|
| 43 |
+
combined = trimesh.util.concatenate(valid_components)
|
| 44 |
+
|
| 45 |
+
# Convert back to our Mesh format
|
| 46 |
+
new_mesh = Mesh(
|
| 47 |
+
torch.tensor(combined.vertices, dtype=self.v_pos.dtype, device=self.v_pos.device),
|
| 48 |
+
torch.tensor(combined.faces, dtype=self.t_pos_idx.dtype, device=self.t_pos_idx.device)
|
| 49 |
)
|
| 50 |
+
|
| 51 |
+
return new_mesh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
@property
|
| 54 |
def requires_grad(self):
|
|
|
|
| 222 |
edges = torch.unique(edges, dim=0)
|
| 223 |
return edges
|
| 224 |
|
| 225 |
+
def normal_consistency(self):
|
| 226 |
+
edge_nrm = self.v_nrm[self.edges]
|
| 227 |
nc = (
|
| 228 |
1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
|
| 229 |
).mean()
|
|
|
|
| 256 |
# correct diagonal
|
| 257 |
return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
|
| 258 |
|
| 259 |
+
def laplacian(self):
|
| 260 |
with torch.no_grad():
|
| 261 |
L = self._laplacian_uniform()
|
| 262 |
loss = L.mm(self.v_pos)
|
| 263 |
loss = loss.norm(dim=1)
|
| 264 |
loss = loss.mean()
|
| 265 |
return loss
|
| 266 |
+
|
| 267 |
+
def to(self, device):
|
| 268 |
+
v_pos = self.v_pos.to(device)
|
| 269 |
+
t_pos_idx = self.t_pos_idx.to(device)
|
| 270 |
+
return Mesh(v_pos, t_pos_idx)
|
| 271 |
+
|
| 272 |
+
def as_trimesh(self):
|
| 273 |
+
vertices = self.v_pos.detach().cpu().numpy()
|
| 274 |
+
faces = self.t_pos_idx.detach().cpu().numpy()
|
| 275 |
+
|
| 276 |
+
mesh = trimesh.Trimesh(
|
| 277 |
+
vertices=vertices,
|
| 278 |
+
faces=faces,
|
| 279 |
+
process=False
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Add texture if available
|
| 283 |
+
if hasattr(self, 'albedo_map') and self.albedo_map is not None:
|
| 284 |
+
# Create texture visuals
|
| 285 |
+
uv = self.v_tex.detach().cpu().numpy()
|
| 286 |
+
|
| 287 |
+
# Create texture visuals
|
| 288 |
+
visual = trimesh.visual.texture.TextureVisuals(
|
| 289 |
+
uv=uv,
|
| 290 |
+
material=trimesh.visual.material.SimpleMaterial()
|
| 291 |
+
)
|
| 292 |
+
mesh.visual = visual
|
| 293 |
+
|
| 294 |
+
return mesh
|
| 295 |
+
|
| 296 |
+
def scale_tensor(x, input_range, target_range):
|
| 297 |
+
"""Scale tensor from input_range to target_range."""
|
| 298 |
+
x_unit = (x - input_range[0]) / (input_range[1] - input_range[0])
|
| 299 |
+
x_scaled = x_unit * (target_range[1] - target_range[0]) + target_range[0]
|
| 300 |
+
return x_scaled
|
triplaneturbo_executable/utils/mesh_exporter.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
from typing import Callable, Dict, List, Optional, Tuple, Any
|
| 2 |
-
from jaxtyping import Float
|
| 3 |
-
from torch import Tensor
|
| 4 |
from dataclasses import dataclass
|
| 5 |
|
| 6 |
import torch
|
|
@@ -16,36 +13,35 @@ from ..utils.general_utils import scale_tensor
|
|
| 16 |
class ExporterOutput:
|
| 17 |
save_name: str
|
| 18 |
save_type: str
|
| 19 |
-
params:
|
| 20 |
|
| 21 |
|
| 22 |
class IsosurfaceHelper(nn.Module):
|
| 23 |
-
points_range
|
| 24 |
|
| 25 |
@property
|
| 26 |
-
def grid_vertices(self)
|
| 27 |
raise NotImplementedError
|
| 28 |
|
| 29 |
class DiffMarchingCubeHelper(IsosurfaceHelper):
|
| 30 |
def __init__(
|
| 31 |
self,
|
| 32 |
-
resolution
|
| 33 |
-
point_range
|
| 34 |
-
)
|
| 35 |
super().__init__()
|
| 36 |
self.resolution = resolution
|
| 37 |
self.points_range = point_range
|
| 38 |
|
| 39 |
from diso import DiffMC
|
| 40 |
-
self.mc_func
|
| 41 |
-
self._grid_vertices
|
| 42 |
-
self._dummy: Float[Tensor, "..."]
|
| 43 |
self.register_buffer(
|
| 44 |
"_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
|
| 45 |
)
|
| 46 |
|
| 47 |
@property
|
| 48 |
-
def grid_vertices(self)
|
| 49 |
if self._grid_vertices is None:
|
| 50 |
# keep the vertices on CPU so that we can support very large resolution
|
| 51 |
x, y, z = (
|
|
@@ -62,10 +58,10 @@ class DiffMarchingCubeHelper(IsosurfaceHelper):
|
|
| 62 |
|
| 63 |
def forward(
|
| 64 |
self,
|
| 65 |
-
level
|
| 66 |
-
deformation
|
| 67 |
isovalue=0.0,
|
| 68 |
-
)
|
| 69 |
level = level.view(self.resolution, self.resolution, self.resolution)
|
| 70 |
if deformation is not None:
|
| 71 |
deformation = deformation.view(self.resolution, self.resolution, self.resolution, 3)
|
|
@@ -76,17 +72,17 @@ class DiffMarchingCubeHelper(IsosurfaceHelper):
|
|
| 76 |
|
| 77 |
|
| 78 |
def isosurface(
|
| 79 |
-
space_cache
|
| 80 |
-
forward_field
|
| 81 |
-
isosurface_helper
|
| 82 |
-
)
|
| 83 |
|
| 84 |
# the isosurface is dependent on the space cache
|
| 85 |
# randomly detach isosurface method if it is differentiable
|
| 86 |
# get the batchsize
|
| 87 |
if torch.is_tensor(space_cache): #space cache
|
| 88 |
batch_size = space_cache.shape[0]
|
| 89 |
-
elif isinstance(space_cache,
|
| 90 |
# Dict[str, List[Float[Tensor, "B ..."]]]
|
| 91 |
for key in space_cache.keys():
|
| 92 |
batch_size = space_cache[key][0].shape[0]
|
|
@@ -141,11 +137,11 @@ def isosurface(
|
|
| 141 |
return mesh_list
|
| 142 |
|
| 143 |
def colorize_mesh(
|
| 144 |
-
space_cache
|
| 145 |
-
export_fn
|
| 146 |
-
mesh_list
|
| 147 |
-
activation
|
| 148 |
-
)
|
| 149 |
"""Colorize the mesh using the geometry's export function and space cache.
|
| 150 |
|
| 151 |
Args:
|
|
@@ -199,10 +195,10 @@ class MeshExporter(SaverMixin):
|
|
| 199 |
return x
|
| 200 |
|
| 201 |
def export_obj(
|
| 202 |
-
mesh
|
| 203 |
-
save_path
|
| 204 |
-
save_normal
|
| 205 |
-
)
|
| 206 |
"""
|
| 207 |
Export mesh data to OBJ file format.
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 13 |
class ExporterOutput:
|
| 14 |
save_name: str
|
| 15 |
save_type: str
|
| 16 |
+
params: dict
|
| 17 |
|
| 18 |
|
| 19 |
class IsosurfaceHelper(nn.Module):
|
| 20 |
+
points_range = (0, 1)
|
| 21 |
|
| 22 |
@property
|
| 23 |
+
def grid_vertices(self):
|
| 24 |
raise NotImplementedError
|
| 25 |
|
| 26 |
class DiffMarchingCubeHelper(IsosurfaceHelper):
|
| 27 |
def __init__(
|
| 28 |
self,
|
| 29 |
+
resolution,
|
| 30 |
+
point_range = (0, 1)
|
| 31 |
+
):
|
| 32 |
super().__init__()
|
| 33 |
self.resolution = resolution
|
| 34 |
self.points_range = point_range
|
| 35 |
|
| 36 |
from diso import DiffMC
|
| 37 |
+
self.mc_func = DiffMC(dtype=torch.float32)
|
| 38 |
+
self._grid_vertices = None
|
|
|
|
| 39 |
self.register_buffer(
|
| 40 |
"_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
|
| 41 |
)
|
| 42 |
|
| 43 |
@property
|
| 44 |
+
def grid_vertices(self):
|
| 45 |
if self._grid_vertices is None:
|
| 46 |
# keep the vertices on CPU so that we can support very large resolution
|
| 47 |
x, y, z = (
|
|
|
|
| 58 |
|
| 59 |
def forward(
|
| 60 |
self,
|
| 61 |
+
level,
|
| 62 |
+
deformation = None,
|
| 63 |
isovalue=0.0,
|
| 64 |
+
):
|
| 65 |
level = level.view(self.resolution, self.resolution, self.resolution)
|
| 66 |
if deformation is not None:
|
| 67 |
deformation = deformation.view(self.resolution, self.resolution, self.resolution, 3)
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def isosurface(
|
| 75 |
+
space_cache,
|
| 76 |
+
forward_field,
|
| 77 |
+
isosurface_helper,
|
| 78 |
+
):
|
| 79 |
|
| 80 |
# the isosurface is dependent on the space cache
|
| 81 |
# randomly detach isosurface method if it is differentiable
|
| 82 |
# get the batchsize
|
| 83 |
if torch.is_tensor(space_cache): #space cache
|
| 84 |
batch_size = space_cache.shape[0]
|
| 85 |
+
elif isinstance(space_cache, dict): #hyper net
|
| 86 |
# Dict[str, List[Float[Tensor, "B ..."]]]
|
| 87 |
for key in space_cache.keys():
|
| 88 |
batch_size = space_cache[key][0].shape[0]
|
|
|
|
| 137 |
return mesh_list
|
| 138 |
|
| 139 |
def colorize_mesh(
|
| 140 |
+
space_cache,
|
| 141 |
+
export_fn,
|
| 142 |
+
mesh_list,
|
| 143 |
+
activation,
|
| 144 |
+
):
|
| 145 |
"""Colorize the mesh using the geometry's export function and space cache.
|
| 146 |
|
| 147 |
Args:
|
|
|
|
| 195 |
return x
|
| 196 |
|
| 197 |
def export_obj(
|
| 198 |
+
mesh,
|
| 199 |
+
save_path,
|
| 200 |
+
save_normal = False,
|
| 201 |
+
):
|
| 202 |
"""
|
| 203 |
Export mesh data to OBJ file format.
|
| 204 |
|
triplaneturbo_executable/utils/saving.py
CHANGED
|
@@ -13,22 +13,15 @@ import wandb
|
|
| 13 |
from matplotlib import cm
|
| 14 |
from matplotlib.colors import LinearSegmentedColormap
|
| 15 |
from PIL import Image, ImageDraw
|
| 16 |
-
# from pytorch_lightning.loggers import WandbLogger
|
| 17 |
|
| 18 |
-
from ..utils.mesh import Mesh
|
| 19 |
-
|
| 20 |
-
from typing import Dict, List, Optional, Union, Any
|
| 21 |
-
from omegaconf import DictConfig
|
| 22 |
-
from jaxtyping import Float
|
| 23 |
-
from torch import Tensor
|
| 24 |
|
| 25 |
import threading
|
| 26 |
|
| 27 |
class SaverMixin:
|
| 28 |
-
_save_dir
|
| 29 |
-
# _wandb_logger
|
| 30 |
|
| 31 |
-
def set_save_dir(self, save_dir
|
| 32 |
self._save_dir = save_dir
|
| 33 |
|
| 34 |
def get_save_dir(self):
|
|
@@ -58,17 +51,6 @@ class SaverMixin:
|
|
| 58 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 59 |
return save_path
|
| 60 |
|
| 61 |
-
# def create_loggers(self, cfg_loggers: DictConfig) -> None:
|
| 62 |
-
# if "wandb" in cfg_loggers.keys() and cfg_loggers.wandb.enable:
|
| 63 |
-
# self._wandb_logger = WandbLogger(
|
| 64 |
-
# project=cfg_loggers.wandb.project, name=cfg_loggers.wandb.name
|
| 65 |
-
# )
|
| 66 |
-
|
| 67 |
-
# def get_loggers(self) -> List:
|
| 68 |
-
# if self._wandb_logger:
|
| 69 |
-
# return [self._wandb_logger]
|
| 70 |
-
# else:
|
| 71 |
-
# return []
|
| 72 |
|
| 73 |
DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
|
| 74 |
DEFAULT_UV_KWARGS = {
|
|
@@ -119,8 +101,8 @@ class SaverMixin:
|
|
| 119 |
img,
|
| 120 |
data_format,
|
| 121 |
data_range,
|
| 122 |
-
name
|
| 123 |
-
step
|
| 124 |
):
|
| 125 |
img = self.get_rgb_image_(img, data_format, data_range)
|
| 126 |
cv2.imwrite(filename, img)
|
|
@@ -138,8 +120,8 @@ class SaverMixin:
|
|
| 138 |
img,
|
| 139 |
data_format=DEFAULT_RGB_KWARGS["data_format"],
|
| 140 |
data_range=DEFAULT_RGB_KWARGS["data_range"],
|
| 141 |
-
name
|
| 142 |
-
step
|
| 143 |
) -> str:
|
| 144 |
save_path = self.get_save_path(filename)
|
| 145 |
self._save_rgb_image(save_path, img, data_format, data_range, name, step)
|
|
@@ -231,8 +213,8 @@ class SaverMixin:
|
|
| 231 |
img,
|
| 232 |
data_range,
|
| 233 |
cmap,
|
| 234 |
-
name
|
| 235 |
-
step
|
| 236 |
):
|
| 237 |
img = self.get_grayscale_image_(img, data_range, cmap)
|
| 238 |
cv2.imwrite(filename, img)
|
|
@@ -250,8 +232,8 @@ class SaverMixin:
|
|
| 250 |
img,
|
| 251 |
data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
|
| 252 |
cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
|
| 253 |
-
name
|
| 254 |
-
step
|
| 255 |
) -> str:
|
| 256 |
save_path = self.get_save_path(filename)
|
| 257 |
self._save_grayscale_image(save_path, img, data_range, cmap, name, step)
|
|
@@ -308,9 +290,9 @@ class SaverMixin:
|
|
| 308 |
filename,
|
| 309 |
imgs,
|
| 310 |
align=DEFAULT_GRID_KWARGS["align"],
|
| 311 |
-
name
|
| 312 |
-
step
|
| 313 |
-
texts
|
| 314 |
):
|
| 315 |
save_path = self.get_save_path(filename)
|
| 316 |
img = self.get_image_grid_(imgs, align=align)
|
|
@@ -404,8 +386,8 @@ class SaverMixin:
|
|
| 404 |
# matcher,
|
| 405 |
# save_format="mp4",
|
| 406 |
# fps=30,
|
| 407 |
-
# name
|
| 408 |
-
# step
|
| 409 |
# ) -> str:
|
| 410 |
# assert save_format in ["gif", "mp4"]
|
| 411 |
# if not filename.endswith(save_format):
|
|
@@ -442,9 +424,9 @@ class SaverMixin:
|
|
| 442 |
matcher,
|
| 443 |
save_format="mp4",
|
| 444 |
fps=30,
|
| 445 |
-
name
|
| 446 |
-
step
|
| 447 |
-
multithreaded
|
| 448 |
) -> str:
|
| 449 |
assert save_format in ["gif", "mp4"]
|
| 450 |
if not filename.endswith(save_format):
|
|
@@ -494,20 +476,19 @@ class SaverMixin:
|
|
| 494 |
|
| 495 |
def save_obj(
|
| 496 |
self,
|
| 497 |
-
filename
|
| 498 |
-
mesh
|
| 499 |
-
save_mat
|
| 500 |
-
save_normal
|
| 501 |
-
save_uv
|
| 502 |
-
save_vertex_color
|
| 503 |
-
map_Kd
|
| 504 |
-
map_Ks
|
| 505 |
-
map_Bump
|
| 506 |
-
map_Pm
|
| 507 |
-
map_Pr
|
| 508 |
-
map_format
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
if not filename.endswith(".obj"):
|
| 512 |
filename += ".obj"
|
| 513 |
save_path = self.get_save_path(filename)
|
|
@@ -658,8 +639,8 @@ class SaverMixin:
|
|
| 658 |
map_Pm=None,
|
| 659 |
map_Pr=None,
|
| 660 |
map_format="jpg",
|
| 661 |
-
step
|
| 662 |
-
)
|
| 663 |
mtl_save_path = self.get_save_path(filename)
|
| 664 |
save_paths = [mtl_save_path]
|
| 665 |
mtl_str = f"newmtl {matname}\n"
|
|
|
|
| 13 |
from matplotlib import cm
|
| 14 |
from matplotlib.colors import LinearSegmentedColormap
|
| 15 |
from PIL import Image, ImageDraw
|
|
|
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
import threading
|
| 19 |
|
| 20 |
class SaverMixin:
|
| 21 |
+
_save_dir = None
|
| 22 |
+
# _wandb_logger = None
|
| 23 |
|
| 24 |
+
def set_save_dir(self, save_dir):
|
| 25 |
self._save_dir = save_dir
|
| 26 |
|
| 27 |
def get_save_dir(self):
|
|
|
|
| 51 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 52 |
return save_path
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
|
| 56 |
DEFAULT_UV_KWARGS = {
|
|
|
|
| 101 |
img,
|
| 102 |
data_format,
|
| 103 |
data_range,
|
| 104 |
+
name=None,
|
| 105 |
+
step=None,
|
| 106 |
):
|
| 107 |
img = self.get_rgb_image_(img, data_format, data_range)
|
| 108 |
cv2.imwrite(filename, img)
|
|
|
|
| 120 |
img,
|
| 121 |
data_format=DEFAULT_RGB_KWARGS["data_format"],
|
| 122 |
data_range=DEFAULT_RGB_KWARGS["data_range"],
|
| 123 |
+
name=None,
|
| 124 |
+
step=None,
|
| 125 |
) -> str:
|
| 126 |
save_path = self.get_save_path(filename)
|
| 127 |
self._save_rgb_image(save_path, img, data_format, data_range, name, step)
|
|
|
|
| 213 |
img,
|
| 214 |
data_range,
|
| 215 |
cmap,
|
| 216 |
+
name=None,
|
| 217 |
+
step=None,
|
| 218 |
):
|
| 219 |
img = self.get_grayscale_image_(img, data_range, cmap)
|
| 220 |
cv2.imwrite(filename, img)
|
|
|
|
| 232 |
img,
|
| 233 |
data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
|
| 234 |
cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
|
| 235 |
+
name=None,
|
| 236 |
+
step=None,
|
| 237 |
) -> str:
|
| 238 |
save_path = self.get_save_path(filename)
|
| 239 |
self._save_grayscale_image(save_path, img, data_range, cmap, name, step)
|
|
|
|
| 290 |
filename,
|
| 291 |
imgs,
|
| 292 |
align=DEFAULT_GRID_KWARGS["align"],
|
| 293 |
+
name=None,
|
| 294 |
+
step=None,
|
| 295 |
+
texts=None,
|
| 296 |
):
|
| 297 |
save_path = self.get_save_path(filename)
|
| 298 |
img = self.get_image_grid_(imgs, align=align)
|
|
|
|
| 386 |
# matcher,
|
| 387 |
# save_format="mp4",
|
| 388 |
# fps=30,
|
| 389 |
+
# name=None,
|
| 390 |
+
# step=None,
|
| 391 |
# ) -> str:
|
| 392 |
# assert save_format in ["gif", "mp4"]
|
| 393 |
# if not filename.endswith(save_format):
|
|
|
|
| 424 |
matcher,
|
| 425 |
save_format="mp4",
|
| 426 |
fps=30,
|
| 427 |
+
name=None,
|
| 428 |
+
step=None,
|
| 429 |
+
multithreaded=False
|
| 430 |
) -> str:
|
| 431 |
assert save_format in ["gif", "mp4"]
|
| 432 |
if not filename.endswith(save_format):
|
|
|
|
| 476 |
|
| 477 |
def save_obj(
|
| 478 |
self,
|
| 479 |
+
filename,
|
| 480 |
+
mesh,
|
| 481 |
+
save_mat=False,
|
| 482 |
+
save_normal=False,
|
| 483 |
+
save_uv=False,
|
| 484 |
+
save_vertex_color=False,
|
| 485 |
+
map_Kd=None,
|
| 486 |
+
map_Ks=None,
|
| 487 |
+
map_Bump=None,
|
| 488 |
+
map_Pm=None,
|
| 489 |
+
map_Pr=None,
|
| 490 |
+
map_format="jpg",
|
| 491 |
+
):
|
|
|
|
| 492 |
if not filename.endswith(".obj"):
|
| 493 |
filename += ".obj"
|
| 494 |
save_path = self.get_save_path(filename)
|
|
|
|
| 639 |
map_Pm=None,
|
| 640 |
map_Pr=None,
|
| 641 |
map_format="jpg",
|
| 642 |
+
step=None,
|
| 643 |
+
):
|
| 644 |
mtl_save_path = self.get_save_path(filename)
|
| 645 |
save_paths = [mtl_save_path]
|
| 646 |
mtl_str = f"newmtl {matname}\n"
|