-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for locking kernels (#10)
* PoC: allow users to lock the kernel revisions This change allows Python projects that use kernels to lock the kernel revisions on a project-basis. For this to work, the user only has to include `hf-kernels` as a build dependency. During the build, a lock file is written to the package's pkg-info. During runtime we can read it out and use the corresponding revision. When the kernel is not locked, the revision that is provided as an argument is used. * Generate lock files with `hf-lock-kernels`, copy to egg * Various improvements * Name CLI `hf-kernels`, add `download` subcommand * hf-kernels.lock * Bump version to 0.1.1 * Use setuptools for testing the wheel * Factor out tomllib module selection * Pass through `local_files_only` in `get_metadata` * Do not reuse implementation in `load_kernel` * The tests install hf-kernels from PyPI, should be local * docker: package is in subdirectory
- Loading branch information
Showing
8 changed files
with
256 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,28 @@ | ||
[project] | ||
name = "hf-kernels" | ||
version = "0.1.0" | ||
version = "0.1.1" | ||
description = "Download cuda kernels" | ||
authors = [ | ||
{name = "OlivierDehaene", email = "[email protected]"}, | ||
{name = "Daniel de Kok", email = "[email protected]"}, | ||
{name = "David Holtz", email = "[email protected]"}, | ||
{name = "Nicolas Patry", email = "[email protected]"} | ||
{ name = "OlivierDehaene", email = "[email protected]" }, | ||
{ name = "Daniel de Kok", email = "[email protected]" }, | ||
{ name = "David Holtz", email = "[email protected]" }, | ||
{ name = "Nicolas Patry", email = "[email protected]" }, | ||
] | ||
readme = "README.md" | ||
|
||
[project.scripts] | ||
hf-kernels = "hf_kernels.cli:main" | ||
|
||
[project.entry-points."egg_info.writers"] | ||
"hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile" | ||
|
||
[dependencies] | ||
python = "^3.9" | ||
huggingface-hub = "^0.26.3" | ||
packaging = "^24.2" | ||
tomli = { version = "^2.0.1", python = "<3.11" } | ||
|
||
[build-system] | ||
requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"] | ||
build-backend = "hf_kernels.build" | ||
backend-path = ["src"] | ||
#[build-system] | ||
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"] | ||
#build-backend = "hf_kernels.build" | ||
#backend-path = ["src"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from hf_kernels.utils import get_kernel, load_kernel, install_kernel | ||
from hf_kernels.utils import get_kernel, install_kernel, load_kernel | ||
|
||
__all__ = ["get_kernel", "load_kernel", "install_kernel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import argparse | ||
import dataclasses | ||
import json | ||
import sys | ||
from pathlib import Path | ||
|
||
from hf_kernels.compat import tomllib | ||
from hf_kernels.lockfile import KernelLock, get_kernel_locks | ||
from hf_kernels.utils import install_kernel | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
prog="hf-kernel", description="Manage compute kernels" | ||
) | ||
subparsers = parser.add_subparsers() | ||
|
||
download_parser = subparsers.add_parser("download", help="Download locked kernels") | ||
download_parser.add_argument( | ||
"project_dir", | ||
type=Path, | ||
help="The project directory", | ||
) | ||
download_parser.set_defaults(func=download_kernels) | ||
|
||
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") | ||
lock_parser.add_argument( | ||
"project_dir", | ||
type=Path, | ||
help="The project directory", | ||
) | ||
lock_parser.set_defaults(func=lock_kernels) | ||
|
||
args = parser.parse_args() | ||
args.func(args) | ||
|
||
|
||
def download_kernels(args): | ||
lock_path = args.project_dir / "hf-kernels.lock" | ||
|
||
if not lock_path.exists(): | ||
print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr) | ||
sys.exit(1) | ||
|
||
with open(args.project_dir / "hf-kernels.lock", "r") as f: | ||
lock_json = json.load(f) | ||
|
||
for kernel_lock_json in lock_json: | ||
kernel_lock = KernelLock.from_json(kernel_lock_json) | ||
print( | ||
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}", | ||
file=sys.stderr, | ||
) | ||
install_kernel(kernel_lock.repo_id, kernel_lock.sha) | ||
|
||
|
||
def lock_kernels(args): | ||
with open(args.project_dir / "pyproject.toml", "rb") as f: | ||
data = tomllib.load(f) | ||
|
||
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None) | ||
|
||
all_locks = [] | ||
for kernel, version in kernel_versions.items(): | ||
all_locks.append(get_kernel_locks(kernel, version)) | ||
|
||
with open(args.project_dir / "hf-kernels.lock", "w") as f: | ||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2) | ||
|
||
|
||
class _JSONEncoder(json.JSONEncoder): | ||
def default(self, o): | ||
if dataclasses.is_dataclass(o): | ||
return dataclasses.asdict(o) | ||
return super().default(o) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import sys | ||
|
||
if sys.version_info >= (3, 11): | ||
import tomllib | ||
else: | ||
import tomli as tomllib | ||
|
||
__all__ = ["tomllib"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Dict, List | ||
|
||
from huggingface_hub import HfApi | ||
|
||
from hf_kernels.compat import tomllib | ||
|
||
|
||
@dataclass | ||
class FileLock: | ||
filename: str | ||
blob_id: str | ||
|
||
|
||
@dataclass | ||
class KernelLock: | ||
repo_id: str | ||
sha: str | ||
files: List[FileLock] | ||
|
||
@classmethod | ||
def from_json(cls, o: Dict): | ||
files = [FileLock(**f) for f in o["files"]] | ||
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files) | ||
|
||
|
||
def get_kernel_locks(repo_id: str, revision: str): | ||
r = HfApi().repo_info(repo_id=repo_id, revision=revision, files_metadata=True) | ||
if r.sha is None: | ||
raise ValueError( | ||
f"Cannot get commit SHA for repo {repo_id} at revision {revision}" | ||
) | ||
|
||
if r.siblings is None: | ||
raise ValueError( | ||
f"Cannot get sibling information for {repo_id} at revision {revision}" | ||
) | ||
|
||
file_locks = [] | ||
for sibling in r.siblings: | ||
if sibling.rfilename.startswith("build/torch"): | ||
if sibling.blob_id is None: | ||
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}") | ||
|
||
file_locks.append( | ||
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id) | ||
) | ||
|
||
return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks) | ||
|
||
|
||
def write_egg_lockfile(cmd, basename, filename): | ||
import logging | ||
|
||
cwd = Path.cwd() | ||
with open(cwd / "pyproject.toml", "rb") as f: | ||
data = tomllib.load(f) | ||
|
||
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None) | ||
if kernel_versions is None: | ||
return | ||
|
||
lock_path = cwd / "hf-kernels.lock" | ||
if not lock_path.exists(): | ||
logging.warning(f"Lock file {lock_path} does not exist") | ||
# Ensure that the file gets deleted in editable installs. | ||
data = None | ||
else: | ||
data = open(lock_path, "r").read() | ||
|
||
cmd.write_or_delete_file(basename, filename, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters