diff --git a/examples/offline_inference/mammothmodal2_preview/mammoth_moda2_image_summary.yaml b/examples/offline_inference/mammothmodal2_preview/mammoth_moda2_image_summary.yaml new file mode 100644 index 000000000..042e12afb --- /dev/null +++ b/examples/offline_inference/mammothmodal2_preview/mammoth_moda2_image_summary.yaml @@ -0,0 +1,18 @@ +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 16 + engine_args: + model_stage: ar + model_arch: MammothModa2ForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 8192 + gpu_memory_utilization: 0.5 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + enable_prefix_caching: false + final_output: true + final_output_type: text diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summary.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summary.py new file mode 100644 index 000000000..8b4ea9f2d --- /dev/null +++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summary.py @@ -0,0 +1,147 @@ +""" +Offline inference example: MammothModa2 image summarization (single AR stage). + +Example: + uv run python examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summary.py \ + --model /data/datasets/models-hf/MammothModa2-Preview \ + --image /path/to/input.jpg \ + --question "Please summarize the content of this image." +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path + +from PIL import Image +from vllm import SamplingParams +from vllm.multimodal.image import convert_image_mode + +from vllm_omni import Omni + +DEFAULT_SYSTEM = "You are a helpful assistant." +DEFAULT_QUESTION = "Please summarize the content of this image." + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="MammothModa2 image summarization (offline, AR only).") + parser.add_argument( + "--model", + type=str, + default="/data/datasets/models-hf/MammothModa2-Preview", + help="Path to model directory or model id.", + ) + parser.add_argument( + "--stage-config", + type=str, + default=str(Path(__file__).with_name("mammoth_moda2_image_summary.yaml")), + help="Path to stage config yaml (single-stage AR->text).", + ) + parser.add_argument( + "--image", + type=str, + required=True, + help="Path to input image.", + ) + parser.add_argument( + "--question", + type=str, + default=DEFAULT_QUESTION, + help="Question/instruction for the model.", + ) + parser.add_argument( + "--system", + type=str, + default=DEFAULT_SYSTEM, + help="System prompt.", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=512, + help="Max new tokens to generate.", + ) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top-p", type=float, default=0.9) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument( + "--out", + type=str, + default="image_summary.txt", + help="Path to save output text.", + ) + return parser.parse_args() + + +def build_prompt(system: str, question: str) -> str: + return ( + f"<|im_start|>system\n{system}<|im_end|>\n" + "<|im_start|>user\n" + "<|vision_start|><|image_pad|><|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +def main() -> None: + args = parse_args() + + if not os.path.exists(args.image): + raise FileNotFoundError(f"Image file not found: {args.image}") + + os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) + + pil_image = Image.open(args.image) + image_data = convert_image_mode(pil_image, "RGB") + prompt = build_prompt(args.system, args.question) + + omni = Omni( + model=args.model, + stage_configs_path=args.stage_config, + trust_remote_code=args.trust_remote_code, + ) + try: + sp = SamplingParams( + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=-1, + max_tokens=int(args.max_tokens), + seed=int(args.seed), + detokenize=True, + ) + outputs = omni.generate( + [ + { + "prompt": prompt, + "multi_modal_data": {"image": image_data}, + } + ], + [sp], + ) + finally: + omni.close() + + if not isinstance(outputs, list): + outputs = [outputs] + + lines: list[str] = [] + for stage_outputs in outputs: + req_outputs = getattr(stage_outputs, "request_output", stage_outputs) + req_outputs = req_outputs if isinstance(req_outputs, list) else [req_outputs] + for ro in req_outputs: + text = ro.outputs[0].text if getattr(ro, "outputs", None) else str(ro) + lines.append(f"request_id: {getattr(ro, 'request_id', 'unknown')}\n") + lines.append("answer:\n") + lines.append(text.strip() + "\n") + lines.append("\n") + + with open(args.out, "w", encoding="utf-8") as f: + f.writelines(lines) + + print(f"[OK] Saved summary to: {args.out}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py new file mode 100644 index 000000000..b41c4cbf4 --- /dev/null +++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py @@ -0,0 +1,254 @@ +""" +Offline inference example for MammothModa2 Text-to-Image (T2I) generation. +This script uses the vllm_omni.Omni pipeline with a multi-stage configuration. + +Workflow: +1. Stage 0 (AR): Generates visual tokens and their corresponding hidden states. +2. Stage 1 (DiT): Consumes the hidden states as conditions to perform diffusion + and VAE decoding to produce the final image. + +Example Usage: + uv run python examples/offline_inference/run_mammothmoda2_t2i.py \ + --model /path/to/MammothModa2-Preview \ + --prompt "A stylish woman riding a motorcycle in NYC, movie poster style" \ + --out output.png +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +from pathlib import Path + +import torch +from PIL import Image +from vllm.sampling_params import SamplingParams + +from vllm_omni import Omni + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +def load_t2i_generation_config(model_dir: str) -> tuple[int, int, int]: + """Load T2I token ranges from t2i_generation_config.json.""" + cfg_path = Path(model_dir) / "t2i_generation_config.json" + if not cfg_path.exists(): + raise FileNotFoundError(f"Config not found: {cfg_path}") + + with cfg_path.open("r", encoding="utf-8") as f: + cfg = json.load(f) + + return ( + int(cfg["eol_token_id"]), + int(cfg["visual_token_start_id"]), + int(cfg["visual_token_end_id"]), + ) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Run MammothModa2 T2I (AR -> DiT) with vLLM-Omni.") + p.add_argument( + "--model", + type=str, + default="/data/datasets/models-hf/MammothModa2-Preview", + help="Path to the model directory.", + ) + p.add_argument( + "--stage-config", + type=str, + default="vllm_omni/model_executor/stage_configs/mammoth_moda2.yaml", + help="Path to the multi-stage YAML configuration.", + ) + p.add_argument( + "--prompt", + type=str, + action="append", + default=None, + help=( + "Text prompt for image generation. Can be provided multiple times " + "to generate multiple images with shared height/width/CFG settings." + ), + ) + p.add_argument( + "--height", + type=int, + default=1024, + help="Output image height (must be a multiple of 16).", + ) + p.add_argument( + "--width", + type=int, + default=1024, + help="Output image width (must be a multiple of 16).", + ) + p.add_argument( + "--num-inference-steps", + type=int, + default=50, + help="Number of diffusion steps for the DiT stage.", + ) + p.add_argument( + "--text-guidance-scale", + type=float, + default=9.0, + help="Classifier-Free Guidance (CFG) scale for DiT.", + ) + p.add_argument( + "--cfg-range", + type=float, + nargs=2, + default=(0.0, 1.0), + help="Relative step range [start, end] where CFG is active.", + ) + p.add_argument("--out", type=str, default="output.png", help="Path to save the generated image.") + p.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading the model.") + args = p.parse_args() + if not args.prompt: + args.prompt = ["A stylish woman with sunglasses riding a motorcycle in NYC."] + return args + + +def tensor_to_pil(image: torch.Tensor) -> Image.Image: + """Convert a normalized torch tensor [-1, 1] to a PIL Image.""" + if image.ndim == 4: + image = image[0] + image = image.detach().to("cpu") + image = (image / 2 + 0.5).clamp(0, 1) + image = (image * 255).to(torch.uint8) + image = image.permute(1, 2, 0).contiguous().numpy() + return Image.fromarray(image) + + +def main() -> None: + args = parse_args() + os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True) + + if args.height <= 0 or args.width <= 0: + raise ValueError(f"Height and width must be positive, got {args.height}x{args.width}") + if args.height % 16 != 0 or args.width % 16 != 0: + raise ValueError(f"Height and width must be multiples of 16, got {args.height}x{args.width}") + + ar_height = args.height // 16 + ar_width = args.width // 16 + + eol_token_id, visual_start, visual_end = load_t2i_generation_config(args.model) + expected_grid_tokens = ar_height * (ar_width + 1) + + def _format_prompt(user_prompt: str) -> str: + return ( + "<|im_start|>system\nYou are a helpful image generator.<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + "<|im_start|>assistant\n" + f"<|image start|>{ar_width}*{ar_height}<|image token|>" + ) + + logger.info("Initializing Omni pipeline...") + omni = Omni(model=args.model, stage_configs_path=args.stage_config, trust_remote_code=args.trust_remote_code) + + try: + ar_sampling = SamplingParams( + temperature=1.0, + top_p=1.0, + top_k=2048, + # +1 for generating eoi, +1 for generating hidden state of eoi + max_tokens=max(1, expected_grid_tokens + 1 + 1), + detokenize=False, + ) + + dit_sampling = SamplingParams( + temperature=0.0, + top_p=1.0, + top_k=-1, + max_tokens=1, + detokenize=False, + ) + + logger.info("Starting generation...") + shared_additional_information = { + "omni_task": ["t2i"], + "ar_width": [ar_width], + "ar_height": [ar_height], + "eol_token_id": [eol_token_id], + "visual_token_start_id": [visual_start], + "visual_token_end_id": [visual_end], + "image_height": [args.height], + "image_width": [args.width], + "num_inference_steps": [args.num_inference_steps], + "text_guidance_scale": [args.text_guidance_scale], + "cfg_range": [args.cfg_range[0], args.cfg_range[1]], + } + inputs = [ + { + "prompt": _format_prompt(p), + "additional_information": dict(shared_additional_information), + } + for p in args.prompt + ] + + outputs = omni.generate(inputs, [ar_sampling, dit_sampling]) + + # `outputs` may contain one or multiple request results. + if not isinstance(outputs, list): + outputs = [outputs] + + logger.info("Post-processing and saving image(s)...") + out_base, out_ext = os.path.splitext(args.out) + saved_paths: list[str] = [] + + # Flatten to (image_tensor, suffix) list so we can decide filenames. + images_to_save: list[tuple[torch.Tensor, str]] = [] + for out_idx, out in enumerate(outputs): + ro = getattr(out, "request_output", out) + ro_list = ro if isinstance(ro, list) else [ro] + if not ro_list: + raise RuntimeError("Empty request_output from final stage.") + + req_id = getattr(out, "request_id", None) + req_suffix = f"_{req_id}" if isinstance(req_id, str) and req_id else f"_{out_idx}" + + for sample_idx, ro_item in enumerate(ro_list): + mm = getattr(ro_item, "multimodal_output", None) + if not isinstance(mm, dict) or "image" not in mm: + raise RuntimeError(f"Unexpected final output payload: {type(mm)} {mm}") + + img_payload = mm["image"] + img_list = img_payload if isinstance(img_payload, list) else [img_payload] + for img_idx, img_tensor in enumerate(img_list): + if not isinstance(img_tensor, torch.Tensor): + raise TypeError(f"Expected image tensor, got {type(img_tensor)}") + suffix_parts = [req_suffix] + if len(ro_list) > 1: + suffix_parts.append(f"_s{sample_idx}") + if len(img_list) > 1: + suffix_parts.append(f"_i{img_idx}") + images_to_save.append((img_tensor, "".join(suffix_parts))) + + # If there's only one image, respect `--out` exactly. + if len(images_to_save) == 1: + img_tensor, _ = images_to_save[0] + pil = tensor_to_pil(img_tensor) + pil.save(args.out) + saved_paths.append(args.out) + else: + if not out_ext: + out_ext = ".png" + for img_tensor, suffix in images_to_save: + out_path = f"{out_base}{suffix}{out_ext}" + pil = tensor_to_pil(img_tensor) + pil.save(out_path) + saved_paths.append(out_path) + + for p in saved_paths: + logger.info(f"Successfully saved generated image to: {p}") + + except Exception as e: + logger.exception(f"An error occurred during generation: {e}") + finally: + omni.close() + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index 4c1ee5388..4ef20533c 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -52,6 +52,43 @@ def _process_tokens( return inputs + def _process_text( + self, + parsed_content: TextPrompt, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> OmniTokenInputs | MultiModalInputs: + prompt_text = parsed_content["prompt"] + additional_information = parsed_content.get("additional_information") + + inputs: OmniTokenInputs | MultiModalInputs + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + else: + prompt_token_ids = self._tokenize_prompt( + prompt_text, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs_omni( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + additional_information=additional_information, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + if additional_information is not None and "additional_information" not in inputs: + inputs["additional_information"] = additional_information + + return inputs + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, diff --git a/vllm_omni/model_executor/models/__init__.py b/vllm_omni/model_executor/models/__init__.py index 0b2629b4a..5b5907ccb 100644 --- a/vllm_omni/model_executor/models/__init__.py +++ b/vllm_omni/model_executor/models/__init__.py @@ -1,4 +1,5 @@ +from .mammoth_moda2.config import Mammothmoda2Config # noqa: F401 registers AutoConfig from .qwen3_omni import Qwen3OmniMoeForConditionalGeneration from .registry import OmniModelRegistry # noqa: F401 -__all__ = ["Qwen3OmniMoeForConditionalGeneration"] +__all__ = ["Qwen3OmniMoeForConditionalGeneration", "Mammothmoda2Config"] diff --git a/vllm_omni/model_executor/models/mammoth_moda2/__init__.py b/vllm_omni/model_executor/models/mammoth_moda2/__init__.py new file mode 100644 index 000000000..8e952cbaa --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/__init__.py @@ -0,0 +1,6 @@ +from .config import Mammothmoda2Config # noqa: F401 registers AutoConfig +from .mammoth_moda2_ar import MammothModa2ARForConditionalGeneration + +__all__ = [ + "MammothModa2ARForConditionalGeneration", +] diff --git a/vllm_omni/model_executor/models/mammoth_moda2/config.py b/vllm_omni/model_executor/models/mammoth_moda2/config.py new file mode 100644 index 000000000..407ef5e89 --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/config.py @@ -0,0 +1,286 @@ +from typing import ClassVar, Literal + +from transformers import AutoConfig, AutoTokenizer, PretrainedConfig +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, + Qwen2_5_VLTextConfig, + Qwen2_5_VLVisionConfig, +) + +from .tokenizer import MammothUTokenizer + +__all__ = [ + "Mammothmoda2Config", + "Mammothmoda2Qwen2_5_VLConfig", + "Mammothmoda2Qwen2_5_VLTextConfig", + "Mammothmoda2Qwen2_5_VLVisionConfig", +] + + +class Mammothmoda2Qwen2_5_VLVisionConfig(Qwen2_5_VLVisionConfig): + + model_type = "mammothmoda2_qwen2_5_vl_vision" + base_config_key = "vision_config" + + def __init__( + self, + depth: int = 32, + hidden_size: int = 3584, + hidden_act: str = "silu", + intermediate_size: int = 3420, + num_heads: int = 16, + in_channels: int = 3, + patch_size: int = 14, + spatial_merge_size: int = 2, + temporal_patch_size: int = 2, + tokens_per_second: int = 4, + window_size: int = 112, + out_hidden_size: int = 3584, + fullatt_block_indexes: list[int] | None = None, + initializer_range: float = 0.02, + **kwargs, + ) -> None: + super().__init__( + depth=depth, + hidden_size=hidden_size, + hidden_act=hidden_act, + intermediate_size=intermediate_size, + num_heads=num_heads, + in_channels=in_channels, + patch_size=patch_size, + spatial_merge_size=spatial_merge_size, + temporal_patch_size=temporal_patch_size, + tokens_per_second=tokens_per_second, + window_size=window_size, + out_hidden_size=out_hidden_size, + fullatt_block_indexes=fullatt_block_indexes or [7, 15, 23, 31], + initializer_range=initializer_range, + **kwargs, + ) + + +class Mammothmoda2Qwen2_5_VLTextConfig(Qwen2_5_VLTextConfig): + + model_type = "mammothmoda2_qwen2_5_vl_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = Qwen2_5_VLTextConfig.base_model_tp_plan + base_model_pp_plan = Qwen2_5_VLTextConfig.base_model_pp_plan + + def __init__( + self, + vocab_size: int = 152064, + hidden_size: int = 8192, + intermediate_size: int = 29568, + num_hidden_layers: int = 80, + num_attention_heads: int = 64, + num_key_value_heads: int | None = 8, + hidden_act: str = "silu", + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-05, + use_cache: bool = True, + tie_word_embeddings: bool = False, + rope_theta: float = 1000000.0, + use_sliding_window: bool = False, + sliding_window: int = 4096, + max_window_layers: int = 80, + layer_types: list[str] | None = None, + attention_dropout: float = 0.0, + rope_scaling: dict | None = None, + image_token_id: int | None = None, + video_token_id: int | None = None, + extra_gen_vocab: bool = True, + gen_vocab_size: int = 32800, + gen_vocab_start_index: int | None = None, + moe_type: str = "ffn", + **kwargs, + ) -> None: + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + use_sliding_window=use_sliding_window, + sliding_window=sliding_window, + max_window_layers=max_window_layers, + layer_types=layer_types, + attention_dropout=attention_dropout, + rope_scaling=rope_scaling, + **kwargs, + ) + + self.extra_gen_vocab = extra_gen_vocab + self.gen_vocab_size = gen_vocab_size + self.moe_type = moe_type + if gen_vocab_start_index is None: + self.gen_vocab_start_index = ( + self.vocab_size if self.extra_gen_vocab else self.vocab_size - self.gen_vocab_size + ) + else: + self.gen_vocab_start_index = gen_vocab_start_index + + # NOTE: vLLM V1 uses `hf_text_config.vocab_size` for sampling parameter validation + # (e.g., allowed_token_ids). Although MammothModa2's gen vocab is implemented via + # independent gen_embed/gen_head, the overall vocab size should still cover the + # gen vocab token ID range from the perspective of "output logits dimension". + if self.extra_gen_vocab: + self.vocab_size = int(self.gen_vocab_start_index) + int(self.gen_vocab_size) + + # Extra token IDs for multi-modal placeholders. + self.image_token_id = image_token_id + self.video_token_id = video_token_id + +class Mammothmoda2Qwen2_5_VLConfig(Qwen2_5_VLConfig): + """Combined configuration: text_config + vision_config.""" + + model_type = "mammothmoda2_qwen2_5_vl" + sub_configs = { + "vision_config": Mammothmoda2Qwen2_5_VLVisionConfig, + "text_config": Mammothmoda2Qwen2_5_VLTextConfig, + } + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config: dict | PretrainedConfig | None = None, + vision_config: dict | PretrainedConfig | None = None, + image_token_id: int = 151655, + video_token_id: int = 151656, + vision_start_token_id: int = 151652, + vision_end_token_id: int = 151653, + extra_gen_vocab: bool = True, + gen_vocab_size: int = 32800, + gen_vocab_start_index: int | None = None, + moe_type: str = "ffn", + **kwargs, + ) -> None: + if isinstance(vision_config, dict): + vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + vision_config = self.sub_configs["vision_config"]() + + text_extra_kwargs = { + "extra_gen_vocab": extra_gen_vocab, + "gen_vocab_size": gen_vocab_size, + "moe_type": moe_type, + "gen_vocab_start_index": gen_vocab_start_index, + } + if isinstance(text_config, dict): + for key, val in text_extra_kwargs.items(): + text_config.setdefault(key, val) + text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + text_config = self.sub_configs["text_config"](**{**text_extra_kwargs, **kwargs}) + elif isinstance(text_config, PretrainedConfig): + text_config = text_config + + super().__init__( + text_config=text_config, + vision_config=vision_config, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + vision_end_token_id=vision_end_token_id, + **kwargs, + ) + + if not hasattr(self, "text_config"): + self.text_config = text_config + if not hasattr(self, "vision_config"): + self.vision_config = vision_config + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.extra_gen_vocab = getattr(self.text_config, "extra_gen_vocab", extra_gen_vocab) + self.gen_vocab_size = getattr(self.text_config, "gen_vocab_size", gen_vocab_size) + self.moe_type = getattr(self.text_config, "moe_type", moe_type) + self.gen_vocab_start_index = getattr(self.text_config, "gen_vocab_start_index", gen_vocab_start_index) + self.tokenizer_class = "MammothUTokenizer" + +class Mammothmoda2Config(PretrainedConfig): + """Top-level MammothModa2 composition configuration""" + + model_type = "mammothmoda2" + is_composition = True + sub_configs: ClassVar = {"llm_config": AutoConfig} + + def __init__( + self, + *, + llm_config: dict | None = None, + gen_vae_config: dict | None = None, + gen_dit_config: dict | None = None, + gen_condition_mode: Literal["text", "image", "text_image"] = "image", + gen_image_condition_refiner_config: dict | None = None, + gen_axes_dim_rope: list[int] | None = None, + gen_axes_lens: list[int] | None = None, + gen_transport_config: dict | None = None, + initializer_range: float = 0.02, + architectures: list[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.llm_config = AutoConfig.for_model(**llm_config) if llm_config is not None else None + self.gen_vae_config = gen_vae_config + self.gen_dit_config = gen_dit_config + + self.gen_condition_mode = gen_condition_mode + self.gen_image_condition_refiner_config = gen_image_condition_refiner_config + self.gen_axes_dim_rope = gen_axes_dim_rope or [40, 40, 40] + self.gen_axes_lens = gen_axes_lens or [10000, 10000, 10000] + self.gen_transport_config = gen_transport_config or {} + self.initializer_range = initializer_range + self.tokenizer_class = "MammothUTokenizer" + self.architectures = ["Mammothmoda2Model"] + + def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002 + return self.llm_config + + def _require_llm_config(self) -> PretrainedConfig: + if self.llm_config is None: + raise AttributeError("Mammothmoda2Config.llm_config is None") + return self.llm_config + + # ---- Proxy attrs for vLLM multimodal/mrope helpers ---- + # vllm_omni/model_executor/layers/mrope.py expects these fields on `hf_config`. + # In MammothModa2, they live in the nested `llm_config` (VL config), so we + # expose them here to make the top-level composition config compatible. + @property + def vision_config(self): + return self._require_llm_config().vision_config + + @property + def image_token_id(self) -> int: + return int(self._require_llm_config().image_token_id) + + @property + def video_token_id(self) -> int: + return int(self._require_llm_config().video_token_id) + + @property + def vision_start_token_id(self) -> int: + return int(self._require_llm_config().vision_start_token_id) + + @property + def vision_end_token_id(self) -> int: + return int(self._require_llm_config().vision_end_token_id) + + +# Register model_type -> config class for AutoConfig +AutoConfig.register(Mammothmoda2Config.model_type, Mammothmoda2Config) +AutoConfig.register(Mammothmoda2Qwen2_5_VLConfig.model_type, Mammothmoda2Qwen2_5_VLConfig) +AutoConfig.register(Mammothmoda2Qwen2_5_VLTextConfig.model_type, Mammothmoda2Qwen2_5_VLTextConfig) +AutoConfig.register(Mammothmoda2Qwen2_5_VLVisionConfig.model_type, Mammothmoda2Qwen2_5_VLVisionConfig) + +AutoTokenizer.register(config_class=Mammothmoda2Config, slow_tokenizer_class=MammothUTokenizer) +AutoTokenizer.register(config_class=Mammothmoda2Qwen2_5_VLConfig, slow_tokenizer_class=MammothUTokenizer) diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2.py b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2.py new file mode 100644 index 000000000..6ceef8d06 --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import torch +from torch import nn +from vllm.config import VllmConfig +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration +from vllm.model_executor.models.utils import init_vllm_registered_model, maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY + +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights + +from .mammoth_moda2_ar import ( + MammothModa2ARDummyInputsBuilder, + MammothModa2ARMultiModalProcessor, + MammothModa2ARProcessingInfo, +) + + +@MULTIMODAL_REGISTRY.register_processor( + MammothModa2ARMultiModalProcessor, + info=MammothModa2ARProcessingInfo, + dummy_inputs=MammothModa2ARDummyInputsBuilder, +) +class MammothModa2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + # Ensure vllm_omni/worker/gpu_model_runner.py's `extract_multimodal_outputs` follows + # the OmniOutput branch to retrieve text_hidden_states as a pure torch.Tensor, + # preventing errors in `hidden_states[logit_indices]` due to type mismatch (list/tuple). + have_multimodal_outputs = True + + multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} + merge_by_field_config = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + # Consistent with Qwen2_5OmniForConditionalGeneration: instance-level flag. + self.have_multimodal_outputs = True + self.vllm_config = vllm_config + cfg = vllm_config.model_config.hf_config + self.model_stage = vllm_config.model_config.model_stage + self.multimodal_config = vllm_config.model_config.multimodal_config + + # For debugging/alignment with qwen2.5-omni: explicitly nullify unused stages. + self.ar = None + self.dit = None + self.vae = None + + if self.model_stage == "ar": + # AR stage: multi-modal + MoE text. + self.ar = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "ar"), + hf_config=cfg.llm_config if hasattr(cfg, "llm_config") else cfg.text_config, + architectures=["MammothModa2ARForConditionalGeneration"], + ) + self.model = self.ar + elif self.model_stage == "dit": + self.dit = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "dit"), + # NOTE: init_vllm_registered_model -> VllmConfig.with_hf_config requires a + # transformers.PretrainedConfig; however, Mammothmoda2Config.gen_dit_config + # is a dict (diffusers config). The DiT stage hf_config still uses the + # top-level Mammothmoda2Config, and the DiT module reads its own + # gen_dit_config / gen_vae_config dicts. + hf_config=cfg, + architectures=["MammothModa2DiTForConditionalGeneration"], + ) + self.model = self.dit + elif self.model_stage == "vae": + # Reserved: VAEs not implemented yet; raise explicit error. + raise NotImplementedError("MammothModa2 VAE stage not implemented yet.") + else: + raise ValueError(f"Unsupported model_stage: {self.model_stage}") + + # Expose intermediate tensor factory for PP if provided by the submodule. + self.make_empty_intermediate_tensors = getattr(self.model, "make_empty_intermediate_tensors", lambda: None) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int): # noqa: ARG003 + return Qwen2_5_VLForConditionalGeneration.get_placeholder_str(modality, i) + + def get_language_model(self) -> nn.Module: + if hasattr(self.model, "get_language_model"): + return self.model.get_language_model() + return self.model + + def get_multimodal_embeddings(self, **kwargs: object): + # Backward compatibility: route through embed_multimodal. + return self.embed_multimodal(**kwargs) + + def embed_multimodal(self, **kwargs: object): + if hasattr(self.model, "embed_multimodal"): + return self.model.embed_multimodal(**kwargs) + if hasattr(self.model, "get_multimodal_embeddings"): + return self.model.get_multimodal_embeddings(**kwargs) + return [] + + def get_input_embeddings(self, input_ids: torch.Tensor, multimodal_embeddings=None) -> torch.Tensor: + if hasattr(self.model, "get_input_embeddings"): + return self.model.get_input_embeddings(input_ids, multimodal_embeddings=multimodal_embeddings) + # DiT stage does not consume token embeddings from `input_ids`; it uses + # condition embeddings passed via additional_information. + # However, vLLM's generation runner may still request token embeddings + # to populate `inputs_embeds` buffers, so we provide a dummy tensor. + if self.model_stage == "dit": + hidden_size = int(self.vllm_config.model_config.get_hidden_size()) + try: + target_dtype = next(self.model.parameters()).dtype + except StopIteration: + target_dtype = self.vllm_config.model_config.dtype + return torch.zeros( + (input_ids.numel(), hidden_size), + device=input_ids.device, + dtype=target_dtype, + ) + raise NotImplementedError("Underlying model does not implement get_input_embeddings") + + def forward(self, *args, **kwargs) -> OmniOutput | torch.Tensor: + out = self.model(*args, **kwargs) + if isinstance(out, OmniOutput): + return out + # 子模块可能直接返回 tensor / list;保持向后兼容 + if isinstance(out, list): + out = out[0] + return OmniOutput(text_hidden_states=out, multimodal_outputs={}, intermediate_tensors=None) + + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput): + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + if hasattr(self.model, "compute_logits"): + return self.model.compute_logits(hidden_states) + return None + + def get_dummy_runtime_additional_information(self, num_reqs: int) -> list[dict[str, object]]: + if self.model_stage != "dit": + raise RuntimeError( + f"get_dummy_runtime_additional_information only valid for dit stage, got {self.model_stage}" + ) + if self.dit is None: + raise RuntimeError("dit stage model is not initialized") + if not hasattr(self.dit, "get_dummy_runtime_additional_information"): + raise AttributeError("dit model missing get_dummy_runtime_additional_information") + return self.dit.get_dummy_runtime_additional_information(num_reqs) + + def load_weights(self, weights): + # 参考 Qwen2_5OmniForConditionalGeneration:按 stage 把权重交给对应子模块加载, + # 并将子模块返回的“已加载参数名集合”补上正确的前缀,以通过 DefaultModelLoader 的严格校验。 + if self.model_stage == "ar": + if self.ar is None or not hasattr(self.ar, "load_weights"): + return set() + loaded = self.ar.load_weights(weights) + return add_prefix_to_loaded_weights(loaded, "ar") + if self.model_stage == "dit": + if self.dit is None or not hasattr(self.dit, "load_weights"): + return set() + loaded = self.dit.load_weights(weights) + return add_prefix_to_loaded_weights(loaded, "dit") + if self.model_stage == "vae": + return set() + raise ValueError(f"Unsupported model_stage: {self.model_stage}") diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_ar.py b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_ar.py new file mode 100644 index 000000000..fd51d22ea --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_ar.py @@ -0,0 +1,642 @@ +from collections.abc import Callable, Iterable, Mapping +from itertools import islice +from typing import Any + +import torch +from torch import nn +from transformers import Qwen2Config +from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor + +from vllm.attention.backends.abstract import AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.qwen2 import Qwen2Attention, Qwen2MLP +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLDummyInputsBuilder, + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLMultiModalProcessor, + Qwen2_5_VLProcessingInfo, +) +from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser +from vllm.model_executor.models.utils import ( + PPMissingLayer, + WeightsMapper, + init_vllm_registered_model, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import ( + is_interleaved, + patch_rope_parameters, + set_default_rope_theta, +) + +from vllm_omni.model_executor.models.mammoth_moda2.config import Mammothmoda2Config +from vllm_omni.model_executor.models.output_templates import OmniOutput + + +def moe_forward( + hidden_states: torch.Tensor, + und_expert: Callable[[torch.Tensor], torch.Tensor], + gen_expert: Callable[[torch.Tensor], torch.Tensor] | None, + gen_token_mask: torch.Tensor | None = None, +) -> torch.Tensor: + if gen_expert is None: + return und_expert(hidden_states) + + if gen_token_mask is None or not gen_token_mask.any(): + return und_expert(hidden_states) + if gen_token_mask.all(): + return gen_expert(hidden_states) + + if hidden_states.ndim == 2: + flat_hid = hidden_states + d_model = hidden_states.shape[-1] + total_tokens = hidden_states.shape[0] + elif hidden_states.ndim == 3: + d_model = hidden_states.shape[-1] + flat_hid = hidden_states.reshape(-1, d_model) # (B*L, D) + total_tokens = flat_hid.shape[0] + else: + raise ValueError(f"Unexpected hidden_states shape: {tuple(hidden_states.shape)}") + + # mask: [num_tokens] or [B, L] -> flatten to [total_tokens] + flat_mask = gen_token_mask.reshape(-1) # type: ignore[union-attr] + if flat_mask.numel() != total_tokens: + raise ValueError( + "gen_token_mask shape mismatch: " + f"mask={tuple(gen_token_mask.shape)}, hidden_states={tuple(hidden_states.shape)}" + ) + gen_pos = torch.where(flat_mask)[0] + und_pos = torch.where(~flat_mask)[0] + permute_order = torch.cat([gen_pos, und_pos], dim=0) + inverse_order = torch.argsort(permute_order) + gen_token_num = int(flat_mask.sum().item()) + gen_hid, und_hid = flat_hid[permute_order].split([gen_token_num, total_tokens - gen_token_num], dim=0) + + # 1.1 Generation tokens (True) + gen_out = gen_expert(gen_hid) # (N_gen, D) + + # 1.2 Understanding tokens (False) + und_out = und_expert(und_hid) # (N_und, D) + out_dim = und_out.shape[-1] + + merged = torch.cat([gen_out, und_out], dim=0) + merged = merged[inverse_order] + + if hidden_states.ndim == 2: + return merged.view(total_tokens, out_dim).contiguous() + return merged.view(*hidden_states.shape[:-1], out_dim).contiguous() + +class Mammothmoda2Processor(Qwen2_5_VLProcessor): + """Qwen2.5-VL Processor with MammothU tokenizer.""" + + tokenizer_class = ("MammothUTokenizer", None) + + +class MammothModa2ARProcessingInfo(Qwen2_5_VLProcessingInfo): + """Processes multi-modal information for MammothModa2 AR, returning the VL sub-configuration.""" + + def get_hf_config(self): + mammoth_cfg: Mammothmoda2Config = self.ctx.get_hf_config(Mammothmoda2Config) + llm_cfg = getattr(mammoth_cfg, "llm_config", None) + return llm_cfg + + def get_hf_processor(self, **kwargs: object) -> Mammothmoda2Processor: + return self.ctx.get_hf_processor( + Mammothmoda2Processor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + # MammothModa2 currently supports only image input, not video. + return {"image": None} + + +class MammothModa2ARDummyInputsBuilder(Qwen2_5_VLDummyInputsBuilder): + """Reuse Qwen2.5-VL's dummy input generation logic.""" + + +class MammothModa2ARMultiModalProcessor(Qwen2_5_VLMultiModalProcessor): + """Reuse Qwen2.5-VL's multi-modal processing,""" + + def _get_data_parser(self) -> Qwen2VLMultiModalDataParser: + return Qwen2VLMultiModalDataParser( + spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size + ) + + +class Mammoth2DecoderLayer(nn.Module): + def __init__( + self, + config: Qwen2Config, + layer_idx: int, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + patch_rope_parameters(config) + set_default_rope_theta(config, default_theta=1000000) + dual_chunk_attention_config = getattr(config, "dual_chunk_attention_config", None) + + attn_type = AttentionType.DECODER + + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + if 14 <= layer_idx < 28: + self.gen_mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.gen_mlp", + ) + else: + self.gen_mlp = None + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + gen_token_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + # hidden_states = self.mlp(hidden_states) + hidden_states = moe_forward(hidden_states, self.mlp, self.gen_mlp, gen_token_mask) + return hidden_states, residual + + +class MammothModa2Qwen2ForCausalLM(nn.Module): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", decoder_layer_type: type[nn.Module] = Mammoth2DecoderLayer + ): + super().__init__() + + hf_config = vllm_config.model_config.hf_config + if hasattr(hf_config, "get_text_config"): + config = hf_config.get_text_config() + elif hasattr(hf_config, "text_config"): + config = hf_config.text_config + else: + config = hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.prefix = prefix + + if is_interleaved(vllm_config.model_config.hf_text_config): + assert config.max_window_layers == config.num_hidden_layers, ( + "Sliding window for some but all layers is not supported. " + f"This model uses sliding window but `max_window_layers` = {config.max_window_layers} " + f"is less than `num_hidden_layers` = {config.num_hidden_layers}. Please open an issue " + "to discuss this feature." + ) + + self.config = config + self.quant_config = quant_config + # NOTE: MammothModa2 supports extra generation vocabulary (for image tokens). + # Token ID range: [gen_vocab_start_index, gen_vocab_start_index + gen_vocab_size). + # vLLM sampler/processor expects "last dimension of logits == model_config.get_vocab_size()", + # so we output base+gen logits in compute_logits, and embeddings must accept these IDs. + self.extra_gen_vocab = bool(getattr(config, "extra_gen_vocab", False)) + # Starting index for generation tokens (used for gen_token_mask). + self.gen_vocab_start_index = getattr(hf_config, "gen_vocab_start_index", None) or getattr( + config, "gen_vocab_start_index", None + ) + self.gen_vocab_size = int(getattr(config, "gen_vocab_size", 0) or 0) + + self.base_vocab_size = int(self.gen_vocab_start_index) if self.extra_gen_vocab else int(config.vocab_size) + # The configuration level (hf_text_config.vocab_size) has been extended to base+gen + # by the upstream config class. Use config.vocab_size as the total vocab size. + self.total_vocab_size = int(getattr(config, "vocab_size", self.base_vocab_size)) + + if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.base_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + if self.extra_gen_vocab: + if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): + self.gen_embed_tokens = VocabParallelEmbedding( + self.gen_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.gen_embed_tokens", + ) + else: + self.gen_embed_tokens = PPMissingLayer() + else: + self.gen_embed_tokens = None + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.base_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) + else: + self.lm_head = PPMissingLayer() + + if self.extra_gen_vocab: + if get_pp_group().is_last_rank: + self.gen_head = ParallelLMHead( + self.gen_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.gen_head", + ) + else: + self.gen_head = PPMissingLayer() + else: + self.gen_head = None + + self.logits_processor = LogitsProcessor(self.base_vocab_size) + self.gen_logits_processor = LogitsProcessor(self.gen_vocab_size) if self.extra_gen_vocab else None + + decoder_layer_type = decoder_layer_type or Mammoth2DecoderLayer + + def _make_decoder_layer(*, prefix: str) -> nn.Module: + try: + layer_idx = int(prefix.rsplit(".", 1)[-1]) + except Exception: + layer_idx = 0 + return decoder_layer_type( + config=config, + layer_idx=layer_idx, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + _make_decoder_layer, + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + @property + def model(self) -> "MammothModa2Qwen2ForCausalLM": + return self + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + if not self.extra_gen_vocab or self.gen_embed_tokens is None: + return self.embed_tokens(input_ids) + + gen_mask = input_ids >= int(self.gen_vocab_start_index) + if not gen_mask.any(): + return self.embed_tokens(input_ids) + if gen_mask.all(): + gen_ids = input_ids - int(self.gen_vocab_start_index) + return self.gen_embed_tokens(gen_ids) + + flat_ids = input_ids.reshape(-1) + flat_mask = gen_mask.reshape(-1) + out = torch.empty( + (flat_ids.shape[0], self.config.hidden_size), + dtype=self.embed_tokens.weight.dtype, # type: ignore[attr-defined] + device=flat_ids.device, + ) + + base_pos = torch.where(~flat_mask)[0] + gen_pos = torch.where(flat_mask)[0] + if base_pos.numel() > 0: + out[base_pos] = self.embed_tokens(flat_ids[base_pos]) + if gen_pos.numel() > 0: + gen_ids = flat_ids[gen_pos] - int(self.gen_vocab_start_index) + out[gen_pos] = self.gen_embed_tokens(gen_ids) + return out.view(*input_ids.shape, -1).contiguous() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + assert input_ids is not None + hidden_states = self.get_input_embeddings(input_ids) + # gen_token_mask: True indicates image generation tokens, which use gen_mlp. + # In vLLM v1 path, only inputs_embeds might be provided, with input_ids set to None. + # In this case, gen tokens cannot be distinguished by ID, falling back to und_expert. + if self.gen_vocab_start_index is None or input_ids is None: + gen_token_mask = None + else: + gen_token_mask = input_ids >= self.gen_vocab_start_index + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + gen_token_mask = None + + for idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)): + hidden_states, residual = layer(positions, hidden_states, residual, gen_token_mask) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + if isinstance(self.lm_head, PPMissingLayer): + return None + base_logits = self.logits_processor(self.lm_head, hidden_states) + if not self.extra_gen_vocab: + return base_logits + if self.gen_head is None or isinstance(self.gen_head, PPMissingLayer): + return base_logits + assert self.gen_logits_processor is not None + gen_logits = self.gen_logits_processor(self.gen_head, hidden_states) + if base_logits is None or gen_logits is None: + return None + return torch.cat([base_logits, gen_logits], dim=-1) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name)): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + + return loaded_params + + +@MULTIMODAL_REGISTRY.register_processor( + MammothModa2ARMultiModalProcessor, + info=MammothModa2ARProcessingInfo, + dummy_inputs=MammothModa2ARDummyInputsBuilder, +) +class MammothModa2ARForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + """Replaces the language backbone with MoE within the Qwen2_5_VLForConditionalGeneration multi-modal framework.""" + + have_multimodal_outputs = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Skip generation-side (DiT/VAE) weights as they do not belong to the AR stage. + "gen_image_condition_refiner.": None, + "gen_transformer.": None, + "gen_vae.": None, + # LLM backbone: checkpoint uses the llm_model.* prefix. + # Extra generation vocab (image tokens) weights: mapped separately to the vLLM language_model submodule. + "llm_model.model.language_model.gen_embed_tokens.": "language_model.gen_embed_tokens.", + "llm_model.gen_head.": "language_model.gen_head.", + "llm_model.model.language_model.": "language_model.", + "llm_model.model.visual.": "visual.", + "llm_model.lm_head.": "language_model.lm_head.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Switch hf_config to the AR sub-config to ensure the Qwen2.5-VL path receives the correct type. + mammoth_cfg = vllm_config.model_config.hf_config + ar_hf_config = getattr(mammoth_cfg, "llm_config", mammoth_cfg) + ar_vllm_config = vllm_config.with_hf_config(ar_hf_config, architectures=vllm_config.model_config.architectures) + # Initialize multi-modal components like the vision tower first. + super().__init__(vllm_config=ar_vllm_config, prefix=prefix) + # Replace with the custom MoE language model. + lm_hf_config = getattr( + ar_vllm_config.model_config.hf_config, "text_config", ar_vllm_config.model_config.hf_config + ) + self.language_model = init_vllm_registered_model( + vllm_config=ar_vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=lm_hf_config, + architectures=["MammothModa2Qwen2ForCausalLM"], + ) + self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors + + # -------- t2i (AR grid) token constraints -------- + # Constraint logic depends on per-step sampling_metadata + runtime_additional_information. + # These are passed by the vllm-omni runner via kwargs, so caching them in the model is sufficient. + self._last_runtime_additional_information: list[dict[str, Any]] | None = None + + def _apply_t2i_token_constraints(self, logits: torch.Tensor) -> torch.Tensor: + """Applies per-request token constraints. + + - For T2I requests: constrain AR grid tokens (force EOL at row end and + restrict intra-row sampling to visual token range). + - For non-T2I (text/understanding/chat) requests: disallow sampling + from the extra generation vocabulary (image tokens) to prevent + accidentally emitting visual-token sequences. + """ + if logits is None or not isinstance(logits, torch.Tensor): + return logits + + runtime_infos = self._last_runtime_additional_information + + if runtime_infos is None: + # There is no runtime info in dummy/profile run + return logits + + neg_inf = -float("inf") + num_reqs = int(logits.shape[0]) + for i in range(num_reqs): + runtime_info = runtime_infos[i] if isinstance(runtime_infos[i], dict) else {} + if runtime_info["omni_task"][0] != "t2i": + # Text/understanding/chat: forbid sampling from the extra gen vocab. + logits[i, self.language_model.base_vocab_size :] = neg_inf + continue + + ar_width = runtime_info["ar_width"][0] + ar_height = runtime_info["ar_height"][0] + eol_token_id = runtime_info["eol_token_id"][0] + visual_start = runtime_info["visual_token_start_id"][0] + visual_end = runtime_info["visual_token_end_id"][0] + generated_len = runtime_info["generated_len"] + + expected_token_num = (ar_width + 1) * ar_height + + row = logits[i] + column_id = generated_len % (ar_width + 1) + if column_id == ar_width: + # End-of-row token: only allow eol. + eol_logit = row[eol_token_id].clone() + row.fill_(neg_inf) + row[eol_token_id] = eol_logit + else: + # Intra-row tokens: only allow visual tokens (explicitly forbid eol). + row[:visual_start] = neg_inf + row[visual_end:] = neg_inf + row[eol_token_id] = neg_inf + + if generated_len >= expected_token_num: + row.fill_(neg_inf) + end_of_image_id = 152071 + row[end_of_image_id] = 1.0 # Allow only end_of_image_id after expected tokens num + + return logits + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ): + # vllm-omni runner passes sampling_metadata and runtime_additional_information + # in each forward step. compute_logits is called immediately after + # forward, so caching here enables step-by-step dynamic token constraints. + runtime_infos = kwargs.get("runtime_additional_information") + self._last_runtime_additional_information = runtime_infos if isinstance(runtime_infos, list) else None + hidden_states = super().forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + # NOTE: gpu_model_runner._dummy_run performs hidden_states[logit_indices] after forward. + # We must ensure text_hidden_states is a torch.Tensor to avoid errors when + # indexing (which happens if it's a list/tuple). + if isinstance(hidden_states, IntermediateTensors): + text_hidden_states = hidden_states["hidden_states"] + out_intermediate_tensors = hidden_states + elif isinstance(hidden_states, list): + text_hidden_states = hidden_states[0] + out_intermediate_tensors = None + else: + text_hidden_states = hidden_states + out_intermediate_tensors = None + + return OmniOutput( + text_hidden_states=text_hidden_states, + multimodal_outputs={}, + intermediate_tensors=out_intermediate_tensors, + ) + + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput): + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + logits = super().compute_logits(hidden_states) + if isinstance(logits, torch.Tensor): + logits = self._apply_t2i_token_constraints(logits) + return logits diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_dit.py b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_dit.py new file mode 100644 index 000000000..944851f9a --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_dit.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from vllm.config import VllmConfig +from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper + +from vllm_omni.model_executor.models.mammoth_moda2.config import Mammothmoda2Config +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .mammothmoda2_dit_layer import ( + FlowMatchEulerDiscreteScheduler, + RotaryPosEmbedReal, + SimpleQFormerImageRefiner, + Transformer2DModel, +) + + +class MammothModa2DiTForConditionalGeneration(nn.Module): + """ + MammothModa2 DiT + VAE generation stage (non-autoregressive). + + This stage expects "image condition token hidden states" from the upstream AR stage, + and outputs image tensors via diffusion transformer + VAE decode. + + """ + + have_multimodal_outputs = True + + # Load only gen_* weights; ignore llm_model.* to prevent loading the entire LLM backbone in the DiT stage. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "llm_model.": None, + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + del prefix + + hf_config = vllm_config.model_config.hf_config + if not isinstance(hf_config, Mammothmoda2Config): + raise TypeError(f"Expected Mammothmoda2Config, got {type(hf_config)}") + + self.config = hf_config + + # --- Build DiT / VAE modules (names must match checkpoint keys) --- + if self.config.gen_vae_config is None or self.config.gen_dit_config is None: + raise ValueError("Mammothmoda2Config.gen_vae_config / gen_dit_config must not be None") + + self.gen_vae = AutoencoderKL.from_config(self.config.gen_vae_config) + self.gen_transformer = Transformer2DModel.from_config(self.config.gen_dit_config) + + llm_hidden_size = int(getattr(self.config.llm_config, "hidden_size", 0) or 0) + if llm_hidden_size <= 0: + raise ValueError("Failed to infer llm hidden_size from Mammothmoda2Config.llm_config.hidden_size") + self._reinit_caption_embedder(llm_hidden_size) + + # Optional: image condition refiner (Q-Former) + if self.config.gen_image_condition_refiner_config is not None: + self.gen_image_condition_refiner = SimpleQFormerImageRefiner( + hidden_size=llm_hidden_size, + **self.config.gen_image_condition_refiner_config, + ) + else: + self.gen_image_condition_refiner = None + + # Precompute rotary freqs for diffusion transformer + # IMPORTANT: follow upstream mammothmoda: use top-level `config.gen_axes_*` + # (the checkpoint's `gen_dit_config.axes_lens` can be as small as 1024, + # which is insufficient for vLLM dummy-run/cudagraph warmup). + self.gen_freqs_cis = RotaryPosEmbedReal.get_freqs_real( + tuple(self.config.gen_axes_dim_rope), + tuple(self.config.gen_axes_lens), + theta=10000, + ) + + # vLLM PP interface compatibility + self.make_empty_intermediate_tensors = lambda: None + + self._llm_hidden_size = llm_hidden_size + + def _reinit_caption_embedder(self, in_features: int) -> None: + # Align with upstream Mammothmoda2Model's `reinit_caption_embedder`: + # Use Qwen2RMSNorm(in_features) + Linear(in_features -> out_features). + out_features = int(getattr(self.gen_transformer, "hidden_size", 0) or self.gen_transformer.config.hidden_size) + self.gen_transformer.time_caption_embed.caption_embedder = nn.Sequential( + Qwen2RMSNorm(in_features, eps=1e-5), + nn.Linear(in_features, out_features, bias=True), + ) + + def get_dummy_runtime_additional_information(self, num_reqs: int) -> list[dict[str, object]]: + num_reqs = 1 # TODO: support num_reqs > 1 + if num_reqs <= 0: + raise ValueError(f"num_reqs must be positive, got {num_reqs}") + text_prompt_embeds = torch.zeros((1, self._llm_hidden_size), dtype=torch.float32) + image_prompt_embeds = torch.zeros((1, self._llm_hidden_size), dtype=torch.float32) + negative_prompt_embeds = torch.zeros((0, self._llm_hidden_size), dtype=torch.float32) + info = { + "text_prompt_embeds": text_prompt_embeds, + "image_prompt_embeds": image_prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": [], + "image_height": [512], + "image_width": [512], + "text_guidance_scale": [1.0], + "cfg_range": [0.0, 1.0], + "num_inference_steps": [1], + } + return [info for _ in range(num_reqs)] + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + # DiT stage does not consume token embeddings; return a dummy tensor. + try: + dtype = next(self.parameters()).dtype + except StopIteration: + dtype = torch.float32 + return torch.zeros( + (input_ids.numel(), self._llm_hidden_size), + device=input_ids.device, + dtype=dtype, + ) + + @torch.inference_mode() + def forward( + self, + *, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, # noqa: ARG002 + ) -> OmniOutput: + runtime_addi = kwargs.get("runtime_additional_information", None) + info = runtime_addi[0] + text_cond = info["text_prompt_embeds"] + image_cond = info["image_prompt_embeds"] + negative_cond = info.get("negative_prompt_embeds") + negative_attention_mask = info.get("negative_prompt_attention_mask") + image_hw = info["image_height"][0], info["image_width"][0] + text_guidance_scale = info["text_guidance_scale"][0] + cfg_range = info["cfg_range"][0], info["cfg_range"][1] + num_inference_steps = info["num_inference_steps"][0] + + # Move to model device/dtype. + model_device = next(self.parameters()).device + if self.gen_image_condition_refiner is not None: + target_dtype = next(self.gen_image_condition_refiner.parameters()).dtype + else: + target_dtype = next(self.gen_transformer.parameters()).dtype + + def _ensure_2d(x: torch.Tensor, name: str) -> torch.Tensor: + if x.ndim == 3 and x.shape[0] == 1: + x = x[0] + if x.ndim != 2: + raise ValueError(f"Expected {name} to be 2D [T,H], got shape={tuple(x.shape)}") + return x + + text_cond = _ensure_2d(text_cond, "text_prompt_embeds") + image_cond = _ensure_2d(image_cond, "image_prompt_embeds") + text_cond = text_cond.to(device=model_device, dtype=target_dtype, non_blocking=True).contiguous() + image_cond = image_cond.to(device=model_device, dtype=target_dtype, non_blocking=True).contiguous() + + text_embeds = text_cond.unsqueeze(0) # [1, T_text, H] + text_attention_mask = torch.ones( + (1, text_embeds.shape[1]), + dtype=torch.bool, + device=text_embeds.device, + ) + + image_embeds = image_cond.unsqueeze(0) # [1, T_img, H] + image_attention_mask = torch.ones( + (1, image_embeds.shape[1]), + dtype=torch.bool, + device=image_embeds.device, + ) + + # Apply optional refiner ONLY on image condition tokens. + if self.gen_image_condition_refiner is not None and image_embeds.shape[1] > 0: + image_embeds = self.gen_image_condition_refiner(image_embeds, ~image_attention_mask.bool()) + image_attention_mask = torch.ones( + image_embeds.shape[:2], + dtype=torch.bool, + device=image_embeds.device, + ) + + prompt_embeds = torch.cat([text_embeds, image_embeds], dim=1) + prompt_attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + + # Prepare negative prompt (for CFG). If none provided, fall back to unconditional. + negative_prompt_embeds = None + negative_prompt_attention_mask = None + if text_guidance_scale > 1.0: + if negative_cond is not None: + negative_cond = _ensure_2d(negative_cond, "negative_prompt_embeds") + negative_prompt_embeds = ( + negative_cond.to(device=model_device, dtype=target_dtype, non_blocking=True) + .contiguous() + .unsqueeze(0) + ) + if isinstance(negative_attention_mask, torch.Tensor): + neg_mask = negative_attention_mask + elif isinstance(negative_attention_mask, list): + neg_mask = torch.tensor(negative_attention_mask, dtype=torch.bool) + else: + neg_mask = None + if neg_mask is None: + negative_prompt_attention_mask = torch.ones( + (1, negative_prompt_embeds.shape[1]), + dtype=torch.bool, + device=negative_prompt_embeds.device, + ) + else: + neg_mask = neg_mask.to(device=negative_prompt_embeds.device, dtype=torch.bool) + if neg_mask.ndim == 1: + neg_mask = neg_mask.unsqueeze(0) + negative_prompt_attention_mask = neg_mask + else: + hidden_size = int(prompt_embeds.shape[-1]) + negative_prompt_embeds = torch.zeros( + (1, 0, hidden_size), + dtype=target_dtype, + device=prompt_embeds.device, + ) + negative_prompt_attention_mask = torch.zeros( + (1, 0), + dtype=torch.bool, + device=prompt_embeds.device, + ) + + # Output image size (px), passed from stage input processor. + height, width = image_hw + if height <= 0 or width <= 0: + raise ValueError(f"Invalid image size: {height}x{width}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"Image size must be multiples of 16, got {height}x{width}") + vae_scale_factor = 16 + + latent_channels = int(self.gen_transformer.config.in_channels) + shape = (1, latent_channels, 2 * height // vae_scale_factor, 2 * width // vae_scale_factor) + latents = randn_tensor(shape, device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + scheduler = FlowMatchEulerDiscreteScheduler() + + scheduler.set_timesteps( + num_inference_steps=num_inference_steps, + device=prompt_embeds.device, + num_tokens=latents.shape[-2] * latents.shape[-1], + ) + + # Run diffusion loop (CFG supported when text_guidance_scale > 1.0) + total_steps = max(1, len(scheduler.timesteps)) + for i, t in enumerate(scheduler.timesteps): + timestep = t.expand(latents.shape[0]).to(latents.dtype) + model_pred = self.gen_transformer( + hidden_states=latents, + timestep=timestep, + text_hidden_states=prompt_embeds, + text_attention_mask=prompt_attention_mask, + ref_image_hidden_states=None, + freqs_cis=self.gen_freqs_cis, + ) + guidance_scale = text_guidance_scale if cfg_range[0] <= i / total_steps <= cfg_range[1] else 1.0 + if guidance_scale > 1.0 and negative_prompt_embeds is not None: + model_pred_uncond = self.gen_transformer( + hidden_states=latents, + timestep=timestep, + text_hidden_states=negative_prompt_embeds, + text_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + freqs_cis=self.gen_freqs_cis, + ) + model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) + latents = scheduler.step(model_pred, t, latents, return_dict=False)[0] + latents = latents.to(dtype=prompt_embeds.dtype) + + # VAE decode + if self.gen_vae.config.scaling_factor is not None: + latents = latents / self.gen_vae.config.scaling_factor + if self.gen_vae.config.shift_factor is not None: + latents = latents + self.gen_vae.config.shift_factor + image = self.gen_vae.decode(latents, return_dict=False)[0] + + return OmniOutput( + text_hidden_states=inputs_embeds, # 占位,runner 不会用到 + multimodal_outputs=image, + intermediate_tensors=None, + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: # noqa: ARG002 + return None + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/__init__.py b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/__init__.py new file mode 100644 index 000000000..bde02b700 --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mammothmoda2_dit_model import Lumina2CombinedTimestepCaptionEmbedding +from .mammothmoda2_dit_model import Transformer2DModel +from .mammothmoda2_dit_model import SimpleQFormerImageRefiner +from .rope_real import RotaryPosEmbedReal +from .schedulers import FlowMatchEulerDiscreteScheduler + +__all__ = [ + "FlowMatchEulerDiscreteScheduler", + "Lumina2CombinedTimestepCaptionEmbedding", + "RotaryPosEmbedReal", + "SimpleQFormerImageRefiner", + "Transformer2DModel", +] diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/mammothmoda2_dit_model.py b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/mammothmoda2_dit_model.py new file mode 100644 index 000000000..7d02ad8ec --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/mammothmoda2_dit_model.py @@ -0,0 +1,717 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange, repeat +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm + +from .rope_real import RotaryPosEmbedReal + + +try: + from transformers.modeling_flash_attention_utils import ( # type: ignore + flash_attn_varlen_func, # pyright: ignore[reportAttributeAccessIssue] + is_flash_attn_available, + ) +except Exception: # pragma: no cover - best-effort compatibility + flash_attn_varlen_func = None # type: ignore[assignment] + + def is_flash_attn_available() -> bool: # type: ignore[override] + return False + +from .rope_real import apply_real_rotary_emb + +_HAS_FLASH_ATTN_VARLEN = bool(is_flash_attn_available()) and flash_attn_varlen_func is not None + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = Qwen2RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + +class LuminaFeedForward(nn.Module): + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: int | None = 256, + ffn_dim_multiplier: float | None = None, + ): + super().__init__() + + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + + def swiglu(self, x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + return self.linear_2(self.swiglu(h1, h2)) + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: int | None = None, + ): + super().__init__() + + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = Qwen2RMSNorm(embedding_dim, eps=eps) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 2520, + text_feat_dim: int = 3584, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + Qwen2RMSNorm(text_feat_dim, eps=norm_eps), + nn.Linear(text_feat_dim, hidden_size, bias=True), + ) + + def forward( + self, + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed + +class SimpleQFormerImageRefiner(nn.Module): + + def __init__( + self, + hidden_size: int, + num_queries: int = 128, + num_layers: int = 2, + num_heads: int | None = None, + dropout: float = 0.0, + norm_eps: float = 1e-5, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_queries = num_queries + # ensure num_heads divides hidden_size + if num_heads is None: + num_heads = max(1, hidden_size // 128) + self.num_heads = self._choose_valid_num_heads(hidden_size, num_heads) + self.input_proj = nn.Sequential( + Qwen2RMSNorm(hidden_size, eps=norm_eps), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + # Learnable query embeddings + scale = hidden_size**-0.5 + self.query = nn.Parameter(scale * torch.randn(1, num_queries, hidden_size)) + + # Decoder layers + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + nn.ModuleDict( + dict( + ln_q1=Qwen2RMSNorm(hidden_size, eps=norm_eps), + self_attn=nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=self.num_heads, dropout=dropout, batch_first=True + ), + ln_q2=Qwen2RMSNorm(hidden_size, eps=norm_eps), + cross_attn=nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=self.num_heads, dropout=dropout, batch_first=True + ), + ln_ffn=Qwen2RMSNorm(hidden_size, eps=norm_eps), + ffn=LuminaFeedForward(dim=hidden_size, inner_dim=4 * hidden_size), + ) + ) + ) + + @staticmethod + def _choose_valid_num_heads(hidden_size: int, proposed_heads: int, preferred_head_dim: int = 128) -> int: + """Pick a number of heads that divides hidden_size, close to proposed or preferred.""" + # If proposed is valid, use it + if proposed_heads > 0 and hidden_size % proposed_heads == 0: + return proposed_heads + # target based on preferred head dim + target = max(1, round(hidden_size / preferred_head_dim)) + # collect divisors up to 128 heads (more than enough) + max_heads_cap = min(128, hidden_size) + divisors = [d for d in range(1, max_heads_cap + 1) if hidden_size % d == 0] + # choose closest to target + best = min(divisors, key=lambda d: (abs(d - target), -d)) + return best + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor of shape (batch, seq_len, input_dim) + Returns: + Tensor of shape (batch, num_queries, hidden_size) + """ + batch, _, _ = x.shape + kv = self.input_proj(x) + q = self.query.repeat(batch, 1, 1).to(kv.dtype) + + for layer in self.layers: + # Self-attention on queries + q_norm = layer["ln_q1"](q) + attn_out, _ = layer["self_attn"](q_norm, q_norm, q_norm, need_weights=False) + q = q + attn_out + + # Cross-attention: queries attend to inputs + q_norm = layer["ln_q2"](q) + cross_out, _ = layer["cross_attn"](q_norm, kv, kv, need_weights=False, key_padding_mask=attention_mask) + q = q + cross_out + + # Feed-forward + q = q + layer["ffn"](layer["ln_ffn"](q)) + + return q + +class AttnProcessor: + def __init__(self) -> None: + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor requires PyTorch 2.0+ (F.scaled_dot_product_attention).") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + base_sequence_length: int | None = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_real_rotary_emb(query, image_rotary_emb[0], image_rotary_emb[1]) + key = apply_real_rotary_emb(key, image_rotary_emb[0], image_rotary_emb[1]) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + if _HAS_FLASH_ATTN_VARLEN and attention_mask is not None and hidden_states.is_cuda: + # Flash-Attn varlen expects packed tokens + cu_seqlens. Here we only need + # the self-attention case (q/k/v share the same padding mask). + attention_mask = attention_mask.to(torch.bool) + seqlens = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen = int(seqlens.max().item()) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + + query_states = query.reshape(batch_size * sequence_length, attn.heads, head_dim)[indices] + key_states = key.reshape(batch_size * sequence_length, kv_heads, head_dim)[indices] + value_states = value.reshape(batch_size * sequence_length, kv_heads, head_dim)[indices] + + if kv_heads < attn.heads: + key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + + out = torch.zeros( + (batch_size * sequence_length, attn.heads, head_dim), + device=attn_output_unpad.device, + dtype=attn_output_unpad.dtype, + ) + out[indices] = attn_output_unpad + hidden_states = out.view(batch_size, sequence_length, attn.heads, head_dim).flatten(-2) + hidden_states = hidden_states.type_as(query) + else: + # PyTorch SDPA path. + attn_mask = None + if attention_mask is not None: + attention_mask = attention_mask.to(torch.bool) + attn_mask = attention_mask.view(batch_size, 1, 1, -1) + + query = query.transpose(1, 2) # [B, H, S, D] + key = key.transpose(1, 2) # [B, H_kv, S, D] + value = value.transpose(1, 2) + + if kv_heads < attn.heads: + key = key.repeat_interleave(attn.heads // kv_heads, dim=1) + value = value.repeat_interleave(attn.heads // kv_heads, dim=1) + + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + scale=softmax_scale, + ) + + if attention_mask is not None: + # Keep padding tokens consistent with the flash-varlen path (zero output). + hidden_states = hidden_states * attention_mask[:, None, :, None] + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + +class TransformerBlock(nn.Module): + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + processor = AttnProcessor() + + # Initialize attention layer + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm=None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=processor, + ) + # 显式使用 transformers 的 Qwen2RMSNorm,避免依赖 diffusers 内部创建的 `RMSNorm` 再做递归替换。 + self.attn.norm_q = Qwen2RMSNorm(self.head_dim, eps=1e-5) + self.attn.norm_k = Qwen2RMSNorm(self.head_dim, eps=1e-5) + + # Initialize feed-forward network + self.feed_forward = LuminaFeedForward( + dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier + ) + + # Initialize normalization layers + if modulation: + self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + else: + self.norm1 = Qwen2RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = Qwen2RMSNorm(dim, eps=norm_eps) + self.norm2 = Qwen2RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = Qwen2RMSNorm(dim, eps=norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + +class Transformer2DModel(ModelMixin, ConfigMixin): + """MammothModa2 DiT transformer(推理使用)。""" + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: int | None = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: float | None = None, + norm_eps: float = 1e-5, + axes_dim_rope: tuple[int, int, int] = (32, 32, 32), + axes_lens: tuple[int, int, int] = (300, 512, 512), + text_feat_dim: int = 1024, + timestep_scale: float = 1.0, + ) -> None: + """Initialize the transformer model.""" + super().__init__() + self.hidden_size = hidden_size + + # Validate configuration + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + self.out_channels = out_channels or in_channels + # Initialize embeddings + self.rope_embedder = RotaryPosEmbedReal( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=text_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale, + ) + + # Initialize transformer blocks + self.noise_refiner = nn.ModuleList( + [ + TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.ref_image_refiner = nn.ModuleList( + [ + TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False, + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + + # Add learnable embeddings to distinguish different images + self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + text_attention_mask: torch.Tensor, + ref_image_hidden_states: list[list[torch.Tensor]] | None = None, + return_dict: bool = False, + ) -> torch.Tensor: + if return_dict: + raise ValueError("return_dict=True is not supported in vLLM inference.") + if ref_image_hidden_states is not None: + raise ValueError("ref_image_hidden_states is not supported in vLLM inference.") + if hidden_states.ndim != 4: + raise ValueError(f"Expected hidden_states to be 4D [B,C,H,W], got shape={tuple(hidden_states.shape)}") + + batch_size, _channels, height, width = hidden_states.shape + if batch_size != text_hidden_states.shape[0] or batch_size != text_attention_mask.shape[0]: + raise ValueError( + "Batch size mismatch: " + f"hidden_states={batch_size}, text_hidden_states={text_hidden_states.shape[0]}, " + f"text_attention_mask={text_attention_mask.shape[0]}" + ) + + p = self.config.patch_size + if height % p != 0 or width % p != 0: + raise ValueError(f"Input latent H/W must be divisible by patch_size={p}, got {height}x{width}") + + device = hidden_states.device + + temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states.dtype) + + img_tokens = rearrange(hidden_states, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) + img_tokens = self.x_embedder(img_tokens) + + img_len = (height // p) * (width // p) + img_mask = torch.ones((batch_size, img_len), dtype=torch.bool, device=device) + l_effective_img_len = [img_len for _ in range(batch_size)] + img_sizes = [(height, width) for _ in range(batch_size)] + + l_effective_ref_img_len = [[] for _ in range(batch_size)] + ref_img_sizes = [None for _ in range(batch_size)] + + ( + context_rotary_emb, + _ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder( + freqs_cis, + text_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + for layer in self.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + + for layer in self.noise_refiner: + img_tokens = layer(img_tokens, img_mask, noise_rotary_emb, temb) + + max_seq_len = max(seq_lengths) + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len : encoder_seq_len + img_len] = img_tokens[i, :img_len] + + hidden_states = joint_hidden_states + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + + hidden_states = self.norm_out(hidden_states, temb) + + img_hidden_states = torch.stack( + [ + hidden_states[i, encoder_seq_len : encoder_seq_len + img_len] + for i, encoder_seq_len in enumerate(encoder_seq_lengths) + ], + dim=0, + ) + output = rearrange( + img_hidden_states, + "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", + h=height // p, + w=width // p, + p1=p, + p2=p, + ) + return output diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/rope_real.py b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/rope_real.py new file mode 100644 index 000000000..d16181a69 --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/rope_real.py @@ -0,0 +1,250 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from diffusers.models.embeddings import get_1d_rotary_pos_embed as _get_1d_rotary_pos_embed +from einops import repeat +from torch import nn + + +def apply_real_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor: + """ + Apply real-valued rotary embeddings to input tensor. + + Args: + x: Input tensor of shape [..., seq_len, num_heads, dim] or [..., seq_len, dim] + freqs_cos: Cosine frequencies of shape [batch, seq_len, dim] or [seq_len, dim] + freqs_sin: Sine frequencies of shape [batch, seq_len, dim] or [seq_len, dim] + + Returns: + Tensor with rotary embeddings applied + """ + # x: [batch, seq_len, num_heads, dim] or [batch, seq_len, dim] + # freqs_cos: [batch, seq_len, dim] or [seq_len, dim] + # freqs_sin: [batch, seq_len, dim] or [seq_len, dim] + + x_shape = x.shape + if len(x_shape) == 4: + batch, seq_len, num_heads, dim = x_shape + x_reshaped = x.view(batch, seq_len, num_heads, dim // 2, 2) + elif len(x_shape) == 3: + batch, seq_len, dim = x_shape + num_heads = None + x_reshaped = x.view(batch, seq_len, dim // 2, 2) + else: + raise ValueError(f"Unsupported x shape: {x.shape}") + + # freqs_cos/sin: [batch, seq_len, dim] or [seq_len, dim] + # Expand freqs_cos/sin to [batch, seq_len, dim] if needed + if freqs_cos.dim() == 2: + # [seq_len, dim] -> [1, seq_len, dim] + freqs_cos = freqs_cos.unsqueeze(0) + freqs_sin = freqs_sin.unsqueeze(0) + if freqs_cos.shape[0] == 1 and batch > 1: + freqs_cos = freqs_cos.expand(batch, -1, -1) + freqs_sin = freqs_sin.expand(batch, -1, -1) + + # Reshape freqs to [batch, seq_len, dim//2, 2] + freqs_cos_reshaped = freqs_cos.view(batch, seq_len, dim // 2, 2) + freqs_sin_reshaped = freqs_sin.view(batch, seq_len, dim // 2, 2) + + cos_1 = freqs_cos_reshaped[..., 0] # [batch, seq_len, dim//2] + cos_2 = freqs_cos_reshaped[..., 1] # [batch, seq_len, dim//2] + sin_1 = freqs_sin_reshaped[..., 0] # [batch, seq_len, dim//2] + sin_2 = freqs_sin_reshaped[..., 1] # [batch, seq_len, dim//2] + + # Broadcast cos/sin to match x_reshaped + if len(x_shape) == 4: + # [batch, seq_len, 1, dim//2] + cos_1 = cos_1.unsqueeze(2) + cos_2 = cos_2.unsqueeze(2) + sin_1 = sin_1.unsqueeze(2) + sin_2 = sin_2.unsqueeze(2) + + x1 = x_reshaped[..., 0] # [..., seq_len, num_heads, dim//2] or [..., seq_len, dim//2] + x2 = x_reshaped[..., 1] # same + + out1 = x1 * cos_1 - x2 * sin_1 + out2 = x1 * sin_2 + x2 * cos_2 + + out = torch.stack([out1, out2], dim=-1) + return out.view(*x_shape) + + +def get_1d_rotary_pos_embed_real( + dim: int, + pos: np.ndarray | int, + theta: float = 10000.0, + linear_factor: float = 1.0, + ntk_factor: float = 1.0, + freqs_dtype: torch.dtype = torch.float32, +): + freqs_cos, freqs_sin = _get_1d_rotary_pos_embed( + dim, + pos, + theta=theta, + use_real=True, + linear_factor=linear_factor, + ntk_factor=ntk_factor, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, + ) + return freqs_cos, freqs_sin + + +class RotaryPosEmbedReal(nn.Module): + def __init__( + self, theta: int, axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], patch_size: int = 2 + ): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_real( + axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], theta: int + ) -> list[tuple[torch.Tensor, torch.Tensor]]: + freqs_real = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + cos_emb, sin_emb = get_1d_rotary_pos_embed_real(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_real.append((cos_emb, sin_emb)) + return freqs_real + + def _get_freqs_real( + self, freqs_real: list[tuple[torch.Tensor, torch.Tensor]], ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + cos_result = [] + sin_result = [] + for i in range(len(self.axes_dim)): + freqs_cos, freqs_sin = freqs_real[i] + freqs_cos = freqs_cos.to(ids.device) + freqs_sin = freqs_sin.to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs_cos.shape[-1]).to(torch.int64) + cos_result.append(torch.gather(freqs_cos.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + sin_result.append(torch.gather(freqs_sin.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + + combined_cos = torch.cat(cos_result, dim=-1).to(device) + combined_sin = torch.cat(sin_result, dim=-1).to(device) + return combined_cos, combined_sin + + def forward( + self, freqs_real, attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + l_effective_cap_len = [int(len) for len in l_effective_cap_len] + seq_lengths = [ + int(cap_len + sum(ref_img_len) + img_len) + for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len) + ] + + max_seq_len = max(seq_lengths) + max_ref_img_len = max([int(sum(ref_img_len)) for ref_img_len in l_effective_ref_img_len]) + max_img_len = int(max(l_effective_img_len)) + + # Create position IDs + position_ids = torch.zeros(batch_size, int(max_seq_len), 3, dtype=torch.int32, device=device) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + # add text position ids + position_ids[i, :cap_seq_len] = repeat( + torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3" + ) + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + assert ref_H_tokens * ref_W_tokens == ref_img_len + # add image position ids + + row_ids = repeat( + torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens + ).flatten() + col_ids = repeat( + torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens + ).flatten() + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 0] = pe_shift + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 1] = row_ids + position_ids[i, pe_shift_len : pe_shift_len + ref_img_len, 2] = col_ids + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == l_effective_img_len[i] + + row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() + col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() + + assert pe_shift_len + l_effective_img_len[i] == seq_len + position_ids[i, pe_shift_len:seq_len, 0] = pe_shift + position_ids[i, pe_shift_len:seq_len, 1] = row_ids + position_ids[i, pe_shift_len:seq_len, 2] = col_ids + + # Get combined rotary embeddings (real version) + freqs_cos, freqs_sin = self._get_freqs_real(freqs_real, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cos = torch.zeros( + batch_size, encoder_seq_len, freqs_cos.shape[-1], device=device, dtype=freqs_cos.dtype + ) + cap_freqs_sin = torch.zeros( + batch_size, encoder_seq_len, freqs_sin.shape[-1], device=device, dtype=freqs_sin.dtype + ) + ref_img_freqs_cos = torch.zeros( + batch_size, max_ref_img_len, freqs_cos.shape[-1], device=device, dtype=freqs_cos.dtype + ) + ref_img_freqs_sin = torch.zeros( + batch_size, max_ref_img_len, freqs_sin.shape[-1], device=device, dtype=freqs_sin.dtype + ) + img_freqs_cos = torch.zeros(batch_size, max_img_len, freqs_cos.shape[-1], device=device, dtype=freqs_cos.dtype) + img_freqs_sin = torch.zeros(batch_size, max_img_len, freqs_sin.shape[-1], device=device, dtype=freqs_sin.dtype) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate( + zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths) + ): + cap_freqs_cos[i, :cap_seq_len] = freqs_cos[i, :cap_seq_len] + cap_freqs_sin[i, :cap_seq_len] = freqs_sin[i, :cap_seq_len] + ref_img_freqs_cos[i, : sum(ref_img_len)] = freqs_cos[i, cap_seq_len : cap_seq_len + sum(ref_img_len)] + ref_img_freqs_sin[i, : sum(ref_img_len)] = freqs_sin[i, cap_seq_len : cap_seq_len + sum(ref_img_len)] + img_freqs_cos[i, :img_len] = freqs_cos[ + i, cap_seq_len + sum(ref_img_len) : cap_seq_len + sum(ref_img_len) + img_len + ] + img_freqs_sin[i, :img_len] = freqs_sin[ + i, cap_seq_len + sum(ref_img_len) : cap_seq_len + sum(ref_img_len) + img_len + ] + + return ( + (cap_freqs_cos, cap_freqs_sin), + (ref_img_freqs_cos, ref_img_freqs_sin), + (img_freqs_cos, img_freqs_sin), + (freqs_cos, freqs_sin), + l_effective_cap_len, + seq_lengths, + ) diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/schedulers.py b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/schedulers.py new file mode 100644 index 000000000..4f7ad3b5e --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit_layer/schedulers.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by ByteDance Ltd. on 2025-09-30. +# +# Original file was released under Apache-2.0, with the full license text +# available at https://www.apache.org/licenses/LICENSE-2.0 +# +# This modified file is released under the same license. +# +# --- Upstream header preserved below --- +# +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput: + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler: + + order = 1 + + def __init__(self, num_train_timesteps: int = 1000, dynamic_time_shift: bool = True): + self.num_train_timesteps = int(num_train_timesteps) + self.dynamic_time_shift = bool(dynamic_time_shift) + + timesteps = torch.linspace(0, 1, self.num_train_timesteps + 1, dtype=torch.float32)[:-1] + self.timesteps = timesteps + self._timesteps = torch.cat([timesteps, torch.ones(1, dtype=timesteps.dtype)]) + + self._step_index: int | None = None + self._begin_index: int | None = None + + @property + def step_index(self) -> int | None: + return self._step_index + + @property + def begin_index(self) -> int | None: + return self._begin_index + + def set_begin_index(self, begin_index: int = 0) -> None: + self._begin_index = int(begin_index) + + def index_for_timestep(self, timestep: torch.Tensor, schedule_timesteps: torch.Tensor | None = None) -> int: + schedule_timesteps = self._timesteps if schedule_timesteps is None else schedule_timesteps + indices = (schedule_timesteps == timestep).nonzero() + pos = 1 if len(indices) > 1 else 0 + return int(indices[pos].item()) + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[float] | None = None, + num_tokens: int | None = None, + ) -> None: + if timesteps is None: + if num_inference_steps is None: + raise ValueError("`num_inference_steps` must be provided when `timesteps` is None.") + timesteps_np = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1] + if self.dynamic_time_shift and num_tokens is not None: + m = np.sqrt(float(num_tokens)) / 40.0 + timesteps_np = timesteps_np / (m - m * timesteps_np + timesteps_np) + else: + timesteps_np = np.asarray(timesteps, dtype=np.float32) + + timesteps_t = torch.from_numpy(timesteps_np).to(dtype=torch.float32, device=device) + self.timesteps = timesteps_t + self._timesteps = torch.cat([timesteps_t, torch.ones(1, device=timesteps_t.device, dtype=timesteps_t.dtype)]) + + self._step_index = None + self._begin_index = None + + def _init_step_index(self, timestep: float | torch.Tensor) -> None: + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + else: + timestep = torch.tensor(timestep, device=self.timesteps.device, dtype=self.timesteps.dtype) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor, + sample: torch.FloatTensor, + generator: torch.Generator | None = None, # noqa: ARG002 - kept for API compatibility + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple[torch.FloatTensor]: + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + "不支持传入整数 timestep(例如 enumerate 的 index);请传入 `scheduler.timesteps` 中的值。", + ) + + if self.step_index is None: + self._init_step_index(timestep) + assert self._step_index is not None + + sample_fp32 = sample.to(torch.float32) + t = self._timesteps[self._step_index] + t_next = self._timesteps[self._step_index + 1] + + prev_sample = sample_fp32 + (t_next - t) * model_output + prev_sample = prev_sample.to(model_output.dtype) + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self) -> int: + return self.num_train_timesteps diff --git a/vllm_omni/model_executor/models/mammoth_moda2/tokenizer.py b/vllm_omni/model_executor/models/mammoth_moda2/tokenizer.py new file mode 100644 index 000000000..6dd78b04b --- /dev/null +++ b/vllm_omni/model_executor/models/mammoth_moda2/tokenizer.py @@ -0,0 +1,400 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by ByteDance Ltd. on 2025-09-30. +# +# Original file was released under Apache-2.0, with the full license text +# available at https://www.apache.org/licenses/LICENSE-2.0 +# +# This modified file is released under the same license. +# +# --- Upstream header preserved below --- +# +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import base64 +import os +import unicodedata +from collections.abc import Collection + +import tiktoken +from loguru import logger +from transformers import AddedToken, PreTrainedTokenizer + +VOCAB_FILES_NAMES = { + "vocab_file": "mammothu.tiktoken", + "special_tokens_file": "mammothu_vision_tokens.txt", +} + +PAT_STR = ( + r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?""" + r"""[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +) +ENDOFTEXT = "<|endoftext|>" +IMSTART = "<|im_start|>" +IMEND = "<|im_end|>" +# as the default behavior is changed to allow special tokens in +# regular texts, the surface forms of special tokens need to be +# as different as possible to minimize the impact +QWEN_SPECIAL_TOKENS = ( + "<|object_ref_start|>", + "<|object_ref_end|>", + "<|box_start|>", + "<|box_end|>", + "<|quad_start|>", + "<|quad_end|>", + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + "<|image_pad|>", + "<|video_pad|>", + "", + "", + "<|fim_prefix|>", + "<|fim_middle|>", + "<|fim_suffix|>", + "<|fim_pad|>", + "<|repo_name|>", + "<|file_sep|>", +) + +# align to qwen2.5 tokenizer length (151846) +EXTRAS = [f"<|extra_{i}|>" for i in range(181)] # 205 - 19[len(QWEN_SPECIAL_TOKENS)] - 5 +# align to qwen2.5 embedding size (152064) +EXTRAS += [f"<|extra_margin_{i}|>" for i in range(152064 - 151846)] +# append new token in gen embedding range +EXTRAS += ["<|endofline|>", "<|endoffile|>", "<|gen_placeholder|>", "<|useless token|>", "<|beginoftext|>"] +EXTRAS = tuple(EXTRAS) +# changed to use actual index to avoid misconfiguration with vocabulary expansion +SPECIAL_START_ID = 151643 + + +def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: + with open(tiktoken_bpe_file, "rb") as f: + contents = f.read() + return { + base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) + } + + +class MammothUTokenizer(PreTrainedTokenizer): + """MammothU tokenizer.""" + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file: str, + special_tokens_file: str, + errors: str = "replace", + bos_token: str = "<|beginoftext|>", + eos_token: str = "<|endoftext|>", + pad_token: str = "<|endoftext|>", + img_token: str = "<|image token|>", + boi_token: str = "<|image start|>", + eoi_token: str = "<|image end|>", + eol_token: str = "<|endofline|>", + eof_token: str = "<|endoffile|>", + **kwargs, + ) -> None: + super().__init__(**kwargs) + + # how to handle errors in decoding UTF-8 byte sequences + # use ignore if you are in streaming inference + self.errors = errors + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) + vision_tokens = [t.strip() for t in open(special_tokens_file).readlines() if len(t.strip()) > 0] + SPECIAL_TOKENS = tuple( + enumerate( + ( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + QWEN_SPECIAL_TOKENS + + EXTRAS + + tuple(vision_tokens) + ), + start=SPECIAL_START_ID, + ) + ) + + self.special_tokens = {token: index for index, token in SPECIAL_TOKENS} + self.special_tokens_set = set(t for _, t in SPECIAL_TOKENS) + + enc = tiktoken.Encoding( + "mammothu", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + + assert len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab, ( + f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" + ) + + self.decoder = {v: k for k, v in self.mergeable_ranks.items()} + self.decoder.update({v: k for k, v in self.special_tokens.items()}) + self.tokenizer = enc + self.eod_id = self.tokenizer.eot_token + self.bos_token = bos_token + self.eos_token = eos_token + self.pad_token = pad_token + self.img_token = img_token + self.boi_token = boi_token + self.eoi_token = eoi_token + self.eol_token = eol_token + self.eof_token = eof_token + self.image_content_token = "<|image_pad|>" # come from Qwen2.5-VL + self.gen_image_token = "<|gen_image_pad|>" + self.gen_image_placeholder_token = "<|gen_placeholder|>" + self.visual_tokens = ["<|image_pad|>", "<|video_pad|>", "<|vision_start|>", "<|vision_end|>"] + self.visual_tokens_ids = [self.get_vocab()[token] for token in self.visual_tokens] + + self.vision_range = (self.get_vocab()[self.boi_token], self.tokenizer.n_vocab - 1) + logger.info(f"MammothUTokeniser Vision range: {self.vision_range}") + + def __getstate__(self): + # for pickle lovers + state = self.__dict__.copy() + del state["tokenizer"] + return state + + def __setstate__(self, state): + # tokenizer is not python native; don't pass it; rebuild it + self.__dict__.update(state) + enc = tiktoken.Encoding( + "mammothu", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + self.tokenizer = enc + + def __len__(self) -> int: + return self.tokenizer.n_vocab + + def get_vocab(self) -> dict[bytes | str, int]: + vocab = self.mergeable_ranks.copy() + vocab.update(self.special_tokens) + return vocab + + @property + def gen_placeholder_id(self): + return self.get_vocab()[self.gen_image_placeholder_token] + + def convert_tokens_to_ids(self, tokens: bytes | str | list[bytes | str]) -> list[int]: + if isinstance(tokens, (str, bytes)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.mergeable_ranks.get(tokens) + + ids = [] + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.mergeable_ranks.get(token)) + return ids + + def _add_tokens( + self, + new_tokens: list[str] | list[AddedToken], + special_tokens: bool = False, + ) -> int: + if not special_tokens and new_tokens: + raise ValueError("Adding regular tokens is not supported") + + added_tokens = 0 + for token in new_tokens: + surface_form = token.content if isinstance(token, AddedToken) else token + if surface_form in self.special_tokens_set: + # Token already exists in our special tokens set + added_tokens += 1 + else: + logger.warning(f"Token {surface_form} is not in the predefined special tokens set and cannot be added") + + return added_tokens + + def add_special_tokens(self, special_tokens_dict: dict[str, str | AddedToken]) -> int: + """ + Add special tokens to the tokenizer and update the special tokens mapping. + Only adds tokens that are already in the special_tokens_set. + + Args: + special_tokens_dict: dictionary of special tokens to add. + The key is the token type and the value is the token to add. + + Returns: + Number of tokens added to the vocabulary. + """ + added_tokens = 0 + + for token_type, token in special_tokens_dict.items(): + if token_type == "additional_special_tokens" and isinstance(token, list): + added_tokens += self._add_tokens(token, special_tokens=True) + else: + token_value = token.content if isinstance(token, AddedToken) else token + if token_value in self.special_tokens_set: + setattr(self, token_type, token_value) + added_tokens += 1 + else: + logger.warning( + f"Token {token_value} is not in the predefined special tokens set and cannot be added" + ) + + return added_tokens + + def save_vocabulary(self, save_directory: str, **kwargs) -> tuple[str]: + """ + Save only the vocabulary of the tokenizer (vocabulary). + + Returns: + `tuple(str)`: Paths to the files saved. + """ + regular_file_path = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) + with open(regular_file_path, "w", encoding="utf8") as w: + for k, v in self.mergeable_ranks.items(): + line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" + w.write(line) + + excluded_special_tokens = set( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + EXTRAS + ) + special_file_path = os.path.join(save_directory, self.vocab_files_names["special_tokens_file"]) + with open(special_file_path, "w", encoding="utf8") as w: + for k in self.special_tokens: + if k not in excluded_special_tokens: + print(k, file=w) + + return (regular_file_path, special_file_path) + + def tokenize( + self, + text: str, + allowed_special: set | str = "all", + disallowed_special: Collection | str = (), + **kwargs, + ) -> list[bytes | str]: + """ + Converts a string in a sequence of tokens. + + Args: + text (`str`): + The sequence to be encoded. + allowed_special (`Literal["all"]` or `set`): + The surface forms of the tokens to be encoded as special tokens in regular texts. + Default to "all". + disallowed_special (`Literal["all"]` or `Collection`): + The surface forms of the tokens that should not be in regular texts and trigger errors. + Default to an empty tuple. + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific encode method. + + Returns: + `list[bytes|str]`: The list of tokens. + """ + tokens = [] + text = unicodedata.normalize("NFC", text) + + # this implementation takes a detour: text -> token id -> token surface forms + for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special): + tokens.append(self.decoder[t]) + + return tokens + + def convert_tokens_to_string(self, tokens: list[bytes | str]) -> str: + """ + Converts a sequence of tokens in a single string. + """ + text = "" + temp = b"" + for t in tokens: + if isinstance(t, str): + if temp: + text += temp.decode("utf-8", errors=self.errors) + temp = b"" + text += t + elif isinstance(t, bytes): + temp += t + else: + raise TypeError("token should only be of type types or str") + if temp: + text += temp.decode("utf-8", errors=self.errors) + return text + + @property + def vocab_size(self) -> int: + return self.tokenizer.n_vocab + + def _convert_id_to_token(self, index: int) -> bytes | str: + """Converts an id to a token, special tokens included""" + if index in self.decoder: + return self.decoder[index] + raise ValueError("unknown ids") + + def _convert_token_to_id(self, token: bytes | str) -> int: + """Converts a token to an id using the vocab, special tokens included""" + if token in self.special_tokens: + return self.special_tokens[token] + if token in self.mergeable_ranks: + return self.mergeable_ranks[token] + raise ValueError("unknown token") + + def _decode( + self, + token_ids: int | list[int], + skip_special_tokens: bool = False, + errors: str | None = None, + **kwargs, + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + + if skip_special_tokens: + token_ids = [i for i in token_ids if i < self.eod_id] + + return self.tokenizer.decode(token_ids, errors=errors or self.errors) + + def bytes_to_str(self, byte_tokens) -> str: + """Convert byte tokens to string representation. + + Args: + byte_tokens: A dictionary where keys are byte objects and values are integers, + or a single byte object. + + Returns: + If input is a dictionary, returns a new dictionary with byte keys converted to strings. + If input is a single byte object, returns the string representation. + """ + if isinstance(byte_tokens, dict): + return {k.decode("utf-8", errors=self.errors): v for k, v in byte_tokens.items() if isinstance(k, bytes)} + if isinstance(byte_tokens, bytes): + return byte_tokens.decode("utf-8", errors=self.errors) + return byte_tokens diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 56bceae41..28f4e4178 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -1,4 +1,8 @@ -from vllm.model_executor.models.registry import _VLLM_MODELS, _LazyRegisteredModel, _ModelRegistry +from vllm.model_executor.models.registry import ( + _VLLM_MODELS, + _LazyRegisteredModel, + _ModelRegistry, +) _OMNI_MODELS = { "Qwen2_5OmniForConditionalGeneration": ( @@ -48,6 +52,31 @@ "qwen3_omni_code2wav", "Qwen3OmniMoeCode2Wav", ), + "MammothModa2Qwen2ForCausalLM": ( + "mammoth_moda2", + "mammoth_moda2_ar", + "MammothModa2Qwen2ForCausalLM", + ), + "MammothModa2ARForConditionalGeneration": ( + "mammoth_moda2", + "mammoth_moda2_ar", + "MammothModa2ARForConditionalGeneration", + ), + "MammothModa2DiTForConditionalGeneration": ( + "mammoth_moda2", + "mammoth_moda2_dit", + "MammothModa2DiTForConditionalGeneration", + ), + "MammothModa2ForConditionalGeneration": ( + "mammoth_moda2", + "mammoth_moda2", + "MammothModa2ForConditionalGeneration", + ), + "Mammothmoda2Model": ( + "mammoth_moda2", + "mammoth_moda2", + "MammothModa2ForConditionalGeneration", + ), } _VLLM_OMNI_MODELS = { @@ -55,7 +84,6 @@ **_OMNI_MODELS, } - OmniModelRegistry = _ModelRegistry( { **{ diff --git a/vllm_omni/model_executor/stage_configs/mammoth_moda2.yaml b/vllm_omni/model_executor/stage_configs/mammoth_moda2.yaml new file mode 100644 index 000000000..8b492862d --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/mammoth_moda2.yaml @@ -0,0 +1,36 @@ +stage_args: + - stage_id: 0 + runtime: + devices: "0" + max_batch_size: 100 + engine_args: + model_stage: ar + model_arch: MammothModa2ForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + max_model_len: 8192 + gpu_memory_utilization: 0.5 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + final_output: false + + - stage_id: 1 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + model_arch: MammothModa2ForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.3 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + enable_prefix_caching: false + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.mammoth_moda2.ar2dit + final_output: true + final_output_type: image diff --git a/vllm_omni/model_executor/stage_input_processors/mammoth_moda2.py b/vllm_omni/model_executor/stage_input_processors/mammoth_moda2.py new file mode 100644 index 000000000..4137a69d1 --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/mammoth_moda2.py @@ -0,0 +1,97 @@ +"""Stage input processor for MammothModa2 (AR -> DiT).""" + +from typing import Any + +import torch +from vllm.inputs import TextPrompt + +from vllm_omni.inputs.data import OmniTokensPrompt + + +def ar2dit( + stage_list: list[Any], + engine_input_source: list[int], + prompts: OmniTokensPrompt | TextPrompt | None = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + """Convert AR stage outputs to DiT stage inputs.""" + + source_stage_id = engine_input_source[0] + ar_outputs = stage_list[source_stage_id].engine_outputs + + dit_inputs: list[OmniTokensPrompt] = [] + for ar_output, prompt in zip(ar_outputs, prompts): + addi_info = prompt["additional_information"] + image_height = addi_info["image_height"][0] + image_width = addi_info["image_width"][0] + text_guidance_scale = addi_info["text_guidance_scale"][0] + cfg_range = addi_info["cfg_range"] + num_inference_steps = addi_info["num_inference_steps"][0] + gen_vocab_start_index = addi_info["visual_token_start_id"][0] + + prompt_token_ids = ar_output.prompt_token_ids + # exclude the last token because it has no corresponding hidden state + gen_token_ids = ar_output.outputs[0].token_ids[:-1] + full_token_ids = prompt_token_ids + gen_token_ids + + full_hidden_states = ar_output.multimodal_output["latent"] + hidden_total = int(full_hidden_states.shape[0]) + assert hidden_total == len(prompt_token_ids) + len(gen_token_ids), ( + f"Hidden states length mismatch: expected {len(prompt_token_ids) + len(gen_token_ids)}, got {hidden_total}" + ) + + mask_device = full_hidden_states.device + full_token_ids_t = torch.tensor(full_token_ids, dtype=torch.long, device=mask_device) + attention_mask = torch.ones_like(full_token_ids_t, dtype=torch.bool) + + L = int(full_token_ids_t.shape[0]) + answer_start_index = max(L - 10, 0) # the last 10 tokens as answer + pos = torch.arange(L, device=mask_device) + questions_mask = pos < answer_start_index + answers_mask = ~questions_mask + + gen_token_mask = full_token_ids_t >= gen_vocab_start_index + + visual_token_mask = torch.zeros_like(gen_token_mask) + visual_ids = [ + 151655, + 151656, + 151652, + 151653, + ] # ["<|image_pad|>", "<|video_pad|>", "<|vision_start|>", "<|vision_end|>"] + visual_token_mask = torch.isin( + full_token_ids_t, + torch.tensor(visual_ids, dtype=torch.long, device=mask_device), + ) + + text_condition_token_mask = questions_mask & ~(visual_token_mask | gen_token_mask) & attention_mask + image_condition_token_mask = answers_mask & gen_token_mask & attention_mask + + text_condition = full_hidden_states[text_condition_token_mask] + image_condition = full_hidden_states[image_condition_token_mask] + + text_prompt_embeds = text_condition.to(dtype=torch.float32).contiguous() + image_prompt_embeds = image_condition.to(dtype=torch.float32).contiguous() + + additional_information = { + "text_prompt_embeds": text_prompt_embeds, + "text_prompt_embeds_shape": list(text_prompt_embeds.shape), + "image_prompt_embeds": image_prompt_embeds, + "image_prompt_embeds_shape": list(image_prompt_embeds.shape), + "image_height": [int(image_height)], + "image_width": [int(image_width)], + "text_guidance_scale": [float(text_guidance_scale)], + "cfg_range": [float(cfg_range[0]), float(cfg_range[1])], + "num_inference_steps": [int(num_inference_steps)], + } + + dit_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0], + additional_information=additional_information, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return dit_inputs diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 17740d858..9d8d2ceb1 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -399,6 +399,10 @@ def _dummy_run( input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None + if hasattr(self.model, "get_dummy_runtime_additional_information"): + runtime_addi = self.model.get_dummy_runtime_additional_information(num_reqs) + model_kwargs["runtime_additional_information"] = runtime_addi + if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens_padded] elif self.uses_xdrope_dim > 0: diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index e426e3938..ea840b849 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -644,6 +644,7 @@ def _gather_runtime_additional_information(self) -> list[dict]: per_req_runtime_info = [] for req_id in self.input_batch.req_ids: req_state = self.requests.get(req_id) + generated_len = len(req_state.output_token_ids) if req_state is not None else 0 info = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None if info and isinstance(info, dict): per_req_runtime_info.append(info) @@ -651,6 +652,7 @@ def _gather_runtime_additional_information(self) -> list[dict]: q = info["thinker_reply_part_per_request"] if hasattr(q, "shape"): logger.debug(f"[OMNI] req={req_id} has thinker_reply_part_per_request queue shape: {q.shape}") + info["generated_len"] = generated_len else: per_req_runtime_info.append({}) return per_req_runtime_info