Skip to content

Commit

Permalink
Add support for locking kernels (#10)
Browse files Browse the repository at this point in the history
* PoC: allow users to lock the kernel revisions

This change allows Python projects that use kernels to lock the
kernel revisions on a project-basis. For this to work, the user
only has to include `hf-kernels` as a build dependency. During
the build, a lock file is written to the package's pkg-info.
During runtime we can read it out and use the corresponding
revision. When the kernel is not locked, the revision that is provided
as an argument is used.

* Generate lock files with `hf-lock-kernels`, copy to egg

* Various improvements

* Name CLI `hf-kernels`, add `download` subcommand

* hf-kernels.lock

* Bump version to 0.1.1

* Use setuptools for testing the wheel

* Factor out tomllib module selection

* Pass through `local_files_only` in `get_metadata`

* Do not reuse implementation in `load_kernel`

* The tests install hf-kernels from PyPI, should be local

* docker: package is in subdirectory
  • Loading branch information
danieldk authored Jan 21, 2025
1 parent 105704b commit 544354c
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ COPY examples ./examples
COPY tests ./tests

# Install the kernel library
RUN uv pip install hf_kernels
RUN uv pip install ./hf_kernels

# Run tests and benchmarks
CMD [".venv/bin/pytest", "tests", "-v"]
24 changes: 15 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
[project]
name = "hf-kernels"
version = "0.1.0"
version = "0.1.1"
description = "Download cuda kernels"
authors = [
{name = "OlivierDehaene", email = "[email protected]"},
{name = "Daniel de Kok", email = "[email protected]"},
{name = "David Holtz", email = "[email protected]"},
{name = "Nicolas Patry", email = "[email protected]"}
{ name = "OlivierDehaene", email = "[email protected]" },
{ name = "Daniel de Kok", email = "[email protected]" },
{ name = "David Holtz", email = "[email protected]" },
{ name = "Nicolas Patry", email = "[email protected]" },
]
readme = "README.md"

[project.scripts]
hf-kernels = "hf_kernels.cli:main"

[project.entry-points."egg_info.writers"]
"hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile"

[dependencies]
python = "^3.9"
huggingface-hub = "^0.26.3"
packaging = "^24.2"
tomli = { version = "^2.0.1", python = "<3.11" }

[build-system]
requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
build-backend = "hf_kernels.build"
backend-path = ["src"]
#[build-system]
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
#build-backend = "hf_kernels.build"
#backend-path = ["src"]
2 changes: 1 addition & 1 deletion src/hf_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from hf_kernels.utils import get_kernel, load_kernel, install_kernel
from hf_kernels.utils import get_kernel, install_kernel, load_kernel

__all__ = ["get_kernel", "load_kernel", "install_kernel"]
7 changes: 1 addition & 6 deletions src/hf_kernels/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,13 @@
them while IDEs and type checker can see through the quotes.
"""

import sys
from hf_kernels.compat import tomllib

TYPE_CHECKING = False
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence # noqa:I001
from typing import Any # noqa:I001

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib


def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None:
import sys
Expand Down
75 changes: 75 additions & 0 deletions src/hf_kernels/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import argparse
import dataclasses
import json
import sys
from pathlib import Path

from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock, get_kernel_locks
from hf_kernels.utils import install_kernel


def main():
parser = argparse.ArgumentParser(
prog="hf-kernel", description="Manage compute kernels"
)
subparsers = parser.add_subparsers()

download_parser = subparsers.add_parser("download", help="Download locked kernels")
download_parser.add_argument(
"project_dir",
type=Path,
help="The project directory",
)
download_parser.set_defaults(func=download_kernels)

lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
lock_parser.add_argument(
"project_dir",
type=Path,
help="The project directory",
)
lock_parser.set_defaults(func=lock_kernels)

args = parser.parse_args()
args.func(args)


def download_kernels(args):
lock_path = args.project_dir / "hf-kernels.lock"

if not lock_path.exists():
print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr)
sys.exit(1)

with open(args.project_dir / "hf-kernels.lock", "r") as f:
lock_json = json.load(f)

for kernel_lock_json in lock_json:
kernel_lock = KernelLock.from_json(kernel_lock_json)
print(
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
file=sys.stderr,
)
install_kernel(kernel_lock.repo_id, kernel_lock.sha)


def lock_kernels(args):
with open(args.project_dir / "pyproject.toml", "rb") as f:
data = tomllib.load(f)

kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)

all_locks = []
for kernel, version in kernel_versions.items():
all_locks.append(get_kernel_locks(kernel, version))

with open(args.project_dir / "hf-kernels.lock", "w") as f:
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)


class _JSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
8 changes: 8 additions & 0 deletions src/hf_kernels/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import sys

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib

__all__ = ["tomllib"]
72 changes: 72 additions & 0 deletions src/hf_kernels/lockfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List

from huggingface_hub import HfApi

from hf_kernels.compat import tomllib


@dataclass
class FileLock:
filename: str
blob_id: str


@dataclass
class KernelLock:
repo_id: str
sha: str
files: List[FileLock]

@classmethod
def from_json(cls, o: Dict):
files = [FileLock(**f) for f in o["files"]]
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)


def get_kernel_locks(repo_id: str, revision: str):
r = HfApi().repo_info(repo_id=repo_id, revision=revision, files_metadata=True)
if r.sha is None:
raise ValueError(
f"Cannot get commit SHA for repo {repo_id} at revision {revision}"
)

if r.siblings is None:
raise ValueError(
f"Cannot get sibling information for {repo_id} at revision {revision}"
)

file_locks = []
for sibling in r.siblings:
if sibling.rfilename.startswith("build/torch"):
if sibling.blob_id is None:
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")

file_locks.append(
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
)

return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)


def write_egg_lockfile(cmd, basename, filename):
import logging

cwd = Path.cwd()
with open(cwd / "pyproject.toml", "rb") as f:
data = tomllib.load(f)

kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
if kernel_versions is None:
return

lock_path = cwd / "hf-kernels.lock"
if not lock_path.exists():
logging.warning(f"Lock file {lock_path} does not exist")
# Ensure that the file gets deleted in editable installs.
data = None
else:
data = open(lock_path, "r").read()

cmd.write_or_delete_file(basename, filename, data)
95 changes: 83 additions & 12 deletions src/hf_kernels/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import importlib
import importlib.metadata
import inspect
import json
import os
import platform
import sys
import os
from importlib.metadata import Distribution
from types import ModuleType
from typing import List, Optional

import torch
from huggingface_hub import hf_hub_download, snapshot_download
from packaging.version import parse

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock


def build_variant():
Expand All @@ -31,16 +35,26 @@ def import_from_path(module_name: str, file_path):
return module


def install_kernel(repo_id: str, revision: str):
package_name = get_metadata(repo_id)["torch"]["name"]
def install_kernel(repo_id: str, revision: str, local_files_only: bool = False):
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
"torch"
]["name"]
repo_path = snapshot_download(
repo_id, allow_patterns=f"build/{build_variant()}/*", revision=revision
repo_id,
allow_patterns=f"build/{build_variant()}/*",
revision=revision,
local_files_only=local_files_only,
)
return package_name, f"{repo_path}/build/{build_variant()}"


def get_metadata(repo_id: str):
with open(hf_hub_download(repo_id, "build.toml"), "rb") as f:
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
with open(
hf_hub_download(
repo_id, "build.toml", revision=revision, local_files_only=local_files_only
),
"rb",
) as f:
return tomllib.load(f)


Expand All @@ -49,13 +63,70 @@ def get_kernel(repo_id: str, revision: str = "main"):
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")


def load_kernel(repo_id: str, revision: str = "main"):
def load_kernel(repo_id: str):
"""Get a pre-downloaded, locked kernel."""
locked_sha = _get_caller_locked_kernel(repo_id)

if locked_sha is None:
raise ValueError(f"Kernel `{repo_id}` is not locked")

filename = hf_hub_download(
repo_id, "build.toml", local_files_only=True, revision=revision
repo_id, "build.toml", local_files_only=True, revision=locked_sha
)
with open(filename, "rb") as f:
metadata = tomllib.load(f)
package_name = metadata["torch"]["name"]

repo_path = os.path.dirname(filename)
package_path = f"{repo_path}/build/{build_variant()}"
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")


def get_locked_kernel(repo_id: str, local_files_only: bool = False):
"""Get a kernel using a lock file."""
locked_sha = _get_caller_locked_kernel(repo_id)

if locked_sha is None:
raise ValueError(f"Kernel `{repo_id}` is not locked")

package_name, package_path = install_kernel(
repo_id, locked_sha, local_files_only=local_files_only
)

return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")


def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
for dist in _get_caller_distributions():
lock_json = dist.read_text("hf-kernels.lock")
if lock_json is not None:
for kernel_lock_json in json.loads(lock_json):
kernel_lock = KernelLock.from_json(kernel_lock_json)
if kernel_lock.repo_id == repo_id:
return kernel_lock.sha
return None


def _get_caller_distributions() -> List[Distribution]:
module = _get_caller_module()
if module is None:
return []

# Look up all possible distributions that this module could be from.
package = module.__name__.split(".")[0]
dist_names = importlib.metadata.packages_distributions().get(package)
if dist_names is None:
return []

return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]


def _get_caller_module() -> Optional[ModuleType]:
stack = inspect.stack()
# Get first module in the stack that is not the current module.
first_module = inspect.getmodule(stack[0][0])
for frame in stack[1:]:
module = inspect.getmodule(frame[0])
if module is not None and module != first_module:
return module
return first_module

0 comments on commit 544354c

Please sign in to comment.