|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 | 15 | import pytest
|
| 16 | +from packaging.version import parse |
16 | 17 |
|
17 | 18 | from sagemaker.huggingface import get_huggingface_llm_image_uri
|
18 | 19 | from tests.unit.sagemaker.image_uris import expected_uris, conftest
|
@@ -72,10 +73,31 @@ def test_huggingface_uris(load_config):
|
72 | 73 | VERSIONS = load_config["inference"]["versions"]
|
73 | 74 | device = load_config["inference"]["processors"][0]
|
74 | 75 | backend = "huggingface-neuronx" if device == "inf2" else "huggingface"
|
| 76 | + |
| 77 | + # Fail if device is not in mapping |
| 78 | + if device not in HF_VERSIONS_MAPPING: |
| 79 | + raise ValueError(f"Device {device} not found in HF_VERSIONS_MAPPING") |
| 80 | + |
| 81 | + # Get highest version for the device |
| 82 | + highest_version = max(HF_VERSIONS_MAPPING[device].keys(), key=lambda x: parse(x)) |
| 83 | + |
75 | 84 | for version in VERSIONS:
|
76 | 85 | ACCOUNTS = load_config["inference"]["versions"][version]["registries"]
|
77 | 86 | for region in ACCOUNTS.keys():
|
78 | 87 | uri = get_huggingface_llm_image_uri(backend, region=region, version=version)
|
| 88 | + |
| 89 | + # Skip only if test version is higher than highest known version. |
| 90 | + # There's now automation to add new TGI releases to image_uri_config directory |
| 91 | + # that doesn't involve a human raising a PR. |
| 92 | + if parse(version) > parse(highest_version): |
| 93 | + print( |
| 94 | + f"Skipping version check for {version} as there is " |
| 95 | + "automation that now updates the image_uri_config " |
| 96 | + "without a human raising a PR. Tests will pass for " |
| 97 | + f"versions higher than {highest_version} that are not in HF_VERSIONS_MAPPING." |
| 98 | + ) |
| 99 | + continue |
| 100 | + |
79 | 101 | expected = expected_uris.huggingface_llm_framework_uri(
|
80 | 102 | "huggingface-pytorch-tgi-inference",
|
81 | 103 | ACCOUNTS[region],
|
|
0 commit comments