Skip to content

Commit

Permalink
Use per-build variant content hashes in the lockfile
Browse files Browse the repository at this point in the history
This makes the lock file a fair bit shorter than per-file hashes. They
can also be verified without any knowledge of Git objects, etc.
  • Loading branch information
danieldk committed Feb 19, 2025
1 parent df2c165 commit 50fbb4f
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 51 deletions.
4 changes: 4 additions & 0 deletions src/hf_kernels/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Optional
import os

CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
14 changes: 10 additions & 4 deletions src/hf_kernels/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock, get_kernel_locks
from hf_kernels.utils import install_kernel, install_kernel_all_variants
from hf_kernels.utils import build_variant, install_kernel, install_kernel_all_variants


def main():
Expand Down Expand Up @@ -53,16 +53,22 @@ def download_kernels(args):
all_successful = True

for kernel_lock_json in lock_json:
kernel_lock = KernelLock.from_json(kernel_lock_json)
kernel_lock = KernelLock(**kernel_lock_json)
print(
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
file=sys.stderr,
)
if args.all_variants:
install_kernel_all_variants(kernel_lock.repo_id, kernel_lock.sha)
install_kernel_all_variants(
kernel_lock.repo_id, kernel_lock.sha, hashes=kernel_lock.variants
)
else:
try:
install_kernel(kernel_lock.repo_id, kernel_lock.sha)
install_kernel(
kernel_lock.repo_id,
kernel_lock.sha,
hash=kernel_lock.variants[build_variant()],
)
except FileNotFoundError as e:
print(e, file=sys.stderr)
all_successful = False
Expand Down
31 changes: 31 additions & 0 deletions src/hf_kernels/hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Tuple
import hashlib
import os
from pathlib import Path


def content_hash(dir: Path) -> str:
"""Get a hash of the contents of a directory."""

# Get the file paths. The first element is a byte-encoded relative path
# used for sorting. The second element is the absolute path.
paths: List[Tuple[bytes, Path]] = []
# Ideally we'd use Path.walk, but it's only available in Python 3.12.
for dirpath, _, filenames in os.walk(dir):
for filename in filenames:
file_abs = Path(dirpath) / filename

# Python likes to create files when importing modules from the
# cache, only hash files that are symlinked blobs.
if file_abs.is_symlink():
paths.append(
(file_abs.relative_to(dir).as_posix().encode("utf-8"), file_abs)
)

m = hashlib.sha256()
for filename, path in sorted(paths):
m.update(filename)
with open(path, "rb") as f:
m.update(f.read())

return f"sha256-{m.hexdigest()}"
54 changes: 19 additions & 35 deletions src/hf_kernels/lockfile.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,21 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
from typing import Dict

from huggingface_hub import HfApi
from huggingface_hub import HfApi, snapshot_download
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version

from hf_kernels.compat import tomllib


@dataclass
class FileLock:
filename: str
blob_id: str
from hf_kernels.cache import CACHE_DIR
from hf_kernels.hash import content_hash


@dataclass
class KernelLock:
repo_id: str
sha: str
files: List[FileLock]

@classmethod
def from_json(cls, o: Dict):
files = [FileLock(**f) for f in o["files"]]
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
variants: Dict[str, str]


def _get_available_versions(repo_id: str):
Expand Down Expand Up @@ -59,30 +50,23 @@ def get_kernel_locks(repo_id: str, version_spec: str):

tag_for_newest = versions[accepted_versions[-1]]

r = HfApi().repo_info(
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
)
if r.sha is None:
raise ValueError(
f"Cannot get commit SHA for repo {repo_id} for tag {tag_for_newest.name}"
)

if r.siblings is None:
raise ValueError(
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=tag_for_newest.target_commit,
)
)

file_locks = []
for sibling in r.siblings:
if sibling.rfilename.startswith("build/torch"):
if sibling.blob_id is None:
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")

file_locks.append(
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
)
variant_hashes = {}
for entry in (repo_path / "build").iterdir():
variant = entry.parts[-1]
variant_hashes[variant] = content_hash(entry)

return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
return KernelLock(
repo_id=repo_id, sha=tag_for_newest.target_commit, variants=variant_hashes
)


def write_egg_lockfile(cmd, basename, filename):
Expand Down
56 changes: 44 additions & 12 deletions src/hf_kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from huggingface_hub import hf_hub_download, snapshot_download
from packaging.version import parse

from hf_kernels.cache import CACHE_DIR
from hf_kernels.compat import tomllib
from hf_kernels.hash import content_hash
from hf_kernels.lockfile import KernelLock

CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)


def build_variant():
import torch
Expand Down Expand Up @@ -47,7 +47,10 @@ def import_from_path(module_name: str, file_path):


