Skip to content

Commit

Permalink
#0: Fix Qwen on N150 using old reshape syntax, fix in0_block_w for 2d mm
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Feb 10, 2025
1 parent 5c37afe commit 9f614ad
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions models/demos/llama3/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import ttnn
from models.common.lightweightmodule import LightweightModule
from models.demos.llama3.tt.llama_ccl import tt_all_reduce, tt_all_gather
from models.demos.llama3.tt.llama_common import first_five
from models.demos.llama3.tt.load_checkpoints import permute


class TtLlamaAttention(LightweightModule):
Expand Down Expand Up @@ -138,7 +136,9 @@ def __init__(
)
# as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size
self.wqkv_bias_prefill = ttnn.reshape(
self.wqkv_bias_prefill, ttnn.Shape([1, 1, 1, self.wqkv_bias_prefill.shape[-1]])
self.wqkv_bias_prefill,
(1, 1, 1, self.wqkv_bias_prefill.shape[-1]),
(1, 1, self.wqkv_bias_prefill.shape[-2], self.wqkv_bias_prefill.shape[-1]),
)

# Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def matmul_config(
) # TODO: Needed for TG hang workaround

if in0_block_w is None:
in0_block_w = min(4, max(1, k // (self.tile_size * grid_size[0])))
in0_block_w = self.find_largest_divisor(k // (self.tile_size * grid_size[1]))

return ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
Expand Down

0 comments on commit 9f614ad

Please sign in to comment.