Skip to content

Commit dea98a9

Browse files
committed
[frontend]: update api for frontend
1 parent 4deb009 commit dea98a9

File tree

7 files changed

+143
-54
lines changed

7 files changed

+143
-54
lines changed

python/mrt/frontend/api.py

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,107 @@
11
import os
22
import importlib
33
import sys
4+
from functools import wraps
45

5-
from mrt.mir.optype import _DEFAULT_TYPE_INFER
6+
from mrt.mir.symbol import *
7+
from mrt.common.types import *
68

7-
FRONTEND = os.environ.get("FRONTEND", "pytorch")
9+
class Singleton(object):
10+
def __new__(cls, *args, **kw):
11+
if not hasattr(cls, '_instance'):
12+
orig = super(Singleton, cls)
13+
cls._instance = orig.__new__(cls, *args, **kw)
14+
return cls._instance
15+
16+
class DynamicModule(Singleton):
17+
def __init__(self):
18+
self._funcs = {}
19+
20+
def load_mod(self, frontend):
21+
try:
22+
frontend_module = importlib.import_module(f".{FRONTEND}", package="mrt.frontend")
23+
except ImportError as e:
24+
print(f"Error: Frontend '{FRONTEND}' cannot be imported: {e}")
25+
return
26+
27+
for f in self._funcs:
28+
if hasattr(frontend_module, f):
29+
self._funcs[f] = getattr(frontend_module, f)
30+
else:
31+
print(f"Error: function '{f}' not found in frontend '{FRONTEND}'")
32+
33+
return self
34+
35+
def typedef_mod_function(self, func):
36+
fname = func.__name__
37+
self._funcs.setdefault(fname, None)
38+
39+
@wraps(func)
40+
def _func_impl(*args, **kwargs):
41+
assert self._funcs[fname] is not None, f"func:{fname} not registered in mod: {self._funcs.keys()}"
42+
func(*args, **kwargs)
43+
return self._funcs[fname](*args, **kwargs)
44+
return _func_impl
45+
46+
47+
mod = DynamicModule()
48+
49+
@mod.typedef_mod_function
50+
def create_executor(
51+
symbol: MultiHeadSymbol, params: ParametersT,
52+
device: str = "cpu",
53+
target: str = "", # no use in pytorch frontend
54+
):
55+
""" Create Runtime Executor for Model Inference. """
56+
pass
857

