diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 43f2553eed..d61ea8eb75 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -97,7 +97,9 @@ def reference_make_row_id_map( # Compute total tokens per expert and expert offsets tokens_per_expert = jnp.sum(routing_map, axis=0) - expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]]) + expert_offsets = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)] + ) # Compute destination rows for all (token, expert) pairs # dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1 @@ -115,7 +117,9 @@ def reference_make_row_id_map( # Gather the sorted destination rows and expert indices using advanced indexing # Create indices for gathering - token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts)) + token_idx = jnp.broadcast_to( + jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts) + ) sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices] # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed] @@ -373,11 +377,15 @@ def reference_make_chunk_sort_map( Row ID map for chunk sorting of shape [num_tokens,]. """ # Compute source chunk boundaries (cumulative sum of original split_sizes) - src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) + src_cumsum = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)] + ) # Compute destination chunk boundaries based on sorted order sorted_sizes = split_sizes[sorted_indices] - dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)]) + dest_cumsum = jnp.concatenate( + [jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)] + ) # For each source chunk, compute its destination offset # inverse_indices[i] = position of chunk i in sorted order @@ -386,7 +394,7 @@ def reference_make_chunk_sort_map( # Create row_id_map: for each token position, compute its destination # First, figure out which chunk each position belongs to - position_indices = jnp.arange(num_tokens) + position_indices = jnp.arange(num_tokens, dtype=jnp.int32) # chunk_ids[i] = which chunk position i belongs to chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")