Skip to content

Commit 91a0ac9

Browse files
authored
[TC] Fix to graph break inside set_block_mapping (HabanaAI#1143)
This PR is fixing graph break caused by torch compilation of set_block_mapping. Problem is that _replace function of tuple is not supported by torch compilation process (cannot be inlined). This PR is recasted and stripped version of HabanaAI#1073 (outdated).
1 parent f191153 commit 91a0ac9

1 file changed

Lines changed: 19 additions & 5 deletions

File tree

vllm/worker/hpu_model_runner.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,11 @@ def subtuple(obj: object,
108108
else:
109109
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
110110
if typename not in _TYPE_CACHE:
111-
_TYPE_CACHE[typename] = collections.namedtuple(typename,
112-
' '.join(fields))
113-
return _TYPE_CACHE[typename](**values)
111+
_TYPE_CACHE[typename] = {
112+
'object': collections.namedtuple(typename, ' '.join(fields)),
113+
'fields': fields
114+
}
115+
return _TYPE_CACHE[typename]['object'](**values) # type: ignore
114116

115117

116118
def align_workers(value, op):
@@ -340,8 +342,20 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
340342
block_groups.masked_fill_(oob_values, batch_size)
341343
metadata = metadata._replace(block_groups=block_groups)
342344
block_mapping = block_mapping.to(dtype)
343-
metadata = metadata._replace(block_mapping=block_mapping,
344-
attn_bias=attn_bias)
345+
346+
# Torch compile dynamo doesn't support calling any named tuple
347+
# dynamic methods other than len and get_attr so we need to
348+
# mimic behaviour of tuple._replace manually
349+
TrimmedAttentionMetadata = _TYPE_CACHE['TrimmedAttentionMetadata'][
350+
'object']
351+
fields = _TYPE_CACHE['TrimmedAttentionMetadata']['fields']
352+
metadata_dict = {
353+
field: getattr(metadata, field)
354+
for field in fields # type: ignore
355+
} # type: ignore
356+
metadata_dict['attn_bias'] = attn_bias
357+
metadata_dict['block_mapping'] = block_mapping
358+
metadata = TrimmedAttentionMetadata(**metadata_dict) # type: ignore
345359
return metadata
346360

347361
def _set_indices_and_offsets(self, metadata, block_size, is_prompt):

0 commit comments

Comments
 (0)