-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add workaround for ttnn.where #2196
Conversation
e787caf
to
93625f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but it seems that ttnn.where
is just a syntactic sugar around binary ops (add
, mul
, gtz
and lez
). So it's just a thin layer of abstraction, I don't see what could go wrong with ttnn.where
itself, any problem that might arise is probably rooted in one the mentioned binary ops.
Running TTIR IR from #2195 I'm getting
Always | FATAL | Input output tensor size mismatch in memcpy: 576 * 2 != 576 * 4
LOG_ASSERT @ /localdev/azecevic/tt-mlir/runtime/lib/ttnn/runtime.cpp:378: srcTensor.volume() * srcTensor.element_size() == dstTensor.volume() * dstTensor.element_size()
Which is an assert in the runtime. I suspect that's because the actual result is in the bf16
format and we are expecting f32
. If I remember correctly in TTNN result of eltwise binary will always assume a data type from the first operand, it's consistent with the case here. @jnie-TT can you also take a look at this?
This is something that should be dealt with systematically, that's why I'm against merging this into the main.
@azecevicTT The first operand in
|
@mmanzoorTT Yeah, I get all of that, but as I said this is just the consequence of the way ttnn is dealing with binary eltwise ops. For example:
I would assume most frontends would expect I agree |
@azecevicTT, do you have a proposal for this situation? I think we'll eventually want a data format legalization pass that will go through the entire graph and take care of all of these situations, so in your example
would become:
and then a separate AMP pass that would change everything to the data format we actually want to run in (bf16 in almost all situations), but until that's in, I'm ok with a workaround for where, since, as Asif mentioned, where is different as it typically takes in a boolean. |
@AleksKnezevic Yeah, something like that would be a good long-term solution. I agree that |
I think our usage pattern is somewhat unique, in that the other folks writing ttnn are using the exact datatypes they want. I can see some pushback with the change as the two branches of the if will have different orders (due to broadcast), which seems unnecessary, i.e.
Also, it could just as easily be broken by in the future, since I wouldn't expect anyone updating the where implementation to worry about downstream effects of operand order of a commutative eltwise binary, but I suppose that's why this is broken in the first place 😄. Obviously, there's CI, but it's an unexpected consequence of code changes. Anyway, we have a long term solution (data format pass), and a framework in place for intermediate workarounds (that won't be accidentally broken), why not merge it and remove when no longer needed? |
@AleksKnezevic I've opened an issue tenstorrent/tt-metal#17998 to discuss this with Metal folks.
Unfortunately, that might be the case at the moment, because the people involved with the usage are familiar with the implementation details. But if TTNN is to be marketed as a user-friendly library for the users familar with PyTorch that should change in the future. |
@azecevicTT, thanks for opening the issue, it's definitely the better solution to the problem. Let's see how long the ttnn fix will take. I propose if it's more than a week, we get this wokraround in until they resolve it. |
@azecevicTT, as mentioned in the linked issue, I spoke to @cmaryanTT, the underlaying binary op is being revamped which is a longer term change. Let's merge this in and we can revisit once the revamp is finished. |
@AleksKnezevic Fair, I will review this now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved, but please address comments before merging.
test/ttmlir/Dialect/TTNN/Transforms/Workarounds/where_workaround.mlir
Outdated
Show resolved
Hide resolved
tt-metal uses predicate data type to perform the operation. If predicate data type does not match with input/output then it can cause failure.
93625f6
to
5789089
Compare
closes #2195
Ticket
#2195
Problem description
tt-metal uses data type of predicate for implementation of ttnn.where op. If the predicate and inputs/output have different data type; then ttnn.where can generate incorrect results or may cause other failures.
What's changed
Add a data type workaround to apply the input data type to
predicate
ifinput
data type is not same as ofpredicate
Checklist