Skip to content

Commit 6f40799

Browse files
committed
remove explicit ttl.block* since ttl implicitly operates on blocks. Add table of python ttl -> ttl dialect mapping
1 parent 3d22048 commit 6f40799

8 files changed

+286
-4648
lines changed

docs/TT-Lang.md

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ The ttl.core function returns *core coordinates* of the current Tensix core. Cor
9393
x = ttl.core(dims = 1)
9494

9595
# for (8, 8, 8) multi-chip grid gets x = [0, 8), y = [0, 64)
96-
x, y = ttl.core(dims = 2)
96+
x, y = ttl.core(dims = 2)
9797

9898
# for (8, 8) single-chip gets x = [0, 8), y = [0, 8), z = 0
99-
x, y, z = ttl.core(dims = 3)
99+
x, y, z = ttl.core(dims = 3)
100100
```
101101

102102
## Circular buffer
@@ -108,7 +108,7 @@ There are two acquisition functions on a circular buffer object: wait and reserv
108108
## Example
109109

110110
```py
111-
x_cb = ttl.make_circular_buffer_like(x,
111+
x_cb = ttl.make_circular_buffer_like(x,
112112
shape = (2, 2),
113113
buffer_factor = 2)
114114

@@ -129,7 +129,7 @@ There are two acquisition functions on a circular buffer object: wait and reserv
129129
x_cb.pop() # explicit
130130
```
131131

132-
##
132+
##
133133

134134
| Type alias/Function | Description |
135135
| :---- | :---- |
@@ -152,7 +152,7 @@ A *block* represents memory acquired from a circular buffer. Block size is deter
152152
# acquire a_blk and b_blk ...
153153

154154
# source is a tensor slice, destination is a block
155-
a_xf = ttl.copy(a[n], a_blk)
155+
a_xf = ttl.copy(a[n], a_blk)
156156
b_xf = ttl.copy(b[n], b_blk)
157157
a_xf.wait()
158158
b_xf.wait()
@@ -183,7 +183,7 @@ A *block* represents memory acquired from a circular buffer. Block size is deter
183183
# acquire a_blk, b_blk and c_blk ...
184184

185185
# source is a tensor slice, destination is a block
186-
a_xf = ttl.copy(a[0], a_blk)
186+
a_xf = ttl.copy(a[0], a_blk)
187187
b_xf = ttl.copy(b[0], b_blk)
188188
c_xf = ttl.copy(c[0], c_blk)
189189
a_xf.wait()
@@ -214,7 +214,7 @@ A *block* represents memory acquired from a circular buffer. Block size is deter
214214
| ttl.math.sqrt(expr: ttl.BlockExpr) \-\> ttl.BlockExpr | Example of math function. |
215215
| ttl.BlockExpr.\_\_add\_\_( self, other: ttl.BlockExpr) \-\> ttl.BlockExpr | Example of math operator. |
216216

217-
##
217+
##
218218

219219
## Pipe
220220

@@ -241,7 +241,7 @@ A pipe net is constructed in the scope of the kernel function but can only be us
241241
| @property ttl.SrcPipeIdentity\[DstT\].dst(self) \-\> DstT | Get destination core or core range for pipe in if\_src. |
242242
| @property ttl.DstPipeIdentity.src(self) \-\> ttl.CoreCoord | Get source core for pipe in if\_dst. |
243243

244-
##
244+
##
245245

246246
## Gather example
247247

@@ -485,7 +485,7 @@ ttl.copy function expresses a variety of data movements that always have two arg
485485

486486
## Semaphore
487487

488-
A *semaphore* is a communication primitive for general synchronization between data movement threads on different Tensix cores. Each semaphore has an associated 32-bit unsigned integer *semaphore value* for each Tensix core. This value can be changed (set or incremented) by a data movement thread on the local or a remote core. When changing semaphore value remotely a single core coordinate for unicast change or a core range for multicast change is specified. Only setting the semaphore value is supported as a multicast change. A data movement thread can wait on a semaphore until its value satisfies a condition. It is possible to specify either a condition with exact value or a condition with minimum value. Only local data movement threads can wait on a semaphore.
488+
A *semaphore* is a communication primitive for general synchronization between data movement threads on different Tensix cores. Each semaphore has an associated 32-bit unsigned integer *semaphore value* for each Tensix core. This value can be changed (set or incremented) by a data movement thread on the local or a remote core. When changing semaphore value remotely a single core coordinate for unicast change or a core range for multicast change is specified. Only setting the semaphore value is supported as a multicast change. A data movement thread can wait on a semaphore until its value satisfies a condition. It is possible to specify either a condition with exact value or a condition with minimum value. Only local data movement threads can wait on a semaphore.
489489

490490
ttl.Semaphore class is constructed with its initial value that defaults to zero. A ttl.Semaphore instance can be constructed in kernel function scope. A ttl.Semaphore instance provides wait\_eq, wait\_ge and set functions for managing local semaphore value. To change a remote semaphore value an instance of ttl.UnicastRemoteSemaphore or ttl.MulticastRemoteSemaphore is obtained by calling get\_remote and get\_remote\_multicast functions correspondingly. The ttl.UnicastRemoteSemaphore supports inc and set while ttl.MulticastRemoteSemaphore supports only set. Functions that change the value or wait on condition can be used only in the scope of a data movement thread function. Functions that obtain remote semaphores can be used in scopes of both kernel and data movement thread functions.
491491

@@ -564,53 +564,53 @@ ttl.Semaphore class is constructed with its initial value that defaults to zero.
564564
| *Semaphore* | A communication primitive for general synchronization between data movement threads on different Tensix cores. |
565565
| *Semaphore value* | A 32-bit unsigned integer value associated with a semaphore on each Tensix core. This value can be set or incremented by a data movement thread on the local or a remote Tensix core. |
566566

567-
#
567+
#
568568

569569
# Discussion
570570

571571
## Principles
572572

573573
TT-Lang is a Python-based DSL that enables authoring of programs for TT-hardware at the abstraction level similar to SoTA “tile-level” DSLs for GPUs with. Following are the TT-Lang principles, in the order of significance:
574574

575-
1. Ability to express optimizations that achieve **performance within close range (95%)** of hand written TT-Metalium programs;
576-
2. Robust and safe abstractions capable of representing a simplified model of hardware that **eliminates whole classes of mistakes** that are possible when writing TT-Metalium programs; Specifically:
577-
1. Reduce duplication of information that is typical in mult-threaded separation of kernels;
578-
2. Infer CBs operations to eliminate errors in asynchronous code that would be causing hangs or data races (All in single threaded, pop/push guarded by “with” scope in multithreaded);
579-
3. Infer xxx\_init/xxx\_tile(s) etc calls based on functional compute expression;
580-
4. Use compile time memory allocation (DRAM, L1 and DST register) to eliminate OOMs and clobberring at runtime;
581-
5. In addition to (d) use relative memory sizing instead of explicit memory sizing to eliminate OOMs at runtime. With such relative memory sizing the actual size can be maximized at compile time by the allocator or autotuned (see below) at runtime;
582-
3. Allow TT-Lang programs to be **portable across multiple generations** of TT-hardware. Enable generation-specific details to be expressed as autotunable hyper-parameters;
583-
4. SoTA **ergonomics**. Specifically:
584-
1. Functional simulator;
585-
2. VSCode (or similar) integration via language server;
586-
3. As-you-type compilation errors;
587-
4. As-you-type sanitization (based on functional simulator) errors;
588-
5. VSCode integrated line-by-line profiler (ala NSight);
589-
5. Ability to be authored by **Generative AI** from scratch or in translation from “tile-level” DSLs for GPUs. Ability for the compiler, the sanitizer and the simulator to provide ergonomic errors, warnings, correctness and performance feedback for Generative AI to be able to iterate in an agentic workflow.
590-
6. Ability to **autotune** within a space of user-defined hyper-parameters;
591-
7. Ability to **serve as a bootstrap (EmitMetal)** that generates C++ TT-Metalium program for further optimization;
592-
8. Ability to **augment TT-NN programs** with custom TT-Lang kernels;
593-
9. Being **Python-based** as to support a limited subset of Python to express programs as well as being able to integrate into the Python environment. This makes TT-Lang more familiar and convenient for the target audience;
575+
1. Ability to express optimizations that achieve **performance within close range (95%)** of hand written TT-Metalium programs;
576+
2. Robust and safe abstractions capable of representing a simplified model of hardware that **eliminates whole classes of mistakes** that are possible when writing TT-Metalium programs; Specifically:
577+
1. Reduce duplication of information that is typical in mult-threaded separation of kernels;
578+
2. Infer CBs operations to eliminate errors in asynchronous code that would be causing hangs or data races (All in single threaded, pop/push guarded by “with” scope in multithreaded);
579+
3. Infer xxx\_init/xxx\_tile(s) etc calls based on functional compute expression;
580+
4. Use compile time memory allocation (DRAM, L1 and DST register) to eliminate OOMs and clobberring at runtime;
581+
5. In addition to (d) use relative memory sizing instead of explicit memory sizing to eliminate OOMs at runtime. With such relative memory sizing the actual size can be maximized at compile time by the allocator or autotuned (see below) at runtime;
582+
3. Allow TT-Lang programs to be **portable across multiple generations** of TT-hardware. Enable generation-specific details to be expressed as autotunable hyper-parameters;
583+
4. SoTA **ergonomics**. Specifically:
584+
1. Functional simulator;
585+
2. VSCode (or similar) integration via language server;
586+
3. As-you-type compilation errors;
587+
4. As-you-type sanitization (based on functional simulator) errors;
588+
5. VSCode integrated line-by-line profiler (ala NSight);
589+
5. Ability to be authored by **Generative AI** from scratch or in translation from “tile-level” DSLs for GPUs. Ability for the compiler, the sanitizer and the simulator to provide ergonomic errors, warnings, correctness and performance feedback for Generative AI to be able to iterate in an agentic workflow.
590+
6. Ability to **autotune** within a space of user-defined hyper-parameters;
591+
7. Ability to **serve as a bootstrap (EmitMetal)** that generates C++ TT-Metalium program for further optimization;
592+
8. Ability to **augment TT-NN programs** with custom TT-Lang kernels;
593+
9. Being **Python-based** as to support a limited subset of Python to express programs as well as being able to integrate into the Python environment. This makes TT-Lang more familiar and convenient for the target audience;
594594
10. Ability to develop TT-Lang programs **out of tree** and without rebuilding TT-NN from source;
595595

596596
## Outcomes
597597

598598
There is a number of outcomes we are looking for that motivate TT-Lang:
599599

600-
* Adoption by internal Models Team as a tool that materially speeds up supporting new models in inference;
601-
* Adoption by internal Training Team as a tool that enables fast iteration and experimentation without sacrificing performance;
600+
* Adoption by internal Models Team as a tool that materially speeds up supporting new models in inference;
601+
* Adoption by internal Training Team as a tool that enables fast iteration and experimentation without sacrificing performance;
602602
* Adoption by external users on inference and training tracks as an authoring tool that leverages their experiences with “tile-level” DSLs for GPUs and provides robust abstraction over multiple generations of TT-hardware.
603603

604604
## Questions
605605

606-
1) The programming model can be either single-threaded with a program expressed as a synchronous dataflow using load/store and math operations or it can be multi-threaded and asynchronous with separate data movement and compute kernels using abstractions mapped to CBs and NOC transfers. Can both be supported? If so, which one do we start with?
607-
1) In the initial milestone we plan for multi-threaded model will allow the author to fully control the pipeline and order of operations as well as require explicit synchronization;
608-
2) The single-threaded model will allow the compiler to “design” the pipeline, reorder operations when necessary and infer necessary synchronization. We will explore the single-threaded model in the context of evaluation of applicability of SoTA “tile-level” DSLs.
609-
2) Explicit loop nests versus metadata declarations. For-temporal? For-spatial?
610-
1) We want to provide a choice of expressing for-temporal loops as either explicit for statements in Python or implicitly as specified in metadata declarations.
611-
2) For-spacial looks would only be specified implicitly by grid metadata.
612-
3) Python code for DSL can be either analyzed at AST level or traced. How much empirical runtime code is allowed/needed? What do we need to write a performant FA?
613-
1) We will take the approach of taking AST representation from the kernel's Python code. This will limit what can be used in kernel’s code to a subset of Python that is representable in Arith, Scf and TT-Kernel dialects. We will allow utility functions with the same limitation to be called from kernel’s code.
614-
4) What is the user experience? Is TT-Lang embedded in TT-NN? In PyTorch? Standalone?
615-
1) TT-NN integration that allows mixing TT-NN code with TT-Lang. TT-Lang will be installed as a separate wheel compatible with TT-NN.
616-
2) It is unclear if we need standalone mode or PyTorch integration.
606+
1) The programming model can be either single-threaded with a program expressed as a synchronous dataflow using load/store and math operations or it can be multi-threaded and asynchronous with separate data movement and compute kernels using abstractions mapped to CBs and NOC transfers. Can both be supported? If so, which one do we start with?
607+
1) In the initial milestone we plan for multi-threaded model will allow the author to fully control the pipeline and order of operations as well as require explicit synchronization;
608+
2) The single-threaded model will allow the compiler to “design” the pipeline, reorder operations when necessary and infer necessary synchronization. We will explore the single-threaded model in the context of evaluation of applicability of SoTA “tile-level” DSLs.
609+
2) Explicit loop nests versus metadata declarations. For-temporal? For-spatial?
610+
1) We want to provide a choice of expressing for-temporal loops as either explicit for statements in Python or implicitly as specified in metadata declarations.
611+
2) For-spacial looks would only be specified implicitly by grid metadata.
612+
3) Python code for DSL can be either analyzed at AST level or traced. How much empirical runtime code is allowed/needed? What do we need to write a performant FA?
613+
1) We will take the approach of taking AST representation from the kernel's Python code. This will limit what can be used in kernel’s code to a subset of Python that is representable in Arith, Scf and TT-Kernel dialects. We will allow utility functions with the same limitation to be called from kernel’s code.
614+
4) What is the user experience? Is TT-Lang embedded in TT-NN? In PyTorch? Standalone?
615+
1) TT-NN integration that allows mixing TT-NN code with TT-Lang. TT-Lang will be installed as a separate wheel compatible with TT-NN.
616+
2) It is unclear if we need standalone mode or PyTorch integration.

0 commit comments

Comments
 (0)