diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py new file mode 100644 index 0000000..02fb929 --- /dev/null +++ b/python/mrt/mir/opclass.py @@ -0,0 +1,1037 @@ +import typing +import numpy as np +from dataclasses import dataclass + +from mrt.common.utils import N +from . import opns +from . import symbol +from .symbol import SelfSymbol + +#SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") + +SymbolCreator = typing.Union[typing.Callable[[typing.Any, ...], typing.Type[symbol.Symbol]], SelfSymbol] +#SymbolCreator = typing.Union[typing.Callable[[...], symbol.Symbol], SelfSymbol] + +MRT_OP_MAP: typing.Dict[str, SymbolCreator] = {} + +def _register_op_map(op_name: str): + def _wrapper(clss: SymbolCreator = None) -> SymbolCreator: + if len(op_name) > 0 and clss != None: + if op_name not in MRT_OP_MAP: + MRT_OP_MAP[op_name] = clss + else: + print(f'Warning: "{op_name}" Alreary Registered In MRT_OP_MAP, IsBeing Overrided!') + MRT_OP_MAP[op_name] = clss + return clss + return _wrapper + + +# OPs from external (not in MRT op), using custom op_name with default op_func +#y = extern_opfunc("tanh")(X) +def extern_opfunc(op_name: str): + def op_func(name, args, attrs, extra_attrs): + #return symbol.Symbol(op_name=op_name, *args, **attrs) + return symbol.Symbol(name, op_name, args, attrs, extra_attrs) + return op_func + + +def _from_dict_attrs(cls, d: dict, attr_keys:typing.List[str]=[], **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in attr_keys} + try: + out = cls(*data['args'], **attrsdata, **basedata) + except Exception as e: + raise e + return out + +# OPs without attrs, just register function (funcName should be lower case) +def var(name=None, op_name=None, shape=(), dtype=float) -> symbol.Symbol: + op_name = op_name or opns.VAR + assert op_name == opns.VAR + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs={'shape': shape or (), 'dtype': dtype or float}) + +#def _return_func_single_arg(op_name: op_name): +def relu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.RELU + assert op_name == opns.RELU + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def silu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SILU + assert op_name == opns.SILU + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + + +@dataclass(init=False) +class Conv2D(symbol.Symbol): + op_name = opns.CONV2D + + @property + def strides(self) -> typing.Tuple[int, int]: + default_val = (1,1) + return self.attrs['strides'] if 'strides' in self.attrs else default_val + + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0,0,0,0) + return self.attrs['padding'] if 'padding' in self.attrs else default_val + + @property + def groups(self) -> int: + default_val = 1 + return self.attrs['groups'] if 'groups' in self.attrs else default_val + + @property + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1,1) + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + + @property + def kernel_size(self) -> typing.Tuple[int, int]: + assert 'kernel_size' in self.attrs + return self.attrs['kernel_size'] + + @property + def kernel_layout(self) -> str: + default_val = 'OIHW' + return self.attrs['kernel_layout'] if 'kernel_layout' in self.attrs else default_val + + # Follows (*args, name, **attrs) + def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): + op_name = op_name or opns.CONV2D + assert op_name == opns.CONV2D + assert len(W.shape) == 4, f'Wrong Weight Shape for Conv2D: {W.shape}' + kernel_size = (W.shape[2], W.shape[3]) + super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size, 'kernel_layout': kernel_layout}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + # Auto inferred 'kernel_size' + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_layout'], **kwargs) + +def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): + return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, kernel_layout, extra_attrs) + + +@dataclass(init=False) +class Dropout(symbol.Symbol): + op_name = opns.DROP_OUT + + @property + def p(self) -> float: + default_val = 0.5 + return self.attrs['p'] if 'p' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): + op_name = op_name or opns.DROP_OUT + assert op_name == opns.DROP_OUT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'p': p}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['p'], **kwargs) + +def dropout(X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): + return Dropout(X, name, op_name, p, extra_attrs) + + +@dataclass(init=False) +class Clip(symbol.Symbol): + op_name = opns.CLIP + + @property + def min(self) -> float: + assert 'min' in self.attrs + return self.attrs['min'] + + @property + def max(self) -> float: + assert 'max' in self.attrs + return self.attrs['max'] + + def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): + op_name = op_name or opns.CLIP + assert op_name == opns.CLIP + assert min_ != np.nan + assert max_ != np.nan + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min': min_, 'max': max_}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['min', 'max'], **kwargs) + +def clip(X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): + return Clip(X, name, op_name, min_, max_, extra_attrs) + + +@dataclass(init=False) +class BatchNorm(symbol.Symbol): + op_name = opns.BATCH_NORM + + @property + def axis(self) -> int: + default_val = 1 + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + @property + def epsilon(self) -> float: + default_val = 1e-5 + return self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val + + @property + def momentum(self) -> float: + default_val = 0.1 + return self.attrs['momentum'] if 'momentum' in self.attrs else default_val + + @property + def center(self) -> bool: + default_val = True + return self.attrs['center'] if 'center' in self.attrs else default_val + + @property + def scale(self) -> bool: + default_val = True + return self.attrs['scale'] if 'scale' in self.attrs else default_val + + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): + op_name = op_name or opns.BATCH_NORM + assert op_name == opns.BATCH_NORM + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum', 'center', 'scale'], **kwargs) + +def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): + return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, center, scale, extra_attrs) + + +@dataclass(init=False) +class TupleGetItem(symbol.Symbol): + op_name = opns.TUPLE_GET_ITEM + + @property + def index(self) -> float: + default_val = 0 + return self.attrs['index'] if 'index' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): + op_name = op_name or opns.TUPLE_GET_ITEM + assert op_name == opns.TUPLE_GET_ITEM + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'index': index}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['index'], **kwargs) + +def tuple_get_item(X, name=None, op_name=None, index:int = 0, extra_attrs=None): + return TupleGetItem(X, name, op_name, index, extra_attrs) + + +@dataclass(init=False) +class LeakyRelu(symbol.Symbol): + op_name = opns.LEAKY_RELU + + @property + def negative_slope(self) -> float: + default_val = 1e-2 + return self.attrs['negative_slope'] if 'negative_slope' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): + op_name = op_name or opns.LEAKY_RELU + assert op_name == opns.LEAKY_RELU + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'negative_slope': negative_slope}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['negative_slope'], **kwargs) + +def leaky_relu(X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): + return LeakyRelu(X, name, op_name, negative_slope, extra_attrs) + + +def dense(X, W, B, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.DENSE + assert op_name == opns.DENSE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, W, B], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Hardtanh(symbol.Symbol): + op_name = opns.HARDTANH + + @property + def min_val(self) -> float: + default_val = -1.0 + return self.attrs['min_val'] if 'min_val' in self.attrs else default_val + + @property + def max_val(self) -> float: + default_val = 1.0 + return self.attrs['max_val'] if 'max_val' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + op_name = op_name or opns.HARDTANH + assert op_name == opns.HARDTANH + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min_val': min_val, 'max_val':max_val}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['min_val', 'max_val'], **kwargs) + +def hard_tanh(X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + return Hardtanh(X, name, op_name, min_val, max_val, extra_attrs) + +@dataclass(init=False) +class AdaptiveAvgPool2D(symbol.Symbol): + op_name = opns.ADAPTIVE_AVG_POOL2D + + @property + def output_size(self) -> typing.Union[int, typing.Tuple[int, int]]: + assert 'output_size' in self.attrs + return self.attrs['output_size'] + + def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=None, extra_attrs=None): + op_name = op_name or opns.ADAPTIVE_AVG_POOL2D + assert op_name == opns.ADAPTIVE_AVG_POOL2D + assert output_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'output_size': output_size}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['output_size'], **kwargs) + +def adaptive_avg_pool2d(X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): + return AdaptiveAvgPool2D(X, name, op_name, output_size, extra_attrs) + +@dataclass(init=False) +class AvgPool2D(symbol.Symbol): + op_name = opns.AVG_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] + @property + def strides(self) -> typing.Tuple[int, int]: + default_val = (0, 0) + return self.attrs['strides'] if 'strides' in self.attrs else default_val + @property + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1, 1) + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0, 0, 0, 0) + return self.attrs['padding'] if 'padding' in self.attrs else default_val + @property + def ceil_mode(self) -> bool: + default_val = False + return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val + @property + def layout(self) -> str: + default_val = 'NCHW' + return self.attrs['layout'] if 'layout' in self.attrs else default_val + @property + def count_include_pad(self) -> bool: + default_val = True + return self.attrs['count_include_pad'] if 'count_include_pad' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + op_name = op_name or opns.AVG_POOL2D + assert op_name == opns.AVG_POOL2D + assert pool_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + +def avg_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + return AvgPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, count_include_pad, extra_attrs) + + +@dataclass(init=False) +class MaxPool2D(symbol.Symbol): + op_name = opns.MAX_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] + @property + def strides(self) -> typing.Tuple[int, int]: + default_val = (0, 0) + return self.attrs['strides'] if 'strides' in self.attrs else default_val + @property + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1, 1) + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0, 0, 0, 0) + return self.attrs['padding'] if 'padding' in self.attrs else default_val + @property + def ceil_mode(self) -> bool: + default_val = False + return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val + @property + def layout(self) -> str: + default_val = 'NCHW' + return self.attrs['layout'] if 'layout' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): + op_name = op_name or opns.MAX_POOL2D + assert op_name == opns.MAX_POOL2D + assert pool_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout'], **kwargs) + +def max_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): + return MaxPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, extra_attrs) + + +@dataclass(init=False) +class Softmax(symbol.Symbol): + op_name = opns.SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + default_val = None + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.SOFTMAX + assert op_name == opns.SOFTMAX + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + +def softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): + return Softmax(X, name, op_name, axis, extra_attrs) + +@dataclass(init=False) +class LogSoftmax(symbol.Symbol): + op_name = opns.LOG_SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + default_val = None + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.LOG_SOFTMAX + assert op_name == opns.LOG_SOFTMAX + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + +def log_softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): + return LogSoftmax(X, name, op_name, axis, extra_attrs) + + +def exp(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.EXP + assert op_name == opns.EXP + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def sigmoid(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SIGMOID + assert op_name == opns.SIGMOID + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Sum(symbol.Symbol): + op_name = opns.SUM + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.SUM + assert op_name == opns.SUM + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def sum(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return Sum(X, name, op_name, dim, keepdim, extra_attrs) + + +@dataclass(init=False) +class Mean(symbol.Symbol): + op_name = opns.MEAN + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.MEAN + assert op_name == opns.MEAN + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def mean(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return Mean(X, name, op_name, dim, keepdim, extra_attrs) + + +@dataclass(init=False) +class MaxAxis(symbol.Symbol): + op_name = opns.MAX_AXIS + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.MAX_AXIS + assert op_name == opns.MAX_AXIS + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def max_axis(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return MaxAxis(X, name, op_name, dim, keepdim, extra_attrs) + + +def maximum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MAXIMUM + assert op_name == opns.MAXIMUM + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def minimum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MINIMUM + assert op_name == opns.MINIMUM + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def repeat(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.REPEAT + assert op_name == opns.REPEAT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Squeeze(symbol.Symbol): + op_name = opns.SQUEEZE + + @property + def dim(self) -> typing.Optional[int]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, extra_attrs=None): + op_name = op_name or opns.SQUEEZE + assert op_name == opns.SQUEEZE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim'], **kwargs) + +def squeeze(X, name=None, op_name=None, dim=None, extra_attrs=None): + return Squeeze(X, name, op_name, dim, extra_attrs) + +@dataclass(init=False) +class Flatten(symbol.Symbol): + op_name = opns.FLATTEN + + @property + def start_dim(self) -> int: + default_val = 0 + return self.attrs['start_dim'] if 'start_dim' in self.attrs else default_val + + @property + def end_dim(self) -> int: + default_val = -1 + return self.attrs['end_dim'] if 'end_dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): + op_name = op_name or opns.FLATTEN + assert op_name == opns.FLATTEN + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'start_dim': start_dim, 'end_dim':end_dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['start_dim', 'end_dim'], **kwargs) + +def flatten(X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): + return Flatten(X, name, op_name, start_dim, end_dim, extra_attrs) + + +@dataclass(init=False) +class Reshape(symbol.Symbol): + op_name = opns.RESHAPE + + @property + def newshape(self) -> typing.Tuple[int,...]: + assert 'newshape' in self.attrs + return self.attrs['newshape'] + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.RESHAPE + assert op_name == opns.RESHAPE + assert newshape != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + +def reshape(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return Reshape(X, name, op_name, newshape, extra_attrs) + +@dataclass(init=False) +class Concat(symbol.Symbol): + op_name = opns.CONCAT + + @property + def axis(self) -> int: + default_val = 0 + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.CONCAT + assert op_name == opns.CONCAT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis': axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + +def concat(X, name=None, op_name=None, axis=None, extra_attrs=None): + return Concat(X, name, op_name, axis, extra_attrs) + +@dataclass(init=False) +class Split(symbol.Symbol): + op_name = opns.SPLIT + + @property + def split_size(self) -> typing.List[int]: + assert 'split_size' in self.attrs + return self.attrs['split_size'] + + @property + def dim(self) -> int: + default_val = 0 + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, split_size=None, dim=0, extra_attrs=None): + op_name = op_name or opns.SPLIT + assert op_name == opns.SPLIT + assert split_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'split_size': split_size, 'dim': dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['split_size', 'dim'], **kwargs) + +def split(X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): + return Split(X, name, op_name, split_size, dim, extra_attrs) + + +@dataclass(init=False) +class Transpose(symbol.Symbol): + op_name = opns.TRANSPOSE + + @property + def dim0(self) -> int: + assert 'dim0' in self.attrs + return self.attrs['dim0'] + + @property + def dim1(self) -> int: + assert 'dim1' in self.attrs + return self.attrs['dim1'] + + def __init__(self, X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): + op_name = op_name or opns.TRANSPOSE + assert op_name == opns.TRANSPOSE + assert dim0 != None + assert dim1 != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim0': dim0, 'dim1': dim1}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim0', 'dim1'], **kwargs) + +def transpose(X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): + return Transpose(X, name, op_name, dim0, dim1, extra_attrs) + + +@dataclass(init=False) +class BroadcastTo(symbol.Symbol): + op_name = opns.BROADCAST_TO + + @property + def newshape(self) -> typing.Tuple[int,...]: + assert 'newshape' in self.attrs + return self.attrs['newshape'] + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.BROADCAST_TO + assert op_name == opns.BROADCAST_TO + assert newshape != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + +def broadcast_to(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return BroadcastTo(X, name, op_name, newshape, extra_attrs) + +@dataclass(init=False) +class ExpandDims(symbol.Symbol): + op_name = opns.EXPAND_DIMS + + @property + def newshape(self) -> typing.Tuple[int,...]: + assert 'newshape' in self.attrs + return self.attrs['newshape'] + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.EXPAND_DIMS + assert op_name == opns.EXPAND_DIMS + assert newshape != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + +def expand_dims(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return ExpandDims(X, name, op_name, newshape, extra_attrs) + +@dataclass(init=False) +class Tile(symbol.Symbol): + op_name = opns.TILE + + @property + def dims(self) -> typing.Tuple[int,...]: + assert 'dims' in self.attrs + return self.attrs['dims'] + + def __init__(self, X, name=None, op_name=None, dims=None, extra_attrs=None): + op_name = op_name or opns.TILE + assert op_name == opns.TILE + assert dims != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dims': dims}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + +def tile(X, name=None, op_name=None, dims=None, extra_attrs=None): + return Tile(X, name, op_name, dims, extra_attrs) + + +def where(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.WHERE + assert op_name == opns.WHERE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def greater(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.GREATER + assert op_name == opns.GREATER + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class NonMaxSuppression(symbol.Symbol): + op_name = opns.NON_MAX_SUPRESSION + + @property + def iou_threshold(self) -> float: + default_val = 0.5 + return self.attrs['iou_threshold'] if 'iou_threshold' in self.attrs else default_val + @property + def score_threshold(self) -> typing.Optional[float]: + default_val = None + return self.attrs['score_threshold'] if 'score_threshold' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + op_name = op_name or opns.NON_MAX_SUPRESSION + assert op_name == opns.NON_MAX_SUPRESSION + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'iou_threshold': iou_threshold,'score_threshold':score_threshold}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + +def non_max_suppression(X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + return NonMaxSuppression(X, name, op_name, iou_threshold, score_threshold, extra_attrs) + + +def ceil(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.CEIL + assert op_name == opns.CEIL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def right_shift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.RIGHT_SHIFT + assert op_name == opns.RIGHT_SHIFT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Add(symbol.Symbol): + op_name = opns.ADD + + @property + def alpha(self) -> int: + default_val = 1 + return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + op_name = op_name or opns.ADD + assert op_name == opns.ADD + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +def add(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + return Add(X, Y, name, op_name, alpha, extra_attrs) + +@dataclass(init=False) +class Sub(symbol.Symbol): + op_name = opns.SUB + + @property + def alpha(self) -> int: + default_val = 1 + return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + op_name = op_name or opns.SUB + assert op_name == opns.SUB + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +def sub(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + return Sub(X, Y, name, op_name, alpha, extra_attrs) + + +def mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MUL + assert op_name == opns.MUL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +def mat_mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MATMUL + assert op_name == opns.MATMUL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Div(symbol.Symbol): + op_name = opns.DIV + + @property + def rounding_mode(self) -> typing.Optional[str]: + default_val = None + return self.attrs['rounding_mode'] if 'rounding_mode' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): + op_name = op_name or opns.DIV + assert op_name == opns.DIV + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'rounding_mode': rounding_mode}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['rounding_mode'], **kwargs) + +def div(X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): + return Div(X, Y, name, op_name, rounding_mode, extra_attrs) + + +def negative(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.NEGATIVE + assert op_name == opns.NEGATIVE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def abs(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ABS + assert op_name == opns.ABS + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def log(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.LOG + assert op_name == opns.LOG + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def sqrt(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SQRT + assert op_name == opns.SQRT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def pow(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.POW + assert op_name == opns.POW + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +def pass_(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.PASS + assert op_name == opns.PASS + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Arange(symbol.Symbol): + op_name = opns.ARANGE + + @property + def end(self) -> int: + assert 'end' in self.attrs + return self.attrs['end'] + + @property + def start(self) -> int: + default_val = 0 + return self.attrs['start'] if 'start' in self.attrs else default_val + + @property + def step(self) -> int: + default_val = 1 + return self.attrs['step'] if 'step' in self.attrs else default_val + + def __init__(self, name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): + op_name = op_name or opns.ARANGE + assert op_name == opns.ARANGE + assert end != None + super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={'end': end, 'start': start, 'step': step}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['end', 'start', 'step'], **kwargs) + +def arange(name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): + return Arange(name, op_name, end, start, step, extra_attrs) + + +def zeros_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ZEROS_LIKE + assert op_name == opns.ZEROS_LIKE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def ones_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ONES_LIKE + assert op_name == opns.ONES_LIKE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + + +_register_op_map(opns.VAR)(var) +_register_op_map(opns.RELU)(relu) + +_register_op_map(opns.CONV2D)(Conv2D) +_register_op_map(opns.DROP_OUT)(Dropout) +_register_op_map(opns.CLIP)(Clip) +_register_op_map(opns.BATCH_NORM)(BatchNorm) +_register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) + +_register_op_map(opns.LEAKY_RELU)(LeakyRelu) + +_register_op_map(opns.MUL)(mul) +_register_op_map(opns.DENSE)(dense) +_register_op_map(opns.HARDTANH)(Hardtanh) +_register_op_map(opns.SILU)(silu) +_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(AdaptiveAvgPool2D) +_register_op_map(opns.AVG_POOL2D)(AvgPool2D) +_register_op_map(opns.MAX_POOL2D)(MaxPool2D) +_register_op_map(opns.SOFTMAX)(Softmax) +_register_op_map(opns.LOG_SOFTMAX)(LogSoftmax) +_register_op_map(opns.EXP)(exp) +_register_op_map(opns.SIGMOID)(sigmoid) +_register_op_map(opns.SUM)(Sum) +_register_op_map(opns.MEAN)(Mean) +_register_op_map(opns.MAX_AXIS)(MaxAxis) +_register_op_map(opns.MAXIMUM)(maximum) +_register_op_map(opns.MINIMUM)(minimum) + + +_register_op_map(opns.REPEAT)(repeat) +_register_op_map(opns.SQUEEZE)(Squeeze) +_register_op_map(opns.FLATTEN)(Flatten) +_register_op_map(opns.RESHAPE)(Reshape) +_register_op_map(opns.CONCAT)(Concat) +_register_op_map(opns.SPLIT)(Split) +_register_op_map(opns.TRANSPOSE)(Transpose) +_register_op_map(opns.BROADCAST_TO)(BroadcastTo) +_register_op_map(opns.EXPAND_DIMS)(ExpandDims) +_register_op_map(opns.TILE)(Tile) +_register_op_map(opns.WHERE)(where) +_register_op_map(opns.GREATER)(greater) +_register_op_map(opns.NON_MAX_SUPRESSION)(NonMaxSuppression) + +_register_op_map(opns.CEIL)(ceil) +_register_op_map(opns.RIGHT_SHIFT)(right_shift) + +_register_op_map(opns.ADD)(Add) +_register_op_map(opns.SUB)(Sub) +_register_op_map(opns.MATMUL)(mat_mul) +_register_op_map(opns.DIV)(Div) +_register_op_map(opns.NEGATIVE)(negative) +_register_op_map(opns.ABS)(abs) +_register_op_map(opns.LOG)(log) +_register_op_map(opns.SQRT)(sqrt) +_register_op_map(opns.POW)(pow) +_register_op_map(opns.PASS)(pass_) +_register_op_map(opns.ARANGE)(Arange) +_register_op_map(opns.ZEROS_LIKE)(zeros_like) +_register_op_map(opns.ONES_LIKE)(ones_like) + + +# Add default register Class for MRT OP Not Implemented! +_register_op_map(opns.TUPLE)(extern_opfunc(opns.TUPLE)) +_register_op_map(opns.AS_TYPE)(extern_opfunc(opns.AS_TYPE)) +_register_op_map(opns.ADV_INDEX)(extern_opfunc(opns.ADV_INDEX)) +_register_op_map(opns.CALL_TIR)(extern_opfunc(opns.CALL_TIR)) +_register_op_map(opns.CALL_DPS_PACKED)(extern_opfunc(opns.CALL_DPS_PACKED)) + +_register_op_map(opns.IF)(symbol.Symbol) +_register_op_map(opns.ARGWHERE)(symbol.Symbol) +_register_op_map(opns.REQUANT)(symbol.Symbol) +_register_op_map(opns.PCLIP)(symbol.Symbol) +_register_op_map(opns.RS_PCLIP)(symbol.Symbol) +_register_op_map(opns.LUT)(symbol.Symbol) + +_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) +_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) +_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) +_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index 5b92822..31da253 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -99,3 +99,6 @@ LUT = "mrt.lut" """ look up table, equals adv_index in tvm """ + +def Opname2Funcname(op_name: str) -> str: + return op_name.replace('.', '_') diff --git a/python/mrt/mir/simple_pass.py b/python/mrt/mir/simple_pass.py new file mode 100644 index 0000000..302da1b --- /dev/null +++ b/python/mrt/mir/simple_pass.py @@ -0,0 +1,345 @@ +from __future__ import annotations +import typing + +from functools import wraps +from dataclasses import dataclass + +from mrt.common import config +#from mrt.runtime import inference +from mrt.common.utils import * +from mrt.common.types import * + +from . import op, opns, opclass +from . import symbol as _symbol + + +# mrt op visits +@dataclass +class SimplePass: + symbol: _symbol.Symbol + + """op-level visit of graph + infer different visit function with different op_name + return: head symbol processed + """ + def graph_visits(self) -> _symbol.Symbol: + env: typing.Dict[str, _symbol.Symbol] = {} + for sym in _symbol.sym2list(self.symbol): + assert sym.name not in env, f'{sym.name} NotIn env!' + + # Updating args as passed symbol in env_dict + sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) + assert isinstance(sym, _symbol.Symbol), sym + out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) + out = out or sym + assert isinstance(out, _symbol.Symbol), out + env[sym.name] = out + return env[self.symbol.name] + + def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: + return op + + """custom visit of graph + calling custom_func for all op_name + return: head symbol processed + """ + def custom_visits(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: + with N(name): + if once: + return custom_run(self.symbol) + return _symbol.transform(self.symbol, custom_run) + + +# mrt op visits with params, variables +@dataclass +class InferPass(SimplePass): + params: ParametersT + + def is_input(self, op_: _symbol.Symbol) -> bool: + return op.is_input(op_, self.params) + def is_variable(self, op_: _symbol.Symbol) -> bool: + return op.is_variable(op_, self.params) + def is_operator(self, op_: _symbol.Symbol) -> bool: + return op.is_operator(op_, self.params) + def is_param(self, op_: _symbol.Symbol) -> bool: + return op_.op_name == opns.VAR and op_.name in self.params + + def get_param(self, op_: _symbol.Symbol) -> OpNumpyT: + return self.params[op_.name] if self.is_param(op_) else [] + def get_as_numpy(self, op_: _symbol.Symbol) -> OpNumpyT: + assert self.is_param(op_), f"{op_.name} is not parameter." + data = self.params[op_.name] + assert isinstance(data, (tuple, list, np.ndarray)), \ + f"param:{op_.name} not OpNumpyT, get {type(data)}" + return data + + """custom visit of graph + calling custom_func for all op_name + according to how custom_run implemented, params is from argument or class_property + return: head symbol processed + """ + def custom_visits_with_params(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: + with N(name): + if once: + return custom_run(self.symbol, self.params) + return _symbol.transform(self.symbol, custom_run, params=self.params) + + # From original quantization.Transformer + def as_parameter(self, data: OpNumpyT, name:str, dtype): + def _f(data, dtype): + if isinstance(data, list): + assert len(data) == len(dtype) + return [_f(d, t) for d, t in zip(data, dtype)] + assert isinstance(data, np.ndarray), type(data) + return data.astype(dtype) + array = _f(data, dtype) + shape = np.array(array).shape + self.params[name] = array + return opclass.var(array, shape=shape, dtype=dtype) + + def from_np_data(self, sym:_symbol.Symbol, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: + name = N.n(prefix=prefix) + # some data is np.float/int type, use np.array to wrap it. + data = np.array(data) + self.params[name] = data.astype(dtype) + return opclass.var(name, shape=data.shape, dtype=dtype).like(sym) + + def from_const_data(self, sym:_symbol.Symbol, data: typing.Union[int, float], dtype) -> _symbol.Symbol: + return self.from_np_data(sym, data, dtype) + + +# Register MRT all op's default_visit_op function +for op_name in opclass.MRT_OP_MAP.keys(): + funcSuffix = opns.Opname2Funcname(op_name) + setattr(SimplePass, f"visit_{funcSuffix}", SimplePass._default_visit_op) + #print(f"visit_, {op_name} => {funcSuffix}", getattr(SimplePass, f"visit_{funcSuffix}")) + + +# mrt symbol simple pass +class FuseDropoutPass(SimplePass): + def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: + # make sure op fit again + if sym.op_name == opns.DROP_OUT: + return sym.args[0] + return sym + + +class FuseTupleGetItemPass(SimplePass): + def visit_TupleGetItem(self, sym: opclass.TupleGetItem) -> _symbol.Symbol: + #if sym.op_name == opns.TUPLE_GET_ITEM: + # assert sym.index == 0 + # return sym.args[0] + return sym + + +class FuseNaiveSoftmaxPass(SimplePass): + def visit_nn_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.SOFTMAX: + return sym.args[0] + return sym + + def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.LOG_SOFTMAX: + return sym.args[0] + return sym + + +class FuseMeanPass(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.MEAN: + X = sym.args[0] + out = opclass.Sum(X, **sym.attrs).like(sym) + scale = self.from_np_data(sym, np.array( + 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) + out = opclass.mul(out, scale) + return out + return sym + return custom_run + + +class FuseConstantPass(InferPass): + threshold: typing.ClassVar[float] = 1e-5 + + def np_is_zero(self, data) -> float: + return np.abs(data).max() < self.threshold + + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if self.is_operator(sym) and all([self.is_param(arg) for arg in sym.args]): + data = inference.run_single_params( + sym, [self.get_as_numpy(a) for a in sym.args]) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + elif sym.is_op(opns.ADD, opns.SUB): # , BIAS_ADD): + strips = [] + for arg in sym.args: + if self.is_param(arg) and self.np_is_zero(self.get_as_numpy(arg)): + strips.append(arg) + args = [a for a in sym.args if a not in strips] + if len(args) == 1: + return args[0] + elif sym.is_op(opns.SLICE_LIKE): + if not self.is_param(sym.args[0]): + return sym + a, b = sym.args + data = inference.run_single_params( + sym, [self.get_as_numpy(a), np.zeros(b.shape, b.dtype)]) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + elif sym.is_op(opns.REQUANT): + if sym.rescale == 1: + return sym.args[0] + elif sym.is_op(opns.ZEROS_LIKE, opns.ONES_LIKE): + data = inference.run_single_params(sym, []) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + return sym + return custom_run + + +class FuseBatchNormPass(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: opclass.BatchNorm, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.BATCH_NORM: + X, Gamma, Beta, Mean, Var = sym.args + Gamma = self.get_param(Gamma) + Beta = self.get_param(Beta) + Mean = self.get_param(Mean) + Var = self.get_param(Var) + + assert sym.axis == 1 + Beta = Beta if sym.center else 0 + Gamma = Gamma if sym.scale else 1 + + # (x - mean) / sqrt(var + epsilon) * gamma + beta + Gamma = Gamma / np.sqrt(Var + sym.epsilon) + # (x - mean) * gamma + beta + # x * gamma + (beta - mean * gamma) + bias: np.ndarray = (Beta - Mean * Gamma) + K = Gamma.shape[0] + + if X.is_op(opns.CONV2D): + A, W = X.args + assert X.kernel_layout == "OIHW" + assert W.shape[0] == K + # (A * W) * gamma + bias + # A * (W * gamma) + bias + W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1, 1, 1) + W_sym = self.from_np_data(W, W_data, W.dtype) + out = op.nn_conv2d(A, W_sym, **X.attrs) + elif X.is_op(opns.DENSE): + A, W = X.args + # (A * W) * gamma + bias + # A * (W * gamma) + bias + W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1) + W_sym = self.from_np_data(W, W_data, W.dtype) + out = op.nn_dense(A, W_sym, **X.attrs) + else: + reshp = [s if i == sym.axis else 1 \ + for i, s in enumerate(X.shape)] + W = self.from_np_data(X, Gamma.reshape(reshp), X.dtype) + out = opclass.mul(X, W) + + bias = bias.reshape([s if i == sym.axis else 1 \ + for i, s in enumerate(out.shape)]) + B = out.like(sym) + B = self.from_np_data(B, bias, dtype=B.dtype) + return opclass.add(out, B).like(sym) + + return sym + return custom_run + + +class FuseDividePass(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.DIV: + argA = sym.args[0] + argB = sym.args[1] + assert self.is_param(argB), f'NotParam: {argB}' + argB = self.from_np_data(sym, 1. / self.get_as_numpy(argB), dtype=argB.dtype) + out = opclass.mul(argA, argB) + return out.like(sym) + return sym + return custom_run + + +class FuseLeakyReLU(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.LEAKY_RELU: + alpha = self.from_const_data(sym, sym.alpha, dtype=float) + X = sym.args[0] + out = opclass.relu(opclass.negative(X)) + out = opclass.mul(alpha, out) + return opclass.sub(opclass.relu(X), out) + return sym + return custom_run + +class FuseAdaptiveAvgPool2D(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.ADAPTIVE_AVG_POOL2D: + X = sym.args[0] + assert sym.layout == "NCHW" + inp_shap = X.shape[2:] + out_size = sym.output_size or inp_shap + if not isinstance(out_size, (list, tuple)): + out_size = (out_size, out_size) + sym.output_size = out_size + + assert len(X.shape) == 4 + if all([s == 1 for s in sym.output_size]): + scale = np.array(1 / np.prod(X.shape[-2:])) + out = opclass.Sum(X, dim=list(range(4))[-2:], keepdims=True) + scale = self.from_np_data(sym, scale.astype(X.dtype)) + return opclass.mul(out, scale).like(self) + elif out_size[0] > inp_shap[0] or out_size[1] > inp_shap[1]: + assert all([s == 1 for s in inp_shap]) + # TODO: fix opclass repeat + out = opclass.repeat(X, repeats=out_size[0], axis=-2) + out = opclass.repeat(out, repeats=out_size[1], axis=-1) + return out.like(self) + + # calculate the attributes refers to: + # https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work + strides = [i // o for i, o in zip(inp_shap, out_size)] + kernel = [i-(o-1)*s for i, o, s in zip(inp_shap, out_size, strides)] + attrs = { + "kernel_size": kernel, + "strides": strides, + "padding": (0, 0), + "dilation": (1, 1), + "data_layout": sym.layout, + "groups": X.shape[1], + "channels": X.shape[1], + } + W_shape = (X.shape[1], 1, *kernel) + W = self.from_np_data(X, np.full(W_shape, 1 / product(kernel)), dtype=X.dtype) + out = opclass.Conv2D(X, W, **attrs) + return out.like(sym) + return sym + return custom_run + + +class FuseAvgPool2D(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Spliter(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Merger(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Calibrator(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 5c97cee..07bd8af 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -11,8 +11,7 @@ # from . import config # from .utils import * -# from .types import * -from .opns import * +from . import opns __ALL__ = [ "Symbol", @@ -20,6 +19,8 @@ "filter_operators", ] +SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") + def _format_printer(data): if isinstance(data, dict): data = ["{}={}".format(k, _format_printer(v)) \ @@ -112,7 +113,10 @@ def like(self, other: Symbol, **kwargs) -> Symbol: # assert self.shape == other.shape, "%s vs.\n %s" % (self, other) # assert self.dtype == other.dtype , "%s vs.\n %s" % (self, other) data = other.to_dict() - data.update(self.to_dict()) + data_new = self.to_dict() + data.update(data_new) + + data["extra_attrs"] = other.extra_attrs if self.extra_attrs == {} else data["extra_attrs"] # copy extra attrs by default. # data["extra_attrs"] = other.extra_attrs return type(other).from_dict(data, **kwargs) @@ -277,34 +281,6 @@ def __hash__(self) -> int: def hash(self) -> int: return hash(str(self)) -# class Convolution2D(Symbol): -# strides: typing.Tuple[int, int] - -# class Dropout(Symbol): -# eps: float = 1e-5 - -# class Pass: -# symbol: Symbol - -# def visit(self, op: Symbol): -# env: typing.Dict[Symbol, Symbol] = {} -# for sym in sym2list(self.symbol): -# out = getattr(self, f"visit_{op.op_name}")(op) or op -# assert isinstance(sym, Symbol) -# env[sym] = out -# return env[op] - -# def _default_visit_op(op): -# return op - -# for op in op_list: -# setattr(Pass, f"visit_{op.op_name}", _default_visit_op) - -# class FuseDropoutPass(Pass): -# def visit_dropout(self, op: Dropout): -# op.eps -# return op.args[0] - def _topo_sort(symbol: Symbol, sym_list: typing.List[Symbol]): assert isinstance(symbol, Symbol), \ f"({type(symbol).__name__}){str(symbol)}" @@ -349,6 +325,7 @@ def load_json(data: _SymbolJsonT, **extra_attrs) -> Symbol: _VisitorT = typing.Callable[[Symbol], None] _TransformerT = typing.Callable[[Symbol], typing.Optional[Symbol]] +_TransformerParamT = typing.Callable[[Symbol, typing.Optional[ParametersT]], Symbol] """ Symbol Transformer Return new symbol to transform old symbol into updated one, @@ -365,7 +342,7 @@ def visit(symbol: Symbol, callback: _VisitorT): if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f">> {sym}") -def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: +def transform(symbol: Symbol, callback: _TransformerParamT, params:typing.Optional[ParametersT] = None) -> Symbol: """ Transform symbol from old to new, with inputs updated. Only the return value indicates mutation, while changing @@ -382,7 +359,7 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f"<< {sym}") - out = callback(sym) or sym + out = (callback(sym, params) if params else callback(sym)) or sym assert isinstance(out, Symbol), out # default const_ prefix symbol means parameters assert sym.name not in sym_map, sym.name @@ -491,7 +468,7 @@ def as_tuple(self) -> typing.Tuple[typing.List[str], Symbol]: @classmethod def from_tuple(cls, tuple_names, symbol): - assert symbol.is_op(TUPLE), symbol + assert symbol.is_op(opns.TUPLE), symbol mhs = cls(zip(tuple_names, symbol.args)) mhs.origin = symbol return mhs diff --git a/tests/mir/test.infer_pass.py b/tests/mir/test.infer_pass.py new file mode 100644 index 0000000..3d94e93 --- /dev/null +++ b/tests/mir/test.infer_pass.py @@ -0,0 +1,103 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_resnet18_model(): + """Get Resnet18 MRT Model""" + + # Load pre-trained ResNet18 + model = models.resnet18(weights='IMAGENET1K_V1') + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseBatchNorm(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseBatchNorm Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseAdaptiveAvgPool2D(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseAdaptiveAvgPool2D Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseTupleGetItem(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseTuple Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt > 0, f'ori model TupleGetItem op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseTupleGetItemPass = simple_pass.FuseTupleGetItemPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseTuple Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass ===") + mrt_graph, mrt_params = _get_resnet18_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseBatchNorm, test_InferPass_FuseAdaptiveAvgPool2D, test_InferPass_FuseTupleGetItem] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_div.py b/tests/mir/test.infer_pass_div.py new file mode 100644 index 0000000..547363c --- /dev/null +++ b/tests/mir/test.infer_pass_div.py @@ -0,0 +1,88 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_fasterrcnn_resnet50_fpn_model(): + """Get Fasterrcnn_resnet50_fpn MRT Model""" + + # Load pre-trained model + model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseDivide(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDivide Pass ===') + symlist = sx.sym2list(symbol) + + divide_op_cnt = 0 + for sym in symlist: + divide_op_cnt += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt > 0, f'ori model divide op cnt {divide_op_cnt} == zero!' + + # init FuseDivide Passer and execute visit + tfs : simple_pass.FuseDividePass = simple_pass.FuseDividePass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseDivide Pass ===') + rlts = sx.sym2list(symbol_passed) + divide_op_cnt_af = 0 + for sym in rlts: + # print(sym) + divide_op_cnt_af += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt_af==0, f'passed model divide op cnt {divide_op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Divide ===") + mrt_graph, mrt_params = _get_fasterrcnn_resnet50_fpn_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseDivide] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_mean.py b/tests/mir/test.infer_pass_mean.py new file mode 100644 index 0000000..d8586ec --- /dev/null +++ b/tests/mir/test.infer_pass_mean.py @@ -0,0 +1,89 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_shufflenet_model(): + """Get Shufflenet MRT Model""" + + # Load pre-trained + model = models.shufflenet_v2_x1_0(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseMean(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseMean Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt > 0, f'ori model mean op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseMeanPass = simple_pass.FuseMeanPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseMean Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Mean ===") + mrt_graph, mrt_params = _get_shufflenet_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseMean] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py new file mode 100644 index 0000000..7707afb --- /dev/null +++ b/tests/mir/test.op_create.py @@ -0,0 +1,177 @@ +""" +Test script for Alexnet PyTorch to MRT conversion. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass + + +def test_op_func(): + X = opclass.var(name="var2", shape=(16, 128, 128), dtype="float") + ceil0 = opclass.ceil(X) + assert isinstance(ceil0, sx.Symbol), 'ceil0 isnot a symbol' + assert ceil0.op_name == opns.CEIL + assert len(ceil0.name) > 0 + + ceil1 = opclass.ceil(X, 'ceil_1') + assert ceil1.op_name == opns.CEIL + assert ceil1.name == 'ceil_1' + + return True + + +def test_create_conv2d_op(): + + X = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") + assert [shp for shp in X.shape] == [shp for shp in (1, 3, 224, 224)], f'Wrong X shape {X.shape}' + assert X.dtype == "float", f'Wrong X dtype {X.dtype}' + + # Symbol Init using opclass OP + conv2d_a = opclass.Conv2D(X, W, name='conv2d_a', strides=(2,2)) + assert isinstance(conv2d_a, sx.Symbol), 'conv2d_a isnot a symbol' + assert isinstance(conv2d_a, opclass.Conv2D), 'conv2d_a isnot a Conv2D' + + # attrs hint + assert conv2d_a.args != None + assert conv2d_a.attrs != None + assert conv2d_a.strides != None + + print(f'Got {conv2d_a.name} strides: {conv2d_a.strides}') + print(f'Got {conv2d_a.name} padding: {conv2d_a.padding}') + print(f'Show {conv2d_a.name} {conv2d_a}') + + # test Conv2D clone mode + conv2d_b = conv2d_a.copy() + assert isinstance(conv2d_b, sx.Symbol), 'conv2d_b isnot a symbol' + assert isinstance(conv2d_b, opclass.Conv2D), 'conv2d_b isnot a Conv2D' + + assert conv2d_b.attrs == conv2d_a.attrs, f'a: {conv2d_b.attrs} != b: {conv2d_a.attrs}' + + # test Dict to Find Class and Init + conv2d_c = opclass.MRT_OP_MAP[opns.CONV2D](X, W, strides=(2,2)) + assert isinstance(conv2d_c, opclass.Conv2D), 'conv2d_c isnot a Conv2D' + + # test Variable clone mode + X1 = X.copy() + assert X1.shape == X.shape + assert X1.dtype == X.dtype + + # test: Symbol Compatible Mode + args = [X1, W] + attrs = {'strides':(3,3)} + + # Symbol Compatible Init + conv2d_d = opclass.Conv2D(*args, name='conv2d_d', **attrs) + conv2d_e = opclass.Conv2D(*args, **attrs) + assert isinstance(conv2d_d, opclass.Conv2D), 'conv2d_d isnot a Conv2D' + assert isinstance(conv2d_e, opclass.Conv2D), 'conv2d_e isnot a Conv2D' + + # alias function Init + conv2d_f = opclass.conv2d(*args, **attrs) + assert isinstance(conv2d_f, opclass.Conv2D), 'conv2d_f isnot a Conv2D' + + return True + + +def test_create_symbol_graph(): + X0 = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W0 = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") + conv2d_a = opclass.Conv2D(X0, W0, name='conv2d_a', strides=(1,1)) + + W1 = opclass.var(shape=(16, 3, 12, 12,), dtype="float") + conv2d_b = opclass.Conv2D(conv2d_a, W1, name='conv2d_b', strides=(1,1)) + symlist = sx.sym2list(conv2d_b) + + assert symlist[0] == X0 + assert symlist[1] == W0 + + for id_ in range(len(symlist)): + print(id_, symlist[id_]) + + return True + + +def test_create_batch_norm_op(): + X = opclass.var(name="x", shape=(1, 32, 128, 128,), dtype="float") + Gamma = opclass.var(name="gamma", shape=(32,), dtype="float") + Beta = opclass.var(name="beta", shape=(32,), dtype="float") + Mean = opclass.var(name="mean", shape=(32,), dtype="float") + Var = opclass.var(name="var", shape=(32,), dtype="float") + batch_norm_a = opclass.BatchNorm(X, Gamma, Beta, Mean, Var, axis=1, epsilon=1e-4) + + # attrs hint + assert batch_norm_a.args != None + assert batch_norm_a.attrs != None + assert batch_norm_a.axis != 0 + + # test clone mode + batch_norm_b = batch_norm_a.copy() + assert isinstance(batch_norm_b, opclass.BatchNorm) + + assert batch_norm_a.attrs == batch_norm_b.attrs, f'a: {batch_norm_a.attrs} != b: {batch_norm_b.attrs}' + assert len(batch_norm_a.args) == len(batch_norm_b.args), f'a: {len(batch_norm_a.args)} != b: {len(batch_norm_b.args)}' + + return True + + +def test_create_reshape_op(): + X = opclass.var(name="x", shape=(16, 32, 64, 64,), dtype="float") + try: + reshape0 = opclass.Reshape(X, name="reshape_0") + assert False, "Reshape Must have attr 'newshape', Should already Fail!" + except: + pass + + reshape1 = opclass.Reshape(X, name="reshape_1", newshape=(16, 8, 128, 128)) + assert isinstance(reshape1, opclass.Reshape) + + return True + + +def test_op_extern_func(): + + # extern_func Do not need to fill 'op_name' + args = [opclass.var(name="var2", shape=(16, 128, 128), dtype="float")] + attrs = {} + extra_attrs = {} + call_dps_packed = opclass.MRT_OP_MAP[opns.CALL_DPS_PACKED]('packed_0', args, attrs, extra_attrs) + assert isinstance(call_dps_packed, sx.Symbol), 'call_dps_packed isnot a symbol' + assert call_dps_packed.op_name == opns.CALL_DPS_PACKED + return True + + +if __name__ == "__main__": + print('MRT_OP_SET as:', opclass.MRT_OP_MAP.keys()) + assert len(opclass.MRT_OP_MAP.keys()) > 0 + + assert opns.CONV2D in opclass.MRT_OP_MAP + print('MRT_OP_MAP Conv2D Class as:', opclass.MRT_OP_MAP[opns.CONV2D]) + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_op_func, test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op, test_create_reshape_op, test_op_extern_func] + for func_ in test_funcs: + rltflag = func_() + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.simple_pass.py b/tests/mir/test.simple_pass.py new file mode 100644 index 0000000..33139d4 --- /dev/null +++ b/tests/mir/test.simple_pass.py @@ -0,0 +1,149 @@ +""" +Test script for MRT Alexnet FuseDropoutPass. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_alexnet_model(): + """Get Alexnet MRT Model""" + + # Load pre-trained Alexnet + model = models.alexnet(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Alexnet to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + +def test_SimplePass_FuseDropout(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDropout Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + # print(sym) + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt>0, f'original model dropout op cnt {dropout_op_cnt} == zero!' + + # init FuseDropout Passer and execute visit + tfs : simple_pass.FuseDropoutPass = simple_pass.FuseDropoutPass(symbol) + #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) + symbol_passed = tfs.graph_visits() + + print('\n=== After FuseDropout Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + # print(sym) + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af==0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + #for sym in symdict: + # print(sym, symdict[sym]) + + #print('\n=== Back To SymList ===') + #rltlist = sx.sym2list(symdict[symbol.name]) + + return True + + +def test_SimplePass_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + conv2d_name_list = [] + def _filter_op(sym: sx.Symbol, params=None) -> sx.Symbol: + if sym.op_name == opns.CONV2D: + conv2d_name_list.append(sym.name) + return sym + + symbol_passed = tfs.custom_visits(_filter_op) + + print('\n=== After CustomFunc Pass ===') + assert len(conv2d_name_list) > 0 + print(conv2d_name_list) + rlts = sx.sym2list(symbol_passed) + + return True + + +def test_SimplePass_FuseDropout_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before FuseDropout CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt > 0, f'ori model dropout op cnt {dropout_op_cnt} == zero!' + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + def _nn_dropout(sym: sx.Symbol) -> sx.Symbol: + if sym.op_name == opns.DROP_OUT: + return sym.args[0] + return sym + symbol_passed = tfs.custom_visits(_nn_dropout) + + print('\n=== After FuseDropout CustomFunc Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af == 0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + + print("=== Testing SymbolPass ===") + mrt_graph, mrt_params = _get_alexnet_model() + + print("Testing FuseDropoutPass for Model AlexNet") + rltflag = test_SimplePass_FuseDropout(mrt_graph, mrt_params) + print("\n" + "="*60 + "\n") + print('Passed Test1!' if rltflag else 'Test1 Failed!') + print("\n" + "="*60 + "\n") + + rltflag = test_SimplePass_CustomFunc(mrt_graph) + print("\n" + "="*60 + "\n") + print('Passed Test2!' if rltflag else 'Test2 Failed!') + print("\n" + "="*60 + "\n") + + print("Testing FuseDropout CustomFunc for Model AlexNet") + rltflag = test_SimplePass_FuseDropout_CustomFunc(mrt_graph) + print("\n" + "="*60 + "\n") + print('Passed Test3!' if rltflag else 'Test3 Failed!') + print("\n" + "="*60 + "\n") +