Skip to content

Commit e11913b

Browse files
authored
[Target] Add target to all TVM callbacks (apache#14939)
* [Target] Add target to all TVM callbacks This PR adds an extra parameter `target` to all `tvm_callback_*` so that the callback can decide its own behavior by querying which target to compile against. * fix lint * fix lint
1 parent dbcd198 commit e11913b

20 files changed

+76
-77
lines changed

apps/ios_rpc/tests/ios_rpc_mobilenet.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,24 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import tvm
19-
from tvm import rpc, relay
20-
from tvm.contrib.download import download_testdata
21-
from tvm.relay.expr_functor import ExprMutator
22-
from tvm.relay import transform
23-
from tvm.relay.op.annotation import compiler_begin, compiler_end
24-
from tvm.relay.quantize.quantize import prerequisite_optimize
25-
from tvm.contrib import utils, xcode, graph_executor, coreml_runtime
26-
from tvm.contrib.target import coreml as _coreml
27-
18+
import argparse
2819
import os
2920
import re
3021
import sys
22+
23+
import coremltools
3124
import numpy as np
25+
import tvm
3226
from mxnet import gluon
3327
from PIL import Image
34-
import coremltools
35-
import argparse
28+
from tvm import relay, rpc
29+
from tvm.contrib import coreml_runtime, graph_executor, utils, xcode
30+
from tvm.contrib.download import download_testdata
31+
from tvm.contrib.target import coreml as _coreml
32+
from tvm.relay import transform
33+
from tvm.relay.expr_functor import ExprMutator
34+
from tvm.relay.op.annotation import compiler_begin, compiler_end
35+
from tvm.relay.quantize.quantize import prerequisite_optimize
3636

3737
# Change target configuration, this is setting for iphone6s
3838
# arch = "x86_64"
@@ -43,9 +43,10 @@
4343

4444
MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect}
4545

46+
4647
# override metal compiler to compile to iphone
4748
@tvm.register_func("tvm_callback_metal_compile")
48-
def compile_metal(src):
49+
def compile_metal(src, target):
4950
return xcode.compile_metal(src, sdk=sdk)
5051

5152

