diff --git a/.gitignore b/.gitignore
index ed8ebf5..82c3be6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,22 @@
-__pycache__
\ No newline at end of file
+__pycache__
+*.pyc
+*.py~
+*.swp
+*.swo
+*~
+
+# Output directories
+visualization/
+outputs/
+*.ply
+*.glb
+*.obj
+
+# IDE
+.vscode/
+.idea/
+*.code-workspace
+
+# OS
+.DS_Store
+Thumbs.db
\ No newline at end of file
diff --git a/README.md b/README.md
index 59a01fa..e062430 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
+
# SAM 3D
-SAM 3D Objects is one part of SAM 3D, a pair of models for object and human mesh reconstruction. If you’re looking for SAM 3D Body, [click here](https://github.com/facebookresearch/sam-3d-body).
+SAM 3D Objects is one part of SAM 3D, a pair of models for object and human mesh reconstruction. If you're looking for SAM 3D Body, [click here](https://github.com/facebookresearch/sam-3d-body).
# SAM 3D Objects
@@ -67,6 +68,140 @@ For more details and multi-object reconstruction, please take a look at out two
* [single object](notebook/demo_single_object.ipynb)
* [multi object](notebook/demo_multi_object.ipynb)
+## Multi-View 3D Reconstruction
+
+This contribution adds **training-free multi-view 3D reconstruction** capability to SAM 3D Objects using a multidiffusion approach. This allows you to generate consistent 3D models from multiple input images of the same object from different viewpoints, without requiring model retraining.
+
+### Results Comparison
+
+The following comparison demonstrates the improvement of multi-view reconstruction over single-view reconstruction:
+
+
+
+ | Single-View (View 3) |
+ Single-View (View 6) |
+ Multi-View (All 8 Views) |
+
+
+
+ Input Image
+
+ |
+
+ Input Image
+
+ |
+
+ Input Images
+
+
+  |
+  |
+  |
+  |
+
+
+  |
+  |
+  |
+  |
+
+
+ |
+
+
+ |
+ ↓ 3D Reconstruction ↓
+ |
+
+
+
+ 3D Result
+
+ |
+
+ 3D Result
+
+ |
+
+ 3D Result
+
+ |
+
+
+ |
+ Analysis: Due to occlusion in the input image, the red collar on the dog is not visible, resulting in its absence in the generated 3D model.
+ |
+
+ Analysis: Many frontal parts of the dog are occluded or not visible from this angle, leading to structural errors in the front-facing regions of the generated model.
+ |
+
+ Analysis: By combining information from all 8 views, the multi-view reconstruction produces a complete and accurate 3D model that closely matches the actual object.
+ |
+
+
+
+### Quick Start
+
+Use the `run_inference.py` script for both single-view and multi-view reconstruction:
+
+```bash
+# Multi-view reconstruction (mask_prompt=None, images and masks in same directory)
+python run_inference.py --input_path ./data/images_and_masks
+
+# Single-view reconstruction (specify a single image name)
+python run_inference.py --input_path ./data/images_and_masks --image_names image1
+
+# Multi-view reconstruction (mask_prompt!=None, images in images/, masks in {mask_prompt}/)
+python run_inference.py --input_path ./data --mask_prompt stuffed_toy
+
+# Specify multiple image names (can be any filename without extension)
+python run_inference.py --input_path ./data --mask_prompt stuffed_toy --image_names image1,view_a,2
+```
+
+### Data Structure
+
+Multi-view data can be organized in two ways:
+
+**Structure 1** (when `mask_prompt=None`): Images and masks in the same directory
+```
+input_path/
+ ├── 1.png # Original image (PNG format)
+ ├── 1_mask.png # Mask (RGBA format, alpha channel stores mask info)
+ ├── 2.png
+ ├── 2_mask.png
+ └── ...
+```
+
+**Structure 2** (when `mask_prompt!=None`, e.g., `mask_prompt="stuffed_toy"`): Images and masks in separate directories
+```
+input_path/
+ ├── images/
+ │ ├── 1.png
+ │ ├── 2.png
+ │ └── ...
+ └── stuffed_toy/ (or {mask_prompt}/)
+ ├── 1.png (or 1_mask.png)
+ ├── 2.png (or 2_mask.png)
+ └── ...
+```
+
+**Mask Format**: RGBA format where the alpha channel stores mask information (alpha=255 for object, alpha=0 for background).
+
+### Command Line Options
+
+Run `python run_inference.py --help` for full documentation. Key parameters:
+
+- `--input_path`: Path to input directory (required)
+- `--mask_prompt`: Mask folder name. If None, images and masks are in the same directory; if specified, images are in `input_path/images/` and masks are in `input_path/{mask_prompt}/`
+- `--image_names`: Image names (without extension), e.g., `"image1,view_a"` or `"1,2"` or `"image1"`. Can specify multiple, comma-separated. If not specified, uses all available images
+- `--decode_formats`: Output formats, e.g., `"gaussian,mesh"` or `"gaussian"` (default: `gaussian,mesh`)
+- `--seed`: Random seed (default: 42)
+- `--stage1_steps`: Stage 1 inference steps (default: 50)
+- `--stage2_steps`: Stage 2 inference steps (default: 25)
+- `--model_tag`: Model tag (default: hf)
+
+The script automatically detects whether to use single-view or multi-view inference based on the number of views provided. Multi-view reconstruction uses a training-free multidiffusion approach to fuse predictions from all views.
## SAM 3D Body
diff --git a/data/example/images/1.png b/data/example/images/1.png
new file mode 100644
index 0000000..0cd8986
Binary files /dev/null and b/data/example/images/1.png differ
diff --git a/data/example/images/2.png b/data/example/images/2.png
new file mode 100644
index 0000000..015325c
Binary files /dev/null and b/data/example/images/2.png differ
diff --git a/data/example/images/3.png b/data/example/images/3.png
new file mode 100644
index 0000000..c4aeb80
Binary files /dev/null and b/data/example/images/3.png differ
diff --git a/data/example/images/4.png b/data/example/images/4.png
new file mode 100644
index 0000000..a1f7ef0
Binary files /dev/null and b/data/example/images/4.png differ
diff --git a/data/example/images/5.png b/data/example/images/5.png
new file mode 100644
index 0000000..5678d60
Binary files /dev/null and b/data/example/images/5.png differ
diff --git a/data/example/images/6.png b/data/example/images/6.png
new file mode 100644
index 0000000..4e99c6a
Binary files /dev/null and b/data/example/images/6.png differ
diff --git a/data/example/images/7.png b/data/example/images/7.png
new file mode 100644
index 0000000..af77b9b
Binary files /dev/null and b/data/example/images/7.png differ
diff --git a/data/example/images/8.png b/data/example/images/8.png
new file mode 100644
index 0000000..7e98bd0
Binary files /dev/null and b/data/example/images/8.png differ
diff --git a/data/example/stuffed_toy/1_mask.png b/data/example/stuffed_toy/1_mask.png
new file mode 100644
index 0000000..b4a0f8e
Binary files /dev/null and b/data/example/stuffed_toy/1_mask.png differ
diff --git a/data/example/stuffed_toy/2_mask.png b/data/example/stuffed_toy/2_mask.png
new file mode 100644
index 0000000..31dc219
Binary files /dev/null and b/data/example/stuffed_toy/2_mask.png differ
diff --git a/data/example/stuffed_toy/3_mask.png b/data/example/stuffed_toy/3_mask.png
new file mode 100644
index 0000000..85c00e2
Binary files /dev/null and b/data/example/stuffed_toy/3_mask.png differ
diff --git a/data/example/stuffed_toy/4_mask.png b/data/example/stuffed_toy/4_mask.png
new file mode 100644
index 0000000..48a1759
Binary files /dev/null and b/data/example/stuffed_toy/4_mask.png differ
diff --git a/data/example/stuffed_toy/5_mask.png b/data/example/stuffed_toy/5_mask.png
new file mode 100644
index 0000000..0e44e2c
Binary files /dev/null and b/data/example/stuffed_toy/5_mask.png differ
diff --git a/data/example/stuffed_toy/6_mask.png b/data/example/stuffed_toy/6_mask.png
new file mode 100644
index 0000000..41682e6
Binary files /dev/null and b/data/example/stuffed_toy/6_mask.png differ
diff --git a/data/example/stuffed_toy/7_mask.png b/data/example/stuffed_toy/7_mask.png
new file mode 100644
index 0000000..d310642
Binary files /dev/null and b/data/example/stuffed_toy/7_mask.png differ
diff --git a/data/example/stuffed_toy/8_mask.png b/data/example/stuffed_toy/8_mask.png
new file mode 100644
index 0000000..071dc8d
Binary files /dev/null and b/data/example/stuffed_toy/8_mask.png differ
diff --git a/data/example/visualization_results/all_views_cropped.gif b/data/example/visualization_results/all_views_cropped.gif
new file mode 100644
index 0000000..4df3097
Binary files /dev/null and b/data/example/visualization_results/all_views_cropped.gif differ
diff --git a/data/example/visualization_results/all_views_cropped.mp4 b/data/example/visualization_results/all_views_cropped.mp4
new file mode 100644
index 0000000..c15bed6
Binary files /dev/null and b/data/example/visualization_results/all_views_cropped.mp4 differ
diff --git a/data/example/visualization_results/view3_cropped.gif b/data/example/visualization_results/view3_cropped.gif
new file mode 100644
index 0000000..5650893
Binary files /dev/null and b/data/example/visualization_results/view3_cropped.gif differ
diff --git a/data/example/visualization_results/view3_cropped.mp4 b/data/example/visualization_results/view3_cropped.mp4
new file mode 100644
index 0000000..5c607b4
Binary files /dev/null and b/data/example/visualization_results/view3_cropped.mp4 differ
diff --git a/data/example/visualization_results/view6_cropped.gif b/data/example/visualization_results/view6_cropped.gif
new file mode 100644
index 0000000..9a64a9e
Binary files /dev/null and b/data/example/visualization_results/view6_cropped.gif differ
diff --git a/data/example/visualization_results/view6_cropped.mp4 b/data/example/visualization_results/view6_cropped.mp4
new file mode 100644
index 0000000..625b77e
Binary files /dev/null and b/data/example/visualization_results/view6_cropped.mp4 differ
diff --git a/demo.py b/demo.py
index befdd84..062a281 100644
--- a/demo.py
+++ b/demo.py
@@ -11,7 +11,7 @@
# load image (RGBA only, mask is embedded in the alpha channel)
image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png")
-mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14)
+mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=11)
# run model
output = inference(image, mask, seed=42)
diff --git a/notebook/load_images_and_masks.py b/notebook/load_images_and_masks.py
new file mode 100644
index 0000000..64de3a5
--- /dev/null
+++ b/notebook/load_images_and_masks.py
@@ -0,0 +1,305 @@
+"""
+Load multi-view data from specified path
+Supports two data structures:
+1. mask_prompt=None: All images and masks in the same directory, naming format: xxxx.png and xxxx_mask.png
+2. mask_prompt!=None: Images in input_path/images/, masks in input_path/{mask_prompt}/
+"""
+from pathlib import Path
+from typing import List, Optional, Tuple
+import numpy as np
+from PIL import Image
+from loguru import logger
+
+
+def load_image(path: Path) -> np.ndarray:
+ """Load image as numpy array"""
+ img = Image.open(path)
+ return np.array(img).astype(np.uint8)
+
+
+def load_mask_from_rgba(path: Path) -> np.ndarray:
+ """
+ Load mask from RGBA image (extract from alpha channel)
+
+ Args:
+ path: RGBA image file path
+
+ Returns:
+ mask: Binary mask, shape (H, W), bool format
+ """
+ img = Image.open(path)
+ img_array = np.array(img)
+
+ if img.mode == 'RGBA' and img_array.ndim == 3 and img_array.shape[2] >= 4:
+ mask = img_array[..., 3] > 0
+ elif img.mode == 'RGB':
+ logger.warning(f"Mask file {path} is RGB format, not RGBA. Using all pixels as mask.")
+ mask = np.ones((img_array.shape[0], img_array.shape[1]), dtype=bool)
+ else:
+ logger.warning(f"Unexpected image mode {img.mode} for mask file {path}")
+ mask = np.ones((img_array.shape[0], img_array.shape[1]), dtype=bool)
+
+ return mask
+
+
+def load_images_and_masks(
+ images_and_masks_dir: Path,
+ image_names: Optional[List[str]] = None,
+) -> Tuple[List[np.ndarray], List[np.ndarray]]:
+ """
+ Load multi-view data from images_and_masks folder
+
+ Data structure:
+ images_and_masks/
+ ├── 1.png (or image1.png, view_a.png, etc.)
+ ├── 1_mask.png (or image1_mask.png, view_a_mask.png)
+ ├── 2.png
+ ├── 2_mask.png
+ └── ...
+
+ Args:
+ images_and_masks_dir: Path to images_and_masks folder
+ image_names: List of image names (without extension), e.g., ["image1", "view_a"] or ["1", "2"],
+ if None then auto-detect all
+
+ Returns:
+ images: List of images (numpy arrays)
+ masks: List of masks (numpy arrays, bool format)
+ """
+ if not images_and_masks_dir.exists():
+ raise FileNotFoundError(f"Directory does not exist: {images_and_masks_dir}")
+
+ if not images_and_masks_dir.is_dir():
+ raise ValueError(f"Path is not a directory: {images_and_masks_dir}")
+
+ if image_names is None:
+ image_files = sorted(images_and_masks_dir.glob("*.png")) + sorted(images_and_masks_dir.glob("*.jpg"))
+ image_files = [f for f in image_files if "_mask" not in f.name]
+ image_names = [f.stem for f in image_files]
+ logger.info(f"Auto-detected {len(image_names)} images: {image_names}")
+
+ images = []
+ masks = []
+
+ for image_name in image_names:
+ image_candidates = [
+ images_and_masks_dir / f"{image_name}.png",
+ images_and_masks_dir / f"{image_name}.jpg",
+ ]
+
+ mask_candidates = [
+ images_and_masks_dir / f"{image_name}_mask.png",
+ images_and_masks_dir / f"{image_name}_mask.jpg",
+ ]
+
+ image_path = None
+ for candidate in image_candidates:
+ if candidate.exists():
+ image_path = candidate
+ break
+
+ mask_path = None
+ for candidate in mask_candidates:
+ if candidate.exists():
+ mask_path = candidate
+ break
+
+ if image_path is None:
+ logger.warning(f"Image file not found for '{image_name}', skipping")
+ continue
+
+ if mask_path is None:
+ logger.warning(f"Mask file not found for '{image_name}', skipping")
+ continue
+
+ try:
+ image = load_image(image_path)
+ mask = load_mask_from_rgba(mask_path)
+
+ images.append(image)
+ masks.append(mask)
+
+ logger.info(f"Loaded '{image_name}': image={image.shape}, mask={mask.shape}")
+
+ except Exception as e:
+ logger.error(f"Failed to load '{image_name}': {e}")
+ continue
+
+ if len(images) == 0:
+ raise ValueError(f"No valid images and masks found in {images_and_masks_dir}")
+
+ logger.info(f"Successfully loaded {len(images)} images")
+ return images, masks
+
+
+def load_from_segmentation_structure(
+ segmentation_base_dir: Path,
+ prompt: Optional[str] = None,
+ view_indices: Optional[List[int]] = None,
+) -> Tuple[List[np.ndarray], List[np.ndarray]]:
+ """
+ Load data from complete segmentation structure (kept for backward compatibility)
+
+ Structure:
+ visualization/
+ └── {folder_name}_segmentation/
+ └── {prompt}/
+ └── images_and_masks/
+
+ Args:
+ segmentation_base_dir: Segmentation base directory (e.g., visualization/joy_segmentation)
+ prompt: Prompt subfolder name (e.g., "stuffed_toy"), if None then auto-detect
+ view_indices: List of view indices, if None then auto-detect all
+
+ Returns:
+ images: List of images
+ masks: List of masks
+ """
+ if prompt:
+ segmentation_dir = segmentation_base_dir / prompt
+ else:
+ prompt_dirs = [d for d in segmentation_base_dir.iterdir()
+ if d.is_dir() and d.name != "all_masks"]
+ if len(prompt_dirs) == 0:
+ raise ValueError(f"No prompt subdirectories found in {segmentation_base_dir}")
+ elif len(prompt_dirs) == 1:
+ segmentation_dir = prompt_dirs[0]
+ logger.info(f"Auto-detected prompt directory: {segmentation_dir.name}")
+ else:
+ raise ValueError(
+ f"Multiple prompt directories found in {segmentation_base_dir}. "
+ f"Please specify --prompt. Found: {[d.name for d in prompt_dirs]}"
+ )
+
+ images_and_masks_dir = segmentation_dir / "images_and_masks"
+
+ if not images_and_masks_dir.exists():
+ raise FileNotFoundError(
+ f"images_and_masks directory not found: {images_and_masks_dir}"
+ )
+
+ if view_indices is not None and len(view_indices) > 0 and isinstance(view_indices[0], int):
+ image_names = [str(idx) for idx in view_indices]
+ else:
+ image_names = view_indices
+ return load_images_and_masks(images_and_masks_dir, image_names=image_names)
+
+
+def load_images_and_masks_from_path(
+ input_path: Path,
+ mask_prompt: Optional[str] = None,
+ image_names: Optional[List[str]] = None,
+) -> Tuple[List[np.ndarray], List[np.ndarray]]:
+ """
+ Load multi-view data from specified path (supports two data structures)
+
+ Data structure 1 (mask_prompt=None):
+ input_path/
+ ├── 1.png
+ ├── 1_mask.png
+ ├── 2.png
+ ├── 2_mask.png
+ └── ...
+
+ Data structure 2 (mask_prompt!=None, e.g., mask_prompt="stuffed_toy"):
+ input_path/
+ ├── images/
+ │ ├── 1.png
+ │ ├── 2.png
+ │ └── ...
+ └── stuffed_toy/ (or {mask_prompt}/)
+ ├── 1.png (or 1_mask.png)
+ ├── 2.png (or 2_mask.png)
+ └── ...
+
+ Args:
+ input_path: Input path
+ mask_prompt: Mask folder name, if None then images and masks are in the same directory
+ image_names: List of image names (without extension), e.g., ["image1", "view_a"] or ["1", "2"],
+ if None then auto-detect all
+
+ Returns:
+ images: List of images
+ masks: List of masks
+ """
+ if not input_path.exists():
+ raise FileNotFoundError(f"Input path does not exist: {input_path}")
+
+ if not input_path.is_dir():
+ raise ValueError(f"Input path is not a directory: {input_path}")
+
+ if mask_prompt is None:
+ logger.info(f"Loading from single directory: {input_path}")
+ return load_images_and_masks(input_path, image_names=image_names)
+ else:
+ images_dir = input_path / "images"
+ masks_dir = input_path / mask_prompt
+
+ if not images_dir.exists():
+ raise FileNotFoundError(f"Images directory does not exist: {images_dir}")
+ if not masks_dir.exists():
+ raise FileNotFoundError(f"Mask directory does not exist: {masks_dir}")
+
+ logger.info(f"Loading images from: {images_dir}")
+ logger.info(f"Loading masks from: {masks_dir}")
+
+ if image_names is None:
+ image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
+ image_names = [f.stem for f in image_files]
+ logger.info(f"Auto-detected {len(image_names)} images: {image_names}")
+
+ images = []
+ masks = []
+
+ for image_name in image_names:
+ image_candidates = [
+ images_dir / f"{image_name}.png",
+ images_dir / f"{image_name}.jpg",
+ ]
+
+ mask_candidates = [
+ masks_dir / f"{image_name}.png",
+ masks_dir / f"{image_name}_mask.png",
+ masks_dir / f"{image_name}.jpg",
+ masks_dir / f"{image_name}_mask.jpg",
+ ]
+
+ image_path = None
+ for candidate in image_candidates:
+ if candidate.exists():
+ image_path = candidate
+ break
+
+ mask_path = None
+ for candidate in mask_candidates:
+ if candidate.exists():
+ mask_path = candidate
+ break
+
+ if image_path is None:
+ logger.warning(f"Image file not found for '{image_name}', skipping")
+ continue
+
+ if mask_path is None:
+ logger.warning(f"Mask file not found for '{image_name}', skipping")
+ continue
+
+ try:
+ image = load_image(image_path)
+ mask = load_mask_from_rgba(mask_path)
+
+ images.append(image)
+ masks.append(mask)
+
+ logger.info(f"Loaded '{image_name}': image={image.shape}, mask={mask.shape}")
+
+ except Exception as e:
+ logger.error(f"Failed to load '{image_name}': {e}")
+ continue
+
+ if len(images) == 0:
+ raise ValueError(f"No valid images and masks found in {input_path}")
+
+ logger.info(f"Successfully loaded {len(images)} images")
+ return images, masks
+
diff --git a/run_inference.py b/run_inference.py
new file mode 100644
index 0000000..c06df2c
--- /dev/null
+++ b/run_inference.py
@@ -0,0 +1,321 @@
+"""
+SAM 3D Objects Inference Script
+Supports both single-view and multi-view 3D reconstruction
+
+Usage:
+ # Multi-view inference (mask_prompt=None, images and masks in same directory, use all images)
+ python run_inference.py --input_path ./data/images_and_masks
+
+ # Single-view inference (specify a single image name)
+ python run_inference.py --input_path ./data/images_and_masks --image_names image1
+
+ # Multi-view inference (mask_prompt!=None, images in images/, masks in specified folder)
+ python run_inference.py --input_path ./data --mask_prompt stuffed_toy
+
+ # Specify multiple image names (can be any filename without extension)
+ python run_inference.py --input_path ./data --mask_prompt stuffed_toy --image_names image1,view_a,2
+"""
+import sys
+import argparse
+from pathlib import Path
+from typing import List, Optional
+from loguru import logger
+
+# 导入推理代码
+sys.path.append("notebook")
+from inference import Inference
+from load_images_and_masks import load_images_and_masks_from_path
+
+
+def parse_image_names(image_names_str: Optional[str]) -> Optional[List[str]]:
+ """
+ Parse image names string
+
+ Args:
+ image_names_str: Image names string, e.g., "image1,view_a" or "1,2" or "image1"
+ Can be any filename (without extension) or numbers
+
+ Returns:
+ image_names: List of image names (without extension), None means use all available images
+ """
+ if image_names_str is None or image_names_str == "":
+ return None
+
+ names = [x.strip() for x in image_names_str.split(",") if x.strip()]
+ return names if names else None
+
+
+def get_output_dir(
+ input_path: Path,
+ mask_prompt: Optional[str] = None,
+ image_names: Optional[List[str]] = None,
+ is_single_view: bool = False
+) -> Path:
+ """
+ Create output directory based on input path and parameters
+
+ Args:
+ input_path: Input path
+ mask_prompt: Mask folder name (if using separated directory structure)
+ image_names: List of image names
+ is_single_view: Whether it's single-view inference
+
+ Returns:
+ output_dir: Path to visualization/{mask_prompt_or_dirname}_{image_names}/ directory
+ """
+ visualization_dir = Path("visualization")
+ visualization_dir.mkdir(exist_ok=True)
+
+ if mask_prompt:
+ dir_name = mask_prompt
+ else:
+ dir_name = input_path.name if input_path.is_dir() else input_path.parent.name
+
+ if is_single_view:
+ if image_names and len(image_names) == 1:
+ safe_name = image_names[0].replace("/", "_").replace("\\", "_")
+ dir_name = f"{dir_name}_{safe_name}"
+ else:
+ dir_name = f"{dir_name}_single"
+ elif image_names:
+ if len(image_names) == 1:
+ safe_name = image_names[0].replace("/", "_").replace("\\", "_")
+ dir_name = f"{dir_name}_{safe_name}"
+ else:
+ safe_names = [name.replace("/", "_").replace("\\", "_") for name in image_names]
+ dir_name = f"{dir_name}_{'_'.join(safe_names[:3])}"
+ if len(safe_names) > 3:
+ dir_name += f"_and_{len(safe_names)-3}_more"
+ else:
+ dir_name = f"{dir_name}_multiview"
+
+ output_dir = visualization_dir / dir_name
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ logger.info(f"Output directory: {output_dir}")
+ return output_dir
+
+
+def run_inference(
+ input_path: Path,
+ mask_prompt: Optional[str] = None,
+ image_names: Optional[List[str]] = None,
+ seed: int = 42,
+ stage1_steps: int = 50,
+ stage2_steps: int = 25,
+ decode_formats: List[str] = None,
+ model_tag: str = "hf",
+):
+ """
+ Run inference
+
+ Args:
+ input_path: Input path
+ mask_prompt: Mask folder name, if None then images and masks are in the same directory
+ image_names: List of image names (without extension), e.g., ["image1", "view_a"] or ["1", "2"],
+ None means use all available images
+ seed: Random seed
+ stage1_steps: Stage 1 inference steps
+ stage2_steps: Stage 2 inference steps
+ decode_formats: List of decode formats
+ model_tag: Model tag
+ """
+ config_path = f"checkpoints/{model_tag}/pipeline.yaml"
+ if not Path(config_path).exists():
+ raise FileNotFoundError(f"Model config file not found: {config_path}")
+
+ logger.info(f"Loading model: {config_path}")
+ inference = Inference(config_path, compile=False)
+
+ if hasattr(inference._pipeline, 'rendering_engine'):
+ if inference._pipeline.rendering_engine != "pytorch3d":
+ logger.warning(f"Rendering engine is set to {inference._pipeline.rendering_engine}, changing to pytorch3d")
+ inference._pipeline.rendering_engine = "pytorch3d"
+
+ logger.info(f"Loading data: {input_path}")
+ if mask_prompt:
+ logger.info(f"Mask prompt: {mask_prompt} (images in images/, masks in {mask_prompt}/)")
+ else:
+ logger.info("Mask prompt: None (images and masks in same directory)")
+
+ view_images, view_masks = load_images_and_masks_from_path(
+ input_path=input_path,
+ mask_prompt=mask_prompt,
+ image_names=image_names,
+ )
+
+ num_views = len(view_images)
+ logger.info(f"Successfully loaded {num_views} views")
+
+ is_single_view = num_views == 1
+
+ if is_single_view:
+ logger.info("Single-view inference mode")
+ image = view_images[0]
+ mask = view_masks[0] if view_masks else None
+ result = inference(image, mask, seed=seed)
+ else:
+ logger.info("Multi-view inference mode")
+ decode_formats = decode_formats or ["gaussian", "mesh"]
+
+ result = inference._pipeline.run_multi_view(
+ view_images=view_images,
+ view_masks=view_masks,
+ seed=seed,
+ mode="multidiffusion",
+ stage1_inference_steps=stage1_steps,
+ stage2_inference_steps=stage2_steps,
+ decode_formats=decode_formats,
+ with_mesh_postprocess=False,
+ with_texture_baking=False,
+ use_vertex_color=True,
+ )
+
+ output_dir = get_output_dir(input_path, mask_prompt, image_names, is_single_view)
+ saved_files = []
+
+ print(f"\n{'='*60}")
+ print(f"Inference completed!")
+ print(f"Generated coordinates: {result['coords'].shape[0] if 'coords' in result else 'N/A'}")
+ print(f"{'='*60}")
+
+ if 'glb' in result and result['glb'] is not None:
+ output_path = output_dir / "result.glb"
+ result['glb'].export(str(output_path))
+ saved_files.append("result.glb")
+ print(f"✓ GLB file saved to: {output_path}")
+
+ if 'gs' in result:
+ output_path = output_dir / "result.ply"
+ result['gs'].save_ply(str(output_path))
+ saved_files.append("result.ply")
+ print(f"✓ Gaussian Splatting (PLY) saved to: {output_path}")
+ elif 'gaussian' in result:
+ if isinstance(result['gaussian'], list) and len(result['gaussian']) > 0:
+ output_path = output_dir / "result.ply"
+ result['gaussian'][0].save_ply(str(output_path))
+ saved_files.append("result.ply")
+ print(f"✓ Gaussian Splatting (PLY) saved to: {output_path}")
+
+ if 'mesh' in result:
+ print(f"✓ Mesh information generated (included in GLB)")
+
+ print(f"\n{'='*60}")
+ print(f"All output files saved to: {output_dir}")
+ print(f"Saved files: {', '.join(saved_files)}")
+ print(f"{'='*60}")
+
+ print(f"\nFile descriptions:")
+ print(f"- PLY file: Gaussian Splatting format with position and color information")
+ print(f" * Recommended to use specialized Gaussian Splatting viewers")
+ print(f"- GLB file: Complete 3D mesh model, can be viewed in Blender, Three.js, etc.")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="SAM 3D Objects Inference Script - Supports single-view and multi-view 3D reconstruction",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Multi-view inference (mask_prompt=None, images and masks in same directory, use all images)
+ python run_inference.py --input_path ./data/images_and_masks
+
+ # Single-view inference (specify a single image name)
+ python run_inference.py --input_path ./data/images_and_masks --image_names image1
+
+ # Multi-view inference (mask_prompt!=None, images in images/, masks in specified folder)
+ python run_inference.py --input_path ./data --mask_prompt stuffed_toy
+
+ # Specify multiple image names (can be any filename without extension)
+ python run_inference.py --input_path ./data --mask_prompt stuffed_toy --image_names image1,view_a,2
+ """
+ )
+
+ parser.add_argument(
+ "--input_path",
+ type=str,
+ required=True,
+ help="Input path. If mask_prompt=None, images and masks are in this directory; "
+ "if mask_prompt!=None, images are in input_path/images/, masks in input_path/{mask_prompt}/"
+ )
+ parser.add_argument(
+ "--mask_prompt",
+ type=str,
+ default=None,
+ help="Mask folder name. If None, images and masks are in the same directory "
+ "(naming format: xxxx.png and xxxx_mask.png); "
+ "if not None, images are in input_path/images/, masks in input_path/{mask_prompt}/"
+ )
+ parser.add_argument(
+ "--image_names",
+ type=str,
+ default=None,
+ help="Image names (without extension), e.g., 'image1,view_a' or '1,2' or 'image1'. "
+ "Can specify multiple, comma-separated. If not specified, use all available images"
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="Random seed (default: 42)"
+ )
+ parser.add_argument(
+ "--stage1_steps",
+ type=int,
+ default=50,
+ help="Stage 1 inference steps (default: 50)"
+ )
+ parser.add_argument(
+ "--stage2_steps",
+ type=int,
+ default=25,
+ help="Stage 2 inference steps (default: 25)"
+ )
+
+ parser.add_argument(
+ "--decode_formats",
+ type=str,
+ default="gaussian,mesh",
+ help="Decode formats, comma-separated, e.g., 'gaussian,mesh' or 'gaussian' (default: gaussian,mesh)"
+ )
+
+ parser.add_argument(
+ "--model_tag",
+ type=str,
+ default="hf",
+ help="Model tag (default: hf)"
+ )
+
+ args = parser.parse_args()
+
+ input_path = Path(args.input_path)
+ if not input_path.exists():
+ raise FileNotFoundError(f"Input path does not exist: {input_path}")
+
+ image_names = parse_image_names(args.image_names)
+
+ decode_formats = [fmt.strip() for fmt in args.decode_formats.split(",") if fmt.strip()]
+ if not decode_formats:
+ decode_formats = ["gaussian", "mesh"]
+
+ try:
+ run_inference(
+ input_path=input_path,
+ mask_prompt=args.mask_prompt,
+ image_names=image_names,
+ seed=args.seed,
+ stage1_steps=args.stage1_steps,
+ stage2_steps=args.stage2_steps,
+ decode_formats=decode_formats,
+ model_tag=args.model_tag,
+ )
+ except Exception as e:
+ logger.error(f"Inference failed: {e}")
+ import traceback
+ traceback.print_exc()
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py b/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py
index de7a4da..1f60ace 100644
--- a/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py
+++ b/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py
@@ -615,12 +615,17 @@ def to_glb(
if with_mesh_postprocess:
# mesh postprocess
+ # 如果rendering_engine是pytorch3d,禁用fill_holes因为它需要nvdiffrast
+ effective_fill_holes = fill_holes and rendering_engine == "nvdiffrast"
+ if fill_holes and rendering_engine == "pytorch3d":
+ logger.warning("fill_holes is disabled because rendering_engine is 'pytorch3d' (requires nvdiffrast)")
+
vertices, faces = postprocess_mesh(
vertices,
faces,
simplify=simplify > 0,
simplify_ratio=simplify,
- fill_holes=fill_holes,
+ fill_holes=effective_fill_holes,
fill_holes_max_hole_size=fill_holes_max_size,
fill_holes_max_hole_nbe=int(250 * np.sqrt(1 - simplify)),
fill_holes_resolution=1024,
diff --git a/sam3d_objects/pipeline/inference_pipeline.py b/sam3d_objects/pipeline/inference_pipeline.py
index e4b81fe..15cb0ef 100644
--- a/sam3d_objects/pipeline/inference_pipeline.py
+++ b/sam3d_objects/pipeline/inference_pipeline.py
@@ -20,7 +20,7 @@ def set_attention_backend():
set_attention_backend()
-from typing import List, Union
+from typing import List, Union, Optional, Literal
from hydra.utils import instantiate
from omegaconf import OmegaConf
import numpy as np
@@ -842,3 +842,393 @@ def _get_dtype(dtype):
return torch.float32
else:
raise NotImplementedError
+
+ def get_multi_view_condition_input(
+ self,
+ condition_embedder,
+ view_input_dicts: List[dict],
+ input_mapping
+ ):
+ """
+ 为多视角输入准备条件
+
+ Args:
+ condition_embedder: 条件嵌入器
+ view_input_dicts: 每个视角的输入字典列表
+ input_mapping: 输入映射
+
+ Returns:
+ condition_args: 条件参数(包含所有视角的条件tokens)
+ condition_kwargs: 条件关键字参数
+ """
+ # 对每个视角分别提取条件
+ view_conditions = []
+ for view_input_dict in view_input_dicts:
+ condition_args = self.map_input_keys(view_input_dict, input_mapping)
+ condition_kwargs = {
+ k: v for k, v in view_input_dict.items() if k not in input_mapping
+ }
+ embedded_cond, _, _ = self.embed_condition(
+ condition_embedder, *condition_args, **condition_kwargs
+ )
+ if embedded_cond is not None:
+ view_conditions.append(embedded_cond)
+ else:
+ # 如果没有嵌入,使用原始参数
+ view_conditions.append(condition_args)
+
+ # 将所有视角的条件堆叠在一起
+ # 形状: (num_views, batch_size, num_tokens, dim)
+ if isinstance(view_conditions[0], torch.Tensor):
+ # 如果是tensor,堆叠
+ all_conditions = torch.stack(view_conditions, dim=0)
+ else:
+ # 如果是其他类型,保持为列表
+ all_conditions = view_conditions
+
+ return (all_conditions,), {}
+
+ def sample_sparse_structure_multi_view(
+ self,
+ view_ss_input_dicts: List[dict],
+ inference_steps=None,
+ use_distillation=False,
+ mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion',
+ ):
+ """
+ 多视角稀疏结构生成
+
+ Args:
+ view_ss_input_dicts: 每个视角的输入字典列表
+ inference_steps: 推理步数
+ use_distillation: 是否使用蒸馏
+ mode: 'stochastic' 或 'multidiffusion'
+ """
+ from sam3d_objects.pipeline.multi_view_utils import inject_generator_multi_view
+
+ ss_generator = self.models["ss_generator"]
+ ss_decoder = self.models["ss_decoder"]
+ num_views = len(view_ss_input_dicts)
+
+ if use_distillation:
+ ss_generator.no_shortcut = False
+ ss_generator.reverse_fn.strength = 0
+ ss_generator.reverse_fn.strength_pm = 0
+ else:
+ ss_generator.no_shortcut = True
+ ss_generator.reverse_fn.strength = self.ss_cfg_strength
+ ss_generator.reverse_fn.strength_pm = self.ss_cfg_strength_pm
+
+ prev_inference_steps = ss_generator.inference_steps
+ if inference_steps:
+ ss_generator.inference_steps = inference_steps
+
+ image = view_ss_input_dicts[0]["image"]
+ bs = image.shape[0]
+ logger.info(
+ f"Sampling sparse structure with {num_views} views: "
+ f"inference_steps={ss_generator.inference_steps}, mode={mode}"
+ )
+
+ with torch.no_grad():
+ with torch.autocast(device_type="cuda", dtype=self.shape_model_dtype):
+ if self.is_mm_dit():
+ latent_shape_dict = {
+ k: (bs,) + (v.pos_emb.shape[0], v.input_layer.in_features)
+ for k, v in ss_generator.reverse_fn.backbone.latent_mapping.items()
+ }
+ logger.info(f"[Stage 1] Latent shape (MM-DiT): {latent_shape_dict}")
+ else:
+ latent_shape_dict = (bs,) + (4096, 8)
+ logger.info(f"[Stage 1] Latent shape: {latent_shape_dict}")
+
+ # 准备多视角条件
+ condition_args, condition_kwargs = self.get_multi_view_condition_input(
+ self.condition_embedders["ss_condition_embedder"],
+ view_ss_input_dicts,
+ self.ss_condition_input_mapping,
+ )
+
+ # 注入多视角支持
+ with inject_generator_multi_view(
+ ss_generator,
+ num_views=num_views,
+ num_steps=ss_generator.inference_steps,
+ mode=mode
+ ):
+ return_dict = ss_generator(
+ latent_shape_dict,
+ image.device,
+ *condition_args,
+ **condition_kwargs,
+ )
+
+ if not self.is_mm_dit():
+ return_dict = {"shape": return_dict}
+
+ shape_latent = return_dict["shape"]
+ logger.info(f"[Stage 1 Multi-view] Generated shape_latent shape: {shape_latent.shape}")
+ ss = ss_decoder(
+ shape_latent.permute(0, 2, 1)
+ .contiguous()
+ .view(shape_latent.shape[0], 8, 16, 16, 16)
+ )
+ logger.info(f"[Stage 1 Multi-view] Decoded sparse structure shape: {ss.shape}")
+ coords = torch.argwhere(ss > 0)[:, [0, 2, 3, 4]].int()
+ logger.info(f"[Stage 1 Multi-view] Extracted coords shape: {coords.shape}")
+
+ # downsample output
+ return_dict["coords_original"] = coords
+ original_shape = coords.shape
+ if self.downsample_ss_dist > 0:
+ coords = prune_sparse_structure(
+ coords,
+ max_neighbor_axes_dist=self.downsample_ss_dist,
+ )
+ coords, downsample_factor = downsample_sparse_structure(coords)
+ logger.info(
+ f"[Stage 1 Multi-view] Downsampled coords from {original_shape[0]} to {coords.shape[0]}"
+ )
+ return_dict["coords"] = coords
+ return_dict["downsample_factor"] = downsample_factor
+
+ ss_generator.inference_steps = prev_inference_steps
+ return return_dict
+
+ def sample_slat_multi_view(
+ self,
+ view_slat_input_dicts: List[dict],
+ coords: torch.Tensor,
+ inference_steps=25,
+ use_distillation=False,
+ mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion',
+ ) -> sp.SparseTensor:
+ """
+ 多视角结构化潜在生成
+
+ Args:
+ view_slat_input_dicts: 每个视角的输入字典列表
+ coords: 坐标(从Stage 1得到)
+ inference_steps: 推理步数
+ use_distillation: 是否使用蒸馏
+ mode: 'stochastic' 或 'multidiffusion'
+ """
+ from sam3d_objects.pipeline.multi_view_utils import inject_generator_multi_view
+
+ image = view_slat_input_dicts[0]["image"]
+ DEVICE = image.device
+ slat_generator = self.models["slat_generator"]
+ num_views = len(view_slat_input_dicts)
+ latent_shape = (image.shape[0],) + (coords.shape[0], 8)
+ logger.info(f"[Stage 2] Coords shape: {coords.shape}")
+ logger.info(f"[Stage 2] Latent shape: {latent_shape}")
+ prev_inference_steps = slat_generator.inference_steps
+ if inference_steps:
+ slat_generator.inference_steps = inference_steps
+ if use_distillation:
+ slat_generator.no_shortcut = False
+ slat_generator.reverse_fn.strength = 0
+ else:
+ slat_generator.no_shortcut = True
+ slat_generator.reverse_fn.strength = self.slat_cfg_strength
+
+ logger.info(
+ f"Sampling sparse latent with {num_views} views: "
+ f"inference_steps={slat_generator.inference_steps}, mode={mode}"
+ )
+
+ with torch.autocast(device_type="cuda", dtype=self.dtype):
+ with torch.no_grad():
+ # 准备多视角条件
+ condition_args, condition_kwargs = self.get_multi_view_condition_input(
+ self.condition_embedders["slat_condition_embedder"],
+ view_slat_input_dicts,
+ self.slat_condition_input_mapping,
+ )
+ condition_args += (coords.cpu().numpy(),)
+
+ # 注入多视角支持
+ with inject_generator_multi_view(
+ slat_generator,
+ num_views=num_views,
+ num_steps=slat_generator.inference_steps,
+ mode=mode
+ ):
+ slat = slat_generator(
+ latent_shape, DEVICE, *condition_args, **condition_kwargs
+ )
+
+ logger.info(f"[Stage 2] Generated slat shape (before SparseTensor): {slat[0].shape if isinstance(slat, (list, tuple)) else slat.shape}")
+ slat = sp.SparseTensor(
+ coords=coords,
+ feats=slat[0],
+ ).to(DEVICE)
+ slat = slat * self.slat_std.to(DEVICE) + self.slat_mean.to(DEVICE)
+ logger.info(f"[Stage 2] Final slat: coords={slat.coords.shape}, feats={slat.feats.shape}")
+
+ slat_generator.inference_steps = prev_inference_steps
+ return slat
+
+ def run_multi_view(
+ self,
+ view_images: List[Union[np.ndarray, Image.Image]],
+ view_masks: List[Optional[Union[None, np.ndarray, Image.Image]]] = None,
+ num_samples: int = 1,
+ seed: Optional[int] = None,
+ stage1_inference_steps: Optional[int] = None,
+ stage2_inference_steps: Optional[int] = None,
+ use_stage1_distillation: bool = False,
+ use_stage2_distillation: bool = False,
+ decode_formats: Optional[List[str]] = None,
+ with_mesh_postprocess: bool = True,
+ with_texture_baking: bool = True,
+ use_vertex_color: bool = False,
+ stage1_only: bool = False,
+ mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion',
+ ) -> dict:
+ """
+ 多视角推理主函数
+
+ Args:
+ view_images: 每个视角的图像列表
+ view_masks: 每个视角的掩码列表(可选)
+ num_samples: 生成样本数
+ seed: 随机种子
+ stage1_inference_steps: Stage 1推理步数
+ stage2_inference_steps: Stage 2推理步数
+ use_stage1_distillation: 是否使用Stage 1蒸馏
+ use_stage2_distillation: 是否使用Stage 2蒸馏
+ decode_formats: 解码格式
+ with_mesh_postprocess: 是否进行网格后处理
+ with_texture_baking: 是否进行纹理烘焙
+ use_vertex_color: 是否使用顶点颜色
+ stage1_only: 是否只运行Stage 1
+ mode: 'stochastic' 或 'multidiffusion'
+ """
+ num_views = len(view_images)
+ if view_masks is None:
+ view_masks = [None] * num_views
+ assert len(view_masks) == num_views, "Number of masks must match number of images"
+
+ if seed is not None:
+ torch.manual_seed(seed)
+
+ logger.info(f"Running multi-view inference with {num_views} views, mode={mode}")
+
+ # 预处理每个视角
+ # 注意:需要先将mask合并到图像的alpha通道,然后调用preprocess_image
+ view_ss_input_dicts = []
+ view_slat_input_dicts = []
+ for i, (image, mask) in enumerate(zip(view_images, view_masks)):
+ logger.info(f"Preprocessing view {i+1}/{num_views}")
+
+ # 将mask合并到图像的alpha通道(RGBA格式)
+ # 如果image已经是RGBA格式(从mask的alpha通道加载),mask可能是None
+ if mask is not None:
+ # 确保image是numpy数组
+ if isinstance(image, Image.Image):
+ image = np.array(image)
+ else:
+ image = np.array(image)
+
+ # 确保mask是numpy数组
+ mask = np.array(mask)
+
+ # 如果mask是bool类型,转换为uint8
+ if mask.dtype == bool:
+ mask = mask.astype(np.uint8) * 255
+ elif mask.dtype != np.uint8:
+ # 如果mask是0-1范围的float,转换为0-255
+ if mask.max() <= 1.0:
+ mask = (mask * 255).astype(np.uint8)
+ else:
+ mask = mask.astype(np.uint8)
+
+ if mask.ndim == 2:
+ mask = mask[..., None]
+
+ # 合并mask到alpha通道
+ if image.shape[-1] == 3: # RGB
+ rgba_image = np.concatenate([image, mask], axis=-1).astype(np.uint8)
+ elif image.shape[-1] == 4: # 已经是RGBA,替换alpha通道
+ rgba_image = np.concatenate([image[..., :3], mask], axis=-1).astype(np.uint8)
+ else:
+ raise ValueError(f"Unexpected image shape: {image.shape}")
+ else:
+ # 如果没有mask,假设image已经是RGBA格式
+ if isinstance(image, Image.Image):
+ rgba_image = np.array(image)
+ else:
+ rgba_image = np.array(image)
+
+ # 转换为PIL Image(preprocess_image需要)
+ rgba_image_pil = Image.fromarray(rgba_image)
+
+ # 调用preprocess_image(注意:InferencePipelinePointMap需要pointmap)
+ # 先检查是否是InferencePipelinePointMap
+ if hasattr(self, 'compute_pointmap'):
+ # 这是InferencePipelinePointMap,需要计算pointmap
+ pointmap_dict = self.compute_pointmap(rgba_image_pil, pointmap=None)
+ pointmap = pointmap_dict["pointmap"]
+ ss_input_dict = self.preprocess_image(
+ rgba_image_pil, self.ss_preprocessor, pointmap=pointmap
+ )
+ slat_input_dict = self.preprocess_image(
+ rgba_image_pil, self.slat_preprocessor
+ )
+ else:
+ # 这是InferencePipeline,不需要pointmap
+ ss_input_dict = self.preprocess_image(
+ rgba_image_pil, self.ss_preprocessor
+ )
+ slat_input_dict = self.preprocess_image(
+ rgba_image_pil, self.slat_preprocessor
+ )
+
+ view_ss_input_dicts.append(ss_input_dict)
+ view_slat_input_dicts.append(slat_input_dict)
+
+ # Stage 1: 生成稀疏结构
+ logger.info("Stage 1: Sampling sparse structure...")
+ ss_return_dict = self.sample_sparse_structure_multi_view(
+ view_ss_input_dicts,
+ inference_steps=stage1_inference_steps,
+ use_distillation=use_stage1_distillation,
+ mode=mode,
+ )
+
+ ss_return_dict.update(self.pose_decoder(ss_return_dict))
+
+ if "scale" in ss_return_dict:
+ logger.info(f"Rescaling scale by {ss_return_dict['downsample_factor']}")
+ ss_return_dict["scale"] = ss_return_dict["scale"] * ss_return_dict["downsample_factor"]
+
+ if stage1_only:
+ logger.info("Finished!")
+ ss_return_dict["voxel"] = ss_return_dict["coords"][:, 1:] / 64 - 0.5
+ return ss_return_dict
+
+ # Stage 2: 生成结构化潜在
+ coords = ss_return_dict["coords"]
+ logger.info("Stage 2: Sampling structured latent...")
+ slat = self.sample_slat_multi_view(
+ view_slat_input_dicts,
+ coords,
+ inference_steps=stage2_inference_steps,
+ use_distillation=use_stage2_distillation,
+ mode=mode,
+ )
+
+ # 解码
+ outputs = self.decode_slat(
+ slat, self.decode_formats if decode_formats is None else decode_formats
+ )
+ outputs = self.postprocess_slat_output(
+ outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color
+ )
+ logger.info("Finished!")
+
+ return {
+ **ss_return_dict,
+ **outputs,
+ }
diff --git a/sam3d_objects/pipeline/multi_view_utils.py b/sam3d_objects/pipeline/multi_view_utils.py
new file mode 100644
index 0000000..1f56613
--- /dev/null
+++ b/sam3d_objects/pipeline/multi_view_utils.py
@@ -0,0 +1,153 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+"""
+Multi-view multidiffusion utilities for SAM 3D Objects
+Adapted from TRELLIS implementation, adapted for SAM 3D Objects' two-stage structure
+"""
+from contextlib import contextmanager
+from typing import List, Literal
+import torch
+from loguru import logger
+
+
+@contextmanager
+def inject_generator_multi_view(
+ generator,
+ num_views: int,
+ num_steps: int,
+ mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion',
+):
+ """
+ Inject multi-view support into generator
+
+ Args:
+ generator: SAM 3D Objects generator (ss_generator or slat_generator)
+ num_views: Number of views
+ num_steps: Number of inference steps
+ mode: 'stochastic' or 'multidiffusion'
+
+ Yields:
+ None
+ """
+ original_dynamics = generator._generate_dynamics
+
+ if mode == 'stochastic':
+ if num_views > num_steps:
+ logger.warning(
+ f"Warning: number of views ({num_views}) is greater than number of steps ({num_steps}). "
+ "This may lead to performance degradation."
+ )
+
+ cond_indices = (torch.arange(num_steps) % num_views).tolist()
+ cond_idx_counter = [0]
+
+ def _new_dynamics_stochastic(x_t, t, *args_conditionals, **kwargs_conditionals):
+ """Stochastic mode: select one view per time step"""
+ cond_idx = cond_indices[cond_idx_counter[0] % len(cond_indices)]
+ cond_idx_counter[0] += 1
+
+ if len(args_conditionals) > 0:
+ cond_tokens = args_conditionals[0]
+ if isinstance(cond_tokens, (list, tuple)):
+ cond_i = cond_tokens[cond_idx:cond_idx+1] if isinstance(cond_tokens[0], torch.Tensor) else [cond_tokens[cond_idx]]
+ new_args = (cond_i,) + args_conditionals[1:]
+ elif isinstance(cond_tokens, torch.Tensor) and cond_tokens.shape[0] == num_views:
+ cond_i = cond_tokens[cond_idx:cond_idx+1]
+ new_args = (cond_i,) + args_conditionals[1:]
+ else:
+ new_args = args_conditionals
+ else:
+ new_args = args_conditionals
+
+ return original_dynamics(x_t, t, *new_args, **kwargs_conditionals)
+
+ generator._generate_dynamics = _new_dynamics_stochastic
+
+ elif mode == 'multidiffusion':
+ def _new_dynamics_multidiffusion(x_t, t, *args_conditionals, **kwargs_conditionals):
+ """Multidiffusion mode: fuse predictions from all views at each time step"""
+ cond_idx = 0
+ if len(args_conditionals) > 0:
+ if isinstance(args_conditionals[0], (int, float)) or (isinstance(args_conditionals[0], torch.Tensor) and args_conditionals[0].numel() == 1):
+ cond_idx = 1
+
+ if len(args_conditionals) > cond_idx:
+ cond_tokens = args_conditionals[cond_idx]
+
+ if not hasattr(_new_dynamics_multidiffusion, '_logged_cond_shape'):
+ logger.info(f"[Multidiffusion] args_conditionals length: {len(args_conditionals)}")
+ logger.info(f"[Multidiffusion] cond_idx: {cond_idx}")
+ if isinstance(cond_tokens, torch.Tensor):
+ logger.info(f"[Multidiffusion] Condition tokens shape: {cond_tokens.shape}")
+ elif isinstance(cond_tokens, (list, tuple)):
+ logger.info(f"[Multidiffusion] Condition tokens type: {type(cond_tokens)}, length: {len(cond_tokens)}")
+ if len(cond_tokens) > 0 and isinstance(cond_tokens[0], torch.Tensor):
+ logger.info(f"[Multidiffusion] First condition token shape: {cond_tokens[0].shape}")
+ else:
+ logger.info(f"[Multidiffusion] Condition tokens type: {type(cond_tokens)}")
+ _new_dynamics_multidiffusion._logged_cond_shape = True
+
+ if isinstance(cond_tokens, (list, tuple)):
+ view_conditions = cond_tokens
+ elif isinstance(cond_tokens, torch.Tensor) and cond_tokens.shape[0] == num_views:
+ view_conditions = []
+ for i in range(num_views):
+ view_cond = cond_tokens[i]
+ view_conditions.append(view_cond)
+ else:
+ logger.warning(f"Condition tokens shape {cond_tokens.shape if isinstance(cond_tokens, torch.Tensor) else type(cond_tokens)} not organized by views, using same condition for all views")
+ view_conditions = [cond_tokens] * num_views
+
+ preds = []
+ for view_idx in range(num_views):
+ view_cond = view_conditions[view_idx]
+ if cond_idx < len(args_conditionals):
+ new_args = args_conditionals[:cond_idx] + (view_cond,) + args_conditionals[cond_idx+1:]
+ else:
+ new_args = args_conditionals + (view_cond,)
+ pred = original_dynamics(x_t, t, *new_args, **kwargs_conditionals)
+ preds.append(pred)
+
+ if not hasattr(_new_dynamics_multidiffusion, '_logged_shape'):
+ if isinstance(x_t, dict):
+ logger.info(f"[Multidiffusion] Latent shape (dict): {[(k, v.shape if isinstance(v, torch.Tensor) else type(v)) for k, v in x_t.items()]}")
+ elif isinstance(x_t, (list, tuple)):
+ logger.info(f"[Multidiffusion] Latent shape (tuple/list): {[v.shape if isinstance(v, torch.Tensor) else type(v) for v in x_t]}")
+ else:
+ logger.info(f"[Multidiffusion] Latent shape: {x_t.shape if isinstance(x_t, torch.Tensor) else type(x_t)}")
+
+ if isinstance(preds[0], dict):
+ logger.info(f"[Multidiffusion] Pred shape (dict): {[(k, v.shape if isinstance(v, torch.Tensor) else type(v)) for k, v in preds[0].items()]}")
+ elif isinstance(preds[0], (list, tuple)):
+ logger.info(f"[Multidiffusion] Pred shape (tuple/list): {[v.shape if isinstance(v, torch.Tensor) else type(v) for v in preds[0]]}")
+ else:
+ logger.info(f"[Multidiffusion] Pred shape: {preds[0].shape if isinstance(preds[0], torch.Tensor) else type(preds[0])}")
+ logger.info(f"[Multidiffusion] Number of views: {num_views}, fusing {len(preds)} predictions")
+ _new_dynamics_multidiffusion._logged_shape = True
+
+ if isinstance(preds[0], dict):
+ fused_pred = {}
+ for key in preds[0].keys():
+ fused_pred[key] = torch.stack([p[key] for p in preds]).mean(dim=0)
+ return fused_pred
+ elif isinstance(preds[0], (list, tuple)):
+ fused_pred = tuple(
+ torch.stack([p[i] for p in preds]).mean(dim=0)
+ for i in range(len(preds[0]))
+ )
+ return fused_pred
+ else:
+ fused_pred = torch.stack(preds).mean(dim=0)
+ return fused_pred
+ else:
+ return original_dynamics(x_t, t, *args_conditionals, **kwargs_conditionals)
+
+ generator._generate_dynamics = _new_dynamics_multidiffusion
+
+ else:
+ raise ValueError(f"Unsupported mode: {mode}")
+
+ try:
+ yield
+ finally:
+ generator._generate_dynamics = original_dynamics
+