|
1 | 1 | import os |
2 | 2 | import importlib |
3 | 3 | import sys |
| 4 | +from functools import wraps |
4 | 5 |
|
5 | | -from mrt.mir.optype import _DEFAULT_TYPE_INFER |
| 6 | +from mrt.mir.symbol import * |
| 7 | +from mrt.common.types import * |
6 | 8 |
|
7 | | -FRONTEND = os.environ.get("FRONTEND", "pytorch") |
| 9 | +class Singleton(object): |
| 10 | + def __new__(cls, *args, **kw): |
| 11 | + if not hasattr(cls, '_instance'): |
| 12 | + orig = super(Singleton, cls) |
| 13 | + cls._instance = orig.__new__(cls, *args, **kw) |
| 14 | + return cls._instance |
| 15 | + |
| 16 | +class DynamicModule(Singleton): |
| 17 | + def __init__(self): |
| 18 | + self._funcs = {} |
| 19 | + |
| 20 | + def load_mod(self, frontend): |
| 21 | + try: |
| 22 | + frontend_module = importlib.import_module(f".{FRONTEND}", package="mrt.frontend") |
| 23 | + except ImportError as e: |
| 24 | + print(f"Error: Frontend '{FRONTEND}' cannot be imported: {e}") |
| 25 | + return |
| 26 | + |
| 27 | + for f in self._funcs: |
| 28 | + if hasattr(frontend_module, f): |
| 29 | + self._funcs[f] = getattr(frontend_module, f) |
| 30 | + else: |
| 31 | + print(f"Error: function '{f}' not found in frontend '{FRONTEND}'") |
| 32 | + |
| 33 | + return self |
| 34 | + |
| 35 | + def typedef_mod_function(self, func): |
| 36 | + fname = func.__name__ |
| 37 | + self._funcs.setdefault(fname, None) |
| 38 | + |
| 39 | + @wraps(func) |
| 40 | + def _func_impl(*args, **kwargs): |
| 41 | + assert self._funcs[fname] is not None, f"func:{fname} not registered in mod: {self._funcs.keys()}" |
| 42 | + func(*args, **kwargs) |
| 43 | + return self._funcs[fname](*args, **kwargs) |
| 44 | + return _func_impl |
| 45 | + |
| 46 | + |
| 47 | +mod = DynamicModule() |
| 48 | + |
| 49 | +@mod.typedef_mod_function |
| 50 | +def create_executor( |
| 51 | + symbol: MultiHeadSymbol, params: ParametersT, |
| 52 | + device: str = "cpu", |
| 53 | + target: str = "", # no use in pytorch frontend |
| 54 | + ): |
| 55 | + """ Create Runtime Executor for Model Inference. """ |
| 56 | + pass |
8 | 57 |
|
9 | | -# Dynamically load the frontend module |
10 | | -try: |
11 | | - frontend_module = importlib.import_module(f".{FRONTEND}", package="mrt.frontend") |
12 | | -except ImportError as e: |
13 | | - print(f"Error: Frontend '{FRONTEND}' is not supported or cannot be imported: {e}") |
14 | | - sys.exit(1) |
15 | | - |
16 | | -# Register default type infer functions |
17 | | -if hasattr(frontend_module, "type_infer"): |
18 | | - _DEFAULT_TYPE_INFER = frontend_module.type_infer |
19 | | -else: |
20 | | - print(f"Error: Required function 'type_infer' not found in frontend '{FRONTEND}'") |
21 | | - sys.exit(1) |
22 | | - |
23 | | -# Try to get frontend_to_mrt function |
24 | | -frontend_to_mrt = None |
25 | | -if hasattr(frontend_module, "from_frontend"): |
26 | | - frontend_to_mrt = frontend_module.from_frontend |
27 | | -elif hasattr(frontend_module, "pytorch_to_mrt"): |
28 | | - frontend_to_mrt = frontend_module.pytorch_to_mrt |
29 | | -elif hasattr(frontend_module, "expr2symbol"): |
30 | | - frontend_to_mrt = frontend_module.expr2symbol |
31 | | -else: |
32 | | - print(f"Error: Required function 'frontend_to_mrt' not found in frontend '{FRONTEND}'") |
33 | | - sys.exit(1) |
34 | | - |
35 | | -# Try to get mrt_to_frontend function |
36 | | -mrt_to_frontend = None |
37 | | -if hasattr(frontend_module, "to_frontend"): |
38 | | - mrt_to_frontend = frontend_module.to_frontend |
39 | | -elif hasattr(frontend_module, "mrt_to_pytorch"): |
40 | | - mrt_to_frontend = frontend_module.mrt_to_pytorch |
41 | | -elif hasattr(frontend_module, "symbol2expr"): |
42 | | - mrt_to_frontend = frontend_module.symbol2expr |
43 | | -else: |
44 | | - print(f"Error: Required function 'mrt_to_frontend' not found in frontend '{FRONTEND}'") |
45 | | - sys.exit(1) |
| 58 | +@mod.typedef_mod_function |
| 59 | +def run_executor( |
| 60 | + executor, |
| 61 | + data: typing.Optional[np.ndarray] = None, |
| 62 | + data_dict: ParametersT = {} |
| 63 | + ) -> OpNumpyT: |
| 64 | + """ Apply data to executor. """ |
| 65 | + pass |
| 66 | + |
| 67 | +@mod.typedef_mod_function |
| 68 | +def infer( |
| 69 | + graph: MultiHeadSymbol, |
| 70 | + params: ParametersT, |
| 71 | + data: typing.Optional[np.ndarray] = None, |
| 72 | + data_dict: ParametersT = {}, |
| 73 | + device: str = "cpu", |
| 74 | + **kwargs): |
| 75 | + """ Convinent Method to infer model. """ |
| 76 | + pass |
| 77 | + |
| 78 | +@mod.typedef_mod_function |
| 79 | +def data_from_frontend(data: typing.Any) -> OpNumpyT: |
| 80 | + """ Convert Frontend Tensor to MRT DType. """ |
| 81 | + pass |
| 82 | + |
| 83 | +@mod.typedef_mod_function |
| 84 | +def data_to_frontend(data: OpNumpyT): |
| 85 | + """ Convert MRT DType to Frontend Tensor. """ |
| 86 | + pass |
| 87 | + |
| 88 | +@mod.typedef_mod_function |
| 89 | +def model_from_frontend( |
| 90 | + fe_model, |
| 91 | + func_names: typing.List[str] = [ "main", ] |
| 92 | + ) -> typing.Tuple[MultiHeadSymbol, ParametersT]: |
| 93 | + """ Convert Frontend Graph to MRT Symbol/Params. """ |
| 94 | + pass |
| 95 | + |
| 96 | +@mod.typedef_mod_function |
| 97 | +def model_to_frontend(graph: MultiHeadSymbol, params: ParametersT,): |
| 98 | + """ Convert MRT Symbol/Params to Frontend Graph. """ |
| 99 | + pass |
| 100 | + |
| 101 | +@mod.typedef_mod_function |
| 102 | +def type_infer(symbol: Symbol) -> Symbol: |
| 103 | + """ Shape/DType Inference use Frontend API. """ |
| 104 | + |
| 105 | + |
| 106 | +FRONTEND = os.environ.get("FRONTEND", "pytorch") |
| 107 | +mod.load_mod(FRONTEND) |
0 commit comments