Skip to content

Commit 569b0d5

Browse files
Merge pull request #58 from RosettaCommons/feat/multiple-checkpoint-directories
Add multiple checkpoint directories optioning
2 parents 798fd9c + 0a8348f commit 569b0d5

File tree

5 files changed

+117
-61
lines changed

5 files changed

+117
-61
lines changed

.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ COLABFOLD_NET_DB_PATH_CPU=
6161
# Foundry install dir for checkpoints
6262
# Commented out by default since otherwise may be overridden by user export (load_dotenv(override=True) used at the moment)
6363
# TODO: Ensure override=False can be used.
64-
# FOUNDRY_CHECKPOINTS_DIR=
64+
FOUNDRY_CHECKPOINT_DIRS=

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ pip install rc-foundry[all]
1515
```
1616
foundry install base-models --checkpoint-dir <path/to/ckpt/dir>
1717
```
18-
where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Once installed, foundry sets the env `FOUNDRY_CHECKPOINTS_DIR` which it will use during inference or subsequent commands to find the checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`).
18+
where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Foundry always searches `~/.foundry/checkpoints` plus any colon-separated entries in `$FOUNDRY_CHECKPOINT_DIRS` during inference or subsequent commands to find checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`).
1919
To list the registry of available checkpoints:
2020
```
2121
foundry list-available
2222
```
23-
To check what you already have downloaded (defaults to `$FOUNDRY_CHECKPOINTS_DIR` if set):
23+
To check what you already have downloaded (searches `~/.foundry/checkpoints` plus `$FOUNDRY_CHECKPOINT_DIRS` if set):
2424
```
2525
foundry list-installed
2626
```

models/rfd3/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pip install rc-foundry[rfd3]
2222
```bash
2323
foundry install rfd3 --checkpoint-dir <path/to/ckpt/dir>
2424
```
25-
This sets `FOUNDRY_CHECKPOINTS_DIR` and will in future look for checkpoints in that directory, allowing you to run inference without supplying the checkpoint path. The checkpoint directory is optional, defaulting to `~/.foundry/checkpoints` if unset.
25+
This sets `FOUNDRY_CHECKPOINT_DIRS` and will in future look for checkpoints in that directory (alongside the default `~/.foundry/checkpoints` location), allowing you to run inference without supplying the checkpoint path. The checkpoint directory is optional, defaulting to `~/.foundry/checkpoints` if unset.
2626

2727
## Running Inference
2828

src/foundry/inference_engines/checkpoint_registry.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,62 @@
33
import os
44
from dataclasses import dataclass
55
from pathlib import Path
6+
from typing import Iterable, List
67

8+
import dotenv
79

