Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use per-build variant content hashes in the lockfile #26

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/hf_kernels/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Optional
import os

CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
12 changes: 9 additions & 3 deletions src/hf_kernels/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock, get_kernel_locks
from hf_kernels.utils import install_kernel, install_kernel_all_variants
from hf_kernels.utils import build_variant, install_kernel, install_kernel_all_variants


def main():
Expand Down Expand Up @@ -59,10 +59,16 @@ def download_kernels(args):
file=sys.stderr,
)
if args.all_variants:
install_kernel_all_variants(kernel_lock.repo_id, kernel_lock.sha)
install_kernel_all_variants(
kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants
)
else:
try:
install_kernel(kernel_lock.repo_id, kernel_lock.sha)
install_kernel(
kernel_lock.repo_id,
kernel_lock.sha,
variant_locks=kernel_lock.variants,
)
except FileNotFoundError as e:
print(e, file=sys.stderr)
all_successful = False
Expand Down
31 changes: 31 additions & 0 deletions src/hf_kernels/hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Tuple
import hashlib
import os
from pathlib import Path


def content_hash(dir: Path) -> str:
"""Get a hash of the contents of a directory."""

# Get the file paths. The first element is a byte-encoded relative path
# used for sorting. The second element is the absolute path.
paths: List[Tuple[bytes, Path]] = []
# Ideally we'd use Path.walk, but it's only available in Python 3.12.
for dirpath, _, filenames in os.walk(dir):
for filename in filenames:
file_abs = Path(dirpath) / filename

# Python likes to create files when importing modules from the
# cache, only hash files that are symlinked blobs.
if file_abs.is_symlink():
paths.append(
(file_abs.relative_to(dir).as_posix().encode("utf-8"), file_abs)
)

m = hashlib.sha256()
for filename, path in sorted(paths):
m.update(filename)
with open(path, "rb") as f:
m.update(f.read())

return f"sha256-{m.hexdigest()}"
55 changes: 26 additions & 29 deletions src/hf_kernels/lockfile.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
from typing import Dict

from huggingface_hub import HfApi
from huggingface_hub import HfApi, snapshot_download
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version

from hf_kernels.compat import tomllib
from hf_kernels.cache import CACHE_DIR
from hf_kernels.hash import content_hash


@dataclass
class FileLock:
filename: str
blob_id: str
class VariantLock:
hash: str
hash_type: str = "recursive_file_hash"


@dataclass
class KernelLock:
repo_id: str
sha: str
files: List[FileLock]
variants: Dict[str, VariantLock]

@classmethod
def from_json(cls, o: Dict):
files = [FileLock(**f) for f in o["files"]]
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
variants = {
variant: VariantLock(**lock) for variant, lock in o["variants"].items()
}
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)


def _get_available_versions(repo_id: str):
Expand Down Expand Up @@ -59,30 +63,23 @@ def get_kernel_locks(repo_id: str, version_spec: str):

tag_for_newest = versions[accepted_versions[-1]]

r = HfApi().repo_info(
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
)
if r.sha is None:
raise ValueError(
f"Cannot get commit SHA for repo {repo_id} for tag {tag_for_newest.name}"
)

if r.siblings is None:
raise ValueError(
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=tag_for_newest.target_commit,
)
)

file_locks = []
for sibling in r.siblings:
if sibling.rfilename.startswith("build/torch"):
if sibling.blob_id is None:
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")

file_locks.append(
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
)
variant_hashes = {}
for entry in (repo_path / "build").iterdir():
variant = entry.parts[-1]
variant_hashes[variant] = VariantLock(hash=content_hash(entry))

return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
return KernelLock(
repo_id=repo_id, sha=tag_for_newest.target_commit, variants=variant_hashes
)


def write_egg_lockfile(cmd, basename, filename):
Expand Down
73 changes: 57 additions & 16 deletions src/hf_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import sys
from importlib.metadata import Distribution
from types import ModuleType
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

from huggingface_hub import hf_hub_download, snapshot_download
from packaging.version import parse

from hf_kernels.cache import CACHE_DIR
from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock

CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
from hf_kernels.hash import content_hash
from hf_kernels.lockfile import KernelLock, VariantLock


def build_variant():
Expand Down Expand Up @@ -47,42 +47,83 @@ def import_from_path(module_name: str, file_path):