9-
# Dynamically load the frontend module
10-
try:
11-
frontend_module = importlib.import_module(f".{FRONTEND}", package="mrt.frontend")
12-
except ImportError as e:
13-
print(f"Error: Frontend '{FRONTEND}' is not supported or cannot be imported: {e}")
14-
sys.exit(1)
15-
16-
# Register default type infer functions
17-
if hasattr(frontend_module, "type_infer"):
18-
_DEFAULT_TYPE_INFER = frontend_module.type_infer
19-
else:
20-
print(f"Error: Required function 'type_infer' not found in frontend '{FRONTEND}'")
21-
sys.exit(1)
22-
23-
# Try to get frontend_to_mrt function
24-
frontend_to_mrt = None
25-
if hasattr(frontend_module, "from_frontend"):
26-
frontend_to_mrt = frontend_module.from_frontend
27-
elif hasattr(frontend_module, "pytorch_to_mrt"):
28-
frontend_to_mrt = frontend_module.pytorch_to_mrt
29-
elif hasattr(frontend_module, "expr2symbol"):
30-
frontend_to_mrt = frontend_module.expr2symbol
31-
else:
32-
print(f"Error: Required function 'frontend_to_mrt' not found in frontend '{FRONTEND}'")
33-
sys.exit(1)
34-
35-
# Try to get mrt_to_frontend function
36-
mrt_to_frontend = None
37-
if hasattr(frontend_module, "to_frontend"):
38-
mrt_to_frontend = frontend_module.to_frontend
39-
elif hasattr(frontend_module, "mrt_to_pytorch"):
40-
mrt_to_frontend = frontend_module.mrt_to_pytorch
41-
elif hasattr(frontend_module, "symbol2expr"):
42-
mrt_to_frontend = frontend_module.symbol2expr
43-
else:
44-
print(f"Error: Required function 'mrt_to_frontend' not found in frontend '{FRONTEND}'")
45-
sys.exit(1)
58+
@mod.typedef_mod_function
59+
def run_executor(
60+
executor,
61+
data: typing.Optional[np.ndarray] = None,
62+
data_dict: ParametersT = {}
63+
) -> OpNumpyT:
64+
""" Apply data to executor. """
65+
pass
66+
67+
@mod.typedef_mod_function
68+
def infer(
69+
graph: MultiHeadSymbol,
70+
params: ParametersT,
71+
data: typing.Optional[np.ndarray] = None,
72+
data_dict: ParametersT = {},
73+
device: str = "cpu",
74+
**kwargs):
75+
""" Convinent Method to infer model. """
76+
pass
77+
78+
@mod.typedef_mod_function
79+
def data_from_frontend(data: typing.Any) -> OpNumpyT:
80+
""" Convert Frontend Tensor to MRT DType. """
81+
pass
82+
83+
@mod.typedef_mod_function
84+
def data_to_frontend(data: OpNumpyT):
85+
""" Convert MRT DType to Frontend Tensor. """
86+
pass
87+
88+
@mod.typedef_mod_function
89+
def model_from_frontend(
90+
fe_model,
91+
func_names: typing.List[str] = [ "main", ]
92+
) -> typing.Tuple[MultiHeadSymbol, ParametersT]:
93+
""" Convert Frontend Graph to MRT Symbol/Params. """
94+
pass
95+
96+
@mod.typedef_mod_function
97+
def model_to_frontend(graph: MultiHeadSymbol, params: ParametersT,):
98+
""" Convert MRT Symbol/Params to Frontend Graph. """
99+
pass
100+
101+
@mod.typedef_mod_function
102+
def type_infer(symbol: Symbol) -> Symbol:
103+
""" Shape/DType Inference use Frontend API. """
104+
105+
106+
FRONTEND = os.environ.get("FRONTEND", "pytorch")
107+
mod.load_mod(FRONTEND)

