diff --git a/setup.py b/setup.py index aeb0766c474..a42aef69725 100644 --- a/setup.py +++ b/setup.py @@ -96,8 +96,14 @@ def get_dist(pkgname): return None pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch") - if os.getenv("PYTORCH_VERSION"): - pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") + if version_pin := os.getenv("PYTORCH_VERSION"): + pytorch_dep += "==" + version_pin + elif (version_pin_ge := os.getenv("PYTORCH_VERSION_GE")) and (version_pin_lt := os.getenv("PYTORCH_VERSION_LT")): + # This branch and the associated env vars exist to help third-party + # builds like in https://github.com/pytorch/vision/pull/8936. This is + # supported on a best-effort basis, we don't guarantee that this won't + # eventually break (and we don't test it.) + pytorch_dep += f">={version_pin_ge},<{version_pin_lt}" requirements = [ "numpy",