From 0ada4b5643d2085630d2e48ecdf2bc434acf5839 Mon Sep 17 00:00:00 2001 From: youn17 Date: Mon, 18 Aug 2025 13:31:35 +0900 Subject: [PATCH 1/6] fix torch version detector --- test/test_utils.py | 4 ++-- torchao/utils.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index c5bbf45a96..6e745a415b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,9 +16,9 @@ class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ - ("2.5.0a0+git9f17037", "2.5.0", True), + ("2.5.0a0+git9f17037", "2.5.0", False), ("2.5.0a0+git9f17037", "2.4.0", True), - ("2.5.0.dev20240708+cu121", "2.5.0", True), + ("2.5.0.dev20240708+cu121", "2.5.0", False), ("2.5.0.dev20240708+cu121", "2.4.0", True), ("2.5.0", "2.4.0", True), ("2.5.0", "2.5.0", True), diff --git a/torchao/utils.py b/torchao/utils.py index a32166d556..df2522f155 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -368,7 +368,14 @@ def is_fbcode(): def torch_version_at_least(min_version): - return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 + from packaging.version import parse as parse_version + + if is_fbcode(): + return True + + # Parser for local identifiers + current_version = re.sub(r"\+.*$", "", torch.__version__) + return parse_version(current_version) >= parse_version(min_version) def _deprecated_torch_version_at_least(version_str: str) -> str: From de7a877e74a8c746db2bb1bcec6c46a8cc9dea77 Mon Sep 17 00:00:00 2001 From: youn17 Date: Wed, 20 Aug 2025 01:14:40 +0900 Subject: [PATCH 2/6] add pre-release parser for torch_version_at_least and remove compare_versions - Co-authored-by: andrewor14 --- test/test_utils.py | 16 ++++++++-------- torchao/utils.py | 33 +++++++++++++++++---------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 6e745a415b..df3290dd93 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,14 +16,14 @@ class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ - ("2.5.0a0+git9f17037", "2.5.0", False), - ("2.5.0a0+git9f17037", "2.4.0", True), - ("2.5.0.dev20240708+cu121", "2.5.0", False), - ("2.5.0.dev20240708+cu121", "2.4.0", True), - ("2.5.0", "2.4.0", True), - ("2.5.0", "2.5.0", True), - ("2.4.0", "2.4.0", True), - ("2.4.0", "2.5.0", False), + ("2.5.0a0", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0a0", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0.dev", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0.dev", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0", "2.4.0", True), # [2, 5, 0] > [2, 4, 0] + ("2.5.0", "2.5.0", True), # [2, 5, 0] >= [2, 5, 0] + ("2.4.0", "2.4.0", True), # [2, 4, 0] >= [2, 4, 0] + ("2.4.0", "2.5.0", False), # [2, 4, 0] < [2, 5, 0] ] for torch_version, compare_version, expected_result in test_cases: diff --git a/torchao/utils.py b/torchao/utils.py index df2522f155..00653a008f 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -348,34 +348,35 @@ def _is_float8_type(dtype: torch.dtype) -> bool: def parse_version(version_string): - # Extract just the X.Y.Z part from the version string - match = re.match(r"(\d+\.\d+\.\d+)", version_string) + """ + Parse version string representing pre-release with -1 + + Examples: "2.5.0" -> [2, 5, 0], "2.5.0.dev" -> [2, 5, -1] + """ + version = re.sub(r"\+.*$", "", version_string) + + # Check for pre-release indicators (including all common patterns) + is_prerelease = bool(re.search(r"(a|b|dev)", version)) + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version) if match: - version = match.group(1) - return [int(x) for x in version.split(".")] + major, minor, patch = map(int, match.groups()) + if is_prerelease: + patch = -1 + return [major, minor, patch] else: raise ValueError(f"Invalid version string format: {version_string}") -def compare_versions(v1, v2): - v1_parts = parse_version(v1) - v2_parts = parse_version(v2) - return (v1_parts > v2_parts) - (v1_parts < v2_parts) - - def is_fbcode(): return not hasattr(torch.version, "git_version") def torch_version_at_least(min_version): - from packaging.version import parse as parse_version - if is_fbcode(): return True # Parser for local identifiers - current_version = re.sub(r"\+.*$", "", torch.__version__) - return parse_version(current_version) >= parse_version(min_version) + return parse_version(torch.__version__) >= parse_version(min_version) def _deprecated_torch_version_at_least(version_str: str) -> str: @@ -990,13 +991,13 @@ def is_sm_at_least_100(): def check_cpu_version(device, version="2.6.0"): if isinstance(device, torch.device): device = device.type - return device == "cpu" and compare_versions(torch.__version__, version) >= 0 + return device == "cpu" and torch_version_at_least(version) def check_xpu_version(device, version="2.8.0"): if isinstance(device, torch.device): device = device.type - return device == "xpu" and compare_versions(torch.__version__, version) >= 0 + return device == "xpu" and torch_version_at_least(version) def ceil_div(a, b): From 3818cbf9ec34f998f39d9d40014f4c31b77dfc0a Mon Sep 17 00:00:00 2001 From: youn17 Date: Wed, 20 Aug 2025 01:14:40 +0900 Subject: [PATCH 3/6] add pre-release parser for torch_version_at_least and remove compare_versions - Co-authored-by: andrewor14 --- test.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 0000000000..c8a2b69081 --- /dev/null +++ b/test.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import re + +import torch + + +def parse_version(version_string): + """ + Parse version string representing pre-release with -1. + + Examples: + - "2.5.0" -> [2, 5, 0] + - "2.5.0.dev" -> [2, 5, -1] + """ + # Remove local version identifier (everything after +) + clean_version = re.sub(r"\+.*$", "", version_string) + + # Check for pre-release indicators (including all common patterns) + is_prerelease = bool(re.search(r"(a|b|dev)", clean_version)) + + match = re.match(r"(\d+)\.(\d+)\.(\d+)", clean_version) + if match: + major, minor, patch = map(int, match.groups()) + if is_prerelease: + patch = -1 + return [major, minor, patch] + else: + raise ValueError(f"Invalid version string format: {version_string}") + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +def torch_version_at_least(min_version): + if is_fbcode(): + return True + + # Parser for local identifiers + return parse_version(torch.__version__) >= parse_version(min_version) + + +# Test cases +if __name__ == "__main__": + test_cases = [ + ("2.5.0+cu126", [2, 5, 0]), + ("2.5.0", [2, 5, 0]), + ("2.5.0a0+git9f17037", [2, 5, -1]), + ("2.5.0.dev20240708+cu121", [2, 5, -1]), + ("2.4.0", [2, 4, 0]), + ("2.2.0beta1", [2, 2, -1]), + ] + + print("Testing parse_version:") + for version_str, expected in test_cases: + result = parse_version(version_str) + status = "✓" if result == expected else "✗" + print(f"{status} {version_str} -> {result} (expected: {expected})") From 747026bf0b634fe48fef80234a7e4a220f651a65 Mon Sep 17 00:00:00 2001 From: namgyu-youn Date: Wed, 20 Aug 2025 01:30:26 +0900 Subject: [PATCH 4/6] remove local test code --- test.py | 63 --------------------------------------------------------- 1 file changed, 63 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index c8a2b69081..0000000000 --- a/test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import re - -import torch - - -def parse_version(version_string): - """ - Parse version string representing pre-release with -1. - - Examples: - - "2.5.0" -> [2, 5, 0] - - "2.5.0.dev" -> [2, 5, -1] - """ - # Remove local version identifier (everything after +) - clean_version = re.sub(r"\+.*$", "", version_string) - - # Check for pre-release indicators (including all common patterns) - is_prerelease = bool(re.search(r"(a|b|dev)", clean_version)) - - match = re.match(r"(\d+)\.(\d+)\.(\d+)", clean_version) - if match: - major, minor, patch = map(int, match.groups()) - if is_prerelease: - patch = -1 - return [major, minor, patch] - else: - raise ValueError(f"Invalid version string format: {version_string}") - - -def is_fbcode(): - return not hasattr(torch.version, "git_version") - - -def torch_version_at_least(min_version): - if is_fbcode(): - return True - - # Parser for local identifiers - return parse_version(torch.__version__) >= parse_version(min_version) - - -# Test cases -if __name__ == "__main__": - test_cases = [ - ("2.5.0+cu126", [2, 5, 0]), - ("2.5.0", [2, 5, 0]), - ("2.5.0a0+git9f17037", [2, 5, -1]), - ("2.5.0.dev20240708+cu121", [2, 5, -1]), - ("2.4.0", [2, 4, 0]), - ("2.2.0beta1", [2, 2, -1]), - ] - - print("Testing parse_version:") - for version_str, expected in test_cases: - result = parse_version(version_str) - status = "✓" if result == expected else "✗" - print(f"{status} {version_str} -> {result} (expected: {expected})") From 37b7089bda6bff50b8c510a894e382385152bf1f Mon Sep 17 00:00:00 2001 From: youn17 Date: Wed, 20 Aug 2025 12:03:37 +0900 Subject: [PATCH 5/6] update PyTorch pre-release version indicator --- test/test_utils.py | 8 ++++---- torchao/utils.py | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index df3290dd93..3bc16c20c0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,10 +16,10 @@ class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ - ("2.5.0a0", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] - ("2.5.0a0", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] - ("2.5.0.dev", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] - ("2.5.0.dev", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0a0+git9f17037", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0a0+git9f17037", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0.dev20240708+cu121", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0.dev20240708+cu121", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] ("2.5.0", "2.4.0", True), # [2, 5, 0] > [2, 4, 0] ("2.5.0", "2.5.0", True), # [2, 5, 0] >= [2, 5, 0] ("2.4.0", "2.4.0", True), # [2, 4, 0] >= [2, 4, 0] diff --git a/torchao/utils.py b/torchao/utils.py index 00653a008f..71e140eb87 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -351,13 +351,11 @@ def parse_version(version_string): """ Parse version string representing pre-release with -1 - Examples: "2.5.0" -> [2, 5, 0], "2.5.0.dev" -> [2, 5, -1] + Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0] """ - version = re.sub(r"\+.*$", "", version_string) - - # Check for pre-release indicators (including all common patterns) - is_prerelease = bool(re.search(r"(a|b|dev)", version)) - match = re.match(r"(\d+)\.(\d+)\.(\d+)", version) + # Check for pre-release indicators + is_prerelease = bool(re.search(r"(git|dev|a\d+|b\d+|rc\d+)", version_string)) + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string) if match: major, minor, patch = map(int, match.groups()) if is_prerelease: From 05edf521abe70dfab2223fde67f7903af1f0a5ef Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 21 Aug 2025 09:03:21 +0900 Subject: [PATCH 6/6] update pre-release patterns --- torchao/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/utils.py b/torchao/utils.py index 71e140eb87..298a0d176a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -354,7 +354,7 @@ def parse_version(version_string): Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0] """ # Check for pre-release indicators - is_prerelease = bool(re.search(r"(git|dev|a\d+|b\d+|rc\d+)", version_string)) + is_prerelease = bool(re.search(r"(git|dev)", version_string)) match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string) if match: major, minor, patch = map(int, match.groups())