apps/ios_rpc/tests/ios_rpc_test.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
And configure the proxy host field as commented.
2121
"""
2222

23-
import tvm
24-
from tvm import te
23+
import argparse
2524
import os
2625
import re
2726
import sys
28-
from tvm import rpc
29-
from tvm.contrib import utils, xcode
27+
3028
import numpy as np
31-
import argparse
29+
import tvm
30+
from tvm import rpc, te
31+
from tvm.contrib import utils, xcode
3232

3333
# Change target configuration, this is setting for iphone6s
3434
arch = "arm64"
@@ -37,9 +37,10 @@
3737

3838
MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect}
3939

40+
4041
# override metal compiler to compile to iphone
4142
@tvm.register_func("tvm_callback_metal_compile")
42-
def compile_metal(src):
43+
def compile_metal(src, target):
4344
return xcode.compile_metal(src, sdk=sdk)
4445

4546

apps/topi_recipe/broadcast/test_broadcast_map.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import os
18+
19+
import numpy as np
1820
import tvm
19-
from tvm import te
21+
from tvm import te, topi
2022
from tvm.contrib import nvcc
21-
import numpy as np
22-
23-
from tvm import topi
24-
2523

2624
TASK = "reduce_map"
2725
USE_MANUAL_CODE = False
2826

2927

3028
@tvm.register_func("tvm_callback_cuda_compile", override=True)
31-
def tvm_callback_cuda_compile(code):
29+
def tvm_callback_cuda_compile(code, target):
3230
ptx = nvcc.compile_cuda(code, target_format="ptx")
3331
return ptx
3432

@@ -39,7 +37,7 @@ def write_code(code, fname):
3937

4038

4139
@tvm.register_func
42-
def tvm_callback_cuda_postproc(code):
40+
def tvm_callback_cuda_postproc(code, target):
4341
if not os.path.exists("perf"):
4442
os.mkdir("perf")
4543
write_code(code, "perf/%s_generated.cu" % TASK)

apps/topi_recipe/conv/depthwise_conv2d_test.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,24 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import os
18-
import tvm
19-
from tvm import te
18+
2019
import numpy as np
20+
import tvm
2121
from scipy import signal
22+
from tvm import te, topi
2223
from tvm.contrib import nvcc
23-
24-
from tvm import topi
25-
from tvm.topi.utils import get_const_tuple
2624
from tvm.topi.cuda.depthwise_conv2d import (
2725
schedule_depthwise_conv2d_nchw,
2826
schedule_depthwise_conv2d_nhwc,
2927
)
28+
from tvm.topi.utils import get_const_tuple
3029

3130
TASK = "depthwise_conv2d"
3231
USE_MANUAL_CODE = False
3332

3433

3534
@tvm.register_func("tvm_callback_cuda_compile", override=True)
36-
def tvm_callback_cuda_compile(code):
35+
def tvm_callback_cuda_compile(code, target):
3736
ptx = nvcc.compile_cuda(code, target_format="ptx")
3837
return ptx
3938

@@ -44,7 +43,7 @@ def write_code(code, fname):
4443

4544

4645
@tvm.register_func
47-
def tvm_callback_cuda_postproc(code):
46+
def tvm_callback_cuda_postproc(code, target):
4847
if not os.path.exists("perf"):
4948
os.mkdir("perf")
5049
write_code(code, "perf/%s_generated.cu" % TASK)

apps/topi_recipe/conv/test_conv2d_hwcn_map.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,19 @@
1616
# under the License.
1717
"""Example code to do convolution."""
1818
import os
19+
1920
import numpy as np
20-
import scipy.signal
2121
import tvm
22-
from tvm import te
22+
from tvm import te, topi
2323
from tvm.contrib import nvcc
24-
from tvm import topi
2524
from tvm.topi.utils import get_const_tuple
2625

2726
TASK = "conv2d_hwcn_map"
2827
USE_MANUAL_CODE = False
2928

3029

3130
@tvm.register_func("tvm_callback_cuda_compile", override=True)
32-
def tvm_callback_cuda_compile(code):
31+
def tvm_callback_cuda_compile(code, target):
3332
ptx = nvcc.compile_cuda(code, target_format="ptx")
3433
return ptx
3534

@@ -40,7 +39,7 @@ def write_code(code, fname):
4039

4140

4241
@tvm.register_func
43-
def tvm_callback_cuda_postproc(code):
42+
def tvm_callback_cuda_postproc(code, target):
4443
if not os.path.exists("perf"):
4544
os.mkdir("perf")
4645
write_code(code, "perf/%s_generated.cu" % TASK)

apps/topi_recipe/reduce/test_reduce_map.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import os
18+
19+
import numpy as np
1820
import tvm
19-
from tvm import te
21+
from tvm import te, topi
2022
from tvm.contrib import nvcc
21-
import numpy as np
22-
23-
from tvm import topi
24-
2523

2624
TASK = "reduce_map"
2725
USE_MANUAL_CODE = False
@@ -33,7 +31,7 @@ def write_code(code, fname):
3331

3432

3533
@tvm.register_func
36-
def tvm_callback_cuda_postproc(code):
34+
def tvm_callback_cuda_postproc(code, target):
3735
if not os.path.exists("perf"):
3836
os.mkdir("perf")
3937
write_code(code, "perf/%s_generated.cu" % TASK)

apps/topi_recipe/rnn/lstm.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""LSTM Example, still work in progress.."""
18+
import os
19+
20+
import numpy as np
1821
import tvm
1922
from tvm import te
20-
import os
2123
from tvm.contrib import nvcc
22-
import numpy as np
2324

2425
# Quick knobs
2526
TASK = "lstm"
@@ -31,7 +32,7 @@
3132

3233

3334
@tvm.register_func("tvm_callback_cuda_compile", override=True)
34-
def tvm_callback_cuda_compile(code):
35+
def tvm_callback_cuda_compile(code, target):
3536
"""Use nvcc compiler for better perf."""
3637
ptx = nvcc.compile_cuda(code, target_format="ptx")
3738
return ptx
@@ -43,7 +44,7 @@ def write_code(code, fname):
4344

4445

4546
@tvm.register_func
46-
def tvm_callback_cuda_postproc(code):
47+
def tvm_callback_cuda_postproc(code, target):
4748
if not os.path.exists("perf"):
4849
os.mkdir("perf")
4950
write_code(code, "perf/%s_generated.cu" % TASK)

