diff --git a/graph_net/torch/rp_expr/longest_rp_expr_parser.py b/graph_net/torch/rp_expr/longest_rp_expr_parser.py new file mode 100644 index 00000000..c6c6b17f --- /dev/null +++ b/graph_net/torch/rp_expr/longest_rp_expr_parser.py @@ -0,0 +1,252 @@ +import typing as t +from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser +from graph_net.torch.rp_expr.rp_expr import PrimitiveId, LetsListTokenRpExpr +import numpy as np +import sys + + +class LongestRpExprParser: + def __init__(self, max_window_size=1024, min_window_size=4): + self.max_window_size = max_window_size + self.min_window_size = min_window_size + + def __call__(self, primitive_id_lists: t.List[t.List[PrimitiveId]]): + fold_policy = "default" + rp_expr_parser = RpExprParser( + self.max_window_size, + fold_policy=fold_policy, + fold_times=1, + ) + lets_list_rp_expr, token_id2primitive_id = rp_expr_parser(primitive_id_lists) + for window_size in self._get_sub_window_sizes(): + rp_expr_parser = RpExprParser( + window_size, + fold_policy=fold_policy, + fold_times=1, + ) + cur_primitive_id_lists = [ + [token_id2primitive_id[token_id] for token_id in tensor.tolist()] + for tensor in lets_list_rp_expr.get_pure_primitive_binding_tensors( + token_id2primitive_id + ) + ] + cur_lets_list_rp_expr, cur_token_id2primitive_id = rp_expr_parser( + cur_primitive_id_lists + ) + # cur_lets_list_rp_expr.try_unwrap_body_of_sole_symbol_token() + lets_list_rp_expr = self._merge_lets_list_rp_expr( + inner=cur_lets_list_rp_expr, + outer=lets_list_rp_expr, + inner_token_id2primitive_id=cur_token_id2primitive_id, + outer_token_id2primitive_id=token_id2primitive_id, + ) + lets_list_rp_expr.try_recursive_inline_symbol_sole_used( + token_id2primitive_id=token_id2primitive_id + ) + # lets_list_rp_expr.try_unwrap_body_of_sole_symbol_token() + return lets_list_rp_expr, token_id2primitive_id + + def _merge_lets_list_rp_expr( + self, + inner, + outer, + inner_token_id2primitive_id, + outer_token_id2primitive_id, + ): + def get_inner_token_id2outer_token_id(): + primitive_id2outer_token_id = {} + for token_id, primitive_id in enumerate(outer_token_id2primitive_id): + assert primitive_id not in primitive_id2outer_token_id + primitive_id2outer_token_id[primitive_id] = token_id + return [ + primitive_id2outer_token_id[primitive_id] + for primitive_id in inner_token_id2primitive_id + ] + + kInner = "inner" + kOuter = "outer" + uid2new_symbol_token = self._make_uid2new_symbol_token_id( + inner=inner, + outer=outer, + inner_uid_prefix=kInner, + outer_uid_prefix=kOuter, + outer_primitive_table_size=len(outer_token_id2primitive_id), + ) + inner_symbol_token_ids = self._convert_symbol_token_ids( + symbol_token_ids=inner.symbol_token_ids, + new_token4old_token=( + lambda old_token: uid2new_symbol_token[f"{kInner}{old_token}"] + ), + ) + inner_token_id2outer_token_id = get_inner_token_id2outer_token_id() + inner_symbol_token_tensors = self._convert_token_tensors( + inner.symbol_token_tensors, + new_token4old_primitive_token=( + lambda old_token: inner_token_id2outer_token_id[old_token] + ), + new_token4old_symbol_token=( + lambda old_token: uid2new_symbol_token[f"{kInner}{old_token}"] + ), + primitive_ids_table_size=len(inner_token_id2primitive_id), + ) + + inner_body_rp_expr = self._convert_token_tensors( + inner.body_rp_expr, + new_token4old_primitive_token=( + lambda old_token: inner_token_id2outer_token_id[old_token] + ), + new_token4old_symbol_token=( + lambda old_token: uid2new_symbol_token[f"{kInner}{old_token}"] + ), + primitive_ids_table_size=len(inner_token_id2primitive_id), + ) + + inner_symbol_token2token_tensor = { + symbol_token: token_tensor + for symbol_token, token_tensor in zip( + inner_symbol_token_ids, inner_symbol_token_tensors + ) + } + + outer_symbol_token_tensors = self._convert_outer_symbol_binding_token_tensors( + inner_body_rp_expr=inner_body_rp_expr, + inner_symbol_token2token_tensor=inner_symbol_token2token_tensor, + outer_lets_list_rp_expr=outer, + new_token4old_primitive_token=lambda x: x, + new_token4old_symbol_token=( + lambda old_token: uid2new_symbol_token[f"{kOuter}{old_token}"] + ), + outer_token_id2primitive_id=outer_token_id2primitive_id, + ) + + symbol_token_ids = inner_symbol_token_ids + self._convert_symbol_token_ids( + symbol_token_ids=outer.symbol_token_ids, + new_token4old_token=( + lambda old_token: uid2new_symbol_token[f"{kOuter}{old_token}"] + ), + ) + + symbol_token_tensors = inner_symbol_token_tensors + outer_symbol_token_tensors + + body_rp_expr = self._convert_token_tensors( + outer.body_rp_expr, + new_token4old_primitive_token=lambda x: x, + new_token4old_symbol_token=( + lambda old_token: uid2new_symbol_token[f"{kOuter}{old_token}"] + ), + primitive_ids_table_size=len(outer_token_id2primitive_id), + ) + ret_lets_list_token_rp_expr = LetsListTokenRpExpr( + symbol_token_ids=symbol_token_ids, + symbol_token_tensors=symbol_token_tensors, + body_rp_expr=body_rp_expr, + ) + ret_lets_list_token_rp_expr.move_pure_primitive_bindings_front( + outer_token_id2primitive_id + ) + return ret_lets_list_token_rp_expr + + def _convert_outer_symbol_binding_token_tensors( + self, + inner_body_rp_expr, + inner_symbol_token2token_tensor, + outer_lets_list_rp_expr, + new_token4old_primitive_token, + new_token4old_symbol_token, + outer_token_id2primitive_id, + ): + indexes = outer_lets_list_rp_expr.get_pure_primitive_binding_indexes( + outer_token_id2primitive_id + ) + assert len(inner_body_rp_expr) == len(indexes) + index2inner_body_rp_expr_idx = { + index: inner_body_rp_expr_idx + for inner_body_rp_expr_idx, index in enumerate(indexes) + } + old_tensors = outer_lets_list_rp_expr.symbol_token_tensors + return [ + ( + inner_body_rp_expr[index2inner_body_rp_expr_idx[index]] + if index in index2inner_body_rp_expr_idx + else self._convert_token_tensor( + tensor=old_tensors[index], + new_token4old_primitive_token=new_token4old_primitive_token, + new_token4old_symbol_token=new_token4old_symbol_token, + primitive_ids_table_size=len(outer_token_id2primitive_id), + ) + ) + for index in range(len(old_tensors)) + ] + + def _convert_token_tensors( + self, + tensors, + new_token4old_primitive_token, + new_token4old_symbol_token, + primitive_ids_table_size, + ): + return [ + self._convert_token_tensor( + tensor, + new_token4old_primitive_token, + new_token4old_symbol_token, + primitive_ids_table_size, + ) + for tensor in tensors + ] + + def _convert_token_tensor( + self, + tensor, + new_token4old_primitive_token, + new_token4old_symbol_token, + primitive_ids_table_size, + ): + return np.array( + [ + ( + new_token4old_primitive_token(token_id) + if token_id < primitive_ids_table_size + else new_token4old_symbol_token(token_id) + ) + for token_id in tensor.tolist() + ], + dtype=np.int64, + ) + + def _make_uid2new_symbol_token_id( + self, + inner, + outer, + inner_uid_prefix, + outer_uid_prefix, + outer_primitive_table_size, + ): + new_symbol_token_id = outer_primitive_table_size + + def get_new_symbol_token_id(): + nonlocal new_symbol_token_id + ret = new_symbol_token_id + new_symbol_token_id += 1 + return ret + + uid2new_symbol_token_id = {} + for inner_symbol_token_id in inner.symbol_token_ids: + uid = f"{inner_uid_prefix}{inner_symbol_token_id}" + uid2new_symbol_token_id[uid] = get_new_symbol_token_id() + for outer_symbol_token_id in outer.symbol_token_ids: + uid = f"{outer_uid_prefix}{outer_symbol_token_id}" + uid2new_symbol_token_id[uid] = get_new_symbol_token_id() + return uid2new_symbol_token_id + + def _convert_symbol_token_ids(self, symbol_token_ids, new_token4old_token): + return [ + new_token4old_token(symbol_token_id) for symbol_token_id in symbol_token_ids + ] + + def _get_sub_window_sizes(self): + min_window_size = max(1, self.min_window_size) + window_size = self.max_window_size // 2 + while window_size > min_window_size: + yield window_size + window_size = window_size // 2 diff --git a/graph_net/torch/rp_expr/nested_range.py b/graph_net/torch/rp_expr/nested_range.py new file mode 100644 index 00000000..9a0680a3 --- /dev/null +++ b/graph_net/torch/rp_expr/nested_range.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +import typing as t + + +@dataclass +class NestedRange: + pass + + +@dataclass +class Range(NestedRange): + start: int + end: int + + def FilterSubTreeRangeBySize(self, min_len: int, max_len: int): + length = self.end - self.start + if (length >= min_len) and (length < max_len): + yield (self.start, self.end) + + +@dataclass +class Tree(NestedRange): + uid: str + node: Range + children: t.List[NestedRange] + + def FilterSubTreeRangeBySize(self, min_len: int, max_len: int): + length = self.node.end - self.node.start + if length < min_len: + yield from () + elif length < max_len: + yield from self.node.FilterSubTreeRangeBySize(min_len, max_len) + else: + yield from ( + node_range + for child in self.children + for node_range in child.FilterSubTreeRangeBySize(min_len, max_len) + ) diff --git a/graph_net/torch/rp_expr/rp_expr.py b/graph_net/torch/rp_expr/rp_expr.py new file mode 100644 index 00000000..9621a3d7 --- /dev/null +++ b/graph_net/torch/rp_expr/rp_expr.py @@ -0,0 +1,476 @@ +from dataclasses import dataclass +import typing as t +import numpy as np +import torch +from collections import defaultdict +import functools + +PrimitiveId = t.TypeVar("PrimitiveId") + +TokenId = int + + +# Repeat Pattern Expression +@dataclass +class RpExpr: + pass + + +@dataclass +class ListRpExpr(RpExpr): + pass + + +@dataclass +class NaiveTokenListRpExpr(ListRpExpr): + tensors: t.List[np.ndarray["N", np.int64]] + + +@dataclass +class TokenizedRpExpr(RpExpr): + token_id2primitive_id: t.List[PrimitiveId] + token_tensors: ListRpExpr + + +@dataclass +class TokenRpExpr(RpExpr): + pass + + +@dataclass +class FlattenedTokenListRpExpr(ListRpExpr): + tensor_list_size: int + flattened_tensor: TokenRpExpr + + +@dataclass +class NaiveTokenRpExpr(TokenRpExpr): + tensor: np.ndarray["N", np.int64] + + +@dataclass +class LetsTokenRpExpr(TokenRpExpr): + symbol_token_ids: t.List[TokenId] + symbol_token_tensors: t.List[np.ndarray["N", np.int64]] + body_rp_expr: NaiveTokenRpExpr + + +@dataclass +class LetsListTokenRpExpr(TokenRpExpr): + symbol_token_ids: t.List[TokenId] + symbol_token_tensors: t.List[np.ndarray["N", np.int64]] + body_rp_expr: t.List[np.ndarray["N", np.int64]] + + def DebugStrings( + self, + token_id2primitive_id: t.List[PrimitiveId], + prefix="sequence", + end_of_line="", + ): + return self._DebugStrings( + token_id2primitive_id, + prefix=prefix, + end_of_line=end_of_line, + ) + + def get_pure_primitive_binding_indexes(self, token_id2primitive_id): + return self._get_pure_primitive_binding_indexes(token_id2primitive_id) + + def get_pure_primitive_binding_tensors(self, token_id2primitive_id): + return [ + self.symbol_token_tensors[index] + for index in self._get_pure_primitive_binding_indexes(token_id2primitive_id) + ] + + def inplace_group_consecutive_primitives(self, token_id2primitive_id): + return self._inplace_group_consecutive_primitives(token_id2primitive_id) + + def try_unwrap_body_of_sole_symbol_token(self): + return self._try_unwrap_body_of_sole_symbol_token() + + def try_recursive_inline_symbol_sole_used(self, token_id2primitive_id): + return self._try_recursive_inline_symbol_sole_used(token_id2primitive_id) + + def try_recursive_inline_symbol(self, token_id2primitive_id): + return self._try_recursive_inline_symbol(token_id2primitive_id) + + def _try_recursive_inline_symbol(self, token_id2primitive_id): + while self._try_inline_symbol(token_id2primitive_id): + pass + + def _try_inline_symbol(self, token_id2primitive_id): + pure_primitive_indexes = self.get_pure_primitive_binding_indexes( + token_id2primitive_id + ) + symbol_token2index = { + symbol_token: index + for index, symbol_token in enumerate(self.symbol_token_ids) + if index not in pure_primitive_indexes + } + symbol_token2used_count = { + symbol_token: 0 for symbol_token, _ in symbol_token2index.items() + } + for tensor in self.symbol_token_tensors: + for token in tensor.tolist(): + if token not in symbol_token2used_count: + continue + symbol_token2used_count[token] += 1 + found = False + symbol_token = None + symbol_index = None + symbol_tensor = None + for cur_symbol_token, used_count in symbol_token2used_count.items(): + if used_count >= 1: + found = True + symbol_token = cur_symbol_token + symbol_index = symbol_token2index[symbol_token] + symbol_tensor = self.symbol_token_tensors[symbol_index] + break + if not found: + return False + + def get_self_or_inlined(x): + if x == symbol_token: + yield from symbol_tensor.tolist() + else: + yield x + + def inline_list(lst): + return [x for token in lst for x in get_self_or_inlined(token)] + + def inline_tensor(tensor): + return np.array(inline_list(tensor.tolist()), dtype=np.int64) + + def inline_tensor_list(tensor_list): + return [inline_tensor(tensor) for tensor in tensor_list] + + self.symbol_token_ids.pop(symbol_index) + self.symbol_token_tensors.pop(symbol_index) + self.symbol_token_tensors = inline_tensor_list(self.symbol_token_tensors) + self.body_rp_expr = inline_tensor_list(self.body_rp_expr) + return True + + def _try_recursive_inline_symbol_sole_used(self, token_id2primitive_id): + while self._try_inline_symbol_sole_used(token_id2primitive_id): + pass + + def _try_inline_symbol_sole_used(self, token_id2primitive_id): + pure_primitive_indexes = self.get_pure_primitive_binding_indexes( + token_id2primitive_id + ) + symbol_token2index = { + symbol_token: index + for index, symbol_token in enumerate(self.symbol_token_ids) + if index not in pure_primitive_indexes + } + symbol_token2used_count = { + symbol_token: 0 for symbol_token, _ in symbol_token2index.items() + } + for tensor in self.symbol_token_tensors: + for token in tensor.tolist(): + if token not in symbol_token2used_count: + continue + symbol_token2used_count[token] += 1 + found = False + symbol_token = None + symbol_index = None + symbol_tensor = None + for cur_symbol_token, used_count in symbol_token2used_count.items(): + if used_count == 1: + found = True + symbol_token = cur_symbol_token + symbol_index = symbol_token2index[symbol_token] + symbol_tensor = self.symbol_token_tensors[symbol_index] + break + if not found: + return False + + def get_self_or_inlined(x): + if x == symbol_token: + yield from symbol_tensor.tolist() + else: + yield x + + def inline_list(lst): + return [x for token in lst for x in get_self_or_inlined(token)] + + def inline_tensor(tensor): + return np.array(inline_list(tensor.tolist()), dtype=np.int64) + + def inline_tensor_list(tensor_list): + return [inline_tensor(tensor) for tensor in tensor_list] + + self.symbol_token_ids.pop(symbol_index) + self.symbol_token_tensors.pop(symbol_index) + self.symbol_token_tensors = inline_tensor_list(self.symbol_token_tensors) + self.body_rp_expr = inline_tensor_list(self.body_rp_expr) + return True + + def _try_unwrap_body_of_sole_symbol_token(self): + symbol_token2symbol_tensor = { + symbol_token: symbol_tensor + for symbol_token, symbol_tensor in zip( + self.symbol_token_ids, self.symbol_token_tensors + ) + } + token2used_count = {} + for tensor in self.symbol_token_tensors + self.body_rp_expr: + for token in tensor.tolist(): + if token not in token2used_count: + token2used_count[token] = 1 + else: + token2used_count[token] += 1 + sole_symbol_token_body_indexes = [ + i + for i in range(len(self.body_rp_expr)) + for body_item in [self.body_rp_expr[i]] + if body_item.size == 1 + if body_item[0] in symbol_token2symbol_tensor + if token2used_count[body_item[0]] == 1 + ] + symbol_tokens_in_sole_symbol_body = [ + self.body_rp_expr[i][0] for i in sole_symbol_token_body_indexes + ] + symbol_token_and_symbol_tensors = [ + (symbol_token, symbol_tensor) + for symbol_token, symbol_tensor in zip( + self.symbol_token_ids, self.symbol_token_tensors + ) + if symbol_token not in symbol_tokens_in_sole_symbol_body + ] + self.symbol_token_ids = [x[0] for x in symbol_token_and_symbol_tensors] + self.symbol_token_tensors = [x[1] for x in symbol_token_and_symbol_tensors] + self.body_rp_expr = [ + ( + symbol_token2symbol_tensor[self.body_rp_expr[i][0]] + if i in sole_symbol_token_body_indexes + else self.body_rp_expr[i] + ) + for i in range(len(self.body_rp_expr)) + ] + + def move_pure_primitive_bindings_front(self, token_id2primitive_id): + return self._move_pure_primitive_bindings_front(token_id2primitive_id) + + def _move_pure_primitive_bindings_front(self, token_id2primitive_id): + indexes = self.get_pure_primitive_binding_indexes(token_id2primitive_id) + + def reorder(lst): + return [lst[i] for i in range(len(lst)) if i in indexes] + [ + lst[i] for i in range(len(lst)) if i not in indexes + ] + + self.symbol_token_ids = reorder(self.symbol_token_ids) + self.symbol_token_tensors = reorder(self.symbol_token_tensors) + + def _get_pure_primitive_binding_indexes(self, token_id2primitive_id): + primitive_table_size = len(token_id2primitive_id) + ret = [] + for i, tensor in enumerate(self.symbol_token_tensors): + primitive_splited_tensors = self._split_consecutive_primitive( + tensor, primitive_table_size + ) + if ( + len(primitive_splited_tensors) == 1 + and primitive_splited_tensors[0][0] < primitive_table_size + ): + ret.append(i) + return ret + + def _inplace_group_consecutive_primitives(self, token_id2primitive_id): + get_auto_symbol_token_id = self._getter_auto_symbol_token_id( + token_id2primitive_id + ) + primitive_table_size = len(token_id2primitive_id) + ( + primitives_new_token_ids_in_binding, + primitives_new_token_tensors_in_binding, + replaced_tensors_of_bindings, + ) = self._group_token_tensors_consecutive_primitives( + self.symbol_token_tensors, get_auto_symbol_token_id, primitive_table_size + ) + ( + primitives_new_token_ids_in_body, + primitives_new_token_tensors_in_body, + replaced_tensors_of_body, + ) = self._group_token_tensors_consecutive_primitives( + self.body_rp_expr, get_auto_symbol_token_id, primitive_table_size + ) + primitives_new_token_ids = ( + primitives_new_token_ids_in_binding + primitives_new_token_ids_in_body + ) + primitives_new_token_tensors = ( + primitives_new_token_tensors_in_binding + + primitives_new_token_tensors_in_body + ) + self.symbol_token_ids = primitives_new_token_ids + self.symbol_token_ids + self.symbol_token_tensors = ( + primitives_new_token_tensors + replaced_tensors_of_bindings + ) + self.body_rp_expr = replaced_tensors_of_body + + def _group_token_tensors_consecutive_primitives( + self, token_tensors, get_auto_symbol_token_id, primitive_table_size + ): + primitives_new_token_ids = [] + primitives_new_token_tensors = [] + ret_token_tensors = [] + for token_tensor in token_tensors: + ( + cur_primitives_new_token_ids, + cur_primitives_new_token_tensors, + cur_ret_token_tensor, + ) = self._group_consecutive_primitives( + token_tensor, get_auto_symbol_token_id, primitive_table_size + ) + primitives_new_token_ids += cur_primitives_new_token_ids + primitives_new_token_tensors += cur_primitives_new_token_tensors + ret_token_tensors.append(cur_ret_token_tensor) + return primitives_new_token_ids, primitives_new_token_tensors, ret_token_tensors + + def _group_consecutive_primitives( + self, token_tensor, get_auto_symbol_token_id, primitive_table_size + ): + primitive_splited_tensors = self._split_consecutive_primitive( + token_tensor, primitive_table_size + ) + if ( + len(primitive_splited_tensors) == 1 + and primitive_splited_tensors[0][0] < primitive_table_size + ): + return [], [], token_tensor + primitives_new_token_ids = [] + primitives_new_token_tensors = [] + ret_token_tensors = [] + for tensor in primitive_splited_tensors: + assert tensor.size > 0 + if tensor[0] < primitive_table_size: + new_token_id = get_auto_symbol_token_id() + primitives_new_token_ids.append(new_token_id) + primitives_new_token_tensors.append(tensor) + ret_token_tensors.append(np.array([new_token_id], dtype=np.int64)) + else: + ret_token_tensors.append(tensor) + ret_token_tensor = np.concatenate(ret_token_tensors) + return primitives_new_token_ids, primitives_new_token_tensors, ret_token_tensor + + def _split_consecutive_primitive(self, token_tensor, primitive_table_size): + is_primitive_tensor = token_tensor < primitive_table_size + consecutive_tensors = consecutive(is_primitive_tensor, stepsize=0) + global_start = 0 + + def get_range(size): + nonlocal global_start + start = global_start + end = start + size + global_start = end + return (start, end) + + return [ + token_tensor[start:end] + for consecutive_tensor in consecutive_tensors + for start, end in [get_range(consecutive_tensor.size)] + ] + + def _getter_auto_symbol_token_id(self, token_id2primitive_id): + start_token_id = len(token_id2primitive_id) + for token_id in self.symbol_token_ids: + start_token_id = max(token_id + 1, start_token_id) + + def get_new_symbol_token_id(): + nonlocal start_token_id + ret = start_token_id + start_token_id += 1 + return ret + + return get_new_symbol_token_id + + def _DebugStrings( + self, + token_id2primitive_id: t.List[PrimitiveId], + prefix="sequence", + end_of_line="", + ): + indexes = self.get_pure_primitive_binding_indexes(token_id2primitive_id) + pure_primitive_symbol_token_set = set( + symbol_token + for index in indexes + for symbol_token in [self.symbol_token_ids[index]] + ) + + def IsPrimitive(token_id): + return token_id < len(token_id2primitive_id) + + def SymbolToString(symbol_id): + return ( + f"{prefix}{symbol_id}" + if symbol_id in pure_primitive_symbol_token_set + else f"fold_{prefix}{symbol_id}" + ) + + def ValueToString(token_id): + if IsPrimitive(token_id): + return token_id2primitive_id[token_id] + return f"{SymbolToString(token_id)}()" + + yield from ( + pycode + for symbol_id, tensor in zip( + self.symbol_token_ids, self.symbol_token_tensors + ) + for token_ids in [tensor.tolist()] + for pycode in [ + f"def {SymbolToString(symbol_id)}():", + *[f" {ValueToString(x)}{end_of_line}" for x in token_ids], + "", + ] + ) + yield from [ + f"def main():", + *[ + f" {SymbolToString(int(t[0]))}(){end_of_line}" + for t in self.body_rp_expr + ], + ] + + +def consecutive(data, stepsize=1): + return np.split(data, np.where(np.diff(data) != stepsize)[0] + 1) + + +class TokenIdAllocator: + def __init__(self, next_token_id: int = 0): + self.next_token_id = next_token_id + + def NewTokenId(self): + value = self.next_token_id + self.next_token_id += 1 + return value + + def NextTokenId(self): + return self.next_token_id + + def Skip(self, size): + self.next_token_id += size + + +def Tokenize( + primitive_id_lists: t.List[t.List[PrimitiveId]], +) -> t.Tuple[TokenizedRpExpr, TokenIdAllocator]: + token_id_allocator = TokenIdAllocator() + primitive_id2token_id = defaultdict(token_id_allocator.NewTokenId) + token_tensors = [ + torch.tensor( + [primitive_id2token_id[primitive_id] for primitive_id in primitive_id_list], + dtype=torch.int64, + ) + for primitive_id_list in primitive_id_lists + ] + token_id2primitive_id = [None] * len(primitive_id2token_id) + for primitive_id, token_id in primitive_id2token_id.items(): + token_id2primitive_id[token_id] = primitive_id + return ( + NaiveTokenListRpExpr(token_tensors), + token_id_allocator, + token_id2primitive_id, + ) diff --git a/graph_net/torch/rp_expr/rp_expr_parser.py b/graph_net/torch/rp_expr/rp_expr_parser.py new file mode 100644 index 00000000..0791622f --- /dev/null +++ b/graph_net/torch/rp_expr/rp_expr_parser.py @@ -0,0 +1,62 @@ +import typing as t +import numpy as np +from graph_net.torch.rp_expr.rp_expr import Tokenize, PrimitiveId, LetsListTokenRpExpr +from graph_net.torch.rp_expr.rp_expr_passes import ( + FlattenTokenListPass, + FoldTokensPass, + RecursiveFoldTokensPass, + FoldIfTokenIdGreatEqualPass, + UnflattenAndSubThresholdPass, +) + + +class RpExprParser: + def __init__(self, window_size=8, fold_policy="default", fold_times=None): + self.window_size = window_size + self.fold_policy = fold_policy + self.fold_times = fold_times + + def __call__(self, primitive_id_lists: t.List[t.List[PrimitiveId]]): + token_list, id_allocator, token_id2primitive_id = Tokenize(primitive_id_lists) + flatten_pass = FlattenTokenListPass(id_allocator) + success, flattened_rp_expr = flatten_pass(token_list) + assert success + fold_pass = RecursiveFoldTokensPass( + id_allocator, + self.window_size, + fold_policy=self.fold_policy, + fold_times=self.fold_times, + ) + success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor) + if not success: + primitive_id2token_id = { + primitive_id: token_id + for token_id, primitive_id in enumerate(token_id2primitive_id) + } + lets_list_token_rp_expr = LetsListTokenRpExpr( + symbol_token_ids=[], + symbol_token_tensors=[], + body_rp_expr=[ + np.array( + [ + primitive_id2token_id[primitive_id] + for primitive_id in primitive_ids + ], + dtype=np.int64, + ) + for primitive_ids in primitive_id_lists + ], + ) + return lets_list_token_rp_expr, token_id2primitive_id + assert success, f"{self.window_size=}, {self.fold_policy=}, {self.fold_times=}" + threshold = len(primitive_id_lists) + unflatten_pass = UnflattenAndSubThresholdPass( + id_allocator=id_allocator, + threshold_start_token_id=threshold, + ) + success, threshold_fold_rp_expr = unflatten_pass(fold_rp_expr) + assert success + threshold_fold_rp_expr.inplace_group_consecutive_primitives( + token_id2primitive_id + ) + return threshold_fold_rp_expr, token_id2primitive_id diff --git a/graph_net/torch/rp_expr/rp_expr_passes.py b/graph_net/torch/rp_expr/rp_expr_passes.py new file mode 100644 index 00000000..b904eb90 --- /dev/null +++ b/graph_net/torch/rp_expr/rp_expr_passes.py @@ -0,0 +1,378 @@ +from dataclasses import dataclass +import typing as t +import numpy as np +import re +import itertools +import torch +import torch.nn.functional as F +import math +from graph_net.torch.rp_expr.rp_expr import ( + TokenIdAllocator, + NaiveTokenListRpExpr, + FlattenedTokenListRpExpr, + NaiveTokenRpExpr, + LetsTokenRpExpr, + LetsListTokenRpExpr, +) +import itertools +import sys + + +class Pass: + pass + + +class FlattenTokenListPass(Pass): + def __init__(self, id_allocator: TokenIdAllocator): + self.id_allocator = id_allocator + + def __call__(self, token_tensors_rp_expr: NaiveTokenListRpExpr): + tensor_list_size = len(token_tensors_rp_expr.tensors) + self.id_allocator.Skip(tensor_list_size) + + def GetSepTensor(i): + if i == 0: + return [] + return [torch.tensor([i], dtype=torch.int64)] + + token_tensors = [ + tensor + for i, token_tensor in enumerate(token_tensors_rp_expr.tensors) + for tensor in GetSepTensor(i) + [token_tensor + tensor_list_size] + ] + return True, FlattenedTokenListRpExpr( + tensor_list_size=tensor_list_size, + flattened_tensor=NaiveTokenRpExpr( + tensor=torch.cat(token_tensors, dim=0), + ), + ) + + +class FoldTokensPass(Pass): + def __init__(self, id_allocator: TokenIdAllocator, window_size=8, policy="default"): + self.window_size = window_size + self.random_feature_size = 2 + self.id_allocator = id_allocator + size = id_allocator.NextTokenId() + torch.manual_seed(2024) + self.embedding = torch.empty( + size, self.random_feature_size, dtype=torch.float64 + ).uniform_(-1, 1) + self.embedding.requires_grad_(True) + self.policy = policy + + def __call__(self, token_tensor: NaiveTokenRpExpr): + input_tensor = token_tensor.tensor + raw_most_frequent_length, indexes = self.GetMostFrequentPatternLengthAndIndexes( + input_tensor + ) + most_frequent_length = self.GetAdaptivePatternLength( + raw_most_frequent_length, + indexes, + ) + # print(f"{self.window_size=}, {raw_most_frequent_length=}, {most_frequent_length=}") + new_token_id, replacement = self.Replace( + pattern_length=most_frequent_length, + indexes=indexes, + input_tensor=input_tensor, + ) + if new_token_id is None: + return False, token_tensor + start = indexes[0] + return True, LetsTokenRpExpr( + symbol_token_ids=[new_token_id], + symbol_token_tensors=[input_tensor[start : (start + most_frequent_length)]], + body_rp_expr=NaiveTokenRpExpr(tensor=replacement), + ) + + def GetAdaptivePatternLength(self, pattern_length, indexes): + indexes = indexes.numpy().tolist() + kLimit = self.window_size * 2 + while pattern_length > 1: + disjoint_range_starts = [ + start for start in self.GetDisjoint(pattern_length, indexes) + ] + if len(disjoint_range_starts) > 1: + break + if pattern_length > kLimit: + pattern_length = pattern_length // 2 + else: + pattern_length -= 1 + return pattern_length + + def Replace( + self, + pattern_length, + indexes, + input_tensor: np.ndarray["N", np.int64], + ) -> t.Tuple[bool, int, np.ndarray["N", np.int64]]: + num_tokens = input_tensor.shape[0] + if pattern_length == 1: + return None, input_tensor + assert indexes.shape[0] > 0 + disjoint_range_starts = [ + start + for start in self.GetDisjoint(pattern_length, indexes.numpy().tolist()) + ] + if len(disjoint_range_starts) <= 1: + return None, input_tensor + assert disjoint_range_starts[-1] + pattern_length <= num_tokens + first_start = disjoint_range_starts[0] + pattern_tensor = input_tensor[first_start : (first_start + pattern_length)] + segment_starts = ( + [0] + + [ + index + for start in disjoint_range_starts + for index in [start, start + pattern_length] + ] + + [num_tokens] + ) + uniqued_segment_starts = torch.unique(torch.tensor(segment_starts)) + segment_lengths = torch.diff(uniqued_segment_starts).numpy().tolist() + + new_token_id = self.id_allocator.NewTokenId() + new_token_tensor = torch.tensor([new_token_id], dtype=torch.int64) + + def ReplaceTensor(tensor): + if tensor.shape != pattern_tensor.shape: + return tensor + if bool(torch.all(tensor == pattern_tensor)): + return new_token_tensor + return tensor + + replaced_segment_tensors = [ + ReplaceTensor(tensor) + for tensor in torch.split(input_tensor, segment_lengths) + ] + output_tensor = torch.cat(replaced_segment_tensors) + return new_token_id, output_tensor + + def GetConv(self, num_tokens): + windows_size = min(num_tokens, self.window_size) + + def GetWeight(): + torch.manual_seed(2024) + weight = torch.empty( + windows_size, windows_size, dtype=torch.float64 + ).uniform_(-1, 1) + weight.requires_grad_(True) + weight = ( + torch.triu(weight) + .transpose(0, 1) + .reshape(windows_size, 1, windows_size) + ) + return weight + + conv_weight = torch.cat( + [GetWeight() for _ in range(self.random_feature_size)], dim=1 + ) + conv = lambda input: F.conv1d(input, conv_weight, padding=0) + return conv, windows_size + + def GetDisjoint(self, gap, indexes): + if len(indexes) == 0: + return + last = indexes[0] + yield last + for current in indexes: + if current >= (last + gap): + yield current + last = current + + def GetMostFrequentPatternLengthAndIndexes( + self, + token_tensor: np.ndarray["N", np.int64], + ): + conv, windows_size = self.GetConv(num_tokens=token_tensor.shape[0]) + input = torch.index_select(self.embedding, 0, token_tensor) + input.requires_grad_(True) + zeros = torch.zeros( + windows_size - 1, self.random_feature_size, dtype=torch.float64 + ) + input = torch.cat([input, zeros]) + input = input.reshape(1, -1, self.random_feature_size).transpose(1, 2) + y = conv(input) + y = y.reshape(windows_size, -1) + y_hash = y.view(torch.int64) + # `pattern_len_sub_1` means `pattern_length - 1` + pattern_len_sub_1 = ( + torch.arange(windows_size).reshape(-1, 1).expand(y_hash.shape) + ) + pattern_len_sub_1_and_hash = torch.cat( + [pattern_len_sub_1.reshape(-1, 1), y_hash.reshape(-1, 1)], dim=1 + ) + unique_pattern_len_sub_1_and_hash, counts = torch.unique( + pattern_len_sub_1_and_hash, dim=0, return_counts=True + ) + if self.policy == "default": + most_frequent_hash_idx = torch.argmax( + unique_pattern_len_sub_1_and_hash[:, 0] * (counts - 1) + ) + elif self.policy == "longest": + most_frequent_hash_idx = torch.argmax( + unique_pattern_len_sub_1_and_hash[:, 0] * (counts > 1).to(torch.int64) + ) + else: + assert False, f"policy {self.policy} not implemneted." + + most_frequent_hash = int( + unique_pattern_len_sub_1_and_hash[most_frequent_hash_idx, 1] + ) + most_frequent_pattern_len_sub_1 = int( + unique_pattern_len_sub_1_and_hash[most_frequent_hash_idx, 0] + ) + indexes = torch.nonzero( + most_frequent_hash == y_hash[most_frequent_pattern_len_sub_1, :] + ).flatten() + return most_frequent_pattern_len_sub_1 + 1, indexes + + +class RecursiveFoldTokensPass(Pass): + def __init__( + self, + id_allocator: TokenIdAllocator, + window_size=8, + fold_policy="default", + fold_times=None, + ): + self.id_allocator = id_allocator + self.window_size = window_size + self.fold_policy = fold_policy + self.fold_times = fold_times if fold_times is not None else sys.maxsize + + def __call__(self, token_tensor: NaiveTokenRpExpr): + fold_pass = FoldTokensPass( + self.id_allocator, self.window_size, policy=self.fold_policy + ) + success, ret = fold_pass(token_tensor) + if not success: + return False, token_tensor + symbol_token_ids = ret.symbol_token_ids + symbol_token_tensors = ret.symbol_token_tensors + token_tensor = ret.body_rp_expr + counter = itertools.count() + kLimit = 9999999 + for _ in range(self.fold_times): + fold_pass = FoldTokensPass( + self.id_allocator, self.window_size, policy=self.fold_policy + ) + success, ret = fold_pass(token_tensor) + if not success: + token_tensor = ret + break + assert ret.body_rp_expr.tensor.shape[0] < token_tensor.tensor.shape[0] + if next(counter) > kLimit: + raise RuntimeError("dead loop detected.") + symbol_token_ids += ret.symbol_token_ids + symbol_token_tensors += ret.symbol_token_tensors + token_tensor = ret.body_rp_expr + return True, LetsTokenRpExpr( + symbol_token_ids=symbol_token_ids, + symbol_token_tensors=symbol_token_tensors, + body_rp_expr=token_tensor, + ) + + +class UnflattenAndSubThresholdPass(Pass): + def __init__( + self, + id_allocator: TokenIdAllocator, + threshold_start_token_id: int, + ): + self.id_allocator = id_allocator + self.threshold_start_token_id = threshold_start_token_id + + def __call__(self, lists_token_rp_expr: LetsTokenRpExpr): + threshold_fold_pass = FoldIfTokenIdGreatEqualPass( + id_allocator=self.id_allocator, + threshold_start_token_id=self.threshold_start_token_id, + ) + success, threshold_fold_rp_expr = threshold_fold_pass( + lists_token_rp_expr.body_rp_expr + ) + assert success + return True, self.MergeAndUnflatten( + lists_token_rp_expr, threshold_fold_rp_expr, self.threshold_start_token_id + ) + + def MergeAndUnflatten(self, fold_rp_expr, threshold_fold_rp_expr, threshold): + assert len(threshold_fold_rp_expr.body_rp_expr) == threshold + return LetsListTokenRpExpr( + symbol_token_ids=[ + x - threshold + for x in ( + fold_rp_expr.symbol_token_ids + + threshold_fold_rp_expr.symbol_token_ids + ) + ], + symbol_token_tensors=[ + x - threshold + for x in ( + fold_rp_expr.symbol_token_tensors + + threshold_fold_rp_expr.symbol_token_tensors + ) + ], + body_rp_expr=[x - threshold for x in threshold_fold_rp_expr.body_rp_expr], + ) + + +class FoldIfTokenIdGreatEqualPass(Pass): + def __init__( + self, + id_allocator: TokenIdAllocator, + threshold_start_token_id: int, + ): + self.id_allocator = id_allocator + self.threshold_start_token_id = threshold_start_token_id + + def __call__(self, token_rp_expr: NaiveTokenRpExpr): + indexes_ge_threshold = self.GetIndexesGeThreshold(token_rp_expr.tensor) + token_ids_ge_threshold = torch.index_select( + token_rp_expr.tensor, 0, indexes_ge_threshold + ) + consecutive_index_range_lengths = self.GetConsecutiveIndexRangeLengths( + indexes_ge_threshold=indexes_ge_threshold, + ) + tensors = torch.split(token_ids_ge_threshold, consecutive_index_range_lengths) + + def GetSymbolsValuesBodyTriple(tensor): + if tensor.shape[0] == 1: + return [], [], tensor + new_token_id = self.id_allocator.NewTokenId() + return ( + [new_token_id], + [tensor], + torch.tensor([new_token_id], dtype=torch.int64), + ) + + symbols_values_body_triples = [ + GetSymbolsValuesBodyTriple(tensor) for tensor in tensors + ] + return True, LetsListTokenRpExpr( + symbol_token_ids=[ + token_id + for new_token_ids, _, _ in symbols_values_body_triples + for token_id in new_token_ids + ], + symbol_token_tensors=[ + token_tensor + for _, token_tensors, _ in symbols_values_body_triples + for token_tensor in token_tensors + ], + body_rp_expr=[ + body_tensor for _, _, body_tensor in symbols_values_body_triples + ], + ) + + def GetIndexesGeThreshold(self, token_tensor: np.ndarray["N", np.int64]): + indexes = torch.nonzero(token_tensor >= self.threshold_start_token_id).flatten() + return indexes + + def GetConsecutiveIndexRangeLengths(self, indexes_ge_threshold): + groups = self.GetNumpyConsecutiveGroups(indexes_ge_threshold.numpy()) + return [group.shape[0] for group in groups] + + # reference: https://stackoverflow.com/questions/7352684/how-to-find-the-groups-of-consecutive-elements-in-a-numpy-array + def GetNumpyConsecutiveGroups(self, data, stepsize=1): + return np.split(data, np.where(np.diff(data) != stepsize)[0] + 1) diff --git a/graph_net/torch/rp_expr/rp_expr_util.py b/graph_net/torch/rp_expr/rp_expr_util.py new file mode 100644 index 00000000..dbe20c5e --- /dev/null +++ b/graph_net/torch/rp_expr/rp_expr_util.py @@ -0,0 +1,68 @@ +import typing as t +from graph_net.torch.rp_expr.rp_expr import LetsListTokenRpExpr +from graph_net.torch.rp_expr.nested_range import Range, Tree +from collections import defaultdict +from dataclasses import dataclass + + +def MakeNestedIndexRangeFromLetsListTokenRpExpr( + rp_expr: LetsListTokenRpExpr, uid_prefix: str = "" +) -> t.List[Tree]: + symbol_token_tensors = [ + tuple(token_ids.tolist()) for token_ids in rp_expr.symbol_token_tensors + ] + token2len = defaultdict(lambda: 1) + token2children = defaultdict(lambda: ()) + for token, token_ids in zip(rp_expr.symbol_token_ids, symbol_token_tensors): + token2len[token] = sum(map(lambda x: token2len[x], token_ids)) + token2children[token] = token_ids + + for tensor in rp_expr.body_rp_expr: + assert len(tensor.shape) == 1 + assert tensor.shape[0] == 1 + + return [ + MakeNestedIndexRangeFromRangeTokenCtx( + offset=0, + root_token_id=root_token_id, + token2len=token2len, + token2children=token2children, + uid_prefix=uid_prefix, + ) + for tensor in rp_expr.body_rp_expr + for root_token_id in tensor.tolist() + ] + + +def MakeNestedIndexRangeFromRangeTokenCtx( + offset: int, + root_token_id: int, + token2len: t.Dict[int, int], + token2children: t.Dict[int, t.List[int]], + uid_prefix: str, +): + if len(token2children[root_token_id]) == 0: + assert token2len[root_token_id] == 1 + return Range(offset, offset + token2len[root_token_id]) + + def GetThenIncreaseChildOffset(length): + nonlocal offset + child_offset = offset + offset += length + return child_offset + + return Tree( + uid=f"{uid_prefix}{root_token_id}", + node=Range(offset, offset + token2len[root_token_id]), + children=[ + MakeNestedIndexRangeFromRangeTokenCtx( + offset=child_offfset, + root_token_id=child_token_id, + token2len=token2len, + token2children=token2children, + uid_prefix=uid_prefix, + ) + for child_token_id in token2children[root_token_id] + for child_offfset in [GetThenIncreaseChildOffset(token2len[child_token_id])] + ], + ) diff --git a/graph_net/torch/typical_sequence.py b/graph_net/torch/typical_sequence.py new file mode 100644 index 00000000..ab591ea0 --- /dev/null +++ b/graph_net/torch/typical_sequence.py @@ -0,0 +1,210 @@ +""" +Typical Sequence Extractor +Identify repeated subgraph patterns from extracted FX Graph and save them categorized. + +python -m graph_net.torch.typical_sequence /path/to/model_path + +输入:已经采集到的 模型的整图,在目录下体现为 6 个文件,核心是 model.py 和 weight_meta.py +输出:在 模型整图文件夹下的子目录 subgraph ,这个子目录中包含了一系列文件夹 typical_seq__; 每个文件夹下都会包含类似的 6 个结构 +""" +import argparse +import os +import sys +import json +import hashlib +import ast +import copy +from pathlib import Path +from typing import List, Dict, Tuple, Set, Any +from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser +from graph_net.torch.rp_expr.rp_expr_util import ( + MakeNestedIndexRangeFromLetsListTokenRpExp, +) + + +def compute_subgraph_hash(subgraph: List[Tuple[str, str]]) -> str: + """ + Generate unique hash for subgraph based on op types and topology (ignore variable names). + """ + return hashlib.sha256(...).hexdigest()[:8] + + +def find_forward_function(module_ast: ast.AST) -> ast.FunctionDef: + """ + 在 AST 中查找 GraphModule 的 forward 函数 + Find the 'forward' function in the AST of GraphModule class. + """ + return None + + +def extract_assignments_from_forward( + forward_func: ast.FunctionDef, +) -> List[Tuple[str, str]]: + """ + Extract all assignment statements from forward function (simulating FX Graph node sequence). + """ + return + + +def extract_used_parameters(subgraph: List[Tuple[str, str]]) -> List[str]: + """ + Extract all parameter names used in subgraph (e.g., L_self_modules_...). + """ + return + + +def extract_input_vars_from_subgraph(subgraph: List[Tuple[str, str]]) -> List[str]: + """ + Extract input variables (those used but not defined in subgraph). + """ + return + + +def generate_subgraph_model_py(subgraph: List[Tuple[str, str]]) -> str: + """ + Generate model.py code for the subgraph. + Returns a simplified GraphModule containing only this subgraph logic. + """ + return + + +def generate_input_meta_py(inputs: List[str], weight_meta_path: str) -> str: + """ + 生成 input_meta.py (可复用原 weight_meta 中的结构) + Generate input_meta.py (reuse structure from weight_meta). + """ + return + + +def generate_tensor_constraints_py(inputs: List[str], weight_meta_path: str) -> str: + """ + 生成 input_tensor_constraints.py + Generate input_tensor_constraints.py + """ + return + + +def generate_weight_meta_py(used_params: List[str], weight_meta_path: str) -> str: + """ + 生成 weight_meta.py (仅包含子图用到的参数) + Generate weight_meta.py (only used parameters). + """ + return + + +def extract_typical_sequences(model_path: str, dynamic=True): + """ + Extract typical subgraph sequences from a extracted model directory. + + Args: + model_path (str): Model directory path, e.g., "/daiwenhao/graphnet_workspace/resnet18". + dynamic (bool): Enable dynamic shape support in torch.compile. + """ + + model_py_path = os.path.join(model_path, "model.py") + weight_meta_path = os.path.join(model_path, "weight_meta.py") + if not os.path.exists(model_py_path): + raise FileNotFoundError(f"Missing model.py: {model_py_path}") + if not os.path.exists(weight_meta_path): + raise FileNotFoundError(f"Missing weight_meta.py: {weight_meta_path}") + + # Parse model.py into AST and extract forward function of GraphModule + with open(model_py_path, "r", encoding="utf-8") as f: + source_code = f.read() + module_ast = ast.parse(source_code) + forward_func = find_forward_function(module_ast) + + # Extract assignment sequence (simulating FX Graph node flow) from AST + assignments = extract_assignments_from_forward(forward_func) + + # Use RpExprParser to split into candidate subgraph sequences (to be implemented) + rp_expr_parser = RpExprParser(window_size=64) + lets_list_rp_expr, token_id2primitive_id = rp_expr_parser(assignments) + print("\n".join(lets_list_rp_expr.DebugStrings(token_id2primitive_id))) + candidate_subgraphs = MakeNestedIndexRangeFromLetsListTokenRpExpr(lets_list_rp_expr) + + # Generate structural hash for each subgraph (ignore var names, focus on op types & topology) + subgraph_hashes = [] + for idx, subgraph in enumerate(candidate_subgraphs): + graph_hash = compute_subgraph_hash(subgraph) + subgraph_hashes.append((subgraph, graph_hash)) + + # Group by hash value, same structure belongs to one class + grouped_subgraphs: Dict[str, List[List[Tuple[str, str]]]] = {} + for subgraph, h in subgraph_hashes: + if h not in grouped_subgraphs: + grouped_subgraphs[h] = [] + grouped_subgraphs[h].append(subgraph) + + # Create output directory: graphnet_workspace/resnet18/subgraph/ + subgraph_output_dir = os.path.join(model_path, "subgraph") + os.makedirs(subgraph_output_dir, exist_ok=True) + + # 模仿整图目录结构: 创建 6 个文件 + # Create subgraph directories named typical_seq__ + for graph_hash, instances in grouped_subgraphs.items(): + for instance_idx, subgraph in enumerate(instances): + dir_name = f"typical_seq_{graph_hash}_{instance_idx}" + instance_dir = os.path.join(subgraph_output_dir, dir_name) + os.makedirs(instance_dir, exist_ok=True) + + # 1. Generate model.py for this subgraph instance + subgraph_code = generate_subgraph_model_py(subgraph) + with open( + os.path.join(instance_dir, "model.py"), "w", encoding="utf-8" + ) as f: + f.write(subgraph_code) + + # 2. Generate graph_net.json (metadata, marked as subgraph) + metadata = { + "framework": "torch", + "type": "subgraph", + "parent_model": os.path.basename(model_path), + "structure_hash": graph_hash, + "instance_id": instance_idx, + "dynamic": bool(dynamic), + } + with open( + os.path.join(instance_dir, "graph_net.json"), "w", encoding="utf-8" + ) as f: + json.dump(metadata, f, indent=4) + + # 3. Generate input_meta.py and input_tensor_constraints.py based on inputs + inputs = extract_input_vars_from_subgraph(subgraph) + input_meta_code = generate_input_meta_py(inputs, weight_meta_path) + with open( + os.path.join(instance_dir, "input_meta.py"), "w", encoding="utf-8" + ) as f: + f.write(input_meta_code) + + # 4. Generate input_tensor_constraints.py based on inputs + constraints_code = generate_tensor_constraints_py(inputs, weight_meta_path) + with open( + os.path.join(instance_dir, "input_tensor_constraints.py"), + "w", + encoding="utf-8", + ) as f: + f.write(constraints_code) + + # 5. Generate weight_meta.py (parameters used in this subgraph) + used_params = extract_used_parameters(subgraph) + weight_meta_code = generate_weight_meta_py(used_params, weight_meta_path) + with open( + os.path.join(instance_dir, "weight_meta.py"), "w", encoding="utf-8" + ) as f: + f.write(weight_meta_code) + + # 对于整图抽取来说,是 validate 以后得到的 hash + # 6. Optional: Generate graph hash file + with open(os.path.join(instance_dir, "graph_hash.txt"), "w") as f: + f.write(graph_hash + "\n") + + print(f"Subgraph extraction completed, results saved to...: {subgraph_output_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="resnet18") + args = parser.parse_args() + + extract_typical_sequences(args.model_path)