Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
window_block_groups: Optional[torch.Tensor] = None
window_block_usage: Optional[torch.Tensor] = None
window_attn_bias: Optional[torch.Tensor] = None
chunked_slot_mapping: Optional[torch.Tensor] = None
chunked_attn_bias: Optional[torch.Tensor] = None
chunked_block_mapping: Optional[torch.Tensor] = None
chunked_block_list: Optional[torch.Tensor] = None
chunked_block_groups: Optional[torch.Tensor] = None
chunked_block_usage: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -459,6 +465,8 @@ def __init__(
"is not implemented for "
"HPUAttentionImpl")

self.is_chunked_attention = False

def _maybe_init_alibi_biases(
self,
max_seq_len,
Expand Down Expand Up @@ -582,6 +590,9 @@ def forward(
attn_bias = None
window_size = (self.sliding_window, 0)
common_args['window_size'] = window_size
if self.is_chunked_attention and \
hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None:
attn_bias = attn_metadata.chunked_attn_bias

out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
Expand All @@ -602,6 +613,12 @@ def forward(
block_groups = attn_metadata.window_block_groups
block_mapping = attn_metadata.window_block_mapping
attn_bias = attn_metadata.window_attn_bias
elif self.is_chunked_attention and \
attn_metadata.chunked_block_list is not None:
block_list = attn_metadata.chunked_block_list
block_groups = attn_metadata.chunked_block_groups
block_mapping = attn_metadata.chunked_block_mapping
attn_bias = attn_metadata.chunked_attn_bias
else:
block_list = attn_metadata.block_list
block_groups = attn_metadata.block_groups
Expand Down
6 changes: 6 additions & 0 deletions vllm_gaudi/v1/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def make_decode_metadata(cls,
window_block_list,
window_block_usage,
window_block_groups,
chunked_block_list,
chunked_block_usage,
chunked_block_groups,
query_start_loc=None):
return cls(is_prompt=False,
block_mapping=None,
Expand All @@ -100,6 +103,9 @@ def make_decode_metadata(cls,
window_block_list=window_block_list,
window_block_usage=window_block_usage,
window_block_groups=window_block_groups,
chunked_block_list=chunked_block_list,
chunked_block_usage=chunked_block_usage,
chunked_block_groups=chunked_block_groups,
input_positions=input_positions,
slot_mapping=slot_mapping,
block_size=block_size,
Expand Down
149 changes: 133 additions & 16 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,66 @@ def _set_attn_bias_for_sliding_window(self, attn_metadata, batch_size, seq_len,
attn_metadata = prefill_metadata._replace(window_attn_bias=attn_bias)
return attn_metadata

def _set_block_mapping(self, metadata, batch_size, device, dtype, is_window_block=False):
def _set_attn_bias_for_chunked_attention(self, attn_metadata, batch_size, seq_len, chunk_size, device, dtype):
if (attn_metadata is None or not attn_metadata.is_prompt):
return attn_metadata

prefill_metadata = attn_metadata
shift = 0

if self.prefill_use_fusedsdpa and attn_metadata.block_list is not None:

context_lens_t = prefill_metadata.context_lens_tensor
block_list = prefill_metadata.block_list
max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
query_positions = torch.arange(seq_len, device=device)
total_token_positions = context_lens_t.unsqueeze(-1) + query_positions.unsqueeze(0)
which_chunk = (total_token_positions // chunk_size)
chunk_start_positions = which_chunk * chunk_size
invalid_lens_t = chunk_start_positions - 1

past_indices = torch.arange(max_context_len, device=device)
past_mask = (
(past_indices.unsqueeze(0).unsqueeze(0) > invalid_lens_t.unsqueeze(-1)) &
(past_indices.unsqueeze(0).unsqueeze(0) < context_lens_t.unsqueeze(-1).unsqueeze(-1))).unsqueeze(1)

causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=shift)
query_chunk_ids = which_chunk[0]
same_chunk_mask = query_chunk_ids.unsqueeze(0) == query_chunk_ids.unsqueeze(1)

causal_mask = causal_mask & same_chunk_mask
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len)

mask = torch.concat((past_mask, causal_mask), dim=-1)
attn_bias = torch.where(mask, torch.tensor(0.0, dtype=dtype, device=device),
torch.tensor(float('-inf'), dtype=dtype, device=device))
else:
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device, dtype=dtype, fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
idx = torch.arange(seq_len, device=device)
chunk_id = idx // chunk_size
same_chunk = chunk_id.unsqueeze(0) == chunk_id.unsqueeze(1)
same_chunk = same_chunk.unsqueeze(0).unsqueeze(0)
mask = torch.where(same_chunk, mask, torch.tensor(0.0, dtype=dtype, device=device))
attn_bias = torch.log(mask)

attn_metadata = custom_tuple_replace(prefill_metadata, "TrimmedAttentionMetadata", chunked_attn_bias=attn_bias)
return attn_metadata

def _set_block_mapping(self,
metadata,
batch_size,
device,
dtype,
is_window_block=False,
update_for_chunked_attention=False):
if is_window_block:
block_usage = metadata.window_block_usage
block_groups = metadata.window_block_groups
elif update_for_chunked_attention:
block_usage = metadata.chunked_block_usage
block_groups = metadata.chunked_block_groups
else:
block_usage = metadata.block_usage
block_groups = metadata.block_groups
Expand Down Expand Up @@ -533,21 +589,36 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype, is_window_bloc
"TrimmedAttentionMetadata",
window_block_mapping=block_mapping,
window_attn_bias=attn_bias)
elif update_for_chunked_attention:
metadata = custom_tuple_replace(metadata,
"TrimmedAttentionMetadata",
chunked_block_mapping=block_mapping,
chunked_attn_bias=attn_bias)
else:
metadata = custom_tuple_replace(metadata,
"TrimmedAttentionMetadata",
block_mapping=block_mapping,
attn_bias=attn_bias)
return metadata

def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype):
def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype, model_has_chunked_attention=False):
if attn_metadata.is_prompt:
attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype)
if self.interleaved_sliding_window and self.sliding_window is not None:
attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len,
self.sliding_window, device, dtype)
if model_has_chunked_attention:
attn_metadata = self._set_attn_bias_for_chunked_attention(
attn_metadata, batch_size, seq_len, self.model.config.text_config.attention_chunk_size, device,
dtype)
else:
attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype)
if model_has_chunked_attention:
attn_metadata = self._set_block_mapping(attn_metadata,
batch_size,
device,
dtype,
update_for_chunked_attention=True)
if self.interleaved_sliding_window and self.sliding_window is not None:
attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True)
return attn_metadata
Expand All @@ -564,9 +635,11 @@ def forward(self, *args, **kwargs):
if 'warmup_mode' in kwargs:
kwargs.pop('warmup_mode')
input_ids = kwargs['input_ids']
model_has_chunked_attention = kwargs.pop('model_has_chunked_attention', False)
if not self.unified_attn:
kwargs['attn_metadata'] = self._update_metadata(kwargs['attn_metadata'], input_ids.size(0),
input_ids.size(1), input_ids.device, self.dtype)
input_ids.size(1), input_ids.device, self.dtype,
model_has_chunked_attention)
if self._rotary_prepare_cos_sin is not None:
self._rotary_prepare_cos_sin(kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin)
attn_meta = kwargs.pop('attn_metadata')
Expand Down Expand Up @@ -691,7 +764,12 @@ def trim_attn_metadata(metadata: HPUAttentionMetadataV1) -> object:
'window_block_usage',
'window_block_groups',
'window_attn_bias',
])
'chunked_block_mapping',
'chunked_attn_bias',
'chunked_block_list',
'chunked_block_usage',
'chunked_block_groups'
]) # yapf: disable
return attention_metadata


