Skip to content

Commit 50fbb4f

Browse files
committed
Use per-build variant content hashes in the lockfile
This makes the lock file a fair bit shorter than per-file hashes. They can also be verified without any knowledge of Git objects, etc.
1 parent df2c165 commit 50fbb4f

File tree

8 files changed

+148
-51
lines changed

8 files changed

+148
-51
lines changed

src/hf_kernels/cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from typing import Optional
2+
import os
3+
4+
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)

src/hf_kernels/cli.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from hf_kernels.compat import tomllib
88
from hf_kernels.lockfile import KernelLock, get_kernel_locks
9-
from hf_kernels.utils import install_kernel, install_kernel_all_variants
9+
from hf_kernels.utils import build_variant, install_kernel, install_kernel_all_variants
1010

1111

1212
def main():
@@ -53,16 +53,22 @@ def download_kernels(args):
5353
all_successful = True
5454

5555
for kernel_lock_json in lock_json:
56-
kernel_lock = KernelLock.from_json(kernel_lock_json)
56+
kernel_lock = KernelLock(**kernel_lock_json)
5757
print(
5858
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
5959
file=sys.stderr,
6060
)
6161
if args.all_variants:
62-
install_kernel_all_variants(kernel_lock.repo_id, kernel_lock.sha)
62+
install_kernel_all_variants(
63+
kernel_lock.repo_id, kernel_lock.sha, hashes=kernel_lock.variants
64+
)
6365
else:
6466
try:
65-
install_kernel(kernel_lock.repo_id, kernel_lock.sha)
67+
install_kernel(
68+
kernel_lock.repo_id,
69+
kernel_lock.sha,
70+
hash=kernel_lock.variants[build_variant()],
71+
)
6672
except FileNotFoundError as e:
6773
print(e, file=sys.stderr)
6874
all_successful = False

src/hf_kernels/hash.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import List, Tuple
2+
import hashlib
3+
import os
4+
from pathlib import Path
5+
6+
7+
def content_hash(dir: Path) -> str:
8+
"""Get a hash of the contents of a directory."""
9+
10+
# Get the file paths. The first element is a byte-encoded relative path
11+
# used for sorting. The second element is the absolute path.
12+
paths: List[Tuple[bytes, Path]] = []
13+
# Ideally we'd use Path.walk, but it's only available in Python 3.12.
14+
for dirpath, _, filenames in os.walk(dir):
15+
for filename in filenames:
16+
file_abs = Path(dirpath) / filename
17+
18+
# Python likes to create files when importing modules from the
19+
# cache, only hash files that are symlinked blobs.
20+
if file_abs.is_symlink():
21+
paths.append(
22+
(file_abs.relative_to(dir).as_posix().encode("utf-8"), file_abs)
23+
)
24+
25+
m = hashlib.sha256()
26+
for filename, path in sorted(paths):
27+
m.update(filename)
28+
with open(path, "rb") as f:
29+
m.update(f.read())
30+
31+
return f"sha256-{m.hexdigest()}"

src/hf_kernels/lockfile.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,21 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3-
from typing import Dict, List
3+
from typing import Dict
44

5-
from huggingface_hub import HfApi
5+
from huggingface_hub import HfApi, snapshot_download
66
from packaging.specifiers import SpecifierSet
77
from packaging.version import InvalidVersion, Version
88

99
from hf_kernels.compat import tomllib
10-
11-
12-
@dataclass
13-
class FileLock:
14-
filename: str
15-
blob_id: str
10+
from hf_kernels.cache import CACHE_DIR
11+
from hf_kernels.hash import content_hash
1612

1713

1814
@dataclass
1915
class KernelLock:
2016
repo_id: str
2117
sha: str
22-
files: List[FileLock]
23-
24-
@classmethod
25-
def from_json(cls, o: Dict):
26-
files = [FileLock(**f) for f in o["files"]]
27-
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
18+
variants: Dict[str, str]
2819

