Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/trtt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import sys

from ttrt import main

if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ repos:
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.7
hooks:
Expand Down
2 changes: 1 addition & 1 deletion examples/custom_dm_matmul.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from ttlang.ttl_api import *
from ttlang.utils.correctness import assert_allclose
import torch


@pykernel_gen(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import ttnn
import pytest
import torch

import ttnn
from ttlang.utils.correctness import assert_with_ulp


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
# up to tt-lang spec, not intended to compile or run currently
import ttnn
import pytest
import torch

from ttl import Program, make_circular_buffer_like, copy

from ttlang.utils.correctness import assert_with_ulp
import ttnn
from ttl import Program, copy, make_circular_buffer_like
from ttlang.utils.block_allocation import split_work_to_cores
from ttlang.utils.correctness import assert_with_ulp


def get_number_of_cores(grid_range):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import ttnn
import pytest
import torch
import matplotlib.pyplot as plt
import numpy as np

from ttlang.utils.correctness import assert_with_ulp
import pytest
import torch
import ttnn
from ttlang.utils.block_allocation import get_large_matmul_params
from ttlang.utils.correctness import assert_with_ulp


@pytest.mark.parametrize("M,K,N", [(640, 640, 640)])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import ttnn
import pytest
import torch

from ttl import Program, make_circular_buffer_like, copy, core
import ttnn
from metal_examples.utils import assert_with_ulp
from ttl import Program, copy, core, make_circular_buffer_like


@ttl.kernel(grid=(13, 10))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import ttnn
import pytest
import torch

import ttnn
from ttlang.utils.correctness import assert_with_ulp


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# up to tt-lang spec, not intended to compile or run currently
import sys
from pathlib import Path
import ttnn

import pytest
import torch

from ttl import Program, make_circular_buffer_like, copy

import ttnn
from ttl import Program, copy, make_circular_buffer_like
from ttlang.utils.correctness import assert_with_ulp


Expand Down
2 changes: 1 addition & 1 deletion examples/sim/eltwise_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0
# type: ignore
from typing import TYPE_CHECKING
import math
from typing import TYPE_CHECKING

from sim import ttl, ttnn
from sim.testing import assert_pcc
Expand Down
4 changes: 2 additions & 2 deletions examples/sim/eltwise_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
# type: ignore
from typing import TYPE_CHECKING
import math
from typing import TYPE_CHECKING

from sim import ttl, ttnn
from sim.typedefs import Pipe
from sim.testing import assert_pcc
from sim.typedefs import Pipe

if TYPE_CHECKING:
from sim.pykernel_env import granularity
Expand Down
4 changes: 2 additions & 2 deletions examples/sim/eltwise_pipe_core3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
# type: ignore
from typing import TYPE_CHECKING
import math
from typing import TYPE_CHECKING

from sim import ttl, ttnn
from sim.typedefs import Pipe
from sim.testing import assert_pcc
from sim.typedefs import Pipe

if TYPE_CHECKING:
from sim.pykernel_env import granularity
Expand Down
2 changes: 1 addition & 1 deletion examples/test_accessor_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
defined in the DSL are materialized as stream_layout ops immediately
in the Python-generated IR, rather than being added later by a pass.
"""
import torch
from ttlang.ttl_api import *
from ttlang.utils.correctness import assert_allclose
import torch


@pykernel_gen(
Expand Down
2 changes: 1 addition & 1 deletion examples/test_simple_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
Tests the complete data path: host → L1 → compute → L1 → host
"""

from ttlang.ttl_api import *
import torch
from ttlang.ttl_api import *


@pykernel_gen(grid=(1, 1), block_factors=[(1, 1), (1, 1), (1, 1)])
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,12 @@ reportUntypedFunctionDecorator = "error"
reportPrivateUsage = "error"
reportImportCycles = "error"
reportUnnecessaryIsInstance = "warning"

[tool.isort]
profile = "black"
line_length = 88
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
2 changes: 1 addition & 1 deletion python/pykernel/_src/base_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import ast
import inspect

from ttmlir.dialects import emitc, func
from ttmlir.ir import *
from ttmlir.dialects import func, emitc


class PyKernelAstBase(ast.NodeVisitor):
Expand Down
4 changes: 2 additions & 2 deletions python/pykernel/_src/kernel_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import ast
import inspect

from ttmlir.dialects import arith, emitc, func, memref, scf
from ttmlir.ir import *
from ttmlir.dialects import func, scf, arith, memref, emitc

from .kernel_types import ClassRegistry
from .base_ast import PyKernelAstBase
from .kernel_types import ClassRegistry
from .utils import _cast, _get_type_str


Expand Down
3 changes: 2 additions & 1 deletion python/pykernel/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# NOTE: This file was copied from tt-mlir/tools/pykernel/_src/utils.py
# and cleaned up to remove unused code (_discover_dialect_ops).

import textwrap
import inspect
import textwrap
from typing import Callable

from ttmlir.dialects import arith
from ttmlir.ir import *

Expand Down
5 changes: 2 additions & 3 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import pathlib
import shutil
import subprocess

from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext
from datetime import datetime

from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext

readme = None

Expand Down
25 changes: 12 additions & 13 deletions python/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,30 @@
sim package: simulation components for TT-Lang including circular buffers, tensors, and copy operations.
"""

from . import ttnnsim as ttnn
from .cbapi import CBAPI, CBStats
from .typedefs import CoreIndex, Shape, Pipe
from .constants import TILE_SHAPE, MAX_CBS
from .copy import copy, CopyTransaction
from .program import Program
from .kernel import core
from .constants import MAX_CBS, TILE_SHAPE
from .copy import CopyTransaction, copy
from .decorators import compute, datamovement
from .kernel import kernel
from .pipe import if_pipe_src, if_pipe_dst, core_in_pipe
from . import ttnnsim as ttnn
from .kernel import core, kernel
from .pipe import core_in_pipe, if_pipe_dst, if_pipe_src
from .program import Program
from .typedefs import CoreIndex, Pipe, Shape


# Create ttl namespace object
class _TTLNamespace:
"""TT-Lang namespace for DSL constructs."""

def __init__(self):
from .kernel import kernel, grid_size, core
from .cb import make_circular_buffer_like
from .constants import TILE_SHAPE
from .copy import copy
from .decorators import compute, datamovement
from .kernel import core, grid_size, kernel
from .pipe import core_in_pipe, if_pipe_dst, if_pipe_src
from .program import Program
from .copy import copy
from .typedefs import Pipe, Size, Shape
from .constants import TILE_SHAPE
from .pipe import if_pipe_src, if_pipe_dst, core_in_pipe
from .typedefs import Pipe, Shape, Size

self.kernel = kernel
self.grid_size = grid_size
Expand Down
8 changes: 5 additions & 3 deletions python/sim/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
"""

import operator as _op
from typing import List, Sequence, Any, Union, Callable
from .typedefs import Size, Index, Span
from typing import Any, Callable, List, Sequence, Union

from pydantic import validate_call

from .cbstate import CBSlot
from .ttnnsim import Tensor
from pydantic import validate_call
from .typedefs import Index, Size, Span


# Notice that get_read_ptr and get_write_ptr return a C++ pointer which does not
Expand Down
8 changes: 4 additions & 4 deletions python/sim/cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
operations.
"""

from typing import Tuple, Optional, Union, List, Callable
from types import TracebackType
from typing import Callable, List, Optional, Tuple, Union

import torch

from .cbapi import CBAPI
from .block import Block
from .typedefs import CBID, Size, Shape
from .ttnnsim import Tensor
from .cbapi import CBAPI
from .constants import TILE_SHAPE
from .ttnnsim import Tensor
from .typedefs import CBID, Shape, Size


class _BlockContextManager:
Expand Down
12 changes: 7 additions & 5 deletions python/sim/cbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
"""

import threading
from typing import List, Optional, Annotated, NamedTuple, Any
from pydantic import validate_call, Field
from .errors import CBContractError, CBTimeoutError
from .constants import MAX_CBS, CB_DEFAULT_TIMEOUT
from .typedefs import Size, CBID
from typing import Annotated, Any, List, NamedTuple, Optional

from pydantic import Field, validate_call

from .block import Block
from .cbstate import CBState
from .constants import CB_DEFAULT_TIMEOUT, MAX_CBS
from .errors import CBContractError, CBTimeoutError
from .typedefs import CBID, Size


class CBStats(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion python/sim/cbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from threading import Condition, RLock, Thread
from typing import List, Optional
from .typedefs import Size, Index, Count, Span

from .errors import CBContractError, CBNotConfigured
from .ttnnsim import Tensor
from .typedefs import Count, Index, Size, Span

# Type alias for circular buffer slots
CBSlot = Optional[Tensor]
Expand Down
5 changes: 3 additions & 2 deletions python/sim/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
Constants for the cbsim module.
"""
from typing import Any, cast
from .typedefs import Shape, CBID

from pydantic.fields import FieldInfo
from annotated_types import Lt # type that holds the 'lt' constraint
from pydantic.fields import FieldInfo

from .typedefs import CBID, Shape


def _extract_max_cbs_from_cbid() -> int:
Expand Down
2 changes: 1 addition & 1 deletion python/sim/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"""

from .copyhandlers import (
CopyTransferHandler,
CopyEndpoint,
CopyEndpointType,
CopyTransferHandler,
handler_registry,
)

Expand Down
Loading