def install_kernel(
repo_id: str, revision: str, local_files_only: bool = False
repo_id: str,
revision: str,
local_files_only: bool = False,
hash: Optional[str] = None,
) -> Tuple[str, str]:
"""Download a kernel for the current environment to the cache."""
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
Expand All @@ -62,6 +65,14 @@ def install_kernel(
)

variant_path = f"{repo_path}/build/{build_variant()}"

if hash is not None:
found_hash = content_hash(Path(variant_path))
if found_hash != hash:
raise ValueError(
f"Expected hash {hash} for path {variant_path}, but got: {found_hash}"
)

module_init_path = f"{variant_path}/{package_name}/__init__.py"

if not os.path.exists(module_init_path):
Expand All @@ -73,16 +84,37 @@ def install_kernel(


def install_kernel_all_variants(
repo_id: str, revision: str, local_files_only: bool = False
):
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
repo_id: str,
revision: str,
local_files_only: bool = False,
hashes: Optional[dict] = None,
) -> str:
repo_path = Path(
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
)
)

for entry in (repo_path / "build").iterdir():
variant = entry.parts[-1]

if hashes is not None:
hash = hashes.get(variant)
if hash is None:
raise ValueError(f"No hash found for build variant: {variant}")

found_hash = content_hash(entry)
if found_hash != hash:
raise ValueError(
f"Expected hash {hash} for path {entry}, but got: {found_hash}"
)

return f"{repo_path}/build"


def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
with open(
Expand Down Expand Up @@ -145,7 +177,7 @@ def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
lock_json = dist.read_text("hf-kernels.lock")
if lock_json is not None:
for kernel_lock_json in json.loads(lock_json):
kernel_lock = KernelLock.from_json(kernel_lock_json)
kernel_lock = KernelLock(**kernel_lock_json)
if kernel_lock.repo_id == repo_id:
return kernel_lock.sha
return None
Expand Down
20 changes: 20 additions & 0 deletions tests/hash_validation/hf-kernels.lock
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
"variants": {
"torch25-cxx98-cu124-x86_64-linux": "sha256-fc465a135358badf9670792fff90b171623cad84989e2397d9edd2c3c2b036f0",
"torch25-cxx11-cu121-x86_64-linux": "sha256-c9aae75c623e20c16c19fc0d6bf7b9b67823d03dd502c80e39d9c0534fbf0fc0",
"torch26-cxx11-cu118-x86_64-linux": "sha256-dcc1141a70845c5ae20b953cbc8496c265b158ebed8d31c037119f286ff62e06",
"torch25-cxx98-cu118-x86_64-linux": "sha256-8ab0577ecf1a96ea6b55b60ed95bc3c02595fdefedf67967b68b93a22c5978d5",
"torch25-cxx11-cu124-x86_64-linux": "sha256-d0cb59cb5cda691f212004e59a89168d5756ae75888be2cd5ca7b470402d5927",
"torch26-cxx11-cu124-x86_64-linux": "sha256-e3f4adcb205fa00c6a16e2c07dd21f037d31165c1b92fdcf70b9038852e2a8ad",
"torch26-cxx98-cu126-x86_64-linux": "sha256-6ebc15cccd9c61eef186ba6d288c72839f7d5a51bf1a46fafec5d673417c040f",
"torch25-cxx98-cu121-x86_64-linux": "sha256-075bf0f2e208ffd76d219c4f6cea80ba5e1ee2fd9f172786b5e681cbf4d494f7",
"torch26-cxx98-cu118-x86_64-linux": "sha256-51edd32659cef0640ee128228c0478e1b7068b4182075a41047ee288bcaa05a0",
"torch26-cxx98-cu124-x86_64-linux": "sha256-6dddf50d68b9f94bdf101571080b74fa50d084a1e3ac10f7e1b4ff959cb2685f",
"torch26-cxx11-cu126-x86_64-linux": "sha256-a9cac68d9722bb809267777f481522efc81980e237207e5b998926f1c5c88108",
"torch25-cxx11-cu118-x86_64-linux": "sha256-cdc3b4142939f1ad8ac18865c83ac69fc0450f46d3dcb5507f2a69b6a292fb70"
}
}
]
2 changes: 2 additions & 0 deletions tests/hash_validation/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.kernels.dependencies]
"kernels-community/activation" = ">=0.0.2"
18 changes: 18 additions & 0 deletions tests/test_hash_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from pathlib import Path

from hf_kernels.cli import download_kernels


# Mock download arguments class.
@dataclass
class DownloadArgs:
all_variants: bool
project_dir: Path


def test_hash_validation():
project_dir = Path(__file__).parent / "hash_validation"

download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))

0 comments on commit 50fbb4f

Please sign in to comment.