2920

3021
def _get_available_versions(repo_id: str):
@@ -59,30 +50,23 @@ def get_kernel_locks(repo_id: str, version_spec: str):
5950

6051
tag_for_newest = versions[accepted_versions[-1]]
6152

62-
r = HfApi().repo_info(
63-
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
64-
)
65-
if r.sha is None:
66-
raise ValueError(
67-
f"Cannot get commit SHA for repo {repo_id} for tag {tag_for_newest.name}"
68-
)
69-
70-
if r.siblings is None:
71-
raise ValueError(
72-
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
53+
repo_path = Path(
54+
snapshot_download(
55+
repo_id,
56+
allow_patterns="build/*",
57+
cache_dir=CACHE_DIR,
58+
revision=tag_for_newest.target_commit,
7359
)
60+
)
7461

75-
file_locks = []
76-
for sibling in r.siblings:
77-
if sibling.rfilename.startswith("build/torch"):
78-
if sibling.blob_id is None:
79-
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
80-
81-
file_locks.append(
82-
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
83-
)
62+
variant_hashes = {}
63+
for entry in (repo_path / "build").iterdir():
64+
variant = entry.parts[-1]
65+
variant_hashes[variant] = content_hash(entry)
8466

85-
return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
67+
return KernelLock(
68+
repo_id=repo_id, sha=tag_for_newest.target_commit, variants=variant_hashes
69+
)
8670

8771

8872
def write_egg_lockfile(cmd, basename, filename):

src/hf_kernels/utils.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from huggingface_hub import hf_hub_download, snapshot_download
1515
from packaging.version import parse
1616

17+
from hf_kernels.cache import CACHE_DIR
1718
from hf_kernels.compat import tomllib
19+
from hf_kernels.hash import content_hash
1820
from hf_kernels.lockfile import KernelLock
1921

20-
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
21-
2222

2323
def build_variant():
2424
import torch
@@ -47,7 +47,10 @@ def import_from_path(module_name: str, file_path):
4747

4848

4949
def install_kernel(
50-
repo_id: str, revision: str, local_files_only: bool = False
50+
repo_id: str,
51+
revision: str,
52+
local_files_only: bool = False,
53+
hash: Optional[str] = None,
5154
) -> Tuple[str, str]:
5255
"""Download a kernel for the current environment to the cache."""
5356
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
@@ -62,6 +65,14 @@ def install_kernel(
6265
)
6366

6467
variant_path = f"{repo_path}/build/{build_variant()}"
68+
69+
if hash is not None:
70+
found_hash = content_hash(Path(variant_path))
71+
if found_hash != hash:
72+
raise ValueError(
73+
f"Expected hash {hash} for path {variant_path}, but got: {found_hash}"
74+
)
75+
6576
module_init_path = f"{variant_path}/{package_name}/__init__.py"
6677

6778
if not os.path.exists(module_init_path):
@@ -73,16 +84,37 @@ def install_kernel(
7384

7485

7586
def install_kernel_all_variants(
76-
repo_id: str, revision: str, local_files_only: bool = False
77-
):
78-
snapshot_download(
79-
repo_id,
80-
allow_patterns="build/*",
81-
cache_dir=CACHE_DIR,
82-
revision=revision,
83-
local_files_only=local_files_only,
87+
repo_id: str,
88+
revision: str,
89+
local_files_only: bool = False,
90+
hashes: Optional[dict] = None,
91+
) -> str:
92+
repo_path = Path(
93+
snapshot_download(
94+
repo_id,
95+
allow_patterns="build/*",
96+
cache_dir=CACHE_DIR,
97+
revision=revision,
98+
local_files_only=local_files_only,
99+
)
84100
)
85101

102+
for entry in (repo_path / "build").iterdir():
103+
variant = entry.parts[-1]
104+
105+
if hashes is not None:
106+
hash = hashes.get(variant)
107+
if hash is None:
108+
raise ValueError(f"No hash found for build variant: {variant}")
109+
110+
found_hash = content_hash(entry)
111+
if found_hash != hash:
112+
raise ValueError(
113+
f"Expected hash {hash} for path {entry}, but got: {found_hash}"
114+
)
115+
116+
return f"{repo_path}/build"
117+
86118

87119
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
88120
with open(
@@ -145,7 +177,7 @@ def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
145177
lock_json = dist.read_text("hf-kernels.lock")
146178
if lock_json is not None:
147179
for kernel_lock_json in json.loads(lock_json):
148-
kernel_lock = KernelLock.from_json(kernel_lock_json)
180+
kernel_lock = KernelLock(**kernel_lock_json)
149181
if kernel_lock.repo_id == repo_id:
150182
return kernel_lock.sha
151183
return None

tests/hash_validation/hf-kernels.lock

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[
2+
{
3+
"repo_id": "kernels-community/activation",
4+
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
5+
"variants": {
6+
"torch25-cxx98-cu124-x86_64-linux": "sha256-fc465a135358badf9670792fff90b171623cad84989e2397d9edd2c3c2b036f0",
7+
"torch25-cxx11-cu121-x86_64-linux": "sha256-c9aae75c623e20c16c19fc0d6bf7b9b67823d03dd502c80e39d9c0534fbf0fc0",
8+
"torch26-cxx11-cu118-x86_64-linux": "sha256-dcc1141a70845c5ae20b953cbc8496c265b158ebed8d31c037119f286ff62e06",
9+
"torch25-cxx98-cu118-x86_64-linux": "sha256-8ab0577ecf1a96ea6b55b60ed95bc3c02595fdefedf67967b68b93a22c5978d5",
10+
"torch25-cxx11-cu124-x86_64-linux": "sha256-d0cb59cb5cda691f212004e59a89168d5756ae75888be2cd5ca7b470402d5927",
11+
"torch26-cxx11-cu124-x86_64-linux": "sha256-e3f4adcb205fa00c6a16e2c07dd21f037d31165c1b92fdcf70b9038852e2a8ad",
12+
"torch26-cxx98-cu126-x86_64-linux": "sha256-6ebc15cccd9c61eef186ba6d288c72839f7d5a51bf1a46fafec5d673417c040f",
13+
"torch25-cxx98-cu121-x86_64-linux": "sha256-075bf0f2e208ffd76d219c4f6cea80ba5e1ee2fd9f172786b5e681cbf4d494f7",
14+
"torch26-cxx98-cu118-x86_64-linux": "sha256-51edd32659cef0640ee128228c0478e1b7068b4182075a41047ee288bcaa05a0",
15+
"torch26-cxx98-cu124-x86_64-linux": "sha256-6dddf50d68b9f94bdf101571080b74fa50d084a1e3ac10f7e1b4ff959cb2685f",
16+
"torch26-cxx11-cu126-x86_64-linux": "sha256-a9cac68d9722bb809267777f481522efc81980e237207e5b998926f1c5c88108",
17+
"torch25-cxx11-cu118-x86_64-linux": "sha256-cdc3b4142939f1ad8ac18865c83ac69fc0450f46d3dcb5507f2a69b6a292fb70"
18+
}
19+
}
20+
]

tests/hash_validation/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[tool.kernels.dependencies]
2+
"kernels-community/activation" = ">=0.0.2"

tests/test_hash_validation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
4+
from hf_kernels.cli import download_kernels
5+
6+
7+
# Mock download arguments class.
8+
@dataclass
9+
class DownloadArgs:
10+
all_variants: bool
11+
project_dir: Path
12+
13+
14+
def test_hash_validation():
15+
project_dir = Path(__file__).parent / "hash_validation"
16+
17+
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
18+
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))

0 commit comments

Comments
 (0)