Skip to content

Commit

Permalink
Add more logging to sharktank data tests. (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd authored Feb 6, 2025
1 parent 3cccc20 commit a40a486
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ jobs:
run: |
source ${VENV_DIR}/bin/activate
pytest \
-v \
--log-cli-level=info \
--with-clip-data \
--with-flux-data \
--with-t5-data \
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/tools/import_hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def import_hf_dataset(
if output_irpa_file is None:
return dataset

dataset.save(output_irpa_file, io_report_callback=logger.info)
dataset.save(output_irpa_file, io_report_callback=logger.debug)


def main(argv: list[str]):
Expand Down
9 changes: 9 additions & 0 deletions sharktank/tests/models/clip/clip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from parameterized import parameterized
from copy import copy
import logging
import pytest
import torch
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -72,6 +73,8 @@

with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')")

logger = logging.getLogger(__name__)


@pytest.mark.usefixtures("path_prefix")
class ClipTextIreeTest(TempDirTestBase):
Expand Down Expand Up @@ -163,6 +166,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
batch_size = input_ids.shape[0]
mlir_path = f"{target_model_path_prefix}.mlir"

logger.info("Exporting clip text model to MLIR...")
export_clip_text_model_iree_test_data(
reference_model=reference_model,
target_dtype=target_dtype,
Expand All @@ -172,12 +176,14 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
)

iree_module_path = f"{target_model_path_prefix}.vmfb"
logger.info("Compiling MLIR file...")
iree.compiler.compile_file(
mlir_path,
output_file=iree_module_path,
extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
)

logger.info("Invoking reference torch function...")
reference_result_dict = call_torch_module_function(
module=reference_model,
function_name="forward",
Expand All @@ -187,6 +193,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
expected_outputs = flatten_for_iree_signature(reference_result_dict)

iree_devices = get_iree_devices(driver="hip", device_count=1)
logger.info("Loading IREE module...")
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path=iree_module_path,
devices=iree_devices,
Expand All @@ -195,6 +202,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
iree_args = prepare_iree_module_function_args(
args=flatten_for_iree_signature(input_args), devices=iree_devices
)
logger.info("Invoking IREE function...")
iree_result = iree_to_torch(
*run_iree_module_function(
module=iree_module,
Expand All @@ -213,6 +221,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
actual_last_hidden_state = actual_outputs[0]
expected_last_hidden_state = expected_outputs[0]

logger.info("Comparing outputs...")
assert_text_encoder_state_close(
actual_last_hidden_state, expected_last_hidden_state, atol
)
Expand Down
11 changes: 9 additions & 2 deletions sharktank/tests/models/flux/flux_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from sharktank.types import Dataset, Theta

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')")

iree_compile_flags = [
Expand Down Expand Up @@ -102,6 +103,7 @@ def runCompareIreeAgainstTorchEager(
parameters_path = self._temp_dir / "parameters.irpa"
batch_size = 1
batch_sizes = [batch_size]
logger.info("Exporting flux transformer to MLIR...")
export_flux_transformer(
target_torch_model,
mlir_output_path=mlir_path,
Expand All @@ -110,9 +112,10 @@ def runCompareIreeAgainstTorchEager(
)

iree_module_path = self._temp_dir / "model.vmfb"
logger.info("Compiling MLIR file...")
iree.compiler.compile_file(
mlir_path,
output_file=iree_module_path,
str(mlir_path),
output_file=str(iree_module_path),
extra_args=iree_compile_flags,
)

Expand All @@ -136,6 +139,7 @@ def runCompareIreeAgainstTorchEager(
for k, t in target_input_kwargs.items()
)

logger.info("Invoking reference torch function...")
reference_result_dict = call_torch_module_function(
module=reference_model,
function_name="forward",
Expand All @@ -145,6 +149,7 @@ def runCompareIreeAgainstTorchEager(
expected_outputs = flatten_for_iree_signature(reference_result_dict)

iree_devices = get_iree_devices(driver="hip", device_count=1)
logger.info("Loading IREE module...")
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path=iree_module_path,
devices=iree_devices,
Expand All @@ -155,6 +160,7 @@ def runCompareIreeAgainstTorchEager(
devices=iree_devices,
)

logger.info("Invoking IREE function...")
iree_result = iree_to_torch(
*run_iree_module_function(
module=iree_module,
Expand All @@ -168,6 +174,7 @@ def runCompareIreeAgainstTorchEager(
ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
for i in range(len(expected_outputs))
]
logger.info("Comparing outputs...")
torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0)

def runTestCompareDevIreeAgainstHuggingFace(
Expand Down
11 changes: 11 additions & 0 deletions sharktank/tests/models/t5/t5_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional
import os
from collections import OrderedDict
import logging
import pytest
import torch
from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path
Expand Down Expand Up @@ -61,6 +62,8 @@

with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')")

logger = logging.getLogger(__name__)


@pytest.mark.usefixtures("get_model_artifacts")
class T5EncoderEagerTest(TestCase):
Expand Down Expand Up @@ -184,15 +187,18 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
pad_to_multiple_of=config.context_length_padding_block_size,
).input_ids

logger.info("Invoking Torch eager model...")
model = T5Encoder(theta=dataset.root_theta, config=config)
model.eval()

logger.info("Invoking reference HuggingFace model...")
expected_outputs = reference_model(input_ids=input_ids)
actual_outputs = model(input_ids=input_ids)
actual_outputs = tree_map(
lambda t: ops.to(t, dtype=reference_dtype), actual_outputs
)

logger.info("Comparing outputs...")
torch.testing.assert_close(
actual_outputs, expected_outputs, atol=atol, rtol=rtol
)
Expand Down Expand Up @@ -340,18 +346,21 @@ def runTestV1_1CompareIreeAgainstTorchEager(

mlir_path = f"{target_model_path_prefix}.mlir"
if not self.caching or not os.path.exists(mlir_path):
logger.info("Exporting T5 encoder model to MLIR...")
export_encoder_mlir(
parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path
)
iree_module_path = f"{target_model_path_prefix}.vmfb"
if not self.caching or not os.path.exists(iree_module_path):
logger.info("Compiling MLIR file...")
iree.compiler.compile_file(
mlir_path,
output_file=iree_module_path,
extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
)

iree_devices = get_iree_devices(driver="hip", device_count=1)
logger.info("Loading IREE module...")
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path=iree_module_path,
devices=iree_devices,
Expand All @@ -360,6 +369,7 @@ def runTestV1_1CompareIreeAgainstTorchEager(
iree_args = prepare_iree_module_function_args(
args=flatten_for_iree_signature(input_args), devices=iree_devices
)
logger.info("Invoking IREE function...")
iree_result = iree_to_torch(
*run_iree_module_function(
module=iree_module,
Expand All @@ -375,6 +385,7 @@ def runTestV1_1CompareIreeAgainstTorchEager(
for i in range(len(reference_result))
]

logger.info("Comparing outputs...")
torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol)

@with_t5_data
Expand Down

0 comments on commit a40a486

Please sign in to comment.