Skip to content

[Executorch][Export][3/N] Modularize export recipes #13057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions backends/xnnpack/recipes/xnnpack_recipe_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.export import (
BackendRecipeProvider,
ExportRecipe,
LoweringRecipe,
QuantizationRecipe,
RecipeType,
)
Expand Down Expand Up @@ -88,12 +89,19 @@ def create_recipe(
)
return None

def _get_xnnpack_lowering_recipe(
self, precision_type: Optional[ConfigPrecisionType] = None
) -> LoweringRecipe:
return LoweringRecipe(
partitioners=[XnnpackPartitioner(precision_type=precision_type)],
edge_compile_config=get_xnnpack_edge_compile_config(),
)

def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
return ExportRecipe(
name=recipe_type.value,
edge_compile_config=get_xnnpack_edge_compile_config(),
lowering_recipe=self._get_xnnpack_lowering_recipe(),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner()],
)

def _build_quantized_recipe(
Expand All @@ -120,9 +128,8 @@ def _build_quantized_recipe(
return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quant_recipe,
edge_compile_config=get_xnnpack_edge_compile_config(),
lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner(config_precision=precision_type)],
)

def _build_int8da_intx_weight_recipe(
Expand Down Expand Up @@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe(
return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quant_recipe,
edge_compile_config=get_xnnpack_edge_compile_config(),
lowering_recipe=self._get_xnnpack_lowering_recipe(),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner()],
)

def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
Expand Down
3 changes: 2 additions & 1 deletion export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""

from .export import export, ExportSession
from .recipe import ExportRecipe, QuantizationRecipe, RecipeType
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
from .recipe_provider import BackendRecipeProvider
from .recipe_registry import recipe_registry
from .types import StageType