Expand Down Expand Up @@ -926,6 +1004,8 @@ def __init__(
self.scheduler_output: SchedulerOutput | None = None
self.warmup_mode: bool = False
self.batch_changed: bool = False
# WA for chunked attention support
self.model_has_chunked_attention = False

assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!'
assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!'
Expand Down Expand Up @@ -1504,6 +1584,18 @@ def _get_num_decodes(self) -> int:
num_decodes += 1
return num_decodes

def maybe_set_chunked_attention_layers(self, model):
if hasattr(model.config, 'text_config'): # noqa: SIM102
if hasattr(model.config.text_config, 'attention_chunk_size'): # noqa: SIM102
if model.config.text_config.attention_chunk_size > 0:
self.model_has_chunked_attention = True
try:
for layer in model.language_model.model.layers:
if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__:
layer.self_attn.attn.impl.is_chunked_attention = True
except Exception:
pass

def _get_prompts_and_decodes(
self,
scheduler_output: "SchedulerOutput",
Expand Down Expand Up @@ -2177,6 +2269,18 @@ def _create_decode_input_data(self,
window_block_tables, slot_mapping.tolist(),
padded_batch_size * num_tokens)

if self.model_has_chunked_attention:
chunk_size = (self.model.model.config.text_config.attention_chunk_size // self.block_size)
seq_lens_block = [len(block_table) for block_table in block_tables_list]
num_seq_chunks = [math.ceil(sl / chunk_size) - 1 for sl in seq_lens_block]
block_tables_chunk = [
block_table[num_seq_chunks[i] * chunk_size:] for i, block_table in enumerate(block_tables_list)
]
chunked_block_list, chunked_block_groups, chunked_block_usage = \
self.get_habana_paged_attn_buffers(
block_tables_chunk, slot_mapping.tolist(),
padded_batch_size * num_tokens)

# CPU<>HPU sync *should not* happen here.
block_list_device = async_h2d_copy(block_list, device=self.device)
block_usage_device = async_h2d_copy(block_usage, device=self.device)
Expand All @@ -2191,6 +2295,12 @@ def _create_decode_input_data(self,
window_block_groups_device = async_h2d_copy(
window_block_groups,
device=self.device) if self.interleaved_sliding_window and self.sliding_window is not None else None
chunked_block_list_device = async_h2d_copy(chunked_block_list,
device=self.device) if self.model_has_chunked_attention else None
chunked_block_usage_device = async_h2d_copy(chunked_block_usage,
device=self.device) if self.model_has_chunked_attention else None
chunked_block_groups_device = async_h2d_copy(chunked_block_groups,
device=self.device) if self.model_has_chunked_attention else None

token_ids_device = async_h2d_copy(token_ids, device=self.device)
# when DP also enabled, some DP ranks will exeucte dummy run with empty
Expand Down Expand Up @@ -2221,21 +2331,26 @@ def _create_decode_input_data(self,
spec_decode_metadata = None
logits_indices_device = async_h2d_copy(logits_indices, device=self.device)

attn_metadata = HPUAttentionMetadataV1.make_decode_metadata(
block_list=block_list_device,
block_usage=block_usage_device,
block_groups=block_groups_device,
input_positions=None,
slot_mapping=slot_mapping_device,
block_size=self.block_size,
window_block_list=window_block_list_device,
window_block_usage=window_block_usage_device,
window_block_groups=window_block_groups_device,
chunked_block_list=chunked_block_list_device,
chunked_block_usage=chunked_block_usage_device,
chunked_block_groups=chunked_block_groups_device,
)

return DecodeInputData(num_decodes=num_decodes,
token_ids=token_ids_device,
position_ids=positions_device,
logits_indices=logits_indices_device,
attn_metadata=HPUAttentionMetadataV1.make_decode_metadata(
block_list=block_list_device,
block_usage=block_usage_device,
block_groups=block_groups_device,
input_positions=None,
slot_mapping=slot_mapping_device,
block_size=self.block_size,
window_block_list=window_block_list_device,
window_block_usage=window_block_usage_device,
window_block_groups=window_block_groups_device,
),
attn_metadata=attn_metadata,
spec_decode_metadata=spec_decode_metadata)

def _prepare_decode_inputs(self,
Expand Down Expand Up @@ -2583,6 +2698,8 @@ def _execute_model_generic(self,
else:
# no hpu graphs for t.compile?
use_graphs = False
if self.model_has_chunked_attention:
additional_kwargs.update({"model_has_chunked_attention": True})
trimmed_attn_metadata = attn_metadata if self.unified_attn else trim_attn_metadata(attn_metadata)
if self.is_driver_worker:
model_event_name = ("model_forward_"
Expand Down Expand Up @@ -3637,7 +3754,7 @@ def load_model(self) -> None:
elif not is_fake_hpu():
self.model = self.model.to("hpu")
htcore.mark_step()

self.maybe_set_chunked_attention_layers(self.model)
hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
model_config = getattr(self.model, "config", None)
modify_model_layers(self.model,
Expand Down
Loading