diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..d6a8f00f406180633968b688d455e151b1905991 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Desk/532328457_1311198870420578_2167456836351167380_n.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Office/Office.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Room_Cat/no_overlap_2.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Room_Cat/no_overlap_3.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Room_Cat/no_overlap_4.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Room_Cat/no_overlap_5.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Room_Cat/no_overlap_6.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Room_Cat/no_overlap_7.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Room_Cat/no_overlap_8.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Sisters_Statue/481869432_646849634388788_2162202232236218000_n.jpg filter=lfs diff=lfs merge=lfs -text
+examples/realistic/Sisters_Statue/481943293_641636221777392_2955401254290735956_n.jpg filter=lfs diff=lfs merge=lfs -text
+examples/stylistic/Cat_Girl/Cat_Girl.jpg filter=lfs diff=lfs merge=lfs -text
+examples/stylistic/Oil_Painting/oil.jpg filter=lfs diff=lfs merge=lfs -text
+examples/stylistic/Panda_Wild_West/panda_orange_cat_wildwest.jpeg filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..11c08b326e326c497634d846a0739f2e76dbad69
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+__pycache__/
+*.py[cod]
+*$py.class
+.gradio/
+inference_output/
+submodules/gsplat/examples/pycolmap/
\ No newline at end of file
diff --git a/License.txt b/License.txt
new file mode 100644
index 0000000000000000000000000000000000000000..13c89039217923c9c21b29a66a5010c474d775bb
--- /dev/null
+++ b/License.txt
@@ -0,0 +1,82 @@
+TENCENT HUNYUANWORLD-MIRROR COMMUNITY LICENSE AGREEMENT
+Tencent HunyuanWorld-Mirror Release Date: October 22, 2025
+THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
+By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent HunyuanWorld-Mirror Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
+1. DEFINITIONS.
+a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
+b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent HunyuanWorld-Mirror Works or any portion or element thereof set forth herein.
+c. “Documentation” shall mean the specifications, manuals and documentation for Tencent HunyuanWorld-Mirror made publicly available by Tencent.
+d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
+e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent HunyuanWorld-Mirror Works for any purpose and in any field of use.
+f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent HunyuanWorld-Mirror and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
+g. “Model Derivatives” shall mean all: (i) modifications to Tencent HunyuanWorld-Mirror or any Model Derivative of Tencent HunyuanWorld-Mirror; (ii) works based on Tencent HunyuanWorld-Mirror or any Model Derivative of Tencent HunyuanWorld-Mirror; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent HunyuanWorld-Mirror or any Model Derivative of Tencent HunyuanWorld-Mirror, to that model in order to cause that model to perform similarly to Tencent HunyuanWorld-Mirror or a Model Derivative of Tencent HunyuanWorld-Mirror, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent HunyuanWorld-Mirror or a Model Derivative of Tencent HunyuanWorld-Mirror for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
+h. “Output” shall mean the information and/or content output of Tencent HunyuanWorld-Mirror or a Model Derivative that results from operating or otherwise using Tencent HunyuanWorld-Mirror or a Model Derivative, including via a Hosted Service.
+i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials..
+j. “Tencent HunyuanWorld-Mirror” shall mean the 3D generation models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us at [https://github.com/Tencent-Hunyuan/HunyuanWorld-Mirror].
+k. “Tencent HunyuanWorld-Mirror Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
+l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
+m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
+n. “including” shall mean including but not limited to.
+2. GRANT OF RIGHTS.
+We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
+3. DISTRIBUTION.
+You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent HunyuanWorld-Mirror Works, exclusively in the Territory, provided that You meet all of the following conditions:
+a. You must provide all such Third Party recipients of the Tencent HunyuanWorld-Mirror Works or products or services using them a copy of this Agreement;
+b. You must cause any modified files to carry prominent notices stating that You changed the files;
+c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent HunyuanWorld-Mirror Works; and (ii) mark the products or services developed by using the Tencent HunyuanWorld-Mirror Works to indicate that the product/service is “Powered by Tencent Hunyuan”;
+d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent HunyuanWorld-Mirror is licensed under the Tencent HunyuanWorld-Mirror Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate”;
+e. In the event that You use, integrate, implement, or otherwise deploy the Tencent Hunyuan Works, in whole or in part, to provide, enable, or support any service, product, or functionality to third parties, You shall clearly, accurately, and prominently disclose to all end users the full legal name and entity of the actual provider of such service, product, or functionality. You shall expressly and conspicuously state that Tencent is not affiliated with, associated with, sponsoring, or endorsing any such service, product, or functionality. You shall not use or display any name, logo, trademark, trade name, or other indicia of Tencent in any manner that could be construed as, or be likely to create, confusion, deception, or a false impression regarding any relationship, affiliation, sponsorship, or endorsement by Tencent.
+You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent HunyuanWorld-Mirror Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
+4. ADDITIONAL COMMERCIAL TERMS.
+If, on the Tencent HunyuanWorld-Mirror version release date, the monthly active users of all products or services made available by or for Licensee is greater than 1 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
+Subject to Tencent's written approval, you may request a license for the use of Tencent HunyuanWorld-Mirror by submitting the following information to hunyuan3d@tencent.com:
+a. Your company’s name and associated business sector that plans to use Tencent HunyuanWorld-Mirror.
+b. Your intended use case and the purpose of using Tencent HunyuanWorld-Mirror.
+c. Your plans to modify Tencent HunyuanWorld-Mirror or create Model Derivatives.
+5. RULES OF USE.
+a. Your use of the Tencent HunyuanWorld-Mirror Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent HunyuanWorld-Mirror Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent HunyuanWorld-Mirror Works and You must provide notice to subsequent users to whom You distribute that Tencent HunyuanWorld-Mirror Works are subject to the use restrictions in these Sections 5(a) and 5(b).
+b. You must not use the Tencent HunyuanWorld-Mirror Works or any Output or results of the Tencent HunyuanWorld-Mirror Works to improve any other AI model (other than Tencent HunyuanWorld-Mirror or Model Derivatives thereof).
+c. You must not use, reproduce, modify, distribute, or display the Tencent HunyuanWorld-Mirror Works, Output or results of the Tencent HunyuanWorld-Mirror Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
+6. INTELLECTUAL PROPERTY.
+a. Subject to Tencent’s ownership of Tencent HunyuanWorld-Mirror Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
+b. No trademark licenses are granted under this Agreement, and in connection with the Tencent HunyuanWorld-Mirror Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent HunyuanWorld-Mirror Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
+c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent HunyuanWorld-Mirror Works.
+d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
+7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
+a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent HunyuanWorld-Mirror Works or to grant any license thereto.
+b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUANWORLD-MIRROR WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUANWORLD-MIRROR WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUANWORLD-MIRROR WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
+c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUANWORLD-MIRROR WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
+8. SURVIVAL AND TERMINATION.
+a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
+b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent HunyuanWorld-Mirror Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
+9. GOVERNING LAW AND JURISDICTION.
+a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
+b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
+
+EXHIBIT A
+ACCEPTABLE USE POLICY
+
+Tencent reserves the right to update this Acceptable Use Policy from time to time.
+Last modified: November 5, 2024
+
+Tencent endeavors to promote safe and fair use of its tools and features, including Tencent HunyuanWorld-Mirror. You agree not to use Tencent HunyuanWorld-Mirror or Model Derivatives:
+1. Outside the Territory;
+2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
+3. To harm Yourself or others;
+4. To repurpose or distribute output from Tencent HunyuanWorld-Mirror or any Model Derivatives to harm Yourself or others;
+5. To override or circumvent the safety guardrails and safeguards We have put in place;
+6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
+7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
+8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
+9. To intentionally defame, disparage or otherwise harass others;
+10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
+11. To generate or disseminate personal identifiable information with the purpose of harming others;
+12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
+13. To impersonate another individual without consent, authorization, or legal right;
+14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
+15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
+16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
+17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
+18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
+19. For military purposes;
+20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
\ No newline at end of file
diff --git a/Notice.txt b/Notice.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bbaebba9e386016bdbde3a8e7f4fcd4de2ad0e2a
--- /dev/null
+++ b/Notice.txt
@@ -0,0 +1,76 @@
+Usage and Legal Notices:
+
+Tencent is pleased to support the open source community by making Tencent HunyuanWorld-Mirror .
+
+Copyright (C) 2025 Tencent. All rights reserved. The below software and/or models in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
+
+Tencent HunyuanWorld-Mirror is licensed under the TENCENT HUNYUANWORLD-MIRROR COMMUNITY LICENSE AGREEMENT except for the third-party components listed below, which is licensed under different terms. TTencent HunyuanWorld-Mirror does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
+
+For avoidance of doubts, Tencent HunyuanWorld-Mirror means inference-enabling code, parameters, and/or weights of this Model, which are made publicly available by Tencent in accordance with TENCENT HUNYUANWORLD-MIRROR COMMUNITY LICENSE AGREEMENT.
+
+
+Other dependencies and licenses:
+
+
+
+Open Source Software Licensed under the Apache-2.0:
+--------------------------------------------------------------------
+1. gsplat
+
+Copyright 2025 Nerfstudio Team.
+
+You can access this component through: https://github.com/nerfstudio-project/gsplat
+
+Terms of the Apache-2.0:
+--------------------------------------------------------------------
+Apache License
+Version 2.0, January 2004
+http://www.apache.org/licenses/
+
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+1. Definitions.
+
+"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
+
+"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
+
+"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
+
+"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
+
+"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
+
+"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
+
+"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
+
+"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
+
+"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+
+"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
+
+You must give any other recipients of the Work or Derivative Works a copy of this License; and
+You must cause any modified files to carry prominent notices stating that You changed the files; and
+You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
+If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
\ No newline at end of file
diff --git a/README.md b/README.md
index 87a426db82b9391b5bb134ac0bc5cab64f7fedfd..3e2540b3283c9b86e8e0fcbbc949c882e9a0e6af 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,11 @@
---
-title: HunyuanWorld Mirror
-emoji: 🐠
-colorFrom: indigo
-colorTo: pink
+title: HunyuanWorld-Mirror
+emoji: 🌍
+colorFrom: purple
+colorTo: red
sdk: gradio
sdk_version: 5.49.1
app_file: app.py
pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+short_description: Universal 3D World Reconstruction with Any Prior Prompting
+---
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce4175c9f40e5099080b448e96ae6b6326f4af00
--- /dev/null
+++ b/app.py
@@ -0,0 +1,1817 @@
+import gc
+import os
+import shutil
+import time
+from datetime import datetime
+import io
+import sys
+
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+
+import cv2
+import gradio as gr
+import numpy as np
+import spaces
+import torch
+from PIL import Image
+from pillow_heif import register_heif_opener
+register_heif_opener()
+
+from src.utils.inference_utils import load_and_preprocess_images
+from src.utils.geometry import (
+ depth_edge,
+ normals_edge
+)
+from src.utils.visual_util import (
+ convert_predictions_to_glb_scene,
+ segment_sky,
+ download_file_from_url
+)
+from src.utils.save_utils import save_camera_params, save_gs_ply, process_ply_to_splat, convert_gs_to_ply
+from src.utils.render_utils import render_interpolated_video
+import onnxruntime
+
+
+# Initialize model - this will be done on GPU when needed
+model = None
+
+# Global variable to store current terminal output
+current_terminal_output = ""
+
+# Helper class to capture terminal output
+class TeeOutput:
+ """Capture output while still printing to console"""
+ def __init__(self, max_chars=10000):
+ self.terminal = sys.stdout
+ self.log = io.StringIO()
+ self.max_chars = max_chars # 限制最大字符数
+
+ def write(self, message):
+ global current_terminal_output
+ self.terminal.write(message)
+ self.log.write(message)
+
+ # 获取当前内容并限制长度
+ content = self.log.getvalue()
+ if len(content) > self.max_chars:
+ # 只保留最后 max_chars 个字符
+ content = "...(earlier output truncated)...\n" + content[-self.max_chars:]
+ self.log = io.StringIO()
+ self.log.write(content)
+
+ current_terminal_output = self.log.getvalue()
+
+ def flush(self):
+ self.terminal.flush()
+
+ def getvalue(self):
+ return self.log.getvalue()
+
+ def clear(self):
+ global current_terminal_output
+ self.log = io.StringIO()
+ current_terminal_output = ""
+
+# -------------------------------------------------------------------------
+# Model inference
+# -------------------------------------------------------------------------
+@spaces.GPU()
+def run_model(
+ target_dir,
+ confidence_percentile: float = 10,
+ edge_normal_threshold: float = 5.0,
+ edge_depth_threshold: float = 0.03,
+ apply_confidence_mask: bool = True,
+ apply_edge_mask: bool = True,
+):
+ """
+ Run the WorldMirror model on images in the 'target_dir/images' folder and return predictions.
+ """
+ global model
+ import torch # Ensure torch is available in function scope
+
+ from src.models.models.worldmirror import WorldMirror
+ from src.models.utils.geometry import depth_to_world_coords_points
+
+ print(f"Processing images from {target_dir}")
+
+ # Device check
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Initialize model if not already done
+ if model is None:
+ model = WorldMirror.from_pretrained("tencent/HunyuanWorld-Mirror").to(device)
+ else:
+ model.to(device)
+
+ model.eval()
+
+ # Load images using WorldMirror's load_images function
+ print("Loading images...")
+ image_folder_path = os.path.join(target_dir, "images")
+ image_file_paths = [os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]
+ img = load_and_preprocess_images(image_file_paths).to(device)
+
+ print(f"Loaded {img.shape[1]} images")
+ if img.shape[1] == 0:
+ raise ValueError("No images found. Check your upload.")
+
+ # Run model inference
+ print("Running inference...")
+ inputs = {}
+ inputs['img'] = img
+ use_amp = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
+ if use_amp:
+ amp_dtype = torch.bfloat16
+ else:
+ amp_dtype = torch.float32
+ with torch.amp.autocast('cuda', enabled=bool(use_amp), dtype=amp_dtype):
+ predictions = model(inputs)
+
+ # img
+ imgs = inputs["img"].permute(0, 1, 3, 4, 2)
+ imgs = imgs[0].detach().cpu().numpy() # S H W 3
+
+ # depth output
+ depth_preds = predictions["depth"]
+ depth_conf = predictions["depth_conf"]
+ depth_preds = depth_preds[0].detach().cpu().numpy() # S H W 1
+ depth_conf = depth_conf[0].detach().cpu().numpy() # S H W
+
+ # normal output
+ normal_preds = predictions["normals"] # S H W 3
+ normal_preds = normal_preds[0].detach().cpu().numpy() # S H W 3
+
+ # camera parameters
+ camera_poses = predictions["camera_poses"][0].detach().cpu().numpy() # [S,4,4]
+ camera_intrs = predictions["camera_intrs"][0].detach().cpu().numpy() # [S,3,3]
+
+ # points output
+ pts3d_preds = depth_to_world_coords_points(predictions["depth"][0, ..., 0], predictions["camera_poses"][0], predictions["camera_intrs"][0])[0]
+ pts3d_preds = pts3d_preds.detach().cpu().numpy() # S H W 3
+ pts3d_conf = depth_conf # S H W
+
+ # sky mask segmentation
+ if not os.path.exists("skyseg.onnx"):
+ print("Downloading skyseg.onnx...")
+ download_file_from_url(
+ "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx"
+ )
+ skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
+ sky_mask_list = []
+ for i, img_path in enumerate([os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]):
+ sky_mask = segment_sky(img_path, skyseg_session)
+ # Resize mask to match H×W if needed
+ if sky_mask.shape[0] != imgs.shape[1] or sky_mask.shape[1] != imgs.shape[2]:
+ sky_mask = cv2.resize(sky_mask, (imgs.shape[2], imgs.shape[1]))
+ sky_mask_list.append(sky_mask)
+ sky_mask = np.stack(sky_mask_list, axis=0) # [S, H, W]
+ sky_mask = sky_mask>0
+
+ # mask computation
+ final_mask_list = []
+ for i in range(inputs["img"].shape[1]):
+ final_mask = None
+ if apply_confidence_mask:
+ # compute confidence mask based on the pointmap confidence
+ confidences = pts3d_conf[i, :, :] # [H, W]
+ percentile_threshold = np.quantile(confidences, confidence_percentile / 100.0)
+ conf_mask = confidences >= percentile_threshold
+ if final_mask is None:
+ final_mask = conf_mask
+ else:
+ final_mask = final_mask & conf_mask
+ if apply_edge_mask:
+ # compute edge mask based on the normalmap
+ normal_pred = normal_preds[i] # [H, W, 3]
+ normal_edges = normals_edge(
+ normal_pred, tol=edge_normal_threshold, mask=final_mask
+ )
+ # compute depth mask based on the depthmap
+ depth_pred = depth_preds[i, :, :, 0] # [H, W]
+ depth_edges = depth_edge(
+ depth_pred, rtol=edge_depth_threshold, mask=final_mask
+ )
+ edge_mask = ~(depth_edges & normal_edges)
+ if final_mask is None:
+ final_mask = edge_mask
+ else:
+ final_mask = final_mask & edge_mask
+ final_mask_list.append(final_mask)
+
+ if final_mask_list[0] is not None:
+ final_mask = np.stack(final_mask_list, axis=0) # [S, H, W]
+ else:
+ final_mask = np.ones(pts3d_conf.shape[:3], dtype=bool) # [S, H, W]
+
+ # gaussian splatting output
+ if "splats" in predictions:
+ splats_dict = {}
+ splats_dict['means'] = predictions["splats"]["means"]
+ splats_dict['scales'] = predictions["splats"]["scales"]
+ splats_dict['quats'] = predictions["splats"]["quats"]
+ splats_dict['opacities'] = predictions["splats"]["opacities"]
+ if "sh" in predictions["splats"]:
+ splats_dict['sh'] = predictions["splats"]["sh"]
+ if "colors" in predictions["splats"]:
+ splats_dict['colors'] = predictions["splats"]["colors"]
+
+ # output lists
+ outputs = {}
+ outputs['images'] = imgs
+ outputs['world_points'] = pts3d_preds
+ outputs['depth'] = depth_preds
+ outputs['normal'] = normal_preds
+ outputs['final_mask'] = final_mask
+ outputs['sky_mask'] = sky_mask
+ outputs['camera_poses'] = camera_poses
+ outputs['camera_intrs'] = camera_intrs
+ if "splats" in predictions:
+ outputs['splats'] = splats_dict
+
+ # Process data for visualization tabs (depth, normal)
+ processed_data = prepare_visualization_data(
+ outputs, inputs
+ )
+
+ # Clean up
+ torch.cuda.empty_cache()
+
+ return outputs, processed_data
+
+
+# -------------------------------------------------------------------------
+# Update and navigation function
+# -------------------------------------------------------------------------
+def update_view_info(current_view, total_views, view_type="Depth"):
+ """Update view information display"""
+ return f"""
+
+ {view_type} View Navigation |
+ Current: View {current_view} / {total_views} views
+
+ """
+
+def update_view_selectors(processed_data):
+ """Update view selector sliders and info displays based on available views"""
+ if processed_data is None or len(processed_data) == 0:
+ num_views = 1
+ else:
+ num_views = len(processed_data)
+
+ # 确保 num_views 至少为 1
+ num_views = max(1, num_views)
+
+ # 更新滑块的最大值和视图信息,使用 gr.update() 而不是创建新组件
+ depth_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1)
+ normal_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1)
+
+ # 更新视图信息显示
+ depth_info_update = update_view_info(1, num_views, "Depth")
+ normal_info_update = update_view_info(1, num_views, "Normal")
+
+ return (
+ depth_slider_update, # depth_view_slider
+ normal_slider_update, # normal_view_slider
+ depth_info_update, # depth_view_info
+ normal_info_update, # normal_view_info
+ )
+
+def get_view_data_by_index(processed_data, view_index):
+ """Get view data by index, handling bounds"""
+ if processed_data is None or len(processed_data) == 0:
+ return None
+
+ view_keys = list(processed_data.keys())
+ if view_index < 0 or view_index >= len(view_keys):
+ view_index = 0
+
+ return processed_data[view_keys[view_index]]
+
+def update_depth_view(processed_data, view_index):
+ """Update depth view for a specific view index"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None or view_data["depth"] is None:
+ return None
+
+ return render_depth_visualization(view_data["depth"], mask=view_data.get("mask"))
+
+def update_normal_view(processed_data, view_index):
+ """Update normal view for a specific view index"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None or view_data["normal"] is None:
+ return None
+
+ return render_normal_visualization(view_data["normal"], mask=view_data.get("mask"))
+
+def initialize_depth_normal_views(processed_data):
+ """Initialize the depth and normal view displays with the first view data"""
+ if processed_data is None or len(processed_data) == 0:
+ return None, None
+
+ # Use update functions to ensure confidence filtering is applied from the start
+ depth_vis = update_depth_view(processed_data, 0)
+ normal_vis = update_normal_view(processed_data, 0)
+
+ return depth_vis, normal_vis
+
+
+# -------------------------------------------------------------------------
+# File upload and update preview gallery
+# -------------------------------------------------------------------------
+def process_uploaded_files(files, time_interval=1.0):
+ """
+ Process uploaded files by extracting video frames or copying images.
+
+ Args:
+ files: List of uploaded file objects (videos or images)
+ time_interval: Interval in seconds for video frame extraction
+
+ Returns:
+ tuple: (target_dir, image_paths) where target_dir is the output directory
+ and image_paths is a list of processed image file paths
+ """
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Create unique output directory
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ target_dir = f"input_images_{timestamp}"
+ images_dir = os.path.join(target_dir, "images")
+
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ os.makedirs(images_dir)
+
+ image_paths = []
+
+ if files is None:
+ return target_dir, image_paths
+
+ video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
+
+ for file_data in files:
+ # Get file path
+ if isinstance(file_data, dict) and "name" in file_data:
+ src_path = file_data["name"]
+ else:
+ src_path = str(file_data)
+
+ ext = os.path.splitext(src_path)[1].lower()
+ base_name = os.path.splitext(os.path.basename(src_path))[0]
+
+ # Process video: extract frames
+ if ext in video_exts:
+ cap = cv2.VideoCapture(src_path)
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ interval = int(fps * time_interval)
+
+ frame_count = 0
+ saved_count = 0
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ frame_count += 1
+ if frame_count % interval == 0:
+ dst_path = os.path.join(images_dir, f"{base_name}_{saved_count:06}.png")
+ cv2.imwrite(dst_path, frame)
+ image_paths.append(dst_path)
+ saved_count += 1
+ cap.release()
+ print(f"Extracted {saved_count} frames from: {os.path.basename(src_path)}")
+
+ # Process HEIC/HEIF: convert to JPEG
+ elif ext in [".heic", ".heif"]:
+ try:
+ with Image.open(src_path) as img:
+ if img.mode not in ("RGB", "L"):
+ img = img.convert("RGB")
+ dst_path = os.path.join(images_dir, f"{base_name}.jpg")
+ img.save(dst_path, "JPEG", quality=95)
+ image_paths.append(dst_path)
+ print(f"Converted HEIC: {os.path.basename(src_path)} -> {os.path.basename(dst_path)}")
+ except Exception as e:
+ print(f"HEIC conversion failed for {src_path}: {e}")
+ dst_path = os.path.join(images_dir, os.path.basename(src_path))
+ shutil.copy(src_path, dst_path)
+ image_paths.append(dst_path)
+
+ # Process regular images: copy directly
+ else:
+ dst_path = os.path.join(images_dir, os.path.basename(src_path))
+ shutil.copy(src_path, dst_path)
+ image_paths.append(dst_path)
+
+ image_paths = sorted(image_paths)
+
+ print(f"Processed files to {images_dir}")
+ return target_dir, image_paths
+
+# Handle file upload and update preview gallery
+def update_gallery_on_upload(input_video, input_images, time_interval=1.0):
+ """
+ Process uploaded files immediately when user uploads or changes files,
+ and display them in the gallery. Returns (target_dir, image_paths).
+ If nothing is uploaded, returns None and empty list.
+ """
+ if not input_video and not input_images:
+ return None, None, None, None
+ target_dir, image_paths = process_uploaded_files(input_video, input_images, time_interval)
+ return (
+ None,
+ target_dir,
+ image_paths,
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
+ )
+
+# -------------------------------------------------------------------------
+# Init function
+# -------------------------------------------------------------------------
+def prepare_visualization_data(
+ model_outputs, input_views
+):
+ """Transform model predictions into structured format for display components"""
+ visualization_dict = {}
+
+ # Iterate through each input view
+ nviews = input_views["img"].shape[1]
+ for idx in range(nviews):
+ # Extract RGB image data
+ rgb_image = input_views["img"][0, idx].detach().cpu().numpy()
+
+ # Retrieve 3D coordinate predictions
+ world_coordinates = model_outputs["world_points"][idx]
+
+ # Build view-specific data structure
+ current_view_info = {
+ "image": rgb_image,
+ "points3d": world_coordinates,
+ "depth": None,
+ "normal": None,
+ "mask": None,
+ }
+
+ # Apply final segmentation mask from model
+ segmentation_mask = model_outputs["final_mask"][idx].copy()
+
+ current_view_info["mask"] = segmentation_mask
+ current_view_info["depth"] = model_outputs["depth"][idx].squeeze()
+
+ surface_normals = model_outputs["normal"][idx]
+ current_view_info["normal"] = surface_normals
+
+ visualization_dict[idx] = current_view_info
+
+ return visualization_dict
+
+@spaces.GPU()
+def gradio_demo(
+ target_dir,
+ frame_selector="All",
+ show_camera=False,
+ filter_sky_bg=False,
+ show_mesh=False,
+ filter_ambiguous=False,
+):
+ """
+ Perform reconstruction using the already-created target_dir/images.
+ """
+ # Capture terminal output
+ tee = TeeOutput()
+ old_stdout = sys.stdout
+ sys.stdout = tee
+
+ try:
+ if not os.path.isdir(target_dir) or target_dir == "None":
+ terminal_log = tee.getvalue()
+ sys.stdout = old_stdout
+ return None, "No valid target directory found. Please upload first.", None, None, None, None, None, None, None, None, None, None, None, None, terminal_log
+
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Prepare frame_selector dropdown
+ target_dir_images = os.path.join(target_dir, "images")
+ all_files = (
+ sorted(os.listdir(target_dir_images))
+ if os.path.isdir(target_dir_images)
+ else []
+ )
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
+ frame_selector_choices = ["All"] + all_files
+
+ print("Running WorldMirror model...")
+ with torch.no_grad():
+ predictions, processed_data = run_model(target_dir)
+
+ # Save predictions
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
+ np.savez(prediction_save_path, **predictions)
+
+ # Save camera parameters as JSON
+ camera_params_file = save_camera_params(
+ predictions['camera_poses'],
+ predictions['camera_intrs'],
+ target_dir
+ )
+
+ # Handle None frame_selector
+ if frame_selector is None:
+ frame_selector = "All"
+
+ # Build a GLB file name
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene_{frame_selector.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_camera}_mesh{show_mesh}.glb",
+ )
+
+ # Convert predictions to GLB
+ glbscene = convert_predictions_to_glb_scene(
+ predictions,
+ filter_by_frames=frame_selector,
+ show_camera=show_camera,
+ mask_sky_bg=filter_sky_bg,
+ as_mesh=show_mesh, # Use the show_mesh parameter
+ mask_ambiguous=filter_ambiguous
+ )
+ glbscene.export(file_obj=glbfile)
+
+ end_time = time.time()
+ print(f"Total time: {end_time - start_time:.2f} seconds")
+ log_msg = (
+ f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
+ )
+ # Convert predictions to 3dgs ply
+ gs_file = None
+ splat_mode = 'ply'
+ if "splats" in predictions:
+ # Get Gaussian parameters (already filtered by GaussianSplatRenderer)
+ means = predictions["splats"]["means"][0].reshape(-1, 3)
+ scales = predictions["splats"]["scales"][0].reshape(-1, 3)
+ quats = predictions["splats"]["quats"][0].reshape(-1, 4)
+ colors = (predictions["splats"]["sh"][0] if "sh" in predictions["splats"] else predictions["splats"]["colors"][0]).reshape(-1, 3)
+ opacities = predictions["splats"]["opacities"][0].reshape(-1)
+
+ # Convert to torch tensors if needed
+ if not isinstance(means, torch.Tensor):
+ means = torch.from_numpy(means)
+ if not isinstance(scales, torch.Tensor):
+ scales = torch.from_numpy(scales)
+ if not isinstance(quats, torch.Tensor):
+ quats = torch.from_numpy(quats)
+ if not isinstance(colors, torch.Tensor):
+ colors = torch.from_numpy(colors)
+ if not isinstance(opacities, torch.Tensor):
+ opacities = torch.from_numpy(opacities)
+
+ if splat_mode == 'ply':
+ gs_file = os.path.join(target_dir, "gaussians.ply")
+ save_gs_ply(
+ gs_file,
+ means,
+ scales,
+ quats,
+ colors,
+ opacities
+ )
+ print(f"Saved Gaussian Splatting PLY to: {gs_file}")
+ print(f"File exists: {os.path.exists(gs_file)}")
+ if os.path.exists(gs_file):
+ print(f"File size: {os.path.getsize(gs_file)} bytes")
+ elif splat_mode == 'splat':
+ # Save Gaussian splat
+ plydata = convert_gs_to_ply(
+ means,
+ scales,
+ quats,
+ colors,
+ opacities
+ )
+ gs_file = os.path.join(target_dir, "gaussians.splat")
+ gs_file = process_ply_to_splat(plydata, gs_file)
+
+ # Initialize depth and normal view displays with processed data
+ depth_vis, normal_vis = initialize_depth_normal_views(
+ processed_data
+ )
+
+ # Update view selectors and info displays based on available views
+ depth_slider, normal_slider, depth_info, normal_info = update_view_selectors(
+ processed_data
+ )
+
+ # Automatically generate render video
+ # Generate render video if possible
+ rgb_video_path = None
+ depth_video_path = None
+
+ if "splats" in predictions:
+ # try:
+ from pathlib import Path
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ # Get camera parameters and image dimensions
+ camera_poses = torch.tensor(predictions['camera_poses']).unsqueeze(0).to(device)
+ camera_intrs = torch.tensor(predictions['camera_intrs']).unsqueeze(0).to(device)
+ H, W = predictions['images'].shape[1], predictions['images'].shape[2]
+
+ # Render video
+ out_path = Path(target_dir) / "rendered_video"
+ render_interpolated_video(
+ model.gs_renderer,
+ predictions["splats"],
+ camera_poses,
+ camera_intrs,
+ (H, W),
+ out_path,
+ interp_per_pair=15,
+ loop_reverse=True,
+ save_mode="split"
+ )
+
+ # Check output files
+ rgb_video_path = str(out_path) + "_rgb.mp4"
+ depth_video_path = str(out_path) + "_depth.mp4"
+
+ if not os.path.exists(rgb_video_path) and not os.path.exists(depth_video_path):
+ rgb_video_path = None
+ depth_video_path = None
+
+ # Cleanup
+ del predictions
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Get terminal output and restore stdout
+ terminal_log = tee.getvalue()
+ sys.stdout = old_stdout
+
+ return (
+ glbfile,
+ log_msg,
+ gr.Dropdown(choices=frame_selector_choices, value=frame_selector, interactive=True),
+ processed_data,
+ depth_vis,
+ normal_vis,
+ depth_slider,
+ normal_slider,
+ depth_info,
+ normal_info,
+ camera_params_file,
+ gs_file,
+ rgb_video_path,
+ depth_video_path,
+ terminal_log,
+ )
+
+ except Exception as e:
+ # In case of error, still restore stdout
+ terminal_log = tee.getvalue()
+ sys.stdout = old_stdout
+ print(f"Error occurred: {e}")
+ raise
+
+
+# -------------------------------------------------------------------------
+# Helper functions for visualization
+# -------------------------------------------------------------------------
+def render_depth_visualization(depth_map, mask=None):
+ """Generate a color-coded depth visualization image with masking capabilities"""
+ if depth_map is None:
+ return None
+
+ # Create working copy and identify positive depth values
+ depth_copy = depth_map.copy()
+ positive_depth_mask = depth_copy > 0
+
+ # Combine with user-provided mask for filtering
+ if mask is not None:
+ positive_depth_mask = positive_depth_mask & mask
+
+ # Perform percentile-based normalization on valid regions
+ if positive_depth_mask.sum() > 0:
+ valid_depth_values = depth_copy[positive_depth_mask]
+ lower_bound = np.percentile(valid_depth_values, 5)
+ upper_bound = np.percentile(valid_depth_values, 95)
+
+ depth_copy[positive_depth_mask] = (depth_copy[positive_depth_mask] - lower_bound) / (upper_bound - lower_bound)
+
+ # Convert to RGB using matplotlib colormap
+ import matplotlib.pyplot as plt
+
+ color_mapper = plt.cm.turbo_r
+ rgb_result = color_mapper(depth_copy)
+ rgb_result = (rgb_result[:, :, :3] * 255).astype(np.uint8)
+
+ # Mark invalid regions with white color
+ rgb_result[~positive_depth_mask] = [255, 255, 255]
+
+ return rgb_result
+
+def render_normal_visualization(normal_map, mask=None):
+ """Convert surface normal vectors to RGB color representation for display"""
+ if normal_map is None:
+ return None
+
+ # Make a working copy to avoid modifying original data
+ normal_display = normal_map.copy()
+
+ # Handle masking by zeroing out invalid regions
+ if mask is not None:
+ masked_regions = ~mask
+ normal_display[masked_regions] = [0, 0, 0] # Zero out masked pixels
+
+ # Transform from [-1, 1] to [0, 1] range for RGB display
+ normal_display = (normal_display + 1.0) / 2.0
+ normal_display = (normal_display * 255).astype(np.uint8)
+
+ return normal_display
+
+
+def clear_fields():
+ """
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
+ """
+ return None
+
+
+def update_log():
+ """
+ Display a quick log message while waiting.
+ """
+ return "Loading and Reconstructing..."
+
+
+def get_terminal_output():
+ """
+ Get current terminal output for real-time display
+ """
+ global current_terminal_output
+ return current_terminal_output
+
+# -------------------------------------------------------------------------
+# FunctionExample scene metadata extraction
+# -------------------------------------------------------------------------
+def extract_example_scenes_metadata(base_directory):
+ """
+ Extract comprehensive metadata for all scene directories containing valid images.
+
+ Args:
+ base_directory: Root path where example scene directories are located
+
+ Returns:
+ Collection of dictionaries with scene details (title, location, preview, etc.)
+ """
+ from glob import glob
+
+ # Return empty list if base directory is missing
+ if not os.path.exists(base_directory):
+ return []
+
+ # Define supported image format extensions
+ VALID_IMAGE_FORMATS = ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif']
+
+ scenes_data = []
+
+ # Process each subdirectory in the base directory
+ for directory_name in sorted(os.listdir(base_directory)):
+ current_directory = os.path.join(base_directory, directory_name)
+
+ # Filter out non-directory items
+ if not os.path.isdir(current_directory):
+ continue
+
+ # Gather all valid image files within the current directory
+ discovered_images = []
+ for file_format in VALID_IMAGE_FORMATS:
+ # Include both lowercase and uppercase format variations
+ discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format}')))
+ discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format.upper()}')))
+
+ # Skip directories without any valid images
+ if not discovered_images:
+ continue
+
+ # Ensure consistent image ordering
+ discovered_images.sort()
+
+ # Construct scene metadata record
+ scene_record = {
+ 'name': directory_name,
+ 'path': current_directory,
+ 'thumbnail': discovered_images[0],
+ 'num_images': len(discovered_images),
+ 'image_files': discovered_images,
+ }
+
+ scenes_data.append(scene_record)
+
+ return scenes_data
+
+def load_example_scenes(scene_name, scenes):
+ """
+ Initialize and prepare an example scene for 3D reconstruction processing.
+
+ Args:
+ scene_name: Identifier of the target scene to load
+ scenes: List containing all available scene configurations
+
+ Returns:
+ Tuple containing processed scene data and status information
+ """
+ # Locate the target scene configuration by matching names
+ target_scene_config = None
+ for scene_config in scenes:
+ if scene_config["name"] == scene_name:
+ target_scene_config = scene_config
+ break
+
+ # Handle case where requested scene doesn't exist
+ if target_scene_config is None:
+ return None, None, None, "Scene not found"
+
+ # Prepare image file paths for processing pipeline
+ # Extract all image file paths from the selected scene
+ image_file_paths = []
+ for img_file_path in target_scene_config["image_files"]:
+ image_file_paths.append(img_file_path)
+
+ # Process the scene images through the standard upload pipeline
+ processed_target_dir, processed_image_list = process_uploaded_files(image_file_paths, 1.0)
+
+ # Return structured response with scene data and user feedback
+ status_message = f"Successfully loaded scene '{scene_name}' containing {target_scene_config['num_images']} images. Click 'Reconstruct' to begin 3D processing."
+
+ return (
+ None, # Reset reconstruction visualization
+ None, # Reset gaussian splatting output
+ processed_target_dir, # Provide working directory path
+ processed_image_list, # Update image gallery display
+ status_message,
+ )
+
+
+# -------------------------------------------------------------------------
+# UI and event handling
+# -------------------------------------------------------------------------
+theme = gr.themes.Base()
+
+with gr.Blocks(
+ theme=theme,
+ css="""
+ .custom-log * {
+ font-style: italic;
+ font-size: 22px !important;
+ background-image: linear-gradient(120deg, #a9b8f8 0%, #7081e8 60%, #4254c5 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ font-weight: bold !important;
+ color: transparent !important;
+ text-align: center !important;
+ }
+ .normal-weight-btn button,
+ .normal-weight-btn button span,
+ .normal-weight-btn button *,
+ .normal-weight-btn * {
+ font-weight: 400 !important;
+ }
+ .terminal-output {
+ max-height: 400px !important;
+ overflow-y: auto !important;
+ }
+ .terminal-output textarea {
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
+ font-size: 13px !important;
+ line-height: 1.5 !important;
+ color: #333 !important;
+ background-color: #f8f9fa !important;
+ max-height: 400px !important;
+ }
+ .example-gallery {
+ width: 100% !important;
+ }
+ .example-gallery img {
+ width: 100% !important;
+ height: 280px !important;
+ object-fit: contain !important;
+ aspect-ratio: 16 / 9 !important;
+ }
+ .example-gallery .grid-wrap {
+ width: 100% !important;
+ }
+
+ /* 滑块导航样式 */
+ .depth-tab-improved .gradio-slider input[type="range"] {
+ height: 8px !important;
+ border-radius: 4px !important;
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
+ }
+
+ .depth-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb {
+ height: 20px !important;
+ width: 20px !important;
+ border-radius: 50% !important;
+ background: #fff !important;
+ box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important;
+ }
+
+ .depth-tab-improved button {
+ transition: all 0.3s ease !important;
+ border-radius: 6px !important;
+ font-weight: 500 !important;
+ }
+
+ .depth-tab-improved button:hover {
+ transform: translateY(-1px) !important;
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
+ }
+
+ .normal-tab-improved .gradio-slider input[type="range"] {
+ height: 8px !important;
+ border-radius: 4px !important;
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
+ }
+
+ .normal-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb {
+ height: 20px !important;
+ width: 20px !important;
+ border-radius: 50% !important;
+ background: #fff !important;
+ box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important;
+ }
+
+ .normal-tab-improved button {
+ transition: all 0.3s ease !important;
+ border-radius: 6px !important;
+ font-weight: 500 !important;
+ }
+
+ .normal-tab-improved button:hover {
+ transform: translateY(-1px) !important;
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
+ }
+
+ #depth-view-info, #normal-view-info {
+ animation: fadeIn 0.5s ease-in-out;
+ }
+
+ @keyframes fadeIn {
+ from { opacity: 0; transform: translateY(-10px); }
+ to { opacity: 1; transform: translateY(0); }
+ }
+ """
+) as demo:
+ # State variables for the tabbed interface
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
+ processed_data_state = gr.State(value=None)
+ current_view_index = gr.State(value=0) # Track current view index for navigation
+
+ # Header and description
+ gr.HTML(
+ """
+
+
+
WorldMirror supports any combination of inputs (images, intrinsics, poses, and depth) and multiple outputs including point clouds, camera parameters, depth maps, normal maps, and 3D Gaussian Splatting (3DGS).
+
How to Use:
+
+ - Upload Your Data: Click the "Upload Video or Images" button to add your files. Videos are automatically extracted into frames at one-second intervals.
+ - Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction.
+ - Visualize: Explore multiple reconstruction results across different tabs:
+
+ - 3D View: Interactive point cloud/mesh visualization with camera poses (downloadable as GLB)
+ - 3D Gaussian Splatting: Interactive 3D Gaussian Splatting visualization with RGB and depth videos (downloadable as PLY)
+ - Depth Maps: Per-view depth estimation results (downloadable as PNG)
+ - Normal Maps: Per-view surface orientation visualization (downloadable as PNG)
+ - Camera Parameters: Estimated camera poses and intrinsics (downloadable as JSON)
+
+
+
+
Please note: Loading data and displaying 3D effects may take a moment. For faster performance, we recommend downloading the code from our GitHub and running it locally.
+
+ """)
+
+ output_path_state = gr.Textbox(label="Output Path", visible=False, value="None")
+
+ # Main UI components
+ with gr.Row(equal_height=False):
+ with gr.Column(scale=1):
+ file_upload = gr.File(
+ file_count="multiple",
+ label="Upload Video or Images",
+ interactive=True,
+ file_types=["image", "video"],
+ height="200px",
+ )
+ time_interval = gr.Slider(
+ minimum=0.1,
+ maximum=10.0,
+ value=1.0,
+ step=0.1,
+ label="Video Sample interval",
+ interactive=True,
+ visible=True,
+ scale=4,
+ )
+ resample_btn = gr.Button(
+ "Resample",
+ visible=True,
+ scale=1,
+ elem_classes=["normal-weight-btn"],
+ )
+ image_gallery = gr.Gallery(
+ label="Image Preview",
+ columns=4,
+ height="200px",
+ show_download_button=True,
+ object_fit="contain",
+ preview=True
+ )
+
+ terminal_output = gr.Textbox(
+ label="Terminal Output",
+ lines=6,
+ max_lines=6,
+ interactive=False,
+ show_copy_button=True,
+ container=True,
+ elem_classes=["terminal-output"],
+ autoscroll=True
+ )
+
+ with gr.Column(scale=3):
+ log_output = gr.Markdown(
+ "Upload video or images first, then click Reconstruct to start processing",
+ elem_classes=["custom-log"],
+ )
+
+ with gr.Tabs() as tabs:
+ with gr.Tab("3D Gaussian Splatting", id=1) as gs_tab:
+ with gr.Row():
+ with gr.Column(scale=3):
+ gs_output = gr.Model3D(
+ label="Gaussian Splatting",
+ height=500,
+ )
+ with gr.Column(scale=1):
+ gs_rgb_video = gr.Video(
+ label="Rendered RGB Video",
+ height=250,
+ autoplay=False,
+ loop=False,
+ interactive=False,
+ )
+ gs_depth_video = gr.Video(
+ label="Rendered Depth Video",
+ height=250,
+ autoplay=False,
+ loop=False,
+ interactive=False,
+ )
+ with gr.Tab("Point Cloud/Mesh", id=0):
+ reconstruction_output = gr.Model3D(
+ label="3D Pointmap/Mesh",
+ height=500,
+ zoom_speed=0.4,
+ pan_speed=0.4,
+ )
+ with gr.Tab("Depth", elem_classes=["depth-tab-improved"]):
+ depth_view_info = gr.HTML(
+ value=""
+ "Depth View Navigation | Current: View 1 / 1 views
",
+ elem_id="depth-view-info"
+ )
+ depth_view_slider = gr.Slider(
+ minimum=1,
+ maximum=1,
+ step=1,
+ value=1,
+ label="View Selection Slider",
+ interactive=True,
+ elem_id="depth-view-slider"
+ )
+ depth_map = gr.Image(
+ type="numpy",
+ label="Depth Map",
+ format="png",
+ interactive=False,
+ height=340
+ )
+ with gr.Tab("Normal", elem_classes=["normal-tab-improved"]):
+ normal_view_info = gr.HTML(
+ value=""
+ "Normal View Navigation | Current: View 1 / 1 views
",
+ elem_id="normal-view-info"
+ )
+ normal_view_slider = gr.Slider(
+ minimum=1,
+ maximum=1,
+ step=1,
+ value=1,
+ label="View Selection Slider",
+ interactive=True,
+ elem_id="normal-view-slider"
+ )
+ normal_map = gr.Image(
+ type="numpy",
+ label="Normal Map",
+ format="png",
+ interactive=False,
+ height=340
+ )
+ with gr.Tab("Camera Parameters", elem_classes=["camera-tab"]):
+ with gr.Row():
+ gr.HTML("")
+ camera_params = gr.DownloadButton(
+ label="Download Camera Parameters",
+ scale=1,
+ variant="primary",
+ )
+ gr.HTML("")
+
+ with gr.Row():
+ reconstruct_btn = gr.Button(
+ "Reconstruct",
+ scale=1,
+ variant="primary"
+ )
+ clear_btn = gr.ClearButton(
+ [
+ file_upload,
+ reconstruction_output,
+ log_output,
+ output_path_state,
+ image_gallery,
+ depth_map,
+ normal_map,
+ depth_view_slider,
+ normal_view_slider,
+ depth_view_info,
+ normal_view_info,
+ camera_params,
+ gs_output,
+ gs_rgb_video,
+ gs_depth_video,
+ ],
+ scale=1,
+ )
+
+ with gr.Row():
+ frame_selector = gr.Dropdown(
+ choices=["All"], value="All", label="Show Points of a Specific Frame"
+ )
+
+ gr.Markdown("### Reconstruction Options: (not applied to 3DGS)")
+ with gr.Row():
+ show_camera = gr.Checkbox(label="Show Camera", value=True)
+ show_mesh = gr.Checkbox(label="Show Mesh", value=True)
+ filter_ambiguous = gr.Checkbox(label="Filter low confidence & depth/normal edges", value=True)
+ filter_sky_bg = gr.Checkbox(label="Filter Sky Background", value=False)
+
+ with gr.Column(scale=1):
+ gr.Markdown("### Click to load example scenes")
+ realworld_scenes = extract_example_scenes_metadata("examples/realistic") if os.path.exists("examples/realistic") else extract_example_scenes_metadata("examples")
+ generated_scenes = extract_example_scenes_metadata("examples/stylistic") if os.path.exists("examples/stylistic") else []
+
+ # If no subdirectories exist, fall back to single gallery
+ if not os.path.exists("examples/realistic") and not os.path.exists("examples/stylistic"):
+ # Fallback: use all scenes from examples directory
+ all_scenes = extract_example_scenes_metadata("examples")
+ if all_scenes:
+ gallery_items = [
+ (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
+ for scene in all_scenes
+ ]
+
+ example_gallery = gr.Gallery(
+ value=gallery_items,
+ label="Example Scenes",
+ columns=1,
+ rows=None,
+ height=800,
+ object_fit="contain",
+ show_label=False,
+ interactive=True,
+ preview=False,
+ allow_preview=False,
+ elem_classes=["example-gallery"]
+ )
+
+ def handle_example_selection(evt: gr.SelectData):
+ if evt:
+ result = load_example_scenes(all_scenes[evt.index]["name"], all_scenes)
+ return result
+ return (None, None, None, None, "No scene selected")
+
+ example_gallery.select(
+ fn=handle_example_selection,
+ outputs=[
+ reconstruction_output,
+ gs_output,
+ output_path_state,
+ image_gallery,
+ log_output,
+ ],
+ )
+ else:
+ # Tabbed interface for categorized examples
+ with gr.Tabs():
+ with gr.Tab("🌍 Realistic Cases"):
+ if realworld_scenes:
+ realworld_items = [
+ (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
+ for scene in realworld_scenes
+ ]
+
+ realworld_gallery = gr.Gallery(
+ value=realworld_items,
+ label="Real-world Examples",
+ columns=1,
+ rows=None,
+ height=750,
+ object_fit="contain",
+ show_label=False,
+ interactive=True,
+ preview=False,
+ allow_preview=False,
+ elem_classes=["example-gallery"]
+ )
+
+ def handle_realworld_selection(evt: gr.SelectData):
+ if evt:
+ result = load_example_scenes(realworld_scenes[evt.index]["name"], realworld_scenes)
+ return result
+ return (None, None, None, None, "No scene selected")
+
+ realworld_gallery.select(
+ fn=handle_realworld_selection,
+ outputs=[
+ reconstruction_output,
+ gs_output,
+ output_path_state,
+ image_gallery,
+ log_output,
+ ],
+ )
+ else:
+ gr.Markdown("No real-world examples available")
+
+ with gr.Tab("🎨 Stylistic Cases"):
+ if generated_scenes:
+ generated_items = [
+ (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
+ for scene in generated_scenes
+ ]
+
+ generated_gallery = gr.Gallery(
+ value=generated_items,
+ label="Generated Examples",
+ columns=1,
+ rows=None,
+ height=750,
+ object_fit="contain",
+ show_label=False,
+ interactive=True,
+ preview=False,
+ allow_preview=False,
+ elem_classes=["example-gallery"]
+ )
+
+ def handle_generated_selection(evt: gr.SelectData):
+ if evt:
+ result = load_example_scenes(generated_scenes[evt.index]["name"], generated_scenes)
+ return result
+ return (None, None, None, None, "No scene selected")
+
+ generated_gallery.select(
+ fn=handle_generated_selection,
+ outputs=[
+ reconstruction_output,
+ gs_output,
+ output_path_state,
+ image_gallery,
+ log_output,
+ ],
+ )
+ else:
+ gr.Markdown("No generated examples available")
+
+ # -------------------------------------------------------------------------
+ # Click logic
+ # -------------------------------------------------------------------------
+ reconstruct_btn.click(fn=clear_fields, inputs=[], outputs=[]).then(
+ fn=update_log, inputs=[], outputs=[log_output]
+ ).then(
+ fn=gradio_demo,
+ inputs=[
+ output_path_state,
+ frame_selector,
+ show_camera,
+ filter_sky_bg,
+ show_mesh,
+ filter_ambiguous
+ ],
+ outputs=[
+ reconstruction_output,
+ log_output,
+ frame_selector,
+ processed_data_state,
+ depth_map,
+ normal_map,
+ depth_view_slider,
+ normal_view_slider,
+ depth_view_info,
+ normal_view_info,
+ camera_params,
+ gs_output,
+ gs_rgb_video,
+ gs_depth_video,
+ terminal_output,
+ ],
+ ).then(
+ fn=lambda: "False",
+ inputs=[],
+ outputs=[is_example], # set is_example to "False"
+ )
+
+ # -------------------------------------------------------------------------
+ # Live update logic
+ # -------------------------------------------------------------------------
+ def refresh_3d_scene(
+ workspace_path,
+ frame_selector,
+ show_camera,
+ is_example,
+ filter_sky_bg=False,
+ show_mesh=False,
+ filter_ambiguous=False
+ ):
+ """
+ Refresh 3D scene visualization
+
+ Load prediction data from workspace, generate or reuse GLB scene files based on current parameters,
+ and return file paths needed for the 3D viewer.
+
+ Args:
+ workspace_path: Workspace directory path for reconstruction results
+ frame_selector: Frame selector value for filtering points from specific frames
+ show_camera: Whether to display camera positions
+ is_example: Whether this is an example scene
+ filter_sky_bg: Whether to filter sky background
+ show_mesh: Whether to display as mesh mode
+ filter_ambiguous: Whether to filter low-confidence ambiguous areas
+
+ Returns:
+ tuple: (GLB scene file path, Gaussian point cloud file path, status message)
+ """
+
+ # If example scene is clicked, skip processing directly
+ if is_example == "True":
+ return (
+ gr.update(),
+ gr.update(),
+ "No reconstruction results available. Please click the Reconstruct button first.",
+ )
+
+ # Validate workspace directory path
+ if not workspace_path or workspace_path == "None" or not os.path.isdir(workspace_path):
+ return (
+ gr.update(),
+ gr.update(),
+ "No reconstruction results available. Please click the Reconstruct button first.",
+ )
+
+ # Check if prediction data file exists
+ prediction_file_path = os.path.join(workspace_path, "predictions.npz")
+ if not os.path.exists(prediction_file_path):
+ return (
+ gr.update(),
+ gr.update(),
+ f"Prediction file does not exist: {prediction_file_path}. Please run reconstruction first.",
+ )
+
+ # Load prediction data
+ prediction_data = np.load(prediction_file_path, allow_pickle=True)
+ predictions = {key: prediction_data[key] for key in prediction_data.keys()}
+
+ # Generate GLB scene file path (named based on parameter combination)
+ safe_frame_name = frame_selector.replace('.', '_').replace(':', '').replace(' ', '_')
+ scene_filename = f"scene_{safe_frame_name}_cam{show_camera}_mesh{show_mesh}_edges{filter_ambiguous}_sky{filter_sky_bg}.glb"
+ scene_glb_path = os.path.join(workspace_path, scene_filename)
+
+ # If GLB file doesn't exist, generate new scene file
+ if not os.path.exists(scene_glb_path):
+ scene_model = convert_predictions_to_glb_scene(
+ predictions,
+ filter_by_frames=frame_selector,
+ show_camera=show_camera,
+ mask_sky_bg=filter_sky_bg,
+ as_mesh=show_mesh,
+ mask_ambiguous=filter_ambiguous
+ )
+ scene_model.export(file_obj=scene_glb_path)
+
+ # Find Gaussian point cloud file
+ gaussian_file_path = os.path.join(workspace_path, "gaussians.ply")
+ if not os.path.exists(gaussian_file_path):
+ gaussian_file_path = None
+
+ return (
+ scene_glb_path,
+ gaussian_file_path,
+ "3D scene updated.",
+ )
+
+ def refresh_view_displays_on_filter_update(
+ workspace_dir,
+ sky_background_filter,
+ current_processed_data,
+ depth_slider_position,
+ normal_slider_position,
+ ):
+ """
+ Refresh depth and normal view displays when filter settings change
+
+ When the background filter checkbox state changes, regenerate processed data and update all view displays.
+ This ensures that filter effects are reflected in real-time in the depth map and normal map visualizations.
+
+ Args:
+ workspace_dir: Workspace directory path containing prediction data and images
+ sky_background_filter: Sky background filter enable status
+ current_processed_data: Currently processed visualization data
+ depth_slider_position: Current position of the depth view slider
+ normal_slider_position: Current position of the normal view slider
+
+ Returns:
+ tuple: (updated processed data, depth visualization result, normal visualization result)
+ """
+
+ # Validate workspace directory validity
+ if not workspace_dir or workspace_dir == "None" or not os.path.isdir(workspace_dir):
+ return current_processed_data, None, None
+
+ # Build and check prediction data file path
+ prediction_data_path = os.path.join(workspace_dir, "predictions.npz")
+ if not os.path.exists(prediction_data_path):
+ return current_processed_data, None, None
+
+ try:
+ # Load raw prediction data
+ raw_prediction_data = np.load(prediction_data_path, allow_pickle=True)
+ predictions_dict = {key: raw_prediction_data[key] for key in raw_prediction_data.keys()}
+
+ # Load image data using WorldMirror's load_images function
+ images_directory = os.path.join(workspace_dir, "images")
+ image_file_paths = [os.path.join(images_directory, path) for path in os.listdir(images_directory)]
+ img = load_and_preprocess_images(image_file_paths)
+
+ # Regenerate processed data with new filter settings
+ refreshed_data = {}
+ for view_idx in range(img.shape[1]):
+ view_data = {
+ "image": img[0, view_idx],
+ "points3d": predictions_dict["world_points"][view_idx],
+ "depth": None,
+ "normal": None,
+ "mask": None,
+ }
+ mask = predictions_dict["final_mask"][view_idx].copy()
+ if sky_background_filter:
+ sky_mask = predictions_dict["sky_mask"][view_idx]
+ mask = mask & sky_mask
+ view_data["mask"] = mask
+ view_data["depth"] = predictions_dict["depth"][view_idx].squeeze()
+ view_data["normal"] = predictions_dict["normal"][view_idx]
+ refreshed_data[view_idx] = view_data
+
+ # Get current view indices from slider positions (convert to 0-based indices)
+ current_depth_index = int(depth_slider_position) - 1 if depth_slider_position else 0
+ current_normal_index = int(normal_slider_position) - 1 if normal_slider_position else 0
+
+ # Update depth and normal views with new filter data
+ updated_depth_visualization = update_depth_view(refreshed_data, current_depth_index)
+ updated_normal_visualization = update_normal_view(refreshed_data, current_normal_index)
+
+ return refreshed_data, updated_depth_visualization, updated_normal_visualization
+
+ except Exception as error:
+ print(f"Error occurred while refreshing view displays: {error}")
+ return current_processed_data, None, None
+
+ frame_selector.change(
+ refresh_3d_scene,
+ [
+ output_path_state,
+ frame_selector,
+ show_camera,
+ is_example,
+ filter_sky_bg,
+ show_mesh,
+ filter_ambiguous
+ ],
+ [reconstruction_output, gs_output, log_output],
+ )
+ show_camera.change(
+ refresh_3d_scene,
+ [
+ output_path_state,
+ frame_selector,
+ show_camera,
+ is_example,
+ filter_sky_bg,
+ show_mesh,
+ filter_ambiguous
+ ],
+ [reconstruction_output, gs_output, log_output],
+ )
+ show_mesh.change(
+ refresh_3d_scene,
+ [
+ output_path_state,
+ frame_selector,
+ show_camera,
+ is_example,
+ filter_sky_bg,
+ show_mesh,
+ filter_ambiguous
+ ],
+ [reconstruction_output, gs_output, log_output],
+ )
+
+ filter_sky_bg.change(
+ refresh_3d_scene,
+ [
+ output_path_state,
+ frame_selector,
+ show_camera,
+ is_example,
+ filter_sky_bg,
+ show_mesh,
+ filter_ambiguous
+ ],
+ [reconstruction_output, gs_output, log_output],
+ ).then(
+ fn=refresh_view_displays_on_filter_update,
+ inputs=[
+ output_path_state,
+ filter_sky_bg,
+ processed_data_state,
+ depth_view_slider,
+ normal_view_slider,
+ ],
+ outputs=[
+ processed_data_state,
+ depth_map,
+ normal_map,
+ ],
+ )
+ filter_ambiguous.change(
+ refresh_3d_scene,
+ [
+ output_path_state,
+ frame_selector,
+ show_camera,
+ is_example,
+ filter_sky_bg,
+ show_mesh,
+ filter_ambiguous
+ ],
+ [reconstruction_output, gs_output, log_output],
+ ).then(
+ fn=refresh_view_displays_on_filter_update,
+ inputs=[
+ output_path_state,
+ filter_sky_bg,
+ processed_data_state,
+ depth_view_slider,
+ normal_view_slider,
+ ],
+ outputs=[
+ processed_data_state,
+ depth_map,
+ normal_map,
+ ],
+ )
+
+ # -------------------------------------------------------------------------
+ # Auto update gallery when user uploads or changes files
+ # -------------------------------------------------------------------------
+ def update_gallery_on_file_upload(files, interval):
+ if not files:
+ return None, None, None, ""
+
+ # Capture terminal output
+ tee = TeeOutput()
+ old_stdout = sys.stdout
+ sys.stdout = tee
+
+ try:
+ target_dir, image_paths = process_uploaded_files(files, interval)
+ terminal_log = tee.getvalue()
+ sys.stdout = old_stdout
+
+ return (
+ target_dir,
+ image_paths,
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
+ terminal_log,
+ )
+ except Exception as e:
+ terminal_log = tee.getvalue()
+ sys.stdout = old_stdout
+ print(f"Error occurred: {e}")
+ raise
+
+ def resample_video_with_new_interval(files, new_interval, current_target_dir):
+ """Resample video with new slider value"""
+ if not files:
+ return (
+ current_target_dir,
+ None,
+ "No files to resample.",
+ "",
+ )
+
+ # Check if we have videos to resample
+ video_extensions = [
+ ".mp4",
+ ".avi",
+ ".mov",
+ ".mkv",
+ ".wmv",
+ ".flv",
+ ".webm",
+ ".m4v",
+ ".3gp",
+ ]
+ has_video = any(
+ os.path.splitext(
+ str(file_data["name"] if isinstance(file_data, dict) else file_data)
+ )[1].lower()
+ in video_extensions
+ for file_data in files
+ )
+
+ if not has_video:
+ return (
+ current_target_dir,
+ None,
+ "No videos found to resample.",
+ "",
+ )
+
+ # Capture terminal output
+ tee = TeeOutput()
+ old_stdout = sys.stdout
+ sys.stdout = tee
+
+ try:
+ # Clean up old target directory if it exists
+ if (
+ current_target_dir
+ and current_target_dir != "None"
+ and os.path.exists(current_target_dir)
+ ):
+ shutil.rmtree(current_target_dir)
+
+ # Process files with new interval
+ target_dir, image_paths = process_uploaded_files(files, new_interval)
+
+ terminal_log = tee.getvalue()
+ sys.stdout = old_stdout
+
+ return (
+ target_dir,
+ image_paths,
+ f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.",
+ terminal_log,
+ )
+ except Exception as e:
+ terminal_log = tee.getvalue()
+ sys.stdout = old_stdout
+ print(f"Error occurred: {e}")
+ raise
+
+ file_upload.change(
+ fn=update_gallery_on_file_upload,
+ inputs=[file_upload, time_interval],
+ outputs=[output_path_state, image_gallery, log_output, terminal_output],
+ )
+
+ resample_btn.click(
+ fn=resample_video_with_new_interval,
+ inputs=[file_upload, time_interval, output_path_state],
+ outputs=[output_path_state, image_gallery, log_output, terminal_output],
+ )
+
+ # -------------------------------------------------------------------------
+ # Navigation for Depth, Normal tabs
+ # -------------------------------------------------------------------------
+ def navigate_with_slider(processed_data, target_view):
+ """Navigate to specified view using slider"""
+ if processed_data is None or len(processed_data) == 0:
+ return None, update_view_info(1, 1)
+
+ # Check if target_view is None or invalid value, and safely convert to int
+ try:
+ if target_view is None:
+ target_view = 1
+ else:
+ target_view = int(float(target_view)) # Convert to float first then int, handle decimal input
+ except (ValueError, TypeError):
+ target_view = 1
+
+ total_views = len(processed_data)
+ # Ensure view index is within valid range
+ view_index = max(1, min(target_view, total_views)) - 1
+
+ # Update depth map
+ depth_vis = update_depth_view(processed_data, view_index)
+
+ # Update view information
+ info_html = update_view_info(view_index + 1, total_views)
+
+ return depth_vis, info_html
+
+ def navigate_with_slider_normal(processed_data, target_view):
+ """Navigate to specified normal view using slider"""
+ if processed_data is None or len(processed_data) == 0:
+ return None, update_view_info(1, 1, "Normal")
+
+ # Check if target_view is None or invalid value, and safely convert to int
+ try:
+ if target_view is None:
+ target_view = 1
+ else:
+ target_view = int(float(target_view)) # Convert to float first then int, handle decimal input
+ except (ValueError, TypeError):
+ target_view = 1
+
+ total_views = len(processed_data)
+ # Ensure view index is within valid range
+ view_index = max(1, min(target_view, total_views)) - 1
+
+ # Update normal map
+ normal_vis = update_normal_view(processed_data, view_index)
+
+ # Update view information
+ info_html = update_view_info(view_index + 1, total_views, "Normal")
+
+ return normal_vis, info_html
+
+ def handle_depth_slider_change(processed_data, target_view):
+ return navigate_with_slider(processed_data, target_view)
+
+ def handle_normal_slider_change(processed_data, target_view):
+ return navigate_with_slider_normal(processed_data, target_view)
+
+ depth_view_slider.change(
+ fn=handle_depth_slider_change,
+ inputs=[processed_data_state, depth_view_slider],
+ outputs=[depth_map, depth_view_info]
+ )
+
+ normal_view_slider.change(
+ fn=handle_normal_slider_change,
+ inputs=[processed_data_state, normal_view_slider],
+ outputs=[normal_map, normal_view_info]
+ )
+
+ # -------------------------------------------------------------------------
+ # Real-time terminal output update
+ # -------------------------------------------------------------------------
+ # Use a timer to periodically update terminal output
+ timer = gr.Timer(value=0.5) # Update every 0.5 seconds
+ timer.tick(
+ fn=get_terminal_output,
+ inputs=[],
+ outputs=[terminal_output]
+ )
+
+ gr.HTML("""
+
+
+ """)
+
+ demo.queue().launch(
+ show_error=True,
+ share=True,
+ ssr_mode=False,
+ )
diff --git a/examples/realistic/Archway_Tunnel/image_0001.jpg b/examples/realistic/Archway_Tunnel/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e2fa05028700ab9793feb255449665afd8a83a3e
Binary files /dev/null and b/examples/realistic/Archway_Tunnel/image_0001.jpg differ
diff --git a/examples/realistic/Archway_Tunnel/image_0030.jpg b/examples/realistic/Archway_Tunnel/image_0030.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bc5361f334c5eb061f260fb8be47f1a159bde5f1
Binary files /dev/null and b/examples/realistic/Archway_Tunnel/image_0030.jpg differ
diff --git a/examples/realistic/Bright_Room/image_0001.jpg b/examples/realistic/Bright_Room/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..30d6818094e7227a96332728075165b7c12c9d72
Binary files /dev/null and b/examples/realistic/Bright_Room/image_0001.jpg differ
diff --git a/examples/realistic/Bright_Room/image_0035.jpg b/examples/realistic/Bright_Room/image_0035.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0f863ea60ed20f7ea9c9b33ff0297215a1b8b259
Binary files /dev/null and b/examples/realistic/Bright_Room/image_0035.jpg differ
diff --git a/examples/realistic/Desk/530554609_3367433673396747_2161028887770608277_n.jpg b/examples/realistic/Desk/530554609_3367433673396747_2161028887770608277_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..81a14e68d682055689fc27b088b6b1a6da7734d6
Binary files /dev/null and b/examples/realistic/Desk/530554609_3367433673396747_2161028887770608277_n.jpg differ
diff --git a/examples/realistic/Desk/532328457_1311198870420578_2167456836351167380_n.jpg b/examples/realistic/Desk/532328457_1311198870420578_2167456836351167380_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..77d382184d2378fd017ae6466bdb9f48ffbcb8f2
--- /dev/null
+++ b/examples/realistic/Desk/532328457_1311198870420578_2167456836351167380_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc595d2d32d5a87edb75f3e76d08e2ce81a0a13deb4921736263fd96f6d08288
+size 108349
diff --git a/examples/realistic/Dining_Table/image_0001.jpg b/examples/realistic/Dining_Table/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7667165333bd19e3b763ea36d92a2b33c9bbf378
Binary files /dev/null and b/examples/realistic/Dining_Table/image_0001.jpg differ
diff --git a/examples/realistic/Dining_Table/image_0008.jpg b/examples/realistic/Dining_Table/image_0008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2033c5479f2c1ad886aec6d808a14a06fcb597a2
Binary files /dev/null and b/examples/realistic/Dining_Table/image_0008.jpg differ
diff --git a/examples/realistic/Dining_Table/image_0012.jpg b/examples/realistic/Dining_Table/image_0012.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0f357c20da9c2e48dec1fa76a0925f8d95b6a5cc
Binary files /dev/null and b/examples/realistic/Dining_Table/image_0012.jpg differ
diff --git a/examples/realistic/Dining_Table/image_0016.jpg b/examples/realistic/Dining_Table/image_0016.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0ff58bdc25bbd7396773e3444745a1f3f6c7e4b0
Binary files /dev/null and b/examples/realistic/Dining_Table/image_0016.jpg differ
diff --git a/examples/realistic/Dino/528883410_1456464302336597_4114529568612559572_n.jpg b/examples/realistic/Dino/528883410_1456464302336597_4114529568612559572_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..80b2f752bd684a32648085d38ab5bc105b01cdbc
Binary files /dev/null and b/examples/realistic/Dino/528883410_1456464302336597_4114529568612559572_n.jpg differ
diff --git a/examples/realistic/Dino/530182709_1122456693282934_3373468492106282632_n.jpg b/examples/realistic/Dino/530182709_1122456693282934_3373468492106282632_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bbfdd7a2170a3d22d31862b72e3ddaf7914c7781
Binary files /dev/null and b/examples/realistic/Dino/530182709_1122456693282934_3373468492106282632_n.jpg differ
diff --git a/examples/realistic/Dino/532847807_1055021109949229_8315548832183031452_n.jpg b/examples/realistic/Dino/532847807_1055021109949229_8315548832183031452_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f34cd847367d0b6abc1d484310c51d92a2bf0939
Binary files /dev/null and b/examples/realistic/Dino/532847807_1055021109949229_8315548832183031452_n.jpg differ
diff --git a/examples/realistic/Festival/image_0001.jpg b/examples/realistic/Festival/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fa510ef4b0d84cc4077393747b061400a02e29e1
Binary files /dev/null and b/examples/realistic/Festival/image_0001.jpg differ
diff --git a/examples/realistic/Festival/image_0023.jpg b/examples/realistic/Festival/image_0023.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e06e68207fcc3e05d68315215bbd68ecd04574f5
Binary files /dev/null and b/examples/realistic/Festival/image_0023.jpg differ
diff --git a/examples/realistic/Festival/image_0046.jpg b/examples/realistic/Festival/image_0046.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..70fa3b6438e5b7a2641d29b8b0213c7ff6cbb73f
Binary files /dev/null and b/examples/realistic/Festival/image_0046.jpg differ
diff --git a/examples/realistic/Flower/image_0001.jpg b/examples/realistic/Flower/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..69db71b5c4cd082fa6986edb317d313e7eb36381
Binary files /dev/null and b/examples/realistic/Flower/image_0001.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000000.jpg b/examples/realistic/Great_Wall/great_wall_000000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..49785136b9f773ab73e882458bc5ecface53efe6
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000000.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000001.jpg b/examples/realistic/Great_Wall/great_wall_000001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e8983cd2590999efd1b3e35db65c19e730b69e7d
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000001.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000002.jpg b/examples/realistic/Great_Wall/great_wall_000002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..da075726d5663c5ce85cbf22b350e6a3afaf6765
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000002.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000003.jpg b/examples/realistic/Great_Wall/great_wall_000003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0690f141438552fd33d5d238e0b94a408cfc8b3d
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000003.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000004.jpg b/examples/realistic/Great_Wall/great_wall_000004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c656d7fd2c8245f3c2d4d84175d451f316d0c319
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000004.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000005.jpg b/examples/realistic/Great_Wall/great_wall_000005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..238213e39f3f79094f33f6055d84618496143ba1
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000005.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000006.jpg b/examples/realistic/Great_Wall/great_wall_000006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ac284721d8c1c8e0537bb994425727c73fe10f3c
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000006.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000007.jpg b/examples/realistic/Great_Wall/great_wall_000007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0ebeb33439d33ec82c86514aed4277e365b069d0
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000007.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000008.jpg b/examples/realistic/Great_Wall/great_wall_000008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..17e09818083f1d4bb0285af051f9be82edd11f4c
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000008.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000009.jpg b/examples/realistic/Great_Wall/great_wall_000009.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9633e1f890e5448f0351f5fc1274f8515e9d88b1
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000009.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000010.jpg b/examples/realistic/Great_Wall/great_wall_000010.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7f0385b904728c55f7f78c7d6defb509e75fc47c
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000010.jpg differ
diff --git a/examples/realistic/Great_Wall/great_wall_000011.jpg b/examples/realistic/Great_Wall/great_wall_000011.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8a7a8f3e5b813b4599c58a01e1c947314890fe08
Binary files /dev/null and b/examples/realistic/Great_Wall/great_wall_000011.jpg differ
diff --git a/examples/realistic/Hall/image_0001.jpg b/examples/realistic/Hall/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5f2f64388f662c84d7316ed63168c47c127d36d3
Binary files /dev/null and b/examples/realistic/Hall/image_0001.jpg differ
diff --git a/examples/realistic/Hall/image_0027.jpg b/examples/realistic/Hall/image_0027.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3d5240e970f11a648074d202b2000c1d1cf05e2b
Binary files /dev/null and b/examples/realistic/Hall/image_0027.jpg differ
diff --git a/examples/realistic/Ireland_Landscape/image_0001.jpg b/examples/realistic/Ireland_Landscape/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..db01d5b6fc7b900f526d5314de47ad4f19887c1b
Binary files /dev/null and b/examples/realistic/Ireland_Landscape/image_0001.jpg differ
diff --git a/examples/realistic/Ireland_Landscape/image_0007.jpg b/examples/realistic/Ireland_Landscape/image_0007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fc2913193a474c1079d6393b4323635603444225
Binary files /dev/null and b/examples/realistic/Ireland_Landscape/image_0007.jpg differ
diff --git a/examples/realistic/Ireland_Landscape/image_0010.jpg b/examples/realistic/Ireland_Landscape/image_0010.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8860042799f34f45369b0831b49f0e1b3ca3c92a
Binary files /dev/null and b/examples/realistic/Ireland_Landscape/image_0010.jpg differ
diff --git a/examples/realistic/Ireland_Landscape/image_0017.jpg b/examples/realistic/Ireland_Landscape/image_0017.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f678efc626fb477a26dcbe2e2fd938f6a18bb3d4
Binary files /dev/null and b/examples/realistic/Ireland_Landscape/image_0017.jpg differ
diff --git a/examples/realistic/Ireland_Landscape/image_0022.jpg b/examples/realistic/Ireland_Landscape/image_0022.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..719d649f01685168b9f6c375f3617650ed094138
Binary files /dev/null and b/examples/realistic/Ireland_Landscape/image_0022.jpg differ
diff --git a/examples/realistic/Ireland_Landscape/image_0026.jpg b/examples/realistic/Ireland_Landscape/image_0026.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..046e904321738ab6661752da5af9cbb01e776a65
Binary files /dev/null and b/examples/realistic/Ireland_Landscape/image_0026.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/00.jpg b/examples/realistic/Lego_Kitchen/00.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7f0063bc20987c40bf2afffcc0bc6bb7b93d3793
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/00.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/01.jpg b/examples/realistic/Lego_Kitchen/01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a268a25a1554dfb82c80fbcee035f60989f70d2e
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/01.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/02.jpg b/examples/realistic/Lego_Kitchen/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..83aa1faa24b8bd3d5bc7a58e43474623260b85c3
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/02.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/03.jpg b/examples/realistic/Lego_Kitchen/03.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..45be74c1d5d2f3645ca6fdc3f65e2165b8196585
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/03.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/04.jpg b/examples/realistic/Lego_Kitchen/04.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..73fe913a9b76ed3ad0a681493cd3c979bf105143
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/04.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/05.jpg b/examples/realistic/Lego_Kitchen/05.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7c347498236437b6af7fe6f53627688b8c922290
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/05.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/06.jpg b/examples/realistic/Lego_Kitchen/06.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f0d07874e80b04ffb93935dbf4406c2c906810b6
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/06.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/07.jpg b/examples/realistic/Lego_Kitchen/07.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3053cd2277306f7bcb8bbfe406929cc3586ef0aa
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/07.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/08.jpg b/examples/realistic/Lego_Kitchen/08.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..895f72886323b801c2cb96e33424416c7ce6e47b
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/08.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/09.jpg b/examples/realistic/Lego_Kitchen/09.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..950bc13efeace6fb075bb5c694e3049b86e137b8
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/09.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/10.jpg b/examples/realistic/Lego_Kitchen/10.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e1d3bed8de7387a29eabc280f7cbe63f1a1e0755
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/10.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/11.jpg b/examples/realistic/Lego_Kitchen/11.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..331fdd3a02d1dfa4d7b83b1454386a61a6dce209
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/11.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/12.jpg b/examples/realistic/Lego_Kitchen/12.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ab2a5cf490611bc7c7923fcdc8d3beaabc5c6d72
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/12.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/13.jpg b/examples/realistic/Lego_Kitchen/13.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e60006d59f695d17114ed41b83405e098a4815d7
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/13.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/14.jpg b/examples/realistic/Lego_Kitchen/14.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..48669906971377005cae5cb9882662652d535c28
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/14.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/15.jpg b/examples/realistic/Lego_Kitchen/15.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f496953201c580e236c605aa38c65afc271abef4
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/15.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/16.jpg b/examples/realistic/Lego_Kitchen/16.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dfbf489b980b5ad2c720df2ec2f8b56804f23af0
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/16.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/17.jpg b/examples/realistic/Lego_Kitchen/17.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9c1268f7dbcb4e0393b85b01b8f65d3ef16ae80b
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/17.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/18.jpg b/examples/realistic/Lego_Kitchen/18.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ac8fd3d489116d66252b7127e5f2803270193c42
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/18.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/19.jpg b/examples/realistic/Lego_Kitchen/19.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c566124ce2f19e834997a77e1c4b4a81cec1fbbf
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/19.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/20.jpg b/examples/realistic/Lego_Kitchen/20.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..974be6f2ef083dceaa1b27a922bb714d4421ab4a
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/20.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/21.jpg b/examples/realistic/Lego_Kitchen/21.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1af671de8d08b52bc688be589e6cc4fa8eb7b158
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/21.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/22.jpg b/examples/realistic/Lego_Kitchen/22.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1dc610bc13a473dbec3eb3586a3e1e2cfe1761f9
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/22.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/23.jpg b/examples/realistic/Lego_Kitchen/23.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0e68180ab71e5bd098d4a3e2e97f876e86b58c04
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/23.jpg differ
diff --git a/examples/realistic/Lego_Kitchen/24.jpg b/examples/realistic/Lego_Kitchen/24.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b781bdb7e984e7dc89d35e237881d04277a6520d
Binary files /dev/null and b/examples/realistic/Lego_Kitchen/24.jpg differ
diff --git a/examples/realistic/Living_Room/image_0001.jpg b/examples/realistic/Living_Room/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5cecdb4dfba161d4113a29bf73d9b2b41eb2223e
Binary files /dev/null and b/examples/realistic/Living_Room/image_0001.jpg differ
diff --git a/examples/realistic/Living_Room/image_0012.jpg b/examples/realistic/Living_Room/image_0012.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8f3743af7d8b5bd2a1234a5aa28aabdc92a36429
Binary files /dev/null and b/examples/realistic/Living_Room/image_0012.jpg differ
diff --git a/examples/realistic/Office/Office.jpg b/examples/realistic/Office/Office.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3a21eec3de0c64ed4a8ce9cc612145673882d07d
--- /dev/null
+++ b/examples/realistic/Office/Office.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28767640002f93b703b24a34a6d75ca24b1ef093a19f52ef0f9d3b074ef68c61
+size 197508
diff --git a/examples/realistic/Park/image_0001.jpg b/examples/realistic/Park/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..407263b4f4fa1a2e7a1b7d31798cfc2f7349c56e
Binary files /dev/null and b/examples/realistic/Park/image_0001.jpg differ
diff --git a/examples/realistic/Park/image_0008.jpg b/examples/realistic/Park/image_0008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f2a1a8a0f74b45fcbfcee60dfbe55263639c84ba
Binary files /dev/null and b/examples/realistic/Park/image_0008.jpg differ
diff --git a/examples/realistic/Park/image_0014.jpg b/examples/realistic/Park/image_0014.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..38b6e3888dbec3d88f42da86613178d9d63c87cd
Binary files /dev/null and b/examples/realistic/Park/image_0014.jpg differ
diff --git a/examples/realistic/Remains/image_0001.jpg b/examples/realistic/Remains/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5339d268c6c08c0e84cbb201ea70b473d9bef1c2
Binary files /dev/null and b/examples/realistic/Remains/image_0001.jpg differ
diff --git a/examples/realistic/Remains/image_0011.jpg b/examples/realistic/Remains/image_0011.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf0eaab72274ade6e0a3e9d765bcf647a4dde448
Binary files /dev/null and b/examples/realistic/Remains/image_0011.jpg differ
diff --git a/examples/realistic/Remains/image_0020.jpg b/examples/realistic/Remains/image_0020.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ea2f29ed1e994bf6525b3bad6f951c0bfdded7cc
Binary files /dev/null and b/examples/realistic/Remains/image_0020.jpg differ
diff --git a/examples/realistic/Remains/image_0030.jpg b/examples/realistic/Remains/image_0030.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a041264a791aac68572ca2cdaf60e4db583e59f2
Binary files /dev/null and b/examples/realistic/Remains/image_0030.jpg differ
diff --git a/examples/realistic/Remains/image_0038.jpg b/examples/realistic/Remains/image_0038.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..122a38606f5bd65e2c2bafe451c69535ef29786b
Binary files /dev/null and b/examples/realistic/Remains/image_0038.jpg differ
diff --git a/examples/realistic/Room_Cat/no_overlap_1.jpg b/examples/realistic/Room_Cat/no_overlap_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2f5a46fb43fb0c41d3781f35b0914b64e2dc4d26
Binary files /dev/null and b/examples/realistic/Room_Cat/no_overlap_1.jpg differ
diff --git a/examples/realistic/Room_Cat/no_overlap_2.jpg b/examples/realistic/Room_Cat/no_overlap_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9e6731afc699d9f96421ae60828789e9d274c734
--- /dev/null
+++ b/examples/realistic/Room_Cat/no_overlap_2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6f2cac1c3271918177eb134fc080d5a43e40f71bf2a50fda946614a4204d3de
+size 275326
diff --git a/examples/realistic/Room_Cat/no_overlap_3.jpg b/examples/realistic/Room_Cat/no_overlap_3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..766b14e84a84229637b6e2db4505a3bae9d1f047
--- /dev/null
+++ b/examples/realistic/Room_Cat/no_overlap_3.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c0e52375887d44657a25a34578d79f744378b870863023fed6f86dcbd84eeb0
+size 249085
diff --git a/examples/realistic/Room_Cat/no_overlap_4.jpg b/examples/realistic/Room_Cat/no_overlap_4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ac8e88a24a6ec0c8570a02a67b3d46a4622d88c9
--- /dev/null
+++ b/examples/realistic/Room_Cat/no_overlap_4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:47c17353216ebf663ca519b2c6e5d515301995386b4aa602a9b2bd508c8bffe0
+size 230462
diff --git a/examples/realistic/Room_Cat/no_overlap_5.jpg b/examples/realistic/Room_Cat/no_overlap_5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3a0e1a795b0794714d093235878a16122497e3d9
--- /dev/null
+++ b/examples/realistic/Room_Cat/no_overlap_5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:69c2af2f336d84712b2879ea5a334da1dc1f01095eb71d9d30fa6a26e5ad66af
+size 265973
diff --git a/examples/realistic/Room_Cat/no_overlap_6.jpg b/examples/realistic/Room_Cat/no_overlap_6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4d46981beb355e12377942421e2f8f8fb0792473
--- /dev/null
+++ b/examples/realistic/Room_Cat/no_overlap_6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01cf8937c2da430ef49e9c0cb23a8031a698eebf2dd1261a37a5c1ee28f5a7f5
+size 270884
diff --git a/examples/realistic/Room_Cat/no_overlap_7.jpg b/examples/realistic/Room_Cat/no_overlap_7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fa37dd2712a10dc959a44c7990f40f5042d438f3
--- /dev/null
+++ b/examples/realistic/Room_Cat/no_overlap_7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:189a30e8bd6445c972eb6a8c31581e9af6d0bbc03b0345fba5ca023e678f5492
+size 260800
diff --git a/examples/realistic/Room_Cat/no_overlap_8.jpg b/examples/realistic/Room_Cat/no_overlap_8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..37dcb508239b2c679eeb324ee4f5d67bf17fb83f
--- /dev/null
+++ b/examples/realistic/Room_Cat/no_overlap_8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db13926aab6bcc0f7c4903839ccb0dc554ab3276fdcf73e0f304b596f5c15221
+size 191454
diff --git a/examples/realistic/Rooms/image_0001.jpg b/examples/realistic/Rooms/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..25d55b0b1ff2b6663a1abeb40179e8da68103101
Binary files /dev/null and b/examples/realistic/Rooms/image_0001.jpg differ
diff --git a/examples/realistic/Rooms/image_0008.jpg b/examples/realistic/Rooms/image_0008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0fc60d59ef82d1a1a7b764fac0f4f7b4cac637ab
Binary files /dev/null and b/examples/realistic/Rooms/image_0008.jpg differ
diff --git a/examples/realistic/Rooms/image_0016.jpg b/examples/realistic/Rooms/image_0016.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f4585a42e136acd89cb139df33c93079834ad1d
Binary files /dev/null and b/examples/realistic/Rooms/image_0016.jpg differ
diff --git a/examples/realistic/Safari_Car/view_0.jpg b/examples/realistic/Safari_Car/view_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3cefb8e3af551df4f2481d4d78dda622d7d14e1d
Binary files /dev/null and b/examples/realistic/Safari_Car/view_0.jpg differ
diff --git a/examples/realistic/Sisters_Statue/481869432_646849634388788_2162202232236218000_n.jpg b/examples/realistic/Sisters_Statue/481869432_646849634388788_2162202232236218000_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..52e1bcdc4ae0b03f91929c9619b65391ea0c83f1
--- /dev/null
+++ b/examples/realistic/Sisters_Statue/481869432_646849634388788_2162202232236218000_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:415d501df6c4449bc9c2d3b597ffd4ea6649cbb7f1c8e833746cf8462063bc31
+size 127831
diff --git a/examples/realistic/Sisters_Statue/481943293_641636221777392_2955401254290735956_n.jpg b/examples/realistic/Sisters_Statue/481943293_641636221777392_2955401254290735956_n.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6560bf0d9ba72d968b95ba123acf7ef75c1012f7
--- /dev/null
+++ b/examples/realistic/Sisters_Statue/481943293_641636221777392_2955401254290735956_n.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:440206e24b18b1a537c8c4acfd43ed86c43275a3dd288d6a5b4200b3a85c166c
+size 139765
diff --git a/examples/realistic/Snow/image_0001.jpg b/examples/realistic/Snow/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cac232018dc021990e38821acb26b1b6196cb7f9
Binary files /dev/null and b/examples/realistic/Snow/image_0001.jpg differ
diff --git a/examples/realistic/Snow/image_0015.jpg b/examples/realistic/Snow/image_0015.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..438719956a59322a38bf1f4736553fdd2e518f16
Binary files /dev/null and b/examples/realistic/Snow/image_0015.jpg differ
diff --git a/examples/realistic/Snow/image_0020.jpg b/examples/realistic/Snow/image_0020.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..52318bec0f8918eb86edc4cb66287d9b9e382277
Binary files /dev/null and b/examples/realistic/Snow/image_0020.jpg differ
diff --git a/examples/realistic/Snow/image_0030.jpg b/examples/realistic/Snow/image_0030.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b7bdd8f839a4910a16ef6a59a927a5e4c6633660
Binary files /dev/null and b/examples/realistic/Snow/image_0030.jpg differ
diff --git a/examples/realistic/Snow/image_0045.jpg b/examples/realistic/Snow/image_0045.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6c6dd3fbdd587ca15ac75458cd584a751b0f3c2d
Binary files /dev/null and b/examples/realistic/Snow/image_0045.jpg differ
diff --git a/examples/realistic/Snow/image_0064.jpg b/examples/realistic/Snow/image_0064.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4a31310c32ef5d0092496459f0bdf4f4fe9f0dfd
Binary files /dev/null and b/examples/realistic/Snow/image_0064.jpg differ
diff --git a/examples/realistic/Valley/image_0.jpg b/examples/realistic/Valley/image_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0ceaf5592365347129c366167e227972c7f856b6
Binary files /dev/null and b/examples/realistic/Valley/image_0.jpg differ
diff --git a/examples/realistic/Valley/image_1.jpg b/examples/realistic/Valley/image_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..072b0fef12eb4a1661f1b523721817e91209f4be
Binary files /dev/null and b/examples/realistic/Valley/image_1.jpg differ
diff --git a/examples/realistic/Valley/image_10.jpg b/examples/realistic/Valley/image_10.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9323fc42eaa454811046297cf7e63cf84cfc98b5
Binary files /dev/null and b/examples/realistic/Valley/image_10.jpg differ
diff --git a/examples/realistic/Valley/image_2.jpg b/examples/realistic/Valley/image_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3c7f9c9be18e2a527413f84050638dee6b14d02a
Binary files /dev/null and b/examples/realistic/Valley/image_2.jpg differ
diff --git a/examples/realistic/Valley/image_3.jpg b/examples/realistic/Valley/image_3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..635a1b74963d2a7c463387c43d7e4ccbbd0bed9a
Binary files /dev/null and b/examples/realistic/Valley/image_3.jpg differ
diff --git a/examples/realistic/Valley/image_4.jpg b/examples/realistic/Valley/image_4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..471602227ef6a4204925559c6d1b134d65f43369
Binary files /dev/null and b/examples/realistic/Valley/image_4.jpg differ
diff --git a/examples/realistic/Valley/image_5.jpg b/examples/realistic/Valley/image_5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..93acd287e632a39fa08f81c000fc6e353ecf1495
Binary files /dev/null and b/examples/realistic/Valley/image_5.jpg differ
diff --git a/examples/realistic/Valley/image_6.jpg b/examples/realistic/Valley/image_6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..51b5f92664ed2c1cab410d2fe766a082dd57ffc3
Binary files /dev/null and b/examples/realistic/Valley/image_6.jpg differ
diff --git a/examples/realistic/Valley/image_7.jpg b/examples/realistic/Valley/image_7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4101e6cf9b746380f7bcacc46d5f2a8935857e7b
Binary files /dev/null and b/examples/realistic/Valley/image_7.jpg differ
diff --git a/examples/realistic/Valley/image_8.jpg b/examples/realistic/Valley/image_8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6e739a2cb33a4e3e9ecdeb01cdcff11bd89ca7a2
Binary files /dev/null and b/examples/realistic/Valley/image_8.jpg differ
diff --git a/examples/realistic/Valley/image_9.jpg b/examples/realistic/Valley/image_9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f1e24715a21067ede0de3f24e78a988cd37e4870
Binary files /dev/null and b/examples/realistic/Valley/image_9.jpg differ
diff --git a/examples/realistic/Workspace/image_0001.jpg b/examples/realistic/Workspace/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87bf8a107a51bf999f39e91bff74e002d5ba0381
Binary files /dev/null and b/examples/realistic/Workspace/image_0001.jpg differ
diff --git a/examples/realistic/Workspace/image_0008.jpg b/examples/realistic/Workspace/image_0008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..875917fb76314b8e8556274ec66a87e4f4123722
Binary files /dev/null and b/examples/realistic/Workspace/image_0008.jpg differ
diff --git a/examples/realistic/Workspace/image_0016.jpg b/examples/realistic/Workspace/image_0016.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..668053cfd552b6034c4bf722c8899daa04cf77e4
Binary files /dev/null and b/examples/realistic/Workspace/image_0016.jpg differ
diff --git a/examples/realistic/Workspace/image_0027.jpg b/examples/realistic/Workspace/image_0027.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e67ebee594e8ddf7fe2a7f1d8eed2cc43137e66a
Binary files /dev/null and b/examples/realistic/Workspace/image_0027.jpg differ
diff --git a/examples/stylistic/A_MARCEAU/image_0001.jpg b/examples/stylistic/A_MARCEAU/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6a4d89eb434b5e66a956dbf6e7961e4cc30c3329
Binary files /dev/null and b/examples/stylistic/A_MARCEAU/image_0001.jpg differ
diff --git a/examples/stylistic/A_MARCEAU/image_0017.jpg b/examples/stylistic/A_MARCEAU/image_0017.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f9b792d74e86b8352354fa6e77e617c58964d011
Binary files /dev/null and b/examples/stylistic/A_MARCEAU/image_0017.jpg differ
diff --git a/examples/stylistic/A_Stylized_Kitchen/image_0001.jpg b/examples/stylistic/A_Stylized_Kitchen/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0bf1d6a729f9fd8470ae7aeba46f61add56f874c
Binary files /dev/null and b/examples/stylistic/A_Stylized_Kitchen/image_0001.jpg differ
diff --git a/examples/stylistic/A_Stylized_Kitchen/image_0006.jpg b/examples/stylistic/A_Stylized_Kitchen/image_0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..64b6f201af6dcb5e847b730b52b634ea4fd0bbc5
Binary files /dev/null and b/examples/stylistic/A_Stylized_Kitchen/image_0006.jpg differ
diff --git a/examples/stylistic/A_Stylized_Kitchen/image_0012.jpg b/examples/stylistic/A_Stylized_Kitchen/image_0012.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d08c38ba956ca7b5d8c8b0ed60f680493f2a8b2b
Binary files /dev/null and b/examples/stylistic/A_Stylized_Kitchen/image_0012.jpg differ
diff --git a/examples/stylistic/A_Stylized_Kitchen/image_0017.jpg b/examples/stylistic/A_Stylized_Kitchen/image_0017.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cadc30218dc617ca2043ad75262b00844f0641a9
Binary files /dev/null and b/examples/stylistic/A_Stylized_Kitchen/image_0017.jpg differ
diff --git a/examples/stylistic/Cat_Girl/Cat_Girl.jpg b/examples/stylistic/Cat_Girl/Cat_Girl.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bbb591b66d21e151a279b5ee56c10acea98e8e6f
--- /dev/null
+++ b/examples/stylistic/Cat_Girl/Cat_Girl.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be53d3cfe77f8d930975e82ad5ed9d7f523c26407b93ce9a2c8e900ad97d8f0d
+size 103673
diff --git a/examples/stylistic/Cottage_Autumn/image_0001.jpg b/examples/stylistic/Cottage_Autumn/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1546536fa8868c170ddc55568feccca09c8b0e0b
Binary files /dev/null and b/examples/stylistic/Cottage_Autumn/image_0001.jpg differ
diff --git a/examples/stylistic/Cottage_Autumn/image_0016.jpg b/examples/stylistic/Cottage_Autumn/image_0016.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..15fc3fc1444140aae0d0d59b069396ddd4ac1b3c
Binary files /dev/null and b/examples/stylistic/Cottage_Autumn/image_0016.jpg differ
diff --git a/examples/stylistic/Oil_Painting/oil.jpg b/examples/stylistic/Oil_Painting/oil.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..240c4d4816eb48c087a18044ccc264d6970b4e1d
--- /dev/null
+++ b/examples/stylistic/Oil_Painting/oil.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2a5f03943a22d7183d202baeff9dba3befd3d17438bba4eb903d12d9466df4
+size 125386
diff --git a/examples/stylistic/Palace/image_0001.jpg b/examples/stylistic/Palace/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c7ddda32b0b37314aa2150c7046a89023745d888
Binary files /dev/null and b/examples/stylistic/Palace/image_0001.jpg differ
diff --git a/examples/stylistic/Palace/image_0017.jpg b/examples/stylistic/Palace/image_0017.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8db124c7ec83002f3e7c3a214f0abd74fbbfe363
Binary files /dev/null and b/examples/stylistic/Palace/image_0017.jpg differ
diff --git a/examples/stylistic/Palace/image_0020.jpg b/examples/stylistic/Palace/image_0020.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7e834ccd5fea8dbcee900e76fe98a53763c152f7
Binary files /dev/null and b/examples/stylistic/Palace/image_0020.jpg differ
diff --git a/examples/stylistic/Panda_Wild_West/panda_orange_cat_wildwest.jpeg b/examples/stylistic/Panda_Wild_West/panda_orange_cat_wildwest.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..fc33b7880d10e1d04c97a02d82def4fc0cb36b0b
--- /dev/null
+++ b/examples/stylistic/Panda_Wild_West/panda_orange_cat_wildwest.jpeg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:909fd5cc7f0be2d823e76e9ece7ac0bc6b7d2dbb58c82214701edfbb2256b3ce
+size 472387
diff --git a/examples/stylistic/The_Ancient_Buildings/image_0001.jpg b/examples/stylistic/The_Ancient_Buildings/image_0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cc7166cba5ae0642560fece843ee0c5a253d2e55
Binary files /dev/null and b/examples/stylistic/The_Ancient_Buildings/image_0001.jpg differ
diff --git a/examples/stylistic/The_Ancient_Buildings/image_0015.jpg b/examples/stylistic/The_Ancient_Buildings/image_0015.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..09de807ec235748c584d1814ecea0f187d3f232d
Binary files /dev/null and b/examples/stylistic/The_Ancient_Buildings/image_0015.jpg differ
diff --git a/examples/stylistic/The_Ancient_Buildings/image_0031.jpg b/examples/stylistic/The_Ancient_Buildings/image_0031.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ae4e2a3a2858e76727131610c1f9aab8df315c0c
Binary files /dev/null and b/examples/stylistic/The_Ancient_Buildings/image_0031.jpg differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4d95ed7b59f571e7672b5523f2fe6d77561b304a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+gradio==5.49.1
+moviepy==1.0.3
+torch
+torchvision
+tqdm
+omegaconf
+pydantic
+opencv-python
+scipy
+requests
+trimesh
+matplotlib
+spaces
+pillow_heif
+onnxruntime
+einops
+torchmetrics
+uniception
+colorspacious
+safetensors
+plyfile
+hf_transfer
+gsplat @ https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.3/gsplat-1.5.3+pt24cu124-cp310-cp310-linux_x86_64.whl
\ No newline at end of file
diff --git a/skyseg.onnx b/skyseg.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..786323602c883882bc01ee682bcc737b2d082c16
--- /dev/null
+++ b/skyseg.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab9c34c64c3d821220a2886a4a06da4642ffa14d5b30e8d5339056a089aa1d39
+size 175997079
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/heads/camera_head.py b/src/models/heads/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf5aa641608a240ee205d8f30a7cbd0539336984
--- /dev/null
+++ b/src/models/heads/camera_head.py
@@ -0,0 +1,166 @@
+# inspired by https://github.com/facebookresearch/vggt/blob/main/src/models/heads/camera_head.py
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from src.models.layers import Mlp
+from src.models.layers.block import Block
+
+
+class CameraHead(nn.Module):
+ """
+ Camera head module: predicts camera parameters from token representations using iterative refinement
+
+ Processes dedicated camera tokens through a series of transformer blocks
+ """
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu",
+ ):
+ super().__init__()
+
+ self.out_dim = 9
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.depth = trunk_depth
+
+ # Build refinement network using transformer block sequence
+ self.refine_net = nn.Sequential(
+ *[
+ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalization for camera tokens and network output
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.out_norm = nn.LayerNorm(dim_in)
+
+ # Learnable initial camera parameter token
+ self.init_token = nn.Parameter(torch.zeros(1, 1, self.out_dim))
+ self.param_embed = nn.Linear(self.out_dim, dim_in)
+
+ # Generate adaptive normalization parameters: shift, scale, and gate
+ self.adapt_norm_gen = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
+
+ # Adaptive layer normalization (no learnable parameters)
+ self.adapt_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.param_predictor = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.out_dim, drop=0)
+
+ def forward(self, feat_seq: list, steps: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters
+
+ Args:
+ feat_seq: List of token tensors from network, last one used for prediction
+ steps: Number of iterative refinement steps, default 4
+
+ Returns:
+ List of predicted camera encodings (post-activation) from each iteration
+ """
+ # Use tokens from last block for camera prediction
+ latest_feat = feat_seq[-1]
+
+ # Extract camera tokens
+ cam_tokens = latest_feat[:, :, 0]
+ cam_tokens = self.token_norm(cam_tokens)
+
+ # Iteratively refine camera pose predictions
+ b, seq_len, feat_dim = cam_tokens.shape # seq_len expected to be 1
+ curr_pred = None
+ pred_seq = []
+
+ for step in range(steps):
+ # Use learned initial token for first iteration
+ if curr_pred is None:
+ net_input = self.param_embed(self.init_token.expand(b, seq_len, -1))
+ else:
+ curr_pred = curr_pred.detach()
+ net_input = self.param_embed(curr_pred)
+ norm_shift, norm_scale, norm_gate = self.adapt_norm_gen(net_input).chunk(3, dim=-1)
+ mod_cam_feat = norm_gate * self.apply_adaptive_modulation(self.adapt_norm(cam_tokens), norm_shift, norm_scale)
+ mod_cam_feat = mod_cam_feat + cam_tokens
+
+ proc_feat = self.refine_net(mod_cam_feat)
+ param_delta = self.param_predictor(self.out_norm(proc_feat))
+
+ if curr_pred is None:
+ curr_pred = param_delta
+ else:
+ curr_pred = curr_pred + param_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view
+ activated_params = self.apply_camera_parameter_activation(curr_pred)
+ pred_seq.append(activated_params)
+
+ return pred_seq
+
+ def apply_camera_parameter_activation(self, params: torch.Tensor) -> torch.Tensor:
+ """
+ Apply activation functions to camera parameter components
+
+ Args:
+ params: Tensor containing camera parameters [translation, quaternion, focal_length]
+
+ Returns:
+ Activated camera parameters tensor
+ """
+ trans_vec = params[..., :3]
+ quat_vec = params[..., 3:7]
+ fl_vec = params[..., 7:] # or field of view
+
+ trans_vec = self.apply_parameter_activation(trans_vec, self.trans_act)
+ quat_vec = self.apply_parameter_activation(quat_vec, self.quat_act)
+ fl_vec = self.apply_parameter_activation(fl_vec, self.fl_act)
+
+ activated_params = torch.cat([trans_vec, quat_vec, fl_vec], dim=-1)
+ return activated_params
+
+ def apply_parameter_activation(self, tensor: torch.Tensor, act_type: str) -> torch.Tensor:
+ """
+ Apply specified activation function to parameter tensor
+
+ Args:
+ tensor: Tensor containing parameter values
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated parameter tensor
+ """
+ if act_type == "linear":
+ return tensor
+ elif act_type == "inv_log":
+ return self.apply_inverse_logarithm_transform(tensor)
+ elif act_type == "exp":
+ return torch.exp(tensor)
+ elif act_type == "relu":
+ return F.relu(tensor)
+ else:
+ raise ValueError(f"Unknown activation_type: {act_type}")
+
+ def apply_inverse_logarithm_transform(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ x: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(x) * (torch.expm1(torch.abs(x)))
+
+ def apply_adaptive_modulation(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Apply adaptive modulation to input tensor using scaling and shifting parameters
+ """
+ # Modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
\ No newline at end of file
diff --git a/src/models/heads/dense_head.py b/src/models/heads/dense_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b1af55565437546978b2a4a839c37920bd94dcd
--- /dev/null
+++ b/src/models/heads/dense_head.py
@@ -0,0 +1,579 @@
+# inspired by https://github.com/DepthAnything/Depth-Anything-V2
+from typing import List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from src.models.utils.grid import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ # DPT Head for dense prediction tasks.
+
+ # This module implements the DPT (Dense Prediction Transformer) head as proposed in
+ # "Vision Transformers for Dense Prediction" (https://arxiv.org/abs/2103.13413).
+ # It takes features from a vision transformer backbone and generates dense (per-pixel) predictions
+ # by fusing multi-scale features through a series of projection, upsampling, and refinement blocks.
+
+ # Args:
+ # dim_in (int): Number of input feature channels.
+ # patch_size (int, optional): Patch size used by the backbone, default is 14.
+ # output_dim (int, optional): Number of output channels, default is 4.
+ # activation (str, optional): Activation function type for the output head, default is "inv_log".
+ # conf_activation (str, optional): Activation function type for the confidence/output uncertainty head, default is "expp1".
+ # features (int, optional): Number of channels used in intermediate feature representations, default is 256.
+ # out_channels (List[int], optional): Number of channels for each intermediate multi-scale feature.
+ # intermediate_layer_idx (List[int], optional): Indices specifying which backbone layers to use for multi-scale fusion.
+ # pos_embed (bool, optional): Whether to add positional encoding to the features, default is True.
+ # feature_only (bool, optional): If True, only return intermediate features (skip final prediction and activations).
+ # down_ratio (int, optional): Downsampling ratio of the output predictions, default is 1 (no downsampling).
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log+expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ pos_embed: bool = True,
+ down_ratio: int = 1,
+ is_gsdpt: bool = False
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.pos_embed = pos_embed
+ self.down_ratio = down_ratio
+ self.is_gsdpt = is_gsdpt
+
+ self.norm = nn.LayerNorm(dim_in)
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList([nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels])
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+ self.scratch = _make_scratch(out_channels, features, expand=False)
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if self.is_gsdpt:
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
+ conv2_in_channels = head_features_1 // 2
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+ self.input_merger = nn.Sequential(
+ nn.Conv2d(3, conv2_in_channels, 7, 1, 3),
+ nn.ReLU()
+ )
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
+ )
+ conv2_in_channels = head_features_1 // 2
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(
+ self,
+ token_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass with optional frame chunking for memory efficiency.
+
+ Args:
+ token_list: List of token tensors from transformer, each [B, N, C]
+ images: Input images [B, S, 3, H, W], range [0, 1]
+ patch_start_idx: Starting index of patch tokens
+ frames_chunk_size: Number of frames per chunk. If None or >= S, process all at once
+ gradient_checkpoint: Whether to use gradient checkpointing
+
+ Returns:
+ For is_gsdpt: predictions [B, S, ...]
+ Otherwise: (predictions, confidence), [B, S, X, H, W] and [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # Process all frames together if chunk size not specified or large enough
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(token_list, images, patch_start_idx)
+
+ assert frames_chunk_size > 0
+
+ # Process frames in chunks
+ preds_chunks = []
+ conf_chunks = []
+ gs_chunks = []
+
+ for frame_start in range(0, S, frames_chunk_size):
+ frame_end = min(frame_start + frames_chunk_size, S)
+
+ if self.is_gsdpt:
+ gs, preds, conf = self._forward_impl(
+ token_list, images, patch_start_idx, frame_start, frame_end
+ )
+ gs_chunks.append(gs)
+ preds_chunks.append(preds)
+ conf_chunks.append(conf)
+ else:
+ preds, conf = self._forward_impl(
+ token_list, images, patch_start_idx, frame_start, frame_end
+ )
+ preds_chunks.append(preds)
+ conf_chunks.append(conf)
+
+ # Concatenate chunks along frame dimension
+ if self.is_gsdpt:
+ return torch.cat(gs_chunks, dim=1), torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1),
+ else:
+ return torch.cat(preds_chunks, dim=1), torch.cat(conf_chunks, dim=1)
+
+ def _forward_impl(
+ self,
+ token_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frame_start: int = None,
+ frame_end: int = None,
+ ) -> torch.Tensor:
+ """
+ Core forward implementation for DPT head.
+
+ Args:
+ token_list: List of transformer tokens from each layer, [B, S, N, C]
+ images: Input images [B, S, 3, H, W]
+ patch_start_idx: Starting index of patch tokens
+ frame_start: Start index for frame chunking (optional)
+ frame_end: End index for frame chunking (optional)
+
+ Returns:
+ If is_gsdpt: (features, preds, conf)
+ Else: (preds, conf)
+ """
+ # Slice frames if chunking
+ if frame_start is not None and frame_end is not None:
+ images = images[:, frame_start:frame_end].contiguous()
+
+ B, S, _, H, W = images.shape
+ ph = H // self.patch_size # patch height
+ pw = W // self.patch_size # patch width
+
+ # Extract and project multi-level features
+ feats = []
+ for proj, resize, tokens in zip(self.projects, self.resize_layers, token_list):
+ # Extract patch tokens
+ patch_tokens = tokens[:, :, patch_start_idx:]
+ if frame_start is not None and frame_end is not None:
+ patch_tokens = patch_tokens[:, frame_start:frame_end]
+
+ # Reshape to [B*S, N_patches, C]
+ patch_tokens = patch_tokens.reshape(B * S, -1, patch_tokens.shape[-1])
+ patch_tokens = self.norm(patch_tokens)
+
+ # Convert to 2D feature map [B*S, C, ph, pw]
+ feat = patch_tokens.permute(0, 2, 1).reshape(B * S, patch_tokens.shape[-1], ph, pw)
+ feat = proj(feat)
+
+ if self.pos_embed:
+ feat = self._apply_pos_embed(feat, W, H)
+ feat = resize(feat)
+ feats.append(feat)
+
+ # Fuse multi-level features
+ fused = self.scratch_forward(feats)
+ fused = custom_interpolate(
+ fused,
+ size=(
+ int(ph * self.patch_size / self.down_ratio),
+ int(pw * self.patch_size / self.down_ratio)
+ ),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ # Apply positional embedding after upsampling
+ if self.pos_embed:
+ fused = self._apply_pos_embed(fused, W, H)
+
+ # Generate predictions and confidence
+ if self.is_gsdpt:
+ # GSDPT: output features, predictions, and confidence
+ out = self.scratch.output_conv2(fused)
+ preds, conf = self.activate_head(out, activation=self.activation)
+ preds = preds.reshape(B, S, *preds.shape[1:])
+ conf = conf.reshape(B, S, *conf.shape[1:])
+
+ # Merge direct image features
+ img_flat = images.reshape(B * S, -1, H, W)
+ img_feat = self.input_merger(img_flat)
+ fused = fused + img_feat
+ fused = fused.reshape(B, S, *fused.shape[1:])
+ return fused, preds, conf
+ else:
+ # Standard: output predictions and confidence
+ out = self.scratch.output_conv2(fused)
+ preds, conf = self.activate_head(out, activation=self.activation)
+ preds = preds.reshape(B, S, *preds.shape[1:])
+ conf = conf.reshape(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+ def activate_head(self, out_head: torch.Tensor, activation: str = "inv_log+expp1") -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Process network output to extract attribute (e.g. points, depth, etc.) and confidence values.
+
+ Args:
+ out_head: Network output tensor (B, C, H, W)
+ activation: Activation type for processing (e.g., "inv_log+expp1")
+
+ Returns:
+ Tuple of (attribute tensor, confidence tensor)
+ """
+ # Parse activation string
+ act_attr, act_conf = (activation.split("+") if "+" in activation else (activation, "expp1"))
+
+ # (B,C,H,W) -> (B,H,W,C)
+ feat = out_head.permute(0, 2, 3, 1)
+ attr, conf = feat[..., :-1], feat[..., -1]
+
+ # Map point activations to lambdas for clarity and conciseness
+ attr_activations = {
+ "norm_exp": lambda x: (x / x.norm(dim=-1, keepdim=True).clamp(min=1e-8)) * torch.expm1(x.norm(dim=-1, keepdim=True)),
+ "norm": lambda x: x / x.norm(dim=-1, keepdim=True),
+ "exp": torch.exp,
+ "relu": F.relu,
+ "inv_log": self._apply_inverse_log_transform,
+ "xy_inv_log": lambda x: torch.cat([
+ x[..., :2] * self._apply_inverse_log_transform(x[..., 2:]),
+ self._apply_inverse_log_transform(x[..., 2:])
+ ], dim=-1),
+ "sigmoid": torch.sigmoid,
+ "linear": lambda x: x
+ }
+
+ if act_attr not in attr_activations:
+ raise ValueError(f"Unknown attribute activation: {act_attr}")
+ attr_out = attr_activations[act_attr](attr)
+
+ # Confidence activation mapping
+ conf_activations = {
+ "expp1": lambda c: 1 + c.exp(),
+ "expp0": torch.exp,
+ "sigmoid": torch.sigmoid
+ }
+ if act_conf not in conf_activations:
+ raise ValueError(f"Unknown confidence activation: {act_conf}")
+ conf_out = conf_activations[act_conf](conf)
+
+ return attr_out, conf_out
+
+ def _apply_inverse_log_transform(self, input_tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Apply inverse logarithm transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ input_tensor: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(input_tensor) * (torch.expm1(torch.abs(input_tensor)))
+
+
+
+################################################################################
+# DPT Modules
+################################################################################
+
+
+def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module with skip connection."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Initialize ResidualConvUnit.
+
+ Args:
+ features (int): Number of input/output feature channels
+ activation: Activation function to use
+ bn (bool): Whether to use batch normalization (currently unused)
+ groups (int): Number of groups for grouped convolution
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass with residual connection.
+
+ Args:
+ x (tensor): Input tensor of shape (B, C, H, W)
+
+ Returns:
+ tensor: Output tensor of shape (B, C, H, W) with residual added
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Initialize FeatureFusionBlock.
+
+ Args:
+ features (int): Number of input/output feature channels
+ activation: Activation function to use
+ deconv (bool): Whether to use deconvolution
+ bn (bool): Whether to use batch normalization
+ expand (bool): Whether to expand features (halve output channels)
+ align_corners (bool): Align corners for interpolation
+ size: Target size for upsampling
+ has_residual (bool): Whether to include residual connection
+ groups (int): Number of groups for grouped convolution
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass through the feature fusion block.
+
+ Args:
+ *xs: Variable number of input tensors. First tensor is the main input,
+ second tensor (if present) is used for residual connection.
+ size: Optional target size for upsampling. If None, uses self.size or scale_factor=2.
+
+ Returns:
+ torch.Tensor: Fused and upsampled output tensor.
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolation function to handle large tensors by chunking.
+
+ Avoids INT_MAX overflow issues in nn.functional.interpolate when dealing with
+ very large input tensors by splitting them into smaller chunks.
+
+ Args:
+ x: Input tensor to interpolate
+ size: Target output size (H, W)
+ scale_factor: Scaling factor if size is not provided
+ mode: Interpolation mode (default: "bilinear")
+ align_corners: Whether to align corners in interpolation
+
+ Returns:
+ Interpolated tensor
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
diff --git a/src/models/layers/__init__.py b/src/models/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2190558e6a4fd88d71ac8c731d2d0314350d5744
--- /dev/null
+++ b/src/models/layers/__init__.py
@@ -0,0 +1,5 @@
+from .mlp import Mlp
+from .patch_embed import PatchEmbed, PatchEmbed_Mlp
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/src/models/layers/attention.py b/src/models/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..d87ce7b00636a72550ee4c72e6cf822ee8d66d70
--- /dev/null
+++ b/src/models/layers/attention.py
@@ -0,0 +1,90 @@
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+import torch.nn.functional as F
+import torch
+
+# from torch.nn.attention import SDPBackend
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.fused_attn = fused_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.rope is not None:
+ q = self.rope(q, pos)
+ k = self.rope(k, pos)
+
+ # orig_dtype = q.dtype
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
+
+ # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ # with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ # if x.dtype != orig_dtype:
+ # x = x.to(orig_dtype)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
+ assert pos is None
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/src/models/layers/block.py b/src/models/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..0420a1ed294fbd5c501e8451456b8d917611d96f
--- /dev/null
+++ b/src/models/layers/block.py
@@ -0,0 +1,242 @@
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+XFORMERS_AVAILABLE = False
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(2)) + shift.unsqueeze(2)
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None
+ ) -> None:
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ fused_attn=fused_attn,
+ rope=rope,
+ )
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+
+ def forward(self, x: Tensor, pos=None) -> Tensor:
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
+ )
+ x = drop_add_residual_stochastic_depth(
+ x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, pos=pos)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ if pos is not None:
+ # if necessary, apply rope to the subset
+ pos = pos[brange]
+ residual = residual_func(x_subset, pos=pos)
+ else:
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/src/models/layers/drop_path.py b/src/models/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..265a429fc518f74bd896494534cc94c683b1e88c
--- /dev/null
+++ b/src/models/layers/drop_path.py
@@ -0,0 +1,29 @@
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/src/models/layers/layer_scale.py b/src/models/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..24f857f8e64d7cc0991f176c90ca7f2fab16e325
--- /dev/null
+++ b/src/models/layers/layer_scale.py
@@ -0,0 +1,17 @@
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/src/models/layers/mlp.py b/src/models/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..96dd129eca9a582a416e3be35ea3174b6268b174
--- /dev/null
+++ b/src/models/layers/mlp.py
@@ -0,0 +1,35 @@
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/src/models/layers/patch_embed.py b/src/models/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdbdef99a159a1ec043f0448ae6f34cc9303358a
--- /dev/null
+++ b/src/models/layers/patch_embed.py
@@ -0,0 +1,155 @@
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+from itertools import repeat
+import collections.abc
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+
+class PatchEmbed_Mlp(PatchEmbed):
+ def __init__(self, img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None,
+ flatten_embedding=True):
+ super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten_embedding)
+
+ self.proj = nn.Sequential(
+ PixelUnshuffle(patch_size),
+ Permute((0,2,3,1)),
+ Mlp(in_chans * patch_size**2, 4*embed_dim, embed_dim),
+ Permute((0,3,1,2)),
+ )
+
+
+class PixelUnshuffle (nn.Module):
+ def __init__(self, downscale_factor):
+ super().__init__()
+ self.downscale_factor = downscale_factor
+
+ def forward(self, input):
+ if input.numel() == 0:
+ # this is not in the original torch implementation
+ C,H,W = input.shape[-3:]
+ assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0
+ return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor)
+ else:
+ return F.pixel_unshuffle(input, self.downscale_factor)
+
+
+class Permute(torch.nn.Module):
+ dims: tuple[int, ...]
+ def __init__(self, dims: tuple[int, ...]) -> None:
+ super().__init__()
+ self.dims = tuple(dims)
+
+ def __repr__(self):
+ return f"Permute{self.dims}"
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input.permute(*self.dims)
+
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+to_2tuple = _ntuple(2)
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
diff --git a/src/models/layers/rope.py b/src/models/layers/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..f20c002876fe3de60d140a9fe82f28902447d4e2
--- /dev/null
+++ b/src/models/layers/rope.py
@@ -0,0 +1,182 @@
+# Implementation of 2D Rotary Position Embeddings (RoPE).
+
+# This module provides a clean implementation of 2D Rotary Position Embeddings,
+# which extends the original RoPE concept to handle 2D spatial positions.
+
+# Inspired by:
+# https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# https://github.com/naver-ai/rope-vit
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Dict, Tuple
+
+
+class PositionGetter:
+ """Generates and caches 2D spatial positions for patches in a grid.
+
+ This class efficiently manages the generation of spatial coordinates for patches
+ in a 2D grid, caching results to avoid redundant computations.
+
+ Attributes:
+ position_cache: Dictionary storing precomputed position tensors for different
+ grid dimensions.
+ """
+
+ def __init__(self):
+ """Initializes the position generator with an empty cache."""
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
+
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
+ """Generates spatial positions for a batch of patches.
+
+ Args:
+ batch_size: Number of samples in the batch.
+ height: Height of the grid in patches.
+ width: Width of the grid in patches.
+ device: Target device for the position tensor.
+
+ Returns:
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
+ for each position in the grid, repeated for each batch item.
+ """
+ if (height, width) not in self.position_cache:
+ y_coords = torch.arange(height, device=device)
+ x_coords = torch.arange(width, device=device)
+ positions = torch.cartesian_prod(y_coords, x_coords)
+ self.position_cache[height, width] = positions
+
+ cached_positions = self.position_cache[height, width]
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
+
+
+class RotaryPositionEmbedding2D(nn.Module):
+ """2D Rotary Position Embedding implementation.
+
+ This module applies rotary position embeddings to input tokens based on their
+ 2D spatial positions. It handles the position-dependent rotation of features
+ separately for vertical and horizontal dimensions.
+
+ Args:
+ frequency: Base frequency for the position embeddings. Default: 100.0
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
+
+ Attributes:
+ base_frequency: Base frequency for computing position embeddings.
+ scaling_factor: Factor to scale the computed frequencies.
+ frequency_cache: Cache for storing precomputed frequency components.
+ """
+
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0,):
+ """Initializes the 2D RoPE module."""
+ super().__init__()
+ self.base_frequency = frequency
+ self.scaling_factor = scaling_factor
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
+
+ def _compute_frequency_components(
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes frequency components for rotary embeddings.
+
+ Args:
+ dim: Feature dimension (must be even).
+ seq_len: Maximum sequence length.
+ device: Target device for computations.
+ dtype: Data type for the computed tensors.
+
+ Returns:
+ Tuple of (cosine, sine) tensors for frequency components.
+ """
+ cache_key = (dim, seq_len, device, dtype)
+ if cache_key not in self.frequency_cache:
+ # Compute frequency bands
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
+ inv_freq = 1.0 / (self.base_frequency**exponents)
+
+ # Generate position-dependent frequencies
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
+
+ # Compute and cache frequency components
+ angles = angles.to(dtype)
+ angles = torch.cat((angles, angles), dim=-1)
+ cos_components = angles.cos().to(dtype)
+ sin_components = angles.sin().to(dtype)
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
+
+ return self.frequency_cache[cache_key]
+
+ @staticmethod
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
+ """Performs feature rotation by splitting and recombining feature dimensions.
+
+ Args:
+ x: Input tensor to rotate.
+
+ Returns:
+ Rotated feature tensor.
+ """
+ feature_dim = x.shape[-1]
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_1d_rope(
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
+ ) -> torch.Tensor:
+ """Applies 1D rotary position embeddings along one dimension.
+
+ Args:
+ tokens: Input token features.
+ positions: Position indices.
+ cos_comp: Cosine components for rotation.
+ sin_comp: Sine components for rotation.
+
+ Returns:
+ Tokens with applied rotary position embeddings.
+ """
+ # Embed positions with frequency components
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
+
+ # Apply rotation
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
+
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Applies 2D rotary position embeddings to input tokens.
+
+ Args:
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
+ The feature dimension (dim) must be divisible by 4.
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
+ the y and x coordinates for each token.
+
+ Returns:
+ Tensor of same shape as input with applied 2D rotary position embeddings.
+
+ Raises:
+ AssertionError: If input dimensions are invalid or positions are malformed.
+ """
+ # Validate inputs
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
+
+ # Compute feature dimension for each spatial direction
+ feature_dim = tokens.size(-1) // 2
+
+ # Get frequency components
+ max_position = int(positions.max()) + 1
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
+
+ # Split features for vertical and horizontal processing
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
+
+ # Apply RoPE separately for each dimension
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
+
+ # Combine processed features
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
\ No newline at end of file
diff --git a/src/models/layers/swiglu_ffn.py b/src/models/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca9b0719b9f104c7ec1a42d6789107f17562f2ca
--- /dev/null
+++ b/src/models/layers/swiglu_ffn.py
@@ -0,0 +1,62 @@
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+# try:
+# if XFORMERS_ENABLED:
+# from xformers.ops import SwiGLU
+
+# XFORMERS_AVAILABLE = True
+# warnings.warn("xFormers is available (SwiGLU)")
+# else:
+# warnings.warn("xFormers is disabled (SwiGLU)")
+# raise ImportError
+# except ImportError:
+SwiGLU = SwiGLUFFN
+XFORMERS_AVAILABLE = False
+
+# warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
diff --git a/src/models/layers/vision_transformer.py b/src/models/layers/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..03b76a2a19b5ee5cadabf38c7596f66839627769
--- /dev/null
+++ b/src/models/layers/vision_transformer.py
@@ -0,0 +1,392 @@
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ qk_norm=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+ self.use_reentrant = False # hardcoded to False
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=True, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/src/models/models/rasterization.py b/src/models/models/rasterization.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba7cae17c1ac4879c67359edf5074e2a1e7a834e
--- /dev/null
+++ b/src/models/models/rasterization.py
@@ -0,0 +1,623 @@
+from typing import Dict, Tuple
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+from einops import rearrange
+
+from gsplat.rendering import rasterization
+from gsplat.strategy import DefaultStrategy
+
+from src.models.utils.frustum import calculate_unprojected_mask
+from src.models.utils.geometry import depth_to_world_coords_points, closed_form_inverse_se3
+from src.models.utils.camera_utils import vector_to_camera_matrices
+from src.models.utils import sh_utils, act_gs
+
+
+class Rasterizer:
+ def __init__(self, rasterization_mode="classic", packed=True, abs_grad=True, with_eval3d=False,
+ camera_model="pinhole", sparse_grad=False, distributed=False, grad_strategy=DefaultStrategy):
+ self.rasterization_mode = rasterization_mode
+ self.packed = packed
+ self.abs_grad = abs_grad
+ self.camera_model = camera_model
+ self.sparse_grad = sparse_grad
+ self.grad_strategy = grad_strategy
+ self.distributed = distributed
+ self.with_eval3d = with_eval3d
+
+ def rasterize_splats(
+ self,
+ means,
+ quats,
+ scales,
+ opacities,
+ colors,
+ camtoworlds: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Dict]:
+ render_colors, render_alphas, _ = rasterization(
+ means=means,
+ quats=quats,
+ scales=scales,
+ opacities=opacities,
+ colors=colors,
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
+ Ks=Ks, # [C, 3, 3]
+ width=width,
+ height=height,
+ packed=self.packed,
+ absgrad=(
+ self.abs_grad
+ if isinstance(self.grad_strategy, DefaultStrategy)
+ else False
+ ),
+ sparse_grad=self.sparse_grad,
+ rasterize_mode=self.rasterization_mode,
+ distributed=self.distributed,
+ camera_model=self.camera_model,
+ with_eval3d=self.with_eval3d,
+ render_mode="RGB+ED",
+ **kwargs,
+ )
+ return render_colors[..., :3], render_colors[..., 3:], render_alphas
+
+ def rasterize_batches(self, means, quats, scales, opacities, colors, viewmats, Ks, width, height, **kwargs):
+ rendered_colors, rendered_depths, rendered_alphas = [], [], []
+ batch_size = len(means)
+
+ for i in range(batch_size):
+ means_i = means[i] # [N, 4]
+ quats_i = quats[i] # [N, 4]
+ scales_i = scales[i] # [N, 3]
+ opacities_i = opacities[i] # [N,]
+ colors_i = colors[i] # [N, 3]
+ viewmats_i = viewmats[i] # [V, 4, 4]
+ Ks_i = Ks[i] # [V, 3, 3]
+
+ render_colors_i, render_depths_i, render_alphas_i = self.rasterize_splats(
+ means_i, quats_i, scales_i, opacities_i, colors_i, viewmats_i, Ks_i, width, height, **kwargs
+ )
+
+ rendered_colors.append(render_colors_i) # V H W 3
+ rendered_depths.append(render_depths_i) # V H W 1
+ rendered_alphas.append(render_alphas_i) # V H W 1
+
+ rendered_colors = torch.stack(rendered_colors, dim=0) # B V H W 3
+ rendered_depths = torch.stack(rendered_depths, dim=0) # B V H W 1
+ rendered_alphas = torch.stack(rendered_alphas, dim=0) # B V H W 1
+
+ return rendered_colors, rendered_depths, rendered_alphas
+
+
+class GaussianSplatRenderer(nn.Module):
+ def __init__(
+ self,
+ feature_dim: int = 256, # Output channels of gs_feat_head
+ sh_degree: int = 0,
+ predict_offset: bool = False,
+ predict_residual_sh: bool = True,
+ enable_prune: bool = True,
+ voxel_size: float = 0.002, # Default voxel size for prune_gs
+ using_gtcamera_splat: bool = False,
+ render_novel_views: bool = False,
+ enable_conf_filter: bool = False, # Enable confidence filtering
+ conf_threshold_percent: float = 30.0, # Confidence threshold percentage
+ max_gaussians: int = 5000000, # Maximum number of Gaussians
+ debug=False,
+ ):
+ super().__init__()
+
+ self.feature_dim = feature_dim
+ self.sh_degree = sh_degree
+ self.nums_sh = (sh_degree + 1) ** 2
+ self.predict_offset = predict_offset
+ self.predict_residual_sh = predict_residual_sh
+ self.voxel_size = voxel_size
+ self.enable_prune = enable_prune
+ self.using_gtcamera_splat = using_gtcamera_splat
+ self.render_novel_views = render_novel_views
+ self.enable_conf_filter = enable_conf_filter
+ self.conf_threshold_percent = conf_threshold_percent
+ self.max_gaussians = max_gaussians
+ self.debug = debug
+
+ # Predict Gaussian parameters from GS features (quaternions/scales/opacities/SH/weights/optional offsets)
+ if self.predict_offset:
+ splits_and_inits = [
+ (4, 1.0, 0.0), # quats
+ (3, 0.00003, -7.0), # scales
+ (1, 1.0, -2.0), # opacities
+ (3 * self.nums_sh, 1.0, 0.0), # residual_sh
+ (1, 1.0, -2.0), # weights
+ (3, 0.001, 0.001), # offsets
+ ]
+ gaussian_raw_channels = 4 + 3 + 1 + self.nums_sh * 3 + 1 + 3
+ else:
+ splits_and_inits = [
+ (4, 1.0, 0.0), # quats
+ (3, 0.00003, -7.0), # scales
+ (1, 1.0, -2.0), # opacities
+ (3 * self.nums_sh, 1.0, 0.0), # residual_sh
+ (1, 1.0, -2.0), # weights
+ ]
+ gaussian_raw_channels = 4 + 3 + 1 + self.nums_sh * 3 + 1
+
+ self.gs_head = nn.Sequential(
+ nn.Conv2d(feature_dim // 2, feature_dim, kernel_size=3, padding=1, bias=False),
+ nn.ReLU(True),
+ nn.Conv2d(feature_dim, gaussian_raw_channels, kernel_size=1),
+ )
+ # Initialize weights and biases of the final layer by segments
+ final_conv_layer = self.gs_head[-1]
+ start_channels = 0
+ for out_channel, s, b in splits_and_inits:
+ nn.init.xavier_uniform_(final_conv_layer.weight[start_channels:start_channels+out_channel], s)
+ nn.init.constant_(final_conv_layer.bias[start_channels:start_channels+out_channel], b)
+ start_channels += out_channel
+
+ # Rasterizer
+ self.rasterizer = Rasterizer()
+
+ # ======== Main entry point: Complete GS rendering and fill results back to predictions ========
+ def render(
+ self,
+ gs_feats: torch.Tensor, # [B, S(+V), 3, H, W]
+ images: torch.Tensor, # [B, S+V, 3, H, W]
+ predictions: Dict[str, torch.Tensor], # From WorldMirror: pose/depth/pts3d etc
+ views: Dict[str, torch.Tensor],
+ context_predictions = None,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Returns predictions with the following fields filled:
+ - rendered_colors / rendered_depths / (rendered_alphas during training)
+ - gt_colors / gt_depths / valid_masks
+ - splats / rendered_extrinsics / rendered_intrinsics
+ """
+ H, W = images.shape[-2:]
+ S = views["context_nums"] if "context_nums" in views else images.shape[1]
+ V = images.shape[1] - S
+
+ # 1) Predict GS features from tokens, then convert to Gaussian parameters
+ gs_feats_reshape = rearrange(gs_feats, "b s c h w -> (b s) c h w")
+ gs_params = self.gs_head(gs_feats_reshape)
+
+ # 2) Select cameras (predicted or GT), and organize supervision data (gt_colors, gt_depths, valid_masks)
+ if self.training:
+ if self.render_novel_views and V > 0:
+ pred_all_extrinsic, pred_all_intrinsic = self.prepare_cameras(views, S+V)
+ render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
+ render_images = images
+ gt_colors = render_images.permute(0, 1, 3, 4, 2)
+ gt_depths = views["depthmap"] # [B, S+V, H, W]
+
+ gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
+ gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
+ unproject_masks = calculate_unprojected_mask(views, S) # [B, V, H, W]
+ valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
+ else:
+ # Only render source views
+ render_viewmats, render_Ks = self.prepare_cameras(views, S)
+ render_images = views["img"][:, :S]
+ gt_colors = render_images.permute(0, 1, 3, 4, 2)
+ gt_depths = views["depthmap"][:, :S]
+ gt_valid_masks = views["valid_mask"][:, :S]
+ valid_masks = gt_valid_masks
+ else:
+ # Re-predict cameras for novel views and perform translation/scale alignment
+ Bx = images.shape[0]
+ pred_all_extrinsic, pred_all_intrinsic = self.prepare_prediction_cameras(predictions, S + V, hw=(H, W))
+ pred_all_extrinsic = pred_all_extrinsic.reshape(Bx, S + V, 4, 4)
+ pred_all_source_extrinsic = pred_all_extrinsic[:, :S]
+
+ scale_factor = 1.0
+ if context_predictions is not None:
+ pred_source_extrinsic, _ = self.prepare_prediction_cameras(context_predictions, S, hw=(H, W))
+ pred_source_extrinsic = pred_source_extrinsic.reshape(Bx, S, 4, 4)
+ scale_factor = pred_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) / (
+ pred_all_source_extrinsic[:, :, :3, 3].mean(dim=(1, 2), keepdim=True) + 1e-6
+ )
+
+ pred_all_extrinsic[..., :3, 3] = pred_all_extrinsic[..., :3, 3] * scale_factor
+
+ render_viewmats, render_Ks = pred_all_extrinsic, pred_all_intrinsic
+ render_images = images
+ gt_colors = render_images.permute(0, 1, 3, 4, 2)
+
+ # Handle pure inference case where views may not have ground truth data
+ gt_depths = views.get("depthmap")
+ valid_masks = None
+ if gt_depths is not None:
+ if views.get("gt_depth") is not None and views["gt_depth"]:
+ unproject_masks = calculate_unprojected_mask(views, S)
+ gt_valid_masks_src = views["valid_mask"][:, :S] # [B, S, H, W]
+ gt_valid_masks_tgt = views["valid_mask"][:, S:] # [B, V, H, W]
+ gt_valid_masks = torch.cat([gt_valid_masks_src, (gt_valid_masks_tgt & unproject_masks)], dim=1)
+ else:
+ gt_valid_masks = views.get("valid_mask")
+ valid_masks = gt_valid_masks
+
+ # 3) Generate splats from gs_params + predictions, and perform voxel merging
+ if self.training and self.using_gtcamera_splat:
+ splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+gtcamera", debug=False)
+ else:
+ splats = self.prepare_splats(views, predictions, images, gs_params, S, V, position_from="gsdepth+predcamera", context_predictions=context_predictions, debug=False)
+ splats_raw = {k: v.clone() for k, v in splats.items()}
+
+ # Apply confidence filtering before pruning
+ if self.enable_conf_filter and "depth_conf" in predictions:
+ splats = self.apply_confidence_filter(splats, predictions["depth_conf"])
+
+ if self.enable_prune:
+ splats = self.prune_gs(splats, voxel_size=self.voxel_size)
+
+ # 4) Rasterization rendering (training: chunked rendering + novel view valid mask correction; evaluation: view-by-view)
+ if self.training:
+ if self.render_novel_views and V > 0:
+ indices = np.arange(S+V)
+ else:
+ indices = np.arange(S)
+
+ render_viewmats = render_viewmats[:, indices]
+ render_Ks = render_Ks[:, indices]
+ gt_colors = gt_colors[:, indices]
+ if gt_depths is not None:
+ gt_depths = gt_depths[:, indices]
+ if valid_masks is not None:
+ valid_masks = valid_masks[:, indices]
+
+ # Prevent OOM by using chunked rendering
+ rendered_colors_list, rendered_depths_list, rendered_alphas_list = [], [], []
+ chunk_size = 4
+ for i in range(0, gt_colors.shape[1], chunk_size):
+ end_idx = min(i + chunk_size, gt_colors.shape[1])
+ viewmats_i = render_viewmats[:, i:end_idx]
+ Ks_i = render_Ks[:, i:end_idx]
+
+ rendered_colors, rendered_depths, rendered_alphas = self.rasterizer.rasterize_batches(
+ splats["means"], splats["quats"], splats["scales"], splats["opacities"],
+ splats["sh"] if "sh" in splats else splats["colors"],
+ viewmats_i.detach(), Ks_i.detach(),
+ width=render_images.shape[-1], height=render_images.shape[-2],
+ sh_degree=min(self.sh_degree, 0) if "sh" in splats else None,
+ )
+ rendered_colors_list.append(rendered_colors)
+ rendered_depths_list.append(rendered_depths)
+ rendered_alphas_list.append(rendered_alphas)
+
+ rendered_colors = torch.cat(rendered_colors_list, dim=1)
+ rendered_depths = torch.cat(rendered_depths_list, dim=1)
+ rendered_alphas = torch.cat(rendered_alphas_list, dim=1)
+
+ if self.training and self.render_novel_views and V > 0:
+ nvs_rendered_mask = rendered_alphas[:, S:, ..., 0].detach() > 0.1
+ valid_masks[:, S:] = nvs_rendered_mask & valid_masks[:, S:]
+
+ # 5) return predictions
+ predictions["rendered_colors"] = rendered_colors
+ predictions["rendered_depths"] = rendered_depths
+ predictions["gt_colors"] = gt_colors
+ predictions["gt_depths"] = gt_depths
+ predictions["valid_masks"] = valid_masks
+ predictions["splats"] = splats
+ predictions["splats_raw"] = splats_raw
+ predictions["rendered_extrinsics"] = render_viewmats
+ predictions["rendered_intrinsics"] = render_Ks
+
+ return predictions
+
+ def apply_confidence_filter(self, splats, gs_depth_conf):
+ """
+ Apply confidence filtering to Gaussian splats before pruning.
+ Discard bottom p% confidence points, keep top (100-p)%.
+
+ Args:
+ splats: Dictionary containing Gaussian parameters
+ gs_depth_conf: Confidence tensor [B, S, H, W]
+
+ Returns:
+ Filtered splats dictionary
+ """
+ if not self.enable_conf_filter or gs_depth_conf is None:
+ return splats
+
+ device = splats["means"].device
+ B, N = splats["means"].shape[:2]
+
+ # Flatten confidence: [B, S, H, W] -> [B, N]
+ conf = gs_depth_conf.flatten(1).to(device)
+ # Mask invalid/very small values
+ conf = conf.masked_fill(conf <= 1e-5, float("-inf"))
+
+ # Keep top (100-p)% points, discard bottom p%
+ if self.conf_threshold_percent > 0:
+ keep_from_percent = int(np.ceil(N * (100.0 - self.conf_threshold_percent) / 100.0))
+ else:
+ keep_from_percent = N
+ K = max(1, min(self.max_gaussians, keep_from_percent))
+
+ # Select top-K indices for each batch (deterministic, no randomness)
+ topk_idx = torch.topk(conf, K, dim=1, largest=True, sorted=False).indices # [B, K]
+
+ filtered = {}
+ mask_keys = ["means", "quats", "scales", "opacities", "sh", "weights"]
+
+ for key in splats.keys():
+ if key in mask_keys and key in splats:
+ x = splats[key]
+ if x.ndim == 2: # [B, N]
+ filtered[key] = torch.gather(x, 1, topk_idx)
+ else:
+ # Expand indices to match tensor dimensions
+ expand_idx = topk_idx.clone()
+ for i in range(x.ndim - 2):
+ expand_idx = expand_idx.unsqueeze(-1)
+ expand_idx = expand_idx.expand(-1, -1, *x.shape[2:])
+ filtered[key] = torch.gather(x, 1, expand_idx)
+ else:
+ filtered[key] = splats[key]
+
+ return filtered
+
+ def prune_gs(self, splats, voxel_size=0.002):
+ """
+ Prune Gaussian splats by merging those in the same voxel.
+
+ Args:
+ splats: Dictionary containing Gaussian parameters
+ voxel_size: Size of voxels for spatial grouping
+
+ Returns:
+ Dictionary with pruned splats
+ """
+ B = splats["means"].shape[0]
+ merged_splats_list = []
+ device = splats["means"].device
+
+ for i in range(B):
+ # Extract splats for current batch
+ splats_i = {k: splats[k][i] for k in ["means", "quats", "scales", "opacities", "sh", "weights"]}
+
+ # Compute voxel indices
+ coords = splats_i["means"]
+ voxel_indices = (coords / voxel_size).floor().long()
+ min_indices = voxel_indices.min(dim=0)[0]
+ voxel_indices = voxel_indices - min_indices
+ max_dims = voxel_indices.max(dim=0)[0] + 1
+
+ # Flatten 3D voxel indices to 1D
+ flat_indices = (voxel_indices[:, 0] * max_dims[1] * max_dims[2] +
+ voxel_indices[:, 1] * max_dims[2] +
+ voxel_indices[:, 2])
+
+ # Find unique voxels and inverse mapping
+ unique_voxels, inverse_indices = torch.unique(flat_indices, return_inverse=True)
+ K = len(unique_voxels)
+
+ # Initialize merged splats
+ merged = {
+ "means": torch.zeros((K, 3), device=device),
+ "quats": torch.zeros((K, 4), device=device),
+ "scales": torch.zeros((K, 3), device=device),
+ "opacities": torch.zeros(K, device=device),
+ "sh": torch.zeros((K, self.nums_sh, 3), device=device)
+ }
+
+ # Get weights and compute weight sums per voxel
+ weights = splats_i["weights"]
+ weight_sums = torch.zeros(K, device=device)
+ weight_sums.scatter_add_(0, inverse_indices, weights)
+ weight_sums = torch.clamp(weight_sums, min=1e-8)
+
+ # Merge means (weighted average)
+ for d in range(3):
+ merged["means"][:, d].scatter_add_(0, inverse_indices,
+ splats_i["means"][:, d] * weights)
+ merged["means"] = merged["means"] / weight_sums.unsqueeze(1)
+
+ # Merge spherical harmonics (weighted average)
+ for d in range(3):
+ merged["sh"][:, 0, d].scatter_add_(0, inverse_indices,
+ splats_i["sh"][:, 0, d] * weights)
+ merged["sh"] = merged["sh"] / weight_sums.unsqueeze(-1).unsqueeze(-1)
+
+ # Merge opacities (weighted sum of squares)
+ merged["opacities"].scatter_add_(0, inverse_indices, weights * weights)
+ merged["opacities"] = merged["opacities"] / weight_sums
+
+ # Merge scales (weighted average)
+ for d in range(3):
+ merged["scales"][:, d].scatter_add_(0, inverse_indices,
+ splats_i["scales"][:, d] * weights)
+ merged["scales"] = merged["scales"] / weight_sums.unsqueeze(1)
+
+ # Merge quaternions (weighted average + normalization)
+ for d in range(4):
+ merged["quats"][:, d].scatter_add_(0, inverse_indices,
+ splats_i["quats"][:, d] * weights)
+ quat_norms = torch.norm(merged["quats"], dim=1, keepdim=True)
+ merged["quats"] = merged["quats"] / torch.clamp(quat_norms, min=1e-8)
+
+ merged_splats_list.append(merged)
+
+ # Reorganize output
+ output = {}
+ for key in ["means", "sh", "opacities", "scales", "quats"]:
+ output[key] = [merged[key] for merged in merged_splats_list]
+
+ return output
+
+ def prepare_splats(self, views, predictions, images, gs_params, context_nums, target_nums, context_predictions=None, position_from="gsdepth+predcamera", debug=False):
+ """
+ Prepare Gaussian splats from model predictions and input data.
+
+ Args:
+ views: Dictionary containing view data (camera poses, intrinsics, etc.)
+ predictions: Model predictions including depth, pose_enc, etc.
+ images: Input images [B, S_all, 3, H, W]
+ gs_params: Gaussian splatting parameters from model
+ context_nums: Number of context views (S)
+ target_nums: Number of target views (V)
+ context_predictions: Optional context predictions for camera poses
+ position_from: Method to compute 3D positions ("pts3d", "preddepth+predcamera", "gsdepth+predcamera", "gsdepth+gtcamera")
+ debug: Whether to use debug mode with ground truth data
+
+ Returns:
+ splats: Dictionary containing prepared Gaussian splat parameters
+ """
+ B, S_all, _, H, W = images.shape
+ S, V = context_nums, target_nums
+ splats = {}
+
+ # Only take parameters from source view branch
+ gs_params = rearrange(gs_params, "(b s) c h w -> b s h w c", b=B)[:, :S]
+ splats["gs_feats"] = gs_params.reshape(B, S*H*W, -1)
+
+ # Split Gaussian parameters based on whether offset prediction is enabled
+ if self.predict_offset:
+ quats, scales, opacities, residual_sh, weights, offsets = torch.split(
+ gs_params, [4, 3, 1, self.nums_sh * 3, 1, 3], dim=-1
+ )
+ offsets = act_gs.reg_dense_offsets(offsets.reshape(B, S * H * W, 3))
+ splats["offsets"] = offsets
+ else:
+ quats, scales, opacities, residual_sh, weights = torch.split(
+ gs_params, [4, 3, 1, self.nums_sh * 3, 1], dim=-1
+ )
+ offsets = 0.
+
+ # Apply activation functions to Gaussian parameters
+ splats["quats"] = act_gs.reg_dense_rotation(quats.reshape(B, S * H * W, 4))
+ splats["scales"] = act_gs.reg_dense_scales(scales.reshape(B, S * H * W, 3)).clamp_max(0.3)
+ splats["opacities"] = act_gs.reg_dense_opacities(opacities.reshape(B, S * H * W))
+ residual_sh = act_gs.reg_dense_sh(residual_sh.reshape(B, S * H * W, self.nums_sh * 3))
+
+ # Handle spherical harmonics (SH) coefficients
+ if self.predict_residual_sh:
+ new_sh = torch.zeros_like(residual_sh)
+ new_sh[..., 0, :] = sh_utils.RGB2SH(
+ images[:, :S].permute(0, 1, 3, 4, 2).reshape(B, S * H * W, 3)
+ )
+ splats['sh'] = new_sh + residual_sh
+ splats['residual_sh'] = residual_sh
+ else:
+ splats['sh'] = residual_sh
+
+ splats["weights"] = act_gs.reg_dense_weights(weights.reshape(B, S * H * W))
+
+ # Compute 3D positions based on specified method
+ if position_from == "pts3d":
+ pts3d = predictions["pts3d"][:, :S].reshape(B, S * H * W, 3)
+ splats["means"] = pts3d + offsets
+
+ elif position_from == "preddepth+predcamera":
+ depth = predictions["depth"][:, :S].reshape(B * S, H, W)
+ if context_predictions is not None:
+ pose3x4, intrinsic = vector_to_camera_matrices(
+ context_predictions["camera_params"][:, :S].reshape(B * S, -1), (H, W)
+ )
+ else:
+ pose3x4, intrinsic = vector_to_camera_matrices(
+ predictions["camera_params"][:, :S].reshape(B * S, -1), (H, W)
+ )
+ pose4x4 = torch.eye(4, device=pose3x4.device, dtype=pose3x4.dtype)[None].repeat(B * S, 1, 1)
+ pose4x4[:, :3, :4] = pose3x4
+ extrinsics = closed_form_inverse_se3(pose4x4)
+ pts3d, _, _ = depth_to_world_coords_points(depth, extrinsics.detach(), intrinsic.detach())
+ pts3d = pts3d.reshape(B, S * H * W, 3)
+ splats["means"] = pts3d + offsets
+
+ elif position_from == "gsdepth+predcamera":
+ depth = predictions["gs_depth"][:, :S].reshape(B * S, H, W)
+ if context_predictions is not None:
+ pose3x4, intrinsic = vector_to_camera_matrices(
+ context_predictions["camera_params"][:, :S].reshape(B * S, -1), (H, W)
+ )
+ else:
+ pose3x4, intrinsic = vector_to_camera_matrices(
+ predictions["camera_params"][:, :S].reshape(B * S, -1), (H, W)
+ )
+ pose4x4 = torch.eye(4, device=pose3x4.device, dtype=pose3x4.dtype)[None].repeat(B * S, 1, 1)
+ pose4x4[:, :3, :4] = pose3x4
+ extrinsics = closed_form_inverse_se3(pose4x4)
+ pts3d, _, _ = depth_to_world_coords_points(depth, extrinsics.detach(), intrinsic.detach())
+ pts3d = pts3d.reshape(B, S * H * W, 3)
+ splats["means"] = pts3d + offsets
+
+ elif position_from == "gsdepth+gtcamera":
+ depth = predictions["gs_depth"][:, :S].reshape(B * S, H, W)
+ pose4x4 = views["camera_pose"][:, :S].reshape(B * S, 4, 4)
+ intrinsic = views["camera_intrinsics"][:, :S].reshape(B * S, 3, 3)
+ extrinsics = pose4x4
+ pts3d, _, _ = depth_to_world_coords_points(depth, extrinsics.detach(), intrinsic.detach())
+ pts3d = pts3d.reshape(B, S * H * W, 3)
+ splats["means"] = pts3d + offsets
+
+ else:
+ raise ValueError(f"Invalid position_from={position_from}")
+
+ return splats
+
+ def prepare_cameras(self, views, nums):
+ viewmats = views['camera_pose'][:, :nums]
+ Ks = views['camera_intrinsics'][:, :nums]
+ return viewmats, Ks
+
+ def prepare_prediction_cameras(self, predictions, nums, hw: Tuple[int, int]):
+ """
+ Prepare camera matrices from predicted pose encodings.
+
+ Args:
+ predictions: Dictionary containing pose_enc predictions
+ nums: Number of views to process
+ hw: Tuple of (height, width)
+
+ Returns:
+ viewmats: Camera view matrices [B, S, 4, 4]
+ Ks: Camera intrinsic matrices [B, S, 3, 3]
+ """
+ B = predictions["camera_params"].shape[0]
+ H, W = hw
+
+ # Convert pose encoding to extrinsics and intrinsics
+ pose3x4, intrinsic = vector_to_camera_matrices(
+ predictions["camera_params"][:, :nums].reshape(B * nums, -1), (H, W)
+ )
+
+ # Convert to homogeneous coordinates and compute view matrices
+ pose4x4 = torch.eye(4, device=pose3x4.device, dtype=pose3x4.dtype)[None].repeat(B * nums, 1, 1)
+ pose4x4[:, :3, :4] = pose3x4
+
+ viewmats = closed_form_inverse_se3(pose4x4).reshape(B, nums, 4, 4)
+ Ks = intrinsic.reshape(B, nums, 3, 3)
+
+ return viewmats, Ks
+
+
+
+if __name__ == "__main__":
+ device = "cuda:0"
+ means = torch.randn((100, 3), device=device)
+ quats = torch.randn((100, 4), device=device)
+ scales = torch.rand((100, 3), device=device) * 0.1
+ opacities = torch.rand((100,), device=device)
+ colors = torch.rand((100, 3), device=device)
+
+ viewmats = torch.eye(4, device=device)[None, :, :].repeat(10, 1, 1)
+ Ks = torch.tensor([
+ [300., 0., 150.], [0., 300., 100.], [0., 0., 1.]], device=device)[None, :, :].repeat(10, 1, 1)
+ width, height = 300, 200
+
+ rasterizer = Rasterizer()
+ splats = {
+ "means": means,
+ "quats": quats,
+ "scales": scales,
+ "opacities": opacities,
+ "colors": colors,
+ }
+ colors, alphas, _ = rasterizer.rasterize_splats(splats, viewmats, Ks, width, height)
+
\ No newline at end of file
diff --git a/src/models/models/visual_transformer.py b/src/models/models/visual_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bab0907dbc8389a23219c7bf5182aad9ce1bd0d
--- /dev/null
+++ b/src/models/models/visual_transformer.py
@@ -0,0 +1,431 @@
+import logging
+import random
+from typing import Tuple, List
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from src.models.layers import PatchEmbed, PatchEmbed_Mlp
+from src.models.layers.block import Block
+from src.models.layers.rope import RotaryPositionEmbedding2D, PositionGetter
+from src.models.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class VisualGeometryTransformer(nn.Module):
+ """
+ The VisualGeometryTransformer applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ enable_condition (bool): Whether to enable conditioning inputs.
+ sampling_strategy (str): Sampling strategy for patches.
+ fixed_patch_embed (bool): Whether to fix patch embedding weights.
+ enable_interpolation (bool): Whether to enable position interpolation.
+ max_resolution (int): Maximum resolution for position interpolation.
+ condition_strategy (list[str]): Strategy for each conditioning input.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ enable_cond=False,
+ sampling_strategy="uniform",
+ fixed_patch_embed=False,
+ enable_interpolation=False,
+ max_resolution=2044,
+ condition_strategy=["token", "pow3r", "token"],
+ intermediate_idxs: List[int] = [4, 11, 17, 23]
+ ):
+ super().__init__()
+ # Store config parameters
+ self.enable_cond = enable_cond
+ self.sampling_strategy = sampling_strategy
+ self.cond_methods = condition_strategy
+ self.intermediate_idxs = intermediate_idxs
+ self.depth = depth
+ self.patch_size = patch_size
+
+ # Initialize patch embedding module
+ self.patch_embed = self._init_patch_embedding_module(
+ patch_embed, img_size, patch_size, num_register_tokens,
+ embed_dim=embed_dim, is_fixed=fixed_patch_embed
+ )
+
+ # Initialize conditioning embeddings if enabled
+ if self.enable_cond:
+ self._init_cond_embeddings(embed_dim, img_size, patch_size, num_register_tokens)
+
+ # Initialize rotary position embedding
+ self._init_rotary_position_embedding(rope_freq, enable_interpolation, max_resolution)
+
+ # Initialize transformer blocks
+ self._init_transformer_blocks(block_fn, embed_dim, num_heads, mlp_ratio, qkv_bias, proj_bias, ffn_bias, init_values, qk_norm)
+
+ # Initialize learnable tokens
+ self._init_learnable_tokens(embed_dim, num_register_tokens)
+
+ # Calculate patch start index based on conditioning
+ if self.enable_cond:
+ self.patch_start_idx = 1 + num_register_tokens + 1 + 1 # camera + register + pose + rays
+ else:
+ self.patch_start_idx = 1 + num_register_tokens # camera + register
+
+ # Register normalization constants
+ for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
+ self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
+
+ self.use_reentrant_checkpointing = False
+
+ def _init_patch_embedding_module(
+ self,
+ patch_embed_type,
+ img_size,
+ patch_size,
+ num_reg_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ is_fixed=False,
+ in_chans=3
+ ):
+ """
+ Create the patch embedding module. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+ if "conv" in patch_embed_type:
+ if 'mlp' in patch_embed_type:
+ patch_embed_module = PatchEmbed_Mlp(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim
+ )
+ else:
+ patch_embed_module = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim
+ )
+ else:
+ vit_models = {
+ "dinov2_vitl14_reg": vit_large,
+ "dinov2_vitb14_reg": vit_base,
+ "dinov2_vits14_reg": vit_small,
+ "dinov2_vitg2_reg": vit_giant2,
+ }
+
+ patch_embed_module = vit_models[patch_embed_type](
+ img_size=img_size,
+ patch_size=patch_size,
+ num_register_tokens=num_reg_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ block_chunks=block_chunks,
+ init_values=init_values,
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(patch_embed_module, "mask_token"):
+ patch_embed_module.mask_token.requires_grad_(False)
+
+ if is_fixed:
+ for param in patch_embed_module.parameters():
+ param.requires_grad_(False)
+
+ return patch_embed_module
+
+ def _init_cond_embeddings(self, embed_dim, img_size, patch_size, num_reg_tokens):
+ """Initialize conditioning embeddings for camera, depth, and rays."""
+ assert self.cond_methods is not None
+ assert self.cond_methods[0] == "token"
+
+ # Camera pose embedding
+ if self.cond_methods[0] == "token":
+ self.pose_embed = nn.Sequential(
+ nn.Linear(7, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True)
+ )
+ else:
+ raise NotImplementedError
+
+ # Depth map embedding
+ if self.cond_methods[1] == "pow3r":
+ self.depth_embed = self._init_patch_embedding_module(
+ "conv+mlp", img_size, patch_size, num_reg_tokens,
+ embed_dim=embed_dim, in_chans=1
+ )
+ else:
+ raise NotImplementedError
+
+ # Ray direction embedding
+ if self.cond_methods[2] == "token":
+ self.ray_embed = nn.Sequential(
+ nn.Linear(4, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True)
+ )
+ else:
+ raise NotImplementedError
+
+ def _init_rotary_position_embedding(self, rope_freq, enable_interpolation, max_resolution):
+ self.rope = RotaryPositionEmbedding2D(
+ frequency=rope_freq,
+ ) if rope_freq > 0 else None
+ self.pos_getter = PositionGetter() if self.rope is not None else None
+
+ def _init_transformer_blocks(self, block_fn, embed_dim, num_heads, mlp_ratio, qkv_bias, proj_bias, ffn_bias, init_values, qk_norm):
+ self.frame_blocks = nn.ModuleList([
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(self.depth)
+ ])
+
+ self.global_blocks = nn.ModuleList([
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope
+ )
+ for _ in range(self.depth)
+ ])
+
+ def _init_learnable_tokens(self, embed_dim, num_reg_tokens):
+ """Initialize learnable tokens."""
+ self.cam_token = nn.Parameter(torch.zeros(1, 2, 1, embed_dim))
+ self.reg_token = nn.Parameter(torch.zeros(1, 2, num_reg_tokens, embed_dim))
+ nn.init.normal_(self.cam_token, std=1e-6)
+ nn.init.normal_(self.reg_token, std=1e-6)
+
+ def forward(self, images: torch.Tensor, priors: List | None=None, cond_flags: List[int]=[0,0,0], ctx_frames: int=None) -> Tuple[List[torch.Tensor], int]:
+ """
+ Args:
+ images: Input images with shape [B, S, 3, H, W], in range [0, 1]
+ priors: Optional tuple of (depth, rays, poses) for conditioning
+ cond_flags: List indicating which conditions to use [pose, depth, rays]
+ ctx_frames: Number of context frames to use
+
+ Returns:
+ (list[torch.Tensor], int): List of attention block outputs and patch_start_idx
+ """
+ depth_maps, ray_dirs, poses = priors if priors is not None else (None, None, None)
+
+ # Slice to context frames if specified
+ if ctx_frames is not None:
+ for var_name in ['images', 'depth_maps', 'ray_dirs', 'poses']:
+ var = locals()[var_name]
+ if var is not None:
+ locals()[var_name] = var[:, :ctx_frames].clone()
+
+ # Process image tokens
+ b, seq_len, ch, h, w = images.shape
+ if ch != 3:
+ raise ValueError(f"Expected 3 input channels, got {ch}")
+
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
+ images = (images - self._resnet_mean) / self._resnet_std
+ images = images.view(b * seq_len, ch, h, w)
+ patch_tokens = self.patch_embed(images)
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ _, patch_count, embed_dim = patch_tokens.shape
+
+ # Prepare special tokens
+ cam_tokens = expand_and_flatten_special_tokens(self.cam_token, b, seq_len)
+ reg_tokens = expand_and_flatten_special_tokens(self.reg_token, b, seq_len)
+
+ # Process all tokens (optional conditioning)
+ if self.enable_cond:
+ pose_tokens, depth_tokens, ray_tokens = self._process_conditioning(depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags)
+ # Add condition tokens to patch tokens
+ patch_tokens = patch_tokens + depth_tokens
+ all_tokens = torch.cat([cam_tokens, reg_tokens, pose_tokens, ray_tokens, patch_tokens], dim=1)
+ else:
+ all_tokens = torch.cat([cam_tokens, reg_tokens, patch_tokens], dim=1)
+
+ _, patch_count, embed_dim = all_tokens.shape
+
+ # Position embedding
+ pos_emb = None
+ if self.rope is not None:
+ pos_emb = self.pos_getter(b * seq_len, h // self.patch_size, w // self.patch_size, device=images.device)
+ if self.patch_start_idx > 0:
+ pos_emb = pos_emb + 1
+ special_pos = torch.zeros(b * seq_len, self.patch_start_idx, 2, device=images.device, dtype=pos_emb.dtype)
+ pos_emb = torch.cat([special_pos, pos_emb], dim=1)
+
+ # Forward through attention blocks
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
+ outputs = []
+ global_tokens = None
+ for idx in range(self.depth):
+ local_tokens = self._process_attention_blocks(
+ tokens=all_tokens if global_tokens is None else global_tokens,
+ b=b,
+ seq_len=seq_len,
+ patch_count=patch_count,
+ embed_dim=embed_dim,
+ block_idx=idx,
+ blocks=self.frame_blocks,
+ block_type='frame',
+ pos=pos_emb,
+ )
+ global_tokens = self._process_attention_blocks(
+ tokens=local_tokens,
+ b=b,
+ seq_len=seq_len,
+ patch_count=patch_count,
+ embed_dim=embed_dim,
+ block_idx=idx,
+ blocks=self.global_blocks,
+ block_type='global',
+ pos=pos_emb,
+ )
+
+ # Combine frame and global intermediates
+ if idx in self.intermediate_idxs:
+ combined_out = torch.cat([local_tokens, global_tokens], dim=-1)
+ outputs.append(combined_out)
+
+ return outputs, self.patch_start_idx
+
+ def _process_conditioning(self, depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags):
+ """Process conditioning inputs."""
+ h, w = images.shape[-2:]
+ if self.training:
+ assert self.sampling_strategy is not None
+ if self.sampling_strategy == "uniform":
+ pose_prob = depth_prob = rays_prob = 0.5
+ else:
+ raise ValueError(f"Unknown sampling strategy: {self.sampling_strategy}")
+
+ # Process camera pose embedding
+ use_poses = (self.training and random.random() < pose_prob) or (not self.training and cond_flags[0] == 1 and poses is not None)
+ if use_poses:
+ poses = poses.view(b*seq_len, -1)
+ pose_tokens = self.pose_embed(poses).unsqueeze(1)
+ else:
+ pose_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
+
+ # Process depth map embedding
+ use_depth = (self.training and random.random() < depth_prob) or (not self.training and cond_flags[1] == 1 and depth_maps is not None)
+ if use_depth:
+ depth_maps = depth_maps.view(b*seq_len, 1, h, w)
+ depth_tokens = self.depth_embed(depth_maps).reshape(b * seq_len, patch_count, embed_dim)
+ else:
+ depth_tokens = torch.zeros((b*seq_len, patch_count, embed_dim), device=images.device, dtype=images.dtype)
+
+ # Process ray direction embedding
+ use_rays = (self.training and random.random() < rays_prob) or (not self.training and cond_flags[2] == 1 and ray_dirs is not None)
+ if use_rays:
+ ray_dirs = ray_dirs.view(b*seq_len, -1)
+ ray_tokens = self.ray_embed(ray_dirs).unsqueeze(1)
+ else:
+ ray_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
+
+ return pose_tokens, depth_tokens, ray_tokens
+
+ def _process_attention_blocks(self, tokens, b, seq_len, patch_count, embed_dim, block_idx, blocks, block_type, pos=None):
+ """Process attention blocks with tokens in shape (B*S, P, C)."""
+ token_shape = (b, seq_len, patch_count, embed_dim)
+ if block_type == 'frame': # local
+ target_shape = (b * seq_len, patch_count, embed_dim)
+ pos_target_shape = (b * seq_len, patch_count, 2) if pos is not None else None
+ else: # global
+ target_shape = (b, seq_len * patch_count, embed_dim)
+ pos_target_shape = (b, seq_len * patch_count, 2) if pos is not None else None
+
+ if tokens.shape != target_shape:
+ tokens = tokens.view(*target_shape)
+
+ if pos is not None and pos.shape != pos_target_shape:
+ pos = pos.view(*pos_target_shape)
+
+ if self.training:
+ tokens = checkpoint(
+ blocks[block_idx],
+ tokens,
+ pos=pos,
+ use_reentrant=self.use_reentrant_checkpointing,
+ )
+ else:
+ tokens = blocks[block_idx](tokens, pos=pos)
+
+ return tokens.view(*token_shape)
+
+
+def expand_and_flatten_special_tokens(token_tensor, b, seq_len):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing.
+ Uses first position for frame 0, second position for remaining frames.
+
+ Args:
+ token_tensor: Input tensor with shape (1, 2, X, C)
+ b: Batch size
+ seq_len: Sequence length
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+ # First frame uses position 0, remaining frames use position 1
+ first_frame_tokens = token_tensor[:, 0:1, ...].expand(b, 1, *token_tensor.shape[2:])
+ remaining_frame_tokens = token_tensor[:, 1:, ...].expand(b, seq_len - 1, *token_tensor.shape[2:])
+
+ # Concatenate and flatten
+ combined_tokens = torch.cat([first_frame_tokens, remaining_frame_tokens], dim=1)
+ return combined_tokens.view(b * seq_len, *combined_tokens.shape[2:])
diff --git a/src/models/models/worldmirror.py b/src/models/models/worldmirror.py
new file mode 100644
index 0000000000000000000000000000000000000000..62bd21cfa32acd8bd4b928ee728d9730db7153d0
--- /dev/null
+++ b/src/models/models/worldmirror.py
@@ -0,0 +1,278 @@
+from typing import Dict, List
+
+import torch
+import torch.nn as nn
+
+from src.models.models.visual_transformer import VisualGeometryTransformer
+from src.models.heads.camera_head import CameraHead
+from src.models.heads.dense_head import DPTHead
+from src.models.models.rasterization import GaussianSplatRenderer
+from src.models.utils.camera_utils import vector_to_camera_matrices, extrinsics_to_vector
+from src.models.utils.priors import normalize_depth, normalize_poses
+
+from huggingface_hub import PyTorchModelHubMixin
+
+
+class WorldMirror(nn.Module, PyTorchModelHubMixin):
+ def __init__(self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ gs_dim=256,
+ enable_cond=True,
+ enable_cam=True,
+ enable_pts=True,
+ enable_depth=True,
+ enable_norm=True,
+ enable_gs=True,
+ patch_embed="dinov2_vitl14_reg",
+ fixed_patch_embed=False,
+ sampling_strategy="uniform",
+ dpt_gradient_checkpoint=False,
+ condition_strategy=["token", "pow3r", "token"],
+ enable_interpolation=False,
+ max_resolution=2044):
+
+ super().__init__()
+ # Configuration flags
+ self.enable_cam = enable_cam
+ self.enable_pts = enable_pts
+ self.enable_depth = enable_depth
+ self.enable_cond = enable_cond
+ self.enable_norm = enable_norm
+ self.enable_gs = enable_gs
+ self.patch_embed = patch_embed
+ self.sampling = sampling_strategy
+ self.dpt_checkpoint = dpt_gradient_checkpoint
+ self.cond_methods = condition_strategy
+
+ # Visual geometry transformer
+ self.visual_geometry_transformer = VisualGeometryTransformer(
+ img_size=img_size,
+ patch_size=patch_size,
+ embed_dim=embed_dim,
+ enable_cond=enable_cond,
+ sampling_strategy=sampling_strategy,
+ patch_embed=patch_embed,
+ fixed_patch_embed=fixed_patch_embed,
+ enable_interpolation=enable_interpolation,
+ max_resolution=max_resolution,
+ condition_strategy=condition_strategy
+ )
+
+ # Initialize prediction heads
+ self._init_heads(embed_dim, patch_size, gs_dim)
+
+ def _init_heads(self, dim, patch_size, gs_dim):
+ """Initialize all prediction heads"""
+
+ # Camera pose prediction head
+ if self.enable_cam:
+ self.cam_head = CameraHead(dim_in=2 * dim)
+
+ # 3D point prediction head
+ if self.enable_pts:
+ self.pts_head = DPTHead(
+ dim_in=2 * dim,
+ output_dim=4,
+ patch_size=patch_size,
+ activation="inv_log+expp1"
+ )
+
+ # Depth prediction head
+ if self.enable_depth:
+ self.depth_head = DPTHead(
+ dim_in=2 * dim,
+ output_dim=2,
+ patch_size=patch_size,
+ activation="exp+expp1",
+ )
+
+ # Surface normal prediction head
+ if self.enable_norm:
+ self.norm_head = DPTHead(
+ dim_in=2 * dim,
+ output_dim=4,
+ patch_size=patch_size,
+ activation="norm+expp1",
+ )
+
+ # Gaussian splatting feature head and renderer
+ if self.enable_gs:
+ self.gs_head = DPTHead(
+ dim_in=2 * dim,
+ output_dim=2,
+ patch_size=patch_size,
+ features=gs_dim,
+ is_gsdpt=True,
+ activation="exp+expp1"
+ )
+ self.gs_renderer = GaussianSplatRenderer(
+ sh_degree=0,
+ predict_offset=False,
+ predict_residual_sh=True,
+ enable_prune=True,
+ voxel_size=0.002,
+ using_gtcamera_splat=True,
+ render_novel_views=True,
+ )
+
+ def forward(self, views: Dict[str, torch.Tensor], cond_flags: List[int]=[0, 0, 0]):
+ """
+ Execute forward pass through the WorldMirror model.
+
+ Args:
+ views: Input data dictionary
+ cond_flags: Conditioning flags [depth, rays, camera]
+
+ Returns:
+ dict: Prediction results dictionary
+ """
+ imgs = views['img']
+
+ # Enable conditional input during training if enabled, or during inference if any cond_flags are set
+ use_cond = (
+ (self.training and self.enable_cond) or
+ (not self.training and sum(cond_flags) > 0)
+ )
+
+ # Extract priors and process features based on conditional input
+ context_token_list = None
+ if use_cond:
+ priors = self.extract_priors(views)
+ token_list, patch_start_idx = self.visual_geometry_transformer(
+ imgs, priors, cond_flags=cond_flags
+ )
+ if self.enable_gs:
+ cnums = views["context_nums"]
+ context_priors = (priors[0][:,:cnums], priors[1][:,:cnums], priors[2][:,:cnums])
+ context_token_list = self.visual_geometry_transformer(
+ imgs[:,:cnums], context_priors, cond_flags=cond_flags
+ )[0]
+ else:
+ token_list, patch_start_idx = self.visual_geometry_transformer(imgs)
+ if self.enable_gs:
+ cnums = views["context_nums"] if "context_nums" in views else imgs.shape[1]
+ context_token_list = self.visual_geometry_transformer(imgs[:,:cnums])[0]
+
+ # Execute predictions
+ with torch.amp.autocast('cuda', enabled=False):
+ # Generate all predictions
+ preds = self._gen_all_preds(
+ token_list, context_token_list, imgs, patch_start_idx, views
+ )
+
+ return preds
+
+ def _gen_all_preds(self, token_list, context_token_list,
+ imgs, patch_start_idx, views):
+ """Generate all enabled predictions"""
+ preds = {}
+ preds['images'] = imgs
+
+ # Camera pose prediction
+ if self.enable_cam:
+ cam_seq = self.cam_head(token_list)
+ cam_params = cam_seq[-1]
+ preds["camera_params"] = cam_params
+ if context_token_list is not None:
+ context_cam_params = self.cam_head(context_token_list)[-1]
+ context_preds = {"camera_params": context_cam_params}
+ ext_mat, int_mat = vector_to_camera_matrices(
+ cam_params, image_hw=(imgs.shape[-2], imgs.shape[-1])
+ )
+ # Create homogeneous transformation matrix
+ homo_row = torch.tensor([0, 0, 0, 1], device=ext_mat.device).view(1, 1, 1, 4)
+ homo_row = homo_row.repeat(ext_mat.shape[0], ext_mat.shape[1], 1, 1)
+ w2c_mat = torch.cat([ext_mat, homo_row], dim=2)
+ c2w_mat = torch.linalg.inv(w2c_mat)
+
+ preds["camera_poses"] = c2w_mat # C2W pose (OpenCV) in world coordinates: [B, S, 4, 4]
+ preds["camera_intrs"] = int_mat # Camera intrinsic matrix: [B, S, 3, 3]
+
+ # Depth prediction
+ if self.enable_depth:
+ depth, depth_conf = self.depth_head(
+ token_list, images=imgs, patch_start_idx=patch_start_idx,
+ )
+ preds["depth"] = depth
+ preds["depth_conf"] = depth_conf
+
+ # 3D point prediction
+ if self.enable_pts:
+ pts, pts_conf = self.pts_head(
+ token_list, images=imgs, patch_start_idx=patch_start_idx,
+ )
+ preds["pts3d"] = pts
+ preds["pts3d_conf"] = pts_conf
+
+ # Normal prediction
+ if self.enable_norm:
+ normals, norm_conf = self.norm_head(
+ token_list, images=imgs, patch_start_idx=patch_start_idx,
+ )
+ preds["normals"] = normals
+ preds["normals_conf"] = norm_conf
+
+ # 3D Gaussian Splatting
+ if self.enable_gs:
+ views['context_nums'] = imgs.shape[1] if "context_nums" not in views else views["context_nums"]
+ gs_feat, gs_depth, gs_depth_conf = self.gs_head(
+ context_token_list, images=imgs[:,:views["context_nums"]], patch_start_idx=patch_start_idx
+ )
+
+ preds["gs_depth"] = gs_depth
+ preds["gs_depth_conf"] = gs_depth_conf
+ preds = self.gs_renderer.render(
+ gs_feats=gs_feat,
+ images=imgs,
+ predictions=preds,
+ views=views,
+ context_predictions=context_preds
+ )
+
+ return preds
+
+ def extract_priors(self, views):
+ """
+ Extract and normalize geometric priors.
+
+ Args:
+ views: Input view data dictionary.
+
+ Returns:
+ tuple: (depths, rays, poses) Normalized priors.
+ """
+ h, w = views['img'].shape[-2:]
+
+ # Initialize prior variables
+ poses = depths = rays = None
+
+ # Extract camera pose
+ if 'camera_pose' in views:
+ extrinsics = views['camera_pose'][:, :, :3]
+ extrinsics = normalize_poses(extrinsics)
+ cam_params = extrinsics_to_vector(extrinsics)
+ poses = cam_params[:, :, :7] # Shape: [B, S, 7]
+
+ # Extract depth map
+ if 'depthmap' in views:
+ depths = normalize_depth(views['depthmap']) # Shape: [B, S, H, W]
+
+ # Extract ray directions
+ if 'camera_intrinsics' in views:
+ intrinsics = views['camera_intrinsics'][:, :, :3, :3]
+ fx, fy = intrinsics[:, :, 0, 0] / w, intrinsics[:, :, 1, 1] / h
+ cx, cy = intrinsics[:, :, 0, 2] / w, intrinsics[:, :, 1, 2] / h
+ rays = torch.stack([fx, fy, cx, cy], dim=-1) # Shape: [B, S, 4]
+
+ return (depths, rays, poses)
+
+
+if __name__ == "__main__":
+ device = "cuda"
+ model = WorldMirror().to(device).eval()
+ x = torch.rand(1, 1, 3, 518, 518).to(device)
+ out = model({'img': x})
+ import pdb; pdb.set_trace()
+
\ No newline at end of file
diff --git a/src/models/utils/act_gs.py b/src/models/utils/act_gs.py
new file mode 100644
index 0000000000000000000000000000000000000000..863e60b753933bb04a7b2f1c6e5d43cc636c667d
--- /dev/null
+++ b/src/models/utils/act_gs.py
@@ -0,0 +1,22 @@
+import torch
+from einops import rearrange
+
+
+def reg_dense_offsets(xyz, shift=6.0):
+ d = xyz.norm(dim=-1, keepdim=True)
+ return xyz / d.clamp(min=1e-8) * (torch.exp(d - shift) - torch.exp(-shift))
+
+def reg_dense_scales(scales):
+ return scales.exp()
+
+def reg_dense_rotation(rotations, eps=1e-8):
+ return rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
+
+def reg_dense_sh(sh):
+ return rearrange(sh, '... (d_sh xyz) -> ... d_sh xyz', xyz=3)
+
+def reg_dense_opacities(opacities):
+ return opacities.sigmoid()
+
+def reg_dense_weights(weights):
+ return weights.sigmoid()
diff --git a/src/models/utils/camera_utils.py b/src/models/utils/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1757378bd1c1b73e5dcee565b8f30d35cfa84279
--- /dev/null
+++ b/src/models/utils/camera_utils.py
@@ -0,0 +1,75 @@
+import torch
+from .rotation import quat_to_rotmat, rotmat_to_quat
+
+
+def camera_params_to_vector(
+ ext, intr, image_hw=None
+):
+ """Convert camera matrices to a compact vector."""
+ # ext: (..., 3, 4): Camera-to-world extrinsic [R|t]
+ # intr: (..., 3, 3): Intrinsics
+ # image_hw: (h, w)
+ R = ext[..., :3, :3] # Rotation part
+ t = ext[..., :3, 3] # Translation part
+ q = rotmat_to_quat(R) # Quaternion (wxyz)
+ h, w = image_hw
+ fov_v = 2.0 * torch.atan(h * 0.5 / intr[..., 1, 1]) # Vertical FOV
+ fov_u = 2.0 * torch.atan(w * 0.5 / intr[..., 0, 0]) # Horizontal FOV
+ vec = torch.stack([
+ t[..., 0], t[..., 1], t[..., 2],
+ q[..., 0], q[..., 1], q[..., 2], q[..., 3],
+ fov_v, fov_u
+ ], dim=-1).float()
+ return vec
+
+def extrinsics_to_vector(ext):
+ """Convert extrinsics to [t, q] vector."""
+ # ext: (..., 3, 4)
+ R = ext[..., :3, :3]
+ t = ext[..., :3, 3]
+ q = rotmat_to_quat(R)
+ vec = torch.stack([
+ t[..., 0], t[..., 1], t[..., 2],
+ q[..., 0], q[..., 1], q[..., 2], q[..., 3]
+ ], dim=-1).float()
+ return vec
+
+def vector_to_extrinsics(cam_vec):
+ """Convert [t, q] vector to extrinsic matrix."""
+ # cam_vec: (..., 7)
+ q = cam_vec[..., 3:7]
+ t = cam_vec[..., :3]
+ R = quat_to_rotmat(q)
+ ext = torch.cat([R, t.unsqueeze(-1)], dim=-1)
+ return ext
+
+def vector_to_camera_matrices(
+ cam_vec, image_hw=None, build_intr=True
+):
+ """Reconstruct extrinsic and intrinsic matrix from vector."""
+ # cam_vec: (..., 9)
+ intr = None
+ # Decompose vector
+ t = cam_vec[..., 0:3]
+ q = cam_vec[..., 3:7]
+ fov_v = cam_vec[..., 7]
+ fov_u = cam_vec[..., 8]
+
+ # Build extrinsic: [R|t]
+ R = quat_to_rotmat(q)
+ ext = torch.cat([R, t.unsqueeze(-1)], dim=-1)
+
+ # Build intrinsic if needed
+ if build_intr:
+ h, w = image_hw
+ fy = h * 0.5 / torch.tan(fov_v * 0.5)
+ fx = w * 0.5 / torch.tan(fov_u * 0.5)
+ shape = cam_vec.shape[:-1] + (3, 3)
+ intr = torch.zeros(shape, device=cam_vec.device, dtype=cam_vec.dtype)
+ intr[..., 0, 0] = fx
+ intr[..., 1, 1] = fy
+ intr[..., 0, 2] = w * 0.5
+ intr[..., 1, 2] = h * 0.5
+ intr[..., 2, 2] = 1.0
+
+ return ext, intr
diff --git a/src/models/utils/frustum.py b/src/models/utils/frustum.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dc7d2ebbfc8aa43bf7d6157c4335194cc714383
--- /dev/null
+++ b/src/models/utils/frustum.py
@@ -0,0 +1,196 @@
+import einops
+import torch
+
+
+# Calculate the loss mask for the target views in the batch
+@torch.no_grad()
+def calculate_unprojected_mask(views, context_nums):
+ '''Calcuate the loss mask for the target views in the batch'''
+ target_depth = views["depthmap"][:, context_nums:]
+ target_intrinsics = views["camera_intrinsics"][:, context_nums:]
+ target_c2w = views["camera_pose"][:, context_nums:]
+ context_depth = views["depthmap"][:, :context_nums]
+ context_intrinsics = views["camera_intrinsics"][:, :context_nums]
+ context_c2w = views["camera_pose"][:, :context_nums]
+
+ target_intrinsics = target_intrinsics[..., :3, :3]
+ context_intrinsics = context_intrinsics[..., :3, :3]
+
+ mask = calculate_in_frustum_mask(
+ target_depth, target_intrinsics, target_c2w,
+ context_depth, context_intrinsics, context_c2w
+ )
+ return mask
+
+@torch.no_grad()
+def calculate_in_frustum_mask(depth_1, intrinsics_1, c2w_1, depth_2, intrinsics_2, c2w_2):
+ """
+ A function that takes in the depth, intrinsics and c2w matrices of two sets
+ of views, and then works out which of the pixels in the first set of views
+ has a direct corresponding pixel in any of views in the second set
+
+ Args:
+ depth_1: (b, v1, h, w)
+ intrinsics_1: (b, v1, 3, 3)
+ c2w_1: (b, v1, 4, 4)
+ depth_2: (b, v2, h, w)
+ intrinsics_2: (b, v2, 3, 3)
+ c2w_2: (b, v2, 4, 4)
+
+ Returns:
+ torch.Tensor: valid mask with shape (b, v1, v2, h, w).
+ """
+
+ _, v1, h, w = depth_1.shape
+ _, v2, _, _ = depth_2.shape
+
+ # Unproject the depth to get the 3D points in world space
+ points_3d = unproject_depth(depth_1[..., None], intrinsics_1, c2w_1) # (b, v1, h, w, 3)
+
+ # Project the 3D points into the pixel space of all the second views simultaneously
+ camera_points = world_space_to_camera_space(points_3d, c2w_2) # (b, v1, v2, h, w, 3)
+ points_2d = camera_space_to_pixel_space(camera_points, intrinsics_2) # (b, v1, v2, h, w, 2)
+
+ # Calculate the depth of each point
+ rendered_depth = camera_points[..., 2] # (b, v1, v2, h, w)
+
+ # We use three conditions to determine if a point should be masked
+
+ # Condition 1: Check if the points are in the frustum of any of the v2 views
+ in_frustum_mask = (
+ (points_2d[..., 0] > 0) &
+ (points_2d[..., 0] < w) &
+ (points_2d[..., 1] > 0) &
+ (points_2d[..., 1] < h)
+ ) # (b, v1, v2, h, w)
+ in_frustum_mask = in_frustum_mask.any(dim=-3) # (b, v1, h, w)
+
+ # Condition 2: Check if the points have non-zero (i.e. valid) depth in the input view
+ non_zero_depth = depth_1 > 1e-6
+
+ # Condition 3: Check if the points have matching depth to any of the v2
+ # views torch.nn.functional.grid_sample expects the input coordinates to
+ # be normalized to the range [-1, 1], so we normalize first
+ points_2d[..., 0] /= w
+ points_2d[..., 1] /= h
+ points_2d = points_2d * 2 - 1
+ matching_depth = torch.ones_like(rendered_depth, dtype=torch.bool)
+ for b in range(depth_1.shape[0]):
+ for i in range(v1):
+ for j in range(v2):
+ depth = einops.rearrange(depth_2[b, j], 'h w -> 1 1 h w')
+ coords = einops.rearrange(points_2d[b, i, j], 'h w c -> 1 h w c')
+ sampled_depths = torch.nn.functional.grid_sample(depth, coords, align_corners=False)[0, 0]
+ matching_depth[b, i, j] = torch.isclose(rendered_depth[b, i, j], sampled_depths, atol=1e-1)
+
+ matching_depth = matching_depth.any(dim=-3) # (..., v1, h, w)
+
+ mask = in_frustum_mask & non_zero_depth & matching_depth
+ return mask
+
+# --- Projections ---
+def homogenize_points(points):
+ """Append a '1' along the final dimension of the tensor (i.e. convert xyz->xyz1)"""
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+
+
+def normalize_homogenous_points(points):
+ """Normalize the point vectors"""
+ return points / points[..., -1:]
+
+
+def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics):
+ """
+ Convert pixel space points to camera space points.
+
+ Args:
+ pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2)
+ depth (torch.Tensor): Depth map with shape (b, v, h, w, 1)
+ intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3)
+
+ Returns:
+ torch.Tensor: Camera space points with shape (b, v, h, w, 3).
+ """
+ pixel_space_points = homogenize_points(pixel_space_points)
+ camera_space_points = torch.einsum('b v i j , h w j -> b v h w i', intrinsics.inverse(), pixel_space_points)
+ camera_space_points = camera_space_points * depth
+ return camera_space_points
+
+
+def camera_space_to_world_space(camera_space_points, c2w):
+ """
+ Convert camera space points to world space points.
+
+ Args:
+ camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3)
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4)
+
+ Returns:
+ torch.Tensor: World space points with shape (b, v, h, w, 3).
+ """
+ camera_space_points = homogenize_points(camera_space_points)
+ world_space_points = torch.einsum('b v i j , b v h w j -> b v h w i', c2w, camera_space_points)
+ return world_space_points[..., :3]
+
+
+def camera_space_to_pixel_space(camera_space_points, intrinsics):
+ """
+ Convert camera space points to pixel space points.
+
+ Args:
+ camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3)
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3)
+
+ Returns:
+ torch.Tensor: World space points with shape (b, v1, v2, h, w, 2).
+ """
+ camera_space_points = normalize_homogenous_points(camera_space_points)
+ pixel_space_points = torch.einsum('b u i j , b v u h w j -> b v u h w i', intrinsics, camera_space_points)
+ return pixel_space_points[..., :2]
+
+
+def world_space_to_camera_space(world_space_points, c2w):
+ """
+ Convert world space points to pixel space points.
+
+ Args:
+ world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3)
+ c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4)
+
+ Returns:
+ torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3).
+ """
+ world_space_points = homogenize_points(world_space_points)
+ camera_space_points = torch.einsum('b u i j , b v h w j -> b v u h w i', c2w.inverse(), world_space_points)
+ return camera_space_points[..., :3]
+
+
+def unproject_depth(depth, intrinsics, c2w):
+ """
+ Turn the depth map into a 3D point cloud in world space
+
+ Args:
+ depth: (b, v, h, w, 1)
+ intrinsics: (b, v, 3, 3)
+ c2w: (b, v, 4, 4)
+
+ Returns:
+ torch.Tensor: World space points with shape (b, v, h, w, 3).
+ """
+
+ # Compute indices of pixels
+ h, w = depth.shape[-3], depth.shape[-2]
+ x_grid, y_grid = torch.meshgrid(
+ torch.arange(w, device=depth.device, dtype=torch.float32),
+ torch.arange(h, device=depth.device, dtype=torch.float32),
+ indexing='xy'
+ ) # (h, w), (h, w)
+
+ # Compute coordinates of pixels in camera space
+ pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2)
+ camera_points = pixel_space_to_camera_space(pixel_space_points, depth, intrinsics) # (..., h, w, 3)
+
+ # Convert points to world space
+ world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3)
+
+ return world_points
\ No newline at end of file
diff --git a/src/models/utils/geometry.py b/src/models/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed1805b308f9d93fbc781d45982aea4a95f5ee09
--- /dev/null
+++ b/src/models/utils/geometry.py
@@ -0,0 +1,138 @@
+import torch
+import numpy as np
+
+
+def depth_to_camera_coords(depthmap, camera_intrinsics):
+ """
+ Convert depth map to 3D camera coordinates.
+
+ Args:
+ depthmap (BxHxW tensor): Batch of depth maps
+ camera_intrinsics (Bx3x3 tensor): Camera intrinsics matrix for each camera
+
+ Returns:
+ X_cam (BxHxWx3 tensor): 3D points in camera coordinates
+ valid_mask (BxHxW tensor): Mask indicating valid depth pixels
+ """
+ B, H, W = depthmap.shape
+ device = depthmap.device
+ dtype = depthmap.dtype
+
+ # Ensure intrinsics are float
+ camera_intrinsics = camera_intrinsics.float()
+
+ # Extract focal lengths and principal points
+ fx = camera_intrinsics[:, 0, 0] # (B,)
+ fy = camera_intrinsics[:, 1, 1] # (B,)
+ cx = camera_intrinsics[:, 0, 2] # (B,)
+ cy = camera_intrinsics[:, 1, 2] # (B,)
+
+ # Generate pixel grid
+ v_grid, u_grid = torch.meshgrid(
+ torch.arange(H, dtype=dtype, device=device),
+ torch.arange(W, dtype=dtype, device=device),
+ indexing='ij'
+ )
+
+ # Reshape for broadcasting: (1, H, W)
+ u_grid = u_grid.unsqueeze(0)
+ v_grid = v_grid.unsqueeze(0)
+
+ # Compute 3D camera coordinates
+ # X = (u - cx) * Z / fx
+ # Y = (v - cy) * Z / fy
+ # Z = depth
+ z_cam = depthmap # (B, H, W)
+ x_cam = (u_grid - cx.view(B, 1, 1)) * z_cam / fx.view(B, 1, 1)
+ y_cam = (v_grid - cy.view(B, 1, 1)) * z_cam / fy.view(B, 1, 1)
+
+ # Stack to form (B, H, W, 3)
+ X_cam = torch.stack([x_cam, y_cam, z_cam], dim=-1)
+
+ # Valid depth mask
+ valid_mask = depthmap > 0.0
+
+ return X_cam, valid_mask
+
+def depth_to_world_coords_points(
+ depth_map: torch.Tensor, extrinsic: torch.Tensor, intrinsic: torch.Tensor, eps=1e-8
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Convert a batch of depth maps to world coordinates.
+
+ Args:
+ depth_map (torch.Tensor): (B, H, W) Depth map
+ extrinsic (torch.Tensor): (B, 4, 4) Camera extrinsic matrix (camera-to-world transformation)
+ intrinsic (torch.Tensor): (B, 3, 3) Camera intrinsic matrix
+
+ Returns:
+ world_coords_points (torch.Tensor): (B, H, W, 3) World coordinates
+ camera_points (torch.Tensor): (B, H, W, 3) Camera coordinates
+ point_mask (torch.Tensor): (B, H, W) Valid depth mask
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask (B, H, W)
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates (B, H, W, 3)
+ camera_points, _ = depth_to_camera_coords(depth_map, intrinsic)
+
+ # Apply extrinsic matrix (camera -> world)
+ R_cam_to_world = extrinsic[:, :3, :3] # (B, 3, 3)
+ t_cam_to_world = extrinsic[:, :3, 3] # (B, 3)
+
+ # Transform (B, H, W, 3) x (B, 3, 3)^T + (B, 3) -> (B, H, W, 3)
+ world_coords_points = torch.einsum('bhwi,bji->bhwj', camera_points, R_cam_to_world) + t_cam_to_world[:, None, None, :]
+
+ return world_coords_points, camera_points, point_mask
+
+
+def closed_form_inverse_se3(se3: torch.Tensor) -> torch.Tensor:
+ """
+ Efficiently invert batched SE(3) matrices of shape (B, 4, 4).
+
+ Args:
+ se3 (torch.Tensor): (B, 4, 4) Transformation matrices
+
+ Returns:
+ out (torch.Tensor): (B, 4, 4) Inverse transformation matrices
+ """
+ assert se3.ndim == 3 and se3.shape[1:] == (4, 4), f"se3 must be (B, 4, 4), got {se3.shape}"
+ R = se3[:, :3, :3] # (B, 3, 3)
+ t = se3[:, :3, 3] # (B, 3)
+ Rt = R.transpose(1, 2) # (B, 3, 3)
+ t_inv = -torch.bmm(Rt, t.unsqueeze(-1)).squeeze(-1) # (B, 3)
+ out = se3.new_zeros(se3.shape)
+ out[:, :3, :3] = Rt
+ out[:, :3, 3] = t_inv
+ out[:, 3, 3] = 1.0
+ return out
+
+
+def create_pixel_coordinate_grid(num_frames, height, width):
+ """
+ Creates a grid of pixel coordinates and frame indices for all frames.
+ Returns:
+ tuple: A tuple containing:
+ - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
+ with x, y coordinates and frame indices
+ """
+ # Create coordinate grids for a single frame
+ y_grid, x_grid = np.indices((height, width), dtype=np.float32)
+ x_grid = x_grid[np.newaxis, :, :]
+ y_grid = y_grid[np.newaxis, :, :]
+
+ # Broadcast to all frames
+ x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
+ y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
+
+ # Create frame indices and broadcast
+ f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
+ f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
+
+ # Stack coordinates and frame indices
+ points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
+
+ return points_xyf
\ No newline at end of file
diff --git a/src/models/utils/grid.py b/src/models/utils/grid.py
new file mode 100644
index 0000000000000000000000000000000000000000..f624a7eb44c543d408fbf5dcdd02f39898c9c364
--- /dev/null
+++ b/src/models/utils/grid.py
@@ -0,0 +1,90 @@
+import torch
+
+
+def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+ omega_0: Base frequency for sinusoidal encoding
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ assert embed_dim % 2 == 0
+
+ device = pos_grid.device
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Generate frequency bands
+ omega = torch.arange(embed_dim // 4, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
+ omega /= embed_dim / 4.0
+ omega = 1.0 / omega_0**omega # (D/4,)
+
+ # Process x and y coordinates separately
+ pos_x = pos_flat[:, 0].reshape(-1) # (H*W,)
+ pos_y = pos_flat[:, 1].reshape(-1) # (H*W,)
+
+ # Compute outer products
+ out_x = torch.einsum("m,d->md", pos_x, omega) # (H*W, D/4)
+ out_y = torch.einsum("m,d->md", pos_y, omega) # (H*W, D/4)
+
+ # Apply sin and cos
+ emb_x = torch.cat([torch.sin(out_x), torch.cos(out_x)], dim=1) # (H*W, D/2)
+ emb_y = torch.cat([torch.sin(out_y), torch.cos(out_y)], dim=1) # (H*W, D/2)
+
+ # Combine x and y embeddings
+ emb = torch.cat([emb_x, emb_y], dim=-1) # (H*W, D)
+
+ return emb.float().view(H, W, embed_dim) # [H, W, D]
+
+
+# Inspired by https://github.com/microsoft/moge
+def create_uv_grid(
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/src/models/utils/priors.py b/src/models/utils/priors.py
new file mode 100644
index 0000000000000000000000000000000000000000..468703a1540b697c126eb3549c17ccb0d89882d0
--- /dev/null
+++ b/src/models/utils/priors.py
@@ -0,0 +1,168 @@
+import torch
+
+
+def normalize_poses(extrinsics, padding=0.1, return_stats=False):
+ """
+ Normalize camera positions to unit cube, processing each batch separately
+
+ Args:
+ extrinsics: Camera extrinsic matrices with shape (B, S, 3, 4)
+ padding: Boundary space within [0,1] range to prevent values near boundaries
+ return_stats: Whether to return normalization statistics
+
+ Returns:
+ normalized_extrinsics: Normalized extrinsic matrices
+ (optional) stats: Dictionary containing scale and translation information
+ """
+ B, S, _, _ = extrinsics.shape
+ device = extrinsics.device
+
+ # Check input validity and handle NaN/Inf values
+ for i in range(B):
+ if torch.isnan(extrinsics[i]).any() or torch.isinf(extrinsics[i]).any():
+ print(f"Warning: dataset sample has NaN/Inf in extrinsics")
+ extrinsics[i] = torch.nan_to_num(
+ extrinsics[i], nan=0.0, posinf=1e6, neginf=-1e6
+ )
+
+ normalized_extrinsics = extrinsics.clone()
+
+ # Store normalization parameters if needed
+ if return_stats:
+ stats = {
+ 'scale_factors': torch.zeros(B, device=device),
+ 'translation_vectors': torch.zeros(B, 3, device=device)
+ }
+
+ for b in range(B):
+ # Extract camera positions for this batch
+ positions = extrinsics[b, :, :3, 3] # (S, 3)
+
+ # Filter valid positions to ignore outliers
+ valid_mask = torch.isfinite(positions).all(dim=1) # (S,)
+
+ if valid_mask.sum() == 0:
+ # No valid positions, use default values
+ print(f"Warning: Batch {b} has no valid camera positions")
+ normalized_extrinsics[b, :, :3, 3] = 0.5 # Place at center
+ if return_stats:
+ stats['scale_factors'][b] = 1.0
+ stats['translation_vectors'][b] = 0.0
+ continue
+
+ valid_positions = positions[valid_mask]
+
+ # Calculate bounds using percentiles for robustness
+ if valid_positions.shape[0] > 10:
+ # Use 5% and 95% percentiles instead of min/max
+ min_pos = torch.quantile(valid_positions, 0.05, dim=0)
+ max_pos = torch.quantile(valid_positions, 0.95, dim=0)
+ else:
+ # Too few samples, use min/max
+ min_pos = torch.min(valid_positions, dim=0)[0]
+ max_pos = torch.max(valid_positions, dim=0)[0]
+
+ # Calculate scale factor considering all dimensions
+ pos_range = max_pos - min_pos
+
+ # Add small epsilon to prevent dimension collapse
+ eps = torch.maximum(
+ torch.tensor(1e-6, device=device),
+ torch.abs(max_pos) * 1e-6
+ )
+ pos_range = torch.maximum(pos_range, eps)
+
+ # Use maximum range as scale factor for uniform scaling
+ scale_factor = torch.max(pos_range)
+ scale_factor = torch.clamp(scale_factor, min=1e-6, max=1e6)
+
+ # Calculate center point for centering
+ center = (min_pos + max_pos) / 2.0
+
+ # Normalize: center first, then scale with padding
+ actual_scale = scale_factor / (1 - 2 * padding)
+ normalized_positions = (positions - center) / actual_scale + 0.5
+
+ # Ensure all values are within valid range
+ normalized_positions = torch.clamp(normalized_positions, 0.0, 1.0)
+
+ # Handle invalid positions by setting them to scene center
+ invalid_mask = ~torch.isfinite(positions).all(dim=1)
+ if invalid_mask.any():
+ normalized_positions[invalid_mask] = 0.5
+
+ normalized_extrinsics[b, :, :3, 3] = normalized_positions
+
+ if return_stats:
+ stats['scale_factors'][b] = actual_scale
+ stats['translation_vectors'][b] = center
+
+ # Final validation
+ assert torch.isfinite(normalized_extrinsics).all(), "Output contains non-finite values"
+
+ if return_stats:
+ return normalized_extrinsics, stats
+ return normalized_extrinsics
+
+
+def normalize_depth(depth, eps=1e-6, min_percentile=1, max_percentile=99):
+ """
+ Normalize depth values to [0, 1] range using percentile-based scaling.
+
+ Args:
+ depth: Input depth tensor with shape (B, S, H, W)
+ eps: Small epsilon value to prevent division by zero
+ min_percentile: Lower percentile for robust min calculation (default: 1)
+ max_percentile: Upper percentile for robust max calculation (default: 99)
+
+ Returns:
+ normalized_depth: Depth tensor normalized to [0, 1] range with same shape (B, S, H, W)
+ """
+ B, S, H, W = depth.shape
+ depth = depth.flatten(0,1) # [B*S, H, W]
+
+ # Handle invalid values
+ depth = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0)
+
+ normalized_list = []
+ for i in range(depth.shape[0]):
+ depth_img = depth[i] # [H, W]
+ depth_flat = depth_img.flatten()
+
+ # Filter out zero values if needed
+ non_zero_mask = depth_flat > 0
+ if non_zero_mask.sum() > 0:
+ values_to_use = depth_flat[non_zero_mask]
+ else:
+ values_to_use = depth_flat
+
+ # Only calculate percentiles when there are enough values
+ if values_to_use.numel() > 100: # Ensure enough samples for percentile calculation
+ # Calculate min and max percentiles
+ depth_min = torch.quantile(values_to_use, min_percentile/100.0)
+ depth_max = torch.quantile(values_to_use, max_percentile/100.0)
+ else:
+ # If too few samples, use min/max values
+ depth_min = values_to_use.min()
+ depth_max = values_to_use.max()
+
+ # Handle case where max equals min
+ if depth_max == depth_min:
+ depth_max = depth_min + 1.0
+
+ # Use relative epsilon
+ scale = torch.abs(depth_max - depth_min)
+ eps_val = max(eps, scale.item() * eps)
+
+ # Perform normalization
+ depth_norm_img = (depth_img - depth_min) / (depth_max - depth_min + eps_val)
+
+ # Ensure output is within [0,1] range
+ depth_norm_img = torch.clamp(depth_norm_img, 0.0, 1.0)
+
+ normalized_list.append(depth_norm_img)
+
+ # Recombine all normalized images
+ depth_norm = torch.stack(normalized_list)
+
+ return depth_norm.reshape(B, S, H, W)
\ No newline at end of file
diff --git a/src/models/utils/rotation.py b/src/models/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..14ab9315d3dfdd6652c888ea8fd49aaacf772524
--- /dev/null
+++ b/src/models/utils/rotation.py
@@ -0,0 +1,126 @@
+# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+def quat_to_rotmat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def rotmat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/src/models/utils/sh_utils.py b/src/models/utils/sh_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a064d97e6aa2986f635531ebc8eb5b2393992f8
--- /dev/null
+++ b/src/models/utils/sh_utils.py
@@ -0,0 +1,116 @@
+# Copyright 2021 The PlenOctree Authors.
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+ 1.0925484305920792,
+ -1.0925484305920792,
+ 0.31539156525252005,
+ -1.0925484305920792,
+ 0.5462742152960396
+]
+C3 = [
+ -0.5900435899266435,
+ 2.890611442640554,
+ -0.4570457994644658,
+ 0.3731763325901154,
+ -0.4570457994644658,
+ 1.445305721320277,
+ -0.5900435899266435
+]
+C4 = [
+ 2.5033429417967046,
+ -1.7701307697799304,
+ 0.9461746957575601,
+ -0.6690465435572892,
+ 0.10578554691520431,
+ -0.6690465435572892,
+ 0.47308734787878004,
+ -1.7701307697799304,
+ 0.6258357354491761,
+]
+
+
+def eval_sh(deg, sh, dirs):
+ """
+ Evaluate spherical harmonics at unit directions
+ using hardcoded SH polynomials.
+ Works with torch/np/jnp.
+ ... Can be 0 or more batch dimensions.
+ Args:
+ deg: int SH deg. Currently, 0-3 supported
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
+ dirs: jnp.ndarray unit directions [..., 3]
+ Returns:
+ [..., C]
+ """
+ assert deg <= 4 and deg >= 0
+ coeff = (deg + 1) ** 2
+ assert sh.shape[-1] >= coeff
+
+ result = C0 * sh[..., 0]
+ if deg > 0:
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+ result = (result -
+ C1 * y * sh[..., 1] +
+ C1 * z * sh[..., 2] -
+ C1 * x * sh[..., 3])
+
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result = (result +
+ C2[0] * xy * sh[..., 4] +
+ C2[1] * yz * sh[..., 5] +
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+ C2[3] * xz * sh[..., 7] +
+ C2[4] * (xx - yy) * sh[..., 8])
+
+ if deg > 2:
+ result = (result +
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+ C3[1] * xy * z * sh[..., 10] +
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+ C3[5] * z * (xx - yy) * sh[..., 14] +
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+ if deg > 3:
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+ return result
+
+def RGB2SH(rgb):
+ return (rgb - 0.5) / C0
+
+def SH2RGB(sh):
+ return sh * C0 + 0.5
\ No newline at end of file
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/utils/build_pycolmap_recon.py b/src/utils/build_pycolmap_recon.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ae474bcec59a7acb39941f2fdbb1a4d3c4e24bd
--- /dev/null
+++ b/src/utils/build_pycolmap_recon.py
@@ -0,0 +1,116 @@
+import numpy as np
+import pycolmap
+
+
+def _create_camera_params(frame_idx, cam_matrices, model_type, distortion_coeffs=None):
+ """Build camera parameter array for different model types."""
+ if model_type == "PINHOLE":
+ return np.array([
+ cam_matrices[frame_idx][0, 0], cam_matrices[frame_idx][1, 1],
+ cam_matrices[frame_idx][0, 2], cam_matrices[frame_idx][1, 2]
+ ])
+ elif model_type == "SIMPLE_PINHOLE":
+ focal_avg = (cam_matrices[frame_idx][0, 0] + cam_matrices[frame_idx][1, 1]) / 2
+ return np.array([focal_avg, cam_matrices[frame_idx][0, 2], cam_matrices[frame_idx][1, 2]])
+ elif model_type == "SIMPLE_RADIAL":
+ raise NotImplementedError("SIMPLE_RADIAL model not supported")
+ else:
+ raise ValueError(f"Unsupported camera model: {model_type}")
+
+
+def _setup_camera_object(frame_idx, cam_matrices, img_dims, model_type, use_shared):
+ """Create and configure camera object."""
+ if use_shared and frame_idx > 0:
+ return None
+
+ params = _create_camera_params(frame_idx, cam_matrices, model_type)
+ return pycolmap.Camera(
+ model=model_type,
+ width=img_dims[0],
+ height=img_dims[1],
+ params=params,
+ camera_id=frame_idx + 1
+ )
+
+
+def _process_frame_points(scene_points, point_coords, frame_idx):
+ """Extract and process 2D points belonging to specific frame."""
+ frame_mask = point_coords[:, 2].astype(np.int32) == frame_idx
+ valid_indices = np.nonzero(frame_mask)[0]
+
+ point2d_list = []
+ for idx, batch_idx in enumerate(valid_indices):
+ point3d_id = batch_idx + 1
+ xy_coords = point_coords[batch_idx][:2]
+ point2d_list.append(pycolmap.Point2D(xy_coords, point3d_id))
+
+ # Update track information
+ track = scene_points.points3D[point3d_id].track
+ track.add_element(frame_idx + 1, idx)
+
+ return point2d_list
+
+
+def build_pycolmap_reconstruction(
+ points,
+ pixel_coords,
+ point_colors,
+ poses,
+ intrinsics,
+ image_size,
+ shared_camera_model=False,
+ camera_model="SIMPLE_PINHOLE",
+):
+ """
+ Convert numpy arrays to pycolmap reconstruction format.
+
+ Creates 3D scene structure without track optimization.
+ Suitable for initialization of neural rendering methods.
+ """
+ num_frames = len(poses)
+ num_points = len(points)
+
+ scene = pycolmap.Reconstruction()
+
+ # Add 3D points to scene
+ for pt_idx in range(num_points):
+ scene.add_point3D(points[pt_idx], pycolmap.Track(), point_colors[pt_idx])
+
+ current_camera = None
+
+ # Process each frame
+ for frame_idx in range(num_frames):
+ # Setup camera if needed
+ if current_camera is None or not shared_camera_model:
+ current_camera = _setup_camera_object(
+ frame_idx, intrinsics, image_size, camera_model, shared_camera_model
+ )
+ scene.add_camera(current_camera)
+
+ # Create pose transformation
+ rotation_matrix = poses[frame_idx][:3, :3]
+ translation_vec = poses[frame_idx][:3, 3]
+ world_to_cam = pycolmap.Rigid3d(pycolmap.Rotation3d(rotation_matrix), translation_vec)
+
+ # Create image object
+ frame_image = pycolmap.Image(
+ id=frame_idx + 1,
+ name=f"frame_{frame_idx + 1}",
+ camera_id=current_camera.camera_id,
+ cam_from_world=world_to_cam
+ )
+
+ # Process 2D points for this frame
+ frame_points = _process_frame_points(scene, pixel_coords, frame_idx)
+
+ # Set image points and registration status
+ try:
+ frame_image.points2D = pycolmap.ListPoint2D(frame_points)
+ frame_image.registered = True
+ except:
+ print(f"Warning: Frame {frame_idx + 1} has no valid points")
+ frame_image.registered = False
+
+ scene.add_image(frame_image)
+
+ return scene
\ No newline at end of file
diff --git a/src/utils/color_map.py b/src/utils/color_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..e751cf970fcd99677a4b2064854a94d760b990b7
--- /dev/null
+++ b/src/utils/color_map.py
@@ -0,0 +1,50 @@
+# References: mvsplat
+
+import torch
+from colorspacious import cspace_convert
+from einops import rearrange
+from jaxtyping import Float
+from matplotlib import cm
+from torch import Tensor
+
+
+def apply_color_map(
+ x: torch.Tensor,
+ color_map: str = "inferno",
+) -> torch.Tensor:
+ cmap = cm.get_cmap(color_map)
+
+ # Convert to NumPy so that Matplotlib color maps can be used.
+ mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3]
+
+ # Convert back to the original format.
+ return torch.tensor(mapped, device=x.device, dtype=torch.float32)
+
+
+def apply_color_map_to_image(
+ image: torch.Tensor, #Float[Tensor, "*batch height width"],
+ color_map: str = "inferno",
+): #-> Float[Tensor, "*batch 3 height with"]
+ image = apply_color_map(image, color_map)
+ return rearrange(image, "... h w c -> ... c h w")
+
+
+def apply_color_map_2d(
+ x, #Float[Tensor, "*#batch"],
+ y, # Float[Tensor, "*#batch"],
+): # -> Float[Tensor, "*batch 3"]
+ red = cspace_convert((189, 0, 0), "sRGB255", "CIELab")
+ blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab")
+ white = cspace_convert((255, 255, 255), "sRGB255", "CIELab")
+ x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None]
+ y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None]
+
+ # Interpolate between red and blue on the x axis.
+ interpolated = x_np * red + (1 - x_np) * blue
+
+ # Interpolate between color and white on the y axis.
+ interpolated = y_np * interpolated + (1 - y_np) * white
+
+ # Convert to RGB.
+ rgb = cspace_convert(interpolated, "CIELab", "sRGB1")
+ return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1)
\ No newline at end of file
diff --git a/src/utils/cropping.py b/src/utils/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..591eb89d6fc9548efec66c7d51d252b562c45a36
--- /dev/null
+++ b/src/utils/cropping.py
@@ -0,0 +1,387 @@
+"""
+Utility functions for cropping and resizing data while maintaining proper cameras.
+
+References: DUSt3R
+"""
+
+import cv2
+import numpy as np
+import PIL.Image
+
+try:
+ lanczos = PIL.Image.Resampling.LANCZOS
+ bicubic = PIL.Image.Resampling.BICUBIC
+except AttributeError:
+ lanczos = PIL.Image.LANCZOS
+ bicubic = PIL.Image.BICUBIC
+
+from src.utils.geometry import (
+ colmap_to_opencv_intrinsics,
+ opencv_to_colmap_intrinsics,
+)
+
+
+class ImageList:
+ """
+ Convenience class to apply the same operation to a whole set of images.
+
+ This class wraps a list of PIL.Image objects and provides methods to perform
+ operations on all images simultaneously.
+ """
+
+ def __init__(self, images):
+ if not isinstance(images, (tuple, list, set)):
+ images = [images]
+ self.images = []
+ for image in images:
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+ self.images.append(image)
+
+ def __len__(self):
+ """Return the number of images in the list."""
+ return len(self.images)
+
+ def to_pil(self):
+ """
+ Convert ImageList back to PIL Image(s).
+
+ Returns:
+ PIL.Image.Image or tuple: Single PIL Image if list contains one image,
+ or tuple of PIL Images if multiple images
+ """
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
+
+ @property
+ def size(self):
+ """
+ Get the size of images in the list.
+
+ Returns:
+ tuple: (width, height) of the images
+
+ Raises:
+ AssertionError: If images have different sizes
+ """
+ sizes = [im.size for im in self.images]
+ assert all(sizes[0] == s for s in sizes), "All images must have the same size"
+ return sizes[0]
+
+ def resize(self, *args, **kwargs):
+ """
+ Resize all images with the same parameters.
+
+ Args:
+ *args, **kwargs: Arguments passed to PIL.Image.resize()
+
+ Returns:
+ ImageList: New ImageList containing resized images
+ """
+ return ImageList(self._dispatch("resize", *args, **kwargs))
+
+ def crop(self, *args, **kwargs):
+ """
+ Crop all images with the same parameters.
+
+ Args:
+ *args, **kwargs: Arguments passed to PIL.Image.crop()
+
+ Returns:
+ ImageList: New ImageList containing cropped images
+ """
+ return ImageList(self._dispatch("crop", *args, **kwargs))
+
+ def _dispatch(self, func, *args, **kwargs):
+ """
+ Apply a PIL.Image method to all images in the list.
+
+ Args:
+ func (str): Name of the PIL.Image method to call
+ *args, **kwargs: Arguments to pass to the method
+
+ Returns:
+ list: List of results from applying the method to each image
+ """
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
+
+
+def rescale_image_and_other_optional_info(
+ image,
+ output_resolution,
+ depthmap=None,
+ camera_intrinsics=None,
+ force=True,
+ additional_quantities_to_be_resized_with_nearest=None,
+):
+ """
+ Rescale the image and depthmap to the output resolution.
+ If the image is larger than the output resolution, it is rescaled with lanczos interpolation.
+ If force is false and the image is smaller than the output resolution, it is not rescaled.
+ If force is true and the image is smaller than the output resolution, it is rescaled with bicubic interpolation.
+ Depth and other quantities are rescaled with nearest interpolation.
+
+ Args:
+ image (PIL.Image.Image or np.ndarray): The input image to be rescaled.
+ output_resolution (tuple): The desired output resolution as a tuple (width, height).
+ depthmap (np.ndarray, optional): The depth map associated with the image. Defaults to None.
+ camera_intrinsics (np.ndarray, optional): The camera intrinsics matrix. Defaults to None.
+ force (bool, optional): If True, force rescaling even if the image is smaller than the output resolution. Defaults to True.
+ additional_quantities_to_be_resized_with_nearest (list of np.ndarray, optional): Additional quantities to be rescaled using nearest interpolation. Defaults to None.
+
+ Returns:
+ tuple: A tuple containing:
+ - The rescaled image (PIL.Image.Image)
+ - The rescaled depthmap (numpy.ndarray or None)
+ - The updated camera intrinsics (numpy.ndarray or None)
+ - The list of rescaled additional quantities (list of numpy.ndarray or None)
+ """
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (W, H)
+ output_resolution = np.array(output_resolution)
+ if depthmap is not None:
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
+ if additional_quantities_to_be_resized_with_nearest is not None:
+ assert all(
+ tuple(additional_quantity.shape[:2]) == image.size[::-1]
+ for additional_quantity in additional_quantities_to_be_resized_with_nearest
+ )
+
+ # Define output resolution
+ assert output_resolution.shape == (2,)
+ scale_final = max(output_resolution / image.size) + 1e-8
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
+ output = (
+ image.to_pil(),
+ depthmap,
+ camera_intrinsics,
+ additional_quantities_to_be_resized_with_nearest,
+ )
+ return output
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+ # First rescale the image so that it contains the crop
+ image = image.resize(
+ tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic
+ )
+ if depthmap is not None:
+ depthmap = cv2.resize(
+ depthmap,
+ output_resolution,
+ fx=scale_final,
+ fy=scale_final,
+ interpolation=cv2.INTER_NEAREST,
+ )
+ if additional_quantities_to_be_resized_with_nearest is not None:
+ resized_additional_quantities = []
+ for quantity in additional_quantities_to_be_resized_with_nearest:
+ resized_additional_quantities.append(
+ cv2.resize(
+ quantity,
+ output_resolution,
+ fx=scale_final,
+ fy=scale_final,
+ interpolation=cv2.INTER_NEAREST,
+ )
+ )
+ additional_quantities_to_be_resized_with_nearest = resized_additional_quantities
+
+ # No offset here; simple rescaling
+ if camera_intrinsics is not None:
+ camera_intrinsics = camera_matrix_of_crop(
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
+ )
+
+ # Return
+ return (
+ image.to_pil(),
+ depthmap,
+ camera_intrinsics,
+ additional_quantities_to_be_resized_with_nearest,
+ )
+
+
+def camera_matrix_of_crop(
+ input_camera_matrix,
+ input_resolution,
+ output_resolution,
+ scaling=1,
+ offset_factor=0.5,
+ offset=None,
+):
+ """
+ Calculate the camera matrix for a cropped image.
+
+ Args:
+ input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix
+ input_resolution (tuple or numpy.ndarray): Original image resolution as (width, height)
+ output_resolution (tuple or numpy.ndarray): Target image resolution as (width, height)
+ scaling (float, optional): Scaling factor for the image. Defaults to 1.
+ offset_factor (float, optional): Factor to determine crop offset. Defaults to 0.5 (centered).
+ offset (tuple or numpy.ndarray, optional): Explicit offset to use. If None, calculated from offset_factor.
+
+ Returns:
+ numpy.ndarray: Updated camera matrix for the cropped image
+ """
+ # Margins to offset the origin
+ margins = np.asarray(input_resolution) * scaling - output_resolution
+ assert np.all(margins >= 0.0)
+ if offset is None:
+ offset = offset_factor * margins
+
+ # Generate new camera parameters
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
+ output_camera_matrix_colmap[:2, :] *= scaling
+ output_camera_matrix_colmap[:2, 2] -= offset
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
+
+ return output_camera_matrix
+
+
+def crop_image_and_other_optional_info(
+ image,
+ crop_bbox,
+ depthmap=None,
+ camera_intrinsics=None,
+ additional_quantities=None,
+):
+ """
+ Return a crop of the input view and associated data.
+
+ Args:
+ image (PIL.Image.Image or numpy.ndarray): The input image to be cropped
+ crop_bbox (tuple): Crop bounding box as (left, top, right, bottom)
+ depthmap (numpy.ndarray, optional): Depth map associated with the image
+ camera_intrinsics (numpy.ndarray, optional): Camera intrinsics matrix
+ additional_quantities (list of numpy.ndarray, optional): Additional data arrays to crop
+
+ Returns:
+ tuple: A tuple containing:
+ - The cropped image
+ - The cropped depth map (if provided or None)
+ - Updated camera intrinsics (if provided or None)
+ - List of cropped additional quantities (if provided or None)
+ """
+ image = ImageList(image)
+ left, top, right, bottom = crop_bbox
+
+ image = image.crop((left, top, right, bottom))
+ if depthmap is not None:
+ depthmap = depthmap[top:bottom, left:right]
+ if additional_quantities is not None:
+ additional_quantities = [
+ quantity[top:bottom, left:right] for quantity in additional_quantities
+ ]
+
+ if camera_intrinsics is not None:
+ camera_intrinsics = camera_intrinsics.copy()
+ camera_intrinsics[0, 2] -= left
+ camera_intrinsics[1, 2] -= top
+
+ return (image.to_pil(), depthmap, camera_intrinsics, additional_quantities)
+
+
+def bbox_from_intrinsics_in_out(
+ input_camera_matrix, output_camera_matrix, output_resolution
+):
+ """
+ Calculate the bounding box for cropping based on input and output camera intrinsics.
+
+ Args:
+ input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix
+ output_camera_matrix (numpy.ndarray): Target camera intrinsics matrix
+ output_resolution (tuple): Target resolution as (width, height)
+
+ Returns:
+ tuple: Crop bounding box as (left, top, right, bottom)
+ """
+ out_width, out_height = output_resolution
+ left, top = np.int32(
+ np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])
+ )
+ crop_bbox = (left, top, left + out_width, top + out_height)
+ return crop_bbox
+
+
+def crop_resize_if_necessary(
+ image,
+ resolution,
+ depthmap=None,
+ intrinsics=None,
+ additional_quantities=None,
+):
+ """
+ First downsample image using LANCZOS and then crop if necessary to achieve target resolution.
+
+ This function performs high-quality downsampling followed by cropping to achieve the
+ desired output resolution while maintaining proper camera intrinsics.
+
+ Args:
+ image (PIL.Image.Image or numpy.ndarray): The input image to be processed
+ resolution (tuple): Target resolution as (width, height)
+ depthmap (numpy.ndarray, optional): Depth map associated with the image
+ intrinsics (numpy.ndarray, optional): Camera intrinsics matrix
+ additional_quantities (list of numpy.ndarray, optional): Additional data arrays to process
+
+ Returns:
+ tuple: A tuple containing the processed image and any provided additional data
+ (depthmap, intrinsics, additional_quantities) that have been similarly processed
+ """
+ # Convert image to PIL.Image.Image if necessary
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+
+ # Get width and height of image
+ original_width, original_height = image.size
+
+ # High-quality Lanczos down-scaling
+ target_rescale_resolution = np.array(resolution)
+ image, depthmap, intrinsics, additional_quantities = (
+ rescale_image_and_other_optional_info(
+ image=image,
+ output_resolution=target_rescale_resolution,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities_to_be_resized_with_nearest=additional_quantities,
+ )
+ )
+
+ # Actual cropping (if necessary)
+ if intrinsics is not None:
+ new_intrinsics = camera_matrix_of_crop(
+ input_camera_matrix=intrinsics,
+ input_resolution=image.size,
+ output_resolution=resolution,
+ offset_factor=0.5,
+ )
+ crop_bbox = bbox_from_intrinsics_in_out(
+ input_camera_matrix=intrinsics,
+ output_camera_matrix=new_intrinsics,
+ output_resolution=resolution,
+ )
+ else:
+ # Create a centered crop if no intrinsics are available
+ w, h = image.size
+ target_w, target_h = resolution
+ left = (w - target_w) // 2
+ top = (h - target_h) // 2
+ crop_bbox = (left, top, left + target_w, top + target_h)
+
+ image, depthmap, new_intrinsics, additional_quantities = (
+ crop_image_and_other_optional_info(
+ image=image,
+ crop_bbox=crop_bbox,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities=additional_quantities,
+ )
+ )
+
+ # Return the output
+ output = (image,)
+ if depthmap is not None:
+ output += (depthmap,)
+ if new_intrinsics is not None:
+ output += (new_intrinsics,)
+ if additional_quantities is not None:
+ output += (additional_quantities,)
+ return output
diff --git a/src/utils/geometry.py b/src/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34c4154161ccf766cd19adbcd5aea2e483b837e
--- /dev/null
+++ b/src/utils/geometry.py
@@ -0,0 +1,531 @@
+"""
+Utilities for geometry operations.
+
+References: DUSt3R, MoGe
+"""
+
+from numbers import Number
+from typing import Tuple, Union
+
+import numpy as np
+from src.utils.warnings import no_warnings
+
+
+def colmap_to_opencv_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] -= 0.5
+ K[1, 2] -= 0.5
+
+ return K
+
+
+def opencv_to_colmap_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] += 0.5
+ K[1, 2] += 0.5
+
+ return K
+
+
+def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12):
+ """
+ Compute angle difference between 3D vectors using NumPy.
+
+ Args:
+ v1 (np.ndarray): First vector of shape (..., 3)
+ v2 (np.ndarray): Second vector of shape (..., 3)
+ eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12.
+
+ Returns:
+ np.ndarray: Angle differences in radians
+ """
+ return np.arctan2(
+ np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1)
+ )
+
+
+@no_warnings(category=RuntimeWarning)
+def points_to_normals(
+ point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None
+) -> np.ndarray:
+ """
+ Calculate normal map from point map. Value range is [-1, 1].
+
+ Args:
+ point (np.ndarray): shape (height, width, 3), point map
+ mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None.
+ edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None.
+
+ Returns:
+ normal (np.ndarray): shape (height, width, 3), normal map.
+ """
+ height, width = point.shape[-3:-1]
+ has_mask = mask is not None
+
+ if mask is None:
+ mask = np.ones_like(point[..., 0], dtype=bool)
+ mask_pad = np.zeros((height + 2, width + 2), dtype=bool)
+ mask_pad[1:-1, 1:-1] = mask
+ mask = mask_pad
+
+ pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype)
+ pts[1:-1, 1:-1, :] = point
+ up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :]
+ left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :]
+ down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :]
+ right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :]
+ normal = np.stack(
+ [
+ np.cross(up, left, axis=-1),
+ np.cross(left, down, axis=-1),
+ np.cross(down, right, axis=-1),
+ np.cross(right, up, axis=-1),
+ ]
+ )
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+
+ valid = (
+ np.stack(
+ [
+ mask[:-2, 1:-1] & mask[1:-1, :-2],
+ mask[1:-1, :-2] & mask[2:, 1:-1],
+ mask[2:, 1:-1] & mask[1:-1, 2:],
+ mask[1:-1, 2:] & mask[:-2, 1:-1],
+ ]
+ )
+ & mask[None, 1:-1, 1:-1]
+ )
+ if edge_threshold is not None:
+ view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal)
+ view_angle = np.minimum(view_angle, np.pi - view_angle)
+ valid = valid & (view_angle < np.deg2rad(edge_threshold))
+
+ normal = (normal * valid[..., None]).sum(axis=0)
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+
+ if has_mask:
+ normal_mask = valid.any(axis=0)
+ normal = np.where(normal_mask[..., None], normal, 0)
+ return normal, normal_mask
+ else:
+ return normal
+
+
+def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
+ """
+ Create a sliding window view of the input array along a specified axis.
+
+ This function creates a memory-efficient view of the input array with sliding windows
+ of the specified size and stride. The window dimension is appended to the end of the
+ output array's shape. This is useful for operations like convolution, pooling, or
+ any analysis that requires examining local neighborhoods in the data.
+
+ Args:
+ x (np.ndarray): Input array with shape (..., axis_size, ...)
+ window_size (int): Size of the sliding window
+ stride (int): Stride of the sliding window (step size between consecutive windows)
+ axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis)
+
+ Returns:
+ np.ndarray: View of the input array with shape (..., n_windows, ..., window_size),
+ where n_windows = (axis_size - window_size + 1) // stride
+
+ Raises:
+ AssertionError: If window_size is larger than the size of the specified axis
+
+ Example:
+ >>> x = np.array([1, 2, 3, 4, 5, 6])
+ >>> sliding_window_1d(x, window_size=3, stride=2)
+ array([[1, 2, 3],
+ [3, 4, 5]])
+ """
+ assert x.shape[axis] >= window_size, (
+ f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})"
+ )
+ axis = axis % x.ndim
+ shape = (
+ *x.shape[:axis],
+ (x.shape[axis] - window_size + 1) // stride,
+ *x.shape[axis + 1 :],
+ window_size,
+ )
+ strides = (
+ *x.strides[:axis],
+ stride * x.strides[axis],
+ *x.strides[axis + 1 :],
+ x.strides[axis],
+ )
+ x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+ return x_sliding
+
+
+def sliding_window_nd(
+ x: np.ndarray,
+ window_size: Tuple[int, ...],
+ stride: Tuple[int, ...],
+ axis: Tuple[int, ...],
+) -> np.ndarray:
+ """
+ Create sliding windows along multiple dimensions of the input array.
+
+ This function applies sliding_window_1d sequentially along multiple axes to create
+ N-dimensional sliding windows. This is useful for operations that need to examine
+ local neighborhoods in multiple dimensions simultaneously.
+
+ Args:
+ x (np.ndarray): Input array
+ window_size (Tuple[int, ...]): Size of the sliding window for each axis
+ stride (Tuple[int, ...]): Stride of the sliding window for each axis
+ axis (Tuple[int, ...]): Axes to perform sliding window over
+
+ Returns:
+ np.ndarray: Array with sliding windows along the specified dimensions.
+ The window dimensions are appended to the end of the shape.
+
+ Note:
+ The length of window_size, stride, and axis tuples must be equal.
+
+ Example:
+ >>> x = np.random.rand(10, 10)
+ >>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1))
+ >>> # Creates 3x3 sliding windows with stride 2 in both dimensions
+ """
+ axis = [axis[i] % x.ndim for i in range(len(axis))]
+ for i in range(len(axis)):
+ x = sliding_window_1d(x, window_size[i], stride[i], axis[i])
+ return x
+
+
+def sliding_window_2d(
+ x: np.ndarray,
+ window_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]],
+ axis: Tuple[int, int] = (-2, -1),
+) -> np.ndarray:
+ """
+ Create 2D sliding windows over the input array.
+
+ Convenience function for creating 2D sliding windows, commonly used for image
+ processing operations like convolution, pooling, or patch extraction.
+
+ Args:
+ x (np.ndarray): Input array
+ window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window.
+ If int, same size is used for both dimensions.
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window.
+ If int, same stride is used for both dimensions.
+ axis (Tuple[int, int], optional): Two axes to perform sliding window over.
+ Defaults to (-2, -1) (last two dimensions).
+
+ Returns:
+ np.ndarray: Array with 2D sliding windows. The window dimensions (height, width)
+ are appended to the end of the shape.
+
+ Example:
+ >>> image = np.random.rand(100, 100)
+ >>> patches = sliding_window_2d(image, window_size=8, stride=4)
+ >>> # Creates 8x8 patches with stride 4 from the image
+ """
+ if isinstance(window_size, int):
+ window_size = (window_size, window_size)
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ return sliding_window_nd(x, window_size, stride, axis)
+
+
+def max_pool_1d(
+ x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1
+):
+ """
+ Perform 1D max pooling on the input array.
+
+ Max pooling reduces the dimensionality of the input by taking the maximum value
+ within each sliding window. This is commonly used in neural networks and signal
+ processing for downsampling and feature extraction.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (int): Size of the pooling kernel
+ stride (int): Stride of the pooling operation
+ padding (int, optional): Amount of padding to add on both sides. Defaults to 0.
+ axis (int, optional): Axis to perform max pooling over. Defaults to -1.
+
+ Returns:
+ np.ndarray: Max pooled array with reduced size along the specified axis
+
+ Note:
+ - For floating point arrays, padding is done with np.nan values
+ - For integer arrays, padding is done with the minimum value of the dtype
+ - np.nanmax is used to handle NaN values in the computation
+
+ Example:
+ >>> x = np.array([1, 3, 2, 4, 5, 1, 2])
+ >>> max_pool_1d(x, kernel_size=3, stride=2)
+ array([3, 5, 2])
+ """
+ axis = axis % x.ndim
+ if padding > 0:
+ fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min
+ padding_arr = np.full(
+ (*x.shape[:axis], padding, *x.shape[axis + 1 :]),
+ fill_value=fill_value,
+ dtype=x.dtype,
+ )
+ x = np.concatenate([padding_arr, x, padding_arr], axis=axis)
+ a_sliding = sliding_window_1d(x, kernel_size, stride, axis)
+ max_pool = np.nanmax(a_sliding, axis=-1)
+ return max_pool
+
+
+def max_pool_nd(
+ x: np.ndarray,
+ kernel_size: Tuple[int, ...],
+ stride: Tuple[int, ...],
+ padding: Tuple[int, ...],
+ axis: Tuple[int, ...],
+) -> np.ndarray:
+ """
+ Perform N-dimensional max pooling on the input array.
+
+ This function applies max_pool_1d sequentially along multiple axes to perform
+ multi-dimensional max pooling. This is useful for downsampling multi-dimensional
+ data while preserving the most important features.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis
+ stride (Tuple[int, ...]): Stride of the pooling operation for each axis
+ padding (Tuple[int, ...]): Amount of padding for each axis
+ axis (Tuple[int, ...]): Axes to perform max pooling over
+
+ Returns:
+ np.ndarray: Max pooled array with reduced size along the specified axes
+
+ Note:
+ The length of kernel_size, stride, padding, and axis tuples must be equal.
+ Max pooling is applied sequentially along each axis in the order specified.
+
+ Example:
+ >>> x = np.random.rand(10, 10, 10)
+ >>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2),
+ ... padding=(0, 0, 0), axis=(-3, -2, -1))
+ >>> # Reduces each dimension by half with 2x2x2 max pooling
+ """
+ for i in range(len(axis)):
+ x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i])
+ return x
+
+
+def max_pool_2d(
+ x: np.ndarray,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]],
+ padding: Union[int, Tuple[int, int]],
+ axis: Tuple[int, int] = (-2, -1),
+):
+ """
+ Perform 2D max pooling on the input array.
+
+ Convenience function for 2D max pooling, commonly used in computer vision
+ and image processing for downsampling images while preserving important features.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel.
+ If int, same size is used for both dimensions.
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation.
+ If int, same stride is used for both dimensions.
+ padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions.
+ If int, same padding is used for both dimensions.
+ axis (Tuple[int, int], optional): Two axes to perform max pooling over.
+ Defaults to (-2, -1) (last two dimensions).
+
+ Returns:
+ np.ndarray: 2D max pooled array with reduced size along the specified axes
+
+ Example:
+ >>> image = np.random.rand(64, 64)
+ >>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0)
+ >>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling
+ """
+ if isinstance(kernel_size, Number):
+ kernel_size = (kernel_size, kernel_size)
+ if isinstance(stride, Number):
+ stride = (stride, stride)
+ if isinstance(padding, Number):
+ padding = (padding, padding)
+ axis = tuple(axis)
+ return max_pool_nd(x, kernel_size, stride, padding, axis)
+
+
+@no_warnings(category=RuntimeWarning)
+def depth_edge(
+ depth: np.ndarray,
+ atol: float = None,
+ rtol: float = None,
+ kernel_size: int = 3,
+ mask: np.ndarray = None,
+) -> np.ndarray:
+ """
+ Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
+
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff = max_pool_2d(
+ depth, kernel_size, stride=1, padding=kernel_size // 2
+ ) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
+ else:
+ diff = max_pool_2d(
+ np.where(mask, depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ ) + max_pool_2d(
+ np.where(mask, -depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+
+def depth_aliasing(
+ depth: np.ndarray,
+ atol: float = None,
+ rtol: float = None,
+ kernel_size: int = 3,
+ mask: np.ndarray = None,
+) -> np.ndarray:
+ """
+ Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff_max = (
+ max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
+ )
+ diff_min = (
+ max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
+ )
+ else:
+ diff_max = (
+ max_pool_2d(
+ np.where(mask, depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ - depth
+ )
+ diff_min = (
+ max_pool_2d(
+ np.where(mask, -depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ + depth
+ )
+ diff = np.minimum(diff_max, diff_min)
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+
+@no_warnings(category=RuntimeWarning)
+def normals_edge(
+ normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None
+) -> np.ndarray:
+ """
+ Compute the edge mask from normal map.
+
+ Args:
+ normal (np.ndarray): shape (..., height, width, 3), normal map
+ tol (float): tolerance in degrees
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ assert normals.ndim >= 3 and normals.shape[-1] == 3, (
+ "normal should be of shape (..., height, width, 3)"
+ )
+ normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
+
+ padding = kernel_size // 2
+ normals_window = sliding_window_2d(
+ np.pad(
+ normals,
+ (
+ *([(0, 0)] * (normals.ndim - 3)),
+ (padding, padding),
+ (padding, padding),
+ (0, 0),
+ ),
+ mode="edge",
+ ),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2),
+ )
+ if mask is None:
+ angle_diff = np.arccos(
+ (normals[..., None, None] * normals_window).sum(axis=-3)
+ ).max(axis=(-2, -1))
+ else:
+ mask_window = sliding_window_2d(
+ np.pad(
+ mask,
+ (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)),
+ mode="edge",
+ ),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2),
+ )
+ angle_diff = np.where(
+ mask_window,
+ np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)),
+ 0,
+ ).max(axis=(-2, -1))
+
+ angle_diff = max_pool_2d(
+ angle_diff, kernel_size, stride=1, padding=kernel_size // 2
+ )
+ edge = angle_diff > np.deg2rad(tol)
+ return edge
diff --git a/src/utils/gs_effects.py b/src/utils/gs_effects.py
new file mode 100644
index 0000000000000000000000000000000000000000..dec327d4994cb27e4fef57b16d484b211baa0b50
--- /dev/null
+++ b/src/utils/gs_effects.py
@@ -0,0 +1,272 @@
+from math import atan2, cos, exp, floor, sin, sqrt
+
+import numpy as np
+import torch
+
+def fract(x):
+ """Get fractional part of a number"""
+ if isinstance(x, torch.Tensor):
+ return x - torch.floor(x)
+ return x - floor(x)
+
+class GSEffects:
+ """Convert GLSL GS render effects to PyTorch - vectorized for batch processing"""
+
+ def __init__(self, start_time=0.0, end_time=10.0):
+ """
+ Initialize effects with time range
+
+ Args:
+ start_time: Animation start time
+ end_time: Animation end time
+ """
+ self.start_time = start_time
+ self.end_time = end_time
+
+ @staticmethod
+ def smoothstep(edge0, edge1, x):
+ """GLSL smoothstep function (vectorized)"""
+ if isinstance(x, torch.Tensor):
+ result = torch.zeros_like(x, dtype=x.dtype)
+ mask_low = x < edge0
+ mask_high = x > edge1
+ mask_mid = ~(mask_low | mask_high)
+
+ t = (x[mask_mid] - edge0) / (edge1 - edge0)
+ result[mask_mid] = t * t * (3.0 - 2.0 * t)
+ result[mask_low] = 0.0
+ result[mask_high] = 1.0
+ return result
+ else:
+ if x < edge0:
+ return 0.0
+ if x > edge1:
+ return 1.0
+ t = (x - edge0) / (edge1 - edge0)
+ return t * t * (3.0 - 2.0 * t)
+
+ @staticmethod
+ def step(edge, x):
+ """GLSL step function (vectorized)"""
+ if isinstance(x, torch.Tensor):
+ return (x >= edge).to(x.dtype)
+ if isinstance(edge, torch.Tensor):
+ return (x >= edge).to(edge.dtype)
+ return 1.0 if x >= edge else 0.0
+
+ @staticmethod
+ def mix(x, y, a):
+ """GLSL mix function (linear interpolation, vectorized)"""
+ return x * (1.0 - a) + y * a
+
+ @staticmethod
+ def clamp(x, min_val, max_val):
+ """Clamp value between min and max (vectorized)"""
+ if isinstance(x, torch.Tensor):
+ return torch.clamp(x, min_val, max_val)
+ return max(min_val, min(max_val, x))
+
+ @staticmethod
+ def length_xz(pos):
+ """Calculate length of XZ components (vectorized)"""
+ if pos.dim() == 1:
+ return torch.sqrt(pos[0]**2 + pos[2]**2)
+ return torch.sqrt(pos[:, 0]**2 + pos[:, 2]**2)
+
+ @staticmethod
+ def length_vec(v):
+ """Calculate vector length (vectorized)"""
+ if v.dim() == 1:
+ return torch.sqrt(torch.sum(v**2))
+ return torch.sqrt(torch.sum(v**2, dim=1))
+
+ @staticmethod
+ def hash(p):
+ """Pseudo-random hash function (vectorized)"""
+ p = fract(p * 0.3183099 + 0.1)
+ p = p * 17.0
+ return torch.stack([
+ fract(p[:, 0] * p[:, 1] * p[:, 2]),
+ fract(p[:, 0] + p[:, 1] * p[:, 2]),
+ fract(p[:, 0] * p[:, 1] + p[:, 2])
+ ], dim=1)
+
+ @staticmethod
+ def noise(p):
+ """3D Perlin-style noise function (vectorized)"""
+ i = torch.floor(p).to(torch.long)
+ f = fract(p)
+ f = f * f * (3.0 - 2.0 * f)
+
+ def get_hash_offset(offset):
+ return GSEffects.hash(i.to(p.dtype) + offset)
+
+ n000 = get_hash_offset(torch.tensor([0, 0, 0], dtype=p.dtype, device=p.device))
+ n100 = get_hash_offset(torch.tensor([1, 0, 0], dtype=p.dtype, device=p.device))
+ n010 = get_hash_offset(torch.tensor([0, 1, 0], dtype=p.dtype, device=p.device))
+ n110 = get_hash_offset(torch.tensor([1, 1, 0], dtype=p.dtype, device=p.device))
+ n001 = get_hash_offset(torch.tensor([0, 0, 1], dtype=p.dtype, device=p.device))
+ n101 = get_hash_offset(torch.tensor([1, 0, 1], dtype=p.dtype, device=p.device))
+ n011 = get_hash_offset(torch.tensor([0, 1, 1], dtype=p.dtype, device=p.device))
+ n111 = get_hash_offset(torch.tensor([1, 1, 1], dtype=p.dtype, device=p.device))
+
+ x0 = GSEffects.mix(n000, n100, f[:, 0:1])
+ x1 = GSEffects.mix(n010, n110, f[:, 0:1])
+ x2 = GSEffects.mix(n001, n101, f[:, 0:1])
+ x3 = GSEffects.mix(n011, n111, f[:, 0:1])
+
+ y0 = GSEffects.mix(x0, x1, f[:, 1:2])
+ y1 = GSEffects.mix(x2, x3, f[:, 1:2])
+
+ return GSEffects.mix(y0, y1, f[:, 2:3])
+
+ @staticmethod
+ def rot_2d(angle):
+ """2D rotation (vectorized)"""
+ if isinstance(angle, torch.Tensor):
+ s = torch.sin(angle)
+ c = torch.cos(angle)
+ rot = torch.stack([torch.stack([c, -s], dim=-1),
+ torch.stack([s, c], dim=-1)], dim=-2).squeeze()
+ else:
+ s = np.sin(angle)
+ c = np.cos(angle)
+ rot = torch.tensor([[c, -s],
+ [s, c]]).cuda().float()
+ return rot
+
+ def twister(self, pos, scale, t):
+ h = self.hash(pos)[:, 0:1] + 0.1
+ pos_xz_len = self.length_xz(pos)
+ s = self.smoothstep(0.0, 8.0, t * t * 0.1 - pos_xz_len * 2.0 + 2.0)[:, None]
+ mask = (torch.linalg.norm(scale, dim=-1, keepdim=True) < 0.05)
+ pos_y = torch.where(mask, (-10. + pos[:, 1:2]) * (s ** (2 * h)), pos[:, 1:2])
+ pos_xz = pos[:, [0, 2]] * torch.exp(-1 * torch.linalg.norm(pos[:, [0, 2]], dim=-1, keepdim=True))
+ pos_xz = torch.einsum("n i, n i j -> n j", pos_xz, self.rot_2d(t * 0.2 + pos[:, 1:2] * 20. * (1 - s)))
+ pos_new = torch.cat([pos_xz[:, 0:1], pos_y, pos_xz[:, 1:2]], dim=-1)
+ return pos_new, s ** 4
+
+ def rain(self, pos, scale, t):
+ h = self.hash(pos)
+ pos_xz_len = self.length_xz(pos)
+ s = self.smoothstep(0.0, 5.0, t * t * 0.1 - pos_xz_len * 2.0 + 1.0) ** (0.5 + h[:, 0])
+ y = pos[:, 1:2]
+ pos_y = torch.minimum(-10. + s[:, None] * 15., pos[:, 1:2])
+ pos_x = pos[:, 0:1] + pos_y * 0.2
+ pos_xz = torch.cat([pos_x, pos[:, 2:3]], dim=-1)
+ pos_xz = pos_xz * torch.matmul(self.rot_2d(t * 0.3), torch.ones_like(pos_xz).unsqueeze(-1)).squeeze(-1)
+ pos_new = torch.cat([pos_xz[:, 0:1], pos_y, pos_xz[:, 1:2]], dim=-1)
+ a = self.smoothstep(-10.0, y.squeeze(), pos_y.squeeze())[:, None]
+ return pos_new, a
+
+ def apply_effect(self, gsplat, t, effect_type, ignore_scale=False):
+ """
+ Apply the effect shader logic (vectorized for batch processing)
+
+ Args:
+ gsplat: Dictionary with:
+ 'means': (n, 3) tensor
+ 'scales': (n, 3) tensor
+ 'colors': (n, 3) tensor
+ 'quats': (n, 4) tensor
+ 'opacities': (n,) tensor
+ t: Current time (normalized based on start_time and end_time)
+ effect_type: 2=Spread
+
+ Returns:
+ Modified gsplat dictionary
+ """
+ # Normalize time to animation range
+ normalized_t = t - self.start_time
+ device = gsplat['means'].device
+ dtype = gsplat['means'].dtype
+
+ output = {
+ 'means': gsplat['means'].clone(),
+ 'quats': gsplat['quats'].clone(),
+ 'scales': gsplat['scales'].clone(),
+ 'opacities': gsplat['opacities'].clone(),
+ 'colors': gsplat['colors'].clone()
+ }
+
+ s = self.smoothstep(0.0, 10.0, normalized_t - 3.2) * 10.0
+ scales = output['scales']
+ local_pos = output['means'].clone()
+ l = self.length_xz(local_pos)
+ smoothstep_val = None
+
+ if effect_type == 2: # Spread Effect
+ border = torch.abs(s - l - 0.5)
+ decay = 1.0 - 0.2 * torch.exp(-20.0 * border)
+ # decay = 1.0 - 0.7 * torch.exp(-10.0 * border)
+ local_pos = local_pos * decay[:, None]
+
+ smoothstep_val = self.smoothstep(s - 0.5, s, l + 0.5)
+ # final_scales = self.mix(scales, 0.002, smoothstep_val[:, None])
+ if not ignore_scale:
+ final_scales = self.mix(scales, 1e-9, smoothstep_val[:, None])
+ else:
+ final_scales = scales
+
+ noise_input = torch.stack([
+ local_pos[:, 0] * 2.0 + normalized_t * 0.5,
+ local_pos[:, 1] * 2.0 + normalized_t * 0.5,
+ local_pos[:, 2] * 2.0 + normalized_t * 0.5
+ ], dim=1)
+ noise_val = self.noise(noise_input)
+
+ output['means'] = local_pos + 0.0 * noise_val * smoothstep_val[:, None]
+ output['scales'] = final_scales
+
+ at = torch.atan2(local_pos[:, 0], local_pos[:, 2]) / 3.1416
+ output['colors'] *= self.step(at, normalized_t - 3.1416)[:, None]
+ output['colors'] += (torch.exp(-20.0 * border) +
+ torch.exp(-50.0 * torch.abs(normalized_t - at - 3.1416)) * 0.5)[:, None]
+ output['opacities'] *= self.step(at, normalized_t - 3.1416)
+ output['opacities'] += (torch.exp(-20.0 * border) +
+ torch.exp(-50.0 * torch.abs(normalized_t - at - 3.1416)) * 0.5)
+
+ # ===== New feature: Randomly mask points based on smoothstep_val =====
+ # Higher smoothstep_val means higher probability of masking
+ mask_prob = smoothstep_val.squeeze() if smoothstep_val.dim() > 1 else smoothstep_val
+ if not hasattr(self, "random_vals"):
+ self.random_vals = torch.rand(mask_prob.shape, device=device, dtype=dtype)
+ mask = self.random_vals < mask_prob*0.8 # True indicates the point is masked
+
+ # Apply mask to various attributes
+ if not ignore_scale:
+ output['means'][mask] *= 0 # Or can be set to other values
+ output['scales'][mask] *= 0 # Set scales to 0 to make points invisible
+ output['opacities'][mask] *= 0 # Set opacity to 0 to make points transparent
+
+ return output, smoothstep_val
+
+
+# Usage example
+if __name__ == "__main__":
+ # Create effects processor with time range from 0 to 10 seconds
+ effects = GSEffects(start_time=0.0, end_time=10.0)
+
+ # Sample gsplat data (batch)
+ n_points = 100
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ sample_gsplat = {
+ 'means': torch.randn(n_points, 3, dtype=torch.float32, device=device),
+ 'quats': torch.randn(n_points, 4, dtype=torch.float32, device=device),
+ 'scales': torch.rand(n_points, 3, dtype=torch.float32, device=device),
+ 'opacities': torch.rand(n_points, dtype=torch.float32, device=device),
+ 'colors': torch.rand(n_points, 3, dtype=torch.float32, device=device)
+ }
+
+ # Apply Magic effect at different time points
+ for t in [0.0, 2.5, 5.0, 7.5, 10.0]:
+ result = effects.apply_effect(sample_gsplat, t, effect_type=2)
+ print(f"\nTime: {t}s")
+ print(f"Center shape: {result['means'].shape}")
+ print(f"Center[0]: {result['means'][0]}")
+ print(f"Scales shape: {result['scales'].shape}")
+ print(f"Scales[0]: {result['scales'][0]}")
+ print(f"RGB shape: {result['colors'].shape}")
+ print(f"RGB[0]: {result['colors'][0]}")
+ print(f"Opacity shape: {result['opacities'].shape}")
+ print(f"Opacity[0]: {result['opacities'][0]}")
\ No newline at end of file
diff --git a/src/utils/inference_utils.py b/src/utils/inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2d547ef7c1525b6d477cbdcda374caf26ea3e2
--- /dev/null
+++ b/src/utils/inference_utils.py
@@ -0,0 +1,262 @@
+import torch
+from PIL import Image
+from torchvision import transforms
+
+import glob
+import os
+from src.utils.video_utils import video_to_image_frames
+
+IMAGE_EXTS = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff', '*.webp']
+VIDEO_EXTS = ['.mp4', '.avi', '.mov', '.webm', '.gif']
+
+
+
+def load_and_preprocess_images(image_file_paths, preprocessing_mode="crop", output_size=518):
+ """
+ Transform raw image files into model-ready tensor batches with standardized dimensions.
+
+ This utility function handles the complete pipeline from file paths to batched tensors,
+ ensuring compatibility with neural network requirements while preserving image quality.
+
+ Args:
+ image_file_paths (list): Collection of file system paths pointing to image files
+ preprocessing_mode (str, optional): Image transformation strategy:
+ - "crop" (default): Resize width to 518px, center-crop height if oversized
+ - "pad": Scale largest dimension to 518px, pad smaller dimension to square
+ output_size (int, optional): Target dimension for model input (default: 518)
+
+ Returns:
+ torch.Tensor: Processed image batch with shape (1, N, 3, H, W) ready for model inference
+
+ Raises:
+ ValueError: When input validation fails (empty list or invalid mode)
+
+ Implementation Details:
+ - Automatic alpha channel handling: RGBA images composited onto white backgrounds
+ - Dimension normalization: All outputs divisible by 14 for patch-based processing
+ - Batch consistency: Different-sized images padded to uniform dimensions
+ - Memory optimization: Efficient tensor operations with minimal data copying
+ - Quality preservation: Bicubic resampling maintains visual fidelity
+ """
+ # Input validation and parameter setup
+ if len(image_file_paths) == 0:
+ raise ValueError("At least 1 image is required")
+
+ if preprocessing_mode not in ["crop", "pad"]:
+ raise ValueError("preprocessing_mode must be either 'crop' or 'pad'")
+
+ processed_image_list = []
+ image_dimension_set = set()
+ tensor_converter = transforms.ToTensor()
+ model_target_size = output_size
+
+ # Individual image processing pipeline
+ for image_file_path in image_file_paths:
+ # File system to memory conversion
+ loaded_image = Image.open(image_file_path)
+
+ # Transparency handling for RGBA images
+ if loaded_image.mode == "RGBA":
+ # Generate white canvas matching image dimensions
+ white_background = Image.new("RGBA", loaded_image.size, (255, 255, 255, 255))
+ # Blend transparent pixels with white background
+ loaded_image = Image.alpha_composite(white_background, loaded_image)
+
+ # Format standardization to RGB
+ loaded_image = loaded_image.convert("RGB")
+
+ original_width, original_height = loaded_image.size
+
+ # Dimension calculation based on preprocessing strategy
+ if preprocessing_mode == "pad":
+ # Proportional scaling to fit largest dimension within target
+ if original_width >= original_height:
+ scaled_width = model_target_size
+ scaled_height = round(original_height * (scaled_width / original_width) / 14) * 14 # Patch compatibility
+ else:
+ scaled_height = model_target_size
+ scaled_width = round(original_width * (scaled_height / original_height) / 14) * 14 # Patch compatibility
+ else: # preprocessing_mode == "crop"
+ # Width normalization with proportional height adjustment
+ scaled_width = model_target_size
+ scaled_height = round(original_height * (scaled_width / original_width) / 14) * 14
+
+ # High-quality image resizing
+ loaded_image = loaded_image.resize((scaled_width, scaled_height), Image.Resampling.BICUBIC)
+ image_tensor = tensor_converter(loaded_image) # Normalize to [0, 1] range
+
+ # Height trimming for crop mode (center-based)
+ if preprocessing_mode == "crop" and scaled_height > model_target_size:
+ crop_start_y = (scaled_height - model_target_size) // 2
+ image_tensor = image_tensor[:, crop_start_y : crop_start_y + model_target_size, :]
+
+ # Square padding for pad mode (centered)
+ if preprocessing_mode == "pad":
+ height_padding_needed = model_target_size - image_tensor.shape[1]
+ width_padding_needed = model_target_size - image_tensor.shape[2]
+
+ if height_padding_needed > 0 or width_padding_needed > 0:
+ padding_top = height_padding_needed // 2
+ padding_bottom = height_padding_needed - padding_top
+ padding_left = width_padding_needed // 2
+ padding_right = width_padding_needed - padding_left
+
+ # White padding application (value=1.0 for normalized images)
+ image_tensor = torch.nn.functional.pad(
+ image_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=1.0
+ )
+
+ image_dimension_set.add((image_tensor.shape[1], image_tensor.shape[2]))
+ processed_image_list.append(image_tensor)
+
+ # Cross-image dimension harmonization
+ if len(image_dimension_set) > 1:
+ print(f"Warning: Found images with different shapes: {image_dimension_set}")
+ # Calculate maximum dimensions across the batch
+ maximum_height = max(dimension[0] for dimension in image_dimension_set)
+ maximum_width = max(dimension[1] for dimension in image_dimension_set)
+
+ # Uniform padding to achieve batch consistency
+ uniformly_sized_images = []
+ for image_tensor in processed_image_list:
+ height_padding_needed = maximum_height - image_tensor.shape[1]
+ width_padding_needed = maximum_width - image_tensor.shape[2]
+
+ if height_padding_needed > 0 or width_padding_needed > 0:
+ padding_top = height_padding_needed // 2
+ padding_bottom = height_padding_needed - padding_top
+ padding_left = width_padding_needed // 2
+ padding_right = width_padding_needed - padding_left
+
+ image_tensor = torch.nn.functional.pad(
+ image_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=1.0
+ )
+ uniformly_sized_images.append(image_tensor)
+ processed_image_list = uniformly_sized_images
+
+ # Batch tensor construction
+ batched_images = torch.stack(processed_image_list) # Concatenate along batch dimension
+
+ # Single image batch dimension handling
+ if len(image_file_paths) == 1:
+ # Ensure proper 4D tensor structure (batch, channels, height, width)
+ if batched_images.dim() == 3:
+ batched_images = batched_images.unsqueeze(0)
+
+ return batched_images.unsqueeze(0)
+
+
+def _handle_alpha_channel(img_data):
+ """Process RGBA images by blending with white background."""
+ if img_data.mode == "RGBA":
+ white_bg = Image.new("RGBA", img_data.size, (255, 255, 255, 255))
+ img_data = Image.alpha_composite(white_bg, img_data)
+ return img_data.convert("RGB")
+
+
+def _calculate_resize_dims(orig_w, orig_h, max_dim, resize_strategy, patch_size=14):
+ """Calculate new dimensions based on resize strategy."""
+ if resize_strategy == "pad":
+ if orig_w >= orig_h:
+ new_w = max_dim
+ new_h = round(orig_h * (new_w / orig_w) / patch_size) * patch_size
+ else:
+ new_h = max_dim
+ new_w = round(orig_w * (new_h / orig_h) / patch_size) * patch_size
+ else: # crop strategy
+ new_w = max_dim
+ new_h = round(orig_h * (new_w / orig_w) / patch_size) * patch_size
+ return new_w, new_h
+
+
+def _apply_padding(tensor_img, target_dim):
+ """Apply padding to make tensor square."""
+ h_pad = target_dim - tensor_img.shape[1]
+ w_pad = target_dim - tensor_img.shape[2]
+
+ if h_pad > 0 or w_pad > 0:
+ pad_top, pad_bottom = h_pad // 2, h_pad - h_pad // 2
+ pad_left, pad_right = w_pad // 2, w_pad - w_pad // 2
+ return torch.nn.functional.pad(
+ tensor_img, (pad_left, pad_right, pad_top, pad_bottom),
+ mode="constant", value=1.0
+ )
+ return tensor_img
+
+
+def prepare_images_to_tensor(file_paths, resize_strategy="crop", target_size=518):
+ """
+ Process image files into uniform tensor batch for model input.
+
+ Args:
+ file_paths (list): Paths to image files
+ resize_strategy (str): "crop" or "pad" processing mode
+ target_size (int): Target size for processing
+
+ Returns:
+ torch.Tensor: Processed image batch (1, N, 3, H, W)
+ """
+ if not file_paths:
+ raise ValueError("At least 1 image is required")
+
+ if resize_strategy not in ["crop", "pad"]:
+ raise ValueError("Strategy must be 'crop' or 'pad'")
+
+ tensor_list = []
+ dimension_set = set()
+ converter = transforms.ToTensor()
+
+ # Process each image file
+ for file_path in file_paths:
+ img_data = Image.open(file_path)
+ img_data = _handle_alpha_channel(img_data)
+
+ orig_w, orig_h = img_data.size
+ new_w, new_h = _calculate_resize_dims(orig_w, orig_h, target_size, resize_strategy)
+
+ # Resize and convert to tensor
+ img_data = img_data.resize((new_w, new_h), Image.Resampling.BICUBIC)
+ tensor_img = converter(img_data)
+
+ # Apply center crop for crop strategy
+ if resize_strategy == "crop" and new_h > target_size:
+ crop_start = (new_h - target_size) // 2
+ tensor_img = tensor_img[:, crop_start:crop_start + target_size, :]
+
+ # Apply padding for pad strategy
+ if resize_strategy == "pad":
+ tensor_img = _apply_padding(tensor_img, target_size)
+
+ dimension_set.add((tensor_img.shape[1], tensor_img.shape[2]))
+ tensor_list.append(tensor_img)
+
+ # Handle mixed dimensions
+ if len(dimension_set) > 1:
+ print(f"Warning: Mixed image dimensions found: {dimension_set}")
+ max_h = max(dims[0] for dims in dimension_set)
+ max_w = max(dims[1] for dims in dimension_set)
+
+ tensor_list = [_apply_padding(img, max(max_h, max_w)) if img.shape[1] != max_h or img.shape[2] != max_w
+ else img for img in tensor_list]
+
+ batch_tensor = torch.stack(tensor_list)
+
+ # Ensure proper batch dimensions
+ if batch_tensor.dim() == 3:
+ batch_tensor = batch_tensor.unsqueeze(0)
+
+ return batch_tensor.unsqueeze(0)
+
+
+def extract_load_and_preprocess_images(image_folder_or_video_path, fps=1, target_size=518, mode="crop"):
+ # Support multiple image formats
+ if image_folder_or_video_path.is_file() and image_folder_or_video_path.suffix.lower() in VIDEO_EXTS:
+ frame_paths = video_to_image_frames(str(image_folder_or_video_path), fps=fps)
+ img_paths = sorted(frame_paths)
+ else:
+ img_paths = []
+ for ext in IMAGE_EXTS:
+ img_paths.extend(glob.glob(os.path.join(str(image_folder_or_video_path), ext)))
+ img_paths = sorted(img_paths)
+ images = prepare_images_to_tensor(img_paths, resize_strategy=mode, target_size=target_size)
+ return images
\ No newline at end of file
diff --git a/src/utils/render_utils.py b/src/utils/render_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc9be8d2a2cc0423c51274d85328af4ea1bb98e
--- /dev/null
+++ b/src/utils/render_utils.py
@@ -0,0 +1,377 @@
+from pathlib import Path
+
+import numpy as np
+import torch
+import moviepy.editor as mpy
+
+from src.models.models.rasterization import GaussianSplatRenderer
+from src.models.utils.sh_utils import RGB2SH, SH2RGB
+from src.utils.gs_effects import GSEffects
+from src.utils.color_map import apply_color_map_to_image
+from tqdm import tqdm
+
+
+def rotation_matrix_to_quaternion(R):
+ """Convert rotation matrix to quaternion"""
+ trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
+
+ q = torch.zeros(R.shape[:-2] + (4,), device=R.device, dtype=R.dtype)
+
+ # Case where trace > 0
+ mask1 = trace > 0
+ s = torch.sqrt(trace[mask1] + 1.0) * 2 # s=4*qw
+ q[mask1, 0] = 0.25 * s # qw
+ q[mask1, 1] = (R[mask1, 2, 1] - R[mask1, 1, 2]) / s # qx
+ q[mask1, 2] = (R[mask1, 0, 2] - R[mask1, 2, 0]) / s # qy
+ q[mask1, 3] = (R[mask1, 1, 0] - R[mask1, 0, 1]) / s # qz
+
+ # Case where R[0,0] > R[1,1] and R[0,0] > R[2,2]
+ mask2 = (~mask1) & (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
+ s = torch.sqrt(1.0 + R[mask2, 0, 0] - R[mask2, 1, 1] - R[mask2, 2, 2]) * 2 # s=4*qx
+ q[mask2, 0] = (R[mask2, 2, 1] - R[mask2, 1, 2]) / s # qw
+ q[mask2, 1] = 0.25 * s # qx
+ q[mask2, 2] = (R[mask2, 0, 1] + R[mask2, 1, 0]) / s # qy
+ q[mask2, 3] = (R[mask2, 0, 2] + R[mask2, 2, 0]) / s # qz
+
+ # Case where R[1,1] > R[2,2]
+ mask3 = (~mask1) & (~mask2) & (R[..., 1, 1] > R[..., 2, 2])
+ s = torch.sqrt(1.0 + R[mask3, 1, 1] - R[mask3, 0, 0] - R[mask3, 2, 2]) * 2 # s=4*qy
+ q[mask3, 0] = (R[mask3, 0, 2] - R[mask3, 2, 0]) / s # qw
+ q[mask3, 1] = (R[mask3, 0, 1] + R[mask3, 1, 0]) / s # qx
+ q[mask3, 2] = 0.25 * s # qy
+ q[mask3, 3] = (R[mask3, 1, 2] + R[mask3, 2, 1]) / s # qz
+
+ # Remaining case
+ mask4 = (~mask1) & (~mask2) & (~mask3)
+ s = torch.sqrt(1.0 + R[mask4, 2, 2] - R[mask4, 0, 0] - R[mask4, 1, 1]) * 2 # s=4*qz
+ q[mask4, 0] = (R[mask4, 1, 0] - R[mask4, 0, 1]) / s # qw
+ q[mask4, 1] = (R[mask4, 0, 2] + R[mask4, 2, 0]) / s # qx
+ q[mask4, 2] = (R[mask4, 1, 2] + R[mask4, 2, 1]) / s # qy
+ q[mask4, 3] = 0.25 * s # qz
+
+ return q
+
+
+def quaternion_to_rotation_matrix(q):
+ """Convert quaternion to rotation matrix"""
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
+
+ # Normalize quaternion
+ norm = torch.sqrt(w*w + x*x + y*y + z*z)
+ w, x, y, z = w/norm, x/norm, y/norm, z/norm
+
+ R = torch.zeros(q.shape[:-1] + (3, 3), device=q.device, dtype=q.dtype)
+
+ R[..., 0, 0] = 1 - 2*(y*y + z*z)
+ R[..., 0, 1] = 2*(x*y - w*z)
+ R[..., 0, 2] = 2*(x*z + w*y)
+ R[..., 1, 0] = 2*(x*y + w*z)
+ R[..., 1, 1] = 1 - 2*(x*x + z*z)
+ R[..., 1, 2] = 2*(y*z - w*x)
+ R[..., 2, 0] = 2*(x*z - w*y)
+ R[..., 2, 1] = 2*(y*z + w*x)
+ R[..., 2, 2] = 1 - 2*(x*x + y*y)
+
+ return R
+
+
+def slerp_quaternions(q1, q2, t):
+ """Spherical linear interpolation between quaternions"""
+ # Compute dot product
+ dot = (q1 * q2).sum(dim=-1, keepdim=True)
+
+ # If dot product is negative, slerp won't take the shorter path.
+ # Note that q and -q represent the same rotation, so we can flip one.
+ mask = dot < 0
+ q2 = torch.where(mask, -q2, q2)
+ dot = torch.where(mask, -dot, dot)
+
+ # If the inputs are too close for comfort, linearly interpolate
+ # and normalize the result.
+ DOT_THRESHOLD = 0.9995
+ mask_linear = dot > DOT_THRESHOLD
+
+ result = torch.zeros_like(q1)
+
+ # Linear interpolation for close quaternions
+ if mask_linear.any():
+ result_linear = q1 + t * (q2 - q1)
+ norm = torch.norm(result_linear, dim=-1, keepdim=True)
+ result_linear = result_linear / norm
+ result = torch.where(mask_linear, result_linear, result)
+
+ # Spherical interpolation for distant quaternions
+ mask_slerp = ~mask_linear
+ if mask_slerp.any():
+ theta_0 = torch.acos(torch.abs(dot))
+ sin_theta_0 = torch.sin(theta_0)
+
+ theta = theta_0 * t
+ sin_theta = torch.sin(theta)
+
+ s0 = torch.cos(theta) - dot * sin_theta / sin_theta_0
+ s1 = sin_theta / sin_theta_0
+
+ result_slerp = (s0 * q1) + (s1 * q2)
+ result = torch.where(mask_slerp, result_slerp, result)
+
+ return result
+
+
+def render_interpolated_video(gs_renderer: GaussianSplatRenderer,
+ splats: dict,
+ camtoworlds: torch.Tensor,
+ intrinsics: torch.Tensor,
+ hw: tuple[int, int],
+ out_path: Path,
+ interp_per_pair: int = 20,
+ loop_reverse: bool = True,
+ effects: GSEffects = None,
+ effect_type: int = 2,
+ save_mode: str = "split") -> None:
+ # camtoworlds: [B, S, 4, 4], intrinsics: [B, S, 3, 3]
+ b, s, _, _ = camtoworlds.shape
+ h, w = hw
+
+ # Build interpolated trajectory
+ def build_interpolated_traj(index, nums):
+ exts, ints = [], []
+ tmp_camtoworlds = camtoworlds[:, index]
+ tmp_intrinsics = intrinsics[:, index]
+ for i in range(len(index)-1):
+ exts.append(tmp_camtoworlds[:, i:i+1])
+ ints.append(tmp_intrinsics[:, i:i+1])
+ # Extract rotation and translation
+ R0, t0 = tmp_camtoworlds[:, i, :3, :3], tmp_camtoworlds[:, i, :3, 3]
+ R1, t1 = tmp_camtoworlds[:, i + 1, :3, :3], tmp_camtoworlds[:, i + 1, :3, 3]
+
+ # Convert rotations to quaternions
+ q0 = rotation_matrix_to_quaternion(R0)
+ q1 = rotation_matrix_to_quaternion(R1)
+
+ # Interpolate using smooth quaternion slerp
+ for j in range(1, nums + 1):
+ alpha = j / (nums + 1)
+
+ # Linear interpolation for translation
+ t_interp = (1 - alpha) * t0 + alpha * t1
+
+ # Spherical interpolation for rotation
+ q_interp = slerp_quaternions(q0, q1, alpha)
+ R_interp = quaternion_to_rotation_matrix(q_interp)
+
+ # Create interpolated extrinsic matrix
+ ext = torch.eye(4, device=R_interp.device, dtype=R_interp.dtype)[None].repeat(b, 1, 1)
+ ext[:, :3, :3] = R_interp
+ ext[:, :3, 3] = t_interp
+
+ # Linear interpolation for intrinsics
+ K0 = tmp_intrinsics[:, i]
+ K1 = tmp_intrinsics[:, i + 1]
+ K = (1 - alpha) * K0 + alpha * K1
+
+ exts.append(ext[:, None])
+ ints.append(K[:, None])
+
+ exts = torch.cat(exts, dim=1)[:1]
+ ints = torch.cat(ints, dim=1)[:1]
+ return exts, ints
+
+ # Build wobble trajectory
+ def build_wobble_traj(nums, delta):
+ assert s==1
+ t = torch.linspace(0, 1, nums, dtype=torch.float32, device=camtoworlds.device)
+ t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
+ tf = torch.eye(4, dtype=torch.float32, device=camtoworlds.device)
+ radius = delta * 0.15
+ tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone()
+ radius = radius[..., None]
+ radius = radius * t
+ tf[..., 0, 3] = torch.sin(2 * torch.pi * t) * radius
+ tf[..., 1, 3] = -torch.cos(2 * torch.pi * t) * radius
+ exts = camtoworlds @ tf
+ ints = intrinsics.repeat(1, exts.shape[1], 1, 1)
+ return exts, ints
+
+ if s > 1:
+ all_ext, all_int = build_interpolated_traj([i for i in range(s)], interp_per_pair)
+ else:
+ all_ext, all_int = build_wobble_traj(interp_per_pair * 12, splats["means"][0].median(dim=0).values.norm(dim=-1)[None])
+
+ rendered_rgbs, rendered_depths = [], []
+ chunk = 40 if effects is None else 1
+ t = 0
+ t_skip = 0
+ if effects is not None:
+ try:
+ pruned_splats = gs_renderer.prune_gs(splats, gs_renderer.voxel_size)
+ except:
+ pruned_splats = splats
+ # indices = [x for x in range(0, all_ext.shape[1], 2)][:4]
+ # add_ext, add_int = build_interpolated_traj(indices, 150)
+ # add_ext = torch.flip(add_ext, dims=[1])
+ # add_int = torch.flip(add_int, dims=[1])
+ add_ext = all_ext[:, :1, :, :].repeat(1, 320, 1, 1)
+ add_int = all_int[:, :1, :, :].repeat(1, 320, 1, 1)
+ shift = pruned_splats["means"][0].median(dim=0).values
+ scale_factor = (pruned_splats["means"][0] - shift).abs().quantile(0.95, dim=0).max()
+ all_ext[0, :, :3, -1] = (all_ext[0, :, :3, -1] - shift) / scale_factor
+ add_ext[0, :, :3, -1] = (add_ext[0, :, :3, -1] - shift) / scale_factor
+ flag = None
+ try:
+ raw_splats = gs_renderer.rasterizer.runner.splats
+ except:
+ pass
+ for st in range(0, add_ext.shape[1]):
+ ed = min(st + 1, add_ext.shape[1])
+ assert gs_renderer.sh_degree == 0
+ if flag is not None and (flag < 0.99).any():
+ break
+ sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor,
+ "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))}
+ effects_splats, flag = effects.apply_effect(sample_gsplat, t, effect_type=effect_type)
+ t += 0.04
+ effects_splats["sh"] = RGB2SH(effects_splats["colors"]).reshape(-1, 1, 3)
+ try:
+ gs_renderer.rasterizer.runner.splats
+ effects_splats["sh0"] = effects_splats["sh"][:, :1, :]
+ effects_splats["shN"] = effects_splats["sh"][:, 1:, :]
+ effects_splats["scales"] = effects_splats["scales"].log()
+ effects_splats["opacities"] = torch.logit(torch.clamp(effects_splats["opacities"], 1e-6, 1 - 1e-6))
+ gs_renderer.rasterizer.runner.splats = effects_splats
+ colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
+ None, None, None,
+ None, None,
+ add_ext[:, st:ed].to(torch.float32), add_int[:, st:ed].to(torch.float32),
+ width=w, height=h, sh_degree=gs_renderer.sh_degree,
+ )
+ except:
+ colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
+ effects_splats["means"][None], effects_splats["quats"][None], effects_splats["scales"][None],
+ effects_splats["opacities"][None], effects_splats["sh"][None],
+ add_ext[:, st:ed].to(torch.float32), add_int[:, st:ed].to(torch.float32),
+ width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None,
+ )
+
+ if st > add_ext.shape[1]*0.14:
+ t_skip = t if t_skip == 0 else t_skip
+ # break
+ rendered_rgbs.append(colors)
+ rendered_depths.append(depths)
+ # if (flag == 0).all():
+ # break
+ t_st = t
+ t_ed = 0
+ loop_dir = 1
+ ignore_scale = False
+ for st in tqdm(range(0, all_ext.shape[1], chunk)):
+ ed = min(st + chunk, all_ext.shape[1])
+ if effects is not None:
+ try:
+ sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor,
+ "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))}
+ except:
+ sample_gsplat = {"means": (pruned_splats["means"][0] - shift)/scale_factor, "quats": pruned_splats["quats"][0], "scales": pruned_splats["scales"][0]/scale_factor,
+ "opacities": pruned_splats["opacities"][0],"colors": SH2RGB(pruned_splats["sh"][0].reshape(-1, 3))}
+ effects_splats, flag = effects.apply_effect(sample_gsplat, t, effect_type=effect_type, ignore_scale=ignore_scale)
+ if loop_dir < 0:
+ t -= 0.04
+ else:
+ t += 0.04
+ if flag.mean() < 0.01 and t_ed == 0:
+ t_ed = t
+ effects_splats["sh"] = RGB2SH(effects_splats["colors"]).reshape(-1, 1, 3)
+ effects_splats["sh0"] = effects_splats["sh"][:, :1, :]
+ effects_splats["shN"] = effects_splats["sh"][:, 1:, :]
+ try:
+ gs_renderer.rasterizer.runner.splats
+ effects_splats["sh0"] = effects_splats["sh"][:, :1, :]
+ effects_splats["shN"] = effects_splats["sh"][:, 1:, :]
+ effects_splats["scales"] = effects_splats["scales"].log()
+ effects_splats["opacities"] = torch.logit(torch.clamp(effects_splats["opacities"], 1e-6, 1 - 1e-6))
+ gs_renderer.rasterizer.runner.splats = effects_splats
+ colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
+ None, None, None,
+ None, None,
+ all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
+ width=w, height=h, sh_degree=gs_renderer.sh_degree,
+ )
+ except:
+ colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
+ effects_splats["means"][None], effects_splats["quats"][None], effects_splats["scales"][None],
+ effects_splats["opacities"][None], effects_splats["sh"][None],
+ all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
+ width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in pruned_splats else None,
+ )
+
+ if t > (all_ext.shape[1]) * 0.04 + t_st - (t_ed - t_st)*2 - 15*0.04 or t < t_st:
+ # ignore_scale = True
+ loop_dir *= -1
+ t = t_ed if loop_dir == -1 else t
+ else:
+ colors, depths, _ = gs_renderer.rasterizer.rasterize_batches(
+ splats["means"][:1], splats["quats"][:1], splats["scales"][:1], splats["opacities"][:1],
+ splats["sh"][:1] if "sh" in splats else splats["colors"][:1],
+ all_ext[:, st:ed].to(torch.float32), all_int[:, st:ed].to(torch.float32),
+ width=w, height=h, sh_degree=gs_renderer.sh_degree if "sh" in splats else None,
+ )
+ rendered_rgbs.append(colors)
+ rendered_depths.append(depths)
+
+
+ rgbs = torch.cat(rendered_rgbs, dim=1)[0] # [N, H, W, 3]
+ depths = torch.cat(rendered_depths, dim=1)[0, ..., 0] # [N, H, W]
+
+
+ def depth_vis(d: torch.Tensor) -> torch.Tensor:
+ valid = d > 0
+ if valid.any():
+ near = d[valid].float().quantile(0.01).log()
+ else:
+ near = torch.tensor(0.0, device=d.device)
+ far = d.flatten().float().quantile(0.99).log()
+ x = d.float().clamp(min=1e-9).log()
+ x = 1.0 - (x - near) / (far - near + 1e-9)
+ return apply_color_map_to_image(x, "turbo")
+
+ frames = []
+ rgb_frames = []
+ depth_frames = []
+
+ for rgb, dep in zip(rgbs, depths):
+ rgb_img = rgb.permute(2, 0, 1) # [3, H, W]
+ depth_img = depth_vis(dep) # [3, H, W]
+
+ if save_mode == 'both':
+ combined = torch.cat([rgb_img, depth_img], dim=1) # [3, 2*H, W]
+ frames.append(combined)
+ elif save_mode == 'split':
+ rgb_frames.append(rgb_img)
+ depth_frames.append(depth_img)
+ else:
+ raise ValueError("save_mode must be 'both' or 'split'")
+
+ def _make_video(frames, path):
+ video = torch.stack(frames).clamp(0, 1) # [N, 3, H, W]
+ video = video.permute(0, 2, 3, 1) # [N, H, W, 3] for moviepy
+ video = (video * 255).to(torch.uint8).cpu().numpy()
+ if loop_reverse and video.shape[0] > 1:
+ video = np.concatenate([video, video[::-1][1:-1]], axis=0)
+ clip = mpy.ImageSequenceClip(list(video), fps=30)
+ clip.write_videofile(str(path), logger=None)
+
+ # Save videos
+ if save_mode == 'both':
+ _make_video(frames, f"{out_path}.mp4")
+ elif save_mode == 'split':
+ _make_video(rgb_frames, f"{out_path}_rgb.mp4")
+ _make_video(depth_frames, f"{out_path}_depth.mp4")
+
+ print(f"Video saved to {out_path} (mode: {save_mode})")
+
+ if effects is not None:
+ try:
+ gs_renderer.rasterizer.runner.splats = raw_splats
+ except:
+ pass
+ torch.cuda.empty_cache()
\ No newline at end of file
diff --git a/src/utils/save_utils.py b/src/utils/save_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..58c7f9b8150f23685b038b72e93c1795302acce2
--- /dev/null
+++ b/src/utils/save_utils.py
@@ -0,0 +1,286 @@
+# wzw
+"""
+Utilities for saving images, depths, normals, point clouds, and Gaussian splat data.
+tencent
+"""
+from pathlib import Path
+
+import numpy as np
+import torch
+from PIL import Image
+from plyfile import PlyData, PlyElement
+from io import BytesIO
+import json
+import os
+
+def save_camera_params(extrinsics, intrinsics, target_dir):
+ """
+ Save camera parameters (extrinsics and intrinsics) in JSON format
+
+ Args:
+ extrinsics: numpy array, shape [N, 4, 4] - extrinsic matrices for N cameras
+ intrinsics: numpy array, shape [N, 3, 3] - intrinsic matrices for N cameras
+ target_dir: str - directory to save the parameters
+
+ Returns:
+ str: path to the saved file
+ """
+ camera_data = {
+ "num_cameras": int(extrinsics.shape[0]),
+ "extrinsics": [],
+ "intrinsics": []
+ }
+
+ # Convert each camera's parameters to list format
+ for i in range(extrinsics.shape[0]):
+ camera_data["extrinsics"].append({
+ "camera_id": i,
+ "matrix": extrinsics[i].tolist() # [4, 4] -> list
+ })
+ camera_data["intrinsics"].append({
+ "camera_id": i,
+ "matrix": intrinsics[i].tolist() # [3, 3] -> list
+ })
+
+ # Save as JSON file
+ camera_params_path = os.path.join(target_dir, "camera_params.json")
+ with open(camera_params_path, 'w') as f:
+ json.dump(camera_data, f, indent=2)
+
+ return camera_params_path
+
+def save_image_png(path: Path, image_tensor: torch.Tensor) -> None:
+ # image_tensor: [H, W, 3]
+ img = (image_tensor.detach().cpu() * 255.0).to(torch.uint8).numpy()
+ Image.fromarray(img).save(str(path))
+
+
+def save_depth_png(path: Path, depth_tensor: torch.Tensor) -> None:
+ # depth_tensor: [H, W]
+ d = depth_tensor.detach()
+ d = d - d.min()
+ d = d / (d.max() + 1e-9)
+ img = (d.clamp(0, 1) * 255.0).to(torch.uint8).cpu().numpy()
+ Image.fromarray(img, mode="L").save(str(path))
+
+
+def save_depth_npy(path: Path, depth_tensor: torch.Tensor) -> None:
+ # depth_tensor: [H, W]
+ # Save actual depth values in numpy format
+ d = depth_tensor.detach().cpu().numpy()
+ np.save(str(path), d)
+
+
+def save_normal_png(path: Path, normal_hwc: torch.Tensor) -> None:
+ # normal_hwc: [H, W, 3], in [-1, 1]
+ n = (normal_hwc.detach().cpu() + 1.0) * 0.5
+ img = (n.clamp(0, 1) * 255.0).to(torch.uint8).numpy()
+ Image.fromarray(img).save(str(path))
+
+
+def save_scene_ply(path: Path,
+ points_xyz: torch.Tensor,
+ point_colors: torch.Tensor,
+ valid_mask: torch.Tensor = None) -> None:
+ """Save point cloud to PLY format"""
+ pts = points_xyz.detach().cpu().to(torch.float32).numpy().reshape(-1, 3)
+ colors = point_colors.detach().cpu().to(torch.uint8).numpy().reshape(-1, 3)
+
+ # Filter out invalid points (NaN, Inf)
+ if valid_mask is None:
+ valid_mask = np.isfinite(pts).all(axis=1)
+ else:
+ valid_mask = valid_mask.detach().cpu().numpy().reshape(-1)
+ pts = pts[valid_mask]
+ colors = colors[valid_mask]
+
+ # Handle empty point cloud
+ if len(pts) == 0:
+ pts = np.array([[0, 0, 0]], dtype=np.float32)
+ colors = np.array([[255, 255, 255]], dtype=np.uint8)
+
+ # Create PLY data
+ vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4"),
+ ("red", "u1"), ("green", "u1"), ("blue", "u1")]
+ vertex_elements = np.empty(len(pts), dtype=vertex_dtype)
+ vertex_elements["x"] = pts[:, 0]
+ vertex_elements["y"] = pts[:, 1]
+ vertex_elements["z"] = pts[:, 2]
+ vertex_elements["red"] = colors[:, 0]
+ vertex_elements["green"] = colors[:, 1]
+ vertex_elements["blue"] = colors[:, 2]
+
+ # Write PLY file
+ PlyData([PlyElement.describe(vertex_elements, "vertex")]).write(str(path))
+
+
+def save_points_ply(path: Path, pts_np: np.ndarray, cols_np: np.ndarray) -> None:
+ """Save point cloud to PLY format from numpy arrays"""
+ vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4"),
+ ("red", "u1"), ("green", "u1"), ("blue", "u1")]
+ vertex_elements = np.empty(len(pts_np), dtype=vertex_dtype)
+ vertex_elements["x"] = pts_np[:, 0]
+ vertex_elements["y"] = pts_np[:, 1]
+ vertex_elements["z"] = pts_np[:, 2]
+ vertex_elements["red"] = cols_np[:, 0]
+ vertex_elements["green"] = cols_np[:, 1]
+ vertex_elements["blue"] = cols_np[:, 2]
+
+ # Write PLY file
+ PlyData([PlyElement.describe(vertex_elements, "vertex")]).write(str(path))
+
+
+def save_gs_ply(path: Path,
+ means: torch.Tensor,
+ scales: torch.Tensor,
+ rotations: torch.Tensor,
+ rgbs: torch.Tensor,
+ opacities: torch.Tensor) -> None:
+ """
+ Export Gaussian splat data to PLY format.
+
+ Args:
+ path: Output PLY file path
+ means: Gaussian centers [N, 3]
+ scales: Gaussian scales [N, 3]
+ rotations: Gaussian rotations as quaternions [N, 4]
+ rgbs: RGB colors [N, 3]
+ opacities: Opacity values [N]
+ """
+ # Filter out points with scales greater than the 95th percentile
+ scale_threshold = torch.quantile(scales.max(dim=-1)[0], 0.95, dim=0)
+ filter_mask = scales.max(dim=-1)[0] <= scale_threshold
+
+ # Apply the filter to all tensors
+ means = means[filter_mask].reshape(-1, 3)
+ scales = scales[filter_mask].reshape(-1, 3)
+ rotations = rotations[filter_mask].reshape(-1, 4)
+ rgbs = rgbs[filter_mask].reshape(-1, 3)
+ opacities = opacities[filter_mask].reshape(-1)
+
+ # Construct attribute names
+ attributes = ["x", "y", "z", "nx", "ny", "nz"]
+ for i in range(3):
+ attributes.append(f"f_dc_{i}")
+ attributes.append("opacity")
+ for i in range(3):
+ attributes.append(f"scale_{i}")
+ for i in range(4):
+ attributes.append(f"rot_{i}")
+
+ # Prepare PLY data structure
+ dtype_full = [(attribute, "f4") for attribute in attributes]
+ elements = np.empty(means.shape[0], dtype=dtype_full)
+
+ # Concatenate all attributes
+ attributes_data = (
+ means.float().detach().cpu().numpy(),
+ torch.zeros_like(means).float().detach().cpu().numpy(),
+ rgbs.detach().cpu().contiguous().numpy(),
+ opacities[..., None].detach().cpu().numpy(),
+ scales.log().detach().cpu().numpy(),
+ rotations.detach().cpu().numpy(),
+ )
+ attributes_data = np.concatenate(attributes_data, axis=1)
+ elements[:] = list(map(tuple, attributes_data))
+
+ # Write to PLY file
+ PlyData([PlyElement.describe(elements, "vertex")]).write(str(path))
+
+
+def convert_gs_to_ply(means, scales, rotations, rgbs, opacities):
+ """
+ Export Gaussian splat data to PLY format.
+
+ Args:
+ path: Output PLY file path
+ means: Gaussian centers [N, 3]
+ scales: Gaussian scales [N, 3]
+ rotations: Gaussian rotations as quaternions [N, 4]
+ rgbs: RGB colors [N, 3]
+ opacities: Opacity values [N]
+ """
+ # Filter out points with scales greater than the 90th percentile
+ scale_threshold = torch.quantile(scales.max(dim=-1)[0], 0.90, dim=0)
+ filter_mask = scales.max(dim=-1)[0] <= scale_threshold
+
+ # Apply the filter to all tensors
+ means = means[filter_mask].reshape(-1, 3)
+ scales = scales[filter_mask].reshape(-1, 3)
+ rotations = rotations[filter_mask].reshape(-1, 4)
+ rgbs = rgbs[filter_mask].reshape(-1, 3)
+ opacities = opacities[filter_mask].reshape(-1)
+
+ # Construct attribute names
+ attributes = ["x", "y", "z", "nx", "ny", "nz"]
+ for i in range(3):
+ attributes.append(f"f_dc_{i}")
+ attributes.append("opacity")
+ for i in range(3):
+ attributes.append(f"scale_{i}")
+ for i in range(4):
+ attributes.append(f"rot_{i}")
+
+ # Prepare PLY data structure
+ dtype_full = [(attribute, "f4") for attribute in attributes]
+ elements = np.empty(means.shape[0], dtype=dtype_full)
+
+ # Concatenate all attributes
+ attributes_data = (
+ means.float().detach().cpu().numpy(),
+ torch.zeros_like(means).float().detach().cpu().numpy(),
+ rgbs.detach().cpu().contiguous().numpy(),
+ opacities[..., None].detach().cpu().numpy(),
+ scales.log().detach().cpu().numpy(),
+ rotations.detach().cpu().numpy(),
+ )
+ attributes_data = np.concatenate(attributes_data, axis=1)
+ elements[:] = list(map(tuple, attributes_data))
+ plydata = PlyData([PlyElement.describe(elements, "vertex")])
+ return plydata
+
+
+def process_ply_to_splat(plydata, output_path):
+ vert = plydata["vertex"]
+ sorted_indices = np.argsort(
+ -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
+ / (1 + np.exp(-vert["opacity"]))
+ )
+ buffer = BytesIO()
+ for idx in sorted_indices:
+ v = plydata["vertex"][idx]
+ position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
+ scales = np.exp(
+ np.array(
+ [v["scale_0"], v["scale_1"], v["scale_2"]],
+ dtype=np.float32,
+ )
+ )
+ rot = np.array(
+ [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
+ dtype=np.float32,
+ )
+ SH_C0 = 0.28209479177387814
+ color = np.array(
+ [
+ 0.5 + SH_C0 * v["f_dc_0"],
+ 0.5 + SH_C0 * v["f_dc_1"],
+ 0.5 + SH_C0 * v["f_dc_2"],
+ 1 / (1 + np.exp(-v["opacity"])),
+ ]
+ )
+ buffer.write(position.tobytes())
+ buffer.write(scales.tobytes())
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
+ buffer.write(
+ ((rot / np.linalg.norm(rot)) * 128 + 128)
+ .clip(0, 255)
+ .astype(np.uint8)
+ .tobytes()
+ )
+ value = buffer.getvalue()
+ with open(output_path, "wb") as f:
+ f.write(value)
+
+ return output_path
+
diff --git a/src/utils/video_utils.py b/src/utils/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0de434e0a49dc8f72454d4faad24b244b8f55ccf
--- /dev/null
+++ b/src/utils/video_utils.py
@@ -0,0 +1,150 @@
+"""
+Video utilities for visualization.
+
+"""
+
+import os
+import cv2
+import numpy as np
+import subprocess
+from PIL import Image
+
+
+def video_to_image_frames(input_video_path, save_directory=None, frames_per_second=1):
+ """
+ Extracts image frames from a video file at the specified frame rate and saves them as JPEG format.
+ Supports regular video files, webcam captures, WebM files, and GIF files, including incomplete files.
+
+ Args:
+ input_video_path: Path to the input video file
+ save_directory: Directory to save extracted frames (default: None)
+ frames_per_second: Number of frames to extract per second (default: 1)
+
+ Returns: List of file paths to extracted frames
+ """
+ extracted_frame_paths = []
+
+ # For GIF files, use PIL library for better handling
+ if input_video_path.lower().endswith('.gif'):
+ try:
+ print(f"Processing GIF file using PIL: {input_video_path}")
+
+ with Image.open(input_video_path) as gif_img:
+ # Get GIF properties
+ frame_duration_ms = gif_img.info.get('duration', 100) # Duration per frame in milliseconds
+ gif_frame_rate = 1000.0 / frame_duration_ms if frame_duration_ms > 0 else 10.0 # Convert to frame rate
+
+ print(f"GIF properties: {gif_img.n_frames} frames, {gif_frame_rate:.2f} FPS, {frame_duration_ms}ms per frame")
+
+ # Calculate sampling interval
+ sampling_interval = max(1, int(gif_frame_rate / frames_per_second)) if frames_per_second < gif_frame_rate else 1
+
+ saved_count = 0
+ for current_frame_index in range(gif_img.n_frames):
+ gif_img.seek(current_frame_index)
+
+ # Sample frames based on desired frame rate
+ if current_frame_index % sampling_interval == 0:
+ # Convert to RGB format if necessary
+ rgb_frame = gif_img.convert('RGB')
+
+ # Convert PIL image to numpy array
+ frame_ndarray = np.array(rgb_frame)
+
+ # Save frame as JPEG format
+ frame_output_path = os.path.join(save_directory, f"frame_{saved_count:06d}.jpg")
+ pil_image = Image.fromarray(frame_ndarray)
+ pil_image.save(frame_output_path, 'JPEG', quality=95)
+ extracted_frame_paths.append(frame_output_path)
+ saved_count += 1
+
+ if extracted_frame_paths:
+ print(f"Successfully extracted {len(extracted_frame_paths)} frames from GIF using PIL")
+ return extracted_frame_paths
+
+ except Exception as error:
+ print(f"PIL GIF extraction error: {str(error)}, falling back to OpenCV")
+
+ # For WebM files, use FFmpeg directly for more stable processing
+ if input_video_path.lower().endswith('.webm'):
+ try:
+ print(f"Processing WebM file using FFmpeg: {input_video_path}")
+
+ # Create a unique output pattern for the frames
+ output_frame_pattern = os.path.join(save_directory, "frame_%04d.jpg")
+
+ # Use FFmpeg to extract frames at specified frame rate
+ ffmpeg_command = [
+ "ffmpeg",
+ "-i", input_video_path,
+ "-vf", f"fps={frames_per_second}", # Specified frames per second
+ "-q:v", "2", # High quality
+ output_frame_pattern
+ ]
+
+ # Run FFmpeg process
+ ffmpeg_process = subprocess.Popen(
+ ffmpeg_command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE
+ )
+ process_stdout, process_stderr = ffmpeg_process.communicate()
+
+ # Collect all extracted frames
+ for filename in sorted(os.listdir(save_directory)):
+ if filename.startswith("frame_") and filename.endswith(".jpg"):
+ full_frame_path = os.path.join(save_directory, filename)
+ extracted_frame_paths.append(full_frame_path)
+
+ if extracted_frame_paths:
+ print(f"Successfully extracted {len(extracted_frame_paths)} frames from WebM using FFmpeg")
+ return extracted_frame_paths
+
+ print("FFmpeg extraction failed, falling back to OpenCV")
+ except Exception as error:
+ print(f"FFmpeg extraction error: {str(error)}, falling back to OpenCV")
+
+ # Standard OpenCV method for non-WebM files or as fallback
+ try:
+ video_capture = cv2.VideoCapture(input_video_path)
+
+ # For WebM files, try setting more robust decoder options
+ if input_video_path.lower().endswith('.webm'):
+ video_capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'VP80'))
+
+ source_fps = video_capture.get(cv2.CAP_PROP_FPS)
+ extraction_interval = max(1, int(source_fps / frames_per_second)) # Extract at specified frame rate
+ processed_frame_count = 0
+
+ # Set error mode to suppress console warnings
+ cv2.setLogLevel(0)
+
+ while True:
+ read_success, current_frame = video_capture.read()
+ if not read_success:
+ break
+
+ if processed_frame_count % extraction_interval == 0:
+ try:
+ # Additional check for valid frame data
+ if current_frame is not None and current_frame.size > 0:
+ rgb_converted_frame = cv2.cvtColor(current_frame, cv2.COLOR_BGR2RGB)
+ frame_output_path = os.path.join(save_directory, f"frame_{processed_frame_count:06d}.jpg")
+ cv2.imwrite(frame_output_path, cv2.cvtColor(rgb_converted_frame, cv2.COLOR_RGB2BGR))
+ extracted_frame_paths.append(frame_output_path)
+ except Exception as error:
+ print(f"Warning: Failed to process frame {processed_frame_count}: {str(error)}")
+
+ processed_frame_count += 1
+
+ # Safety limit to prevent infinite loops
+ if processed_frame_count > 1000:
+ break
+
+ video_capture.release()
+ print(f"Extracted {len(extracted_frame_paths)} frames from video using OpenCV")
+
+ except Exception as error:
+ print(f"Error extracting frames: {str(error)}")
+
+ return extracted_frame_paths
\ No newline at end of file
diff --git a/src/utils/visual_util.py b/src/utils/visual_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..352f8462928b385fafb445b26ffcf35155fc929d
--- /dev/null
+++ b/src/utils/visual_util.py
@@ -0,0 +1,615 @@
+# wzw
+""" Visual utilities for HuggingFace integration.
+
+References: https://github.com/facebookresearch/vggt
+"""
+
+import copy
+import os
+from typing import Tuple
+
+import cv2
+import matplotlib
+import numpy as np
+import requests
+import trimesh
+
+from scipy.spatial.transform import Rotation
+
+
+def segment_sky(image_path, onnx_session):
+ """
+ Segments sky from an image using an ONNX model.
+ Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
+
+ Args:
+ image_path: Path to input image
+ onnx_session: ONNX runtime session with loaded model
+
+ Returns:
+ np.ndarray: Binary mask where 255 indicates non-sky regions
+ """
+
+ image = cv2.imread(image_path)
+ result_map = run_skyseg(onnx_session, [320, 320], image)
+ # resize the result_map to the original image size
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
+
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
+ # The model outputs low values for sky, high values for non-sky
+ output_mask = np.zeros_like(result_map_original)
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
+ return output_mask
+
+
+def run_skyseg(onnx_session, input_size, image):
+ """
+ Runs sky segmentation inference using ONNX model.
+
+ Args:
+ onnx_session: ONNX runtime session
+ input_size: Target size for model input (width, height)
+ image: Input image in BGR format
+
+ Returns:
+ np.ndarray: Segmentation mask
+ """
+
+ # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
+ temp_image = copy.deepcopy(image)
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
+ x = np.array(x, dtype=np.float32)
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ x = (x / 255 - mean) / std
+ x = x.transpose(2, 0, 1)
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
+
+ # Inference
+ input_name = onnx_session.get_inputs()[0].name
+ output_name = onnx_session.get_outputs()[0].name
+ onnx_result = onnx_session.run([output_name], {input_name: x})
+
+ # Post process
+ onnx_result = np.array(onnx_result).squeeze()
+ min_value = np.min(onnx_result)
+ max_value = np.max(onnx_result)
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
+ onnx_result *= 255
+ onnx_result = onnx_result.astype("uint8")
+
+ return onnx_result
+
+
+def download_file_from_url(url, filename):
+ """Downloads a file from a Hugging Face model repo, handling redirects."""
+ try:
+ # Get the redirect URL
+ response = requests.get(url, allow_redirects=False)
+ response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
+
+ if response.status_code == 302: # Expecting a redirect
+ redirect_url = response.headers["Location"]
+ response = requests.get(redirect_url, stream=True)
+ response.raise_for_status()
+ else:
+ print(f"Unexpected status code: {response.status_code}")
+ return
+
+ with open(filename, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+ print(f"Downloaded {filename} successfully.")
+
+ except requests.exceptions.RequestException as e:
+ print(f"Error downloading file: {e}")
+
+
+def create_image_mesh(
+ *image_data: np.ndarray,
+ mask: np.ndarray = None,
+ triangulate: bool = False,
+ return_vertex_indices: bool = False,
+) -> Tuple[np.ndarray, ...]:
+ """
+ Create a mesh from image data using pixel coordinates as vertices and grid connections as faces.
+
+ Args:
+ *image_data (np.ndarray): Image arrays with shape (height, width, [channels])
+ mask (np.ndarray, optional): Boolean mask with shape (height, width). Defaults to None.
+ triangulate (bool): Convert quad faces to triangular faces. Defaults to False.
+ return_vertex_indices (bool): Include vertex indices in output. Defaults to False.
+
+ Returns:
+ faces (np.ndarray): Face connectivity array. Shape (N, 4) for quads or (N, 3) for triangles
+ *vertex_data (np.ndarray): Vertex attributes corresponding to input image_data
+ vertex_indices (np.ndarray, optional): Original vertex indices if return_vertex_indices=True
+ """
+ # Validate inputs
+ assert (len(image_data) > 0) or (mask is not None), "Need at least one image or mask"
+
+ if mask is None:
+ height, width = image_data[0].shape[:2]
+ else:
+ height, width = mask.shape
+
+ # Check all images have same dimensions
+ for img in image_data:
+ assert img.shape[:2] == (height, width), "All images must have same height and width"
+
+ # Create quad faces connecting neighboring pixels
+ base_quad = np.stack([
+ np.arange(0, width - 1, dtype=np.int32), # bottom-left
+ np.arange(width, 2 * width - 1, dtype=np.int32), # top-left
+ np.arange(1 + width, 2 * width, dtype=np.int32), # top-right
+ np.arange(1, width, dtype=np.int32), # bottom-right
+ ], axis=1)
+
+ # Replicate quad pattern for all rows
+ row_offsets = np.arange(0, (height - 1) * width, width, dtype=np.int32)
+ faces = (row_offsets[:, None, None] + base_quad[None, :, :]).reshape((-1, 4))
+
+ if mask is None:
+ # No masking - use all faces and vertices
+ if triangulate:
+ faces = _convert_quads_to_triangles(faces)
+
+ output = [faces]
+ for img in image_data:
+ output.append(img.reshape(-1, *img.shape[2:]))
+
+ if return_vertex_indices:
+ output.append(np.arange(height * width, dtype=np.int32))
+
+ return tuple(output)
+ else:
+ # Apply mask - only keep faces where all 4 corners are valid
+ valid_quads = (
+ mask[:-1, :-1] & mask[1:, :-1] &
+ mask[1:, 1:] & mask[:-1, 1:]
+ ).ravel()
+ faces = faces[valid_quads]
+
+ if triangulate:
+ faces = _convert_quads_to_triangles(faces)
+
+ # Remove unused vertices and remap face indices
+ num_face_vertices = faces.shape[-1]
+ unique_vertices, remapped_indices = np.unique(faces, return_inverse=True)
+ faces = remapped_indices.astype(np.int32).reshape(-1, num_face_vertices)
+
+ output = [faces]
+ for img in image_data:
+ flattened_img = img.reshape(-1, *img.shape[2:])
+ output.append(flattened_img[unique_vertices])
+
+ if return_vertex_indices:
+ output.append(unique_vertices)
+
+ return tuple(output)
+
+
+def _convert_quads_to_triangles(quad_faces: np.ndarray) -> np.ndarray:
+ """Convert quadrilateral faces to triangular faces."""
+ if quad_faces.shape[-1] == 3:
+ return quad_faces # Already triangular
+
+ num_vertices_per_face = quad_faces.shape[-1]
+ triangle_indices = np.stack([
+ np.zeros(num_vertices_per_face - 2, dtype=int), # First vertex
+ np.arange(1, num_vertices_per_face - 1, dtype=int), # Sequential vertices
+ np.arange(2, num_vertices_per_face, dtype=int), # Next sequential vertices
+ ], axis=1)
+
+ return quad_faces[:, triangle_indices].reshape((-1, 3))
+
+
+def convert_predictions_to_glb_scene(
+ predictions,
+ filter_by_frames="all",
+ show_camera=True,
+ mask_sky_bg=False,
+ mask_ambiguous=False,
+ as_mesh=True,
+) -> trimesh.Scene:
+ """
+ Converts model predictions to a 3D scene represented as a GLB file.
+
+ Args:
+ predictions (dict): Dictionary containing model predictions with keys:
+ - world_points: 3D point coordinates (S, H, W, 3)
+ - images: Input images (S, H, W, 3)
+ - camera_poses: Camera extrinsic matrices (S, 3, 4)
+ filter_by_frames (str): Frame filter specification (default: "all")
+ show_camera (bool): Include camera visualization (default: True)
+ mask_sky_bg (bool): Mask out sky background pixels (default: False)
+ mask_ambiguous (bool): Apply final mask to filter ambiguous predictions (default: False)
+ as_mesh (bool): Represent the data as a mesh instead of point cloud (default: False)
+
+ Returns:
+ trimesh.Scene: Processed 3D scene containing point cloud/mesh and cameras
+
+ Raises:
+ ValueError: If input predictions structure is invalid
+ """
+ if not isinstance(predictions, dict):
+ raise ValueError("predictions must be a dictionary")
+
+ print("Building GLB scene")
+
+ # Parse frame selection from filter string
+ target_frame_index = None
+ if filter_by_frames not in ["all", "All"]:
+ try:
+ # Extract numeric index before colon separator
+ target_frame_index = int(filter_by_frames.split(":")[0])
+ except (ValueError, IndexError):
+ pass
+
+ # Validate required data in predictions
+ print("Using Pointmap Branch")
+ if "world_points" not in predictions:
+ raise ValueError(
+ "world_points not found in predictions. Pointmap Branch requires 'world_points' key. "
+ "Depthmap and Camera branches have been removed."
+ )
+
+ # Extract prediction data
+ point_cloud_3d = predictions["world_points"]
+ input_images = predictions["images"]
+ extrinsic_matrices = predictions["camera_poses"]
+ ambiguity_mask = predictions["final_mask"]
+ sky_region_mask = predictions["sky_mask"]
+
+ # Filter to single frame if specified
+ if target_frame_index is not None:
+ point_cloud_3d = point_cloud_3d[target_frame_index][None]
+ input_images = input_images[target_frame_index][None]
+ extrinsic_matrices = extrinsic_matrices[target_frame_index][None]
+ ambiguity_mask = ambiguity_mask[target_frame_index][None]
+ sky_region_mask = sky_region_mask[target_frame_index][None]
+
+ # Flatten 3D points to vertex array
+ flattened_vertices = point_cloud_3d.reshape(-1, 3)
+
+ # Convert images to RGB color array
+ if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
+ rgb_colors = np.transpose(input_images, (0, 2, 3, 1))
+ else: # Already in NHWC format
+ rgb_colors = input_images
+ rgb_colors = (rgb_colors.reshape(-1, 3) * 255).astype(np.uint8)
+
+ # Build composite filtering mask
+ valid_points_mask = np.ones(len(flattened_vertices), dtype=bool)
+
+ # Apply ambiguity filtering if requested
+ if mask_ambiguous:
+ flat_ambiguity_mask = ambiguity_mask.reshape(-1)
+ valid_points_mask = valid_points_mask & flat_ambiguity_mask
+
+ # Apply sky region filtering if requested
+ if mask_sky_bg:
+ flat_sky_mask = sky_region_mask.reshape(-1)
+ valid_points_mask = valid_points_mask & flat_sky_mask
+
+ # Apply mask to filter vertices and colors
+ filtered_vertices = flattened_vertices[valid_points_mask].copy()
+ filtered_colors = rgb_colors[valid_points_mask].copy()
+
+ # Handle empty geometry case
+ if filtered_vertices is None or np.asarray(filtered_vertices).size == 0:
+ filtered_vertices = np.array([[1, 0, 0]])
+ filtered_colors = np.array([[255, 255, 255]])
+ scene_scale_factor = 1
+ else:
+ # Compute scene scale from percentile-based bounding box
+ percentile_lower = np.percentile(filtered_vertices, 5, axis=0)
+ percentile_upper = np.percentile(filtered_vertices, 95, axis=0)
+ scene_scale_factor = np.linalg.norm(percentile_upper - percentile_lower)
+
+ # Initialize color mapping for cameras
+ color_palette = matplotlib.colormaps.get_cmap("gist_rainbow")
+
+ # Create empty 3D scene container
+ output_scene = trimesh.Scene()
+
+ # Add geometry to scene based on representation type
+ if as_mesh:
+ # Mesh representation
+ if target_frame_index is not None:
+ # Single frame mesh generation
+ frame_height, frame_width = point_cloud_3d.shape[1:3]
+
+ # Prepare unfiltered data for mesh construction
+ structured_points = point_cloud_3d.reshape(frame_height, frame_width, 3)
+
+ # Convert image data to proper format
+ if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
+ structured_colors = np.transpose(input_images[0], (1, 2, 0))
+ else: # Already in HWC format
+ structured_colors = input_images[0]
+ structured_colors *= 255
+
+ # Get structured mask for mesh creation
+ structured_mask = predictions["final_mask"][target_frame_index].reshape(
+ frame_height, frame_width
+ )
+
+ # Build filtering mask
+ mesh_filter_mask = structured_mask
+
+ # Check for normal data availability
+ mesh_normals = None
+ if "normal" in predictions and predictions["normal"] is not None:
+ # Extract normals for selected frame
+ frame_normal_data = (
+ predictions["normal"][target_frame_index]
+ if target_frame_index is not None
+ else predictions["normal"][0]
+ )
+
+ # Generate mesh with normal information
+ mesh_faces, mesh_vertices, mesh_colors, mesh_normals = create_image_mesh(
+ structured_points * np.array([1, -1, 1], dtype=np.float32),
+ structured_colors / 255.0,
+ frame_normal_data * np.array([1, -1, 1], dtype=np.float32),
+ mask=mesh_filter_mask,
+ triangulate=True,
+ return_vertex_indices=False,
+ )
+
+ # Apply coordinate system transformation to normals
+ mesh_normals = mesh_normals * np.array([1, -1, 1], dtype=np.float32)
+ else:
+ # Generate mesh without normal information
+ mesh_faces, mesh_vertices, mesh_colors = create_image_mesh(
+ structured_points * np.array([1, -1, 1], dtype=np.float32),
+ structured_colors / 255.0,
+ mask=mesh_filter_mask,
+ triangulate=True,
+ return_vertex_indices=False,
+ )
+
+ # Construct trimesh object with optional normals
+ geometry_mesh = trimesh.Trimesh(
+ vertices=mesh_vertices * np.array([1, -1, 1], dtype=np.float32),
+ faces=mesh_faces,
+ vertex_colors=(mesh_colors * 255).astype(np.uint8),
+ vertex_normals=(mesh_normals if mesh_normals is not None else None),
+ process=False,
+ )
+ output_scene.add_geometry(geometry_mesh)
+ else:
+ # Multi-frame mesh generation
+ print("Creating mesh for multi-frame data...")
+
+ for frame_idx in range(point_cloud_3d.shape[0]):
+ frame_height, frame_width = point_cloud_3d.shape[1:3]
+
+ # Extract per-frame data
+ frame_point_data = point_cloud_3d[frame_idx]
+ frame_ambiguity_mask = predictions["final_mask"][frame_idx]
+ frame_sky_mask = predictions["sky_mask"][frame_idx]
+
+ # Extract frame image data
+ if input_images.ndim == 4 and input_images.shape[1] == 3: # NCHW format
+ frame_image_data = np.transpose(input_images[frame_idx], (1, 2, 0))
+ else: # Already in HWC format
+ frame_image_data = input_images[frame_idx]
+ frame_image_data *= 255
+
+ # Build per-frame filtering mask
+ frame_filter_mask = np.ones((frame_height, frame_width), dtype=bool)
+
+ # Apply ambiguity filtering if enabled
+ if mask_ambiguous:
+ frame_filter_mask = frame_filter_mask & frame_ambiguity_mask
+
+ # Apply sky filtering if enabled
+ if mask_sky_bg:
+ frame_filter_mask = frame_filter_mask & frame_sky_mask
+
+ # Generate mesh for current frame
+ frame_faces, frame_vertices, frame_colors = create_image_mesh(
+ frame_point_data * np.array([1, -1, 1], dtype=np.float32),
+ frame_image_data / 255.0,
+ mask=frame_filter_mask,
+ triangulate=True,
+ return_vertex_indices=False,
+ )
+
+ frame_vertices = frame_vertices * np.array([1, -1, 1], dtype=np.float32)
+
+ # Create trimesh object for current frame
+ frame_geometry = trimesh.Trimesh(
+ vertices=frame_vertices,
+ faces=frame_faces,
+ vertex_colors=(frame_colors * 255).astype(np.uint8),
+ process=False,
+ )
+ output_scene.add_geometry(frame_geometry)
+ else:
+ # Point cloud representation
+ point_cloud_geometry = trimesh.PointCloud(vertices=filtered_vertices, colors=filtered_colors)
+ output_scene.add_geometry(point_cloud_geometry)
+
+ # Add camera visualizations if requested
+ num_camera_views = len(extrinsic_matrices)
+
+ if show_camera:
+ # Iterate through all camera views
+ for camera_idx in range(num_camera_views):
+ camera_extrinsic = extrinsic_matrices[camera_idx]
+ camera_color_rgba = color_palette(camera_idx / num_camera_views)
+ camera_color_rgb = tuple(int(255 * x) for x in camera_color_rgba[:3])
+
+ integrate_camera_into_scene(
+ output_scene, camera_extrinsic, camera_color_rgb, scene_scale_factor
+ )
+
+ # Define coordinate system transformation matrices
+ opengl_transform = np.eye(4)
+ opengl_transform[1, 1] = -1 # Flip Y axis
+ opengl_transform[2, 2] = -1 # Flip Z axis
+
+ # Define alignment rotation (180 degrees around Y-axis)
+ alignment_rotation = np.eye(4)
+ alignment_rotation[:3, :3] = Rotation.from_euler("y", 0, degrees=True).as_matrix()
+
+ # Compute and apply final transformation
+ scene_transformation = (
+ np.linalg.inv(extrinsic_matrices[0])
+ @ opengl_transform
+ @ alignment_rotation
+ )
+ output_scene.apply_transform(scene_transformation)
+
+ print("GLB Scene built")
+ return output_scene
+
+def integrate_camera_into_scene(
+ scene: trimesh.Scene,
+ camera_transform: np.ndarray,
+ camera_color: tuple,
+ scale_factor: float,
+):
+ """
+ Adds a camera visualization mesh to the 3D scene.
+
+ Args:
+ scene (trimesh.Scene): The 3D scene to add the camera visualization.
+ camera_transform (np.ndarray): 4x4 transformation matrix for camera positioning.
+ camera_color (tuple): RGB color tuple for the camera mesh.
+ scale_factor (float): Scaling factor for the camera size relative to scene.
+ """
+ # Define camera dimensions based on scene scale
+ camera_base_width = scale_factor * 0.05
+ camera_cone_height = scale_factor * 0.1
+
+ # Create base cone geometry for camera representation
+ base_cone = trimesh.creation.cone(camera_base_width, camera_cone_height, sections=4)
+
+ # Setup rotation transformation (45 degrees around z-axis)
+ z_rotation_matrix = np.eye(4)
+ z_rotation_matrix[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
+ z_rotation_matrix[2, 3] = -camera_cone_height
+
+ # Setup OpenGL coordinate system conversion
+ opengl_coord_transform = np.eye(4)
+ opengl_coord_transform[1, 1] = -1 # Flip Y axis
+ opengl_coord_transform[2, 2] = -1 # Flip Z axis
+
+ # Combine all transformations
+ final_transform = camera_transform @ opengl_coord_transform @ z_rotation_matrix
+
+ # Create slight rotation for mesh variation
+ minor_rotation = np.eye(4)
+ minor_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
+
+ # Generate multiple vertex sets for complex camera geometry
+ original_vertices = base_cone.vertices
+ scaled_vertices = 0.95 * original_vertices
+ rotated_vertices = apply_transformation_to_points(minor_rotation, original_vertices)
+
+ # Combine all vertex sets
+ all_vertices = np.concatenate([
+ original_vertices,
+ scaled_vertices,
+ rotated_vertices
+ ])
+
+ # Transform vertices to final position
+ transformed_vertices = apply_transformation_to_points(final_transform, all_vertices)
+
+ # Generate faces for the complete camera mesh
+ camera_faces = generate_camera_mesh_faces(base_cone)
+
+ # Create and configure the camera mesh
+ camera_mesh = trimesh.Trimesh(
+ vertices=transformed_vertices,
+ faces=camera_faces
+ )
+ camera_mesh.visual.face_colors[:, :3] = camera_color
+
+ # Add the camera mesh to the scene
+ scene.add_geometry(camera_mesh)
+
+
+def apply_transformation_to_points(
+ transform_matrix: np.ndarray, point_array: np.ndarray, output_dim: int = None
+) -> np.ndarray:
+ """
+ Applies a 4x4 transformation matrix to a collection of 3D points.
+
+ Args:
+ transform_matrix (np.ndarray): 4x4 transformation matrix to apply.
+ point_array (np.ndarray): Array of points to transform.
+ output_dim (int, optional): Target dimension for output points.
+
+ Returns:
+ np.ndarray: Array of transformed points.
+ """
+ point_array = np.asarray(point_array)
+ original_shape = point_array.shape[:-1]
+ target_dim = output_dim or point_array.shape[-1]
+
+ # Transpose transformation matrix for matrix multiplication
+ transposed_transform = transform_matrix.swapaxes(-1, -2)
+
+ # Apply rotation/scaling and translation components
+ transformed_points = (
+ point_array @ transposed_transform[..., :-1, :] +
+ transposed_transform[..., -1:, :]
+ )
+
+ # Extract desired dimensions and restore original shape
+ final_result = transformed_points[..., :target_dim].reshape(*original_shape, target_dim)
+ return final_result
+
+
+def generate_camera_mesh_faces(base_cone_mesh: trimesh.Trimesh) -> np.ndarray:
+ """
+ Generates face indices for a complex camera mesh composed of multiple cone layers.
+
+ Args:
+ base_cone_mesh (trimesh.Trimesh): Base cone geometry used as template.
+
+ Returns:
+ np.ndarray: Array of face indices defining the camera mesh topology.
+ """
+ face_indices = []
+ vertex_count_per_cone = len(base_cone_mesh.vertices)
+
+ # Process each face of the base cone
+ for triangle_face in base_cone_mesh.faces:
+ # Skip faces that include the cone tip (vertex 0)
+ if 0 in triangle_face:
+ continue
+
+ # Get vertex indices for current triangle
+ vertex_a, vertex_b, vertex_c = triangle_face
+
+ # Calculate corresponding vertices in second and third cone layers
+ vertex_a_layer2, vertex_b_layer2, vertex_c_layer2 = triangle_face + vertex_count_per_cone
+ vertex_a_layer3, vertex_b_layer3, vertex_c_layer3 = triangle_face + 2 * vertex_count_per_cone
+
+ # Create connecting faces between cone layers
+ connecting_faces = [
+ (vertex_a, vertex_b, vertex_b_layer2),
+ (vertex_a, vertex_a_layer2, vertex_c),
+ (vertex_c_layer2, vertex_b, vertex_c),
+ (vertex_a, vertex_b, vertex_b_layer3),
+ (vertex_a, vertex_a_layer3, vertex_c),
+ (vertex_c_layer3, vertex_b, vertex_c),
+ ]
+
+ face_indices.extend(connecting_faces)
+
+ # Add reverse-winding faces for proper mesh closure
+ reversed_faces = [(vertex_c, vertex_b, vertex_a) for vertex_a, vertex_b, vertex_c in face_indices]
+ face_indices.extend(reversed_faces)
+
+ return np.array(face_indices)
+
+
diff --git a/src/utils/warnings.py b/src/utils/warnings.py
new file mode 100644
index 0000000000000000000000000000000000000000..8422416bac8ba5893f6a50f2b32125f4f9ab65bb
--- /dev/null
+++ b/src/utils/warnings.py
@@ -0,0 +1,41 @@
+"""
+Wrapper utilities for warnings.
+"""
+
+import warnings
+from functools import wraps
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = "ignore", **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)