Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions tests/jax/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def reference_make_row_id_map(

# Compute total tokens per expert and expert offsets
tokens_per_expert = jnp.sum(routing_map, axis=0)
expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]])
expert_offsets = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
)

# Compute destination rows for all (token, expert) pairs
# dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
Expand All @@ -115,7 +117,9 @@ def reference_make_row_id_map(

# Gather the sorted destination rows and expert indices using advanced indexing
# Create indices for gathering
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
token_idx = jnp.broadcast_to(
jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts)
)
sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]

# Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
Expand Down Expand Up @@ -373,11 +377,15 @@ def reference_make_chunk_sort_map(
Row ID map for chunk sorting of shape [num_tokens,].
"""
# Compute source chunk boundaries (cumulative sum of original split_sizes)
src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)])
src_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)]
)

# Compute destination chunk boundaries based on sorted order
sorted_sizes = split_sizes[sorted_indices]
dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)])
dest_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)]
)

# For each source chunk, compute its destination offset
# inverse_indices[i] = position of chunk i in sorted order
Expand All @@ -386,7 +394,7 @@ def reference_make_chunk_sort_map(

# Create row_id_map: for each token position, compute its destination
# First, figure out which chunk each position belongs to
position_indices = jnp.arange(num_tokens)
position_indices = jnp.arange(num_tokens, dtype=jnp.int32)

# chunk_ids[i] = which chunk position i belongs to
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")
Expand Down
Loading