xinjie.wang commited on
Commit
517c236
·
0 Parent(s):

Initial clean commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EmbodiedGen Gallery Explorer
3
+ emoji: 🐨
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: 🏛️ EmbodiedGen 3D Asset Gallery Explorer
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import os
18
+
19
+ gradio_tmp_dir = os.path.join(
20
+ os.path.dirname(os.path.abspath(__file__)), "gradio_cache"
21
+ )
22
+ os.makedirs(gradio_tmp_dir, exist_ok=True)
23
+ os.environ["GRADIO_TEMP_DIR"] = gradio_tmp_dir
24
+
25
+ import shutil
26
+ import uuid
27
+ import xml.etree.ElementTree as ET
28
+ from pathlib import Path
29
+ from typing import Any, Dict, Tuple
30
+
31
+ import gradio as gr
32
+ import pandas as pd
33
+ import yaml
34
+ from app_style import custom_theme, lighting_css
35
+
36
+ try:
37
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT as gpt_client
38
+
39
+ gpt_client.check_connection()
40
+ GPT_AVAILABLE = True
41
+ except Exception as e:
42
+ gpt_client = None
43
+ GPT_AVAILABLE = False
44
+ print(
45
+ f"Warning: GPT client could not be initialized. Search will be disabled. Error: {e}"
46
+ )
47
+
48
+
49
+ # --- Configuration & Data Loading ---
50
+ VERSION = "v0.1.5"
51
+ RUNNING_MODE = "hf_remote" # local or hf_remote
52
+ CSV_FILE = "dataset_index.csv"
53
+ import spaces
54
+ @spaces.GPU
55
+ def fake_gpu_init():
56
+ pass
57
+ fake_gpu_init()
58
+
59
+ if RUNNING_MODE == "local":
60
+ DATA_ROOT = "/horizon-bucket/robot_lab/datasets/embodiedgen/assets"
61
+ elif RUNNING_MODE == "hf_remote":
62
+ from huggingface_hub import snapshot_download
63
+
64
+ snapshot_download(
65
+ repo_id="HorizonRobotics/EmbodiedGenData",
66
+ repo_type="dataset",
67
+ allow_patterns=f"dataset/**",
68
+ local_dir="EmbodiedGenData",
69
+ local_dir_use_symlinks=False,
70
+ )
71
+ DATA_ROOT = "EmbodiedGenData/dataset"
72
+ else:
73
+ raise ValueError(
74
+ f"Unknown RUNNING_MODE: {RUNNING_MODE}, must be 'local' or 'hf_remote'."
75
+ )
76
+
77
+ csv_path = os.path.join(DATA_ROOT, CSV_FILE)
78
+ df = pd.read_csv(csv_path)
79
+ TMP_DIR = os.path.join(
80
+ os.path.dirname(os.path.abspath(__file__)), "sessions/asset_viewer"
81
+ )
82
+ os.makedirs(TMP_DIR, exist_ok=True)
83
+
84
+
85
+ # --- Custom CSS for Styling ---
86
+ css = """
87
+ .gradio-container .gradio-group { box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; }
88
+ #asset-gallery { border: 1px solid #E5E7EB; border-radius: 8px; padding: 8px; background-color: #F9FAFB; }
89
+ """
90
+
91
+ lighting_css = """
92
+ <style>
93
+ #visual_mesh canvas { filter: brightness(2.2) !important; }
94
+ #collision_mesh_a canvas, #collision_mesh_b canvas { filter: brightness(1.0) !important; }
95
+ </style>
96
+ """
97
+
98
+ _prev_temp = {}
99
+
100
+
101
+ def _unique_path(
102
+ src_path: str | None, session_hash: str, kind: str
103
+ ) -> str | None:
104
+ """Link/copy src to GRADIO_TEMP_DIR/session_hash with random filename. Always return a fresh URL."""
105
+ if not src_path:
106
+ return None
107
+ tmp_root = (
108
+ Path(os.environ.get("GRADIO_TEMP_DIR", "/tmp"))
109
+ / "model3d-cache"
110
+ / session_hash
111
+ )
112
+ tmp_root.mkdir(parents=True, exist_ok=True)
113
+
114
+ # rolling cleanup for same kind
115
+ prev = _prev_temp.get(session_hash, {})
116
+ old = prev.get(kind)
117
+ if old and old.exists():
118
+ old.unlink()
119
+
120
+ ext = Path(src_path).suffix or ".glb"
121
+ dst = tmp_root / f"{kind}-{uuid.uuid4().hex}{ext}"
122
+ shutil.copy2(src_path, dst)
123
+
124
+ prev[kind] = dst
125
+ _prev_temp[session_hash] = prev
126
+ return str(dst)
127
+
128
+
129
+ # --- Helper Functions (data filtering) ---
130
+ def get_primary_categories():
131
+ return sorted(df["primary_category"].dropna().unique())
132
+
133
+
134
+ def get_secondary_categories(primary):
135
+ if not primary:
136
+ return []
137
+ return sorted(
138
+ df[df["primary_category"] == primary]["secondary_category"]
139
+ .dropna()
140
+ .unique()
141
+ )
142
+
143
+
144
+ def get_categories(primary, secondary):
145
+ if not primary or not secondary:
146
+ return []
147
+ return sorted(
148
+ df[
149
+ (df["primary_category"] == primary)
150
+ & (df["secondary_category"] == secondary)
151
+ ]["category"]
152
+ .dropna()
153
+ .unique()
154
+ )
155
+
156
+
157
+ def get_assets(primary, secondary, category):
158
+ if not primary or not secondary:
159
+ return [], gr.update(interactive=False), pd.DataFrame()
160
+
161
+ subset = df[
162
+ (df["primary_category"] == primary)
163
+ & (df["secondary_category"] == secondary)
164
+ ]
165
+ if category:
166
+ subset = subset[subset["category"] == category]
167
+
168
+ items = []
169
+ for row in subset.itertuples():
170
+ asset_dir = os.path.join(DATA_ROOT, row.asset_dir)
171
+ video_path = None
172
+ if pd.notna(asset_dir) and os.path.exists(asset_dir):
173
+ for f in os.listdir(asset_dir):
174
+ if f.lower().endswith(".mp4"):
175
+ video_path = os.path.join(asset_dir, f)
176
+ break
177
+ items.append(
178
+ video_path
179
+ if video_path
180
+ else "https://dummyimage.com/512x512/cccccc/000000&text=No+Preview"
181
+ )
182
+
183
+ return items, gr.update(interactive=True), subset
184
+
185
+
186
+ def search_assets(query: str, top_k: int):
187
+ if not GPT_AVAILABLE or not query:
188
+ gr.Warning(
189
+ "GPT client is not available or query is empty. Cannot perform search."
190
+ )
191
+ return [], gr.update(interactive=False), pd.DataFrame()
192
+
193
+ gr.Info(f"Searching for assets matching: '{query}'...")
194
+
195
+ keywords = query.split()
196
+ keyword_filter = pd.Series([False] * len(df), index=df.index)
197
+ for keyword in keywords:
198
+ keyword_filter |= df['description'].str.contains(
199
+ keyword, case=False, na=False
200
+ )
201
+
202
+ candidates = df[keyword_filter]
203
+
204
+ if len(candidates) > 100:
205
+ candidates = candidates.head(100)
206
+
207
+ if candidates.empty:
208
+ gr.Warning("No assets found matching the keywords.")
209
+ return [], gr.update(interactive=True), pd.DataFrame()
210
+
211
+ try:
212
+ descriptions = [
213
+ f"{idx}: {desc}" for idx, desc in candidates['description'].items()
214
+ ]
215
+ descriptions_text = "\n".join(descriptions)
216
+
217
+ prompt = f"""
218
+ A user is searching for 3D assets with the query: "{query}".
219
+ Below is a list of available assets, each with an ID and a description.
220
+ Please evaluate how well each asset description matches the user's query and rate them on a scale from 0 to 10, where 10 is a perfect match.
221
+
222
+ Your task is to return a list of the top {top_k} asset IDs, sorted from the most relevant to the least relevant.
223
+ The output format must be a simple comma-separated list of IDs, for example: "123,45,678". Do not add any other text.
224
+
225
+ Asset Descriptions:
226
+ {descriptions_text}
227
+
228
+ User Query: "{query}"
229
+
230
+ Top {top_k} sorted asset IDs:
231
+ """
232
+ response = gpt_client.query(prompt)
233
+ sorted_ids_str = response.strip().split(',')
234
+ sorted_ids = [
235
+ int(id_str.strip())
236
+ for id_str in sorted_ids_str
237
+ if id_str.strip().isdigit()
238
+ ]
239
+ top_assets = df.loc[sorted_ids].head(top_k)
240
+ except Exception as e:
241
+ gr.Error(f"An error occurred while using GPT for ranking: {e}")
242
+ top_assets = candidates.head(top_k)
243
+
244
+ items = []
245
+ for row in top_assets.itertuples():
246
+ asset_dir = os.path.join(DATA_ROOT, row.asset_dir)
247
+ video_path = None
248
+ if pd.notna(row.asset_dir) and os.path.exists(asset_dir):
249
+ for f in os.listdir(asset_dir):
250
+ if f.lower().endswith(".mp4"):
251
+ video_path = os.path.join(asset_dir, f)
252
+ break
253
+ items.append(
254
+ video_path
255
+ if video_path
256
+ else "https://dummyimage.com/512x512/cccccc/000000&text=No+Preview"
257
+ )
258
+
259
+ gr.Info(f"Found {len(items)} assets.")
260
+ return items, gr.update(interactive=True), top_assets
261
+
262
+
263
+ # --- Mesh extraction ---
264
+ def _extract_mesh_paths(row) -> Tuple[str | None, str | None, str]:
265
+ desc = row["description"]
266
+ urdf_path = os.path.join(DATA_ROOT, row["urdf_path"])
267
+ asset_dir = os.path.join(DATA_ROOT, row["asset_dir"])
268
+ visual_mesh_path = None
269
+ collision_mesh_path = None
270
+
271
+ if pd.notna(urdf_path) and os.path.exists(urdf_path):
272
+ try:
273
+ tree = ET.parse(urdf_path)
274
+ root = tree.getroot()
275
+
276
+ visual_mesh_element = root.find('.//visual/geometry/mesh')
277
+ if visual_mesh_element is not None:
278
+ visual_mesh_filename = visual_mesh_element.get('filename')
279
+ if visual_mesh_filename:
280
+ glb_filename = (
281
+ os.path.splitext(visual_mesh_filename)[0] + ".glb"
282
+ )
283
+ potential_path = os.path.join(asset_dir, glb_filename)
284
+ if os.path.exists(potential_path):
285
+ visual_mesh_path = potential_path
286
+
287
+ collision_mesh_element = root.find('.//collision/geometry/mesh')
288
+ if collision_mesh_element is not None:
289
+ collision_mesh_filename = collision_mesh_element.get(
290
+ 'filename'
291
+ )
292
+ if collision_mesh_filename:
293
+ potential_collision_path = os.path.join(
294
+ asset_dir, collision_mesh_filename
295
+ )
296
+ if os.path.exists(potential_collision_path):
297
+ collision_mesh_path = potential_collision_path
298
+
299
+ except ET.ParseError:
300
+ desc = f"Error: Failed to parse URDF at {urdf_path}. {desc}"
301
+ except Exception as e:
302
+ desc = f"An error occurred while processing URDF: {str(e)}. {desc}"
303
+
304
+ return visual_mesh_path, collision_mesh_path, desc
305
+
306
+
307
+ def show_asset_from_gallery(
308
+ evt: gr.SelectData,
309
+ primary: str,
310
+ secondary: str,
311
+ category: str,
312
+ search_query: str,
313
+ gallery_df: pd.DataFrame,
314
+ ):
315
+ """Parse the selected asset and return raw paths + metadata."""
316
+ index = evt.index
317
+
318
+ if search_query and gallery_df is not None and not gallery_df.empty:
319
+ subset = gallery_df
320
+ else:
321
+ if not primary or not secondary:
322
+ return (
323
+ None, # visual_path
324
+ None, # collision_path
325
+ "Error: Primary or secondary category not selected.",
326
+ None, # asset_dir
327
+ None, # urdf_path
328
+ "N/A",
329
+ "N/A",
330
+ "N/A",
331
+ "N/A",
332
+ )
333
+
334
+ subset = df[
335
+ (df["primary_category"] == primary)
336
+ & (df["secondary_category"] == secondary)
337
+ ]
338
+ if category:
339
+ subset = subset[subset["category"] == category]
340
+
341
+ if subset.empty or index >= len(subset):
342
+ return (
343
+ None,
344
+ None,
345
+ "Error: Selection index is out of bounds or data is missing.",
346
+ None,
347
+ None,
348
+ "N/A",
349
+ "N/A",
350
+ "N/A",
351
+ "N/A",
352
+ )
353
+
354
+ row = subset.iloc[index]
355
+ visual_path, collision_path, desc = _extract_mesh_paths(row)
356
+
357
+ urdf_path = os.path.join(DATA_ROOT, row["urdf_path"])
358
+ asset_dir = os.path.join(DATA_ROOT, row["asset_dir"])
359
+
360
+ # read extra info
361
+ est_type_text = "N/A"
362
+ est_height_text = "N/A"
363
+ est_mass_text = "N/A"
364
+ est_mu_text = "N/A"
365
+
366
+ if pd.notna(urdf_path) and os.path.exists(urdf_path):
367
+ try:
368
+ tree = ET.parse(urdf_path)
369
+ root = tree.getroot()
370
+ category_elem = root.find('.//extra_info/category')
371
+ if category_elem is not None and category_elem.text:
372
+ est_type_text = category_elem.text.strip()
373
+ height_elem = root.find('.//extra_info/real_height')
374
+ if height_elem is not None and height_elem.text:
375
+ est_height_text = height_elem.text.strip()
376
+ mass_elem = root.find('.//extra_info/min_mass')
377
+ if mass_elem is not None and mass_elem.text:
378
+ est_mass_text = mass_elem.text.strip()
379
+ mu_elem = root.find('.//collision/gazebo/mu2')
380
+ if mu_elem is not None and mu_elem.text:
381
+ est_mu_text = mu_elem.text.strip()
382
+ except Exception:
383
+ pass
384
+
385
+ return (
386
+ visual_path,
387
+ collision_path,
388
+ desc,
389
+ asset_dir,
390
+ urdf_path,
391
+ est_type_text,
392
+ est_height_text,
393
+ est_mass_text,
394
+ est_mu_text,
395
+ )
396
+
397
+
398
+ def render_meshes(
399
+ visual_path: str | None,
400
+ collision_path: str | None,
401
+ switch_viewer: bool,
402
+ req: gr.Request,
403
+ ):
404
+ session_hash = getattr(req, "session_hash", "default")
405
+
406
+ if switch_viewer:
407
+ yield (
408
+ gr.update(value=None),
409
+ gr.update(value=None, visible=False),
410
+ gr.update(value=None, visible=True),
411
+ True,
412
+ )
413
+ else:
414
+ yield (
415
+ gr.update(value=None),
416
+ gr.update(value=None, visible=True),
417
+ gr.update(value=None, visible=False),
418
+ True,
419
+ )
420
+
421
+ visual_unique = (
422
+ _unique_path(visual_path, session_hash, "visual")
423
+ if visual_path
424
+ else None
425
+ )
426
+ collision_unique = (
427
+ _unique_path(collision_path, session_hash, "collision")
428
+ if collision_path
429
+ else None
430
+ )
431
+
432
+ if switch_viewer:
433
+ yield (
434
+ gr.update(value=visual_unique),
435
+ gr.update(value=None, visible=False),
436
+ gr.update(value=collision_unique, visible=True),
437
+ False,
438
+ )
439
+ else:
440
+ yield (
441
+ gr.update(value=visual_unique),
442
+ gr.update(value=collision_unique, visible=True),
443
+ gr.update(value=None, visible=False),
444
+ True,
445
+ )
446
+
447
+
448
+ def create_asset_zip(asset_dir: str, req: gr.Request):
449
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
450
+ os.makedirs(user_dir, exist_ok=True)
451
+
452
+ asset_folder_name = os.path.basename(os.path.normpath(asset_dir))
453
+ zip_path_base = os.path.join(user_dir, asset_folder_name)
454
+
455
+ archive_path = shutil.make_archive(
456
+ base_name=zip_path_base, format='zip', root_dir=asset_dir
457
+ )
458
+ gr.Info(f"✅ {asset_folder_name}.zip is ready and can be downloaded.")
459
+
460
+ return archive_path
461
+
462
+
463
+ def start_session(req: gr.Request) -> None:
464
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
465
+ os.makedirs(user_dir, exist_ok=True)
466
+
467
+
468
+ def end_session(req: gr.Request) -> None:
469
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
470
+ if os.path.exists(user_dir):
471
+ shutil.rmtree(user_dir)
472
+
473
+
474
+ # --- UI ---
475
+ with gr.Blocks(
476
+ theme=custom_theme,
477
+ css=css,
478
+ title="3D Asset Library",
479
+ ) as demo:
480
+ # gr.HTML(lighting_css, visible=False)
481
+ gr.Markdown(
482
+ """
483
+ ## 🏛️ ***EmbodiedGen***: 3D Asset Gallery Explorer
484
+
485
+ **���� Version**: {VERSION}
486
+ <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
487
+ <a href="https://horizonrobotics.github.io/EmbodiedGen">
488
+ <img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
489
+ </a>
490
+ <a href="https://arxiv.org/abs/2506.10600">
491
+ <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
492
+ </a>
493
+ <a href="https://github.com/HorizonRobotics/EmbodiedGen">
494
+ <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
495
+ </a>
496
+ <a href="https://www.youtube.com/watch?v=rG4odybuJRk">
497
+ <img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
498
+ </a>
499
+ </p>
500
+
501
+ Browse and visualize the EmbodiedGen 3D asset database. Select categories to filter and click on a preview to load the model.
502
+
503
+ """.format(
504
+ VERSION=VERSION
505
+ ),
506
+ elem_classes=["header"],
507
+ )
508
+
509
+ primary_list = get_primary_categories()
510
+ primary_val = primary_list[0] if primary_list else None
511
+ secondary_list = get_secondary_categories(primary_val)
512
+ secondary_val = secondary_list[0] if secondary_list else None
513
+ category_list = get_categories(primary_val, secondary_val)
514
+ category_val = category_list[0] if category_list else None
515
+ asset_folder = gr.State(value=None)
516
+ gallery_df_state = gr.State()
517
+
518
+ switch_viewer_state = gr.State(value=False)
519
+
520
+ with gr.Row(equal_height=False):
521
+ with gr.Column(scale=1, min_width=350):
522
+ with gr.Group():
523
+ gr.Markdown("### Search Asset with Descriptions")
524
+ search_box = gr.Textbox(
525
+ label="🔎 Enter your search query",
526
+ placeholder="e.g., 'a red chair with four legs'",
527
+ interactive=GPT_AVAILABLE,
528
+ )
529
+ top_k_slider = gr.Slider(
530
+ minimum=1,
531
+ maximum=50,
532
+ value=10,
533
+ step=1,
534
+ label="Number of results",
535
+ interactive=GPT_AVAILABLE,
536
+ )
537
+ search_button = gr.Button(
538
+ "Search", variant="primary", interactive=GPT_AVAILABLE
539
+ )
540
+ if not GPT_AVAILABLE:
541
+ gr.Markdown(
542
+ "<p style='color: #ff4b4b;'>⚠️ GPT client not available. Search is disabled.</p>"
543
+ )
544
+
545
+ with gr.Group():
546
+ gr.Markdown("### Select Asset Category")
547
+ primary = gr.Dropdown(
548
+ choices=primary_list,
549
+ value=primary_val,
550
+ label="🗂️ Primary Category",
551
+ )
552
+ secondary = gr.Dropdown(
553
+ choices=secondary_list,
554
+ value=secondary_val,
555
+ label="📂 Secondary Category",
556
+ )
557
+ category = gr.Dropdown(
558
+ choices=category_list,
559
+ value=category_val,
560
+ label="🏷️ Asset Category",
561
+ )
562
+
563
+ with gr.Group():
564
+ initial_assets, _, initial_df = get_assets(
565
+ primary_val, secondary_val, category_val
566
+ )
567
+ gallery = gr.Gallery(
568
+ value=initial_assets,
569
+ label="🖼️ Asset Previews",
570
+ columns=3,
571
+ height="auto",
572
+ allow_preview=True,
573
+ elem_id="asset-gallery",
574
+ interactive=bool(category_val),
575
+ )
576
+
577
+ with gr.Column(scale=2, min_width=500):
578
+ with gr.Group():
579
+ with gr.Tabs():
580
+ with gr.TabItem("Visual Mesh") as t1:
581
+ viewer = gr.Model3D(
582
+ label="🧊 3D Model Viewer",
583
+ height=500,
584
+ clear_color=[0.95, 0.95, 0.95],
585
+ elem_id="visual_mesh",
586
+ )
587
+ with gr.TabItem("Collision Mesh") as t2:
588
+ collision_viewer_a = gr.Model3D(
589
+ label="🧊 Collision Mesh",
590
+ height=500,
591
+ clear_color=[0.95, 0.95, 0.95],
592
+ elem_id="collision_mesh_a",
593
+ visible=True,
594
+ )
595
+ collision_viewer_b = gr.Model3D(
596
+ label="🧊 Collision Mesh",
597
+ height=500,
598
+ clear_color=[0.95, 0.95, 0.95],
599
+ elem_id="collision_mesh_b",
600
+ visible=False,
601
+ )
602
+
603
+ t1.select(
604
+ fn=lambda: None,
605
+ js="() => { window.dispatchEvent(new Event('resize')); }",
606
+ )
607
+ t2.select(
608
+ fn=lambda: None,
609
+ js="() => { window.dispatchEvent(new Event('resize')); }",
610
+ )
611
+
612
+ with gr.Row():
613
+ est_type_text = gr.Textbox(
614
+ label="Asset category", interactive=False
615
+ )
616
+ est_height_text = gr.Textbox(
617
+ label="Real height(.m)", interactive=False
618
+ )
619
+ est_mass_text = gr.Textbox(
620
+ label="Mass(.kg)", interactive=False
621
+ )
622
+ est_mu_text = gr.Textbox(
623
+ label="Friction coefficient", interactive=False
624
+ )
625
+ with gr.Row():
626
+ desc_box = gr.Textbox(
627
+ label="📝 Asset Description", interactive=False
628
+ )
629
+ with gr.Accordion(label="Asset Details", open=False):
630
+ urdf_file = gr.Textbox(
631
+ label="URDF File Path", interactive=False, lines=2
632
+ )
633
+ with gr.Row():
634
+ extract_btn = gr.Button(
635
+ "📥 Extract Asset",
636
+ variant="primary",
637
+ interactive=False,
638
+ )
639
+ download_btn = gr.DownloadButton(
640
+ label="⬇️ Download Asset",
641
+ variant="primary",
642
+ interactive=False,
643
+ )
644
+
645
+ search_button.click(
646
+ fn=search_assets,
647
+ inputs=[search_box, top_k_slider],
648
+ outputs=[gallery, gallery, gallery_df_state],
649
+ )
650
+ search_box.submit(
651
+ fn=search_assets,
652
+ inputs=[search_box, top_k_slider],
653
+ outputs=[gallery, gallery, gallery_df_state],
654
+ )
655
+
656
+ def update_on_primary_change(p):
657
+ s_choices = get_secondary_categories(p)
658
+ initial_assets, gallery_update, initial_df = get_assets(p, None, None)
659
+ return (
660
+ gr.update(choices=s_choices, value=None),
661
+ gr.update(choices=[], value=None),
662
+ initial_assets,
663
+ gallery_update,
664
+ initial_df,
665
+ )
666
+
667
+ def update_on_secondary_change(p, s):
668
+ c_choices = get_categories(p, s)
669
+ asset_previews, gallery_update, gallery_df = get_assets(p, s, None)
670
+ return (
671
+ gr.update(choices=c_choices, value=None),
672
+ asset_previews,
673
+ gallery_update,
674
+ gallery_df,
675
+ )
676
+
677
+ def update_assets(p, s, c):
678
+ asset_previews, gallery_update, gallery_df = get_assets(p, s, c)
679
+ return asset_previews, gallery_update, gallery_df
680
+
681
+ primary.change(
682
+ fn=update_on_primary_change,
683
+ inputs=[primary],
684
+ outputs=[secondary, category, gallery, gallery, gallery_df_state],
685
+ )
686
+ secondary.change(
687
+ fn=update_on_secondary_change,
688
+ inputs=[primary, secondary],
689
+ outputs=[category, gallery, gallery, gallery_df_state],
690
+ )
691
+ category.change(
692
+ fn=update_assets,
693
+ inputs=[primary, secondary, category],
694
+ outputs=[gallery, gallery, gallery_df_state],
695
+ )
696
+
697
+ gallery.select(
698
+ fn=show_asset_from_gallery,
699
+ inputs=[primary, secondary, category, search_box, gallery_df_state],
700
+ outputs=[
701
+ (visual_path_state := gr.State()),
702
+ (collision_path_state := gr.State()),
703
+ desc_box,
704
+ asset_folder,
705
+ urdf_file,
706
+ est_type_text,
707
+ est_height_text,
708
+ est_mass_text,
709
+ est_mu_text,
710
+ ],
711
+ ).then(
712
+ fn=render_meshes,
713
+ inputs=[visual_path_state, collision_path_state, switch_viewer_state],
714
+ outputs=[
715
+ viewer,
716
+ collision_viewer_a,
717
+ collision_viewer_b,
718
+ switch_viewer_state,
719
+ ],
720
+ ).success(
721
+ lambda: (gr.Button(interactive=True), gr.Button(interactive=False)),
722
+ outputs=[extract_btn, download_btn],
723
+ )
724
+
725
+ extract_btn.click(
726
+ fn=create_asset_zip, inputs=[asset_folder], outputs=[download_btn]
727
+ ).success(fn=lambda: gr.update(interactive=True), outputs=download_btn)
728
+
729
+ demo.load(start_session)
730
+ demo.unload(end_session)
731
+
732
+
733
+ if __name__ == "__main__":
734
+ demo.launch()
app_style.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio.themes import Soft
2
+ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
3
+
4
+ lighting_css = """
5
+ <style>
6
+ #lighter_mesh canvas {
7
+ filter: brightness(1.9) !important;
8
+ }
9
+ </style>
10
+ """
11
+
12
+ image_css = """
13
+ <style>
14
+ .image_fit .image-frame {
15
+ object-fit: contain !important;
16
+ height: 100% !important;
17
+ }
18
+ </style>
19
+ """
20
+
21
+ custom_theme = Soft(
22
+ primary_hue=stone,
23
+ secondary_hue=gray,
24
+ radius_size="md",
25
+ text_size="sm",
26
+ spacing_size="sm",
27
+ )
embodied_gen/utils/gpt_clients.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import math
18
+ import base64
19
+ import logging
20
+ import os
21
+ from io import BytesIO
22
+ from typing import Optional
23
+
24
+ import yaml
25
+ from openai import AzureOpenAI, OpenAI # pip install openai
26
+ from PIL import Image
27
+ from tenacity import (
28
+ retry,
29
+ stop_after_attempt,
30
+ stop_after_delay,
31
+ wait_random_exponential,
32
+ )
33
+
34
+ logging.getLogger("httpx").setLevel(logging.WARNING)
35
+ logging.basicConfig(level=logging.WARNING)
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ __all__ = [
40
+ "GPTclient",
41
+ ]
42
+
43
+ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
44
+
45
+
46
+ def combine_images_to_grid(
47
+ images: list[str | Image.Image],
48
+ cat_row_col: tuple[int, int] = None,
49
+ target_wh: tuple[int, int] = (512, 512),
50
+ image_mode: str = "RGB",
51
+ ) -> list[Image.Image]:
52
+ n_images = len(images)
53
+ if n_images == 1:
54
+ return images
55
+
56
+ if cat_row_col is None:
57
+ n_col = math.ceil(math.sqrt(n_images))
58
+ n_row = math.ceil(n_images / n_col)
59
+ else:
60
+ n_row, n_col = cat_row_col
61
+
62
+ images = [
63
+ Image.open(p).convert(image_mode) if isinstance(p, str) else p
64
+ for p in images
65
+ ]
66
+ images = [img.resize(target_wh) for img in images]
67
+
68
+ grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
69
+ grid = Image.new(image_mode, (grid_w, grid_h), (0, 0, 0))
70
+
71
+ for idx, img in enumerate(images):
72
+ row, col = divmod(idx, n_col)
73
+ grid.paste(img, (col * target_wh[0], row * target_wh[1]))
74
+
75
+ return [grid]
76
+
77
+
78
+ class GPTclient:
79
+ """A client to interact with the GPT model via OpenAI or Azure API."""
80
+
81
+ def __init__(
82
+ self,
83
+ endpoint: str,
84
+ api_key: str,
85
+ model_name: str = "yfb-gpt-4o",
86
+ api_version: str = None,
87
+ check_connection: bool = True,
88
+ verbose: bool = False,
89
+ ):
90
+ if api_version is not None:
91
+ self.client = AzureOpenAI(
92
+ azure_endpoint=endpoint,
93
+ api_key=api_key,
94
+ api_version=api_version,
95
+ )
96
+ else:
97
+ self.client = OpenAI(
98
+ base_url=endpoint,
99
+ api_key=api_key,
100
+ )
101
+
102
+ self.endpoint = endpoint
103
+ self.model_name = model_name
104
+ self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
105
+ self.verbose = verbose
106
+ if check_connection:
107
+ self.check_connection()
108
+
109
+ logger.info(f"Using GPT model: {self.model_name}.")
110
+
111
+ @retry(
112
+ wait=wait_random_exponential(min=1, max=20),
113
+ stop=(stop_after_attempt(10) | stop_after_delay(30)),
114
+ )
115
+ def completion_with_backoff(self, **kwargs):
116
+ return self.client.chat.completions.create(**kwargs)
117
+
118
+ def query(
119
+ self,
120
+ text_prompt: str,
121
+ image_base64: Optional[list[str | Image.Image]] = None,
122
+ system_role: Optional[str] = None,
123
+ params: Optional[dict] = None,
124
+ ) -> Optional[str]:
125
+ """Queries the GPT model with a text and optional image prompts.
126
+
127
+ Args:
128
+ text_prompt (str): The main text input that the model responds to.
129
+ image_base64 (Optional[List[str]]): A list of image base64 strings
130
+ or local image paths or PIL.Image to accompany the text prompt.
131
+ system_role (Optional[str]): Optional system-level instructions
132
+ that specify the behavior of the assistant.
133
+ params (Optional[dict]): Additional parameters for GPT setting.
134
+
135
+ Returns:
136
+ Optional[str]: The response content generated by the model based on
137
+ the prompt. Returns `None` if an error occurs.
138
+ """
139
+ if system_role is None:
140
+ system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
141
+
142
+ content_user = [
143
+ {
144
+ "type": "text",
145
+ "text": text_prompt,
146
+ },
147
+ ]
148
+
149
+ # Process images if provided
150
+ if image_base64 is not None:
151
+ if not isinstance(image_base64, list):
152
+ image_base64 = [image_base64]
153
+ # Hardcode tmp because of the openrouter can't input multi images.
154
+ if "openrouter" in self.endpoint:
155
+ image_base64 = combine_images_to_grid(image_base64)
156
+ for img in image_base64:
157
+ if isinstance(img, Image.Image):
158
+ buffer = BytesIO()
159
+ img.save(buffer, format=img.format or "PNG")
160
+ buffer.seek(0)
161
+ image_binary = buffer.read()
162
+ img = base64.b64encode(image_binary).decode("utf-8")
163
+ elif (
164
+ len(os.path.splitext(img)) > 1
165
+ and os.path.splitext(img)[-1].lower() in self.image_formats
166
+ ):
167
+ if not os.path.exists(img):
168
+ raise FileNotFoundError(f"Image file not found: {img}")
169
+ with open(img, "rb") as f:
170
+ img = base64.b64encode(f.read()).decode("utf-8")
171
+
172
+ content_user.append(
173
+ {
174
+ "type": "image_url",
175
+ "image_url": {"url": f"data:image/png;base64,{img}"},
176
+ }
177
+ )
178
+
179
+ payload = {
180
+ "messages": [
181
+ {"role": "system", "content": system_role},
182
+ {"role": "user", "content": content_user},
183
+ ],
184
+ "temperature": 0.1,
185
+ "max_tokens": 500,
186
+ "top_p": 0.1,
187
+ "frequency_penalty": 0,
188
+ "presence_penalty": 0,
189
+ "stop": None,
190
+ "model": self.model_name,
191
+ }
192
+
193
+ if params:
194
+ payload.update(params)
195
+
196
+ response = None
197
+ try:
198
+ response = self.completion_with_backoff(**payload)
199
+ response = response.choices[0].message.content
200
+ except Exception as e:
201
+ logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
202
+ response = None
203
+
204
+ if self.verbose:
205
+ logger.info(f"Prompt: {text_prompt}")
206
+ logger.info(f"Response: {response}")
207
+
208
+ return response
209
+
210
+ def check_connection(self) -> None:
211
+ """Check whether the GPT API connection is working."""
212
+ try:
213
+ response = self.completion_with_backoff(
214
+ messages=[
215
+ {"role": "system", "content": "You are a test system."},
216
+ {"role": "user", "content": "Hello"},
217
+ ],
218
+ model=self.model_name,
219
+ temperature=0,
220
+ max_tokens=100,
221
+ )
222
+ content = response.choices[0].message.content
223
+ logger.info(f"Connection check success.")
224
+ except Exception as e:
225
+ raise ConnectionError(
226
+ f"Failed to connect to GPT API at {self.endpoint}, "
227
+ f"please check setting in `{CONFIG_FILE}` and `README`."
228
+ )
229
+
230
+
231
+ with open(CONFIG_FILE, "r") as f:
232
+ config = yaml.safe_load(f)
233
+
234
+ agent_type = config["agent_type"]
235
+ agent_config = config.get(agent_type, {})
236
+
237
+ # Prefer environment variables, fallback to YAML config
238
+ endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint"))
239
+ api_key = os.environ.get("API_KEY", agent_config.get("api_key"))
240
+ api_version = os.environ.get("API_VERSION", agent_config.get("api_version"))
241
+ model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name"))
242
+
243
+ GPT_CLIENT = GPTclient(
244
+ endpoint=endpoint,
245
+ api_key=api_key,
246
+ api_version=api_version,
247
+ model_name=model_name,
248
+ check_connection=False,
249
+ )
embodied_gen/utils/gpt_config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.yaml
2
+ agent_type: "qwen2.5-vl" # gpt-4o or qwen2.5-vl
3
+
4
+ gpt-4o:
5
+ endpoint: https://xxx.openai.azure.com
6
+ api_key: xxx
7
+ api_version: 2025-xx-xx
8
+ model_name: yfb-gpt-4o
9
+
10
+ qwen2.5-vl:
11
+ endpoint: https://openrouter.ai/api/v1
12
+ api_key: sk-or-v1-xxx
13
+ api_version: null
14
+ model_name: qwen/qwen2.5-vl-72b-instruct:free
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==5.12.0
2
+ pandas
3
+ openai==1.58.1
4
+ tenacity