8-
def get_default_checkpoint_dir() -> Path:
9-
"""Get the default checkpoint directory.
10+
DEFAULT_CHECKPOINT_DIR = Path.home() / ".foundry" / "checkpoints"
11+
12+
13+
def _normalize_paths(paths: Iterable[Path]) -> list[Path]:
14+
"""Return absolute, deduplicated paths in order."""
15+
seen = set()
16+
normalized: List[Path] = []
17+
for path in paths:
18+
resolved = path.expanduser().absolute()
19+
if resolved not in seen:
20+
normalized.append(resolved)
21+
seen.add(resolved)
22+
return normalized
1023

11-
Priority:
12-
1. FOUNDRY_CHECKPOINTS_DIR environment variable
13-
2. ~/.foundry/checkpoints
24+
25+
def get_default_checkpoint_dirs() -> list[Path]:
26+
"""Return checkpoint search paths.
27+
28+
Always starts with the default ~/.foundry/checkpoints directory and then
29+
appends any additional directories from the colon-separated
30+
FOUNDRY_CHECKPOINT_DIRS environment variable.
1431
"""
15-
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get(
16-
"FOUNDRY_CHECKPOINTS_DIR"
17-
):
18-
return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
19-
return Path.home() / ".foundry" / "checkpoints"
32+
env_dirs = os.environ.get("FOUNDRY_CHECKPOINT_DIRS", "")
33+
34+
# For backward compatibility, also check FOUNDRY_CHECKPOINTS_DIR
35+
if not env_dirs:
36+
env_dirs = os.environ.get("FOUNDRY_CHECKPOINTS_DIR", "")
37+
38+
extra_dirs: list[Path] = []
39+
if env_dirs:
40+
extra_dirs = [Path(p.strip()) for p in env_dirs.split(":") if p.strip()]
41+
return _normalize_paths([*extra_dirs, DEFAULT_CHECKPOINT_DIR])
42+
43+
44+
def get_default_checkpoint_dir() -> Path:
45+
"""Backward-compatible helper returning the primary checkpoint directory."""
46+
return get_default_checkpoint_dirs()[0]
47+
48+
49+
def append_checkpoint_to_env(checkpoint_dirs: list[Path]) -> bool:
50+
dotenv_path = dotenv.find_dotenv()
51+
if dotenv_path:
52+
checkpoint_dirs = _normalize_paths(checkpoint_dirs)
53+
dotenv.set_key(
54+
dotenv_path=dotenv_path,
55+
key_to_set="FOUNDRY_CHECKPOINT_DIRS",
56+
value_to_set=":".join(str(path) for path in checkpoint_dirs),
57+
export=False,
58+
)
59+
return True
60+
else:
61+
return False
2062

2163

2264
@dataclass
@@ -27,7 +69,12 @@ class RegisteredCheckpoint:
2769
sha256: None = None # Optional: add checksum for verification
2870

2971
def get_default_path(self):
30-
return get_default_checkpoint_dir() / self.filename
72+
checkpoint_dirs = get_default_checkpoint_dirs()
73+
for checkpoint_dir in checkpoint_dirs:
74+
candidate = checkpoint_dir / self.filename
75+
if candidate.exists():
76+
return candidate
77+
return checkpoint_dirs[0] / self.filename
3178

3279

3380
REGISTERED_CHECKPOINTS = {

src/foundry_cli/download_checkpoints.py

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from urllib.request import urlopen
77

88
import typer
9-
from dotenv import find_dotenv, load_dotenv, set_key
9+
from dotenv import load_dotenv
1010
from rich.console import Console
1111
from rich.progress import (
1212
BarColumn,
@@ -20,7 +20,8 @@
2020

2121
from foundry.inference_engines.checkpoint_registry import (
2222
REGISTERED_CHECKPOINTS,
23-
get_default_checkpoint_dir,
23+
append_checkpoint_to_env,
24+
get_default_checkpoint_dirs,
2425
)
2526

2627
load_dotenv(override=True)
@@ -29,11 +30,25 @@
2930
console = Console()
3031

3132

32-
def _resolve_checkpoint_dir(checkpoint_dir: Optional[Path]) -> Path:
33-
"""Return user-specified checkpoint dir or fall back to default."""
34-
return (
35-
checkpoint_dir if checkpoint_dir is not None else get_default_checkpoint_dir()
36-
)
33+
def _resolve_checkpoint_dirs(checkpoint_dir: Optional[Path]) -> list[Path]:
34+
"""Return checkpoint search path with defaults first."""
35+
checkpoint_dirs = get_default_checkpoint_dirs()
36+
if checkpoint_dir is not None:
37+
resolved = checkpoint_dir.expanduser().absolute()
38+
if resolved not in checkpoint_dirs:
39+
checkpoint_dirs.insert(0, resolved)
40+
else:
41+
# Move to front
42+
checkpoint_dirs.remove(resolved)
43+
checkpoint_dirs.insert(0, resolved)
44+
45+
# Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
46+
if append_checkpoint_to_env(checkpoint_dirs):
47+
console.print(
48+
f"Tracked checkpoint directories: {':'.join(str(path) for path in checkpoint_dirs)}"
49+
)
50+
51+
return checkpoint_dirs
3752

3853

3954
def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
@@ -136,7 +151,7 @@ def install(
136151
None,
137152
"--checkpoint-dir",
138153
"-d",
139-
help="Directory to save checkpoints (default: $FOUNDRY_CHECKPOINTS_DIR or ~/.foundry/checkpoints)",
154+
help="Directory to save checkpoints (default search path: ~/.foundry/checkpoints plus any $FOUNDRY_CHECKPOINT_DIRS entries)",
140155
),
141156
force: bool = typer.Option(
142157
False, "--force", "-f", help="Overwrite existing checkpoints"
@@ -149,10 +164,10 @@ def install(
149164
foundry install proteinmpnn --checkpoint-dir ./checkpoints
150165
"""
151166
# Determine checkpoint directory
152-
checkpoint_dir = _resolve_checkpoint_dir(checkpoint_dir)
167+
checkpoint_dirs = _resolve_checkpoint_dirs(checkpoint_dir)
168+
primary_checkpoint_dir = checkpoint_dirs[0]
153169

