diff --git a/shardy/integrations/python/jax/mpmd/__init__.py b/shardy/integrations/python/jax/mpmd/__init__.py index d5dee7a2..f6fb3379 100644 --- a/shardy/integrations/python/jax/mpmd/__init__.py +++ b/shardy/integrations/python/jax/mpmd/__init__.py @@ -24,14 +24,14 @@ from shardy.integrations.python.jax.mpmd.ops import named_computation from shardy.integrations.python.jax.mpmd.ops import named_tensor from shardy.integrations.python.jax.mpmd.ops import reduce +from shardy.integrations.python.jax.mpmd.pipeline import FragmentInfo +from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRule +from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRules +from shardy.integrations.python.jax.mpmd.pipeline import FragmentOrigin from shardy.integrations.python.jax.mpmd.stages import MpmdCompiled as Compiled from shardy.integrations.python.jax.mpmd.stages import MpmdExecutable as Executable from shardy.integrations.python.jax.mpmd.stages import MpmdJitShardingInfo from shardy.integrations.python.jax.mpmd.stages import MpmdLowered as Lowered -from shardy.integrations.python.jax.mpmd.types import FragmentInfo -from shardy.integrations.python.jax.mpmd.types import FragmentMergeRule -from shardy.integrations.python.jax.mpmd.types import FragmentMergeRules -from shardy.integrations.python.jax.mpmd.types import FragmentOrigin from shardy.integrations.python.jax.mpmd.types import FunctionIOMeshAssignment from shardy.integrations.python.jax.mpmd.types import make_config from shardy.integrations.python.jax.mpmd.types import mesh_names diff --git a/shardy/integrations/python/jax/mpmd/jaxlib_utils.py b/shardy/integrations/python/jax/mpmd/jaxlib_utils.py index 5bbaf028..22e5a7ac 100644 --- a/shardy/integrations/python/jax/mpmd/jaxlib_utils.py +++ b/shardy/integrations/python/jax/mpmd/jaxlib_utils.py @@ -12,22 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for converting between Python types and jaxlib pybind types.""" +"""Utilities for converting between Python types and jaxlib pybind pipeline.""" from jaxlib import _sdy_mpmd as jaxlib_mpmd -from shardy.integrations.python.jax.mpmd import types +from shardy.integrations.python.jax.mpmd import pipeline def _to_jaxlib_split_type( - split_type: types.SplitFragmentType | None, + split_type: pipeline.SplitFragmentType | None, ) -> jaxlib_mpmd.SplitFragmentType | None: """Convert native Python enum to pybinded enum.""" if split_type is None: return None - if split_type == types.SplitFragmentType.KEEP_TRANSFERRED: + if split_type == pipeline.SplitFragmentType.KEEP_TRANSFERRED: return jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED - elif split_type == types.SplitFragmentType.DROP_TRANSFERRED: + elif split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED: return jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED else: raise ValueError(f'Unknown SplitFragmentType: {split_type}') @@ -35,20 +35,20 @@ def _to_jaxlib_split_type( def _from_jaxlib_split_type( split_type: jaxlib_mpmd.SplitFragmentType | None, -) -> types.SplitFragmentType | None: +) -> pipeline.SplitFragmentType | None: """Convert pybinded enum to native Python enum.""" if split_type is None: return None if split_type == jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED: - return types.SplitFragmentType.KEEP_TRANSFERRED + return pipeline.SplitFragmentType.KEEP_TRANSFERRED elif split_type == jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED: - return types.SplitFragmentType.DROP_TRANSFERRED + return pipeline.SplitFragmentType.DROP_TRANSFERRED else: raise ValueError(f'Unknown jaxlib_mpmd.SplitFragmentType: {split_type}') def convert_fragment_info_to_pybind( - fragment: types.FragmentInfo, + fragment: pipeline.FragmentInfo, ) -> jaxlib_mpmd.FragmentInfo: """Converts FragmentInfo to jaxlib_mpmd.FragmentInfo.""" return jaxlib_mpmd.FragmentInfo( @@ -67,11 +67,13 @@ def convert_fragment_info_to_pybind( def convert_pybind_fragment_info_to_types( fragment: jaxlib_mpmd.FragmentInfo, -) -> types.FragmentInfo: +) -> pipeline.FragmentInfo: """Converts jaxlib_mpmd.FragmentInfo to FragmentInfo.""" - return types.FragmentInfo( + return pipeline.FragmentInfo( origins=tuple( - types.FragmentOrigin(origin.computation_name, origin.transpose_count) + pipeline.FragmentOrigin( + origin.computation_name, origin.transpose_count + ) for origin in fragment.origins ), stage_id=fragment.stage_id, @@ -82,7 +84,7 @@ def convert_pybind_fragment_info_to_types( def convert_fragment_merge_rules_to_pybind( - fragment_merge_rules: types.FragmentMergeRules, + fragment_merge_rules: pipeline.FragmentMergeRules, ) -> list[jaxlib_mpmd.FragmentMergeRule]: """Converts fragment merge rules to jaxlib_mpmd.FragmentMergeRules.""" pybind_fragment_merge_rules = [] @@ -100,7 +102,7 @@ def convert_fragment_merge_rules_to_pybind( def convert_fragment_schedule_rules_to_pybind( - fragment_schedule_rules: types.FragmentScheduleRules, + fragment_schedule_rules: pipeline.FragmentScheduleRules, ) -> list[jaxlib_mpmd.FragmentScheduleRule]: """Converts fragment schedule rules to jaxlib_mpmd.FragmentScheduleRules.""" pybind_fragment_schedule_rules = [] diff --git a/shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py b/shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py index 8ef6ef9c..1fb6789e 100644 --- a/shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py +++ b/shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py @@ -19,7 +19,7 @@ from jaxlib import _sdy_mpmd as jaxlib_mpmd from shardy.integrations.python.jax.mpmd import jaxlib_utils -from shardy.integrations.python.jax.mpmd import types +from shardy.integrations.python.jax.mpmd import pipeline class SplitTypeConversionTest(parameterized.TestCase): @@ -28,12 +28,12 @@ class SplitTypeConversionTest(parameterized.TestCase): @parameterized.named_parameters( ( 'keep_transferred', - types.SplitFragmentType.KEEP_TRANSFERRED, + pipeline.SplitFragmentType.KEEP_TRANSFERRED, jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED, ), ( 'drop_transferred', - types.SplitFragmentType.DROP_TRANSFERRED, + pipeline.SplitFragmentType.DROP_TRANSFERRED, jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED, ), ('none', None, None), @@ -52,27 +52,27 @@ class FragmentInfoConversionTest(parameterized.TestCase): @parameterized.named_parameters( ( 'single_origin', - types.FragmentInfo( - origins=(types.FragmentOrigin('comp1', 0),), mesh_name='mesh1' + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('comp1', 0),), mesh_name='mesh1' ), ), ( 'multiple_origins', - types.FragmentInfo( + pipeline.FragmentInfo( origins=( - types.FragmentOrigin('comp1', 0), - types.FragmentOrigin('comp2', 1), + pipeline.FragmentOrigin('comp1', 0), + pipeline.FragmentOrigin('comp2', 1), ), mesh_name='mesh1', ), ), ( 'all_fields', - types.FragmentInfo( - origins=(types.FragmentOrigin('comp1', 2),), + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('comp1', 2),), stage_id=5, call_counter=3, - split_type=types.SplitFragmentType.KEEP_TRANSFERRED, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, mesh_name='mesh2', ), ), @@ -89,18 +89,21 @@ class FragmentMergeRulesConversionTest(absltest.TestCase): def test_single_rule(self): """Test converting single merge rule.""" - f1 = types.FragmentInfo( - origins=(types.FragmentOrigin('f1', 0),), mesh_name='m1' + f1 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('f1', 0),), mesh_name='m1' ) - f2 = types.FragmentInfo( - origins=(types.FragmentOrigin('f2', 0),), mesh_name='m1' + f2 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('f2', 0),), mesh_name='m1' ) - target = types.FragmentInfo( - origins=(types.FragmentOrigin('f1', 0), types.FragmentOrigin('f2', 0)), + target = pipeline.FragmentInfo( + origins=( + pipeline.FragmentOrigin('f1', 0), + pipeline.FragmentOrigin('f2', 0), + ), mesh_name='m1', ) - rule = types.FragmentMergeRule(sources={f1, f2}, target=target) + rule = pipeline.FragmentMergeRule(sources={f1, f2}, target=target) result = jaxlib_utils.convert_fragment_merge_rules_to_pybind([rule]) self.assertLen(result, 1) @@ -114,18 +117,18 @@ class FragmentScheduleRulesConversionTest(absltest.TestCase): def test_preserves_order(self): """Test that ordered_fragments order is preserved.""" frags = [ - types.FragmentInfo( - origins=(types.FragmentOrigin('first', 0),), mesh_name='m1' + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('first', 0),), mesh_name='m1' ), - types.FragmentInfo( - origins=(types.FragmentOrigin('second', 0),), mesh_name='m1' + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('second', 0),), mesh_name='m1' ), - types.FragmentInfo( - origins=(types.FragmentOrigin('third', 0),), mesh_name='m1' + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('third', 0),), mesh_name='m1' ), ] - rule = types.FragmentScheduleRule(ordered_fragments=frags) + rule = pipeline.FragmentScheduleRule(ordered_fragments=frags) result = jaxlib_utils.convert_fragment_schedule_rules_to_pybind([rule]) self.assertEqual( diff --git a/shardy/integrations/python/jax/mpmd/pipeline.py b/shardy/integrations/python/jax/mpmd/pipeline.py new file mode 100644 index 00000000..170feb8c --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/pipeline.py @@ -0,0 +1,443 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Core data structures and helper functions for MPMD pipeline scheduling. + +The primary entry point for defining a schedule is the `PipelineSchedule` +object, which uses rule builders to determine the execution order and +merging of these fragments. Rule builders take lists of fragments and build +concrete scheduling/merging rules. + +There are two main approaches to defining pipeline schedules: + +1. Predicate-based approach (recommended for simple patterns): + Use binary predicates with helper functions to automatically generate rules. + `schedule_impl.py` contains implementations of common schedules using these + predicates. + +2. Direct construction (for complex custom schedules): + Explicitly build execution order and merge rules for full control. + +This is best shown through example: see `pipeline_test.py` for a concrete +PipelineSchedule definitions using both approaches. +""" + +import collections +from collections.abc import Collection, Mapping, Sequence, Set +import dataclasses +import enum +from typing import Callable + +FragmentMergeRules = Sequence['FragmentMergeRule'] +FragmentScheduleRules = Sequence['FragmentScheduleRule'] + +# Function that constructs a target FragmentInfo from a sequence of source +# fragments that will be merged together into the target. +TargetInfoBuilder = Callable[[Sequence['FragmentInfo']], 'FragmentInfo'] + +# Function that builds schedule and/or merge rules from fragments and pipeline +# context. +ScheduleMergeRuleBuilder = Callable[ + [Sequence['FragmentInfo'], 'PipelineContext'], + tuple[FragmentScheduleRules, FragmentMergeRules], +] + +# Binary predicate determining if two fragments should be merged or scheduled +# together. +RuleGeneratorPredicate = Callable[ + ['FragmentInfo', 'FragmentInfo', 'PipelineContext'], bool +] + + +@dataclasses.dataclass(frozen=True) +class FragmentOrigin: + """The origin of a fragment.""" + + computation_name: str + transpose_count: int = 0 + + +@enum.unique +class SplitFragmentType(enum.Enum): + """Fragment split behavior for transferred data. + + These values indicate how fragment portions handle transferred data from + the original fragment if the fragment is split during compilation: + - KEEP_TRANSFERRED: Fragment portion retains transferred data + - DROP_TRANSFERRED: Fragment portion drops transferred data + """ + + KEEP_TRANSFERRED = enum.auto() + DROP_TRANSFERRED = enum.auto() + + +@dataclasses.dataclass(frozen=True) +class FragmentInfo: + """A fragment of a computation.""" + + origins: tuple[FragmentOrigin, ...] + stage_id: int | None = None + call_counter: int | None = None + split_type: SplitFragmentType | None = None + mesh_name: str = '' + + +def validate_fragment_rule_origins( + fragment_collection: Collection[FragmentInfo], +) -> None: + """Validates that all fragments have at least one origin.""" + for fragment in fragment_collection: + if not fragment.origins: + raise ValueError( + f'Each fragment must have at least one origin, but got {fragment} in' + f' {fragment_collection}.' + ) + + +def validate_fragment_rule_meshes( + fragment_collection: Collection[FragmentInfo], +) -> None: + """Validates that all fragments are on the same mesh.""" + first_fragment = next(iter(fragment_collection)) + first_mesh = first_fragment.mesh_name + if not all( + fragment.mesh_name == first_mesh for fragment in fragment_collection + ): + raise ValueError( + 'Fragments being merged/scheduled must be on the same mesh, but got' + f' {fragment_collection}.' + ) + + +@dataclasses.dataclass(frozen=True) +class FragmentMergeRule: + """A rule for merging fragments of a computation. + + Attributes: + sources: The source fragments to be merged. The order does not affect the + final position of the merged fragment. + target: The target fragment metadata that results from merging the sources. + """ + + sources: Set[FragmentInfo] + target: FragmentInfo + + def __post_init__(self): + # Validate the fragment merge rule. + if len(self.sources) < 2: + raise ValueError( + 'FragmentMergeRule must contain at least 2 source fragments, but got' + f' {self}.' + ) + validate_fragment_rule_origins(self.sources) + validate_fragment_rule_meshes(self.sources) + + if not self.target.origins: + raise ValueError( + f'Target fragment must have at least one origin, but got {self}.' + ) + + +@dataclasses.dataclass(frozen=True) +class FragmentScheduleRule: + """A rule for scheduling fragments in a specific execution order. + + Attributes: + ordered_fragments: Fragments in the order they should execute. Must contain + at least 2 fragments, and all fragments must be on the same mesh. + """ + + ordered_fragments: Sequence[FragmentInfo] + + def __post_init__(self): + # Validate the fragment schedule rule. + if len(self.ordered_fragments) < 2: + raise ValueError( + 'FragmentScheduleRule must contain at least 2 fragments, but got' + f' {self}.' + ) + validate_fragment_rule_origins(self.ordered_fragments) + validate_fragment_rule_meshes(self.ordered_fragments) + + +@dataclasses.dataclass(frozen=True) +class PipelineContext: + """Context for pipeline scheduling and merging predicates.""" + + num_meshes: int + + +@dataclasses.dataclass(frozen=True) +class PipelineSchedule: + """A set of rules and options which define an MPMD pipeline. + + Attributes: + schedule_merge_rule_builders: A sequence of functions that build schedule + and/or merge rules for fragments. + required_mpmd_options: A mapping of PartitioningEnvironment flags that are + required for this schedule to function correctly. See + `partitioning_options.py` for available options. Relevant options + include: + + `mpmd_split_bwd_fragments`: Set to True to split backward fragments into + separate weight gradient and activation gradient fragments. This enables + independent scheduling of weight and activation gradients. + + `mpmd_merge_inferred_after_scheduling`: Set to True to defer merging of + inferred fragments until after scheduling. If False (default), inferred + fragments are merged before scheduling, which may create unintended data + dependencies that constrain your scheduling order. + """ + + schedule_merge_rule_builders: Sequence[ScheduleMergeRuleBuilder] | None = None + required_mpmd_options: Mapping[str, bool | str] | None = None + + +def fragment_origins_contain(fragment: FragmentInfo, substring: str) -> bool: + """Checks if any computation name in fragment origins contains the substring.""" + return any( + substring in origin.computation_name for origin in fragment.origins + ) + + +def build_schedule_rules_from_predicate( + fragment_infos: Sequence[FragmentInfo], + context: PipelineContext, + *, + before_pred: RuleGeneratorPredicate, +) -> tuple[FragmentScheduleRules, FragmentMergeRules]: + """Builds a list of scheduling rules using a binary predicate function.""" + res = [] + for i, a in enumerate(fragment_infos): + for j, b in enumerate(fragment_infos): + if i == j: + continue + if a.mesh_name != b.mesh_name: + continue + + if before_pred(a, b, context): + res.append(FragmentScheduleRule(ordered_fragments=[a, b])) + return res, [] + + +def union_fragment_origins( + source_fragments: Sequence[FragmentInfo], +) -> tuple[FragmentOrigin, ...]: + """Union all origins from a sequence of fragment infos.""" + merged_origins = [] + seen_origins = set() + for fragment in source_fragments: + for origin in fragment.origins: + origin_key = (origin.computation_name, origin.transpose_count) + if origin_key not in seen_origins: + merged_origins.append(origin) + seen_origins.add(origin_key) + return tuple(merged_origins) + + +def _minimal_create_target_info( + source_fragments: Sequence[FragmentInfo], +) -> FragmentInfo: + """Creates a target FragmentInfo based on a sequence of source FragmentInfos. + + FragmentMergeRule takes in a FragmentInfo which describes the final fragment + metadata after all sources have been merged. This functions creates a target + info with the minimal amount of information needed to create this target + FragmentInfo. + + Args: + source_fragments: List of source fragment infos to create target info from. + + Returns: + FragmentInfo object representing the target fragment info. + + Raises: + ValueError: If `source_fragments` is empty or fragments have inconsistent + `mesh_name` values. + """ + if not source_fragments: + raise ValueError( + 'Cannot create target info from empty source fragments sequence' + ) + + mesh_name = source_fragments[0].mesh_name + for fragment in source_fragments: + if fragment.mesh_name != mesh_name: + raise ValueError( + f'Inconsistent mesh_name values: {mesh_name} vs {fragment.mesh_name}' + ) + + return FragmentInfo( + origins=union_fragment_origins(source_fragments), + stage_id=None, + call_counter=None, + split_type=None, + mesh_name=mesh_name, + ) + + +def build_merge_rules_from_predicate( + fragment_infos: Sequence[FragmentInfo], + context: PipelineContext, + target_info_builder: TargetInfoBuilder = _minimal_create_target_info, + *, + pred: RuleGeneratorPredicate, +) -> tuple[FragmentScheduleRules, FragmentMergeRules]: + """Creates a list of fragment merge rules based on a binary predicate. + + Args: + fragment_infos: List of fragments to create merge rules for. + context: PipelineContext object containing additional context for the + scheduling and merging process. + target_info_builder: Function that creates a target fragment info based on + on a list of source fragment infos. Defaults to create_target_info. + pred: Binary predicate function that determines if fragments should be + merged. + + Returns: + Tuple of (schedule_rules, merge_rules) where schedule_rules is empty. + """ + merge_rules = [] + for i, fragment_a in enumerate(fragment_infos): + # Order of fragments should not matter for merge rules, so we can skip + # checking pairs of fragments that have already been checked. + for fragment_b in fragment_infos[i + 1 :]: + if fragment_a.mesh_name != fragment_b.mesh_name: + continue + + if pred(fragment_a, fragment_b, context): + merge_rules.append( + FragmentMergeRule( + sources={fragment_a, fragment_b}, + target=target_info_builder([fragment_a, fragment_b]), + ) + ) + return [], merge_rules + + +def build_rules_from_pipeline( + fragment_infos: Sequence[FragmentInfo], + pipeline: PipelineSchedule, + context: PipelineContext, +) -> tuple[FragmentScheduleRules, FragmentMergeRules]: + """Builds scheduling and merging rules from a PipelineSchedule. + + Args: + fragment_infos: List of fragments to build rules for. + pipeline: PipelineSchedule containing rule generators and options. + context: PipelineContext with pipeline configuration. + + Returns: + Tuple of (schedule_rules, merge_rules) built from rule builders. + """ + # Create a list of fragments for each mesh once + mesh_fragments = collections.defaultdict(list) + for fragment in fragment_infos: + mesh_fragments[fragment.mesh_name].append(fragment) + + all_schedule_rules = [] + all_merge_rules = [] + + if pipeline.schedule_merge_rule_builders: + for builder in pipeline.schedule_merge_rule_builders: + # Run each builder on fragments from each mesh separately + for _, single_mesh_fragments in mesh_fragments.items(): + schedule_rules, merge_rules = builder(single_mesh_fragments, context) + all_schedule_rules.extend(schedule_rules) + all_merge_rules.extend(merge_rules) + + return all_schedule_rules, all_merge_rules + + +def maybe_unique_transpose_count( + fragment: FragmentInfo, +) -> int | None: + """Returns transpose count if all fragment origins have the same value.""" + if not fragment.origins: + return None + + # Check if all origins have the same transpose count. + transpose_counts = {origin.transpose_count for origin in fragment.origins} + if len(transpose_counts) == 1: + return transpose_counts.pop() + + return None + + +def get_scheduling_unit_info(fragment: FragmentInfo) -> tuple[int, int] | None: + """Returns (call_counter, transpose_count) if fragment is a valid scheduling unit. + + A fragment is a valid scheduling unit if it meets all of the following + conditions: + - It is a user fragment (has origins) + - It has a call_counter + - It has a single transpose_count which is 0 or 1 + + Args: + fragment: Fragment to check scheduling unit for. + + Returns: + A tuple of (call_counter, transpose_count) if valid, None otherwise. + """ + if not fragment.origins: + return None + + if fragment.call_counter is None: + return None + + transpose_count = maybe_unique_transpose_count(fragment) + if transpose_count is not None and ( + transpose_count == 0 or transpose_count == 1 + ): + return (fragment.call_counter, transpose_count) + + return None + + +def get_staged_scheduling_info( + f1: FragmentInfo, f2: FragmentInfo, error_context: str +) -> tuple[int, int, int, int] | None: + """Validates two fragments for scheduling and returns their info. + + Args: + f1: First fragment to validate. + f2: Second fragment to validate. + error_context: Context for the error message if stage_id validation fails, + e.g., "1F1B scheduling". + + Returns: + Tuple of (call_counter_f1, transpose_count_f1, call_counter_f2, + transpose_count_f2) if both fragments are valid scheduling units with + stages, None otherwise. + + Raises: + ValueError: If `stage_id` is not set on either of the fragments. + """ + f1_info = get_scheduling_unit_info(f1) + f2_info = get_scheduling_unit_info(f2) + if f1_info is None or f2_info is None: + return None + + if f1.stage_id is None or f2.stage_id is None: + raise ValueError(f'All fragments must have a stage id for {error_context}.') + + call_counter_f1, transpose_count_f1 = f1_info + call_counter_f2, transpose_count_f2 = f2_info + return ( + call_counter_f1, + transpose_count_f1, + call_counter_f2, + transpose_count_f2, + ) diff --git a/shardy/integrations/python/jax/mpmd/pipeline_registry.py b/shardy/integrations/python/jax/mpmd/pipeline_registry.py new file mode 100644 index 00000000..d850ffa2 --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/pipeline_registry.py @@ -0,0 +1,178 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Pipeline schedule registry. + +Central registry mapping schedule names to PipelineSchedule objects. Each +schedule defines fragment merging and ordering using binary predicate functions. + +Usage: + schedule = get_pipeline_schedule('1F1B') + config = make_config( + topology=topology, + name_to_mesh_assignment=mesh_assignment, + pipeline_schedule=schedule, + ) +""" + +import functools + +import immutabledict + +from shardy.integrations.python.jax.mpmd import pipeline +from shardy.integrations.python.jax.mpmd import schedule_impl + +ImmutableDict = immutabledict.immutabledict + +PIPELINE_SCHEDULES: ImmutableDict[str, pipeline.PipelineSchedule] = ( + ImmutableDict({ + 'ONE_FWD_ONE_BWD': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.one_fwd_one_bwd_schedule_predicate, + ) + ], + required_mpmd_options={'mpmd_pipeline_schedule': '1F1B'}, + ), + 'GPIPE': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.gpipe_schedule_predicate, + ) + ], + required_mpmd_options={'mpmd_pipeline_schedule': 'GPipe'}, + ), + 'GPIPE_BUT_1F1B_FOR_LAST_MESH': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.gpipe_with_1f1b_on_last_mesh_schedule_predicate, + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'GPipeBut1F1BForLastMesh' + }, + ), + 'ZERO_BUBBLE_H1': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.zero_bubble_h1_schedule_predicate, + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'ZeroBubbleH1', + 'mpmd_split_bwd_fragments': True, + }, + ), + 'ZERO_BUBBLE_H2_ZERO_TX_LATENCY': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.latency_hiding_zero_bubble_h2_schedule_predicate, + latency_stage_fraction=0.0, + ), + ) + ], + required_mpmd_options={ + 'mpmd_split_bwd_fragments': True, + 'mpmd_pipeline_schedule': 'ZeroBubbleH2ZeroTxLatency', + }, + ), + 'ZERO_BUBBLE_H2_HALF_TX_LATENCY': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.latency_hiding_zero_bubble_h2_schedule_predicate, + latency_stage_fraction=0.5, + ), + ) + ], + required_mpmd_options={ + 'mpmd_split_bwd_fragments': True, + 'mpmd_pipeline_schedule': 'ZeroBubbleH2HalfTxLatency', + }, + ), + 'ZERO_BUBBLE_H2_FULL_TX_LATENCY': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.latency_hiding_zero_bubble_h2_schedule_predicate, + latency_stage_fraction=1.0, + ), + ) + ], + required_mpmd_options={ + 'mpmd_split_bwd_fragments': True, + 'mpmd_pipeline_schedule': 'ZeroBubbleH2FullTxLatency', + }, + ), + 'PARALLEL_PIPELINES_WITH_WRAP_AROUND': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.parallel_pipelines_with_wraparound_schedule_predicate, + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'ParallelPipelinesWithWrapAround', + }, + ), + 'CIRCULAR': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.circular_schedule_predicate_base, + reverse_backward=False, + ), + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'Circular', + }, + ), + 'CIRCULAR_WITH_REVERSED_BACKWARD': pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.circular_schedule_predicate_base, + reverse_backward=True, + ), + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'CircularWithReversedBackward', + }, + ), + }) +) + + +def get_pipeline_schedule(schedule_name: str) -> pipeline.PipelineSchedule: + """Get a PipelineSchedule object for the given schedule name.""" + if schedule_name not in PIPELINE_SCHEDULES: + valid_schedules = sorted(PIPELINE_SCHEDULES.keys()) + raise KeyError( + f"Unknown pipeline schedule '{schedule_name}'. " + f'Valid schedules are: {valid_schedules!r}' + ) + return PIPELINE_SCHEDULES[schedule_name] diff --git a/shardy/integrations/python/jax/mpmd/pipeline_test.py b/shardy/integrations/python/jax/mpmd/pipeline_test.py new file mode 100644 index 00000000..abb7fa14 --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/pipeline_test.py @@ -0,0 +1,210 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MPMD pipeline functions.""" + +from collections.abc import Sequence +import functools +import unittest +from absl.testing import parameterized +from shardy.integrations.python.jax.mpmd import pipeline + + +def _make_fragment( + mesh_name: str = "mesh1", + origins: Sequence[pipeline.FragmentOrigin] | None = None, + **kwargs, +) -> pipeline.FragmentInfo: + """Helper to create FragmentInfo with common defaults.""" + # Use None instead of [] to avoid shared mutable default argument + if origins is None: + origins = () + return pipeline.FragmentInfo(origins=origins, mesh_name=mesh_name, **kwargs) + + +class BasicScheduleBuildTest(unittest.TestCase): + + def test_predicate_based_schedule_example(self): + def my_schedule_predicate(f1, f2, _): + return f1.call_counter < f2.call_counter + + def my_merge_predicate(f1, f2, _): + return f1.stage_id == f2.stage_id and f1.call_counter == f2.call_counter + + schedule = pipeline.PipelineSchedule( + schedule_merge_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=my_schedule_predicate, + ), + functools.partial( + pipeline.build_merge_rules_from_predicate, + pred=my_merge_predicate, + ), + ], + required_mpmd_options={}, + ) + + self.assertIsNotNone(schedule) + self.assertEqual(len(schedule.schedule_merge_rule_builders), 2) + + def test_direct_construction_schedule_example(self): + # This test implements the same logic as the predicate-based example above. + def custom_schedule_builder(fragment_infos, _): + forward = sorted( + [ + f + for f in fragment_infos + if f.origins + and pipeline.maybe_unique_transpose_count(f) == 0 + ], + key=lambda f: f.call_counter or 0, + ) + backward = sorted( + [ + f + for f in fragment_infos + if f.origins + and pipeline.maybe_unique_transpose_count(f) == 1 + ], + key=lambda f: f.call_counter or 0, + ) + + execution_order = [] + for fwd, bwd in zip(forward, backward): + execution_order.extend([fwd, bwd]) + + merge_rules = [ + pipeline.FragmentMergeRule( + sources={fwd, bwd}, + target=pipeline._minimal_create_target_info([fwd, bwd]), + ) + for fwd, bwd in zip(forward, backward) + if fwd.stage_id == bwd.stage_id + ] + + return [ + pipeline.FragmentScheduleRule(ordered_fragments=execution_order) + ], merge_rules + + schedule = pipeline.PipelineSchedule( + schedule_merge_rule_builders=[custom_schedule_builder], + required_mpmd_options={}, + ) + + self.assertIsNotNone(schedule) + self.assertEqual(len(schedule.schedule_merge_rule_builders), 1) + + +class MinimalCreateTargetInfoTest(parameterized.TestCase): + + def test_empty_source_fragments_raises_error(self): + with self.assertRaises(ValueError): + pipeline._minimal_create_target_info([]) + + def test_single_fragment(self): + origin = pipeline.FragmentOrigin("comp1", transpose_count=1) + fragment = _make_fragment( + origins=(origin,), + stage_id=5, + call_counter=10, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, + ) + + result = pipeline._minimal_create_target_info([fragment]) + + self.assertEqual(result.origins, (origin,)) + # minimal_create_target_info always sets these to None + self.assertIsNone(result.stage_id) + self.assertIsNone(result.call_counter) + self.assertIsNone(result.split_type) + self.assertEqual(result.mesh_name, "mesh1") + + def test_origins_union_preserves_all_transpose_counts(self): + origin1 = pipeline.FragmentOrigin("comp1", transpose_count=0) + origin2 = pipeline.FragmentOrigin("comp2", transpose_count=1) + origin3 = pipeline.FragmentOrigin( + "comp1", transpose_count=1 + ) # Different transpose_count + + fragment1 = pipeline.FragmentInfo( + origins=(origin1, origin2), mesh_name="mesh1" + ) + fragment2 = pipeline.FragmentInfo(origins=(origin3,), mesh_name="mesh1") + + result = pipeline._minimal_create_target_info([fragment1, fragment2]) + + self.assertCountEqual(result.origins, (origin1, origin2, origin3)) + + def test_origins_union_removes_duplicates(self): + origin1 = pipeline.FragmentOrigin("comp1", transpose_count=0) + origin2 = pipeline.FragmentOrigin("comp2", transpose_count=1) + + fragment1 = pipeline.FragmentInfo( + origins=(origin1, origin2), mesh_name="mesh1" + ) + # `origin1` also exists in fragment1 origins + fragment2 = pipeline.FragmentInfo(origins=(origin1,), mesh_name="mesh1") + + result = pipeline._minimal_create_target_info([fragment1, fragment2]) + # Verify that the duplicate `origin1` does not remain + self.assertCountEqual(result.origins, (origin1, origin2)) + + def test_mesh_name_inconsistency_raises_error(self): + """Test that inconsistent mesh_name values raise ValueError.""" + fragment1 = _make_fragment(mesh_name="mesh1") + fragment2 = _make_fragment(mesh_name="mesh2") + + with self.assertRaises(ValueError) as cm: + pipeline._minimal_create_target_info([fragment1, fragment2]) + self.assertIn( + "Inconsistent mesh_name values: mesh1 vs mesh2", str(cm.exception) + ) + + def test_mesh_name_from_first_fragment(self): + fragment1 = _make_fragment(mesh_name="mesh1") + fragment2 = _make_fragment(mesh_name="mesh1") + + result = pipeline._minimal_create_target_info((fragment1, fragment2)) + + self.assertEqual(result.mesh_name, "mesh1") + + def test_always_returns_none_for_optional_fields(self): + """Test that stage_id, call_counter, and split_type are always None.""" + fragment1 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin("comp1", transpose_count=0),), + stage_id=5, + call_counter=10, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, + mesh_name="mesh1", + ) + fragment2 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin("comp2", transpose_count=1),), + stage_id=5, + call_counter=10, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, + mesh_name="mesh1", + ) + + result = pipeline._minimal_create_target_info([fragment1, fragment2]) + + # Regardless of input values, these should always be None + self.assertIsNone(result.stage_id) + self.assertIsNone(result.call_counter) + self.assertIsNone(result.split_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/shardy/integrations/python/jax/mpmd/schedule_impl.py b/shardy/integrations/python/jax/mpmd/schedule_impl.py new file mode 100644 index 00000000..d9d5f988 --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/schedule_impl.py @@ -0,0 +1,361 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementations of common pipeline scheduling predicates for MPMD.""" + +from typing import Callable + +from shardy.integrations.python.jax.mpmd import pipeline + + +def gpipe_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + _: pipeline.PipelineContext, +) -> bool: + """Returns true if `f1` must happen before `f2` in a GPipe schedule.""" + transpose_count_f1 = pipeline.maybe_unique_transpose_count(f1) + transpose_count_f2 = pipeline.maybe_unique_transpose_count(f2) + if ( + transpose_count_f1 is None + or transpose_count_f2 is None + or f1.call_counter is None + or f2.call_counter is None + ): + return False + + return (transpose_count_f1, f1.call_counter) < ( + transpose_count_f2, + f2.call_counter, + ) + + +def one_fwd_one_bwd_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in a 1F1B schedule.""" + result = pipeline.get_staged_scheduling_info(f1, f2, "1F1B scheduling") + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + # The following two conditions guarantee the forward and backward fragments + # are interleaved in the steady state of the pipeline. + + # Example: in mesh/stage 0 of pipeline of depth 4, the backward computation + # of microbatch 0 must be scheduled before the forward computation of + # microbatch 4: 0 == 4 - 4 + 0. + if transpose_count_f1 == 1 and transpose_count_f2 == 0: + return call_counter_f1 == call_counter_f2 - context.num_meshes + f1.stage_id + + # Example: in mesh/stage 0 of pipeline of depth 4, the forward computation of + # microbatch 5 must be scheduled before the backward computation of + # microbatch 2: 5 == 2 + 4 - (0 + 1). + if transpose_count_f1 == 0 and transpose_count_f2 == 1: + return call_counter_f1 == call_counter_f2 + context.num_meshes - ( + f1.stage_id + 1 + ) + + # If the fragments have the same transpose count, guarantee that the + # call_counter ordering is preserved. + if transpose_count_f1 == transpose_count_f2: + return call_counter_f1 < call_counter_f2 + + return False + + +def gpipe_with_1f1b_on_last_mesh_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in a GPipe schedule with 1F1B on the last mesh.""" + result = pipeline.get_staged_scheduling_info( + f1, f2, "GPipe with 1F1B on the last mesh scheduling" + ) + if result is None: + return False + # Validation successful - delegate to other functions + _ = result + + if f1.stage_id == context.num_meshes - 1: + return one_fwd_one_bwd_schedule_predicate(f1, f2, context) + return gpipe_schedule_predicate(f1, f2, context) + + +def circular_schedule_predicate_base( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, + reverse_backward: bool, +) -> bool: + """Returns true if f1 must happen before f2 in circular schedule.""" + # Check that both fragments are scheduling units + result = pipeline.get_staged_scheduling_info( + f1, f2, "circular pipelining scheduling" + ) + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + if transpose_count_f1 != transpose_count_f2: + # Forward fragments always happen before backward fragments + return transpose_count_f1 < transpose_count_f2 + + # Both forward or both backward - use phase-based ordering + phase_f1 = call_counter_f1 // context.num_meshes + phase_f2 = call_counter_f2 // context.num_meshes + + f1_list = [phase_f1, f1.stage_id, call_counter_f1] + f2_list = [phase_f2, f2.stage_id, call_counter_f2] + + # Forward fragments - ascending order + if transpose_count_f1 == 0: + return f1_list < f2_list + + # Backward fragments + if reverse_backward: + # Descending order + return f1_list > f2_list + + # Backward fragments with stage in descending order + f1_list[1], f2_list[1] = f2_list[1], f1_list[1] # Swap stage IDs + return f1_list < f2_list + + +def zero_bubble_h1_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in a ZeroBubbleH1 schedule.""" + result = pipeline.get_staged_scheduling_info( + f1, f2, "ZeroBubbleH1 scheduling" + ) + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + is_wgrad_f1 = f1.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + is_wgrad_f2 = f2.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + + # The following two conditions guarantee the forward and backward fragments + # are interleaved in the steady state of the pipeline. They are just like + # 1F1B but specialized to actual back-propagation fragments. + + # Clause 1: Ba(i) < F(i + num_meshes - stage_id) + if transpose_count_f1 == 1 and not is_wgrad_f1 and transpose_count_f2 == 0: + return call_counter_f1 == call_counter_f2 - context.num_meshes + f1.stage_id + + # Clause 2: F(i + num_meshes - stage_id - 1) < Ba(i) + if transpose_count_f1 == 0 and transpose_count_f2 == 1 and not is_wgrad_f2: + return call_counter_f1 == call_counter_f2 + context.num_meshes - ( + f1.stage_id + 1 + ) + + # The rest of the conditions position the parameter gradient fragments. + # Clause 3: Bw(i) < F(i + num_meshes) + # e.g. Bw(0) < F(4) above. + if ( + transpose_count_f1 == 1 + and (is_wgrad_f1 or f1.stage_id == 0) + and transpose_count_f2 == 0 + ): + return call_counter_f2 - call_counter_f1 == context.num_meshes + + # Clause 4: Ba(i + stage_id) < Bw(i) + # e.g. + # mesh0: Ba(0) < Bw(0) + # mesh1: Ba(1) < Bw(0) + # mesh2: Ba(2) < Bw(0) + # mesh3: Ba(3) < Bw(0) + if ( + transpose_count_f1 == 1 + and not is_wgrad_f1 + and transpose_count_f2 == 1 + and is_wgrad_f2 + ): + return call_counter_f1 - call_counter_f2 == f1.stage_id + + # This is just needed for transitively completing Clauses 3 and 2, needed for + # the final phase where there may be no remaining forward to anchor to. + # Bw(i) < Ba(i + stage_id + 1) + if ( + transpose_count_f1 == 1 + and is_wgrad_f1 + and transpose_count_f2 == 1 + and not is_wgrad_f2 + ): + return call_counter_f2 - call_counter_f1 == f1.stage_id + 1 + + return False + + +def zero_bubble_h2_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, + init_fwd_per_stage_fn: Callable[[int], int], +) -> bool: + """Returns true if f1 must happen before f2 in a ZeroBubbleH2 schedule.""" + result = pipeline.get_staged_scheduling_info( + f1, f2, "ZeroBubbleH2 scheduling" + ) + if result is None: + return False + _, transpose_count_f1, _, transpose_count_f2 = result + + is_wgrad_f1 = f1.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + is_wgrad_f2 = f2.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + + # How many fwd we are allowed to stream before entering steady state + init_fwd = init_fwd_per_stage_fn(f1.stage_id) + # The ZeroBubbleH2 pipeline is diagonally symmetric + complement_init_fwd = init_fwd_per_stage_fn( + context.num_meshes - f1.stage_id - 1 + ) + + # Initial phase + # Clause 1: F(i) <= B(_) for i < init_fwd + if ( + transpose_count_f1 == 0 + and transpose_count_f2 == 1 + and f1.call_counter < init_fwd + ): + return True + + # Clause 2: Ba(i) < F(i + init_fwd) + if ( + transpose_count_f1 == 1 + and not is_wgrad_f1 + and transpose_count_f2 == 0 + and f2.call_counter >= init_fwd + ): + return f2.call_counter - f1.call_counter == init_fwd + + # Clause 3: F(i + init_fwd - 1) < Ba(i) + if ( + transpose_count_f1 == 0 + and f1.call_counter >= init_fwd + and transpose_count_f2 == 1 + and not is_wgrad_f2 + ): + return f1.call_counter - f2.call_counter == init_fwd - 1 + + # Clause 4: Ba(i + complement_init_fwd - 1) < Bw(i) + if ( + transpose_count_f1 == 1 + and not is_wgrad_f1 + and transpose_count_f2 == 1 + and is_wgrad_f2 + ): + return f1.call_counter - f2.call_counter == complement_init_fwd - 1 + + # Clause 5: Bw(i) < Ba(i + complement_init_fwd) + if ( + transpose_count_f1 == 1 + and is_wgrad_f1 + and transpose_count_f2 == 1 + and not is_wgrad_f2 + ): + return f2.call_counter - f1.call_counter == complement_init_fwd + + return False + + +def latency_hiding_zero_bubble_h2_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, + latency_stage_fraction: float, +) -> bool: + """Returns true if f1 must happen before f2 in a latency-hiding ZeroBubbleH2 schedule. + + Args: + f1: First fragment to compare. + f2: Second fragment to compare. + context: Pipeline context with configuration. + latency_stage_fraction: Float between 0.0 and 1.0 specifying how much time + activation forwarding transfers take compared to a stage compute time. + """ + if not (0.0 <= latency_stage_fraction <= 1.0): + raise ValueError("latency_stage_fraction must be between 0.0 and 1.0") + + def init_fwds_per_stage(stage_id: int) -> int: + """Calculate number of forward microbatches before first backward.""" + # Number of transfers from beginning until first backward can execute + num_init_transfers = 2.0 * (context.num_meshes - stage_id - 1) + # Compute that has happened in initial first microbatch path + num_init_compute = 2.0 * (context.num_meshes - stage_id) - 1.0 + return int(num_init_compute + num_init_transfers * latency_stage_fraction) + + return zero_bubble_h2_schedule_predicate(f1, f2, context, init_fwds_per_stage) + + +def parallel_pipelines_with_wraparound_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + _: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in parallel pipelines with wraparound. + + Only supports forward fragments. The entrypoint for mesh{i} is call_counter + {i}. + For each mesh, the order is [F{n}, F{n-1}, ..., F{1}] rotated such that + the leading fragment is F{mesh_index}. + + Args: + f1: First fragment to compare. + f2: Second fragment to compare. + """ + result = pipeline.get_staged_scheduling_info( + f1, f2, "parallel pipelines scheduling" + ) + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + # Only forward fragments supported + if transpose_count_f1 != 0 or transpose_count_f2 != 0: + raise ValueError("Only forward fragments supported for parallel pipelines") + + if call_counter_f1 == call_counter_f2: + raise ValueError( + "Should not have duplicate call counter in parallel pipelines" + ) + + # The entrypoint to stage{i} is call_counter {i}, so this always happens + # before + if call_counter_f1 == f1.stage_id or call_counter_f2 == f1.stage_id: + return call_counter_f1 == f1.stage_id + + # stage_id is the pivot. If both call_counters are on the same side of + # the pivot, we flip the order. But if they are on different + # sides, then we take the order as per normal. + if (call_counter_f1 > f1.stage_id and call_counter_f2 > f1.stage_id) or ( + call_counter_f1 < f1.stage_id and call_counter_f2 < f1.stage_id + ): + return call_counter_f1 > call_counter_f2 + + return call_counter_f1 < call_counter_f2 diff --git a/shardy/integrations/python/jax/mpmd/types.py b/shardy/integrations/python/jax/mpmd/types.py index fa6908f6..561341ed 100644 --- a/shardy/integrations/python/jax/mpmd/types.py +++ b/shardy/integrations/python/jax/mpmd/types.py @@ -15,15 +15,14 @@ """Common types used by PartIR:MPMD.""" -from collections.abc import Collection, Mapping, Sequence, Set +from collections.abc import Mapping import dataclasses -import enum -from typing import Callable import jax import jaxtyping from shardy.integrations.python.jax.mpmd import partitioning_options as part_options +from shardy.integrations.python.jax.mpmd import pipeline PyTree = jaxtyping.PyTree @@ -40,119 +39,6 @@ MeshToCompileOptions = Mapping[str, jax.stages.CompilerOptions] PartitioningOptions = dict[str, bool | str] -## Type aliases for custom scheduling and merging rules - -FragmentMergeRules = Sequence['FragmentMergeRule'] -FragmentScheduleRules = Sequence['FragmentScheduleRule'] - -# Function that constructs a target FragmentInfo from a sequence of source -# fragments that will be merged together into the target. -TargetInfoBuilder = Callable[[Sequence['FragmentInfo']], 'FragmentInfo'] - -# Function that builds schedule and/or merge rules from fragments and pipeline -# context. -ScheduleMergeRuleBuilder = Callable[ - [Sequence['FragmentInfo'], 'PipelineContext'], - tuple[FragmentScheduleRules, FragmentMergeRules], -] - -# Binary predicate determining if two fragments should be merged or scheduled -# together. -RuleGeneratorPredicate = Callable[ - ['FragmentInfo', 'FragmentInfo', 'PipelineContext'], bool -] - - -@dataclasses.dataclass(frozen=True) -class FragmentOrigin: - """The origin of a fragment.""" - - computation_name: str - transpose_count: int = 0 - - -@enum.unique -class SplitFragmentType(enum.Enum): - """Fragment split behavior for transferred data. - - These values indicate how fragment portions handle transferred data from - the original fragment if the fragment is split during compilation: - - KEEP_TRANSFERRED: Fragment portion retains transferred data - - DROP_TRANSFERRED: Fragment portion drops transferred data - """ - - KEEP_TRANSFERRED = enum.auto() - DROP_TRANSFERRED = enum.auto() - - -@dataclasses.dataclass(frozen=True) -class FragmentInfo: - """A fragment of a computation.""" - - origins: tuple[FragmentOrigin, ...] - stage_id: int | None = None - call_counter: int | None = None - split_type: SplitFragmentType | None = None - mesh_name: str = '' - - -@dataclasses.dataclass(frozen=True) -class FragmentMergeRule: - """A rule for merging fragments of a computation. - - Attributes: - sources: The source fragments to be merged. The order does not affect the - final position of the merged fragment. - target: The target fragment metadata that results from merging the sources. - """ - - sources: Set[FragmentInfo] - target: FragmentInfo - - def __post_init__(self): - # Validate the fragment merge rule. - if len(self.sources) < 2: - raise ValueError( - 'FragmentMergeRule must contain at least 2 source fragments, but got' - f' {self}.' - ) - validate_fragment_rule_origins(self.sources) - validate_fragment_rule_meshes(self.sources) - - if not self.target.origins: - raise ValueError( - f'Target fragment must have at least one origin, but got {self}.' - ) - - -@dataclasses.dataclass(frozen=True) -class FragmentScheduleRule: - """A rule for scheduling fragments in a specific execution order. - - Attributes: - ordered_fragments: Fragments in the order they should execute. Must contain - at least 2 fragments, and all fragments must be on the same mesh. - """ - - ordered_fragments: Sequence[FragmentInfo] - - def __post_init__(self): - # Validate the fragment schedule rule. - if len(self.ordered_fragments) < 2: - raise ValueError( - 'FragmentScheduleRule must contain at least 2 fragments, but got' - f' {self}.' - ) - validate_fragment_rule_origins(self.ordered_fragments) - validate_fragment_rule_meshes(self.ordered_fragments) - - -@dataclasses.dataclass(frozen=True) -class PipelineContext: - """Context for pipeline scheduling and merging predicates.""" - - num_meshes: int - # LINT.IfChange CPU_MESH_SUFFIX = '/cpu' @@ -171,22 +57,6 @@ def get_schedulable_meshes(topology: Topology) -> list[str]: return [name for name in topology if not mesh_is_on_cpu(name)] -@dataclasses.dataclass(frozen=True) -class PipelineSchedule: - """A set of rules and options which define an MPMD pipeline. - - Attributes: - schedule_merge_rule_builders: A sequence of functions that builds schedule - and/or merge rules for fragments. - required_mpmd_options: A mapping of PartitioningEnvironment flags that are - required for this schedule to function correctly. See - partitioning_options.py for available options. - """ - - schedule_merge_rule_builders: Sequence[ScheduleMergeRuleBuilder] | None = None - required_mpmd_options: Mapping[str, bool | str] | None = None - - @dataclasses.dataclass(frozen=True) class MpmdConfig: """Config for constructing an MPMD program with PartIR. @@ -235,8 +105,8 @@ class MpmdConfig: output_mesh_assignment: PyTree[str | None] partitioning_options: PartitioningOptions | None read_input_output_mesh_from_shardings: bool - fragment_merge_rules: FragmentMergeRules | None - fragment_schedule_rules: FragmentScheduleRules | None + fragment_merge_rules: pipeline.FragmentMergeRules | None + fragment_schedule_rules: pipeline.FragmentScheduleRules | None @property def _spmd_mesh(self) -> jax.sharding.Mesh: @@ -304,8 +174,8 @@ def make_config( output_mesh_assignment: PyTree[str | None] = (), partitioning_options: PartitioningOptions | None = None, read_input_output_mesh_from_shardings: bool = False, - fragment_merge_rules: FragmentMergeRules | None = None, - fragment_schedule_rules: FragmentScheduleRules | None = None, + fragment_merge_rules: pipeline.FragmentMergeRules | None = None, + fragment_schedule_rules: pipeline.FragmentScheduleRules | None = None, ) -> MpmdConfig: """Creates a `MpmdConfig`, inferring the tpu topology if not provided. @@ -434,33 +304,6 @@ def validate_input_output_mesh_assignments( ) -def validate_fragment_rule_origins( - fragment_collection: Collection[FragmentInfo], -) -> None: - """Validates that all fragments have at least one origin.""" - for fragment in fragment_collection: - if not fragment.origins: - raise ValueError( - f'Each fragment must have at least one origin, but got {fragment} in' - f' {fragment_collection}.' - ) - - -def validate_fragment_rule_meshes( - fragment_collection: Collection[FragmentInfo], -) -> None: - """Validates that all fragments are on the same mesh.""" - first_fragment = next(iter(fragment_collection)) - first_mesh = first_fragment.mesh_name - if not all( - fragment.mesh_name == first_mesh for fragment in fragment_collection - ): - raise ValueError( - 'Fragments being merged/scheduled must be on the same mesh, but got' - f' {fragment_collection}.' - ) - - def mesh_names( pytree: PyTree[ jax.Array