Skip to content

Commit b0dabcc

Browse files
Co1linganler
andauthored
Add some TF ops and add XLA backend (ise-uiuc#40)
* feat: impl some forward functions for tf * feat: add more dtypes * fix: remove f16 from dtype_floats as it is not supported by many torch operators * feat: add LocalRespNorm for TF * feat: force CPU when making oracle with TF * feat: add XLA as a backend * feat: add tensorflow dump load test * feat: try to add NHWCConv2d for TF * test: add XLA test * fix: disable autocast for keras layers * fix: enable XLA compilation * fix: typo * chore: del commented code * fix: disable some dtypes * fix: disable some dtypes * fix:reshape * ci: complete tflite and xla test * fix: tflite system name * tweak tf logging and fix topset * clean headers * fix dtype reload * hot fix: ci * feat: add pattern to bug report folder * fix: do not allow conv2d with valid padding to output tensor with dim=0 * refact: optimize logging for model gen * pre-commit * instruction for running pre-commit for all files; fix: error switch * many fixes: supress absl, fix tf.conv spec, more logging levels * pre-commit * feat: dtype test with model gen only; align with testing * feat: scm for auto versioning * refact: use scm ver * add coverage info * show pytest log * feat: optimize dependency * fix: warn and relax pygraphviz * skip gpu tests * refact: disable pytest capture * revert capture Co-authored-by: ganler <[email protected]>
1 parent 738464d commit b0dabcc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1085
-277
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ tmp/
1313
*.profraw
1414
fuzz_report/
1515
nnsmith_output/
16+
nnsmith/_version.py
1617

1718
# hydra
1819
outputs/

README.md

+13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
[![](https://github.com/ise-uiuc/nnsmith/actions/workflows/ci.yaml/badge.svg)](https://github.com/ise-uiuc/nnsmith/actions/workflows/ci.yaml) [![](https://img.shields.io/pypi/v/nnsmith?color=g)](https://pypi.org/project/nnsmith/) [![](https://img.shields.io/pypi/l/nnsmith)](https://github.com/ise-uiuc/nnsmith/blob/main/LICENSE)
44

5+
## Backend-Model Support
6+
7+
| Backend\Model | ONNX/PyTorch | TensorFlow |
8+
| ------------- | ------------ | ---------- |
9+
| TVM || |
10+
| ONNXRuntime || |
11+
| TensorRT || |
12+
| TFLite | | ⚠️ |
13+
| XLA | | ⚠️ |
14+
15+
✅: Supported; ⚠️: Beta support; Others are not supported yet -- Contributions are welcome!
16+
517
## Quick Start
618

719
<details><summary><b>Setting up graphviz for debugging</b> <i>[click to expand]</i></summary>
@@ -47,6 +59,7 @@ You can use `pre-commit` to simpify development:
4759

4860
- `pip install -r requirements/dev.txt`;
4961
- `pre-commit install`;
62+
- `pre-commit` will run upon a commit; To explicitly run `pre-commit` for all files: `pre-commit run --all-files`.
5063

5164
<details><summary><b>More notes</b> <i>[click to expand]</i></summary>
5265
<div>

doc/cli.md

+4
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,7 @@ nnsmith.dtype_test model.type="onnx" backend.type="onnxruntime"
6666
```shell
6767
nnsmith.fuzz fuzz.time=30s fuzz.root=fuzz_report debug.viz=true
6868
```
69+
70+
## Misc
71+
72+
TensorFlow logging can be very noisy. Use `TF_CPP_MIN_LOG_LEVEL=3` as environmental variable to depress that.

doc/log-and-err.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@ We support the following logging "keys":
66

77
- `fuzz`: fuzzing loop;
88
- `mgen`: model generation;
9+
- `smt`: constraints in smt solving;
910
- `exec`: model execution;
1011
- `viz`: graphviz visualization;
1112
- `dtest`: dtype_test;
13+
- `core`: seed setting, etc;
1214

1315
The show messages above "INFO" level (see [Python's logging module](https://docs.python.org/3/library/logging.html)). To show debug level message, add `hydra.verbose=[${keys}]` (also see [hydra.logging](https://hydra.cc/docs/1.2/tutorials/basic/running_your_app/logging/)).
1416

1517
```shell
1618
# Show debug information related to `fuzz`:
1719
${NNSMITH_CMD} hydra.verbose=fuzz
1820
# Show debug info for `fuzz` and `exec`:
19-
${NNSMITH_CMD} hydra.verbose=[fuzz,exec]
21+
${NNSMITH_CMD} hydra.verbose="[fuzz,exec]"
2022
```
2123

2224
#### Logging things into file

nnsmith/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
try:
2+
from nnsmith._version import __version__, __version_tuple__
3+
except ImportError:
4+
pass

nnsmith/abstract/arith.py

+25
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,28 @@ def nnsmith_min(left, right):
181181
return min(left, right)
182182
left, right = align_bvs(left, right)
183183
return z3.If(nnsmith_le(left, right), left, right)
184+
185+
186+
def nnsmith_max(left, right):
187+
if isinstance(left, int) and isinstance(right, int):
188+
return min(left, right)
189+
left, right = align_bvs(left, right)
190+
return z3.If(nnsmith_ge(left, right), left, right)
191+
192+
193+
def nnsmith_and(left, right):
194+
if isinstance(left, bool) and isinstance(right, bool):
195+
return left and right
196+
return z3.And(left, right)
197+
198+
199+
def nnsmith_or(left, right):
200+
if isinstance(left, bool) and isinstance(right, bool):
201+
return left or right
202+
return z3.Or(left, right)
203+
204+
205+
def nnsmith_not(expr):
206+
if isinstance(expr, bool):
207+
return not expr
208+
return z3.Not(expr)

nnsmith/abstract/dtype.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -116,25 +116,43 @@ def tensorflow(self) -> "tf.Dtype":
116116
import tensorflow as tf
117117

118118
return {
119+
DType.float16: tf.float16,
119120
DType.float32: tf.float32,
120121
DType.float64: tf.float64,
122+
DType.int8: tf.int8,
123+
DType.int16: tf.int16,
121124
DType.int32: tf.int32,
122125
DType.int64: tf.int64,
126+
DType.complex64: tf.complex64,
127+
DType.complex128: tf.complex128,
128+
DType.bool: tf.bool,
123129
}[self]
124130

125131
@staticmethod
126132
def from_tensorflow(dtype) -> "DType":
127133
import tensorflow as tf
128134

129135
return {
136+
tf.float16: DType.float16,
130137
tf.float32: DType.float32,
131138
tf.float64: DType.float64,
139+
tf.int8: DType.int8,
140+
tf.int16: DType.int16,
132141
tf.int32: DType.int32,
133142
tf.int64: DType.int64,
143+
tf.complex64: DType.complex64,
144+
tf.complex128: DType.complex128,
145+
tf.bool: DType.bool,
134146
}[dtype]
135147

136148

137-
DTYPE_ALL = [DType.float32, DType.float64, DType.int32, DType.int64, DType.bool]
149+
DTYPE_ALL = [
150+
DType.float32,
151+
DType.float64,
152+
DType.int32,
153+
DType.int64,
154+
DType.bool,
155+
]
138156
DTYPE_NON_BOOLS = [dtype for dtype in DTYPE_ALL if dtype != DType.bool]
139157
DTYPE_FLOATS = [DType.float32, DType.float64]
140158
DTYPE_INTS = [DType.int32, DType.int64]

nnsmith/abstract/op.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -1512,7 +1512,7 @@ def __init__(self, *target_shape):
15121512
super().__init__()
15131513
self.inp_ranks = [int_range(1, 4)]
15141514
self.out_ranks = [(len(target_shape),)]
1515-
self.target_shape: List[Union[int, z3.ExprRef]] = target_shape
1515+
self.target_shape: List[Union[int, z3.ExprRef]] = list(target_shape)
15161516

15171517
def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
15181518
__MAX_SOLVE_SYMBOL__ = 8
@@ -1602,16 +1602,6 @@ def deduct_inp_ranks_and_dtype(
16021602
return [(-1, out_abs_tensor[0].dtype)]
16031603

16041604

1605-
@mark_materialize("core")
1606-
class Flatten(Reshape):
1607-
num_var_param = None
1608-
# Inputs are target shape.
1609-
1610-
def __init__(self, dim0: Union[int, z3.ExprRef]):
1611-
super().__init__(dim0)
1612-
self.dim0 = dim0
1613-
1614-
16151605
@mark_materialize("core")
16161606
class Transpose(UnaryOpBase):
16171607
in_dtypes = [(i,) for i in DTYPE_ALL]
@@ -1665,7 +1655,7 @@ class InterpBase(UnaryOpBase):
16651655

16661656
def __init__(self, *size):
16671657
super().__init__()
1668-
self.size = size
1658+
self.size = list(size)
16691659
self.inp_ranks = [(len(size) + 2,)]
16701660
self.out_ranks = [(len(size) + 2,)]
16711661

nnsmith/abstract/tensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
class AbsTensor:
1212
def __init__(self, shape: List[Union[int, z3.ExprRef]], dtype: DType):
13+
assert isinstance(
14+
shape, (list, tuple)
15+
), f"Shape must be a list/tuple, but got {shape}"
1316
self.shape = list(shape)
1417
self.dtype = DType(dtype)
1518

nnsmith/backends/factory.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ def make_testcase(
155155
log=traceback.format_exc(),
156156
)
157157

158-
return TestCase(model, Oracle(input=input, output=output))
158+
return TestCase(
159+
model, Oracle(input=input, output=output, provider=self.system_name)
160+
)
159161

160162
@staticmethod
161163
def init(name, device="cpu", optmax=True, catch_process_crash=False, **kwargs):
@@ -188,5 +190,23 @@ def init(name, device="cpu", optmax=True, catch_process_crash=False, **kwargs):
188190
catch_process_crash=catch_process_crash,
189191
**kwargs,
190192
)
193+
elif name == "tflite":
194+
from nnsmith.backends.tflite import TFLiteFactory
195+
196+
return TFLiteFactory(
197+
device=device,
198+
optmax=optmax,
199+
catch_process_crash=catch_process_crash,
200+
**kwargs,
201+
)
202+
elif name == "xla":
203+
from nnsmith.backends.xla import XLAFactory
204+
205+
return XLAFactory(
206+
device=device,
207+
optmax=optmax,
208+
catch_process_crash=catch_process_crash,
209+
**kwargs,
210+
)
191211
else:
192212
raise ValueError(f"unknown backend: {name}")

nnsmith/backends/onnxruntime.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, device, optmax, **kwargs):
2222
super().__init__(device, optmax, **kwargs)
2323
self.opt_level = OPT_LEVELS[-1 if optmax else 0]
2424
self.providers = ["CPUExecutionProvider"]
25-
if device in ["cuda", "gpu"]:
25+
if device == "cuda":
2626
self.providers = [
2727
"CUDAExecutionProvider",
2828
"CPUExecutionProvider",

nnsmith/backends/tensorrt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ class HostDeviceMem:
2525

2626

2727
class TRTFactory(BackendFactory):
28-
def __init__(self, device="gpu", optmax=True, **kwargs):
28+
def __init__(self, device="cuda", optmax=True, **kwargs):
2929
super().__init__(device, optmax, **kwargs)
3030

31-
if device != "gpu":
31+
if device != "cuda":
3232
raise ValueError("TensorRT backend only supports GPU!")
3333

3434
if optmax is False:
3535
# TODO(@ganler): support non-optimized TensorRT by using performing
3636
# inference over a model that marks all nodes as outputs.
37-
warnings.warn("There is not O0 mode for TensorRT so far.", UserWarning)
37+
warnings.warning("There is not O0 mode for TensorRT so far.", UserWarning)
3838

3939
@property
4040
def system_name(self) -> str:

nnsmith/backends/tflite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, device, optmax, **kwargs) -> None:
4343

4444
@property
4545
def system_name(self) -> str:
46-
"tflite"
46+
return "tflite"
4747

4848
@dispatch(TFModel)
4949
def make_backend(

nnsmith/backends/tvm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, device="cpu", optmax=True, executor="graph", **kwargs) -> Non
2424
self.opt_level = 4 if optmax else 0
2525
if device == "cpu":
2626
self.target = tvm.target.Target("llvm")
27-
elif device == "cuda" or device == "gpu":
27+
elif device == "cuda":
2828
self.target = tvm.target.Target("cuda")
2929
else:
3030
raise ValueError(f"Unknown device `{device}`")

nnsmith/backends/xla.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from os import PathLike
2+
from typing import Callable, Dict
3+
4+
import numpy as np
5+
import tensorflow as tf # type: ignore
6+
from multipledispatch import dispatch
7+
8+
from nnsmith.backends.factory import BackendCallable, BackendFactory
9+
from nnsmith.materialize.tensorflow import (
10+
TFModel,
11+
TFNetCallable,
12+
np_dict_from_tf,
13+
tf_dict_from_np,
14+
)
15+
16+
17+
class XLAFactory(BackendFactory):
18+
def __init__(self, device="cpu", optmax: bool = False, catch_process_crash=True):
19+
super().__init__(device, optmax, catch_process_crash)
20+
21+
@property
22+
def system_name(self) -> str:
23+
return "xla"
24+
25+
@dispatch(TFModel)
26+
def make_backend(self, model: TFModel) -> BackendCallable:
27+
concrete_net: TFNetCallable = model.concrete_net()
28+
device: tf.device
29+
30+
if self.device == "cpu":
31+
device = tf.device(tf.config.list_logical_devices("CPU")[0].name)
32+
elif self.device == "cuda":
33+
device = tf.device(tf.config.list_logical_devices("GPU")[0].name)
34+
else:
35+
raise ValueError(f"Unknown device: {self.device}")
36+
37+
@tf.function(jit_compile=True)
38+
def compiled_net(**inputs) -> Dict[str, tf.Tensor]:
39+
return concrete_net(**inputs)
40+
41+
def closure(inputs: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
42+
tf.config.run_functions_eagerly(False)
43+
with device:
44+
return np_dict_from_tf(compiled_net(**tf_dict_from_np(inputs)))
45+
46+
return closure

nnsmith/cli/dtype_test.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
@hydra.main(version_base=None, config_path="../config", config_name="main")
1010
def main(cfg: DictConfig):
1111
backend_cfg = cfg["backend"]
12-
factory = BackendFactory.init(
13-
name=backend_cfg["type"],
14-
device=backend_cfg["device"],
15-
optmax=backend_cfg["optmax"],
16-
)
12+
if backend_cfg is not None:
13+
factory = BackendFactory.init(
14+
name=backend_cfg["type"],
15+
device=backend_cfg["device"],
16+
optmax=backend_cfg["optmax"],
17+
)
18+
else:
19+
factory = None
1720
model_type = Model.init(cfg["model"]["type"])
1821
load_topset_from_auto_cache(model_type, factory)
1922

0 commit comments

Comments
 (0)