154-
console.print(f"[bold]Checkpoint directory:[/bold] {checkpoint_dir}")
155-
console.print()
170+
console.print(f"[bold]Install target:[/bold] {primary_checkpoint_dir}\n")
156171

157172
# Expand 'all' to all available models
158173
if "all" in models:
@@ -164,20 +179,9 @@ def install(
164179

165180
# Install each model
166181
for model_name in models_to_install:
167-
install_model(model_name, checkpoint_dir, force)
182+
install_model(model_name, primary_checkpoint_dir, force)
168183
console.print()
169184

170-
# Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
171-
dotenv_path = find_dotenv()
172-
if dotenv_path:
173-
set_key(
174-
dotenv_path=dotenv_path,
175-
key_to_set="FOUNDRY_CHECKPOINTS_DIR",
176-
value_to_set=str(checkpoint_dir),
177-
export=False,
178-
)
179-
console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}")
180-
181185
console.print("[bold green]Installation complete![/bold green]")
182186

183187

@@ -192,27 +196,28 @@ def list_available():
192196
@app.command(name="list-installed")
193197
def list_installed():
194198
"""List installed checkpoints and their sizes."""
195-
checkpoint_dir = _resolve_checkpoint_dir(None)
199+
checkpoint_dirs = _resolve_checkpoint_dirs(None)
196200

197-
if not checkpoint_dir.exists():
198-
console.print(
199-
f"[yellow]No checkpoints directory found at {checkpoint_dir}[/yellow]"
200-
)
201-
raise typer.Exit(0)
201+
checkpoint_files: list[tuple[Path, float]] = []
202+
for checkpoint_dir in checkpoint_dirs:
203+
if not checkpoint_dir.exists():
204+
continue
205+
ckpts = list(checkpoint_dir.glob("*.ckpt")) + list(checkpoint_dir.glob("*.pt"))
206+
for ckpt in ckpts:
207+
size = ckpt.stat().st_size / (1024**3) # GB
208+
checkpoint_files.append((ckpt, size))
202209

203-
checkpoint_files = list(checkpoint_dir.glob("*.ckpt")) + list(
204-
checkpoint_dir.glob("*.pt")
205-
)
206210
if not checkpoint_files:
207-
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
211+
console.print(
212+
"[yellow]No checkpoint files found in any checkpoint directory[/yellow]"
213+
)
208214
raise typer.Exit(0)
209215

210-
console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
216+
console.print("[bold]Installed checkpoints:[/bold]\n")
211217
total_size = 0
212-
for ckpt in sorted(checkpoint_files):
213-
size = ckpt.stat().st_size / (1024**3) # GB
218+
for ckpt, size in sorted(checkpoint_files, key=lambda item: str(item[0])):
214219
total_size += size
215-
console.print(f" {ckpt.name:30} {size:8.2f} GB")
220+
console.print(f" {ckpt} {size:8.2f} GB")
216221

217222
console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
218223

@@ -224,24 +229,28 @@ def clean(
224229
),
225230
):
226231
"""Remove all downloaded checkpoints."""
227-
checkpoint_dir = _resolve_checkpoint_dir(None)
228-
229-
if not checkpoint_dir.exists():
230-
console.print(f"[yellow]No checkpoints found at {checkpoint_dir}[/yellow]")
231-
raise typer.Exit(0)
232+
checkpoint_dirs = _resolve_checkpoint_dirs(None)
232233

233234
# List files to delete
234-
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
235+
checkpoint_files: list[Path] = []
236+
for checkpoint_dir in checkpoint_dirs:
237+
if not checkpoint_dir.exists():
238+
continue
239+
checkpoint_files.extend(checkpoint_dir.glob("*.ckpt"))
240+
checkpoint_files.extend(checkpoint_dir.glob("*.pt"))
241+
235242
if not checkpoint_files:
236-
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
243+
console.print(
244+
"[yellow]No checkpoint files found in any checkpoint directory[/yellow]"
245+
)
237246
raise typer.Exit(0)
238247

239248
console.print("[bold]Files to delete:[/bold]")
240249
total_size = 0
241-
for ckpt in checkpoint_files:
250+
for ckpt in sorted(checkpoint_files, key=str):
242251
size = ckpt.stat().st_size / (1024**3) # GB
243252
total_size += size
244-
console.print(f" {ckpt.name} ({size:.2f} GB)")
253+
console.print(f" {ckpt} ({size:.2f} GB)")
245254

246255
console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")
247256

0 commit comments

Comments
 (0)