From 3313cdfa47875de333eb3e4a3476aae6d4267f0b Mon Sep 17 00:00:00 2001 From: jbutch Date: Thu, 4 Dec 2025 21:06:10 -0800 Subject: [PATCH 1/5] Add multiple checkpoint directories optioning and rename env var --- .env | 2 +- README.md | 4 +- models/rfd3/README.md | 2 +- .../inference_engines/checkpoint_registry.py | 64 +++++++++-- src/foundry_cli/download_checkpoints.py | 101 ++++++++++-------- 5 files changed, 112 insertions(+), 61 deletions(-) diff --git a/.env b/.env index 298b652f..1a25be22 100644 --- a/.env +++ b/.env @@ -61,4 +61,4 @@ COLABFOLD_NET_DB_PATH_CPU= # Foundry install dir for checkpoints # Commented out by default since otherwise may be overridden by user export (load_dotenv(override=True) used at the moment) # TODO: Ensure override=False can be used. -# FOUNDRY_CHECKPOINTS_DIR= \ No newline at end of file +FOUNDRY_CHECKPOINT_DIRS='/home/jbutch/Projects/HT25/af3/foundry/checkpoints:/home/jbutch/Projects/HT25/af3/foundry/checkpoints2:/home/jbutch/.foundry/checkpoints' diff --git a/README.md b/README.md index a43eaf96..9986900a 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,12 @@ pip install rc-foundry[all] ``` foundry install base-models --checkpoint-dir ``` -where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Once installed, foundry sets the env `FOUNDRY_CHECKPOINTS_DIR` which it will use during inference or subsequent commands to find the checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`). +where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Foundry always searches `~/.foundry/checkpoints` plus any colon-separated entries in `$FOUNDRY_CHECKPOINT_DIRS` during inference or subsequent commands to find checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`). To list the registry of available checkpoints: ``` foundry list-available ``` -To check what you already have downloaded (defaults to `$FOUNDRY_CHECKPOINTS_DIR` if set): +To check what you already have downloaded (searches `~/.foundry/checkpoints` plus `$FOUNDRY_CHECKPOINT_DIRS` if set): ``` foundry list-installed ``` diff --git a/models/rfd3/README.md b/models/rfd3/README.md index dbd10129..153d9cd9 100644 --- a/models/rfd3/README.md +++ b/models/rfd3/README.md @@ -22,7 +22,7 @@ pip install rc-foundry[rfd3] ```bash foundry install rfd3 --checkpoint-dir ``` -This sets `FOUNDRY_CHECKPOINTS_DIR` and will in future look for checkpoints in that directory, allowing you to run inference without supplying the checkpoint path. The checkpoint directory is optional, defaulting to `~/.foundry/checkpoints` if unset. +This sets `FOUNDRY_CHECKPOINT_DIRS` and will in future look for checkpoints in that directory (alongside the default `~/.foundry/checkpoints` location), allowing you to run inference without supplying the checkpoint path. The checkpoint directory is optional, defaulting to `~/.foundry/checkpoints` if unset. ## Running Inference diff --git a/src/foundry/inference_engines/checkpoint_registry.py b/src/foundry/inference_engines/checkpoint_registry.py index 50f69e4d..0654811e 100644 --- a/src/foundry/inference_engines/checkpoint_registry.py +++ b/src/foundry/inference_engines/checkpoint_registry.py @@ -3,20 +3,57 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Iterable, List +import dotenv -def get_default_checkpoint_dir() -> Path: - """Get the default checkpoint directory. +DEFAULT_CHECKPOINT_DIR = Path.home() / ".foundry" / "checkpoints" + + +def _normalize_paths(paths: Iterable[Path]) -> list[Path]: + """Return absolute, deduplicated paths in order.""" + seen = set() + normalized: List[Path] = [] + for path in paths: + resolved = path.expanduser().absolute() + if resolved not in seen: + normalized.append(resolved) + seen.add(resolved) + return normalized + + +def get_default_checkpoint_dirs() -> list[Path]: + """Return checkpoint search paths. - Priority: - 1. FOUNDRY_CHECKPOINTS_DIR environment variable - 2. ~/.foundry/checkpoints + Always starts with the default ~/.foundry/checkpoints directory and then + appends any additional directories from the colon-separated + FOUNDRY_CHECKPOINT_DIRS environment variable. """ - if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get( - "FOUNDRY_CHECKPOINTS_DIR" - ): - return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute() - return Path.home() / ".foundry" / "checkpoints" + env_dirs = os.environ.get("FOUNDRY_CHECKPOINT_DIRS", "") + extra_dirs: list[Path] = [] + if env_dirs: + extra_dirs = [Path(p.strip()) for p in env_dirs.split(":") if p.strip()] + return _normalize_paths([*extra_dirs, DEFAULT_CHECKPOINT_DIR]) + + +def get_default_checkpoint_dir() -> Path: + """Backward-compatible helper returning the primary checkpoint directory.""" + return get_default_checkpoint_dirs()[0] + + +def append_checkpoint_to_env(checkpoint_dirs: list[Path]) -> None: + dotenv_path = dotenv.find_dotenv() + if dotenv_path: + checkpoint_dirs = _normalize_paths(checkpoint_dirs) + dotenv.set_key( + dotenv_path=dotenv_path, + key_to_set="FOUNDRY_CHECKPOINT_DIRS", + value_to_set=":".join(str(path) for path in checkpoint_dirs), + export=False, + ) + return True + else: + return False @dataclass @@ -27,7 +64,12 @@ class RegisteredCheckpoint: sha256: None = None # Optional: add checksum for verification def get_default_path(self): - return get_default_checkpoint_dir() / self.filename + checkpoint_dirs = get_default_checkpoint_dirs() + for checkpoint_dir in checkpoint_dirs: + candidate = checkpoint_dir / self.filename + if candidate.exists(): + return candidate + return checkpoint_dirs[0] / self.filename REGISTERED_CHECKPOINTS = { diff --git a/src/foundry_cli/download_checkpoints.py b/src/foundry_cli/download_checkpoints.py index 2607fe23..e0b14026 100644 --- a/src/foundry_cli/download_checkpoints.py +++ b/src/foundry_cli/download_checkpoints.py @@ -6,7 +6,7 @@ from urllib.request import urlopen import typer -from dotenv import find_dotenv, load_dotenv, set_key +from dotenv import load_dotenv from rich.console import Console from rich.progress import ( BarColumn, @@ -20,7 +20,8 @@ from foundry.inference_engines.checkpoint_registry import ( REGISTERED_CHECKPOINTS, - get_default_checkpoint_dir, + append_checkpoint_to_env, + get_default_checkpoint_dirs, ) load_dotenv(override=True) @@ -29,11 +30,25 @@ console = Console() -def _resolve_checkpoint_dir(checkpoint_dir: Optional[Path]) -> Path: - """Return user-specified checkpoint dir or fall back to default.""" - return ( - checkpoint_dir if checkpoint_dir is not None else get_default_checkpoint_dir() - ) +def _resolve_checkpoint_dirs(checkpoint_dir: Optional[Path]) -> list[Path]: + """Return checkpoint search path with defaults first.""" + checkpoint_dirs = get_default_checkpoint_dirs() + if checkpoint_dir is not None: + resolved = checkpoint_dir.expanduser().absolute() + if resolved not in checkpoint_dirs: + checkpoint_dirs.insert(0, resolved) + else: + # Move to front + checkpoint_dirs.remove(resolved) + checkpoint_dirs.insert(0, resolved) + + # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.) + if append_checkpoint_to_env(checkpoint_dirs): + console.print( + f"Tracked checkpoint directories: {':'.join(str(path) for path in checkpoint_dirs)}" + ) + + return checkpoint_dirs def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None: @@ -136,7 +151,7 @@ def install( None, "--checkpoint-dir", "-d", - help="Directory to save checkpoints (default: $FOUNDRY_CHECKPOINTS_DIR or ~/.foundry/checkpoints)", + help="Directory to save checkpoints (default search path: ~/.foundry/checkpoints plus any $FOUNDRY_CHECKPOINT_DIRS entries)", ), force: bool = typer.Option( False, "--force", "-f", help="Overwrite existing checkpoints" @@ -149,10 +164,10 @@ def install( foundry install proteinmpnn --checkpoint-dir ./checkpoints """ # Determine checkpoint directory - checkpoint_dir = _resolve_checkpoint_dir(checkpoint_dir) + checkpoint_dirs = _resolve_checkpoint_dirs(checkpoint_dir) + primary_checkpoint_dir = checkpoint_dirs[0] - console.print(f"[bold]Checkpoint directory:[/bold] {checkpoint_dir}") - console.print() + console.print(f"[bold]Install target:[/bold] {primary_checkpoint_dir}\n") # Expand 'all' to all available models if "all" in models: @@ -164,20 +179,9 @@ def install( # Install each model for model_name in models_to_install: - install_model(model_name, checkpoint_dir, force) + install_model(model_name, primary_checkpoint_dir, force) console.print() - # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.) - dotenv_path = find_dotenv() - if dotenv_path: - set_key( - dotenv_path=dotenv_path, - key_to_set="FOUNDRY_CHECKPOINTS_DIR", - value_to_set=str(checkpoint_dir), - export=False, - ) - console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}") - console.print("[bold green]Installation complete![/bold green]") @@ -192,27 +196,28 @@ def list_available(): @app.command(name="list-installed") def list_installed(): """List installed checkpoints and their sizes.""" - checkpoint_dir = _resolve_checkpoint_dir(None) + checkpoint_dirs = _resolve_checkpoint_dirs(None) - if not checkpoint_dir.exists(): - console.print( - f"[yellow]No checkpoints directory found at {checkpoint_dir}[/yellow]" - ) - raise typer.Exit(0) + checkpoint_files: list[tuple[Path, float]] = [] + for checkpoint_dir in checkpoint_dirs: + if not checkpoint_dir.exists(): + continue + ckpts = list(checkpoint_dir.glob("*.ckpt")) + list(checkpoint_dir.glob("*.pt")) + for ckpt in ckpts: + size = ckpt.stat().st_size / (1024**3) # GB + checkpoint_files.append((ckpt, size)) - checkpoint_files = list(checkpoint_dir.glob("*.ckpt")) + list( - checkpoint_dir.glob("*.pt") - ) if not checkpoint_files: - console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]") + console.print( + "[yellow]No checkpoint files found in any checkpoint directory[/yellow]" + ) raise typer.Exit(0) - console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n") + console.print("[bold]Installed checkpoints:[/bold]\n") total_size = 0 - for ckpt in sorted(checkpoint_files): - size = ckpt.stat().st_size / (1024**3) # GB + for ckpt, size in sorted(checkpoint_files, key=lambda item: str(item[0])): total_size += size - console.print(f" {ckpt.name:30} {size:8.2f} GB") + console.print(f" {ckpt} {size:8.2f} GB") console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB") @@ -224,24 +229,28 @@ def clean( ), ): """Remove all downloaded checkpoints.""" - checkpoint_dir = _resolve_checkpoint_dir(None) - - if not checkpoint_dir.exists(): - console.print(f"[yellow]No checkpoints found at {checkpoint_dir}[/yellow]") - raise typer.Exit(0) + checkpoint_dirs = _resolve_checkpoint_dirs(None) # List files to delete - checkpoint_files = list(checkpoint_dir.glob("*.ckpt")) + checkpoint_files: list[Path] = [] + for checkpoint_dir in checkpoint_dirs: + if not checkpoint_dir.exists(): + continue + checkpoint_files.extend(checkpoint_dir.glob("*.ckpt")) + checkpoint_files.extend(checkpoint_dir.glob("*.pt")) + if not checkpoint_files: - console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]") + console.print( + "[yellow]No checkpoint files found in any checkpoint directory[/yellow]" + ) raise typer.Exit(0) console.print("[bold]Files to delete:[/bold]") total_size = 0 - for ckpt in checkpoint_files: + for ckpt in sorted(checkpoint_files, key=lambda path: str(path)): size = ckpt.stat().st_size / (1024**3) # GB total_size += size - console.print(f" {ckpt.name} ({size:.2f} GB)") + console.print(f" {ckpt} ({size:.2f} GB)") console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB") From 0404d6ae9e4c5794b665bb917b88ca603fc3bca5 Mon Sep 17 00:00:00 2001 From: jbutch Date: Thu, 4 Dec 2025 21:10:03 -0800 Subject: [PATCH 2/5] Rewind .env --- .env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.env b/.env index 1a25be22..ca5beaa0 100644 --- a/.env +++ b/.env @@ -61,4 +61,4 @@ COLABFOLD_NET_DB_PATH_CPU= # Foundry install dir for checkpoints # Commented out by default since otherwise may be overridden by user export (load_dotenv(override=True) used at the moment) # TODO: Ensure override=False can be used. -FOUNDRY_CHECKPOINT_DIRS='/home/jbutch/Projects/HT25/af3/foundry/checkpoints:/home/jbutch/Projects/HT25/af3/foundry/checkpoints2:/home/jbutch/.foundry/checkpoints' +FOUNDRY_CHECKPOINT_DIRS= \ No newline at end of file From 5c766cdde8d6a88112bc9bbdb4d855f7e09550dc Mon Sep 17 00:00:00 2001 From: jbutch Date: Thu, 4 Dec 2025 21:23:40 -0800 Subject: [PATCH 3/5] Add backward compatibility --- src/foundry/inference_engines/checkpoint_registry.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/foundry/inference_engines/checkpoint_registry.py b/src/foundry/inference_engines/checkpoint_registry.py index 0654811e..85b42f25 100644 --- a/src/foundry/inference_engines/checkpoint_registry.py +++ b/src/foundry/inference_engines/checkpoint_registry.py @@ -30,6 +30,11 @@ def get_default_checkpoint_dirs() -> list[Path]: FOUNDRY_CHECKPOINT_DIRS environment variable. """ env_dirs = os.environ.get("FOUNDRY_CHECKPOINT_DIRS", "") + + # For backward compatibility, also check FOUNDRY_CHECKPOINTS_DIR + if not env_dirs: + env_dirs = os.environ.get("FOUNDRY_CHECKPOINTS_DIR", "") + extra_dirs: list[Path] = [] if env_dirs: extra_dirs = [Path(p.strip()) for p in env_dirs.split(":") if p.strip()] From b66aa4324a6ce785df577b68a3d25410951860f5 Mon Sep 17 00:00:00 2001 From: Jasper Butcher <66851659+Ubiquinone-dot@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:26:02 -0800 Subject: [PATCH 4/5] Update src/foundry/inference_engines/checkpoint_registry.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/foundry/inference_engines/checkpoint_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/foundry/inference_engines/checkpoint_registry.py b/src/foundry/inference_engines/checkpoint_registry.py index 85b42f25..2674fe5b 100644 --- a/src/foundry/inference_engines/checkpoint_registry.py +++ b/src/foundry/inference_engines/checkpoint_registry.py @@ -46,7 +46,7 @@ def get_default_checkpoint_dir() -> Path: return get_default_checkpoint_dirs()[0] -def append_checkpoint_to_env(checkpoint_dirs: list[Path]) -> None: +def append_checkpoint_to_env(checkpoint_dirs: list[Path]) -> bool: dotenv_path = dotenv.find_dotenv() if dotenv_path: checkpoint_dirs = _normalize_paths(checkpoint_dirs) From 0a8348f42a57be76a9f9592ebc60d8526aa3ae82 Mon Sep 17 00:00:00 2001 From: Jasper Butcher <66851659+Ubiquinone-dot@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:31:22 -0800 Subject: [PATCH 5/5] Update src/foundry_cli/download_checkpoints.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/foundry_cli/download_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/foundry_cli/download_checkpoints.py b/src/foundry_cli/download_checkpoints.py index e0b14026..ef367a1b 100644 --- a/src/foundry_cli/download_checkpoints.py +++ b/src/foundry_cli/download_checkpoints.py @@ -247,7 +247,7 @@ def clean( console.print("[bold]Files to delete:[/bold]") total_size = 0 - for ckpt in sorted(checkpoint_files, key=lambda path: str(path)): + for ckpt in sorted(checkpoint_files, key=str): size = ckpt.stat().st_size / (1024**3) # GB total_size += size console.print(f" {ckpt} ({size:.2f} GB)")