Skip to content

Commit

Permalink
[Relax][Bugfix] Preserve dtype in ToMixedPrecision for kNever ops (#1…
Browse files Browse the repository at this point in the history
…7263)

Prior to this commit, while an operator with the
`MixedPrecisionPolicyKind::kNever` attribute would not be updated from
`float32` to `float16`, it would be erroneously updated from `float16`
to `float32`.

This commit updates `ToMixedPrecision` to preserve the datatype of any
arguments used in a `kNever` operation, rather than forcing them to a
`float32` datatype.
  • Loading branch information
Lunderberg committed Aug 11, 2024
1 parent bed66d2 commit b3d01c2
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 28 deletions.
69 changes: 44 additions & 25 deletions src/relax/transform/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
}

Array<Expr> RemapArgs(const Array<Expr>& args) {
Array<Expr> new_args;
for (const auto& arg : args) {
new_args.push_back(VarReplacer::Replace(arg, var_remap_));
}
return new_args;
return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, var_remap_); });
}

// Util function to rewrite the expr to the given dtype
Expand Down Expand Up @@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator {
ReEmitBinding(binding, call_node->args[0]);
return;
}
DataType to;
ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);

Call new_call = GetRef<Call>(call_node);

// We first to remap the args to the current vars according to the var_remap_
new_call->args = std::move(RemapArgs(call_node->args));
new_call.CopyOnWrite()->args = RemapArgs(new_call->args);

// Then we rewrite the args according to the policy
std::optional<DataType> opt_new_dtype = std::nullopt;

if (policy == kAlways) {
to = fp16_;
opt_new_dtype = fp16_;
auto attr_map = Op::GetAttrMap<FInferMixedPrecision>("FInferMixedPrecision");
ICHECK(attr_map.count(op));
auto f = attr_map[op];
new_call = make_object<CallNode>(*(f(Call(new_call), output_dtype_).get()));
new_call = attr_map[op](new_call, output_dtype_);
} else if (policy == kFollow) {
to = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
} else if (policy == kNever) {
to = fp32_;
// An upstream operation may have changed the datatype of the
// arguments. Because this operation must be provided with
// exactly the same dtype as it previously had, it may require a
// cast back to the original datatype.

if (!new_call->args.same_as(call_node->args)) {
Array<Expr> new_typed_args;
for (size_t i = 0; i < call_node->args.size(); i++) {
auto arg = new_call->args[i];
auto old_ntype = NTypeFrom(call_node->args[i]);
new_typed_args.push_back(RewriteExpr(arg, old_ntype));
}
new_call.CopyOnWrite()->args = new_typed_args;
}

} else {
LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
}
new_call->args = std::move(RewriteArgs(new_call->args, to));
new_call->struct_info_ = NullOpt;
Expr new_value = builder_->Normalize(Call(new_call));
if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
// kAlways: store the tensors to fp16
// But global vars will be stored to the original dtype anyway (see below)
new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_));
}
if (!binding->var->IsInstance<DataflowVarNode>()) {
// Global var: store the tensors to the original dtype
NType to = NTypeFrom(binding->var);
new_value = RewriteExpr(new_value, to);

Expr new_value = new_call;
if (opt_new_dtype) {
auto new_dtype = opt_new_dtype.value();
new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype);
new_call.CopyOnWrite()->struct_info_ = NullOpt;

new_value = builder_->Normalize(Call(new_call));

if (!binding->var->IsInstance<DataflowVarNode>()) {
// Non-Dataflow var: store the tensors to the original dtype
new_value = RewriteExpr(new_value, NTypeFrom(binding->var));
} else if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
// kAlways: store the tensors to fp16
// But non-dataflow vars will be stored to the original dtype anyway (see above)
new_value = RewriteExpr(new_value, NTypeFrom(new_value, new_dtype));
}
}

ReEmitBinding(binding, builder_->Normalize(new_value));
}

Expand Down
34 changes: 31 additions & 3 deletions tests/python/relax/test_transform_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import relax
import tvm.testing
from tvm.relax.transform import ToMixedPrecision
from tvm.script.parser import ir as I, relax as R
from tvm.script.parser import ir as I, relax as R, tir as T


def _assert_test(input, expected=None, expected2=None):
Expand Down Expand Up @@ -614,8 +614,8 @@ def main(
x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), "float32")
) -> R.Tensor(None, "float32", ndim=4):
with R.dataflow():
gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, padding=(1, 1))
gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, axis=1)
gv: R.Tensor((2, 3, 28, 28), "float32") = R.nn.conv2d(x, w, padding=(1, 1))
gv1: R.Tensor((2, 3, 28, 28), "float32") = R.nn.softmax(x, axis=1)
gv2 = R.add(gv, gv1)
R.output(gv2)
return gv2
Expand Down Expand Up @@ -1036,5 +1036,33 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


def test_call_tir_with_float16_args():
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([64], "float16")):
cls = Before
with R.dataflow():
B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64], "float16"))
C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64], "float16"))
R.output(C)
return C

@T.prim_func
def tir_identity(
Input: T.Buffer(64, "float16"),
Output: T.Buffer(64, "float16"),
):
for i in range(64):
with T.block("copy"):
vi = T.axis.remap("S", [i])
Output[vi] = Input[vi]

Expected = Before

After = ToMixedPrecision()(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit b3d01c2

Please sign in to comment.