Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union

import torch
from torch._ops import OpOverloadPacket
Expand Down Expand Up @@ -576,17 +576,21 @@ def _store_extra_arg(
else:
self._extra_args[name] = None

def _get_unique_value(self, occupied: Set[int], max_val: int) -> int:
def _get_unique_value(self, occupied: Sequence[int], max_val: int) -> int:
"""Get un unoccupied value from the range indicated by max_val.

In addition, this function performs a sanity check to ensure that no value in the occupied
set is out of bounds.
sequence is out of bounds.
"""
full_range = set(range(max_val))
free_values = full_range - occupied
out_of_range = occupied - full_range
# Validate without materializing the full range set
out_of_range = [v for v in occupied if v < 0 or v >= max_val]
assert not out_of_range, f"Out of range values: {out_of_range}"
return free_values.pop() if free_values else 0

# Return the smallest free value; fall back to 0 if none
for candidate in range(max_val):
if candidate not in occupied:
return candidate
return 0

@nvtx_range("ad_nest_sequences")
def nest_sequences(
Expand Down Expand Up @@ -632,13 +636,13 @@ def nest_sequences(
cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence(
page_assignments
)
free_cache_loc = self._get_unique_value(set(cache_loc), self.num_pages)
free_cache_loc = self._get_unique_value(cache_loc, self.num_pages)
self._store_arg("cache_loc", cache_loc, reset_val=free_cache_loc)
self._store_arg("pages_per_seq", pages_per_seq, reset_val=1)

# check for updated slot_idx
if slot_idx is not None:
free_slot_idx = self._get_unique_value(set(slot_idx), self.max_batch_size)
free_slot_idx = self._get_unique_value(slot_idx, self.max_batch_size)
self._store_arg("slot_idx", slot_idx, reset_val=free_slot_idx)

### UPDATE MAIN INPUTS #####################################################################
Expand Down
Loading