diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index c844d5935623..1b660b8fecc5 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } Array RemapArgs(const Array& args) { - Array new_args; - for (const auto& arg : args) { - new_args.push_back(VarReplacer::Replace(arg, var_remap_)); - } - return new_args; + return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, var_remap_); }); } // Util function to rewrite the expr to the given dtype @@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator { ReEmitBinding(binding, call_node->args[0]); return; } - DataType to; - ObjectPtr new_call = make_object(*call_node); + + Call new_call = GetRef(call_node); + // We first to remap the args to the current vars according to the var_remap_ - new_call->args = std::move(RemapArgs(call_node->args)); + new_call.CopyOnWrite()->args = RemapArgs(new_call->args); + // Then we rewrite the args according to the policy + std::optional opt_new_dtype = std::nullopt; + if (policy == kAlways) { - to = fp16_; + opt_new_dtype = fp16_; auto attr_map = Op::GetAttrMap("FInferMixedPrecision"); ICHECK(attr_map.count(op)); - auto f = attr_map[op]; - new_call = make_object(*(f(Call(new_call), output_dtype_).get())); + new_call = attr_map[op](new_call, output_dtype_); } else if (policy == kFollow) { - to = AllFP16Castable(new_call->args) ? fp16_ : fp32_; + opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_; } else if (policy == kNever) { - to = fp32_; + // An upstream operation may have changed the datatype of the + // arguments. Because this operation must be provided with + // exactly the same dtype as it previously had, it may require a + // cast back to the original datatype. + + if (!new_call->args.same_as(call_node->args)) { + Array new_typed_args; + for (size_t i = 0; i < call_node->args.size(); i++) { + auto arg = new_call->args[i]; + auto old_ntype = NTypeFrom(call_node->args[i]); + new_typed_args.push_back(RewriteExpr(arg, old_ntype)); + } + new_call.CopyOnWrite()->args = new_typed_args; + } + } else { LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy; } - new_call->args = std::move(RewriteArgs(new_call->args, to)); - new_call->struct_info_ = NullOpt; - Expr new_value = builder_->Normalize(Call(new_call)); - if (policy == kAlways && binding->var->IsInstance()) { - // kAlways: store the tensors to fp16 - // But global vars will be stored to the original dtype anyway (see below) - new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_)); - } - if (!binding->var->IsInstance()) { - // Global var: store the tensors to the original dtype - NType to = NTypeFrom(binding->var); - new_value = RewriteExpr(new_value, to); + + Expr new_value = new_call; + if (opt_new_dtype) { + auto new_dtype = opt_new_dtype.value(); + new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype); + new_call.CopyOnWrite()->struct_info_ = NullOpt; + + new_value = builder_->Normalize(Call(new_call)); + + if (!binding->var->IsInstance()) { + // Non-Dataflow var: store the tensors to the original dtype + new_value = RewriteExpr(new_value, NTypeFrom(binding->var)); + } else if (policy == kAlways && binding->var->IsInstance()) { + // kAlways: store the tensors to fp16 + // But non-dataflow vars will be stored to the original dtype anyway (see above) + new_value = RewriteExpr(new_value, NTypeFrom(new_value, new_dtype)); + } } + ReEmitBinding(binding, builder_->Normalize(new_value)); } diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 4ddf47b462ad..ed10fc95c723 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -20,7 +20,7 @@ from tvm import relax import tvm.testing from tvm.relax.transform import ToMixedPrecision -from tvm.script.parser import ir as I, relax as R +from tvm.script.parser import ir as I, relax as R, tir as T def _assert_test(input, expected=None, expected2=None): @@ -614,8 +614,8 @@ def main( x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), "float32") ) -> R.Tensor(None, "float32", ndim=4): with R.dataflow(): - gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, padding=(1, 1)) - gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, axis=1) + gv: R.Tensor((2, 3, 28, 28), "float32") = R.nn.conv2d(x, w, padding=(1, 1)) + gv1: R.Tensor((2, 3, 28, 28), "float32") = R.nn.softmax(x, axis=1) gv2 = R.add(gv, gv1) R.output(gv2) return gv2 @@ -1036,5 +1036,33 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_call_tir_with_float16_args(): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([64], "float16")): + cls = Before + with R.dataflow(): + B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64], "float16")) + C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64], "float16")) + R.output(C) + return C + + @T.prim_func + def tir_identity( + Input: T.Buffer(64, "float16"), + Output: T.Buffer(64, "float16"), + ): + for i in range(64): + with T.block("copy"): + vi = T.axis.remap("S", [i]) + Output[vi] = Input[vi] + + Expected = Before + + After = ToMixedPrecision()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()