Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workaround for ttnn.where #2196

Merged
merged 1 commit into from
Feb 21, 2025
Merged

Add workaround for ttnn.where #2196

merged 1 commit into from
Feb 21, 2025

Conversation

mmanzoorTT
Copy link
Contributor

@mmanzoorTT mmanzoorTT commented Feb 17, 2025

closes #2195

Ticket

#2195

Problem description

tt-metal uses data type of predicate for implementation of ttnn.where op. If the predicate and inputs/output have different data type; then ttnn.where can generate incorrect results or may cause other failures.

What's changed

Add a data type workaround to apply the input data type to predicate if input data type is not same as of predicate

Checklist

  • New tests provide coverage for changes

@mmanzoorTT mmanzoorTT changed the title Workaround for ttnn.where Add workaround for ttnn.where Feb 17, 2025
Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but it seems that ttnn.where is just a syntactic sugar around binary ops (add, mul, gtz and lez). So it's just a thin layer of abstraction, I don't see what could go wrong with ttnn.where itself, any problem that might arise is probably rooted in one the mentioned binary ops.

Running TTIR IR from #2195 I'm getting

                 Always |    FATAL | Input output tensor size mismatch in memcpy: 576 * 2 != 576 * 4
LOG_ASSERT @ /localdev/azecevic/tt-mlir/runtime/lib/ttnn/runtime.cpp:378: srcTensor.volume() * srcTensor.element_size() == dstTensor.volume() * dstTensor.element_size()

Which is an assert in the runtime. I suspect that's because the actual result is in the bf16 format and we are expecting f32. If I remember correctly in TTNN result of eltwise binary will always assume a data type from the first operand, it's consistent with the case here. @jnie-TT can you also take a look at this?

This is something that should be dealt with systematically, that's why I'm against merging this into the main.

@mmanzoorTT
Copy link
Contributor Author

@azecevicTT The first operand in ttnn.where op is a predicate/condition which is typically a boolean tensor but tt-metal does not support boolean types. AFAIK, we have these following scenarios.

  1. In case of tt-forge front end; the predicate type depends on the instruction generating it. If the predicate is generated by comparing two float32 then the output will be float32. We cannot use this float32 predicate in ttnn.where op if the true_value/false_value are not float32.
  2. tt-xla and tt-torch front ends uses Stablehlo -> TTIR conversion where are all the booleans are type converted to bfloat16. If the true_value/false_value are not bfloat16 then this op fails.

@azecevicTT
Copy link
Contributor

azecevicTT commented Feb 18, 2025

@mmanzoorTT Yeah, I get all of that, but as I said this is just the consequence of the way ttnn is dealing with binary eltwise ops. For example:

add(bf16, f32) -> bf16
add(f32, bf16) -> f32

I would assume most frontends would expect f32 as a result in this case (and certainly wouldn't expect a type to depend on the order of the operands), so we are back to square one.

I agree ttnn.where is kind of special in its usage pattern, but I would try to rethink the solution, because sooner or later we are going to hit other cases of this problem as well.

@AleksKnezevic
Copy link
Contributor

@azecevicTT, do you have a proposal for this situation?

I think we'll eventually want a data format legalization pass that will go through the entire graph and take care of all of these situations, so in your example

add(bf16, f32) -> bf16

would become:

cast(bf16) -> f32
add(f32, f32) -> f32

and then a separate AMP pass that would change everything to the data format we actually want to run in (bf16 in almost all situations), but until that's in, I'm ok with a workaround for where, since, as Asif mentioned, where is different as it typically takes in a boolean.

@azecevicTT
Copy link
Contributor

@AleksKnezevic Yeah, something like that would be a good long-term solution.

I agree that ttnn.where has a different usage pattern than most eltwise ops, but in that regard I believe it would be beneficial to talk with Metal folks about this. If I'm not mistaken changing the order of condition and value in https://github.com/tenstorrent/tt-metal/blob/main/ttnn/cpp/ttnn/operations/eltwise/ternary/where.cpp#L37 would have the same effect so it would be beneficial to all consumers of TTNN lib (who have the same usage pattern), not just us. I would propose this to owners of eltwise ops before I try to land this in TT-MLIR.

@AleksKnezevic
Copy link
Contributor

I think our usage pattern is somewhat unique, in that the other folks writing ttnn are using the exact datatypes they want. I can see some pushback with the change as the two branches of the if will have different orders (due to broadcast), which seems unnecessary, i.e.

        if (std::holds_alternative<Tensor>(value)) {
            return ttnn::multiply(queue_id, std::get<Tensor>(value), condition, std::nullopt, output_mem_config);
        } else {
            return ttnn::multiply(queue_id, condition, std::get<float>(value), std::nullopt, output_mem_config);
        }

Also, it could just as easily be broken by in the future, since I wouldn't expect anyone updating the where implementation to worry about downstream effects of operand order of a commutative eltwise binary, but I suppose that's why this is broken in the first place 😄. Obviously, there's CI, but it's an unexpected consequence of code changes.

Anyway, we have a long term solution (data format pass), and a framework in place for intermediate workarounds (that won't be accidentally broken), why not merge it and remove when no longer needed?

@azecevicTT
Copy link
Contributor

@AleksKnezevic I've opened an issue tenstorrent/tt-metal#17998 to discuss this with Metal folks.

I think our usage pattern is somewhat unique, in that the other folks writing ttnn are using the exact datatypes they want.

Unfortunately, that might be the case at the moment, because the people involved with the usage are familiar with the implementation details. But if TTNN is to be marketed as a user-friendly library for the users familar with PyTorch that should change in the future.
There are always multiple ways to tackle some problem, and I won't block this if this is the last resort, but I would first try to tackle the problem at its root.

@AleksKnezevic
Copy link
Contributor

@azecevicTT, thanks for opening the issue, it's definitely the better solution to the problem. Let's see how long the ttnn fix will take. I propose if it's more than a week, we get this wokraround in until they resolve it.

@AleksKnezevic
Copy link
Contributor

@azecevicTT, as mentioned in the linked issue, I spoke to @cmaryanTT, the underlaying binary op is being revamped which is a longer term change. Let's merge this in and we can revisit once the revamp is finished.

@azecevicTT
Copy link
Contributor

@AleksKnezevic Fair, I will review this now.

Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, but please address comments before merging.

tt-metal uses predicate data type to perform the operation. If predicate data
type does not match with input/output then it can cause failure.
@mmanzoorTT mmanzoorTT enabled auto-merge (squash) February 21, 2025 16:55
@mmanzoorTT mmanzoorTT merged commit 04758dc into main Feb 21, 2025
34 checks passed
@mmanzoorTT mmanzoorTT deleted the mmanzoor/where-op-wa branch February 21, 2025 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TTNN] Workaround for ttnn.where op
3 participants