diff --git a/.gitignore b/.gitignore index ed8ebf5..82c3be6 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,22 @@ -__pycache__ \ No newline at end of file +__pycache__ +*.pyc +*.py~ +*.swp +*.swo +*~ + +# Output directories +visualization/ +outputs/ +*.ply +*.glb +*.obj + +# IDE +.vscode/ +.idea/ +*.code-workspace + +# OS +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/README.md b/README.md index 59a01fa..e062430 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ + # SAM 3D -SAM 3D Objects is one part of SAM 3D, a pair of models for object and human mesh reconstruction. If you’re looking for SAM 3D Body, [click here](https://github.com/facebookresearch/sam-3d-body). +SAM 3D Objects is one part of SAM 3D, a pair of models for object and human mesh reconstruction. If you're looking for SAM 3D Body, [click here](https://github.com/facebookresearch/sam-3d-body). # SAM 3D Objects @@ -67,6 +68,140 @@ For more details and multi-object reconstruction, please take a look at out two * [single object](notebook/demo_single_object.ipynb) * [multi object](notebook/demo_multi_object.ipynb) +## Multi-View 3D Reconstruction + +This contribution adds **training-free multi-view 3D reconstruction** capability to SAM 3D Objects using a multidiffusion approach. This allows you to generate consistent 3D models from multiple input images of the same object from different viewpoints, without requiring model retraining. + +### Results Comparison + +The following comparison demonstrates the improvement of multi-view reconstruction over single-view reconstruction: + + + + + + + + + + + + + + + + + + + + + + + + + +
Single-View (View 3)Single-View (View 6)Multi-View (All 8 Views)
+ Input Image
+ +
+ Input Image
+ +
+ Input Images
+ + + + + + + + + + + + + +
+
+ ↓ 3D Reconstruction ↓ +
+ 3D Result
+ +
+ 3D Result
+ +
+ 3D Result
+ +
+ Analysis: Due to occlusion in the input image, the red collar on the dog is not visible, resulting in its absence in the generated 3D model. + + Analysis: Many frontal parts of the dog are occluded or not visible from this angle, leading to structural errors in the front-facing regions of the generated model. + + Analysis: By combining information from all 8 views, the multi-view reconstruction produces a complete and accurate 3D model that closely matches the actual object. +
+ +### Quick Start + +Use the `run_inference.py` script for both single-view and multi-view reconstruction: + +```bash +# Multi-view reconstruction (mask_prompt=None, images and masks in same directory) +python run_inference.py --input_path ./data/images_and_masks + +# Single-view reconstruction (specify a single image name) +python run_inference.py --input_path ./data/images_and_masks --image_names image1 + +# Multi-view reconstruction (mask_prompt!=None, images in images/, masks in {mask_prompt}/) +python run_inference.py --input_path ./data --mask_prompt stuffed_toy + +# Specify multiple image names (can be any filename without extension) +python run_inference.py --input_path ./data --mask_prompt stuffed_toy --image_names image1,view_a,2 +``` + +### Data Structure + +Multi-view data can be organized in two ways: + +**Structure 1** (when `mask_prompt=None`): Images and masks in the same directory +``` +input_path/ + ├── 1.png # Original image (PNG format) + ├── 1_mask.png # Mask (RGBA format, alpha channel stores mask info) + ├── 2.png + ├── 2_mask.png + └── ... +``` + +**Structure 2** (when `mask_prompt!=None`, e.g., `mask_prompt="stuffed_toy"`): Images and masks in separate directories +``` +input_path/ + ├── images/ + │ ├── 1.png + │ ├── 2.png + │ └── ... + └── stuffed_toy/ (or {mask_prompt}/) + ├── 1.png (or 1_mask.png) + ├── 2.png (or 2_mask.png) + └── ... +``` + +**Mask Format**: RGBA format where the alpha channel stores mask information (alpha=255 for object, alpha=0 for background). + +### Command Line Options + +Run `python run_inference.py --help` for full documentation. Key parameters: + +- `--input_path`: Path to input directory (required) +- `--mask_prompt`: Mask folder name. If None, images and masks are in the same directory; if specified, images are in `input_path/images/` and masks are in `input_path/{mask_prompt}/` +- `--image_names`: Image names (without extension), e.g., `"image1,view_a"` or `"1,2"` or `"image1"`. Can specify multiple, comma-separated. If not specified, uses all available images +- `--decode_formats`: Output formats, e.g., `"gaussian,mesh"` or `"gaussian"` (default: `gaussian,mesh`) +- `--seed`: Random seed (default: 42) +- `--stage1_steps`: Stage 1 inference steps (default: 50) +- `--stage2_steps`: Stage 2 inference steps (default: 25) +- `--model_tag`: Model tag (default: hf) + +The script automatically detects whether to use single-view or multi-view inference based on the number of views provided. Multi-view reconstruction uses a training-free multidiffusion approach to fuse predictions from all views. ## SAM 3D Body diff --git a/data/example/images/1.png b/data/example/images/1.png new file mode 100644 index 0000000..0cd8986 Binary files /dev/null and b/data/example/images/1.png differ diff --git a/data/example/images/2.png b/data/example/images/2.png new file mode 100644 index 0000000..015325c Binary files /dev/null and b/data/example/images/2.png differ diff --git a/data/example/images/3.png b/data/example/images/3.png new file mode 100644 index 0000000..c4aeb80 Binary files /dev/null and b/data/example/images/3.png differ diff --git a/data/example/images/4.png b/data/example/images/4.png new file mode 100644 index 0000000..a1f7ef0 Binary files /dev/null and b/data/example/images/4.png differ diff --git a/data/example/images/5.png b/data/example/images/5.png new file mode 100644 index 0000000..5678d60 Binary files /dev/null and b/data/example/images/5.png differ diff --git a/data/example/images/6.png b/data/example/images/6.png new file mode 100644 index 0000000..4e99c6a Binary files /dev/null and b/data/example/images/6.png differ diff --git a/data/example/images/7.png b/data/example/images/7.png new file mode 100644 index 0000000..af77b9b Binary files /dev/null and b/data/example/images/7.png differ diff --git a/data/example/images/8.png b/data/example/images/8.png new file mode 100644 index 0000000..7e98bd0 Binary files /dev/null and b/data/example/images/8.png differ diff --git a/data/example/stuffed_toy/1_mask.png b/data/example/stuffed_toy/1_mask.png new file mode 100644 index 0000000..b4a0f8e Binary files /dev/null and b/data/example/stuffed_toy/1_mask.png differ diff --git a/data/example/stuffed_toy/2_mask.png b/data/example/stuffed_toy/2_mask.png new file mode 100644 index 0000000..31dc219 Binary files /dev/null and b/data/example/stuffed_toy/2_mask.png differ diff --git a/data/example/stuffed_toy/3_mask.png b/data/example/stuffed_toy/3_mask.png new file mode 100644 index 0000000..85c00e2 Binary files /dev/null and b/data/example/stuffed_toy/3_mask.png differ diff --git a/data/example/stuffed_toy/4_mask.png b/data/example/stuffed_toy/4_mask.png new file mode 100644 index 0000000..48a1759 Binary files /dev/null and b/data/example/stuffed_toy/4_mask.png differ diff --git a/data/example/stuffed_toy/5_mask.png b/data/example/stuffed_toy/5_mask.png new file mode 100644 index 0000000..0e44e2c Binary files /dev/null and b/data/example/stuffed_toy/5_mask.png differ diff --git a/data/example/stuffed_toy/6_mask.png b/data/example/stuffed_toy/6_mask.png new file mode 100644 index 0000000..41682e6 Binary files /dev/null and b/data/example/stuffed_toy/6_mask.png differ diff --git a/data/example/stuffed_toy/7_mask.png b/data/example/stuffed_toy/7_mask.png new file mode 100644 index 0000000..d310642 Binary files /dev/null and b/data/example/stuffed_toy/7_mask.png differ diff --git a/data/example/stuffed_toy/8_mask.png b/data/example/stuffed_toy/8_mask.png new file mode 100644 index 0000000..071dc8d Binary files /dev/null and b/data/example/stuffed_toy/8_mask.png differ diff --git a/data/example/visualization_results/all_views_cropped.gif b/data/example/visualization_results/all_views_cropped.gif new file mode 100644 index 0000000..4df3097 Binary files /dev/null and b/data/example/visualization_results/all_views_cropped.gif differ diff --git a/data/example/visualization_results/all_views_cropped.mp4 b/data/example/visualization_results/all_views_cropped.mp4 new file mode 100644 index 0000000..c15bed6 Binary files /dev/null and b/data/example/visualization_results/all_views_cropped.mp4 differ diff --git a/data/example/visualization_results/view3_cropped.gif b/data/example/visualization_results/view3_cropped.gif new file mode 100644 index 0000000..5650893 Binary files /dev/null and b/data/example/visualization_results/view3_cropped.gif differ diff --git a/data/example/visualization_results/view3_cropped.mp4 b/data/example/visualization_results/view3_cropped.mp4 new file mode 100644 index 0000000..5c607b4 Binary files /dev/null and b/data/example/visualization_results/view3_cropped.mp4 differ diff --git a/data/example/visualization_results/view6_cropped.gif b/data/example/visualization_results/view6_cropped.gif new file mode 100644 index 0000000..9a64a9e Binary files /dev/null and b/data/example/visualization_results/view6_cropped.gif differ diff --git a/data/example/visualization_results/view6_cropped.mp4 b/data/example/visualization_results/view6_cropped.mp4 new file mode 100644 index 0000000..625b77e Binary files /dev/null and b/data/example/visualization_results/view6_cropped.mp4 differ diff --git a/demo.py b/demo.py index befdd84..062a281 100644 --- a/demo.py +++ b/demo.py @@ -11,7 +11,7 @@ # load image (RGBA only, mask is embedded in the alpha channel) image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png") -mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14) +mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=11) # run model output = inference(image, mask, seed=42) diff --git a/notebook/load_images_and_masks.py b/notebook/load_images_and_masks.py new file mode 100644 index 0000000..64de3a5 --- /dev/null +++ b/notebook/load_images_and_masks.py @@ -0,0 +1,305 @@ +""" +Load multi-view data from specified path +Supports two data structures: +1. mask_prompt=None: All images and masks in the same directory, naming format: xxxx.png and xxxx_mask.png +2. mask_prompt!=None: Images in input_path/images/, masks in input_path/{mask_prompt}/ +""" +from pathlib import Path +from typing import List, Optional, Tuple +import numpy as np +from PIL import Image +from loguru import logger + + +def load_image(path: Path) -> np.ndarray: + """Load image as numpy array""" + img = Image.open(path) + return np.array(img).astype(np.uint8) + + +def load_mask_from_rgba(path: Path) -> np.ndarray: + """ + Load mask from RGBA image (extract from alpha channel) + + Args: + path: RGBA image file path + + Returns: + mask: Binary mask, shape (H, W), bool format + """ + img = Image.open(path) + img_array = np.array(img) + + if img.mode == 'RGBA' and img_array.ndim == 3 and img_array.shape[2] >= 4: + mask = img_array[..., 3] > 0 + elif img.mode == 'RGB': + logger.warning(f"Mask file {path} is RGB format, not RGBA. Using all pixels as mask.") + mask = np.ones((img_array.shape[0], img_array.shape[1]), dtype=bool) + else: + logger.warning(f"Unexpected image mode {img.mode} for mask file {path}") + mask = np.ones((img_array.shape[0], img_array.shape[1]), dtype=bool) + + return mask + + +def load_images_and_masks( + images_and_masks_dir: Path, + image_names: Optional[List[str]] = None, +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Load multi-view data from images_and_masks folder + + Data structure: + images_and_masks/ + ├── 1.png (or image1.png, view_a.png, etc.) + ├── 1_mask.png (or image1_mask.png, view_a_mask.png) + ├── 2.png + ├── 2_mask.png + └── ... + + Args: + images_and_masks_dir: Path to images_and_masks folder + image_names: List of image names (without extension), e.g., ["image1", "view_a"] or ["1", "2"], + if None then auto-detect all + + Returns: + images: List of images (numpy arrays) + masks: List of masks (numpy arrays, bool format) + """ + if not images_and_masks_dir.exists(): + raise FileNotFoundError(f"Directory does not exist: {images_and_masks_dir}") + + if not images_and_masks_dir.is_dir(): + raise ValueError(f"Path is not a directory: {images_and_masks_dir}") + + if image_names is None: + image_files = sorted(images_and_masks_dir.glob("*.png")) + sorted(images_and_masks_dir.glob("*.jpg")) + image_files = [f for f in image_files if "_mask" not in f.name] + image_names = [f.stem for f in image_files] + logger.info(f"Auto-detected {len(image_names)} images: {image_names}") + + images = [] + masks = [] + + for image_name in image_names: + image_candidates = [ + images_and_masks_dir / f"{image_name}.png", + images_and_masks_dir / f"{image_name}.jpg", + ] + + mask_candidates = [ + images_and_masks_dir / f"{image_name}_mask.png", + images_and_masks_dir / f"{image_name}_mask.jpg", + ] + + image_path = None + for candidate in image_candidates: + if candidate.exists(): + image_path = candidate + break + + mask_path = None + for candidate in mask_candidates: + if candidate.exists(): + mask_path = candidate + break + + if image_path is None: + logger.warning(f"Image file not found for '{image_name}', skipping") + continue + + if mask_path is None: + logger.warning(f"Mask file not found for '{image_name}', skipping") + continue + + try: + image = load_image(image_path) + mask = load_mask_from_rgba(mask_path) + + images.append(image) + masks.append(mask) + + logger.info(f"Loaded '{image_name}': image={image.shape}, mask={mask.shape}") + + except Exception as e: + logger.error(f"Failed to load '{image_name}': {e}") + continue + + if len(images) == 0: + raise ValueError(f"No valid images and masks found in {images_and_masks_dir}") + + logger.info(f"Successfully loaded {len(images)} images") + return images, masks + + +def load_from_segmentation_structure( + segmentation_base_dir: Path, + prompt: Optional[str] = None, + view_indices: Optional[List[int]] = None, +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Load data from complete segmentation structure (kept for backward compatibility) + + Structure: + visualization/ + └── {folder_name}_segmentation/ + └── {prompt}/ + └── images_and_masks/ + + Args: + segmentation_base_dir: Segmentation base directory (e.g., visualization/joy_segmentation) + prompt: Prompt subfolder name (e.g., "stuffed_toy"), if None then auto-detect + view_indices: List of view indices, if None then auto-detect all + + Returns: + images: List of images + masks: List of masks + """ + if prompt: + segmentation_dir = segmentation_base_dir / prompt + else: + prompt_dirs = [d for d in segmentation_base_dir.iterdir() + if d.is_dir() and d.name != "all_masks"] + if len(prompt_dirs) == 0: + raise ValueError(f"No prompt subdirectories found in {segmentation_base_dir}") + elif len(prompt_dirs) == 1: + segmentation_dir = prompt_dirs[0] + logger.info(f"Auto-detected prompt directory: {segmentation_dir.name}") + else: + raise ValueError( + f"Multiple prompt directories found in {segmentation_base_dir}. " + f"Please specify --prompt. Found: {[d.name for d in prompt_dirs]}" + ) + + images_and_masks_dir = segmentation_dir / "images_and_masks" + + if not images_and_masks_dir.exists(): + raise FileNotFoundError( + f"images_and_masks directory not found: {images_and_masks_dir}" + ) + + if view_indices is not None and len(view_indices) > 0 and isinstance(view_indices[0], int): + image_names = [str(idx) for idx in view_indices] + else: + image_names = view_indices + return load_images_and_masks(images_and_masks_dir, image_names=image_names) + + +def load_images_and_masks_from_path( + input_path: Path, + mask_prompt: Optional[str] = None, + image_names: Optional[List[str]] = None, +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Load multi-view data from specified path (supports two data structures) + + Data structure 1 (mask_prompt=None): + input_path/ + ├── 1.png + ├── 1_mask.png + ├── 2.png + ├── 2_mask.png + └── ... + + Data structure 2 (mask_prompt!=None, e.g., mask_prompt="stuffed_toy"): + input_path/ + ├── images/ + │ ├── 1.png + │ ├── 2.png + │ └── ... + └── stuffed_toy/ (or {mask_prompt}/) + ├── 1.png (or 1_mask.png) + ├── 2.png (or 2_mask.png) + └── ... + + Args: + input_path: Input path + mask_prompt: Mask folder name, if None then images and masks are in the same directory + image_names: List of image names (without extension), e.g., ["image1", "view_a"] or ["1", "2"], + if None then auto-detect all + + Returns: + images: List of images + masks: List of masks + """ + if not input_path.exists(): + raise FileNotFoundError(f"Input path does not exist: {input_path}") + + if not input_path.is_dir(): + raise ValueError(f"Input path is not a directory: {input_path}") + + if mask_prompt is None: + logger.info(f"Loading from single directory: {input_path}") + return load_images_and_masks(input_path, image_names=image_names) + else: + images_dir = input_path / "images" + masks_dir = input_path / mask_prompt + + if not images_dir.exists(): + raise FileNotFoundError(f"Images directory does not exist: {images_dir}") + if not masks_dir.exists(): + raise FileNotFoundError(f"Mask directory does not exist: {masks_dir}") + + logger.info(f"Loading images from: {images_dir}") + logger.info(f"Loading masks from: {masks_dir}") + + if image_names is None: + image_files = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg")) + image_names = [f.stem for f in image_files] + logger.info(f"Auto-detected {len(image_names)} images: {image_names}") + + images = [] + masks = [] + + for image_name in image_names: + image_candidates = [ + images_dir / f"{image_name}.png", + images_dir / f"{image_name}.jpg", + ] + + mask_candidates = [ + masks_dir / f"{image_name}.png", + masks_dir / f"{image_name}_mask.png", + masks_dir / f"{image_name}.jpg", + masks_dir / f"{image_name}_mask.jpg", + ] + + image_path = None + for candidate in image_candidates: + if candidate.exists(): + image_path = candidate + break + + mask_path = None + for candidate in mask_candidates: + if candidate.exists(): + mask_path = candidate + break + + if image_path is None: + logger.warning(f"Image file not found for '{image_name}', skipping") + continue + + if mask_path is None: + logger.warning(f"Mask file not found for '{image_name}', skipping") + continue + + try: + image = load_image(image_path) + mask = load_mask_from_rgba(mask_path) + + images.append(image) + masks.append(mask) + + logger.info(f"Loaded '{image_name}': image={image.shape}, mask={mask.shape}") + + except Exception as e: + logger.error(f"Failed to load '{image_name}': {e}") + continue + + if len(images) == 0: + raise ValueError(f"No valid images and masks found in {input_path}") + + logger.info(f"Successfully loaded {len(images)} images") + return images, masks + diff --git a/run_inference.py b/run_inference.py new file mode 100644 index 0000000..c06df2c --- /dev/null +++ b/run_inference.py @@ -0,0 +1,321 @@ +""" +SAM 3D Objects Inference Script +Supports both single-view and multi-view 3D reconstruction + +Usage: + # Multi-view inference (mask_prompt=None, images and masks in same directory, use all images) + python run_inference.py --input_path ./data/images_and_masks + + # Single-view inference (specify a single image name) + python run_inference.py --input_path ./data/images_and_masks --image_names image1 + + # Multi-view inference (mask_prompt!=None, images in images/, masks in specified folder) + python run_inference.py --input_path ./data --mask_prompt stuffed_toy + + # Specify multiple image names (can be any filename without extension) + python run_inference.py --input_path ./data --mask_prompt stuffed_toy --image_names image1,view_a,2 +""" +import sys +import argparse +from pathlib import Path +from typing import List, Optional +from loguru import logger + +# 导入推理代码 +sys.path.append("notebook") +from inference import Inference +from load_images_and_masks import load_images_and_masks_from_path + + +def parse_image_names(image_names_str: Optional[str]) -> Optional[List[str]]: + """ + Parse image names string + + Args: + image_names_str: Image names string, e.g., "image1,view_a" or "1,2" or "image1" + Can be any filename (without extension) or numbers + + Returns: + image_names: List of image names (without extension), None means use all available images + """ + if image_names_str is None or image_names_str == "": + return None + + names = [x.strip() for x in image_names_str.split(",") if x.strip()] + return names if names else None + + +def get_output_dir( + input_path: Path, + mask_prompt: Optional[str] = None, + image_names: Optional[List[str]] = None, + is_single_view: bool = False +) -> Path: + """ + Create output directory based on input path and parameters + + Args: + input_path: Input path + mask_prompt: Mask folder name (if using separated directory structure) + image_names: List of image names + is_single_view: Whether it's single-view inference + + Returns: + output_dir: Path to visualization/{mask_prompt_or_dirname}_{image_names}/ directory + """ + visualization_dir = Path("visualization") + visualization_dir.mkdir(exist_ok=True) + + if mask_prompt: + dir_name = mask_prompt + else: + dir_name = input_path.name if input_path.is_dir() else input_path.parent.name + + if is_single_view: + if image_names and len(image_names) == 1: + safe_name = image_names[0].replace("/", "_").replace("\\", "_") + dir_name = f"{dir_name}_{safe_name}" + else: + dir_name = f"{dir_name}_single" + elif image_names: + if len(image_names) == 1: + safe_name = image_names[0].replace("/", "_").replace("\\", "_") + dir_name = f"{dir_name}_{safe_name}" + else: + safe_names = [name.replace("/", "_").replace("\\", "_") for name in image_names] + dir_name = f"{dir_name}_{'_'.join(safe_names[:3])}" + if len(safe_names) > 3: + dir_name += f"_and_{len(safe_names)-3}_more" + else: + dir_name = f"{dir_name}_multiview" + + output_dir = visualization_dir / dir_name + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Output directory: {output_dir}") + return output_dir + + +def run_inference( + input_path: Path, + mask_prompt: Optional[str] = None, + image_names: Optional[List[str]] = None, + seed: int = 42, + stage1_steps: int = 50, + stage2_steps: int = 25, + decode_formats: List[str] = None, + model_tag: str = "hf", +): + """ + Run inference + + Args: + input_path: Input path + mask_prompt: Mask folder name, if None then images and masks are in the same directory + image_names: List of image names (without extension), e.g., ["image1", "view_a"] or ["1", "2"], + None means use all available images + seed: Random seed + stage1_steps: Stage 1 inference steps + stage2_steps: Stage 2 inference steps + decode_formats: List of decode formats + model_tag: Model tag + """ + config_path = f"checkpoints/{model_tag}/pipeline.yaml" + if not Path(config_path).exists(): + raise FileNotFoundError(f"Model config file not found: {config_path}") + + logger.info(f"Loading model: {config_path}") + inference = Inference(config_path, compile=False) + + if hasattr(inference._pipeline, 'rendering_engine'): + if inference._pipeline.rendering_engine != "pytorch3d": + logger.warning(f"Rendering engine is set to {inference._pipeline.rendering_engine}, changing to pytorch3d") + inference._pipeline.rendering_engine = "pytorch3d" + + logger.info(f"Loading data: {input_path}") + if mask_prompt: + logger.info(f"Mask prompt: {mask_prompt} (images in images/, masks in {mask_prompt}/)") + else: + logger.info("Mask prompt: None (images and masks in same directory)") + + view_images, view_masks = load_images_and_masks_from_path( + input_path=input_path, + mask_prompt=mask_prompt, + image_names=image_names, + ) + + num_views = len(view_images) + logger.info(f"Successfully loaded {num_views} views") + + is_single_view = num_views == 1 + + if is_single_view: + logger.info("Single-view inference mode") + image = view_images[0] + mask = view_masks[0] if view_masks else None + result = inference(image, mask, seed=seed) + else: + logger.info("Multi-view inference mode") + decode_formats = decode_formats or ["gaussian", "mesh"] + + result = inference._pipeline.run_multi_view( + view_images=view_images, + view_masks=view_masks, + seed=seed, + mode="multidiffusion", + stage1_inference_steps=stage1_steps, + stage2_inference_steps=stage2_steps, + decode_formats=decode_formats, + with_mesh_postprocess=False, + with_texture_baking=False, + use_vertex_color=True, + ) + + output_dir = get_output_dir(input_path, mask_prompt, image_names, is_single_view) + saved_files = [] + + print(f"\n{'='*60}") + print(f"Inference completed!") + print(f"Generated coordinates: {result['coords'].shape[0] if 'coords' in result else 'N/A'}") + print(f"{'='*60}") + + if 'glb' in result and result['glb'] is not None: + output_path = output_dir / "result.glb" + result['glb'].export(str(output_path)) + saved_files.append("result.glb") + print(f"✓ GLB file saved to: {output_path}") + + if 'gs' in result: + output_path = output_dir / "result.ply" + result['gs'].save_ply(str(output_path)) + saved_files.append("result.ply") + print(f"✓ Gaussian Splatting (PLY) saved to: {output_path}") + elif 'gaussian' in result: + if isinstance(result['gaussian'], list) and len(result['gaussian']) > 0: + output_path = output_dir / "result.ply" + result['gaussian'][0].save_ply(str(output_path)) + saved_files.append("result.ply") + print(f"✓ Gaussian Splatting (PLY) saved to: {output_path}") + + if 'mesh' in result: + print(f"✓ Mesh information generated (included in GLB)") + + print(f"\n{'='*60}") + print(f"All output files saved to: {output_dir}") + print(f"Saved files: {', '.join(saved_files)}") + print(f"{'='*60}") + + print(f"\nFile descriptions:") + print(f"- PLY file: Gaussian Splatting format with position and color information") + print(f" * Recommended to use specialized Gaussian Splatting viewers") + print(f"- GLB file: Complete 3D mesh model, can be viewed in Blender, Three.js, etc.") + + +def main(): + parser = argparse.ArgumentParser( + description="SAM 3D Objects Inference Script - Supports single-view and multi-view 3D reconstruction", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Multi-view inference (mask_prompt=None, images and masks in same directory, use all images) + python run_inference.py --input_path ./data/images_and_masks + + # Single-view inference (specify a single image name) + python run_inference.py --input_path ./data/images_and_masks --image_names image1 + + # Multi-view inference (mask_prompt!=None, images in images/, masks in specified folder) + python run_inference.py --input_path ./data --mask_prompt stuffed_toy + + # Specify multiple image names (can be any filename without extension) + python run_inference.py --input_path ./data --mask_prompt stuffed_toy --image_names image1,view_a,2 + """ + ) + + parser.add_argument( + "--input_path", + type=str, + required=True, + help="Input path. If mask_prompt=None, images and masks are in this directory; " + "if mask_prompt!=None, images are in input_path/images/, masks in input_path/{mask_prompt}/" + ) + parser.add_argument( + "--mask_prompt", + type=str, + default=None, + help="Mask folder name. If None, images and masks are in the same directory " + "(naming format: xxxx.png and xxxx_mask.png); " + "if not None, images are in input_path/images/, masks in input_path/{mask_prompt}/" + ) + parser.add_argument( + "--image_names", + type=str, + default=None, + help="Image names (without extension), e.g., 'image1,view_a' or '1,2' or 'image1'. " + "Can specify multiple, comma-separated. If not specified, use all available images" + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed (default: 42)" + ) + parser.add_argument( + "--stage1_steps", + type=int, + default=50, + help="Stage 1 inference steps (default: 50)" + ) + parser.add_argument( + "--stage2_steps", + type=int, + default=25, + help="Stage 2 inference steps (default: 25)" + ) + + parser.add_argument( + "--decode_formats", + type=str, + default="gaussian,mesh", + help="Decode formats, comma-separated, e.g., 'gaussian,mesh' or 'gaussian' (default: gaussian,mesh)" + ) + + parser.add_argument( + "--model_tag", + type=str, + default="hf", + help="Model tag (default: hf)" + ) + + args = parser.parse_args() + + input_path = Path(args.input_path) + if not input_path.exists(): + raise FileNotFoundError(f"Input path does not exist: {input_path}") + + image_names = parse_image_names(args.image_names) + + decode_formats = [fmt.strip() for fmt in args.decode_formats.split(",") if fmt.strip()] + if not decode_formats: + decode_formats = ["gaussian", "mesh"] + + try: + run_inference( + input_path=input_path, + mask_prompt=args.mask_prompt, + image_names=image_names, + seed=args.seed, + stage1_steps=args.stage1_steps, + stage2_steps=args.stage2_steps, + decode_formats=decode_formats, + model_tag=args.model_tag, + ) + except Exception as e: + logger.error(f"Inference failed: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py b/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py index de7a4da..1f60ace 100644 --- a/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py +++ b/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py @@ -615,12 +615,17 @@ def to_glb( if with_mesh_postprocess: # mesh postprocess + # 如果rendering_engine是pytorch3d,禁用fill_holes因为它需要nvdiffrast + effective_fill_holes = fill_holes and rendering_engine == "nvdiffrast" + if fill_holes and rendering_engine == "pytorch3d": + logger.warning("fill_holes is disabled because rendering_engine is 'pytorch3d' (requires nvdiffrast)") + vertices, faces = postprocess_mesh( vertices, faces, simplify=simplify > 0, simplify_ratio=simplify, - fill_holes=fill_holes, + fill_holes=effective_fill_holes, fill_holes_max_hole_size=fill_holes_max_size, fill_holes_max_hole_nbe=int(250 * np.sqrt(1 - simplify)), fill_holes_resolution=1024, diff --git a/sam3d_objects/pipeline/inference_pipeline.py b/sam3d_objects/pipeline/inference_pipeline.py index e4b81fe..15cb0ef 100644 --- a/sam3d_objects/pipeline/inference_pipeline.py +++ b/sam3d_objects/pipeline/inference_pipeline.py @@ -20,7 +20,7 @@ def set_attention_backend(): set_attention_backend() -from typing import List, Union +from typing import List, Union, Optional, Literal from hydra.utils import instantiate from omegaconf import OmegaConf import numpy as np @@ -842,3 +842,393 @@ def _get_dtype(dtype): return torch.float32 else: raise NotImplementedError + + def get_multi_view_condition_input( + self, + condition_embedder, + view_input_dicts: List[dict], + input_mapping + ): + """ + 为多视角输入准备条件 + + Args: + condition_embedder: 条件嵌入器 + view_input_dicts: 每个视角的输入字典列表 + input_mapping: 输入映射 + + Returns: + condition_args: 条件参数(包含所有视角的条件tokens) + condition_kwargs: 条件关键字参数 + """ + # 对每个视角分别提取条件 + view_conditions = [] + for view_input_dict in view_input_dicts: + condition_args = self.map_input_keys(view_input_dict, input_mapping) + condition_kwargs = { + k: v for k, v in view_input_dict.items() if k not in input_mapping + } + embedded_cond, _, _ = self.embed_condition( + condition_embedder, *condition_args, **condition_kwargs + ) + if embedded_cond is not None: + view_conditions.append(embedded_cond) + else: + # 如果没有嵌入,使用原始参数 + view_conditions.append(condition_args) + + # 将所有视角的条件堆叠在一起 + # 形状: (num_views, batch_size, num_tokens, dim) + if isinstance(view_conditions[0], torch.Tensor): + # 如果是tensor,堆叠 + all_conditions = torch.stack(view_conditions, dim=0) + else: + # 如果是其他类型,保持为列表 + all_conditions = view_conditions + + return (all_conditions,), {} + + def sample_sparse_structure_multi_view( + self, + view_ss_input_dicts: List[dict], + inference_steps=None, + use_distillation=False, + mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion', + ): + """ + 多视角稀疏结构生成 + + Args: + view_ss_input_dicts: 每个视角的输入字典列表 + inference_steps: 推理步数 + use_distillation: 是否使用蒸馏 + mode: 'stochastic' 或 'multidiffusion' + """ + from sam3d_objects.pipeline.multi_view_utils import inject_generator_multi_view + + ss_generator = self.models["ss_generator"] + ss_decoder = self.models["ss_decoder"] + num_views = len(view_ss_input_dicts) + + if use_distillation: + ss_generator.no_shortcut = False + ss_generator.reverse_fn.strength = 0 + ss_generator.reverse_fn.strength_pm = 0 + else: + ss_generator.no_shortcut = True + ss_generator.reverse_fn.strength = self.ss_cfg_strength + ss_generator.reverse_fn.strength_pm = self.ss_cfg_strength_pm + + prev_inference_steps = ss_generator.inference_steps + if inference_steps: + ss_generator.inference_steps = inference_steps + + image = view_ss_input_dicts[0]["image"] + bs = image.shape[0] + logger.info( + f"Sampling sparse structure with {num_views} views: " + f"inference_steps={ss_generator.inference_steps}, mode={mode}" + ) + + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=self.shape_model_dtype): + if self.is_mm_dit(): + latent_shape_dict = { + k: (bs,) + (v.pos_emb.shape[0], v.input_layer.in_features) + for k, v in ss_generator.reverse_fn.backbone.latent_mapping.items() + } + logger.info(f"[Stage 1] Latent shape (MM-DiT): {latent_shape_dict}") + else: + latent_shape_dict = (bs,) + (4096, 8) + logger.info(f"[Stage 1] Latent shape: {latent_shape_dict}") + + # 准备多视角条件 + condition_args, condition_kwargs = self.get_multi_view_condition_input( + self.condition_embedders["ss_condition_embedder"], + view_ss_input_dicts, + self.ss_condition_input_mapping, + ) + + # 注入多视角支持 + with inject_generator_multi_view( + ss_generator, + num_views=num_views, + num_steps=ss_generator.inference_steps, + mode=mode + ): + return_dict = ss_generator( + latent_shape_dict, + image.device, + *condition_args, + **condition_kwargs, + ) + + if not self.is_mm_dit(): + return_dict = {"shape": return_dict} + + shape_latent = return_dict["shape"] + logger.info(f"[Stage 1 Multi-view] Generated shape_latent shape: {shape_latent.shape}") + ss = ss_decoder( + shape_latent.permute(0, 2, 1) + .contiguous() + .view(shape_latent.shape[0], 8, 16, 16, 16) + ) + logger.info(f"[Stage 1 Multi-view] Decoded sparse structure shape: {ss.shape}") + coords = torch.argwhere(ss > 0)[:, [0, 2, 3, 4]].int() + logger.info(f"[Stage 1 Multi-view] Extracted coords shape: {coords.shape}") + + # downsample output + return_dict["coords_original"] = coords + original_shape = coords.shape + if self.downsample_ss_dist > 0: + coords = prune_sparse_structure( + coords, + max_neighbor_axes_dist=self.downsample_ss_dist, + ) + coords, downsample_factor = downsample_sparse_structure(coords) + logger.info( + f"[Stage 1 Multi-view] Downsampled coords from {original_shape[0]} to {coords.shape[0]}" + ) + return_dict["coords"] = coords + return_dict["downsample_factor"] = downsample_factor + + ss_generator.inference_steps = prev_inference_steps + return return_dict + + def sample_slat_multi_view( + self, + view_slat_input_dicts: List[dict], + coords: torch.Tensor, + inference_steps=25, + use_distillation=False, + mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion', + ) -> sp.SparseTensor: + """ + 多视角结构化潜在生成 + + Args: + view_slat_input_dicts: 每个视角的输入字典列表 + coords: 坐标(从Stage 1得到) + inference_steps: 推理步数 + use_distillation: 是否使用蒸馏 + mode: 'stochastic' 或 'multidiffusion' + """ + from sam3d_objects.pipeline.multi_view_utils import inject_generator_multi_view + + image = view_slat_input_dicts[0]["image"] + DEVICE = image.device + slat_generator = self.models["slat_generator"] + num_views = len(view_slat_input_dicts) + latent_shape = (image.shape[0],) + (coords.shape[0], 8) + logger.info(f"[Stage 2] Coords shape: {coords.shape}") + logger.info(f"[Stage 2] Latent shape: {latent_shape}") + prev_inference_steps = slat_generator.inference_steps + if inference_steps: + slat_generator.inference_steps = inference_steps + if use_distillation: + slat_generator.no_shortcut = False + slat_generator.reverse_fn.strength = 0 + else: + slat_generator.no_shortcut = True + slat_generator.reverse_fn.strength = self.slat_cfg_strength + + logger.info( + f"Sampling sparse latent with {num_views} views: " + f"inference_steps={slat_generator.inference_steps}, mode={mode}" + ) + + with torch.autocast(device_type="cuda", dtype=self.dtype): + with torch.no_grad(): + # 准备多视角条件 + condition_args, condition_kwargs = self.get_multi_view_condition_input( + self.condition_embedders["slat_condition_embedder"], + view_slat_input_dicts, + self.slat_condition_input_mapping, + ) + condition_args += (coords.cpu().numpy(),) + + # 注入多视角支持 + with inject_generator_multi_view( + slat_generator, + num_views=num_views, + num_steps=slat_generator.inference_steps, + mode=mode + ): + slat = slat_generator( + latent_shape, DEVICE, *condition_args, **condition_kwargs + ) + + logger.info(f"[Stage 2] Generated slat shape (before SparseTensor): {slat[0].shape if isinstance(slat, (list, tuple)) else slat.shape}") + slat = sp.SparseTensor( + coords=coords, + feats=slat[0], + ).to(DEVICE) + slat = slat * self.slat_std.to(DEVICE) + self.slat_mean.to(DEVICE) + logger.info(f"[Stage 2] Final slat: coords={slat.coords.shape}, feats={slat.feats.shape}") + + slat_generator.inference_steps = prev_inference_steps + return slat + + def run_multi_view( + self, + view_images: List[Union[np.ndarray, Image.Image]], + view_masks: List[Optional[Union[None, np.ndarray, Image.Image]]] = None, + num_samples: int = 1, + seed: Optional[int] = None, + stage1_inference_steps: Optional[int] = None, + stage2_inference_steps: Optional[int] = None, + use_stage1_distillation: bool = False, + use_stage2_distillation: bool = False, + decode_formats: Optional[List[str]] = None, + with_mesh_postprocess: bool = True, + with_texture_baking: bool = True, + use_vertex_color: bool = False, + stage1_only: bool = False, + mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion', + ) -> dict: + """ + 多视角推理主函数 + + Args: + view_images: 每个视角的图像列表 + view_masks: 每个视角的掩码列表(可选) + num_samples: 生成样本数 + seed: 随机种子 + stage1_inference_steps: Stage 1推理步数 + stage2_inference_steps: Stage 2推理步数 + use_stage1_distillation: 是否使用Stage 1蒸馏 + use_stage2_distillation: 是否使用Stage 2蒸馏 + decode_formats: 解码格式 + with_mesh_postprocess: 是否进行网格后处理 + with_texture_baking: 是否进行纹理烘焙 + use_vertex_color: 是否使用顶点颜色 + stage1_only: 是否只运行Stage 1 + mode: 'stochastic' 或 'multidiffusion' + """ + num_views = len(view_images) + if view_masks is None: + view_masks = [None] * num_views + assert len(view_masks) == num_views, "Number of masks must match number of images" + + if seed is not None: + torch.manual_seed(seed) + + logger.info(f"Running multi-view inference with {num_views} views, mode={mode}") + + # 预处理每个视角 + # 注意:需要先将mask合并到图像的alpha通道,然后调用preprocess_image + view_ss_input_dicts = [] + view_slat_input_dicts = [] + for i, (image, mask) in enumerate(zip(view_images, view_masks)): + logger.info(f"Preprocessing view {i+1}/{num_views}") + + # 将mask合并到图像的alpha通道(RGBA格式) + # 如果image已经是RGBA格式(从mask的alpha通道加载),mask可能是None + if mask is not None: + # 确保image是numpy数组 + if isinstance(image, Image.Image): + image = np.array(image) + else: + image = np.array(image) + + # 确保mask是numpy数组 + mask = np.array(mask) + + # 如果mask是bool类型,转换为uint8 + if mask.dtype == bool: + mask = mask.astype(np.uint8) * 255 + elif mask.dtype != np.uint8: + # 如果mask是0-1范围的float,转换为0-255 + if mask.max() <= 1.0: + mask = (mask * 255).astype(np.uint8) + else: + mask = mask.astype(np.uint8) + + if mask.ndim == 2: + mask = mask[..., None] + + # 合并mask到alpha通道 + if image.shape[-1] == 3: # RGB + rgba_image = np.concatenate([image, mask], axis=-1).astype(np.uint8) + elif image.shape[-1] == 4: # 已经是RGBA,替换alpha通道 + rgba_image = np.concatenate([image[..., :3], mask], axis=-1).astype(np.uint8) + else: + raise ValueError(f"Unexpected image shape: {image.shape}") + else: + # 如果没有mask,假设image已经是RGBA格式 + if isinstance(image, Image.Image): + rgba_image = np.array(image) + else: + rgba_image = np.array(image) + + # 转换为PIL Image(preprocess_image需要) + rgba_image_pil = Image.fromarray(rgba_image) + + # 调用preprocess_image(注意:InferencePipelinePointMap需要pointmap) + # 先检查是否是InferencePipelinePointMap + if hasattr(self, 'compute_pointmap'): + # 这是InferencePipelinePointMap,需要计算pointmap + pointmap_dict = self.compute_pointmap(rgba_image_pil, pointmap=None) + pointmap = pointmap_dict["pointmap"] + ss_input_dict = self.preprocess_image( + rgba_image_pil, self.ss_preprocessor, pointmap=pointmap + ) + slat_input_dict = self.preprocess_image( + rgba_image_pil, self.slat_preprocessor + ) + else: + # 这是InferencePipeline,不需要pointmap + ss_input_dict = self.preprocess_image( + rgba_image_pil, self.ss_preprocessor + ) + slat_input_dict = self.preprocess_image( + rgba_image_pil, self.slat_preprocessor + ) + + view_ss_input_dicts.append(ss_input_dict) + view_slat_input_dicts.append(slat_input_dict) + + # Stage 1: 生成稀疏结构 + logger.info("Stage 1: Sampling sparse structure...") + ss_return_dict = self.sample_sparse_structure_multi_view( + view_ss_input_dicts, + inference_steps=stage1_inference_steps, + use_distillation=use_stage1_distillation, + mode=mode, + ) + + ss_return_dict.update(self.pose_decoder(ss_return_dict)) + + if "scale" in ss_return_dict: + logger.info(f"Rescaling scale by {ss_return_dict['downsample_factor']}") + ss_return_dict["scale"] = ss_return_dict["scale"] * ss_return_dict["downsample_factor"] + + if stage1_only: + logger.info("Finished!") + ss_return_dict["voxel"] = ss_return_dict["coords"][:, 1:] / 64 - 0.5 + return ss_return_dict + + # Stage 2: 生成结构化潜在 + coords = ss_return_dict["coords"] + logger.info("Stage 2: Sampling structured latent...") + slat = self.sample_slat_multi_view( + view_slat_input_dicts, + coords, + inference_steps=stage2_inference_steps, + use_distillation=use_stage2_distillation, + mode=mode, + ) + + # 解码 + outputs = self.decode_slat( + slat, self.decode_formats if decode_formats is None else decode_formats + ) + outputs = self.postprocess_slat_output( + outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color + ) + logger.info("Finished!") + + return { + **ss_return_dict, + **outputs, + } diff --git a/sam3d_objects/pipeline/multi_view_utils.py b/sam3d_objects/pipeline/multi_view_utils.py new file mode 100644 index 0000000..1f56613 --- /dev/null +++ b/sam3d_objects/pipeline/multi_view_utils.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +""" +Multi-view multidiffusion utilities for SAM 3D Objects +Adapted from TRELLIS implementation, adapted for SAM 3D Objects' two-stage structure +""" +from contextlib import contextmanager +from typing import List, Literal +import torch +from loguru import logger + + +@contextmanager +def inject_generator_multi_view( + generator, + num_views: int, + num_steps: int, + mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion', +): + """ + Inject multi-view support into generator + + Args: + generator: SAM 3D Objects generator (ss_generator or slat_generator) + num_views: Number of views + num_steps: Number of inference steps + mode: 'stochastic' or 'multidiffusion' + + Yields: + None + """ + original_dynamics = generator._generate_dynamics + + if mode == 'stochastic': + if num_views > num_steps: + logger.warning( + f"Warning: number of views ({num_views}) is greater than number of steps ({num_steps}). " + "This may lead to performance degradation." + ) + + cond_indices = (torch.arange(num_steps) % num_views).tolist() + cond_idx_counter = [0] + + def _new_dynamics_stochastic(x_t, t, *args_conditionals, **kwargs_conditionals): + """Stochastic mode: select one view per time step""" + cond_idx = cond_indices[cond_idx_counter[0] % len(cond_indices)] + cond_idx_counter[0] += 1 + + if len(args_conditionals) > 0: + cond_tokens = args_conditionals[0] + if isinstance(cond_tokens, (list, tuple)): + cond_i = cond_tokens[cond_idx:cond_idx+1] if isinstance(cond_tokens[0], torch.Tensor) else [cond_tokens[cond_idx]] + new_args = (cond_i,) + args_conditionals[1:] + elif isinstance(cond_tokens, torch.Tensor) and cond_tokens.shape[0] == num_views: + cond_i = cond_tokens[cond_idx:cond_idx+1] + new_args = (cond_i,) + args_conditionals[1:] + else: + new_args = args_conditionals + else: + new_args = args_conditionals + + return original_dynamics(x_t, t, *new_args, **kwargs_conditionals) + + generator._generate_dynamics = _new_dynamics_stochastic + + elif mode == 'multidiffusion': + def _new_dynamics_multidiffusion(x_t, t, *args_conditionals, **kwargs_conditionals): + """Multidiffusion mode: fuse predictions from all views at each time step""" + cond_idx = 0 + if len(args_conditionals) > 0: + if isinstance(args_conditionals[0], (int, float)) or (isinstance(args_conditionals[0], torch.Tensor) and args_conditionals[0].numel() == 1): + cond_idx = 1 + + if len(args_conditionals) > cond_idx: + cond_tokens = args_conditionals[cond_idx] + + if not hasattr(_new_dynamics_multidiffusion, '_logged_cond_shape'): + logger.info(f"[Multidiffusion] args_conditionals length: {len(args_conditionals)}") + logger.info(f"[Multidiffusion] cond_idx: {cond_idx}") + if isinstance(cond_tokens, torch.Tensor): + logger.info(f"[Multidiffusion] Condition tokens shape: {cond_tokens.shape}") + elif isinstance(cond_tokens, (list, tuple)): + logger.info(f"[Multidiffusion] Condition tokens type: {type(cond_tokens)}, length: {len(cond_tokens)}") + if len(cond_tokens) > 0 and isinstance(cond_tokens[0], torch.Tensor): + logger.info(f"[Multidiffusion] First condition token shape: {cond_tokens[0].shape}") + else: + logger.info(f"[Multidiffusion] Condition tokens type: {type(cond_tokens)}") + _new_dynamics_multidiffusion._logged_cond_shape = True + + if isinstance(cond_tokens, (list, tuple)): + view_conditions = cond_tokens + elif isinstance(cond_tokens, torch.Tensor) and cond_tokens.shape[0] == num_views: + view_conditions = [] + for i in range(num_views): + view_cond = cond_tokens[i] + view_conditions.append(view_cond) + else: + logger.warning(f"Condition tokens shape {cond_tokens.shape if isinstance(cond_tokens, torch.Tensor) else type(cond_tokens)} not organized by views, using same condition for all views") + view_conditions = [cond_tokens] * num_views + + preds = [] + for view_idx in range(num_views): + view_cond = view_conditions[view_idx] + if cond_idx < len(args_conditionals): + new_args = args_conditionals[:cond_idx] + (view_cond,) + args_conditionals[cond_idx+1:] + else: + new_args = args_conditionals + (view_cond,) + pred = original_dynamics(x_t, t, *new_args, **kwargs_conditionals) + preds.append(pred) + + if not hasattr(_new_dynamics_multidiffusion, '_logged_shape'): + if isinstance(x_t, dict): + logger.info(f"[Multidiffusion] Latent shape (dict): {[(k, v.shape if isinstance(v, torch.Tensor) else type(v)) for k, v in x_t.items()]}") + elif isinstance(x_t, (list, tuple)): + logger.info(f"[Multidiffusion] Latent shape (tuple/list): {[v.shape if isinstance(v, torch.Tensor) else type(v) for v in x_t]}") + else: + logger.info(f"[Multidiffusion] Latent shape: {x_t.shape if isinstance(x_t, torch.Tensor) else type(x_t)}") + + if isinstance(preds[0], dict): + logger.info(f"[Multidiffusion] Pred shape (dict): {[(k, v.shape if isinstance(v, torch.Tensor) else type(v)) for k, v in preds[0].items()]}") + elif isinstance(preds[0], (list, tuple)): + logger.info(f"[Multidiffusion] Pred shape (tuple/list): {[v.shape if isinstance(v, torch.Tensor) else type(v) for v in preds[0]]}") + else: + logger.info(f"[Multidiffusion] Pred shape: {preds[0].shape if isinstance(preds[0], torch.Tensor) else type(preds[0])}") + logger.info(f"[Multidiffusion] Number of views: {num_views}, fusing {len(preds)} predictions") + _new_dynamics_multidiffusion._logged_shape = True + + if isinstance(preds[0], dict): + fused_pred = {} + for key in preds[0].keys(): + fused_pred[key] = torch.stack([p[key] for p in preds]).mean(dim=0) + return fused_pred + elif isinstance(preds[0], (list, tuple)): + fused_pred = tuple( + torch.stack([p[i] for p in preds]).mean(dim=0) + for i in range(len(preds[0])) + ) + return fused_pred + else: + fused_pred = torch.stack(preds).mean(dim=0) + return fused_pred + else: + return original_dynamics(x_t, t, *args_conditionals, **kwargs_conditionals) + + generator._generate_dynamics = _new_dynamics_multidiffusion + + else: + raise ValueError(f"Unsupported mode: {mode}") + + try: + yield + finally: + generator._generate_dynamics = original_dynamics +