Skip to content

Commit

Permalink
install fbgemm via nova
Browse files Browse the repository at this point in the history
  • Loading branch information
henrylhtsang committed Oct 3, 2023
1 parent 64183b3 commit b0254a7
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 41 deletions.
4 changes: 2 additions & 2 deletions .github/scripts/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def main():
new_matrix_entries = []

for entry in full_matrix["include"]:
if entry["gpu_arch_version"] != "12.1":
new_matrix_entries.append(entry)
# if entry["gpu_arch_version"] != "12.1":
new_matrix_entries.append(entry)

new_matrix = {"include": new_matrix_entries}
print(json.dumps(new_matrix))
Expand Down
15 changes: 15 additions & 0 deletions .github/scripts/install_fbgemm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

echo "CU_VERSION"
echo "$CU_VERSION"

echo "CHANNEL"
echo "$CHANNEL"


${CONDA_RUN} pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/"$CU_VERSION"
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
test-infra-ref: main
build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
pre-script: ""
post-script: ""
post-script: .github/scripts/install_fbgemm.sh
package-name: torchrec
smoke-test-script: ""
trigger-event: ${{ github.event_name }}
Expand Down
2 changes: 1 addition & 1 deletion install-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
fbgemm-gpu-nightly
fbgemm-gpu
torchmetrics==1.0.3
tqdm
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
black
cmake
fbgemm-gpu-nightly
fbgemm-gpu
hypothesis==6.70.1
iopath
numpy
Expand Down
39 changes: 3 additions & 36 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import os
import subprocess
import sys
from datetime import date
from pathlib import Path
from typing import List

Expand Down Expand Up @@ -45,17 +44,6 @@ def _export_version(version, sha):
fileobj.write("git_version = {}\n".format(repr(sha)))


def get_channel():
# Channel typically takes on the following values:
# - NIGHTLY: for nightly published binaries
# - TEST: for binaries build from release candidate branches
return os.getenv("CHANNEL")


def get_cu_version():
return os.getenv("CU_VERSION", "cpu")


def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="torchrec setup")
return parser.parse_known_args(argv)
Expand All @@ -64,9 +52,6 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
def main(argv: List[str]) -> None:
args, unknown = parse_args(argv)

# Set up package version
channel = get_channel()

with open(
os.path.join(os.path.dirname(__file__), "README.MD"), encoding="utf8"
) as f:
Expand All @@ -81,25 +66,7 @@ def main(argv: List[str]) -> None:
version, sha = _get_version()
_export_version(version, sha)

# if channel != "nightly":
# if "fbgemm-gpu-nightly" in install_requires:
# install_requires.remove("fbgemm-gpu-nightly")
# install_requires.append("fbgemm-gpu")

cu_version = get_cu_version()
if cu_version == "cpu":
# if "fbgemm-gpu-nightly" in install_requires:
# install_requires.remove("fbgemm-gpu-nightly")
# install_requires.append("fbgemm-gpu-nightly-cpu")
# if "fbgemm-gpu" in install_requires:
install_requires.remove("fbgemm-gpu-nightly")
install_requires.append("fbgemm-gpu-cpu==0.5.0rc3")
else:
install_requires.remove("fbgemm-gpu-nightly")
install_requires.append("fbgemm-gpu==0.5.0rc2")


print(f"-- torchrec building version: {version} CU Version: {cu_version}")
print(f"-- torchrec building version: {version}")

packages = find_packages(
exclude=(
Expand All @@ -126,7 +93,7 @@ def main(argv: List[str]) -> None:
url="https://github.com/pytorch/torchrec",
license="BSD-3",
keywords=["pytorch", "recommendation systems", "sharding"],
python_requires=">=3.7",
python_requires=">=3.8",
install_requires=install_requires,
packages=packages,
zip_safe=False,
Expand All @@ -137,7 +104,7 @@ def main(argv: List[str]) -> None:
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
Expand Down

0 comments on commit b0254a7

Please sign in to comment.