python/mrt/frontend/pytorch/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,17 @@
77
"""
88

99
from .converter import pytorch_to_mrt, mrt_to_pytorch, type_infer
10+
from .types import data_to_mrt, data_to_torch
11+
from .vm import create_executor, run_executor, infer
1012

1113
# Expose the required functions for the frontend API
1214
from_frontend = pytorch_to_mrt
1315
to_frontend = mrt_to_pytorch
1416

15-
__all__ = ["pytorch_to_mrt", "mrt_to_pytorch", "from_frontend", "to_frontend", "type_infer"]
17+
model_from_frontend = pytorch_to_mrt
18+
model_to_frontend = mrt_to_pytorch
19+
20+
data_from_frontend = data_to_torch
21+
data_to_frontend = data_to_mrt
22+
23+
# __all__ = ["pytorch_to_mrt", "mrt_to_pytorch", "from_frontend", "to_frontend", "type_infer"]

python/mrt/frontend/pytorch/converter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ def pytorch_to_mrt(
139139
ep: torch.export.ExportedProgram,
140140
func_names: typing.List[str] = [ "main", ]
141141
) -> typing.Tuple[MultiHeadSymbol, ParametersT]:
142-
env: typing.Dict[fx.Node, Symbol] = {}
142+
env: typing.Dict[torch.Node, Symbol] = {}
143143

144+
assert isinstance(ep, torch.export.ExportedProgram), f"input not torch ExportedProgram, but {type(ep)}"
144145
param_vars, params = create_parameters(ep)
145146

146147
def _retrieve_args(node):
@@ -163,8 +164,6 @@ def _retrieve_args(node):
163164
if "tensor_meta" in node.meta:
164165
meta_data = node.meta["tensor_meta"]
165166
shape = data_to_mrt(meta_data.shape)
166-
# shape = [str(s) if isinstance(s, torch.SymInt) else int(s) \
167-
# for s in meta_data.shape]
168167
dtype = data_to_mrt(meta_data.dtype)
169168
# else:
170169
# print(node.name, "has no tensor meta")
@@ -284,8 +283,6 @@ def _infer_single_op(sym: Symbol, env: typing.Dict[str, F.Tensor]) -> F.Tensor:
284283

285284
def type_infer(symbol: Symbol) -> Symbol:
286285
"""Infer shape and dtype for all symbols in the graph.
287-
288-
This function works by using torch.Tensor inference.
289286
"""
290287
env: Dict[str, F.Tensor] = {}
291288

python/mrt/frontend/pytorch/vm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def run_executor(
3030
data_dict[k] = torch.from_numpy(v).to(device)
3131
if data is not None:
3232
data = torch.from_numpy(data).to(device)
33-
out = vm(data, **data_dict)
33+
with torch.no_grad():
34+
out = vm(data, **data_dict)
3435
return data_to_mrt(out.detach().cpu())
3536

3637
def infer(graph: MultiHeadSymbol, params: ParametersT,

python/mrt/mir/optype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from mrt.common.types import *
44
from mrt.common import config
5+
from mrt.frontend import api
56
# from mrt.mir import op, opns
67
# from mrt.symbol import Symbol, transform
78

@@ -10,7 +11,6 @@
1011
from .symbol import Symbol, transform
1112

1213
InferTypeT = typing.Callable[[Symbol], Symbol]
13-
_DEFAULT_TYPE_INFER = None
1414
_INFER_TYPE_REG: typing.Dict[str, InferTypeT] = {}
1515

1616
def register_type_infer(
@@ -29,7 +29,7 @@ def infer_single(symbol: Symbol) -> Symbol:
2929
C = config.LogConfig.G()
3030

3131
out = op.retrieve_operator(symbol)
32-
_infer = _INFER_TYPE_REG.get(out.op_name, _DEFAULT_TYPE_INFER)
32+
_infer = _INFER_TYPE_REG.get(out.op_name, api.type_infer)
3333
assert _infer is not None
3434

3535
if symbol.is_near(*C.log_type_infer):

tests/frontend/pytorch/test_resnet50.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,4 @@ def test_resnet50_type_infer():
162162
print("\n" + "="*60 + "\n")
163163

164164
test_resnet50_infer()
165-
# test_resnet50_type_infer()
165+
test_resnet50_type_infer()

tests/frontend/test_frontend_loading.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,41 @@
99
# Add the project root to the path
1010
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
1111

12+
from mrt.mir.symbol import *
13+
from mrt.mir import op
14+
1215
def test_frontend_loading():
1316
print("Testing frontend loading...")
14-
17+
18+
x = op.variable("x", (3, 224), "float")
19+
w = op.variable("w", (10, 224), "float")
20+
y = op.nn_dense(x, w)
21+
1522
# Test PyTorch frontend
1623
print("\n1. Testing PyTorch frontend:")
1724
os.environ["FRONTEND"] = "pytorch"
1825
try:
19-
from mrt.frontend.api import FRONTEND, _DEFAULT_TYPE_INFER, frontend_to_mrt, mrt_to_frontend
26+
from mrt.frontend.api import FRONTEND, model_from_frontend, model_to_frontend
2027
print(f" Loaded frontend: {FRONTEND}")
21-
print(f" _DEFAULT_TYPE_INFER: {_DEFAULT_TYPE_INFER}")
22-
print(f" frontend_to_mrt: {frontend_to_mrt}")
23-
print(f" mrt_to_frontend: {mrt_to_frontend}")
28+
print(f" model_from_frontend: {model_from_frontend}")
29+
print(f" model_to_frontend: {model_to_frontend}")
30+
31+
import torch
32+
fe_model = model_to_frontend(MultiHeadSymbol(main=y), {})
33+
out = fe_model(torch.randn(3, 224), w = torch.rand(10, 224))
34+
# out = fe_model(torch.randn(3, 224))
35+
# print(out)
36+
fe_model = torch.export.export(
37+
fe_model,
38+
args=(torch.randn(3, 224),),
39+
kwargs={
40+
"w": torch.randn(10, 224)})
41+
# print(fe_model)
42+
y1 = model_from_frontend(fe_model)
43+
2444
print(" PyTorch frontend loaded successfully!")
2545
except Exception as e:
46+
# raise e
2647
print(f" Error loading PyTorch frontend: {e}")
2748
return False
2849

0 commit comments

Comments
 (0)