def install_kernel(
repo_id: str, revision: str, local_files_only: bool = False
repo_id: str,
revision: str,
local_files_only: bool = False,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, str]:
"""Download a kernel for the current environment to the cache."""
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
"torch"
]["name"]
variant = build_variant()

repo_path = snapshot_download(
repo_id,
allow_patterns=f"build/{build_variant()}/*",
allow_patterns=f"build/{variant}/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
)

variant_path = f"{repo_path}/build/{build_variant()}"
variant_path = f"{repo_path}/build/{variant}"

if variant_locks is not None:
variant_lock = variant_locks.get(variant)
if variant_lock is None:
raise ValueError(f"No lock found for build variant: {variant}")

hash = variant_lock.hash

found_hash = content_hash(Path(variant_path))
if found_hash != hash:
raise ValueError(
f"Expected hash {hash} for path {variant_path}, but got: {found_hash}"
)

module_init_path = f"{variant_path}/{package_name}/__init__.py"

if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {build_variant()}"
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)

return package_name, variant_path


def install_kernel_all_variants(
repo_id: str, revision: str, local_files_only: bool = False
):
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
repo_id: str,
revision: str,
local_files_only: bool = False,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> str:
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
)
)

for entry in (repo_path / "build").iterdir():
variant = entry.parts[-1]

if variant_locks is not None:
variant_lock = variant_locks.get(variant)
if variant_lock is None:
raise ValueError(f"No lock found for build variant: {variant}")

hash = variant_lock.hash
found_hash = content_hash(entry)
if found_hash != hash:
raise ValueError(
f"Expected hash {hash} for path {entry}, but got: {found_hash}"
)

return f"{repo_path}/build"


def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
with open(
Expand Down
56 changes: 56 additions & 0 deletions tests/hash_validation/hf-kernels.lock
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
"variants": {
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-fc465a135358badf9670792fff90b171623cad84989e2397d9edd2c3c2b036f0",
"hash_type": "recursive_file_hash"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-c9aae75c623e20c16c19fc0d6bf7b9b67823d03dd502c80e39d9c0534fbf0fc0",
"hash_type": "recursive_file_hash"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-dcc1141a70845c5ae20b953cbc8496c265b158ebed8d31c037119f286ff62e06",
"hash_type": "recursive_file_hash"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-8ab0577ecf1a96ea6b55b60ed95bc3c02595fdefedf67967b68b93a22c5978d5",
"hash_type": "recursive_file_hash"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-d0cb59cb5cda691f212004e59a89168d5756ae75888be2cd5ca7b470402d5927",
"hash_type": "recursive_file_hash"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-e3f4adcb205fa00c6a16e2c07dd21f037d31165c1b92fdcf70b9038852e2a8ad",
"hash_type": "recursive_file_hash"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-6ebc15cccd9c61eef186ba6d288c72839f7d5a51bf1a46fafec5d673417c040f",
"hash_type": "recursive_file_hash"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-075bf0f2e208ffd76d219c4f6cea80ba5e1ee2fd9f172786b5e681cbf4d494f7",
"hash_type": "recursive_file_hash"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-51edd32659cef0640ee128228c0478e1b7068b4182075a41047ee288bcaa05a0",
"hash_type": "recursive_file_hash"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-6dddf50d68b9f94bdf101571080b74fa50d084a1e3ac10f7e1b4ff959cb2685f",
"hash_type": "recursive_file_hash"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-a9cac68d9722bb809267777f481522efc81980e237207e5b998926f1c5c88108",
"hash_type": "recursive_file_hash"
},
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-cdc3b4142939f1ad8ac18865c83ac69fc0450f46d3dcb5507f2a69b6a292fb70",
"hash_type": "recursive_file_hash"
}
}
}
]
2 changes: 2 additions & 0 deletions tests/hash_validation/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.kernels.dependencies]
"kernels-community/activation" = ">=0.0.2"
21 changes: 21 additions & 0 deletions tests/test_hash_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from pathlib import Path

from hf_kernels.cli import download_kernels


# Mock download arguments class.
@dataclass
class DownloadArgs:
all_variants: bool
project_dir: Path


def test_download_hash_validation():
project_dir = Path(__file__).parent / "hash_validation"
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))


def test_download_all_hash_validation():
project_dir = Path(__file__).parent / "hash_validation"
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))