14
14
from huggingface_hub import hf_hub_download , snapshot_download
15
15
from packaging .version import parse
16
16
17
+ from hf_kernels .cache import CACHE_DIR
17
18
from hf_kernels .compat import tomllib
19
+ from hf_kernels .hash import content_hash
18
20
from hf_kernels .lockfile import KernelLock
19
21
20
- CACHE_DIR : Optional [str ] = os .environ .get ("HF_KERNELS_CACHE" , None )
21
-
22
22
23
23
def build_variant ():
24
24
import torch
@@ -47,7 +47,10 @@ def import_from_path(module_name: str, file_path):
47
47
48
48
49
49
def install_kernel (
50
- repo_id : str , revision : str , local_files_only : bool = False
50
+ repo_id : str ,
51
+ revision : str ,
52
+ local_files_only : bool = False ,
53
+ hash : Optional [str ] = None ,
51
54
) -> Tuple [str , str ]:
52
55
"""Download a kernel for the current environment to the cache."""
53
56
package_name = get_metadata (repo_id , revision , local_files_only = local_files_only )[
@@ -62,6 +65,14 @@ def install_kernel(
62
65
)
63
66
64
67
variant_path = f"{ repo_path } /build/{ build_variant ()} "
68
+
69
+ if hash is not None :
70
+ found_hash = content_hash (Path (variant_path ))
71
+ if found_hash != hash :
72
+ raise ValueError (
73
+ f"Expected hash { hash } for path { variant_path } , but got: { found_hash } "
74
+ )
75
+
65
76
module_init_path = f"{ variant_path } /{ package_name } /__init__.py"
66
77
67
78
if not os .path .exists (module_init_path ):
@@ -73,16 +84,37 @@ def install_kernel(
73
84
74
85
75
86
def install_kernel_all_variants (
76
- repo_id : str , revision : str , local_files_only : bool = False
77
- ):
78
- snapshot_download (
79
- repo_id ,
80
- allow_patterns = "build/*" ,
81
- cache_dir = CACHE_DIR ,
82
- revision = revision ,
83
- local_files_only = local_files_only ,
87
+ repo_id : str ,
88
+ revision : str ,
89
+ local_files_only : bool = False ,
90
+ hashes : Optional [dict ] = None ,
91
+ ) -> str :
92
+ repo_path = Path (
93
+ snapshot_download (
94
+ repo_id ,
95
+ allow_patterns = "build/*" ,
96
+ cache_dir = CACHE_DIR ,
97
+ revision = revision ,
98
+ local_files_only = local_files_only ,
99
+ )
84
100
)
85
101
102
+ for entry in (repo_path / "build" ).iterdir ():
103
+ variant = entry .parts [- 1 ]
104
+
105
+ if hashes is not None :
106
+ hash = hashes .get (variant )
107
+ if hash is None :
108
+ raise ValueError (f"No hash found for build variant: { variant } " )
109
+
110
+ found_hash = content_hash (entry )
111
+ if found_hash != hash :
112
+ raise ValueError (
113
+ f"Expected hash { hash } for path { entry } , but got: { found_hash } "
114
+ )
115
+
116
+ return f"{ repo_path } /build"
117
+
86
118
87
119
def get_metadata (repo_id : str , revision : str , local_files_only : bool = False ):
88
120
with open (
@@ -145,7 +177,7 @@ def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
145
177
lock_json = dist .read_text ("hf-kernels.lock" )
146
178
if lock_json is not None :
147
179
for kernel_lock_json in json .loads (lock_json ):
148
- kernel_lock = KernelLock . from_json ( kernel_lock_json )
180
+ kernel_lock = KernelLock ( ** kernel_lock_json )
149
181
if kernel_lock .repo_id == repo_id :
150
182
return kernel_lock .sha
151
183
return None
0 commit comments