Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions awex/util/tensor_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def group_tensors_by_shape_and_dtype(
final_tensor_groups = []
metadata = []

for (_shape, _dtype), group_tensors in tensor_groups.items():
for _, group_tensors in tensor_groups.items():
# Sort by original index to maintain order
group_tensors.sort(key=lambda x: x[0])
# Split into multiple groups - try to group tensors efficiently
Expand All @@ -227,7 +227,8 @@ def group_tensors_by_shape_and_dtype(
# Check if this tensor can fit in current group
if current_group_size > max_tensor_size:
# Finalize current group and start new one
concatenated = torch.cat(current_group, dim=0).contiguous()
# Use clone() to ensure a copy so caller can safely release original tensors
concatenated = torch.cat(current_group, dim=0).clone()
final_tensor_groups.append(concatenated)
# Record metadata for tensors in this group
offset_elements = 0
Expand All @@ -251,7 +252,8 @@ def group_tensors_by_shape_and_dtype(

# Finalize any remaining group
if current_group:
concatenated = torch.cat(current_group, dim=0).contiguous()
# Use clone() to ensure a copy so caller can safely release original tensors
concatenated = torch.cat(current_group, dim=0).clone()
final_tensor_groups.append(concatenated)
# Record metadata for tensors in this group
offset_elements = 0
Expand Down