__all__ = [
"StageType",
"ExportRecipe",
"LoweringRecipe",
"QuantizationRecipe",
"ExportSession",
"export",
Expand Down
21 changes: 8 additions & 13 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tabulate import tabulate
from torch import nn

from .recipe import ExportRecipe, QuantizationRecipe
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
from .stages import (
EdgeTransformAndLowerStage,
ExecutorchStage,
Expand Down Expand Up @@ -143,6 +143,10 @@ def __init__(
self._export_recipe.quantization_recipe
)

self._lowering_recipe: Optional[LoweringRecipe] = (
self._export_recipe.lowering_recipe
)

# Stages to run
self._pipeline_stages = (
self._export_recipe.pipeline_stages or self._get_default_pipeline()
Expand Down Expand Up @@ -192,20 +196,11 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
)
stage = TorchExportStage(pre_edge_passes)
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
stage = EdgeTransformAndLowerStage(
partitioners=self._export_recipe.partitioners,
transform_passes=self._export_recipe.edge_transform_passes,
compile_config=self._export_recipe.edge_compile_config,
)
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.TO_EDGE:
stage = ToEdgeStage(
edge_compile_config=self._export_recipe.edge_compile_config
)
stage = ToEdgeStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.TO_BACKEND:
stage = ToBackendStage(
partitioners=self._export_recipe.partitioners,
transform_passes=self._export_recipe.edge_transform_passes,
)
stage = ToBackendStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.TO_EXECUTORCH:
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
else:
Expand Down
32 changes: 22 additions & 10 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ def get_quantizers(self) -> Optional[List[Quantizer]]:
return self.quantizers


@dataclass
class LoweringRecipe:
"""
Configuration recipe for lowering and partitioning.

This class holds the configuration parameters for lowering a model
to backend-specific representations.

Attributes:
partitioners: Optional list of partitioners for model partitioning
edge_transform_passes: Optional sequence of transformation passes to apply
edge_compile_config: Optional edge compilation configuration
"""

partitioners: Optional[List[Partitioner]] = None
edge_transform_passes: Optional[Sequence[PassType]] = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None


@experimental(
"This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
)
Expand All @@ -103,26 +123,18 @@ class ExportRecipe:
Attributes:
name: Optional name for the recipe
quantization_recipe: Optional quantization recipe for model quantization
edge_compile_config: Optional edge compilation configuration
pre_edge_transform_passes: Optional function to apply transformation passes
before edge lowering
edge_transform_passes: Optional sequence of transformation passes to apply
during edge lowering
transform_check_ir_validity: Whether to check IR validity during transformation
partitioners: Optional list of partitioners for model partitioning
lowering_recipe: Optional lowering recipe for model lowering and partitioning
executorch_backend_config: Optional backend configuration for ExecuTorch
pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline.
mode: Export mode (debug or release)
"""

name: Optional[str] = None
quantization_recipe: Optional[QuantizationRecipe] = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
edge_transform_passes: Optional[Sequence[PassType]] = None
transform_check_ir_validity: bool = True
partitioners: Optional[List[Partitioner]] = None
lowering_recipe: Optional[LoweringRecipe] = None
# pyre-ignore[11]: Type not defined
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
pipeline_stages: Optional[List[StageType]] = None
Expand Down
36 changes: 35 additions & 1 deletion export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.program import to_edge, to_edge_transform_and_lower
from executorch.exir.program._program import _transform
from executorch.export.recipe import QuantizationRecipe
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
from executorch.export.types import StageType
from torch import nn
from torch._export.pass_base import PassType
Expand Down Expand Up @@ -168,6 +168,19 @@ def __init__(
self._transform_passes = transform_passes
self._compile_config = compile_config

@classmethod
def from_recipe(
cls, lowering_recipe: Optional["LoweringRecipe"]
) -> "EdgeTransformAndLowerStage":
if lowering_recipe is None:
return cls()

return cls(
partitioners=lowering_recipe.partitioners,
transform_passes=lowering_recipe.edge_transform_passes,
compile_config=lowering_recipe.edge_compile_config,
)

@property
def stage_type(self) -> str:
return StageType.TO_EDGE_TRANSFORM_AND_LOWER
Expand Down Expand Up @@ -369,6 +382,15 @@ def __init__(
super().__init__()
self._edge_compile_config = edge_compile_config

@classmethod
def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStage":
if lowering_recipe is None:
return cls()

return cls(
edge_compile_config=lowering_recipe.edge_compile_config,
)

@property
def stage_type(self) -> str:
return StageType.TO_EDGE
Expand Down Expand Up @@ -415,6 +437,18 @@ def __init__(
self._partitioners = partitioners
self._transform_passes = transform_passes

@classmethod
def from_recipe(
cls, lowering_recipe: Optional["LoweringRecipe"]
) -> "ToBackendStage":
if lowering_recipe is None:
return cls()

return cls(
partitioners=lowering_recipe.partitioners,
transform_passes=lowering_recipe.edge_transform_passes,
)

@property
def stage_type(self) -> str:
return StageType.TO_BACKEND
Expand Down
46 changes: 46 additions & 0 deletions export/tests/test_export_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
from executorch.export import ExportRecipe, ExportSession
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
from executorch.export.stages import PipelineArtifact
from executorch.export.types import StageType

Expand Down Expand Up @@ -434,3 +435,48 @@ def test_save_to_pte_invalid_name(self) -> None:

with self.assertRaises(AssertionError):
session.save_to_pte(None) # pyre-ignore


class TestExportSessionPipelineBuilding(unittest.TestCase):
"""Test pipeline building and stage configuration."""

def setUp(self) -> None:
self.model = SimpleTestModel()
self.example_inputs = [(torch.randn(2, 10),)]

def test_pipeline_building_with_all_recipes(self) -> None:
"""Test pipeline building with quantization and lowering recipes."""
# Create comprehensive recipes
quant_recipe = QuantizationRecipe(
ao_base_config=[Mock()],
quantizers=[Mock()],
)
lowering_recipe = LoweringRecipe(
partitioners=[Mock()],
edge_transform_passes=[Mock()],
edge_compile_config=Mock(),
)
recipe = ExportRecipe(
name="comprehensive_test",
quantization_recipe=quant_recipe,
lowering_recipe=lowering_recipe,
executorch_backend_config=Mock(),
)

session = ExportSession(
model=self.model,
example_inputs=self.example_inputs,
export_recipe=recipe,
)

registered_stages = session.get_all_registered_stages()

self.assertEqual(len(registered_stages), 5)
expected_types = [
StageType.SOURCE_TRANSFORM,
StageType.QUANTIZE,
StageType.TORCH_EXPORT,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.TO_EXECUTORCH,
]
self.assertListEqual(list(registered_stages.keys()), expected_types)
Loading