Skip to content

Commit 92efc09

Browse files
evakravibenieric
andauthored
fix: tgi image uri unit tests (#5127)
* fix: tgi image uri unit tests * fix: black-format and flake8 failures * fix: parse * fix: print statement --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 6d52a81 commit 92efc09

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

Diff for: tests/unit/sagemaker/image_uris/test_huggingface_llm.py

+22
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16+
from packaging.version import parse
1617

1718
from sagemaker.huggingface import get_huggingface_llm_image_uri
1819
from tests.unit.sagemaker.image_uris import expected_uris, conftest
@@ -72,10 +73,31 @@ def test_huggingface_uris(load_config):
7273
VERSIONS = load_config["inference"]["versions"]
7374
device = load_config["inference"]["processors"][0]
7475
backend = "huggingface-neuronx" if device == "inf2" else "huggingface"
76+
77+
# Fail if device is not in mapping
78+
if device not in HF_VERSIONS_MAPPING:
79+
raise ValueError(f"Device {device} not found in HF_VERSIONS_MAPPING")
80+
81+
# Get highest version for the device
82+
highest_version = max(HF_VERSIONS_MAPPING[device].keys(), key=lambda x: parse(x))
83+
7584
for version in VERSIONS:
7685
ACCOUNTS = load_config["inference"]["versions"][version]["registries"]
7786
for region in ACCOUNTS.keys():
7887
uri = get_huggingface_llm_image_uri(backend, region=region, version=version)
88+
89+
# Skip only if test version is higher than highest known version.
90+
# There's now automation to add new TGI releases to image_uri_config directory
91+
# that doesn't involve a human raising a PR.
92+
if parse(version) > parse(highest_version):
93+
print(
94+
f"Skipping version check for {version} as there is "
95+
"automation that now updates the image_uri_config "
96+
"without a human raising a PR. Tests will pass for "
97+
f"versions higher than {highest_version} that are not in HF_VERSIONS_MAPPING."
98+
)
99+
continue
100+
79101
expected = expected_uris.huggingface_llm_framework_uri(
80102
"huggingface-pytorch-tgi-inference",
81103
ACCOUNTS[region],

0 commit comments

Comments
 (0)