apps/topi_recipe/rnn/matexp.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
X[t] = dot(X[t-1], W)
2424
```
2525
"""
26+
import argparse
27+
import os
28+
import time
29+
30+
import numpy as np
2631
import tvm
2732
from tvm import te
28-
import time
29-
import os
30-
import argparse
3133
from tvm.contrib import nvcc
32-
import numpy as np
3334

3435
# Quick knobs
3536
TASK = "matexp"
@@ -40,7 +41,7 @@
4041

4142

4243
@tvm.register_func("tvm_callback_cuda_compile", override=True)
43-
def tvm_callback_cuda_compile(code):
44+
def tvm_callback_cuda_compile(code, target):
4445
"""Use nvcc compiler for better perf."""
4546
ptx = nvcc.compile_cuda(code, target_format="ptx")
4647
return ptx
@@ -52,7 +53,7 @@ def write_code(code, fname):
5253

5354

5455
@tvm.register_func
55-
def tvm_callback_cuda_postproc(code):
56+
def tvm_callback_cuda_postproc(code, target):
5657
if not os.path.exists("perf"):
5758
os.mkdir("perf")
5859
write_code(code, "perf/%s_generated.cu" % TASK)

jvm/core/src/test/scripts/test_add_gpu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
import tvm
2020
from tvm import te
21-
from tvm.contrib import cc, utils, nvcc
21+
from tvm.contrib import cc, nvcc, utils
2222

2323

2424
@tvm.register_func("tvm_callback_cuda_compile", override=True)
25-
def tvm_callback_cuda_compile(code):
25+
def tvm_callback_cuda_compile(code, target):
2626
ptx = nvcc.compile_cuda(code, target_format="ptx")
2727
return ptx
2828

python/tvm/contrib/nvcc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
"""Utility to invoke nvcc compiler in the system"""
1919
from __future__ import absolute_import as _abs
2020

21-
import subprocess
2221
import os
22+
import subprocess
2323
import warnings
2424

2525
import tvm._ffi
2626
from tvm.target import Target
2727

28-
from . import utils
2928
from .._ffi.base import py_str
29+
from . import utils
3030

3131

3232
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
@@ -184,7 +184,7 @@ def get_cuda_version(cuda_path=None):
184184

185185

186186
@tvm._ffi.register_func
187-
def tvm_callback_cuda_compile(code):
187+
def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
188188
"""use nvcc to generate fatbin code for better optimization"""
189189
ptx = compile_cuda(code, target_format="fatbin")
190190
return ptx

python/tvm/contrib/sdaccel.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Utility for Interacting with SDAccel Tools"""
18-
import subprocess
1918
import os
19+
import subprocess
2020

2121
import tvm._ffi
22+
2223
from . import utils
2324

2425

2526
@tvm._ffi.register_func("tvm_callback_sdaccel_compile")
26-
def compile_vhls(kernel_info, device_name):
27+
def compile_vhls(kernel_info, target):
2728
"""Compile Vivado HLS code for SDAccel.
2829
2930
Parameters
@@ -32,14 +33,15 @@ def compile_vhls(kernel_info, device_name):
3233
List of kernel information. The kernel information is a tuple of
3334
function name and source code.
3435
35-
device_name : str
36-
The name of the target device
36+
target : tvm.target.Target
37+
The compilation target
3738
3839
Return
3940
------
4041
xclbin : bytearray
4142
The bytearray of the xclbin
4243
"""
44+
device_name = target.attrs.get("device", "")
4345
tmp_dir = utils.tempdir()
4446

4547
sdk = os.environ.get("XILINX_SDX", None)

src/target/opt/build_cuda_on.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
143143
std::string code = cg.Finish();
144144

145145
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
146-
code = (*f)(code).operator std::string();
146+
code = (*f)(code, target).operator std::string();
147147
}
148148
std::string fmt = "ptx";
149149
std::string ptx;
150150
const auto* f_enter = Registry::Get("target.TargetEnterScope");
151151
(*f_enter)(target);
152152
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
153-
ptx = (*f)(code).operator std::string();
153+
ptx = (*f)(code, target).operator std::string();
154154
// Dirty matching to check PTX vs cubin.
155155
// TODO(tqchen) more reliable checks
156156
if (ptx[0] != '/') fmt = "cubin";

src/target/source/codegen_aocl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) {
5151

5252
std::string code = cg.Finish();
5353
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
54-
code = (*f)(code).operator std::string();
54+
code = (*f)(code, target).operator std::string();
5555
}
5656

5757
// Write a .cl file.

src/target/source/codegen_metal.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
365365
std::string fsource = cg.Finish();
366366
source_maker << fsource << "\n";
367367
if (fmetal_compile) {
368-
fsource = (*fmetal_compile)(fsource).operator std::string();
368+
fsource = (*fmetal_compile)(fsource, target).operator std::string();
369369
}
370370
smap[func_name] = fsource;
371371
}

0 commit comments

Comments
 (0)