Skip to content

Commit 3457e15

Browse files
committed
Signed-off-by: HonestDeng <[email protected]>
1 parent f70cb6a commit 3457e15

File tree

6 files changed

+34
-18
lines changed

6 files changed

+34
-18
lines changed

vllm_omni/entrypoints/omni_stage.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,13 +497,15 @@ def _stage_worker(
497497
num_devices_to_lock = len(devices_to_lock)
498498

499499
logger.debug(
500-
"[Stage-%s] Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d; will lock %d devices: %s",
500+
"[Stage-%s] Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d "
501+
"(devices_per_stage=%d); will lock %d devices: %s",
501502
stage_id,
502503
tensor_parallel_size,
503504
pipeline_parallel_size,
504505
data_parallel_size,
505506
prefill_context_parallel_size,
506507
sequence_parallel_size,
508+
num_devices_per_stage,
507509
num_devices_to_lock,
508510
devices_to_lock,
509511
)
@@ -972,7 +974,8 @@ async def _stage_worker_async(
972974
# Check if we've been waiting too long
973975
if _time.time() - wait_start > max_wait_time:
974976
logger.warning(
975-
"[Stage-%s] Timeout waiting for device %s initialization lock, proceeding anyway",
977+
"[Stage-%s] Timeout waiting for device %s "
978+
"initialization lock, proceeding anyway",
976979
stage_id,
977980
device_id,
978981
)

vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2_ar.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,9 @@ def forward(
619619
inputs_embeds: torch.Tensor | None = None,
620620
**kwargs: Any,
621621
):
622-
# vllm-omni runner passes sampling_metadata and runtime_additional_information in each forward step.
623-
# compute_logits is called immediately after forward, so caching here enables step-by-step dynamic token constraints.
622+
# vllm-omni runner passes sampling_metadata and runtime_additional_information
623+
# in each forward step. compute_logits is called immediately after
624+
# forward, so caching here enables step-by-step dynamic token constraints.
624625
runtime_infos = kwargs.get("runtime_additional_information")
625626
self._last_runtime_additional_information = runtime_infos if isinstance(runtime_infos, list) else None
626627
hidden_states = super().forward(

vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit/attention_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def __call__(
458458
key = key.transpose(1, 2)
459459
value = value.transpose(1, 2)
460460

461-
# explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
461+
# Explicitly repeat key and value to match query length; otherwise using
462+
# enable_gqa=True can fall back to the MATH backend of SDPA in our
463+
# PyTorch 2.6 tests.
462464
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
463465
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
464466

vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit/block_lumina2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def forward(
121121
x: torch.Tensor,
122122
conditioning_embedding: torch.Tensor,
123123
) -> torch.Tensor:
124-
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
124+
# Convert back to the original dtype in case `conditioning_embedding`
125+
# is upcasted to float32 (needed for hunyuanDiT).
125126
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
126127
x = self.norm(x) * (1 + scale)[:, None, :]
127128

vllm_omni/model_executor/models/mammoth_moda2/mammothmoda2_dit/transport/transport.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,12 @@ def prior_logp(self, z):
9090
"""
9191
shape = th.tensor(z.size())
9292
N = th.prod(shape[1:])
93-
_fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
94-
return th.vmap(_fn)(z)
93+
94+
# Use a nested def (instead of lambda) to satisfy ruff E731.
95+
def _prior_logp_one(x):
96+
return -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
97+
98+
return th.vmap(_prior_logp_one)(z)
9599

96100
def check_interval(
97101
self,
@@ -154,7 +158,7 @@ def sample(self, x1, process_index, num_processes):
154158
t[_] = 0.0
155159
# print(t)
156160
else:
157-
raise NotImplementedError("Not implemented snr_type %s" % self.snr_type)
161+
raise NotImplementedError(f"Not implemented snr_type {self.snr_type}")
158162

159163
if self.do_shift:
160164
if self.dynamic_time_shift:
@@ -306,16 +310,18 @@ def get_score(
306310
"""member function for obtaining score of
307311
x_t = alpha_t * x + sigma_t * eps"""
308312
if self.model_type == ModelType.NOISE:
309-
score_fn = (
310-
lambda x, t, model, **kwargs: model(x, t, **kwargs)
311-
/ -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
312-
)
313+
314+
def score_fn(x, t, model, **kwargs):
315+
sigma = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
316+
return model(x, t, **kwargs) / -sigma
313317
elif self.model_type == ModelType.SCORE:
314-
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
318+
319+
def score_fn(x, t, model, **kwargs):
320+
return model(x, t, **kwargs)
315321
elif self.model_type == ModelType.VELOCITY:
316-
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
317-
model(x, t, **kwargs), x, t
318-
)
322+
323+
def score_fn(x, t, model, **kwargs):
324+
return self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
319325
else:
320326
raise NotImplementedError()
321327

vllm_omni/model_executor/models/mammoth_moda2/tokenization_mammothmoda2_qwen2_5_vl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@
4444
"special_tokens_file": "mammothu_vision_tokens.txt",
4545
}
4646

47-
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
47+
PAT_STR = (
48+
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?"""
49+
r"""[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
50+
)
4851
ENDOFTEXT = "<|endoftext|>"
4952
IMSTART = "<|im_start|>"
5053
IMEND = "<|im_end|>"

0 commit comments

Comments
 (0)