Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions shardy/integrations/python/jax/mpmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from shardy.integrations.python.jax.mpmd.ops import named_computation
from shardy.integrations.python.jax.mpmd.ops import named_tensor
from shardy.integrations.python.jax.mpmd.ops import reduce
from shardy.integrations.python.jax.mpmd.pipeline import FragmentInfo
from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRule
from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRules
from shardy.integrations.python.jax.mpmd.pipeline import FragmentOrigin
from shardy.integrations.python.jax.mpmd.stages import MpmdCompiled as Compiled
from shardy.integrations.python.jax.mpmd.stages import MpmdExecutable as Executable
from shardy.integrations.python.jax.mpmd.stages import MpmdJitShardingInfo
from shardy.integrations.python.jax.mpmd.stages import MpmdLowered as Lowered
from shardy.integrations.python.jax.mpmd.types import FragmentInfo
from shardy.integrations.python.jax.mpmd.types import FragmentMergeRule
from shardy.integrations.python.jax.mpmd.types import FragmentMergeRules
from shardy.integrations.python.jax.mpmd.types import FragmentOrigin
from shardy.integrations.python.jax.mpmd.types import FunctionIOMeshAssignment
from shardy.integrations.python.jax.mpmd.types import make_config
from shardy.integrations.python.jax.mpmd.types import mesh_names
Expand Down
30 changes: 16 additions & 14 deletions shardy/integrations/python/jax/mpmd/jaxlib_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for converting between Python types and jaxlib pybind types."""
"""Utilities for converting between Python types and jaxlib pybind pipeline."""

from jaxlib import _sdy_mpmd as jaxlib_mpmd

from shardy.integrations.python.jax.mpmd import types
from shardy.integrations.python.jax.mpmd import pipeline


def _to_jaxlib_split_type(
split_type: types.SplitFragmentType | None,
split_type: pipeline.SplitFragmentType | None,
) -> jaxlib_mpmd.SplitFragmentType | None:
"""Convert native Python enum to pybinded enum."""
if split_type is None:
return None
if split_type == types.SplitFragmentType.KEEP_TRANSFERRED:
if split_type == pipeline.SplitFragmentType.KEEP_TRANSFERRED:
return jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED
elif split_type == types.SplitFragmentType.DROP_TRANSFERRED:
elif split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED:
return jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED
else:
raise ValueError(f'Unknown SplitFragmentType: {split_type}')


def _from_jaxlib_split_type(
split_type: jaxlib_mpmd.SplitFragmentType | None,
) -> types.SplitFragmentType | None:
) -> pipeline.SplitFragmentType | None:
"""Convert pybinded enum to native Python enum."""
if split_type is None:
return None
if split_type == jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED:
return types.SplitFragmentType.KEEP_TRANSFERRED
return pipeline.SplitFragmentType.KEEP_TRANSFERRED
elif split_type == jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED:
return types.SplitFragmentType.DROP_TRANSFERRED
return pipeline.SplitFragmentType.DROP_TRANSFERRED
else:
raise ValueError(f'Unknown jaxlib_mpmd.SplitFragmentType: {split_type}')


def convert_fragment_info_to_pybind(
fragment: types.FragmentInfo,
fragment: pipeline.FragmentInfo,
) -> jaxlib_mpmd.FragmentInfo:
"""Converts FragmentInfo to jaxlib_mpmd.FragmentInfo."""
return jaxlib_mpmd.FragmentInfo(
Expand All @@ -67,11 +67,13 @@ def convert_fragment_info_to_pybind(

def convert_pybind_fragment_info_to_types(
fragment: jaxlib_mpmd.FragmentInfo,
) -> types.FragmentInfo:
) -> pipeline.FragmentInfo:
"""Converts jaxlib_mpmd.FragmentInfo to FragmentInfo."""
return types.FragmentInfo(
return pipeline.FragmentInfo(
origins=tuple(
types.FragmentOrigin(origin.computation_name, origin.transpose_count)
pipeline.FragmentOrigin(
origin.computation_name, origin.transpose_count
)
for origin in fragment.origins
),
stage_id=fragment.stage_id,
Expand All @@ -82,7 +84,7 @@ def convert_pybind_fragment_info_to_types(


def convert_fragment_merge_rules_to_pybind(
fragment_merge_rules: types.FragmentMergeRules,
fragment_merge_rules: pipeline.FragmentMergeRules,
) -> list[jaxlib_mpmd.FragmentMergeRule]:
"""Converts fragment merge rules to jaxlib_mpmd.FragmentMergeRules."""
pybind_fragment_merge_rules = []
Expand All @@ -100,7 +102,7 @@ def convert_fragment_merge_rules_to_pybind(


def convert_fragment_schedule_rules_to_pybind(
fragment_schedule_rules: types.FragmentScheduleRules,
fragment_schedule_rules: pipeline.FragmentScheduleRules,
) -> list[jaxlib_mpmd.FragmentScheduleRule]:
"""Converts fragment schedule rules to jaxlib_mpmd.FragmentScheduleRules."""
pybind_fragment_schedule_rules = []
Expand Down
53 changes: 28 additions & 25 deletions shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from jaxlib import _sdy_mpmd as jaxlib_mpmd

from shardy.integrations.python.jax.mpmd import jaxlib_utils
from shardy.integrations.python.jax.mpmd import types
from shardy.integrations.python.jax.mpmd import pipeline


class SplitTypeConversionTest(parameterized.TestCase):
Expand All @@ -28,12 +28,12 @@ class SplitTypeConversionTest(parameterized.TestCase):
@parameterized.named_parameters(
(
'keep_transferred',
types.SplitFragmentType.KEEP_TRANSFERRED,
pipeline.SplitFragmentType.KEEP_TRANSFERRED,
jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED,
),
(
'drop_transferred',
types.SplitFragmentType.DROP_TRANSFERRED,
pipeline.SplitFragmentType.DROP_TRANSFERRED,
jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED,
),
('none', None, None),
Expand All @@ -52,27 +52,27 @@ class FragmentInfoConversionTest(parameterized.TestCase):
@parameterized.named_parameters(
(
'single_origin',
types.FragmentInfo(
origins=(types.FragmentOrigin('comp1', 0),), mesh_name='mesh1'
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('comp1', 0),), mesh_name='mesh1'
),
),
(
'multiple_origins',
types.FragmentInfo(
pipeline.FragmentInfo(
origins=(
types.FragmentOrigin('comp1', 0),
types.FragmentOrigin('comp2', 1),
pipeline.FragmentOrigin('comp1', 0),
pipeline.FragmentOrigin('comp2', 1),
),
mesh_name='mesh1',
),
),
(
'all_fields',
types.FragmentInfo(
origins=(types.FragmentOrigin('comp1', 2),),
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('comp1', 2),),
stage_id=5,
call_counter=3,
split_type=types.SplitFragmentType.KEEP_TRANSFERRED,
split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED,
mesh_name='mesh2',
),
),
Expand All @@ -89,18 +89,21 @@ class FragmentMergeRulesConversionTest(absltest.TestCase):

def test_single_rule(self):
"""Test converting single merge rule."""
f1 = types.FragmentInfo(
origins=(types.FragmentOrigin('f1', 0),), mesh_name='m1'
f1 = pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('f1', 0),), mesh_name='m1'
)
f2 = types.FragmentInfo(
origins=(types.FragmentOrigin('f2', 0),), mesh_name='m1'
f2 = pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('f2', 0),), mesh_name='m1'
)
target = types.FragmentInfo(
origins=(types.FragmentOrigin('f1', 0), types.FragmentOrigin('f2', 0)),
target = pipeline.FragmentInfo(
origins=(
pipeline.FragmentOrigin('f1', 0),
pipeline.FragmentOrigin('f2', 0),
),
mesh_name='m1',
)

rule = types.FragmentMergeRule(sources={f1, f2}, target=target)
rule = pipeline.FragmentMergeRule(sources={f1, f2}, target=target)
result = jaxlib_utils.convert_fragment_merge_rules_to_pybind([rule])

self.assertLen(result, 1)
Expand All @@ -114,18 +117,18 @@ class FragmentScheduleRulesConversionTest(absltest.TestCase):
def test_preserves_order(self):
"""Test that ordered_fragments order is preserved."""
frags = [
types.FragmentInfo(
origins=(types.FragmentOrigin('first', 0),), mesh_name='m1'
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('first', 0),), mesh_name='m1'
),
types.FragmentInfo(
origins=(types.FragmentOrigin('second', 0),), mesh_name='m1'
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('second', 0),), mesh_name='m1'
),
types.FragmentInfo(
origins=(types.FragmentOrigin('third', 0),), mesh_name='m1'
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('third', 0),), mesh_name='m1'
),
]

rule = types.FragmentScheduleRule(ordered_fragments=frags)
rule = pipeline.FragmentScheduleRule(ordered_fragments=frags)
result = jaxlib_utils.convert_fragment_schedule_rules_to_pybind([rule])

self.assertEqual(
Expand Down
Loading
Loading