diff --git a/.env b/.env index 298b652f..ca5beaa0 100644 --- a/.env +++ b/.env @@ -61,4 +61,4 @@ COLABFOLD_NET_DB_PATH_CPU= # Foundry install dir for checkpoints # Commented out by default since otherwise may be overridden by user export (load_dotenv(override=True) used at the moment) # TODO: Ensure override=False can be used. -# FOUNDRY_CHECKPOINTS_DIR= \ No newline at end of file +FOUNDRY_CHECKPOINT_DIRS= \ No newline at end of file diff --git a/README.md b/README.md index a43eaf96..9986900a 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,12 @@ pip install rc-foundry[all] ``` foundry install base-models --checkpoint-dir ``` -where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Once installed, foundry sets the env `FOUNDRY_CHECKPOINTS_DIR` which it will use during inference or subsequent commands to find the checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`). +where `checkpoint-dir` will be `~/.foundry/checkpoints` by default. Foundry always searches `~/.foundry/checkpoints` plus any colon-separated entries in `$FOUNDRY_CHECKPOINT_DIRS` during inference or subsequent commands to find checkpoints. `base-models` installs the latest RFD3, RF3 and MPNN variants - you can also download all of the models supported (including multiple checkpoints of RF3) with `all`, or by listing the models sequentially (e.g. `foundry install rfd3 rf3 ...`). To list the registry of available checkpoints: ``` foundry list-available ``` -To check what you already have downloaded (defaults to `$FOUNDRY_CHECKPOINTS_DIR` if set): +To check what you already have downloaded (searches `~/.foundry/checkpoints` plus `$FOUNDRY_CHECKPOINT_DIRS` if set): ``` foundry list-installed ``` diff --git a/models/rfd3/README.md b/models/rfd3/README.md index dbd10129..153d9cd9 100644 --- a/models/rfd3/README.md +++ b/models/rfd3/README.md @@ -22,7 +22,7 @@ pip install rc-foundry[rfd3] ```bash foundry install rfd3 --checkpoint-dir ``` -This sets `FOUNDRY_CHECKPOINTS_DIR` and will in future look for checkpoints in that directory, allowing you to run inference without supplying the checkpoint path. The checkpoint directory is optional, defaulting to `~/.foundry/checkpoints` if unset. +This sets `FOUNDRY_CHECKPOINT_DIRS` and will in future look for checkpoints in that directory (alongside the default `~/.foundry/checkpoints` location), allowing you to run inference without supplying the checkpoint path. The checkpoint directory is optional, defaulting to `~/.foundry/checkpoints` if unset. ## Running Inference diff --git a/src/foundry/inference_engines/checkpoint_registry.py b/src/foundry/inference_engines/checkpoint_registry.py index 50f69e4d..2674fe5b 100644 --- a/src/foundry/inference_engines/checkpoint_registry.py +++ b/src/foundry/inference_engines/checkpoint_registry.py @@ -3,20 +3,62 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Iterable, List +import dotenv -def get_default_checkpoint_dir() -> Path: - """Get the default checkpoint directory. +DEFAULT_CHECKPOINT_DIR = Path.home() / ".foundry" / "checkpoints" + + +def _normalize_paths(paths: Iterable[Path]) -> list[Path]: + """Return absolute, deduplicated paths in order.""" + seen = set() + normalized: List[Path] = [] + for path in paths: + resolved = path.expanduser().absolute() + if resolved not in seen: + normalized.append(resolved) + seen.add(resolved) + return normalized - Priority: - 1. FOUNDRY_CHECKPOINTS_DIR environment variable - 2. ~/.foundry/checkpoints + +def get_default_checkpoint_dirs() -> list[Path]: + """Return checkpoint search paths. + + Always starts with the default ~/.foundry/checkpoints directory and then + appends any additional directories from the colon-separated + FOUNDRY_CHECKPOINT_DIRS environment variable. """ - if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get( - "FOUNDRY_CHECKPOINTS_DIR" - ): - return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute() - return Path.home() / ".foundry" / "checkpoints" + env_dirs = os.environ.get("FOUNDRY_CHECKPOINT_DIRS", "") + + # For backward compatibility, also check FOUNDRY_CHECKPOINTS_DIR + if not env_dirs: + env_dirs = os.environ.get("FOUNDRY_CHECKPOINTS_DIR", "") + + extra_dirs: list[Path] = [] + if env_dirs: + extra_dirs = [Path(p.strip()) for p in env_dirs.split(":") if p.strip()] + return _normalize_paths([*extra_dirs, DEFAULT_CHECKPOINT_DIR]) + + +def get_default_checkpoint_dir() -> Path: + """Backward-compatible helper returning the primary checkpoint directory.""" + return get_default_checkpoint_dirs()[0] + + +def append_checkpoint_to_env(checkpoint_dirs: list[Path]) -> bool: + dotenv_path = dotenv.find_dotenv() + if dotenv_path: + checkpoint_dirs = _normalize_paths(checkpoint_dirs) + dotenv.set_key( + dotenv_path=dotenv_path, + key_to_set="FOUNDRY_CHECKPOINT_DIRS", + value_to_set=":".join(str(path) for path in checkpoint_dirs), + export=False, + ) + return True + else: + return False @dataclass @@ -27,7 +69,12 @@ class RegisteredCheckpoint: sha256: None = None # Optional: add checksum for verification def get_default_path(self): - return get_default_checkpoint_dir() / self.filename + checkpoint_dirs = get_default_checkpoint_dirs() + for checkpoint_dir in checkpoint_dirs: + candidate = checkpoint_dir / self.filename + if candidate.exists(): + return candidate + return checkpoint_dirs[0] / self.filename REGISTERED_CHECKPOINTS = { diff --git a/src/foundry_cli/download_checkpoints.py b/src/foundry_cli/download_checkpoints.py index 2607fe23..ef367a1b 100644 --- a/src/foundry_cli/download_checkpoints.py +++ b/src/foundry_cli/download_checkpoints.py @@ -6,7 +6,7 @@ from urllib.request import urlopen import typer -from dotenv import find_dotenv, load_dotenv, set_key +from dotenv import load_dotenv from rich.console import Console from rich.progress import ( BarColumn, @@ -20,7 +20,8 @@ from foundry.inference_engines.checkpoint_registry import ( REGISTERED_CHECKPOINTS, - get_default_checkpoint_dir, + append_checkpoint_to_env, + get_default_checkpoint_dirs, ) load_dotenv(override=True) @@ -29,11 +30,25 @@ console = Console() -def _resolve_checkpoint_dir(checkpoint_dir: Optional[Path]) -> Path: - """Return user-specified checkpoint dir or fall back to default.""" - return ( - checkpoint_dir if checkpoint_dir is not None else get_default_checkpoint_dir() - ) +def _resolve_checkpoint_dirs(checkpoint_dir: Optional[Path]) -> list[Path]: + """Return checkpoint search path with defaults first.""" + checkpoint_dirs = get_default_checkpoint_dirs() + if checkpoint_dir is not None: + resolved = checkpoint_dir.expanduser().absolute() + if resolved not in checkpoint_dirs: + checkpoint_dirs.insert(0, resolved) + else: + # Move to front + checkpoint_dirs.remove(resolved) + checkpoint_dirs.insert(0, resolved) + + # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.) + if append_checkpoint_to_env(checkpoint_dirs): + console.print( + f"Tracked checkpoint directories: {':'.join(str(path) for path in checkpoint_dirs)}" + ) + + return checkpoint_dirs def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None: @@ -136,7 +151,7 @@ def install( None, "--checkpoint-dir", "-d", - help="Directory to save checkpoints (default: $FOUNDRY_CHECKPOINTS_DIR or ~/.foundry/checkpoints)", + help="Directory to save checkpoints (default search path: ~/.foundry/checkpoints plus any $FOUNDRY_CHECKPOINT_DIRS entries)", ), force: bool = typer.Option( False, "--force", "-f", help="Overwrite existing checkpoints" @@ -149,10 +164,10 @@ def install( foundry install proteinmpnn --checkpoint-dir ./checkpoints """ # Determine checkpoint directory - checkpoint_dir = _resolve_checkpoint_dir(checkpoint_dir) + checkpoint_dirs = _resolve_checkpoint_dirs(checkpoint_dir) + primary_checkpoint_dir = checkpoint_dirs[0] - console.print(f"[bold]Checkpoint directory:[/bold] {checkpoint_dir}") - console.print() + console.print(f"[bold]Install target:[/bold] {primary_checkpoint_dir}\n") # Expand 'all' to all available models if "all" in models: @@ -164,20 +179,9 @@ def install( # Install each model for model_name in models_to_install: - install_model(model_name, checkpoint_dir, force) + install_model(model_name, primary_checkpoint_dir, force) console.print() - # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.) - dotenv_path = find_dotenv() - if dotenv_path: - set_key( - dotenv_path=dotenv_path, - key_to_set="FOUNDRY_CHECKPOINTS_DIR", - value_to_set=str(checkpoint_dir), - export=False, - ) - console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}") - console.print("[bold green]Installation complete![/bold green]") @@ -192,27 +196,28 @@ def list_available(): @app.command(name="list-installed") def list_installed(): """List installed checkpoints and their sizes.""" - checkpoint_dir = _resolve_checkpoint_dir(None) + checkpoint_dirs = _resolve_checkpoint_dirs(None) - if not checkpoint_dir.exists(): - console.print( - f"[yellow]No checkpoints directory found at {checkpoint_dir}[/yellow]" - ) - raise typer.Exit(0) + checkpoint_files: list[tuple[Path, float]] = [] + for checkpoint_dir in checkpoint_dirs: + if not checkpoint_dir.exists(): + continue + ckpts = list(checkpoint_dir.glob("*.ckpt")) + list(checkpoint_dir.glob("*.pt")) + for ckpt in ckpts: + size = ckpt.stat().st_size / (1024**3) # GB + checkpoint_files.append((ckpt, size)) - checkpoint_files = list(checkpoint_dir.glob("*.ckpt")) + list( - checkpoint_dir.glob("*.pt") - ) if not checkpoint_files: - console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]") + console.print( + "[yellow]No checkpoint files found in any checkpoint directory[/yellow]" + ) raise typer.Exit(0) - console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n") + console.print("[bold]Installed checkpoints:[/bold]\n") total_size = 0 - for ckpt in sorted(checkpoint_files): - size = ckpt.stat().st_size / (1024**3) # GB + for ckpt, size in sorted(checkpoint_files, key=lambda item: str(item[0])): total_size += size - console.print(f" {ckpt.name:30} {size:8.2f} GB") + console.print(f" {ckpt} {size:8.2f} GB") console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB") @@ -224,24 +229,28 @@ def clean( ), ): """Remove all downloaded checkpoints.""" - checkpoint_dir = _resolve_checkpoint_dir(None) - - if not checkpoint_dir.exists(): - console.print(f"[yellow]No checkpoints found at {checkpoint_dir}[/yellow]") - raise typer.Exit(0) + checkpoint_dirs = _resolve_checkpoint_dirs(None) # List files to delete - checkpoint_files = list(checkpoint_dir.glob("*.ckpt")) + checkpoint_files: list[Path] = [] + for checkpoint_dir in checkpoint_dirs: + if not checkpoint_dir.exists(): + continue + checkpoint_files.extend(checkpoint_dir.glob("*.ckpt")) + checkpoint_files.extend(checkpoint_dir.glob("*.pt")) + if not checkpoint_files: - console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]") + console.print( + "[yellow]No checkpoint files found in any checkpoint directory[/yellow]" + ) raise typer.Exit(0) console.print("[bold]Files to delete:[/bold]") total_size = 0 - for ckpt in checkpoint_files: + for ckpt in sorted(checkpoint_files, key=str): size = ckpt.stat().st_size / (1024**3) # GB total_size += size - console.print(f" {ckpt.name} ({size:.2f} GB)") + console.print(f" {ckpt} ({size:.2f} GB)") console.print(f"\n[bold]Total:[/bold] {total_size:.2f} GB")