diff --git a/docs/data2_recovery_survey.md b/docs/data2_recovery_survey.md new file mode 100644 index 00000000..c7f3fb53 --- /dev/null +++ b/docs/data2_recovery_survey.md @@ -0,0 +1,683 @@ +# /data2/ Recovery Survey & Backup Strategy + +> **Date Created:** December 18, 2025 +> **Context:** /data2/ was deleted; backup available from ~2 weeks ago +> **Purpose:** Track what needs to be recreated and establish backup procedures + +--- + +## Table of Contents +1. [Datasets to Recreate](#datasets-to-recreate) +2. [Checkpoints to Retrain](#checkpoints-to-retrain) +3. [Safe Checkpoints (Already on HuggingFace)](#safe-checkpoints-already-on-huggingface) +4. [Historical Run/Log Directories](#historical-runlog-directories) +5. [Recovery Priority Order](#recovery-priority-order) +6. [Backup Strategy Plan](#backup-strategy-plan) + +--- + +## Datasets to Recreate + +### 1. PDB Data (Latent Generator) + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt` | PDB training split | Multiple configs | +| `/data2/lisanzas/latent_generator_files/pdb_data/split_data/validation.pt` | PDB validation split | Multiple configs | +| `/data2/lisanzas/latent_generator_files/pdb_data/split_data/test.pt` | PDB test split | Multiple configs | +| `/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt` | Sequence clustering file | Multiple configs | + +### 2. AFDB SwissProt Data (Gen-UME) + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/AFDB/train_processed/` | AFDB training set | `structure_afdb_swissprot.yaml` | +| `/data2/lisanzas/AFDB/valid_cameo_processed/` | CAMEO validation set | `structure_afdb_swissprot.yaml`, callbacks | +| `/data2/lisanzas/AFDB/test_multiflow_processed/` | MultiFlow test set | `structure_afdb_swissprot.yaml` | +| `/data2/lisanzas/AFDB/pdb_swissprot_clusters.pt` | SwissProt clustering | `structure_afdb_swissprot.yaml` | + +### 3. ESM Atlas Data (Gen-UME) + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/ume/simplefold_dataset/train_processed/` | Processed ESM Atlas structures | `structure_esm_atlas_afdb_swissprot.yaml` | +| `/data2/ume/simplefold_dataset/esm_atlas/` | Raw ESM Atlas data | Processing scripts | + +### 4. Ligand Datasets (Latent Generator) + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/pdb_bind/train/` | PDBBind train (old) | `structure_ligand_pdb.yaml` | +| `/data2/lisanzas/pdb_bind/val/` | PDBBind val (old) | `structure_ligand_pdb.yaml` | +| `/data2/lisanzas/pdb_bind/test/` | PDBBind test (old) | `structure_ligand_pdb.yaml` | +| `/data2/lisanzas/pdb_bind_12_15_25/train/` | PDBBind train (new, with bond_matrix) | `structure_ligand_pdb_sair_bond.yaml` | +| `/data2/lisanzas/pdb_bind_12_15_25/val/` | PDBBind val (new) | `structure_ligand_pdb_sair_bond.yaml` | +| `/data2/lisanzas/pdb_bind_12_15_25/test/` | PDBBind test (new) | `structure_ligand_pdb_sair_bond.yaml` | +| `/data2/lisanzas/geom_12_15_25/train/` | GEOM ligands (with bond_matrix) | `structure_ligand_pdb_sair_bond.yaml` | +| `/data2/lisanzas/sair_12_15_25/train/` | SAIR protein-ligand (with bond_matrix) | `structure_ligand_pdb_sair_bond.yaml` | +| `/data2/lisanzas/sair_protein_ligand/train/` | SAIR (old) | `structure_ligand_pdb_sair.yaml` | + +### 5. CATH Data + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/CATH_v4_3/processed_structures_pt/train/cath_train.pt` | CATH train | `structure_cath.yaml` | +| `/data2/lisanzas/CATH_v4_3/processed_structures_pt/val/cath_val.pt` | CATH val | `structure_cath.yaml` | +| `/data2/lisanzas/CATH_v4_3/processed_structures_pt/test/cath_test.pt` | CATH test | `structure_cath.yaml` | + +### 6. SAbDab (Antibody) Data + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/sabdab/train_denovo_processed_pt/train_denovo_data.pt` | SAbDab train | `structure_sabdab.yaml` | +| `/data2/lisanzas/sabdab/val_denovo_processed_pt/val_denovo_data.pt` | SAbDab val | `structure_sabdab.yaml` | +| `/data2/lisanzas/sabdab/test_denovo_processed_pt/test_dummy_denovo_data.pt` | SAbDab test | `structure_sabdab.yaml` | + +### 7. ESM-C Embeddings (Latent Generator) + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/latent_generator_files/esm_c_300m_embeddings_iterable_sampler/train/` | ESM-C train embeddings | `structure_pinder_esm.yaml` | +| `/data2/lisanzas/latent_generator_files/esm_c_300m_embeddings_iterable_sampler/val/` | ESM-C val embeddings | `structure_pinder_esm.yaml` | +| `/data2/lisanzas/latent_generator_files/esm_c_300m_embeddings_iterable_sampler/test/` | ESM-C test embeddings | `structure_pinder_esm.yaml` | + +### 8. AFDB Genie Data + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/latent_generator_files/afdb_data/processed_pt/train_afdb_genie2_data.pt` | AFDB Genie train | `structure_afdb_genie.yaml` | + +### 9. MultiFlow Test Set + +| Path | Description | Used In | +|------|-------------|---------| +| `/data2/lisanzas/multi_flow_data/test_set_filtered_pt/` | MultiFlow test set | Generation experiments | + +--- + +## Checkpoints to Retrain + +### Latent Generator Checkpoints (Local paths - NOT on HuggingFace) + +| Path | Model Name | Priority | +|------|------------|----------| +| `/data2/ume/latent_generator_/runs//2025-11-09T14-23-55/last.ckpt` | LG Ligand | Medium | +| `/data2/ume/latent_generator_/runs//2025-11-06T00-40-11/last.ckpt` | LG full attention 2 | Medium | +| `/data2/ume/latent_generator_/runs//2025-12-07T22-38-42/epoch=830-step=88917-val_loss=16.5010.ckpt` | LG Protein Ligand | **HIGH** | +| `/data2/ume/latent_generator_/runs//2025-12-13T16-34-07/epoch=240-step=25787-val_loss=16.4510.ckpt` | LG Protein Ligand fsq 4375 | **HIGH** | +| `/data2/ume/latent_generator_/runs//2025-12-13T14-57-53/epoch=210-step=22577-val_loss=17.2066.ckpt` | LG Protein Ligand fsq 1000 | **HIGH** | +| `/data2/lisanzas/latent_generator/studies/outputs/train/dev/runs/2025-11-09_22-19-12/checkpoints/last.ckpt` | LG full attention 512 PDB Pinder FSQ | Low | + +### Gen-UME Checkpoints + +| Path | Model Size | Priority | +|------|------------|----------| +| `/data2/ume/gen_ume/runs//2025-11-17T20-31-05/last.ckpt` | 750M model | **HIGH** (used in many experiments) | +| `/data2/ume/gen_ume/runs//2025-11-07T13-19-11/last.ckpt` | 450M model | **HIGH** | +| `/data2/lisanzas/gen_ume/runs//2025-12-05T16-48-13/last.ckpt` | ESM Atlas trained | High | +| `/data2/lisanzas/gen_ume/runs//2025-12-17T20-25-52/epoch=28-step=20985-val_loss=5.0925.ckpt` | Latest large resume | **HIGH** | + +--- + +## Safe Checkpoints (Already on HuggingFace) + +These are hosted on HuggingFace and **don't need retraining**: + +- ✅ `LG Ligand 20A` +- ✅ `LG Ligand 20A 512 1024` +- ✅ `LG Ligand 20A 512 1024 element` +- ✅ `LG Ligand 20A continuous` +- ✅ `LG Ligand 20A seq 3di Aux` +- ✅ `LG 20A seq Aux` +- ✅ `LG 20A seq 3di c6d Aux` +- ✅ `LG 20A seq 3di c6d Aux Pinder` +- ✅ `LG 20A seq 3di c6d Aux PDB` +- ✅ `LG 20A seq 3di c6d Aux PDB Pinder` +- ✅ `LG 20A seq 3di c6d Aux PDB Pinder Finetune` +- ✅ `LG 20A` +- ✅ `LG 10A` +- ✅ `LG full attention` + +--- + +## Historical Run/Log Directories + +These are training runs and logs - may be acceptable to lose: + +- `/data2/ume/latent_generator_/slurm/logs/` +- `/data2/ume/gen_ume/slurm/logs/` +- `/data2/ume/latent_generator_/runs/` (except specific checkpoints above) +- `/data2/ume/.cache2/` +- `/data2/lisanzas/.cache/` +- `/data2/lisanzas/gen_ume/tmp/` + +--- + +## Recovery Priority Order + +### Tier 1 - Critical (Blocks current work) + +1. ⬜ PDB training/validation/test splits + cluster file +2. ⬜ AFDB SwissProt processed datasets (train/val/test) +3. ⬜ Gen-UME 750M checkpoint (`2025-11-17T20-31-05`) +4. ⬜ Latest Latent Generator protein-ligand checkpoints + +### Tier 2 - High (Needed for experiments) + +1. ⬜ PDBBind/GEOM/SAIR with bond_matrix (12_15_25 versions) +2. ⬜ ESM Atlas processed structures +3. ⬜ Gen-UME 450M checkpoint +4. ⬜ CAMEO/MultiFlow test sets + +### Tier 3 - Medium (Nice to have) + +1. ⬜ CATH datasets +2. ⬜ SAbDab datasets +3. ⬜ ESM-C embeddings +4. ⬜ AFDB Genie data + +--- + +## Backup Strategy Plan + +### Part 1: Automatic Checkpoint Backup + +#### Option A: S3 Bucket Backup (Recommended) + +**Setup:** +```bash +# S3 bucket structure +s3://prescient-pcluster-data/gen_ume/ +├── checkpoints/ +│ ├── latent_generator/ +│ │ ├── LG_Protein_Ligand_v1.ckpt +│ │ ├── LG_Protein_Ligand_fsq_4375_v1.ckpt +│ │ └── ... +│ └── gen_ume/ +│ ├── gen_ume_750M_v1.ckpt +│ ├── gen_ume_450M_v1.ckpt +│ └── ... +└── datasets/ + ├── pdb/ + ├── afdb/ + ├── ligand/ + └── ... +``` + +**Automatic Upload Callback:** + +Create a new callback that uploads checkpoints to S3 after each save: + +```python +# src/lobster/callbacks/_s3_checkpoint_callback.py +import boto3 +import os +from pathlib import Path +from pytorch_lightning.callbacks import Callback + +class S3CheckpointBackupCallback(Callback): + """Automatically backup checkpoints to S3 after saving.""" + + def __init__( + self, + s3_bucket: str = "prescient-lobster", + s3_prefix: str = "checkpoints", + project_name: str = "latent_generator", + upload_every_n_epochs: int = 10, + upload_best_only: bool = False, + ): + self.s3_bucket = s3_bucket + self.s3_prefix = s3_prefix + self.project_name = project_name + self.upload_every_n_epochs = upload_every_n_epochs + self.upload_best_only = upload_best_only + self.s3_client = boto3.client("s3") + + def _upload_to_s3(self, local_path: str, s3_key: str): + """Upload a file to S3.""" + try: + self.s3_client.upload_file(local_path, self.s3_bucket, s3_key) + print(f"✅ Uploaded {local_path} to s3://{self.s3_bucket}/{s3_key}") + except Exception as e: + print(f"❌ Failed to upload {local_path}: {e}") + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + """Called when a checkpoint is saved.""" + # Get the checkpoint path + ckpt_callback = trainer.checkpoint_callback + if ckpt_callback is None: + return + + # Upload best checkpoint + if self.upload_best_only and ckpt_callback.best_model_path: + best_path = ckpt_callback.best_model_path + if os.path.exists(best_path): + filename = Path(best_path).name + s3_key = f"{self.s3_prefix}/{self.project_name}/best/{filename}" + self._upload_to_s3(best_path, s3_key) + + # Upload periodic checkpoints + if trainer.current_epoch % self.upload_every_n_epochs == 0: + last_path = ckpt_callback.last_model_path + if last_path and os.path.exists(last_path): + filename = Path(last_path).name + s3_key = f"{self.s3_prefix}/{self.project_name}/periodic/{filename}" + self._upload_to_s3(last_path, s3_key) +``` + +**Hydra Config:** +```yaml +# src/lobster/hydra_config/callbacks/s3_backup.yaml +s3_backup: + _target_: lobster.callbacks._s3_checkpoint_callback.S3CheckpointBackupCallback + s3_bucket: "prescient-pcluster-data" + s3_prefix: "gen_ume/checkpoints" + project_name: ${logger.project} + upload_every_n_epochs: 10 + upload_best_only: false +``` + +#### Option B: HuggingFace Hub Backup + +**Automatic Upload Callback:** + +```python +# src/lobster/callbacks/_hf_checkpoint_callback.py +from huggingface_hub import HfApi, upload_file +from pytorch_lightning.callbacks import Callback +import os + +class HuggingFaceCheckpointCallback(Callback): + """Automatically upload checkpoints to HuggingFace Hub.""" + + def __init__( + self, + repo_id: str = "Sidney-Lisanza/latent_generator", + upload_every_n_epochs: int = 50, + upload_best_only: bool = True, + ): + self.repo_id = repo_id + self.upload_every_n_epochs = upload_every_n_epochs + self.upload_best_only = upload_best_only + self.api = HfApi() + + def _upload_to_hf(self, local_path: str, path_in_repo: str): + """Upload a file to HuggingFace Hub.""" + try: + self.api.upload_file( + path_or_fileobj=local_path, + path_in_repo=path_in_repo, + repo_id=self.repo_id, + repo_type="model", + ) + print(f"✅ Uploaded to HuggingFace: {self.repo_id}/{path_in_repo}") + except Exception as e: + print(f"❌ Failed to upload to HuggingFace: {e}") + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + ckpt_callback = trainer.checkpoint_callback + if ckpt_callback is None: + return + + # Upload best checkpoint to HuggingFace + if self.upload_best_only and ckpt_callback.best_model_path: + best_path = ckpt_callback.best_model_path + if os.path.exists(best_path): + # Create descriptive name + model_name = os.environ.get("MODEL_NAME", "model") + path_in_repo = f"checkpoints_for_lg/{model_name}.ckpt" + self._upload_to_hf(best_path, path_in_repo) +``` + +### Part 2: Dataset Backup to S3 + +#### Initial Upload Script + +```bash +#!/bin/bash +# scripts/backup_datasets_to_s3.sh + +S3_BUCKET="s3://prescient-pcluster-data/gen_ume/datasets" + +# PDB Data +echo "Uploading PDB data..." +aws s3 sync /data2/lisanzas/latent_generator_files/pdb_data/ \ + ${S3_BUCKET}/latent_generator/pdb_data/ \ + --exclude "*.log" + +# AFDB SwissProt +echo "Uploading AFDB SwissProt..." +aws s3 sync /data2/lisanzas/AFDB/ \ + ${S3_BUCKET}/afdb/ \ + --exclude "*.log" + +# Ligand datasets +echo "Uploading ligand datasets..." +aws s3 sync /data2/lisanzas/pdb_bind_12_15_25/ \ + ${S3_BUCKET}/ligand/pdb_bind_12_15_25/ + +aws s3 sync /data2/lisanzas/geom_12_15_25/ \ + ${S3_BUCKET}/ligand/geom_12_15_25/ + +aws s3 sync /data2/lisanzas/sair_12_15_25/ \ + ${S3_BUCKET}/ligand/sair_12_15_25/ + +# ESM Atlas +echo "Uploading ESM Atlas..." +aws s3 sync /data2/ume/simplefold_dataset/train_processed/ \ + ${S3_BUCKET}/esm_atlas/train_processed/ + +# CATH +echo "Uploading CATH..." +aws s3 sync /data2/lisanzas/CATH_v4_3/ \ + ${S3_BUCKET}/cath/ + +# SAbDab +echo "Uploading SAbDab..." +aws s3 sync /data2/lisanzas/sabdab/ \ + ${S3_BUCKET}/sabdab/ + +echo "✅ Dataset backup complete!" +``` + +#### Dataset Sync Utility + +```python +# scripts/sync_datasets.py +""" +Utility to sync datasets between local storage and S3. + +Usage: + # Download datasets from S3 + python scripts/sync_datasets.py download --dataset pdb + + # Upload datasets to S3 + python scripts/sync_datasets.py upload --dataset all + + # List available datasets + python scripts/sync_datasets.py list +""" + +import argparse +import subprocess +from pathlib import Path + +DATASETS = { + "pdb": { + "local": "/data2/lisanzas/latent_generator_files/pdb_data/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/latent_generator/pdb_data/", + }, + "afdb": { + "local": "/data2/lisanzas/AFDB/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/afdb/", + }, + "pdb_bind": { + "local": "/data2/lisanzas/pdb_bind_12_15_25/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/ligand/pdb_bind_12_15_25/", + }, + "geom": { + "local": "/data2/lisanzas/geom_12_15_25/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/ligand/geom_12_15_25/", + }, + "sair": { + "local": "/data2/lisanzas/sair_12_15_25/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/ligand/sair_12_15_25/", + }, + "esm_atlas": { + "local": "/data2/ume/simplefold_dataset/train_processed/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/esm_atlas/train_processed/", + }, + "cath": { + "local": "/data2/lisanzas/CATH_v4_3/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/cath/", + }, + "sabdab": { + "local": "/data2/lisanzas/sabdab/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/sabdab/", + }, + "multiflow": { + "local": "/data2/lisanzas/multi_flow_data/", + "s3": "s3://prescient-pcluster-data/gen_ume/datasets/multiflow/", + }, +} + +def sync(source: str, dest: str, dry_run: bool = False): + """Sync files between source and destination.""" + cmd = ["aws", "s3", "sync", source, dest] + if dry_run: + cmd.append("--dryrun") + print(f"Running: {' '.join(cmd)}") + subprocess.run(cmd, check=True) + +def download(dataset: str, dry_run: bool = False): + """Download dataset from S3.""" + if dataset == "all": + for name, paths in DATASETS.items(): + print(f"\n📥 Downloading {name}...") + sync(paths["s3"], paths["local"], dry_run) + else: + paths = DATASETS[dataset] + sync(paths["s3"], paths["local"], dry_run) + +def upload(dataset: str, dry_run: bool = False): + """Upload dataset to S3.""" + if dataset == "all": + for name, paths in DATASETS.items(): + print(f"\n📤 Uploading {name}...") + sync(paths["local"], paths["s3"], dry_run) + else: + paths = DATASETS[dataset] + sync(paths["local"], paths["s3"], dry_run) + +def list_datasets(): + """List available datasets.""" + print("\nAvailable datasets:") + print("-" * 60) + for name, paths in DATASETS.items(): + print(f" {name}:") + print(f" Local: {paths['local']}") + print(f" S3: {paths['s3']}") + print() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Sync datasets between local and S3") + subparsers = parser.add_subparsers(dest="command") + + # Download command + dl_parser = subparsers.add_parser("download", help="Download from S3") + dl_parser.add_argument("--dataset", required=True, choices=list(DATASETS.keys()) + ["all"]) + dl_parser.add_argument("--dry-run", action="store_true") + + # Upload command + ul_parser = subparsers.add_parser("upload", help="Upload to S3") + ul_parser.add_argument("--dataset", required=True, choices=list(DATASETS.keys()) + ["all"]) + ul_parser.add_argument("--dry-run", action="store_true") + + # List command + subparsers.add_parser("list", help="List available datasets") + + args = parser.parse_args() + + if args.command == "download": + download(args.dataset, args.dry_run) + elif args.command == "upload": + upload(args.dataset, args.dry_run) + elif args.command == "list": + list_datasets() + else: + parser.print_help() +``` + +### Part 3: Updated Training Scripts + +Update SLURM scripts to use S3 backup: + +```bash +# Add to slurm/scripts/train_*.sh + +# Enable S3 checkpoint backup +export S3_CHECKPOINT_BUCKET="prescient-lobster" +export S3_CHECKPOINT_PREFIX="checkpoints" + +# Add callback to training command +srun ... \ + lobster_train \ + experiment=train_gen_ume \ + +callbacks.s3_backup.s3_bucket=${S3_CHECKPOINT_BUCKET} \ + +callbacks.s3_backup.s3_prefix=${S3_CHECKPOINT_PREFIX} \ + ... +``` + +### Part 4: Recovery Checklist + +After implementing backup: + +- [ ] Create S3 bucket `prescient-lobster` (or verify existing) +- [ ] Run initial dataset backup script +- [ ] Add S3CheckpointBackupCallback to all training configs +- [ ] Update experiment configs to use S3 paths as fallback +- [ ] Set up weekly verification of S3 backups +- [ ] Document recovery procedure + +--- + +## Recovery Status Tracker + +**Last Survey: January 5, 2026** + +### DATASETS + +| Item | Status | Notes | +|------|--------|-------| +| PDB splits (train/val/test) | ✅ Recovered | `/data2/lisanzas/latent_generator_files/pdb_data/split_data/` | +| PDB cluster file | ✅ Recovered | `pdb_seqid40_clusters.pt` | +| AFDB SwissProt (train) | ✅ Recovered | `/data2/lisanzas/AFDB/train_processed/` | +| AFDB SwissProt (val/test) | ✅ Recovered | `valid_cameo_processed/`, `test_multiflow_processed/` | +| AFDB cluster file | ✅ Recovered | `pdb_swissprot_clusters.pt` | +| ESM Atlas | ✅ Recovered | `/data2/ume/simplefold_dataset/train_processed/` | +| PDBBind (old) | ✅ Recovered | `/data2/lisanzas/pdb_bind/` | +| PDBBind 12_15_25 | ✅ **COMPLETE** | `/data2/lisanzas/pdb_bind_12_15_25/` - 27,294 files with bond_matrix | +| GEOM 12_15_25 | ✅ **COMPLETE** | `/data2/lisanzas/geom_12_15_25/` - 246,840 train, 30,953 val, 30,936 test | +| SAIR 12_15_25 | ✅ **REPROCESSED** | `/data2/lisanzas/sair_12_15_25/` - 279,963 train, 38,611 val, 80,343 test | +| SAIR (old) | ⏸️ Not needed | Replaced by SAIR 12_15_25 | +| CATH | ✅ Recovered | `/data2/lisanzas/CATH_v4_3/` | +| SAbDab | ✅ Recovered | `/data2/lisanzas/sabdab/` | +| MultiFlow Test Set | ✅ Recovered | `/data2/lisanzas/multi_flow_data/` | +| AFDB Genie | ❌ Missing | `/data2/lisanzas/latent_generator_files/afdb_data/processed_pt/` - low priority | +| ESM-C Embeddings | ✅ Recovered | `/data2/lisanzas/latent_generator_files/esm_c_300m_embeddings_iterable_sampler/` | + +### CHECKPOINTS + +| Item | Status | Notes | +|------|--------|-------| +| Gen-UME 90M (PDB) | ✅ S3 Backup | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_90M_PDB.ckpt` (1.1 GiB) | +| Gen-UME 750M (2025-11-17) | ✅ S3 Backup | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_750M_2025-11-17_*.ckpt` (8.3 GiB) | +| Gen-UME 450M (2025-11-07) | ✅ S3 Backup | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_450M_2025-11-07_*.ckpt` (5.3 GiB) | +| Gen-UME 750M ESM Atlas (2026-01-04) | ✅ S3 Backup | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_750M_ESM_Atlas_2026-01-04_*.ckpt` (8.3 GiB) | +| Gen-UME Latest Large (2025-12-17) | ❌ Missing | Low priority - experimental training | +| **LG Protein Ligand 4096** (2026-01-05) | ✅ **NEW S3** | `LG_Protein_Ligand_4096_2026-01-05.ckpt` (292.9 MiB) | +| **LG Protein Ligand fsq 4375** (2026-01-05) | ✅ **NEW S3** | `LG_Protein_Ligand_fsq_4375_2026-01-05.ckpt` (295.8 MiB) | +| **LG Protein Ligand fsq 4375/15360** (2026-01-07) | ✅ **NEW S3** | `LG_Protein_Ligand_fsq_4375_15360_2026-01-07.ckpt` (360.2 MiB) | +| LG Protein Ligand (2025-12-07) | ❌ **LOST** | Original 512-token SLQ model - needs retraining | +| LG Protein Ligand fsq 1000 (2025-12-13) | ❌ **LOST** | 1000-token FSQ model - needs retraining | +| LG Ligand (2025-11-09) | ✅ S3 Backup | `LG_Ligand_2025-11-09.ckpt` (250.5 MiB) | +| LG full attention 2 | ✅ S3 Backup | `LG_full_attention_2_2025-11-06.ckpt` (245.3 MiB) | + +### Available Latent Generator runs in `/data2/ume/latent_generator_/runs/`: + +Runs with checkpoints available (Nov 2025): +- 2025-11-30T16-50-54 +- 2025-11-28T17-38-46 +- 2025-11-26T15-51-49 +- 2025-11-25T14-42-33 +- 2025-11-21T23-30-09 +- 2025-11-20T17-28-11 +- (and more from Nov 17-21) + +--- + +## Summary + +### ✅ Recovered/Reprocessed (14 datasets): +- PDB data + clusters +- AFDB SwissProt + clusters +- ESM Atlas +- PDBBind (old) +- **SAIR 12_15_25** ✅ (279,963 train, 38,611 val, 80,343 test) +- **PDBBind 12_15_25** ✅ (21,835 train, 2,729 val, 2,730 test) +- **GEOM 12_15_25** ✅ (246,840 train, 30,953 val, 30,936 test) +- CATH, SAbDab, MultiFlow +- ESM-C Embeddings + +### 🔄 Currently Processing: +None - all datasets ready! + +### ❌ Not Recovered (low priority): +- **AFDB Genie** - can reprocess if needed +- **SAIR (old)** - replaced by SAIR 12_15_25 + +### ✅ Recovered Checkpoints (backed up to S3): + +**Gen-UME Models:** +| Model | Local Path | S3 Path | Size | +|-------|-----------|---------|------| +| Gen-UME 90M (PDB) | (from old S3 bucket) | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_90M_PDB.ckpt` | 1.1 GiB | +| Gen-UME 450M | `/data2/ume/gen_ume/runs/2025-11-07T13-19-11/` | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_450M_2025-11-07_*.ckpt` | 5.3 GiB | +| Gen-UME 750M | `/data2/ume/gen_ume/runs/2025-11-17T20-31-05/` | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_750M_2025-11-17_*.ckpt` | 8.3 GiB | +| Gen-UME 750M ESM Atlas | `/data2/lisanzas/gen_ume/runs/2026-01-04T19-10-12/` | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_750M_ESM_Atlas_2026-01-04_*.ckpt` | 8.3 GiB | + +**Latent Generator Models:** +| Model | Local Path | S3 Path | Size | +|-------|-----------|---------|------| +| LG Ligand | `/data2/ume/latent_generator_/runs/2025-11-09T14-23-55/last.ckpt` | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Ligand_2025-11-09.ckpt` | 250.5 MiB | +| **LG Protein Ligand 4096** | `/data2/ume/latent_generator_/runs/2026-01-05T16-48-02/last.ckpt` | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_4096_2026-01-05.ckpt` | 292.9 MiB | +| **LG Protein Ligand fsq 4375** | `/data2/ume/latent_generator_/runs/2026-01-05T16-13-19/last.ckpt` | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_fsq_4375_2026-01-05.ckpt` | 295.8 MiB | +| **LG Protein Ligand fsq 4375/15360** | `/data2/ume/latent_generator_/runs/2026-01-07T02-17-14/last.ckpt` | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_fsq_4375_15360_2026-01-07.ckpt` | 360.2 MiB | +| LG full attention 2 | `/data2/ume/latent_generator_/runs/2025-11-06T00-40-11/last.ckpt` | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_full_attention_2_2025-11-06.ckpt` | 245.3 MiB | + +### ❌ Missing Checkpoints (Lost - need retraining): +| Model | Original Path | Status | +|-------|--------------|--------| +| LG Protein Ligand (2025-12-07) | `/data2/ume/latent_generator_/runs/2025-12-07T22-38-42/` | **LOST** - needs retraining | +| LG Protein Ligand fsq 1000 (2025-12-13) | `/data2/ume/latent_generator_/runs/2025-12-13T14-57-53/` | **LOST** - needs retraining | + +--- + +## Processing Commands + +### Check Processing Status +```bash +# Check job status +squeue -u $USER + +# Check output counts +echo "PDBBind 12_15_25:"; find /data2/lisanzas/pdb_bind_12_15_25 -name "*_ligand.pt" 2>/dev/null | wc -l +echo "GEOM 12_15_25:"; find /data2/lisanzas/geom_12_15_25 -name "*.pt" 2>/dev/null | wc -l +echo "SAIR 12_15_25:"; find /data2/lisanzas/sair_12_15_25 -name "*_protein.pt" 2>/dev/null | wc -l +``` + +### Submit Processing Jobs +```bash +# PDBBind (fast, ~1-2 hours) +sbatch slurm/scripts/process_pdbbind_bond_matrix_array.sh + +# GEOM (slower, ~4-8 hours due to S3) +sbatch slurm/scripts/process_geom_bond_matrix_array.sh +``` + +### Train LG Protein-Ligand Models (after datasets ready) +```bash +# SLQ quantization (512 tokens) +sbatch slurm/scripts/train_latent_generator_protein_ligand_sair.sh + +# FSQ quantization (4375 tokens) +sbatch slurm/scripts/train_latent_generator_protein_ligand_fsq_ligand_4375.sh + +# FSQ quantization (1000 tokens) +sbatch slurm/scripts/train_latent_generator_protein_ligand_fsq_ligand_1000.sh +``` + +--- + +*Last updated: January 8, 2026 (added Gen-UME 90M, 750M ESM Atlas, and LG Protein Ligand fsq 4375/15360 checkpoints to S3)* + diff --git a/pyproject.toml b/pyproject.toml index c3bf85e5..798b2111 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "polars>=1.33.0", "hydra-submitit-launcher>=1.2.0", "rdkit>=2024.9.4", + "openbabel-wheel>=3.1.1", # For ligand structure minimization ] [build-system] diff --git a/slurm/scripts/train_latent_generator.sh b/slurm/scripts/train_latent_generator.sh new file mode 100644 index 00000000..19b43dcc --- /dev/null +++ b/slurm/scripts/train_latent_generator.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +#SBATCH --partition b200 +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 8 +#SBATCH --gpus-per-node 8 +#SBATCH --cpus-per-task 16 +#SBATCH -o /data2/ume/latent_generator_/slurm/logs/train/%J_%x.out +#SBATCH -q preempt +#SBATCH --mem=256G +#SBATCH --job-name=latent_generator +#SBATCH -t 7-00:00:00 + + +nvidia-smi + +source .venv/bin/activate +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" + +export LD_LIBRARY_PATH=/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/opt/amazon/ofi-nccl/lib64 + +export WANDB_INSECURE_DISABLE_SSL=true +export HYDRA_FULL_ERROR=1 +export PYTHONUNBUFFERED=1 +export NCCL_DEBUG=INFO + +export LOBSTER_RUNS_DIR="/data2/ume/latent_generator_/runs/" #"s3://prescient-lobster/ume/runs" # CHANGE TO YOUR S3 BUCKET +export LOBSTER_DATA_DIR="/data2/ume/.cache2/" # CHANGE TO YOUR DATA DIRECTORY +export LOBSTER_USER=$(whoami) # CHANGE TO YOUR WANDB USERNAME IF NOT YOUR UNIXID +export WANDB_BASE_URL=https://genentech.wandb.io + +export TOKENIZERS_PARALLELISM=true + +# Sets default permissions to allow group write +# access for newly created files. Remove if not needed +umask g+w + +srun -u --cpus-per-task $SLURM_CPUS_PER_TASK --cpu-bind=cores,verbose \ + lobster_train \ + experiment=train_latent_generator \ + data.num_workers=8 \ + ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ + trainer.num_sanity_val_steps=0 \ + +trainer.strategy=ddp_find_unused_parameters_true \ diff --git a/slurm/scripts/train_latent_generator_ligand.sh b/slurm/scripts/train_latent_generator_ligand.sh new file mode 100644 index 00000000..e5628aab --- /dev/null +++ b/slurm/scripts/train_latent_generator_ligand.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +#SBATCH --partition b200 +#SBATCH --nodes 1 +#SBATCH --ntasks-per-node 7 +#SBATCH --gpus-per-node 7 +#SBATCH --cpus-per-task 16 +#SBATCH -o /data2/ume/latent_generator_/slurm/logs/train/%J_%x.out +#SBATCH -q preempt +#SBATCH --mem=256G +#SBATCH --job-name=latent_generator +#SBATCH -t 7-00:00:00 + + +nvidia-smi + +source .venv/bin/activate +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" + +export LD_LIBRARY_PATH=/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/opt/amazon/ofi-nccl/lib64 + +export WANDB_INSECURE_DISABLE_SSL=true +export HYDRA_FULL_ERROR=1 +export PYTHONUNBUFFERED=1 +export NCCL_DEBUG=INFO + +export LOBSTER_RUNS_DIR="/data2/ume/latent_generator_/runs/" #"s3://prescient-lobster/ume/runs" # CHANGE TO YOUR S3 BUCKET +export LOBSTER_DATA_DIR="/data2/ume/.cache2/" # CHANGE TO YOUR DATA DIRECTORY +export LOBSTER_USER=$(whoami) # CHANGE TO YOUR WANDB USERNAME IF NOT YOUR UNIXID +export WANDB_BASE_URL=https://genentech.wandb.io + +export TOKENIZERS_PARALLELISM=true + +# Sets default permissions to allow group write +# access for newly created files. Remove if not needed +umask g+w + +srun -u --cpus-per-task $SLURM_CPUS_PER_TASK --cpu-bind=cores,verbose \ + lobster_train \ + experiment=train_latent_generator \ + data=structure_ligand \ + model=latent_generator_ligand \ + model.num_warmup_steps=10000 \ + model.num_training_steps=500000 \ + model.lr_scheduler.num_warmup_steps=10000 \ + model.lr_scheduler.num_training_steps=500000 \ + data.num_workers=8 \ + ++trainer.num_nodes=$SLURM_JOB_NUM_NODES \ + trainer.num_sanity_val_steps=0 \ + +trainer.strategy=ddp_find_unused_parameters_true \ + model.structure_encoder.encode_ligand=true \ + model.structure_encoder.embed_dim=256 \ + model.quantizer.ligand_n_tokens=512 \ + model.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true \ + +model.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=512 \ + +model.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_dim=512 \ + + + diff --git a/src/lobster/callbacks/__init__.py b/src/lobster/callbacks/__init__.py index e9a37b46..2ca9b690 100644 --- a/src/lobster/callbacks/__init__.py +++ b/src/lobster/callbacks/__init__.py @@ -11,6 +11,12 @@ from ._structure_decode import StructureDecodeCallback from ._unconditional_generation import UnconditionalGenerationCallback from ._auxiliary_task_loss_weight_scheduler import AuxiliaryTaskWeightScheduler, MultiTaskWeightScheduler +from ._inverse_folding_callback import InverseFoldingCallback +from ._forward_folding_callback import ForwardFoldingCallback +from ._protein_ligand_decode import ProteinLigandDecodeCallback +from ._protein_ligand_inverse_folding import ProteinLigandInverseFoldingCallback +from ._protein_ligand_forward_folding import ProteinLigandForwardFoldingCallback +from ._s3_checkpoint_callback import S3CheckpointBackupCallback __all__ = [ "MoleculeACELinearProbeCallback", @@ -27,4 +33,12 @@ "UmeGrpoLoggingCallback", "AuxiliaryTaskWeightScheduler", "MultiTaskWeightScheduler", + "StructureDecodeCallback", + "UnconditionalGenerationCallback", + "InverseFoldingCallback", + "ForwardFoldingCallback", + "ProteinLigandDecodeCallback", + "ProteinLigandInverseFoldingCallback", + "ProteinLigandForwardFoldingCallback", + "S3CheckpointBackupCallback", ] diff --git a/src/lobster/callbacks/_forward_folding_callback.py b/src/lobster/callbacks/_forward_folding_callback.py new file mode 100644 index 00000000..4d33845a --- /dev/null +++ b/src/lobster/callbacks/_forward_folding_callback.py @@ -0,0 +1,348 @@ +import lightning +import os +import torch +import glob +from loguru import logger + +from lobster.model.latent_generator.io import writepdb +from lobster.model.latent_generator.utils.residue_constants import ( + convert_lobster_aa_tokenization_to_standard_aa, + restype_order_with_x_inv, +) +from lobster.metrics import align_and_compute_rmsd +from lobster.transforms._structure_transforms import StructureBackboneTransform, AminoAcidTokenizerTransform +from tmtools import tm_align +import tqdm + + +class ForwardFoldingCallback(lightning.Callback): + """Callback for evaluating forward folding (sequence → structure) on CAMEO dataset during training. + + This callback: + - Loads CAMEO validation structures + - Extracts ground truth sequences + - Generates structures from sequences using the model + - Compares generated vs ground truth structures + - Logs TM-score and RMSD metrics to WandB + """ + + def __init__( + self, + structure_path: str = None, + cameo_data_path: str = "/data2/lisanzas/AFDB/valid_cameo_processed", + save_every_n: int = 1000, + num_samples: int = 127, + max_length: int = 512, + nsteps: int = 200, + temperature_seq: float = 0.3610371899835548, + temperature_struc: float = 0.2195534567490864, + stochasticity_seq: int = 1, + stochasticity_struc: int = 20, + cache_dir: str | None = None, + ): + """Initialize forward folding callback. + + Args: + structure_path: Directory to save generated structures + cameo_data_path: Path to CAMEO dataset directory (glob pattern supported) + save_every_n: Evaluate every N training steps + num_samples: Number of samples to evaluate per callback + max_length: Maximum sequence length to process + nsteps: Number of diffusion steps for generation + temperature_seq: Temperature for sequence sampling + temperature_struc: Temperature for structure sampling + stochasticity_seq: Stochasticity parameter for sequence + stochasticity_struc: Stochasticity parameter for structure + cache_dir: Cache directory for datasets + """ + self.structure_path = structure_path + self.cameo_data_path = cameo_data_path + self.save_every_n = save_every_n + self.num_samples = num_samples + self.max_length = max_length + self.nsteps = nsteps + self.temperature_seq = temperature_seq + self.temperature_struc = temperature_struc + self.stochasticity_seq = stochasticity_seq + self.stochasticity_struc = stochasticity_struc + self.cache_dir = cache_dir + self.cameo_structures = None + self.tokenizer_transform = None + self.structure_transform = None + + if self.structure_path and not os.path.exists(f"{self.structure_path}/forward_folding"): + os.makedirs(f"{self.structure_path}/forward_folding", exist_ok=True) + + def _load_cameo_structures(self): + """Load CAMEO .pt files directly. + + Loads preprocessed CAMEO structures from .pt files directly, + similar to how generate.py handles them. + """ + # Use glob to find all .pt files + if "*" in self.cameo_data_path: + dataset_paths = glob.glob(self.cameo_data_path) + else: + dataset_paths = glob.glob(os.path.join(self.cameo_data_path, "*.pt")) + + if not dataset_paths: + raise ValueError(f"No .pt files found in {self.cameo_data_path}") + + logger.info(f"Found {len(dataset_paths)} CAMEO dataset files") + + # Load structures from first few files (limit to avoid memory issues) + max_files = self.num_samples + dataset_paths = sorted(dataset_paths)[:max_files] + + structures = [] + for pt_path in dataset_paths: + logger.info(f"Loading structures from {pt_path}") + structure_data = torch.load(pt_path, map_location="cpu") + + # Apply StructureBackboneTransform to ensure consistent format + structure_data = self.structure_transform(structure_data) + + # Filter by minimum length and valid sequences + if structure_data["coords_res"].shape[0] >= 30: + percent_unknown = (structure_data["sequence"] == 20).sum().float() / structure_data["sequence"].shape[0] + if percent_unknown <= 0.1: # Less than 10% unknown + structures.append(structure_data) + + # Limit total structures for efficiency + if len(structures) >= self.num_samples * 3: # Load 3x more than needed + break + + logger.info(f"Loaded {len(structures)} valid structures from CAMEO dataset") + return structures + + def setup(self, trainer, pl_module, stage): + """Setup method to load CAMEO structures for evaluation.""" + # Only setup on rank 0 (CUDA device 0) in multinode/multi-GPU settings + if trainer.global_rank != 0: + return + + # Initialize transforms + self.structure_transform = StructureBackboneTransform(max_length=self.max_length) + self.tokenizer_transform = AminoAcidTokenizerTransform(max_length=self.max_length) + + # Load CAMEO structures directly + self.cameo_structures = self._load_cameo_structures() + + logger.info(f"Loaded {len(self.cameo_structures)} CAMEO structures for evaluation") + + def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): + """Called at the end of each training batch.""" + current_step = trainer.global_step + + # Check if this is a step where forward folding should run + # Note: cameo_structures is only loaded on rank 0, so we use batch_idx check for all ranks + is_forward_folding_step = batch_idx % self.save_every_n == 0 + + # All ranks must synchronize BEFORE rank 0 does forward folding + # This prevents NCCL timeouts from internal collectives in generate_sample + if is_forward_folding_step and trainer.world_size > 1: + torch.distributed.barrier() + + # Only rank 0 actually runs the callback (it has the loaded structures) + if is_forward_folding_step and trainer.global_rank == 0 and self.cameo_structures is not None: + # Get device from model or from any available tensor in batch + device = next(model.parameters()).device + + # Perform forward folding on CAMEO validation examples + # Use model.module to avoid DDP wrapper during inference + with torch.no_grad(): + unwrapped_model = model.module if hasattr(model, "module") else model + self._perform_forward_folding(trainer, unwrapped_model, device, batch_idx, current_step) + torch.cuda.empty_cache() + + # Final barrier to ensure all ranks are synced before continuing training + if is_forward_folding_step and trainer.world_size > 1: + torch.distributed.barrier() + + def _perform_forward_folding(self, trainer, model, device, batch_idx, current_step): + """Perform forward folding on CAMEO validation examples. + + Process: + 1. Load ground truth sequences and structures from CAMEO + 2. Tokenize sequences and prepare inputs + 3. Generate structures using model.generate_sample(forward_folding=True) + 4. Compare generated structures to ground truth + 5. Calculate TM-score and RMSD + 6. Save structures and log metrics + """ + # Initialize lists to accumulate metrics + all_tm_scores = [] + all_rmsd_scores = [] + + # Process CAMEO structures in batches + batch_size = 10 + num_structures = min(len(self.cameo_structures), self.num_samples) + + logger.info(f"Evaluating forward folding on {num_structures} CAMEO structures") + + for batch_start in tqdm.tqdm( + range(0, num_structures, batch_size), + desc="Forward Folding Evaluation", + ): + batch_end = min(batch_start + batch_size, num_structures) + batch_structures = self.cameo_structures[batch_start:batch_end] + + # Prepare batch tensors + max_length = max(s["coords_res"].shape[0] for s in batch_structures) + B = len(batch_structures) + + sequence = torch.zeros((B, max_length), dtype=torch.long, device=device) + coords_res = torch.zeros((B, max_length, 3, 3), device=device) + mask = torch.zeros((B, max_length), device=device) + indices = torch.zeros((B, max_length), dtype=torch.long, device=device) + + # Fill batch tensors + for i, structure in enumerate(batch_structures): + L = structure["coords_res"].shape[0] + sequence[i, :L] = structure["sequence"].to(device) + coords_res[i, :L] = structure["coords_res"].to(device) + mask[i, :L] = structure["mask"].to(device) + indices[i, :L] = structure["indices"].to(device) + + mask_orig = mask.clone() + + # Handle NaN coordinates + nan_indices = torch.isnan(coords_res).any(dim=-1).any(dim=-1) + mask[nan_indices] = 0 + coords_res[nan_indices] = 0 + + # Tokenize sequences for forward folding + # Note: sequence is already in standard format from dataset + # Apply tokenizer transform to convert to model input format + tokenized_sequences = torch.zeros((B, max_length), device=device, dtype=torch.long) + for i in range(B): + seq_i = sequence[i, mask[i] == 1] + # Apply tokenizer transform + tokenized_data = self.tokenizer_transform({"sequence": seq_i.cpu()}) + tokenized_seq = tokenized_data["sequence"].to(device) + seq_len = min(len(tokenized_seq), max_length) + tokenized_sequences[i, :seq_len] = tokenized_seq[:seq_len] + + # Generate structures from sequences (forward folding) + logger.info(f"Generating structures for batch (batch {batch_start}-{batch_end}, {B} samples)...") + generate_sample = model.generate_sample( + length=max_length, + num_samples=B, + forward_folding=True, + nsteps=self.nsteps, + temperature_seq=self.temperature_seq, + temperature_struc=self.temperature_struc, + stochasticity_seq=self.stochasticity_seq, + stochasticity_struc=self.stochasticity_struc, + input_sequence_tokens=tokenized_sequences, + input_mask=mask, + input_indices=indices, + ) + + # Decode structures + decoded_x = model.decode_structure(generate_sample, mask) + + # Extract coordinates from vit_decoder + x_recon_xyz = None + for decoder_name in decoded_x: + if "vit_decoder" == decoder_name: + vit_output = decoded_x[decoder_name] + # Handle both tensor output (protein-only) and dict output (protein-ligand) + if isinstance(vit_output, dict): + x_recon_xyz = vit_output.get("protein_coords") + else: + x_recon_xyz = vit_output + break + + if x_recon_xyz is None: + logger.error("No vit_decoder found in decoded structures") + continue + + # Extract sequences (should match input, but check) + if generate_sample["sequence_logits"].shape[-1] == 33: + seq = convert_lobster_aa_tokenization_to_standard_aa(generate_sample["sequence_logits"], device=device) + else: + seq = generate_sample["sequence_logits"].argmax(dim=-1) + seq[seq > 21] = 20 + + # Only save structures for the first batch + if batch_start == 0: + # Save generated and ground truth structures + for i in range(min(5, B)): # Save first 5 samples + seq_i = seq[i, mask_orig[i] == 1] + + # Save generated structure + filename_gen = ( + f"{self.structure_path}/forward_folding/struc_{batch_idx}_{current_step}_{i}_generated.pdb" + ) + writepdb(filename_gen, x_recon_xyz[i], seq_i) + logger.info(f"Saved generated: {filename_gen}") + + # Save ground truth structure + filename_gt = ( + f"{self.structure_path}/forward_folding/struc_{batch_idx}_{current_step}_{i}_ground_truth.pdb" + ) + writepdb(filename_gt, coords_res[i], seq_i) + logger.info(f"Saved ground truth: {filename_gt}") + + # Calculate TM-score and RMSD for all samples in batch + logger.info("Calculating structural metrics...") + batch_tm_scores = [] + batch_rmsd_scores = [] + + for i in range(B): + # Get ground truth and generated coordinates (masked) + gt_coords = coords_res[i, mask_orig[i] == 1, :, :] # Ground truth + gen_coords = x_recon_xyz[i, mask_orig[i] == 1, :, :] # Generated + + # Get sequence string for TM-align + seq_i = seq[i, mask_orig[i] == 1] + sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_i]) + + # Calculate TM-Score using TM-align (CA atoms only) + tm_out = tm_align( + gen_coords[:, 1, :].cpu().numpy(), # CA atoms of generated structure + gt_coords[:, 1, :].detach().cpu().numpy(), # CA atoms of ground truth + sequence_str, + sequence_str, + ) + batch_tm_scores.append(tm_out.tm_norm_chain1) + + # Calculate RMSD using Kabsch alignment (all backbone atoms) + rmsd = align_and_compute_rmsd( + coords1=gen_coords, + coords2=gt_coords, + mask=None, # Use all positions + return_aligned=False, + device=device, + ) + batch_rmsd_scores.append(rmsd) + + # Accumulate metrics + all_tm_scores.extend(batch_tm_scores) + all_rmsd_scores.extend(batch_rmsd_scores) + + logger.info( + f"Batch {batch_start}-{batch_end}: Avg TM-score={sum(batch_tm_scores) / len(batch_tm_scores):.3f}, " + f"Avg RMSD={sum(batch_rmsd_scores) / len(batch_rmsd_scores):.2f} Å" + ) + + # Calculate averaged metrics across all validation batches + if not all_tm_scores: + logger.warning("No metrics collected for forward folding callback") + return + + avg_tm_score = sum(all_tm_scores) / len(all_tm_scores) + avg_rmsd = sum(all_rmsd_scores) / len(all_rmsd_scores) + + # Log averaged metrics to WandB + metrics_to_log = { + "forward_folding/tm_score": avg_tm_score, + "forward_folding/rmsd": avg_rmsd, + "forward_folding/num_samples": len(all_tm_scores), + } + model.log_dict(metrics_to_log, batch_size=1) + + logger.info(f"Forward Folding Validation Results (step {current_step}):") + logger.info(f" Average TM-score: {avg_tm_score:.3f} (n={len(all_tm_scores)})") + logger.info(f" Average RMSD: {avg_rmsd:.2f} Å") diff --git a/src/lobster/callbacks/_inverse_folding_callback.py b/src/lobster/callbacks/_inverse_folding_callback.py index f218f290..fe045410 100644 --- a/src/lobster/callbacks/_inverse_folding_callback.py +++ b/src/lobster/callbacks/_inverse_folding_callback.py @@ -1,6 +1,8 @@ import lightning import os import torch +import torch.distributed as dist +import glob from lobster.model.latent_generator.io import writepdb from loguru import logger from lobster.model.latent_generator.utils.residue_constants import ( @@ -9,6 +11,7 @@ ) from lobster.metrics import get_folded_structure_metrics, calculate_percent_identity from lobster.data._coord_structure_datamodule import StructureLightningDataModule +from lobster.transforms._structure_transforms import StructureBackboneTransform import tqdm from lobster.model import LobsterPLMFold from torch.utils.data import DataLoader @@ -20,6 +23,9 @@ def __init__( self, structure_path: str = None, save_every_n: int = 1000, + dataset_name: str = "cath", + dataset_path: str | None = None, + metric_prefix: str | None = None, length: int = 100, num_samples: int = 10, use_plm_fold: bool = True, @@ -29,6 +35,8 @@ def __init__( ): self.structure_path = structure_path self.save_every_n = save_every_n + self.dataset_name = dataset_name.lower() + self.dataset_path = dataset_path self.length = length self.num_samples = num_samples self.use_plm_fold = use_plm_fold @@ -39,6 +47,11 @@ def __init__( self.eval_datamodule = None self.use_hf_datasets = use_hf_datasets self.cache_dir = cache_dir + self.loaded_structures = None + self.structure_transform = None + + # Auto-generate metric prefix if not provided + self.metric_prefix = metric_prefix or f"inverse_folding_{self.dataset_name}" if not os.path.exists(f"{self.structure_path}/inverse_folding"): os.makedirs(f"{self.structure_path}/inverse_folding", exist_ok=True) @@ -74,8 +87,53 @@ def _download_cath_datasets(self) -> list[str]: return downloaded_paths + def _load_structures_from_pt_files(self): + """Load preprocessed structures from .pt files directly. + + Loads preprocessed structures from .pt files, applying necessary + transforms and filtering. + """ + if not self.dataset_path: + raise ValueError("dataset_path must be provided when loading from .pt files") + + # Use glob to find all .pt files + if "*" in self.dataset_path: + dataset_paths = glob.glob(self.dataset_path) + else: + dataset_paths = glob.glob(os.path.join(self.dataset_path, "*.pt")) + + if not dataset_paths: + raise ValueError(f"No .pt files found in {self.dataset_path}") + + logger.info(f"Found {len(dataset_paths)} dataset files at {self.dataset_path}") + + # Load structures from first few files (limit to avoid memory issues) + max_files = self.num_samples + dataset_paths = sorted(dataset_paths)[:max_files] + + structures = [] + for pt_path in dataset_paths: + logger.info(f"Loading structures from {pt_path}") + structure_data = torch.load(pt_path, map_location="cpu") + + # Apply StructureBackboneTransform to ensure consistent format + structure_data = self.structure_transform(structure_data) + + # Filter by minimum length and valid sequences + if structure_data["coords_res"].shape[0] >= 30: + percent_unknown = (structure_data["sequence"] == 20).sum().float() / structure_data["sequence"].shape[0] + if percent_unknown <= 0.1: # Less than 10% unknown + structures.append(structure_data) + + # Limit total structures for efficiency + if len(structures) >= self.num_samples * 3: # Load 3x more than needed + break + + logger.info(f"Loaded {len(structures)} valid structures from dataset") + return structures + def _create_eval_datamodule(self): - """Create a separate datamodule for evaluation.""" + """Create a separate datamodule for evaluation (CATH only).""" if self.use_hf_datasets: # Download from Hugging Face logger.info("Using Hugging Face datasets for CATH data") @@ -116,36 +174,99 @@ def setup(self, trainer, pl_module, stage): self.plm_fold = LobsterPLMFold(model_name="esmfold_v1", max_length=self.max_length) logger.info("Loaded ESMFold model for inverse folding evaluation") - # Create separate evaluation datamodule - self._create_eval_datamodule() - # Get validation dataset directly (avoid trainer dependency in val_dataloader) - self.val_dataset = self.eval_datamodule._val_dataset + # Load data based on whether dataset_path is provided + if self.dataset_path: + # Load structures directly from .pt files + logger.info( + f"Loading structures from .pt files for inverse folding evaluation (dataset: {self.dataset_name})" + ) + self.structure_transform = StructureBackboneTransform(max_length=self.max_length) + self.loaded_structures = self._load_structures_from_pt_files() + logger.info(f"Loaded {len(self.loaded_structures)} structures for evaluation") + else: + # Use CATH datamodule (default behavior) + logger.info(f"Loading CATH dataset for inverse folding evaluation (dataset: {self.dataset_name})") + # Create separate evaluation datamodule + self._create_eval_datamodule() + # Get validation dataset directly (avoid trainer dependency in val_dataloader) + self.val_dataset = self.eval_datamodule._val_dataset + + # Create our own dataloader to avoid trainer state issues + self.val_dataloader = DataLoader( + self.val_dataset, + batch_size=20, + shuffle=False, + num_workers=0, # Use 0 to avoid multiprocessing issues in callbacks + collate_fn=self.eval_datamodule._collate_fn, + ) - # Create our own dataloader to avoid trainer state issues - self.val_dataloader = DataLoader( - self.val_dataset, - batch_size=20, - shuffle=False, - num_workers=0, # Use 0 to avoid multiprocessing issues in callbacks - collate_fn=self.eval_datamodule._collate_fn, - ) + logger.info(f"Created validation dataloader with {len(self.val_dataset)} examples") - logger.info(f"Created validation dataloader with {len(self.val_dataset)} examples") + def _is_distributed(self) -> bool: + """Check if we're running in a distributed setting.""" + return dist.is_available() and dist.is_initialized() - def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): - # Only run on rank 0 (CUDA device 0) in multinode/multi-GPU settings - if trainer.global_rank != 0: - return + def _barrier(self): + """Synchronize all ranks if running in distributed mode.""" + if self._is_distributed(): + dist.barrier() + def _get_unwrapped_model(self, model): + """Get the underlying model from DDP/FSDP wrapper if present. + + This avoids triggering collective operations during callback evaluation. + """ + # Handle PyTorch Lightning's LightningModule wrapping + if hasattr(model, "module"): + return model.module + return model + + def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): current_step = trainer.global_step - device = batch["sequence"].device - if self.use_plm_fold and self.plm_fold is not None: - self.plm_fold.to(device) + # Check if this is a callback step - use deterministic condition based on batch_idx only + # All ranks must make the same decision to avoid deadlock at barriers + is_callback_step = batch_idx % self.save_every_n == 0 - if batch_idx % self.save_every_n == 0 and self.val_dataloader is not None: - # Perform inverse folding on validation examples - self._perform_inverse_folding(trainer, model, device, batch_idx, current_step) + if not is_callback_step: + return + + # All ranks hit barrier before callback starts to ensure sync + # This prevents other ranks from continuing training while rank 0 runs evaluation + self._barrier() + + # Only rank 0 runs the actual evaluation + if trainer.global_rank == 0: + # Check if we have data to evaluate + has_data = (self.val_dataloader is not None) or (self.loaded_structures is not None) + + if has_data: + # Get device from whatever tensor is available in the batch + if "sequence" in batch: + device = batch["sequence"].device + elif "ligand_coords" in batch: + device = batch["ligand_coords"].device + elif "coords" in batch: + device = batch["coords"].device + else: + # Fallback to model device + device = next(model.parameters()).device + + if self.use_plm_fold and self.plm_fold is not None: + self.plm_fold.to(device) + + # Perform inverse folding on validation examples + with torch.no_grad(): + if self.dataset_path: + # Use direct .pt file loading method + self._perform_inverse_folding_pt_files(trainer, model, device, batch_idx, current_step) + else: + # Use CATH datamodule method + self._perform_inverse_folding(trainer, model, device, batch_idx, current_step) + torch.cuda.empty_cache() + + # All ranks hit barrier after callback completes to ensure sync before resuming training + self._barrier() def _perform_inverse_folding(self, trainer, model, device, batch_idx, current_step): """Perform inverse folding on validation examples.""" @@ -179,7 +300,12 @@ def _perform_inverse_folding(self, trainer, model, device, batch_idx, current_st for decoder_name in decoded_x: if "vit_decoder" == decoder_name: - x_recon_xyz = decoded_x[decoder_name] + vit_output = decoded_x[decoder_name] + # Handle both tensor output (protein-only) and dict output (protein-ligand) + if isinstance(vit_output, dict): + x_recon_xyz = vit_output.get("protein_coords") + else: + x_recon_xyz = vit_output if generate_sample["sequence_logits"].shape[-1] == 33: seq = convert_lobster_aa_tokenization_to_standard_aa(generate_sample["sequence_logits"], device=device) else: @@ -247,16 +373,131 @@ def _perform_inverse_folding(self, trainer, model, device, batch_idx, current_st values = [metrics[key] for metrics in all_folded_structure_metrics if key in metrics] avg_folded_metrics[key] = sum(values) / len(values) - # Log averaged metrics + # Log averaged metrics with dataset-specific prefix total_loss = 0.0 metrics_to_log = { - "inverse_folding_loss": total_loss, - "sequence_percent_identity": avg_percent_identity, - **avg_folded_metrics, + f"{self.metric_prefix}/loss": total_loss, + f"{self.metric_prefix}/sequence_recovery": avg_percent_identity, } - model.log_dict(metrics_to_log, batch_size=1) # Use batch_size=1 since we're logging aggregated metrics + # Add folded metrics with prefix + for key, value in avg_folded_metrics.items(): + # Remove "inverse_folding_" prefix if it exists to avoid duplication + clean_key = key.replace("inverse_folding_", "") + metrics_to_log[f"{self.metric_prefix}/{clean_key}"] = value + + # Use sync_dist=False and rank_zero_only=True to avoid distributed sync during callback + model.log_dict(metrics_to_log, batch_size=1, sync_dist=False, rank_zero_only=True) logger.info(f"Validation metrics averaged over {len(all_folded_structure_metrics)} batches:") logger.info(f"Average sequence percent identity: {avg_percent_identity:.2f}%") for key, value in avg_folded_metrics.items(): logger.info(f"Average {key}: {value:.4f}") + + def _perform_inverse_folding_pt_files(self, trainer, model, device, batch_idx, current_step): + """Perform inverse folding on structures loaded from .pt files. + + Similar to _perform_inverse_folding but uses structures loaded directly from .pt files. + """ + # Initialize lists to accumulate metrics + all_percent_identities = [] + + # Process loaded structures in batches + batch_size = 20 + num_structures = min(len(self.loaded_structures), self.num_samples) + + logger.info(f"Evaluating inverse folding on {num_structures} structures from {self.dataset_name} dataset") + + for batch_start in tqdm.tqdm( + range(0, num_structures, batch_size), + desc=f"Inverse Folding Evaluation ({self.dataset_name})", + ): + batch_end = min(batch_start + batch_size, num_structures) + batch_structures = self.loaded_structures[batch_start:batch_end] + + # Prepare batch tensors + max_length = max(s["coords_res"].shape[0] for s in batch_structures) + B = len(batch_structures) + + sequence = torch.zeros((B, max_length), dtype=torch.long, device=device) + coords_res = torch.zeros((B, max_length, 3, 3), device=device) + mask = torch.zeros((B, max_length), device=device) + indices = torch.zeros((B, max_length), dtype=torch.long, device=device) + + # Fill batch tensors + for i, structure in enumerate(batch_structures): + L = structure["coords_res"].shape[0] + sequence[i, :L] = structure["sequence"].to(device) + coords_res[i, :L] = structure["coords_res"].to(device) + mask[i, :L] = structure["mask"].to(device) + indices[i, :L] = structure["indices"].to(device) + + mask_orig = mask.clone() + + # Handle NaN coordinates + nan_indices = torch.isnan(coords_res).any(dim=-1).any(dim=-1) + mask[nan_indices] = 0 + coords_res[nan_indices] = 0 + + # Generate sequences (inverse folding) + logger.info(f"Generating sequences for batch {batch_start}-{batch_end} ({B} samples)...") + generate_sample = model.generate_sample( + length=max_length, + num_samples=B, + inverse_folding=True, + nsteps=100, + input_structure_coords=coords_res, + input_mask=mask, + input_indices=indices, + ) + decoded_x = model.decode_structure(generate_sample, mask) + + for decoder_name in decoded_x: + if "vit_decoder" == decoder_name: + vit_output = decoded_x[decoder_name] + # Handle both tensor output (protein-only) and dict output (protein-ligand) + if isinstance(vit_output, dict): + x_recon_xyz = vit_output.get("protein_coords") + else: + x_recon_xyz = vit_output + + if generate_sample["sequence_logits"].shape[-1] == 33: + seq = convert_lobster_aa_tokenization_to_standard_aa(generate_sample["sequence_logits"], device=device) + else: + seq = generate_sample["sequence_logits"].argmax(dim=-1) + seq[seq > 21] = 20 + + # Only save structures for the first batch + if batch_start == 0: + # save the generated structure + for i in range(min(5, B)): # Save first 5 samples + filename = f"{self.structure_path}/inverse_folding/struc_{batch_idx}_{current_step}_{i}_inverse_folding_{self.dataset_name}.pdb" + writepdb(filename, x_recon_xyz[i], seq[i]) + logger.info(f"Saved {filename}") + + # Calculate percent identity between ground truth and generated sequences + percent_identities = calculate_percent_identity(sequence, seq, mask_orig) + + # Accumulate metrics + all_percent_identities.extend(percent_identities.cpu().tolist()) + + logger.info( + f"Batch {batch_start}-{batch_end}: Avg sequence recovery={sum(percent_identities.cpu().tolist()) / len(percent_identities):.2f}%" + ) + + # Calculate averaged metrics across all batches + if not all_percent_identities: + logger.warning(f"No metrics collected for inverse folding callback ({self.dataset_name})") + return + + avg_percent_identity = sum(all_percent_identities) / len(all_percent_identities) + + # Log averaged metrics with dataset-specific prefix + # Use sync_dist=False and rank_zero_only=True to avoid distributed sync during callback + metrics_to_log = { + f"{self.metric_prefix}/sequence_recovery": avg_percent_identity, + f"{self.metric_prefix}/num_samples": float(len(all_percent_identities)), + } + model.log_dict(metrics_to_log, batch_size=1, sync_dist=False, rank_zero_only=True) + + logger.info(f"Inverse Folding Validation Results ({self.dataset_name}, step {current_step}):") + logger.info(f" Average sequence recovery: {avg_percent_identity:.2f}% (n={len(all_percent_identities)})") diff --git a/src/lobster/callbacks/_protein_ligand_decode.py b/src/lobster/callbacks/_protein_ligand_decode.py new file mode 100644 index 00000000..900c4097 --- /dev/null +++ b/src/lobster/callbacks/_protein_ligand_decode.py @@ -0,0 +1,336 @@ +"""Callback for decoding and saving protein-ligand complexes during training.""" + +import os + +import lightning +import torch +from loguru import logger + +from lobster.model.latent_generator.io import writepdb_ligand_complex +from lobster.model.latent_generator.utils import minimize_ligand_structure +from lobster.model.latent_generator.utils.residue_constants import ( + ELEMENT_VOCAB_EXTENDED, + convert_lobster_aa_tokenization_to_standard_aa, +) + + +class ProteinLigandDecodeCallback(lightning.Callback): + """Callback to decode and save protein-ligand complexes during training. + + This callback saves: + 1. Decoded protein structure + 2. Decoded ligand structure (atom positions, types, bonds) + 3. Combined protein-ligand complex PDB + + Parameters + ---------- + structure_path : str + Base path for saving structures + save_every_n : int + Save structures every N batches + save_separate : bool + Whether to save protein and ligand separately in addition to complex + minimize_ligand : bool + Whether to apply geometry correction to ligand structures + minimize_mode : str + Minimization mode: "bonds_only" or "bonds_and_angles" (recommended) + force_field : str + Force field for minimization: "MMFF94", "MMFF94s", "UFF", etc. + minimize_steps : int + Maximum number of minimization steps + """ + + def __init__( + self, + structure_path: str = None, + save_every_n: int = 1000, + save_separate: bool = True, + minimize_ligand: bool = False, + minimize_mode: str = "bonds_and_angles", + force_field: str = "MMFF94", + minimize_steps: int = 500, + ): + self.structure_path = structure_path + self.save_every_n = save_every_n + self.save_separate = save_separate + self.minimize_ligand = minimize_ligand + self.minimize_mode = minimize_mode + self.force_field = force_field + self.minimize_steps = minimize_steps + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create output directories + self.complex_dir = f"{self.structure_path}/complexes" + os.makedirs(self.complex_dir, exist_ok=True) + if save_separate: + os.makedirs(f"{self.structure_path}/proteins", exist_ok=True) + os.makedirs(f"{self.structure_path}/ligands", exist_ok=True) + + if self.minimize_ligand: + logger.info(f"Ligand minimization enabled: mode={minimize_mode}, force_field={force_field}") + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + """Save decoded structures at specified intervals.""" + # Only run on rank 0 in distributed setting + if trainer.global_rank != 0: + return + + current_step = trainer.global_step + + if batch_idx % self.save_every_n != 0: + return + + has_protein = outputs.get("has_protein", True) # Default True for backward compat + has_ligand = outputs.get("has_ligand", False) + + # Need at least decoded protein or ligand structure + if "decoded_x" not in outputs and "decoded_ligand_x" not in outputs: + return + + # Get decoded structures from vit_decoder (handles both protein and ligand) + x_recon_xyz = None + ligand_coords_from_decoder = None + seq = None + + if "decoded_x" in outputs: + x_recon = outputs["decoded_x"] + for decoder_name in x_recon: + if "vit_decoder" == decoder_name: + vit_output = x_recon[decoder_name] + # Handle both old format (tensor) and new format (dict with protein_coords/ligand_coords) + if isinstance(vit_output, dict): + x_recon_xyz = vit_output.get("protein_coords") + ligand_coords_from_decoder = vit_output.get("ligand_coords") + else: + x_recon_xyz = vit_output + + # Get protein sequence (if protein present) + # Check if model uses 33-token vocab (from AminoAcidTokenizerTransform) + uses_33_token_vocab = False + if has_protein and x_recon_xyz is not None: + if outputs["unmasked_x"]["sequence_logits"].shape[-1] == 33: + uses_33_token_vocab = True + seq = convert_lobster_aa_tokenization_to_standard_aa( + outputs["unmasked_x"]["sequence_logits"], device=self.device + ) + else: + seq = outputs["unmasked_x"]["sequence_logits"].argmax(dim=-1) + seq[seq > 21] = 20 + + # Get timesteps for filename (use ligand timesteps for ligand-only) + if has_protein: + t_seq = outputs["train_timesteps_seq"][0].cpu().numpy() + t_struc = outputs["train_timesteps_struc"][0].cpu().numpy() + else: + # Ligand-only batch + t_seq = 0.0 + t_struc = outputs.get("train_timesteps_ligand", outputs["train_timesteps_struc"])[0].cpu().numpy() + + # === SAVE LIGAND STRUCTURES === + ligand_coords = None + ligand_atom_types = None + + if has_ligand and "decoded_ligand_x" in outputs: + decoded_ligand = outputs["decoded_ligand_x"] + ligand_mask = outputs.get("ligand_mask") + + # Get ligand coordinates from unified vit_decoder + if ligand_coords_from_decoder is not None: + ligand_coords = ligand_coords_from_decoder[0] # First sample (decoded) + else: + logger.warning( + "No decoded ligand coordinates available - vit_decoder may not be returning ligand_coords" + ) + + # Get ligand atom types + if "atom_types" in decoded_ligand: + # Convert token indices to element names + atom_indices = decoded_ligand["atom_types"][0] + if ligand_mask is not None: + valid_mask = ligand_mask[0].bool() + atom_indices = atom_indices[valid_mask] + ligand_atom_types = self._indices_to_atom_names(atom_indices) + elif "ligand_atom_names" in batch: + ligand_atom_types = batch["ligand_atom_names"][0] + + # Save complex PDB (protein + ligand) + if has_protein and has_ligand and x_recon_xyz is not None and ligand_coords is not None: + if ligand_atom_types is not None: + complex_filename = ( + f"{self.complex_dir}/complex_{batch_idx}_{current_step}_tseq_{t_seq:.2f}_tstruc_{t_struc:.2f}.pdb" + ) + + # Get bond matrix if available (key can be "bond_matrix" or "ligand_bond_matrix") + bond_matrix = batch.get("bond_matrix", batch.get("ligand_bond_matrix", None)) + if bond_matrix is not None: + bond_matrix = bond_matrix[0] # First sample + # Apply mask if present + ligand_mask = outputs.get("ligand_mask") + if ligand_mask is not None: + valid_mask = ligand_mask[0].bool() + bond_matrix = bond_matrix[valid_mask][:, valid_mask] + + # Apply minimization if enabled + ligand_coords_to_save = ligand_coords.cpu() if torch.is_tensor(ligand_coords) else ligand_coords + if self.minimize_ligand: + try: + ligand_coords_to_save = minimize_ligand_structure( + ligand_coords_to_save, + ligand_atom_types, + bond_matrix=bond_matrix, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Ligand minimization failed: {e}") + + try: + writepdb_ligand_complex( + filename=complex_filename, + protein_atoms=x_recon_xyz[0], + protein_seq=seq[0], + protein_chain="A", + ligand_atoms=ligand_coords_to_save, + ligand_atom_names=ligand_atom_types, + ligand_chain="L", + ligand_resname="LIG", + ligand_bond_matrix=bond_matrix, + ) + logger.info(f"Saved complex: {complex_filename}") + except Exception as e: + logger.warning(f"Failed to save complex: {e}") + + # Save ground truth complex + if "coords_res" in batch and "ligand_coords" in batch: + gt_filename = f"{self.complex_dir}/complex_{batch_idx}_{current_step}_gt.pdb" + gt_ligand_names = batch.get("ligand_atom_names", [["C"] * batch["ligand_coords"].shape[1]])[0] + + # Get GT bond matrix if available (key can be "bond_matrix" or "ligand_bond_matrix") + gt_bond_matrix = batch.get("bond_matrix", batch.get("ligand_bond_matrix", None)) + if gt_bond_matrix is not None: + gt_bond_matrix = gt_bond_matrix[0] + + try: + gt_seq = batch["sequence"][0] + # Handle both 33-token (from AminoAcidTokenizerTransform) and 21-token formats + if uses_33_token_vocab: + gt_seq = convert_lobster_aa_tokenization_to_standard_aa( + gt_seq.unsqueeze(0), device=self.device + )[0] + else: + gt_seq = gt_seq.clone() + gt_seq[gt_seq > 21] = 20 + writepdb_ligand_complex( + filename=gt_filename, + protein_atoms=batch["coords_res"][0], + protein_seq=gt_seq, + protein_chain="A", + ligand_atoms=batch["ligand_coords"][0].cpu(), + ligand_atom_names=gt_ligand_names, + ligand_chain="L", + ligand_resname="LIG", + ligand_bond_matrix=gt_bond_matrix, + ) + logger.info(f"Saved ground truth complex: {gt_filename}") + except Exception as e: + logger.warning(f"Failed to save GT complex: {e}") + + # Save ligand-only if no protein (GEOM dataset) + elif not has_protein and has_ligand and ligand_coords is not None and ligand_atom_types is not None: + ligand_filename = f"{self.structure_path}/ligands/ligand_{batch_idx}_{current_step}_tlig_{t_struc:.2f}.pdb" + + # Get bond matrix if available (key can be "bond_matrix" or "ligand_bond_matrix") + bond_matrix = batch.get("bond_matrix", batch.get("ligand_bond_matrix", None)) + if bond_matrix is not None: + bond_matrix = bond_matrix[0] # First sample + ligand_mask = outputs.get("ligand_mask") + if ligand_mask is not None: + valid_mask = ligand_mask[0].bool() + bond_matrix = bond_matrix[valid_mask][:, valid_mask] + + # Apply minimization if enabled + ligand_coords_to_save = ligand_coords.cpu() if torch.is_tensor(ligand_coords) else ligand_coords + if self.minimize_ligand: + try: + ligand_coords_to_save = minimize_ligand_structure( + ligand_coords_to_save, + ligand_atom_types, + bond_matrix=bond_matrix, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Ligand minimization failed: {e}") + + try: + writepdb_ligand_complex( + filename=ligand_filename, + protein_atoms=None, # No protein + protein_seq=None, + ligand_atoms=ligand_coords_to_save, + ligand_atom_names=ligand_atom_types, + ligand_chain="L", + ligand_resname="LIG", + ligand_bond_matrix=bond_matrix, + ) + logger.info(f"Saved ligand-only: {ligand_filename}") + except Exception as e: + logger.warning(f"Failed to save ligand: {e}") + + # Save ground truth ligand + if "ligand_coords" in batch: + gt_filename = f"{self.structure_path}/ligands/ligand_{batch_idx}_{current_step}_gt.pdb" + gt_ligand_names = batch.get("ligand_atom_names", [["C"] * batch["ligand_coords"].shape[1]])[0] + + # Get GT bond matrix (key can be "bond_matrix" or "ligand_bond_matrix") + gt_bond_matrix = batch.get("bond_matrix", batch.get("ligand_bond_matrix", None)) + if gt_bond_matrix is not None: + gt_bond_matrix = gt_bond_matrix[0] + + try: + writepdb_ligand_complex( + filename=gt_filename, + protein_atoms=None, + protein_seq=None, + ligand_atoms=batch["ligand_coords"][0].cpu(), + ligand_atom_names=gt_ligand_names, + ligand_chain="L", + ligand_resname="LIG", + ligand_bond_matrix=gt_bond_matrix, + ) + logger.info(f"Saved ground truth ligand: {gt_filename}") + except Exception as e: + logger.warning(f"Failed to save GT ligand: {e}") + + # Save protein-only if requested or no ligand + if has_protein and x_recon_xyz is not None and (self.save_separate or not has_ligand): + from lobster.model.latent_generator.io import writepdb + + protein_filename = ( + f"{self.structure_path}/proteins/protein_{batch_idx}_{current_step}_" + f"tseq_{t_seq:.2f}_tstruc_{t_struc:.2f}.pdb" + ) + try: + writepdb(protein_filename, x_recon_xyz[0], seq[0]) + logger.info(f"Saved protein: {protein_filename}") + except Exception as e: + logger.warning(f"Failed to save protein: {e}") + + def _indices_to_atom_names(self, indices: torch.Tensor) -> list[str]: + """Convert element indices to atom names.""" + atom_names = [] + for i, idx in enumerate(indices.cpu().tolist()): + # ELEMENT_VOCAB_EXTENDED is a list, index directly + if 0 <= idx < len(ELEMENT_VOCAB_EXTENDED): + element = ELEMENT_VOCAB_EXTENDED[idx] + # Skip special tokens + if element in ("PAD", "MASK", "UNK"): + element = "C" # Default to carbon + else: + element = "C" # Default to carbon + # Create unique atom name: element + number + atom_names.append(f"{element}{i + 1}") + + return atom_names diff --git a/src/lobster/callbacks/_protein_ligand_forward_folding.py b/src/lobster/callbacks/_protein_ligand_forward_folding.py new file mode 100644 index 00000000..01cdace4 --- /dev/null +++ b/src/lobster/callbacks/_protein_ligand_forward_folding.py @@ -0,0 +1,224 @@ +"""Callback for evaluating ligand-conditioned forward folding on protein-ligand complexes. + +This callback compares forward folding (sequence → structure) with and without +ligand context to determine if ligand information improves structure prediction, +particularly for binding pocket residues. +""" + +import os + +import lightning +import torch +from loguru import logger + +from lobster.metrics.protein_ligand_forward_folding import ProteinLigandForwardFoldingEvaluator + + +class ProteinLigandForwardFoldingCallback(lightning.Callback): + """Callback for evaluating ligand-conditioned forward folding on protein-ligand complexes. + + Compares: + - Forward folding with protein sequence only + - Forward folding with protein sequence + ligand context + + Hypothesis: Ligand context should improve pocket structure prediction (lower RMSD). + + Parameters + ---------- + data_dir : str + Path to PDBBind test directory with *_protein.pt and *_ligand.pt pairs + structure_path : str, optional + Output directory for results and structures + save_every_n : int + Run evaluation every N training steps + num_samples : int + Number of structures to evaluate (subset for speed during training) + pocket_distance_threshold : float + Distance threshold (Å) for defining binding pocket residues + nsteps : int + Number of diffusion steps for generation + metric_prefix : str + Prefix for logged metrics + minimize_ligand : bool + Whether to apply geometry correction to decoded ligand structures + minimize_mode : str + Minimization mode: "bonds_only", "bonds_and_angles", "local", or "full" + force_field : str + Force field for minimization: "MMFF94", "MMFF94s", "UFF", etc. + minimize_steps : int + Maximum number of minimization steps + + Example + ------- + Add to callbacks config: + + ```yaml + protein_ligand_forward_folding: + _target_: lobster.callbacks.ProteinLigandForwardFoldingCallback + data_dir: /data2/lisanzas/pdb_bind_12_15_25/test/ + structure_path: ${paths.output_dir}/protein_ligand_eval/ + save_every_n: 5000 + num_samples: 100 + pocket_distance_threshold: 5.0 + ``` + """ + + def __init__( + self, + data_dir: str = "/data2/lisanzas/pdb_bind_12_15_25/test/", + structure_path: str | None = None, + save_every_n: int = 5000, + num_samples: int = 100, + pocket_distance_threshold: float = 5.0, + nsteps: int = 100, + metric_prefix: str = "protein_ligand_forward_folding", + minimize_ligand: bool = False, + minimize_mode: str = "bonds_and_angles", + force_field: str = "MMFF94", + minimize_steps: int = 500, + ): + self.data_dir = data_dir + self.structure_path = structure_path + self.save_every_n = save_every_n + self.num_samples = num_samples + self.pocket_distance_threshold = pocket_distance_threshold + self.nsteps = nsteps + self.metric_prefix = metric_prefix + self.minimize_ligand = minimize_ligand + self.minimize_mode = minimize_mode + self.force_field = force_field + self.minimize_steps = minimize_steps + + self._evaluator = None + self._samples = None + self._evaluated_steps = set() # Track steps we've already evaluated + + # Create output directory + if structure_path: + os.makedirs(structure_path, exist_ok=True) + + def setup(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule, stage: str): + """Load test samples once at setup. + + Only loads on rank 0 in distributed settings. + """ + if trainer.global_rank != 0: + return + + if self._samples is not None: + return # Already loaded + + # Get device + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Get max_length from model if available + max_length = 512 # default + if hasattr(pl_module, "encoder") and hasattr(pl_module.encoder, "neobert"): + if hasattr(pl_module.encoder.neobert, "config") and hasattr(pl_module.encoder.neobert.config, "max_length"): + max_length = pl_module.encoder.neobert.config.max_length + logger.info(f"Using model's max_length: {max_length}") + + # Create evaluator + self._evaluator = ProteinLigandForwardFoldingEvaluator( + data_dir=self.data_dir, + pocket_distance_threshold=self.pocket_distance_threshold, + num_samples=self.num_samples, + nsteps=self.nsteps, + device=device, + max_length=max_length, + minimize_ligand=self.minimize_ligand, + minimize_mode=self.minimize_mode, + force_field=self.force_field, + minimize_steps=self.minimize_steps, + ) + + # Load samples + try: + self._samples = self._evaluator.load_test_set() + logger.info(f"Loaded {len(self._samples)} protein-ligand test samples for forward folding evaluation") + except Exception as e: + logger.error(f"Failed to load protein-ligand test samples: {e}") + self._samples = [] + + def on_train_batch_end( + self, + trainer: lightning.Trainer, + pl_module: lightning.LightningModule, + outputs, + batch, + batch_idx: int, + ): + """Run evaluation at specified intervals.""" + # Only run on rank 0 + if trainer.global_rank != 0: + return + + current_step = trainer.global_step + + # Check if we should run evaluation + if current_step % self.save_every_n != 0: + return + + # Skip if we've already evaluated this step (avoid duplicates from multiple callback invocations) + if current_step in self._evaluated_steps: + return + + if not self._samples: + logger.warning("No protein-ligand samples available for evaluation") + return + + # Mark this step as evaluated + self._evaluated_steps.add(current_step) + logger.info(f"Running protein-ligand forward folding evaluation at step {current_step}") + + # Run evaluation (no try-catch to get full traceback for debugging) + results = self._evaluator.evaluate(pl_module, self._samples) + + # Log metrics to trainer + summary = results["summary"] + metrics_to_log = {} + + for key, value in summary.items(): + if isinstance(value, (int, float)) and not isinstance(value, bool): + metrics_to_log[f"{self.metric_prefix}/{key}"] = value + + pl_module.log_dict(metrics_to_log, batch_size=1) + + # Log key insights + tm_delta = summary.get("mean_tm_score_delta", 0) + rmsd_pocket_delta = summary.get("mean_rmsd_pocket_delta", 0) + rmsd_overall_delta = summary.get("mean_rmsd_overall_delta", 0) + + logger.info( + f"Protein-Ligand Forward Folding (step {current_step}):\n" + f" TM-score: {summary['mean_tm_score_no_ligand']:.3f} → " + f"{summary['mean_tm_score_with_ligand']:.3f} " + f"(Δ={tm_delta:+.3f})\n" + f" Pocket RMSD: {summary['mean_rmsd_pocket_no_ligand']:.2f} Å → " + f"{summary['mean_rmsd_pocket_with_ligand']:.2f} Å " + f"(Δ={rmsd_pocket_delta:+.2f} Å)\n" + f" Overall RMSD: {summary['mean_rmsd_overall_no_ligand']:.2f} Å → " + f"{summary['mean_rmsd_overall_with_ligand']:.2f} Å " + f"(Δ={rmsd_overall_delta:+.2f} Å)" + ) + + # Save CSV results + if self.structure_path: + output_file = os.path.join( + self.structure_path, + f"protein_ligand_forward_folding_step{current_step}.csv", + ) + results["results_df"].to_csv(output_file, index=False) + logger.info(f"Saved results to {output_file}") + + # Clear GPU cache + torch.cuda.empty_cache() + + def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule): + """Optionally run evaluation at validation epoch end. + + This method can be used for periodic evaluation during validation + if preferred over batch-based evaluation. + """ + # Currently using batch-end evaluation; can enable this if needed + pass diff --git a/src/lobster/callbacks/_protein_ligand_inverse_folding.py b/src/lobster/callbacks/_protein_ligand_inverse_folding.py new file mode 100644 index 00000000..b01f9785 --- /dev/null +++ b/src/lobster/callbacks/_protein_ligand_inverse_folding.py @@ -0,0 +1,229 @@ +"""Callback for evaluating ligand-conditioned inverse folding on protein-ligand complexes. + +This callback compares inverse folding with and without ligand context to +determine if ligand information improves sequence recovery, particularly +for binding pocket residues. +""" + +import os + +import lightning +import torch +from loguru import logger + +from lobster.metrics.protein_ligand_inverse_folding import ProteinLigandInverseFoldingEvaluator + + +class ProteinLigandInverseFoldingCallback(lightning.Callback): + """Callback for evaluating ligand-conditioned inverse folding on protein-ligand complexes. + + Compares: + - Inverse folding with protein structure only + - Inverse folding with protein structure + ligand context + + Hypothesis: Ligand context should improve pocket sequence recovery. + + Parameters + ---------- + data_dir : str + Path to PDBBind test directory with *_protein.pt and *_ligand.pt pairs + structure_path : str, optional + Output directory for results and structures + save_every_n : int + Run evaluation every N training steps + num_samples : int + Number of structures to evaluate (subset for speed during training) + pocket_distance_threshold : float + Distance threshold (Å) for defining binding pocket residues + nsteps : int + Number of diffusion steps for generation + metric_prefix : str + Prefix for logged metrics + minimize_ligand : bool + Whether to apply geometry correction to decoded ligand structures + minimize_mode : str + Minimization mode: "bonds_only", "bonds_and_angles", "local", or "full" + force_field : str + Force field for minimization: "MMFF94", "MMFF94s", "UFF", etc. + minimize_steps : int + Maximum number of minimization steps + + Example + ------- + Add to callbacks config: + + ```yaml + protein_ligand_inverse_folding: + _target_: lobster.callbacks.ProteinLigandInverseFoldingCallback + data_dir: /data2/lisanzas/pdb_bind_12_15_25/test/ + structure_path: ${paths.output_dir}/protein_ligand_eval/ + save_every_n: 5000 + num_samples: 100 + pocket_distance_threshold: 5.0 + ``` + """ + + def __init__( + self, + data_dir: str = "/data2/lisanzas/pdb_bind_12_15_25/test/", + structure_path: str | None = None, + save_every_n: int = 5000, + num_samples: int = 100, + pocket_distance_threshold: float = 5.0, + nsteps: int = 100, + metric_prefix: str = "protein_ligand_inverse_folding", + minimize_ligand: bool = False, + minimize_mode: str = "bonds_and_angles", + force_field: str = "MMFF94", + minimize_steps: int = 500, + ): + self.data_dir = data_dir + self.structure_path = structure_path + self.save_every_n = save_every_n + self.num_samples = num_samples + self.pocket_distance_threshold = pocket_distance_threshold + self.nsteps = nsteps + self.metric_prefix = metric_prefix + self.minimize_ligand = minimize_ligand + self.minimize_mode = minimize_mode + self.force_field = force_field + self.minimize_steps = minimize_steps + + self._evaluator = None + self._samples = None + self._evaluated_steps = set() # Track steps we've already evaluated + + # Create output directory + if structure_path: + os.makedirs(structure_path, exist_ok=True) + + def setup(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule, stage: str): + """Load test samples once at setup. + + Only loads on rank 0 in distributed settings. + """ + if trainer.global_rank != 0: + return + + if self._samples is not None: + return # Already loaded + + # Get device + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Get max_length from model if available + max_length = 512 # default + if hasattr(pl_module, "encoder") and hasattr(pl_module.encoder, "neobert"): + if hasattr(pl_module.encoder.neobert, "config") and hasattr(pl_module.encoder.neobert.config, "max_length"): + max_length = pl_module.encoder.neobert.config.max_length + logger.info(f"Using model's max_length: {max_length}") + + # Create evaluator + self._evaluator = ProteinLigandInverseFoldingEvaluator( + data_dir=self.data_dir, + pocket_distance_threshold=self.pocket_distance_threshold, + num_samples=self.num_samples, + nsteps=self.nsteps, + device=device, + max_length=max_length, + minimize_ligand=self.minimize_ligand, + minimize_mode=self.minimize_mode, + force_field=self.force_field, + minimize_steps=self.minimize_steps, + ) + + # Load samples + try: + self._samples = self._evaluator.load_test_set() + logger.info(f"Loaded {len(self._samples)} protein-ligand test samples for inverse folding evaluation") + except Exception as e: + logger.error(f"Failed to load protein-ligand test samples: {e}") + self._samples = [] + + def on_train_batch_end( + self, + trainer: lightning.Trainer, + pl_module: lightning.LightningModule, + outputs, + batch, + batch_idx: int, + ): + """Run evaluation at specified intervals.""" + # Only run on rank 0 + if trainer.global_rank != 0: + return + + current_step = trainer.global_step + + # Check if we should run evaluation + if current_step % self.save_every_n != 0: + return + + # Skip if we've already evaluated this step (avoid duplicates from multiple callback invocations) + if current_step in self._evaluated_steps: + return + + if not self._samples: + logger.warning("No protein-ligand samples available for evaluation") + return + + # Mark this step as evaluated + self._evaluated_steps.add(current_step) + logger.info(f"Running protein-ligand inverse folding evaluation at step {current_step}") + + # Run evaluation + try: + results = self._evaluator.evaluate(pl_module, self._samples) + + # Log metrics to trainer + summary = results["summary"] + metrics_to_log = {} + + for key, value in summary.items(): + if isinstance(value, (int, float)) and not isinstance(value, bool): + metrics_to_log[f"{self.metric_prefix}/{key}"] = value + + pl_module.log_dict(metrics_to_log, batch_size=1) + + # Log key insight + pocket_delta = summary.get("mean_aar_pocket_delta", 0) + nonpocket_delta = summary.get("mean_aar_nonpocket_delta", 0) + + logger.info( + f"Protein-Ligand Inverse Folding (step {current_step}):\n" + f" Pocket AAR: {summary['mean_aar_pocket_no_ligand']:.2%} → " + f"{summary['mean_aar_pocket_with_ligand']:.2%} " + f"(Δ={pocket_delta:+.2%})\n" + f" Non-pocket AAR: {summary['mean_aar_nonpocket_no_ligand']:.2%} → " + f"{summary['mean_aar_nonpocket_with_ligand']:.2%} " + f"(Δ={nonpocket_delta:+.2%})\n" + f" Overall AAR: {summary['mean_aar_overall_no_ligand']:.2%} → " + f"{summary['mean_aar_overall_with_ligand']:.2%}" + ) + + # Save CSV results + if self.structure_path: + output_file = os.path.join( + self.structure_path, + f"protein_ligand_inverse_folding_step{current_step}.csv", + ) + results["results_df"].to_csv(output_file, index=False) + logger.info(f"Saved results to {output_file}") + + except Exception as e: + logger.error(f"Protein-ligand inverse folding evaluation failed: {e}") + import traceback + + traceback.print_exc() + + # Clear GPU cache + torch.cuda.empty_cache() + + def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule): + """Optionally run evaluation at validation epoch end. + + This method can be used for periodic evaluation during validation + if preferred over batch-based evaluation. + """ + # Currently using batch-end evaluation; can enable this if needed + pass diff --git a/src/lobster/callbacks/_s3_checkpoint_callback.py b/src/lobster/callbacks/_s3_checkpoint_callback.py new file mode 100644 index 00000000..22a96de2 --- /dev/null +++ b/src/lobster/callbacks/_s3_checkpoint_callback.py @@ -0,0 +1,168 @@ +"""S3 Checkpoint Backup Callback for PyTorch Lightning. + +This callback automatically backs up checkpoints to S3 after they are saved, +providing disaster recovery capability for training runs. + +Usage: + Add to your training config: + ```yaml + callbacks: + s3_backup: + _target_: lobster.callbacks._s3_checkpoint_callback.S3CheckpointBackupCallback + s3_bucket: "prescient-lobster" + s3_prefix: "checkpoints" + project_name: "latent_generator" + ``` +""" + +import logging +import os +from pathlib import Path +from typing import Any + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback + +py_logger = logging.getLogger(__name__) + + +class S3CheckpointBackupCallback(Callback): + """Automatically backup checkpoints to S3 after saving. + + This callback uploads checkpoints to S3 whenever PyTorch Lightning saves + a checkpoint, providing a backup in case of local storage failures. + + Args: + s3_bucket: S3 bucket name for backup storage. + s3_prefix: Prefix path within the bucket (e.g., "checkpoints/latent_generator"). + project_name: Project name for organizing checkpoints. + upload_every_n_epochs: Upload periodic checkpoints every N epochs. + upload_best_only: If True, only upload the best checkpoint. + upload_last: If True, also upload the last checkpoint. + dry_run: If True, log uploads without actually uploading. + """ + + def __init__( + self, + s3_bucket: str = "prescient-pcluster-data", + s3_prefix: str = "gen_ume/checkpoints", + project_name: str | None = None, + upload_every_n_epochs: int = 10, + upload_best_only: bool = False, + upload_last: bool = True, + dry_run: bool = False, + ): + super().__init__() + self.s3_bucket = s3_bucket + self.s3_prefix = s3_prefix + self.project_name = project_name + self.upload_every_n_epochs = upload_every_n_epochs + self.upload_best_only = upload_best_only + self.upload_last = upload_last + self.dry_run = dry_run + self._s3_client = None + + @property + def s3_client(self): + """Lazy initialization of S3 client.""" + if self._s3_client is None: + try: + import boto3 + + self._s3_client = boto3.client("s3") + except ImportError: + py_logger.error("boto3 not installed. Run: pip install boto3") + raise + return self._s3_client + + def _get_s3_key(self, local_path: str, checkpoint_type: str = "periodic") -> str: + """Generate S3 key for a checkpoint file. + + Args: + local_path: Local path to the checkpoint file. + checkpoint_type: Type of checkpoint ("best", "last", or "periodic"). + + Returns: + S3 key string. + """ + filename = Path(local_path).name + project = self.project_name or "unknown" + return f"{self.s3_prefix}/{project}/{checkpoint_type}/{filename}" + + def _upload_to_s3(self, local_path: str, s3_key: str) -> bool: + """Upload a file to S3. + + Args: + local_path: Local path to the file. + s3_key: S3 key (path within bucket). + + Returns: + True if upload succeeded, False otherwise. + """ + if self.dry_run: + py_logger.info(f"[DRY RUN] Would upload {local_path} to s3://{self.s3_bucket}/{s3_key}") + return True + + try: + self.s3_client.upload_file(local_path, self.s3_bucket, s3_key) + py_logger.info(f"✅ Uploaded checkpoint to s3://{self.s3_bucket}/{s3_key}") + return True + except Exception as e: + py_logger.error(f"❌ Failed to upload {local_path} to S3: {e}") + return False + + def on_save_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: dict[str, Any], + ) -> None: + """Called when a checkpoint is saved. + + Uploads the checkpoint to S3 based on configuration. + """ + ckpt_callback = trainer.checkpoint_callback + if ckpt_callback is None: + return + + current_epoch = trainer.current_epoch + + # Upload best checkpoint + if ckpt_callback.best_model_path and os.path.exists(ckpt_callback.best_model_path): + s3_key = self._get_s3_key(ckpt_callback.best_model_path, "best") + self._upload_to_s3(ckpt_callback.best_model_path, s3_key) + + # Skip if upload_best_only is set + if self.upload_best_only: + return + + # Upload last checkpoint + if self.upload_last and ckpt_callback.last_model_path: + if os.path.exists(ckpt_callback.last_model_path): + s3_key = self._get_s3_key(ckpt_callback.last_model_path, "last") + self._upload_to_s3(ckpt_callback.last_model_path, s3_key) + + # Upload periodic checkpoints + if self.upload_every_n_epochs > 0 and current_epoch % self.upload_every_n_epochs == 0: + # Upload any checkpoint saved at this epoch + if ckpt_callback.last_model_path and os.path.exists(ckpt_callback.last_model_path): + s3_key = self._get_s3_key(ckpt_callback.last_model_path, f"epoch_{current_epoch}") + self._upload_to_s3(ckpt_callback.last_model_path, s3_key) + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + """Called when training ends. Upload final checkpoints.""" + ckpt_callback = trainer.checkpoint_callback + if ckpt_callback is None: + return + + # Final upload of best checkpoint + if ckpt_callback.best_model_path and os.path.exists(ckpt_callback.best_model_path): + s3_key = self._get_s3_key(ckpt_callback.best_model_path, "final_best") + self._upload_to_s3(ckpt_callback.best_model_path, s3_key) + + # Final upload of last checkpoint + if ckpt_callback.last_model_path and os.path.exists(ckpt_callback.last_model_path): + s3_key = self._get_s3_key(ckpt_callback.last_model_path, "final_last") + self._upload_to_s3(ckpt_callback.last_model_path, s3_key) + + py_logger.info(f"📦 Training complete. Checkpoints backed up to s3://{self.s3_bucket}/{self.s3_prefix}/") diff --git a/src/lobster/callbacks/_structure_decode.py b/src/lobster/callbacks/_structure_decode.py index 05ec0365..f1e61736 100644 --- a/src/lobster/callbacks/_structure_decode.py +++ b/src/lobster/callbacks/_structure_decode.py @@ -15,52 +15,66 @@ def __init__(self, structure_path: str = None, save_every_n: int = 1000): os.makedirs(f"{self.structure_path}/decode", exist_ok=True) def on_train_batch_end(self, trainer, laten_mlm, outputs, batch, batch_idx): + if batch_idx % self.save_every_n != 0: + return + current_step = trainer.global_step seq = None x_recon_xyz = None - if batch_idx % self.save_every_n == 0: - x_recon = outputs["decoded_x"] - if "train_t" in outputs: - t = outputs["train_t"] - t = t[0].cpu().numpy() - else: - t_seq = outputs["train_timesteps_seq"] - t_struc = outputs["train_timesteps_struc"] - t_seq = t_seq[0].cpu().numpy() - t_struc = t_struc[0].cpu().numpy() - t = None - conditioning = outputs["conditioning"] - - x_recon_xyz = None + x_recon = outputs["decoded_x"] + if "train_t" in outputs: + t = outputs["train_t"] + t = t[0].cpu().numpy() + else: + t_seq = outputs["train_timesteps_seq"] + t_struc = outputs["train_timesteps_struc"] + t_seq = t_seq[0].cpu().numpy() + t_struc = t_struc[0].cpu().numpy() + t = None + conditioning = outputs["conditioning"] - for decoder_name in x_recon: - if "vit_decoder" == decoder_name: - x_recon_xyz = x_recon[decoder_name] + x_recon_xyz = None - # save the pdb file - if x_recon_xyz is not None: - if outputs["unmasked_x"]["sequence_logits"].shape[-1] == 33: - seq = convert_lobster_aa_tokenization_to_standard_aa( - outputs["unmasked_x"]["sequence_logits"], device=self.device - ) - else: - seq = outputs["unmasked_x"]["sequence_logits"].argmax(dim=-1) - seq[seq > 21] = 20 - if t is not None: - filename = f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_t{str(t)}_cond{conditioning}_decode.pdb" + for decoder_name in x_recon: + if "vit_decoder" == decoder_name: + vit_output = x_recon[decoder_name] + # Handle both old format (tensor) and new format (dict with protein_coords/ligand_coords) + if isinstance(vit_output, dict): + x_recon_xyz = vit_output.get("protein_coords") else: - filename = f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_tseq_{str(t_seq)}_tstruc_{str(t_struc)}_cond{conditioning}_decode.pdb" - writepdb(filename, x_recon_xyz[0], seq[0]) - logger.info(f"Saved {filename}") + x_recon_xyz = vit_output - # save batch - if t is not None: - filename = f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_t{str(t)}_cond{conditioning}_gt.pdb" - else: - filename = f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_tseq_{str(t_seq)}_tstruc_{str(t_struc)}_cond{conditioning}_gt.pdb" - seq = batch["sequence"][0] - # if naything >21, set to 20 + # save the pdb file + if x_recon_xyz is not None: + if outputs["unmasked_x"]["sequence_logits"].shape[-1] == 33: + seq = convert_lobster_aa_tokenization_to_standard_aa( + outputs["unmasked_x"]["sequence_logits"], device=self.device + ) + else: + seq = outputs["unmasked_x"]["sequence_logits"].argmax(dim=-1) seq[seq > 21] = 20 - writepdb(filename, batch["coords_res"][0], seq) - logger.info(f"Saved {filename}") + # Skip if seq has incorrect shape (needs to be at least 2D: batch x seq_len) + if seq.dim() < 2: + logger.warning(f"Skipping structure decode save: seq has unexpected shape {seq.shape}") + return + + if t is not None: + filename = f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_t{str(t)}_cond{conditioning}_decode.pdb" + else: + filename = f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_tseq_{str(t_seq)}_tstruc_{str(t_struc)}_cond{conditioning}_decode.pdb" + writepdb(filename, x_recon_xyz[0], seq[0]) + logger.info(f"Saved {filename}") + + # save batch + if t is not None: + filename = ( + f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_t{str(t)}_cond{conditioning}_gt.pdb" + ) + else: + filename = f"{self.structure_path}decode/struc_{batch_idx}_{current_step}_tseq_{str(t_seq)}_tstruc_{str(t_struc)}_cond{conditioning}_gt.pdb" + seq = batch["sequence"][0] + # if anything >21, set to 20 + seq[seq > 21] = 20 + writepdb(filename, batch["coords_res"][0], seq) + logger.info(f"Saved {filename}") diff --git a/src/lobster/callbacks/_unconditional_generation.py b/src/lobster/callbacks/_unconditional_generation.py index ca1062ad..a4cc0df0 100644 --- a/src/lobster/callbacks/_unconditional_generation.py +++ b/src/lobster/callbacks/_unconditional_generation.py @@ -50,7 +50,9 @@ def on_train_batch_end(self, trainer, gen_ume, outputs, batch, batch_idx): if batch_idx % self.save_every_n == 0 and self.plm_fold is not None: # Perform unconditional generation - self._perform_unconditional_generation(trainer, gen_ume, device, batch_idx, current_step) + with torch.no_grad(): + self._perform_unconditional_generation(trainer, gen_ume, device, batch_idx, current_step) + torch.cuda.empty_cache() def _perform_unconditional_generation(self, trainer, gen_ume, device, batch_idx, current_step): """Perform unconditional generation and folding.""" @@ -60,7 +62,12 @@ def _perform_unconditional_generation(self, trainer, gen_ume, device, batch_idx, for decoder_name in decoded_x: if "vit_decoder" == decoder_name: - x_recon_xyz = decoded_x[decoder_name] + vit_output = decoded_x[decoder_name] + # Handle both tensor output (protein-only) and dict output (protein-ligand) + if isinstance(vit_output, dict): + x_recon_xyz = vit_output.get("protein_coords") + else: + x_recon_xyz = vit_output if generate_sample["sequence_logits"].shape[-1] == 33: seq = convert_lobster_aa_tokenization_to_standard_aa(generate_sample["sequence_logits"], device=device) else: diff --git a/src/lobster/cmdline/analyze_external_predictions.py b/src/lobster/cmdline/analyze_external_predictions.py new file mode 100755 index 00000000..ce797b9f --- /dev/null +++ b/src/lobster/cmdline/analyze_external_predictions.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python3 +""" +Analyze external predicted structures against ground truth. + +Compares predicted PDB structures from an external directory to ground truth +structures, using the same metrics (TM-score, RMSD) as the lobster generation +pipeline. +""" + +import torch +import pandas as pd +from pathlib import Path +from loguru import logger +import argparse +from tmtools import tm_align +import numpy as np + +from lobster.metrics._generation_utils import align_and_compute_rmsd +from lobster.model.latent_generator.io import load_pdb + + +def load_pdb_structure(pdb_path: Path) -> tuple[np.ndarray, str]: + """ + Load structure from PDB file using lobster's load_pdb function. + + Args: + pdb_path: Path to PDB file + + Returns: + coords: Coordinates array of shape (L, 3, 3) for N, CA, C atoms + sequence: Amino acid sequence string + """ + try: + # Use lobster's load_pdb function + # Returns a dictionary with keys: 'sequence', 'sequence_str', 'coords_res', 'mask', etc. + structure_data = load_pdb(str(pdb_path), add_batch_dim=False) + + if structure_data is None: + logger.error(f"load_pdb returned None for {pdb_path}") + return None, None + + # Extract sequence string (already in 1-letter code) + seq_str = structure_data["sequence_str"] + + # Extract coordinates (N, 3, 3) - backbone atoms [N, CA, C] + coords = structure_data["coords_res"] + + # Convert coords to numpy if it's a tensor + if isinstance(coords, torch.Tensor): + coords = coords.numpy() + + return coords, seq_str + + except Exception as e: + logger.error(f"Error loading PDB {pdb_path}: {e}") + return None, None + + +def load_pt_structure(pt_path: Path) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Load structure from .pt file. + + Args: + pt_path: Path to .pt file + + Returns: + coords: Coordinates tensor of shape (L, 3, 3) for N, CA, C atoms + sequence: Sequence tensor of shape (L,) + mask: Mask tensor of shape (L,) + """ + try: + data = torch.load(pt_path, map_location="cpu") + + # Extract coordinates (assuming backbone atoms in order N, CA, C) + if "coords_res" in data: + coords = data["coords_res"] # Should be (L, 3, 3) + elif "bb_positions" in data: + # If only CA positions, need to handle differently + logger.warning(f"Only CA positions available in {pt_path}") + coords = data["bb_positions"].unsqueeze(1).repeat(1, 3, 1) + else: + logger.error(f"No coordinate data found in {pt_path}") + return None, None, None + + # Extract sequence - try both 'sequence' and 'seq' keys + if "sequence" in data: + sequence = data["sequence"] + elif "seq" in data: + sequence = data["seq"] + else: + logger.error(f"No sequence data found in {pt_path}. Available keys: {list(data.keys())}") + return None, None, None + + # Extract mask + if "mask" in data: + mask = data["mask"] + else: + mask = torch.ones(sequence.shape[0]) + + return coords, sequence, mask + + except Exception as e: + logger.error(f"Error loading .pt file {pt_path}: {e}") + return None, None, None + + +def extract_structure_id(filename: str) -> str: + """ + Extract structure ID from filename. + + Args: + filename: Filename (e.g., "5S9R_pred.pdb" or "5S9R.pt") + + Returns: + Structure ID (e.g., "5S9R") + """ + # Remove extension + name = Path(filename).stem + + # Remove common suffixes + for suffix in ["_pred", "_predicted", "_folded", "_structure"]: + if name.endswith(suffix): + name = name[: -len(suffix)] + + return name + + +def compare_sequences(seq1: str, seq2: str, struct_id: str) -> tuple[bool, float]: + """ + Compare two sequences and return whether they match. + + Args: + seq1: First sequence + seq2: Second sequence + struct_id: Structure ID for logging + + Returns: + (match, identity): Whether sequences match exactly, and percent identity + """ + if len(seq1) != len(seq2): + logger.warning(f"{struct_id}: Sequence length mismatch - predicted={len(seq1)}, ground_truth={len(seq2)}") + return False, 0.0 + + # Calculate percent identity + matches = sum(1 for a, b in zip(seq1, seq2) if a == b) + percent_identity = (matches / len(seq1)) * 100.0 + + if percent_identity < 100.0: + logger.warning( + f"{struct_id}: Sequence mismatch - {percent_identity:.1f}% identity ({matches}/{len(seq1)} residues match)" + ) + # Show first difference + for i, (a, b) in enumerate(zip(seq1, seq2)): + if a != b: + context_start = max(0, i - 5) + context_end = min(len(seq1), i + 6) + logger.warning( + f" First difference at position {i}: " + f"predicted='{seq1[context_start:context_end]}' " + f"ground_truth='{seq2[context_start:context_end]}'" + ) + break + + return percent_identity == 100.0, percent_identity + + +def calculate_metrics( + pred_coords: torch.Tensor, + gt_coords: torch.Tensor, + gt_sequence: str, + mask: torch.Tensor | None = None, + device: torch.device = torch.device("cpu"), +) -> dict: + """ + Calculate TM-score and RMSD between predicted and ground truth structures. + + Args: + pred_coords: Predicted coordinates, shape (L, 3, 3) + gt_coords: Ground truth coordinates, shape (L, 3, 3) + gt_sequence: Ground truth sequence string + mask: Optional mask, shape (L,) + device: torch device + + Returns: + Dictionary with 'tm_score' and 'rmsd' keys + """ + # Ensure tensors are on the correct device + pred_coords = pred_coords.to(device) + gt_coords = gt_coords.to(device) + + if mask is not None: + mask = mask.to(device) + # Apply mask + pred_coords_masked = pred_coords[mask.bool()] + gt_coords_masked = gt_coords[mask.bool()] + # Filter sequence + gt_sequence_masked = "".join([gt_sequence[i] for i in range(len(gt_sequence)) if mask[i] == 1]) + else: + pred_coords_masked = pred_coords + gt_coords_masked = gt_coords + gt_sequence_masked = gt_sequence + + # Calculate TM-score using tm_align + try: + tm_out = tm_align( + pred_coords_masked[:, 1, :].cpu().numpy(), # CA atoms + gt_coords_masked[:, 1, :].cpu().numpy(), # CA atoms + gt_sequence_masked, + gt_sequence_masked, + ) + tm_score = tm_out.tm_norm_chain1 + except Exception as e: + logger.error(f"Error calculating TM-score: {e}") + tm_score = 0.0 + + # Calculate RMSD using Kabsch alignment + try: + rmsd = align_and_compute_rmsd( + coords1=pred_coords_masked, + coords2=gt_coords_masked, + mask=None, # Already masked + return_aligned=False, + device=device, + ) + except Exception as e: + logger.error(f"Error calculating RMSD: {e}") + rmsd = 0.0 + + return { + "tm_score": float(tm_score), + "rmsd": float(rmsd), + } + + +def analyze_predictions( + pred_dir: str, + gt_dir: str, + output_csv: str = None, + device_str: str = "cpu", + rmsd_threshold: float = 2.0, + skip_sequence_mismatch: bool = False, +): + """ + Analyze predicted structures against ground truth. + + Args: + pred_dir: Directory containing predicted PDB files + gt_dir: Directory containing ground truth .pt files + output_csv: Optional path to save results CSV + device_str: Device to use ('cpu' or 'cuda') + rmsd_threshold: RMSD threshold for reporting pass rate + skip_sequence_mismatch: If True, skip structures with sequence mismatches + """ + pred_path = Path(pred_dir) + gt_path = Path(gt_dir) + + # Set up device + device = torch.device(device_str if torch.cuda.is_available() and device_str == "cuda" else "cpu") + logger.info(f"Using device: {device}") + + # Find all predicted PDB files + pred_files = sorted(list(pred_path.glob("*.pdb"))) + logger.info(f"Found {len(pred_files)} predicted PDB files in {pred_dir}") + + if len(pred_files) == 0: + logger.error("No PDB files found in prediction directory") + return + + # Build mapping from structure IDs to files + pred_map = {} + for pdb_file in pred_files: + struct_id = extract_structure_id(pdb_file.name) + pred_map[struct_id] = pdb_file + + # Find matching ground truth files + results = [] + matched_count = 0 + missing_gt = [] + sequence_mismatches = [] + + for struct_id, pred_file in pred_map.items(): + # Try to find matching .pt file + gt_file = gt_path / f"{struct_id}.pt" + + if not gt_file.exists(): + # Try with different extensions + possible_gt = list(gt_path.glob(f"{struct_id}*.pt")) + if possible_gt: + gt_file = possible_gt[0] + else: + logger.warning(f"No ground truth found for {struct_id}") + missing_gt.append(struct_id) + continue + + # Load predicted structure + logger.info(f"Processing {struct_id}...") + pred_coords, pred_seq = load_pdb_structure(pred_file) + + if pred_coords is None: + logger.error(f"Failed to load predicted structure: {pred_file}") + continue + + # Convert to tensor + pred_coords = torch.from_numpy(pred_coords).float() + + # Load ground truth structure + gt_coords, gt_seq, gt_mask = load_pt_structure(gt_file) + + if gt_coords is None: + logger.error(f"Failed to load ground truth structure: {gt_file}") + continue + + # Check length match + if pred_coords.shape[0] != gt_coords.shape[0]: + logger.warning( + f"Length mismatch for {struct_id}: predicted={pred_coords.shape[0]}, ground_truth={gt_coords.shape[0]}" + ) + # Try to truncate to shorter length + min_len = min(pred_coords.shape[0], gt_coords.shape[0]) + pred_coords = pred_coords[:min_len] + gt_coords = gt_coords[:min_len] + if gt_mask is not None: + gt_mask = gt_mask[:min_len] + + # Convert sequence tensor to string if needed + if isinstance(gt_seq, torch.Tensor): + # Assuming aatype indices (0-19 standard AAs, 20 = X) + restypes = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", + "X", + ] + gt_seq_str = "".join([restypes[i] if i < len(restypes) else "X" for i in gt_seq]) + else: + gt_seq_str = gt_seq + + # If pred_seq is already 1-letter codes, use as is + # Otherwise try to extract from PDB + if len(pred_seq) == pred_coords.shape[0] and all(c in "ACDEFGHIKLMNPQRSTVWYX" for c in pred_seq): + pred_seq_str = pred_seq + else: + # Fallback: use ground truth sequence length + pred_seq_str = gt_seq_str[: pred_coords.shape[0]] + + # Check sequence match + seq_match, seq_identity = compare_sequences(pred_seq_str, gt_seq_str, struct_id) + + if not seq_match: + sequence_mismatches.append((struct_id, seq_identity)) + if skip_sequence_mismatch: + logger.warning(f" Skipping {struct_id} due to sequence mismatch") + continue + + # Calculate metrics + metrics = calculate_metrics( + pred_coords=pred_coords, + gt_coords=gt_coords, + gt_sequence=gt_seq_str, + mask=gt_mask, + device=device, + ) + + # Store results + results.append( + { + "Structure_ID": struct_id, + "Length": pred_coords.shape[0], + "Seq_Identity": seq_identity, + "TM_Score": metrics["tm_score"], + "RMSD": metrics["rmsd"], + "Pred_File": pred_file.name, + "GT_File": gt_file.name, + } + ) + + matched_count += 1 + + # Log individual result + logger.info(f" {struct_id}: TM-score={metrics['tm_score']:.4f}, RMSD={metrics['rmsd']:.4f} Å") + + # Create DataFrame + if not results: + logger.error("No structures were successfully analyzed") + return + + df = pd.DataFrame(results) + + # Sort by TM-score (descending) + df = df.sort_values("TM_Score", ascending=False) + + # Calculate summary statistics + logger.info("\n" + "=" * 80) + logger.info("SUMMARY STATISTICS") + logger.info("=" * 80) + logger.info(f"Total structures analyzed: {len(df)}") + logger.info(f"Structures with ground truth: {matched_count}/{len(pred_map)}") + + if missing_gt: + logger.info( + f"Missing ground truth for: {', '.join(missing_gt[:10])}" + + (f" ... and {len(missing_gt) - 10} more" if len(missing_gt) > 10 else "") + ) + + if sequence_mismatches: + logger.warning(f"\nSequence mismatches found: {len(sequence_mismatches)}") + # Show worst mismatches + worst_mismatches = sorted(sequence_mismatches, key=lambda x: x[1])[:5] + for struct_id, identity in worst_mismatches: + logger.warning(f" {struct_id}: {identity:.1f}% identity") + + # Report sequence identity statistics + logger.info("\nSequence Identity:") + logger.info(f" Mean: {df['Seq_Identity'].mean():.2f}%") + logger.info(f" Min: {df['Seq_Identity'].min():.2f}%") + logger.info(f" Max: {df['Seq_Identity'].max():.2f}%") + exact_matches = len(df[df["Seq_Identity"] == 100.0]) + logger.info(f" Exact matches: {exact_matches}/{len(df)} ({exact_matches / len(df) * 100:.1f}%)") + + logger.info("\nTM-Score:") + logger.info(f" Mean: {df['TM_Score'].mean():.4f}") + logger.info(f" Std: {df['TM_Score'].std():.4f}") + logger.info(f" Min: {df['TM_Score'].min():.4f}") + logger.info(f" Max: {df['TM_Score'].max():.4f}") + logger.info(f" Median: {df['TM_Score'].median():.4f}") + + logger.info("\nRMSD:") + logger.info(f" Mean: {df['RMSD'].mean():.4f} Å") + logger.info(f" Std: {df['RMSD'].std():.4f} Å") + logger.info(f" Min: {df['RMSD'].min():.4f} Å") + logger.info(f" Max: {df['RMSD'].max():.4f} Å") + logger.info(f" Median: {df['RMSD'].median():.4f} Å") + + # Calculate pass rate + passing = len(df[df["RMSD"] < rmsd_threshold]) + pass_rate = (passing / len(df)) * 100 + logger.info(f"\nStructures with RMSD < {rmsd_threshold} Å: {passing}/{len(df)} ({pass_rate:.1f}%)") + + # Save to CSV + if output_csv: + output_path = Path(output_csv) + df.to_csv(output_path, index=False) + logger.info(f"\n✓ Results saved to: {output_path}") + else: + # Save to default location + output_path = Path("external_predictions_analysis.csv") + df.to_csv(output_path, index=False) + logger.info(f"\n✓ Results saved to: {output_path}") + + # Create aggregate summary table (similar to forward folding summary) + logger.info("\n" + "=" * 80) + logger.info("AGGREGATE SUMMARY") + logger.info("=" * 80) + + summary_table = pd.DataFrame( + [ + { + "Total_Structures": len(df), + "Avg_TM_Score": round(df["TM_Score"].mean(), 4), + "Std_TM_Score": round(df["TM_Score"].std(), 4), + "Min_TM_Score": round(df["TM_Score"].min(), 4), + "Max_TM_Score": round(df["TM_Score"].max(), 4), + "Avg_RMSD": round(df["RMSD"].mean(), 4), + "Std_RMSD": round(df["RMSD"].std(), 4), + "Min_RMSD": round(df["RMSD"].min(), 4), + "Max_RMSD": round(df["RMSD"].max(), 4), + f"Structures_RMSD<{rmsd_threshold}": passing, + f"Pct_RMSD<{rmsd_threshold}": round(pass_rate, 2), + } + ] + ) + + logger.info(f"\n{summary_table.to_string(index=False)}") + + # Save summary table + summary_csv = output_path.parent / f"{output_path.stem}_summary.csv" + summary_table.to_csv(summary_csv, index=False) + logger.info(f"\n✓ Summary saved to: {summary_csv}") + + # Display top and bottom performers + logger.info("\n" + "=" * 80) + logger.info("TOP 10 STRUCTURES (by TM-score)") + logger.info("=" * 80) + logger.info(f"\n{df[['Structure_ID', 'Seq_Identity', 'TM_Score', 'RMSD']].head(10).to_string(index=False)}") + + logger.info("\n" + "=" * 80) + logger.info("BOTTOM 10 STRUCTURES (by TM-score)") + logger.info("=" * 80) + logger.info(f"\n{df[['Structure_ID', 'Seq_Identity', 'TM_Score', 'RMSD']].tail(10).to_string(index=False)}") + + logger.info("\n" + "=" * 80) + logger.info("Analysis complete!") + logger.info("=" * 80) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze external predicted structures against ground truth", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Example usage: + python analyze_external_predictions.py \\ + --pred-dir /path/to/predicted/pdbs \\ + --gt-dir /path/to/ground_truth/pt_files \\ + --output analysis_results.csv \\ + --device cuda + + # For DPLM2 predictions: + python analyze_external_predictions.py \ + --pred-dir /homefs/home/lisanzas/scratch/Develop/dplm/generation-results/dplm2_650m/folding/pdb/ \ + --gt-dir /data2/lisanzas/multi_flow_data/test_set_filtered_pt/ \ + --output dplm2_folding_analysis.csv + """, + ) + + parser.add_argument("--pred-dir", type=str, required=True, help="Directory containing predicted PDB files") + + parser.add_argument("--gt-dir", type=str, required=True, help="Directory containing ground truth .pt files") + + parser.add_argument( + "--output", + type=str, + default="external_predictions_analysis.csv", + help="Output CSV file path (default: external_predictions_analysis.csv)", + ) + + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="Device to use for computation (default: cpu)", + ) + + parser.add_argument( + "--rmsd-threshold", type=float, default=2.0, help="RMSD threshold for pass rate calculation (default: 2.0 Å)" + ) + + parser.add_argument( + "--skip-sequence-mismatch", + action="store_true", + help="Skip structures with sequence mismatches (default: analyze anyway with warning)", + ) + + args = parser.parse_args() + + analyze_predictions( + pred_dir=args.pred_dir, + gt_dir=args.gt_dir, + output_csv=args.output, + device_str=args.device, + rmsd_threshold=args.rmsd_threshold, + skip_sequence_mismatch=args.skip_sequence_mismatch, + ) + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/distributed_esmfold_baseline.py b/src/lobster/cmdline/distributed_esmfold_baseline.py new file mode 100644 index 00000000..e48da741 --- /dev/null +++ b/src/lobster/cmdline/distributed_esmfold_baseline.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +WandB Distributed ESMFold Baseline Script + +Uses wandb agents as a distributed job queue to parallelize ESMFold baseline evaluation. +Compatible with the existing aggregation script (aggregate_results.py) for forward_folding mode. + +Usage: + # Initialize the job queue + wandb sweep src/lobster/cmdline/distributed_generation/wandb_config_esmfold_baseline.yaml + + # Submit SLURM array to process jobs + # Update submit_slurm.sh with sweep ID + sbatch src/lobster/cmdline/distributed_generation/submit_slurm_esmfold.sh +""" + +import glob +from pathlib import Path +import pandas as pd +from loguru import logger +import wandb +from omegaconf import OmegaConf + +# Import the main ESMFold baseline function +from lobster.cmdline.esmfold_baseline import main as run_esmfold_baseline + + +def main(): + """ + Main distributed ESMFold baseline function. + Each wandb agent runs this and gets assigned a job_id. + """ + # Initialize wandb run - this gets config from the sweep + with wandb.init() as run: + config = run.config + + job_id = config.job_id + + # Load base configuration + base_config_path = config.get("base_config_path", "src/lobster/hydra_config/experiment/esmfold_baseline.yaml") + + logger.info("Starting distributed ESMFold baseline job") + logger.info(f"Job ID: {job_id}") + logger.info(f"Loading base config from: {base_config_path}") + baseline_config = OmegaConf.load(base_config_path) + + # Set common parameters + output_base = baseline_config.get("output_dir", "./examples/esmfold_baseline") + baseline_config.output_dir = f"{output_base}/job_{job_id}" + + base_seed = baseline_config.get("seed", 12345) + baseline_config.seed = base_seed + job_id + + # Setup job configuration for structure-based processing + structures_per_job = config.structures_per_job + total_structures = config.total_structures + + # Get input structure pattern from base config + input_structures_pattern = baseline_config.generation.input_structures + + if not input_structures_pattern: + raise ValueError("input_structures must be set in base config") + + logger.info(f"Input structures pattern: {input_structures_pattern}") + + # Expand glob pattern to get all structure files + if isinstance(input_structures_pattern, str): + if "*" in input_structures_pattern or "?" in input_structures_pattern: + # Glob pattern + all_structure_files = sorted(glob.glob(input_structures_pattern)) + else: + # Single file or directory + path = Path(input_structures_pattern) + if path.is_file(): + all_structure_files = [str(path)] + elif path.is_dir(): + # Find all structure files in directory (PDB, CIF, PT) + all_structure_files = [] + all_structure_files.extend(sorted(glob.glob(str(path / "*.pdb")))) + all_structure_files.extend(sorted(glob.glob(str(path / "*.cif")))) + all_structure_files.extend(sorted(glob.glob(str(path / "*.pt")))) + else: + raise ValueError(f"Input path does not exist: {input_structures_pattern}") + elif isinstance(input_structures_pattern, (list, tuple)): + # Already a list of files + all_structure_files = sorted([str(p) for p in input_structures_pattern if Path(p).is_file()]) + else: + raise ValueError(f"Invalid input_structures format: {type(input_structures_pattern)}") + + if not all_structure_files: + raise ValueError(f"No structure files found matching: {input_structures_pattern}") + + logger.info(f"Found {len(all_structure_files)} total structure files") + + # Calculate this job's slice of structures + start_idx = job_id * structures_per_job + end_idx = min((job_id + 1) * structures_per_job, total_structures) + + # Ensure we don't exceed available files + end_idx = min(end_idx, len(all_structure_files)) + + job_structure_files = all_structure_files[start_idx:end_idx] + num_structures = len(job_structure_files) + + if num_structures == 0: + raise ValueError( + f"Job {job_id} has no structures to process (start_idx={start_idx}, total={len(all_structure_files)})" + ) + + logger.info(f"Structure range: {start_idx}-{end_idx} ({num_structures} structures)") + logger.info(f"First file: {job_structure_files[0]}") + logger.info(f"Last file: {job_structure_files[-1]}") + + # Override config with this job's structure subset + baseline_config.generation.input_structures = job_structure_files + + logger.info("Configuration for this job:") + logger.info(f" Output: {baseline_config.output_dir}") + logger.info(f" Structures to process: {num_structures} (indices {start_idx}-{end_idx})") + logger.info(f" Seed: {baseline_config.seed}") + + # Run ESMFold baseline + logger.info("Starting ESMFold baseline evaluation...") + run_esmfold_baseline(baseline_config) + logger.info("ESMFold baseline evaluation complete!") + + # Collect and log metrics to wandb + metrics = collect_job_metrics(baseline_config.output_dir) + + # Log to wandb + wandb.log({"job_id": job_id, "mode": "esmfold_baseline", **metrics}) + + logger.info(f"Job {job_id} completed successfully") + logger.info(f"Metrics: {metrics}") + + +def collect_job_metrics(output_dir: str) -> dict: + """ + Collect metrics from this job's outputs. + + Args: + output_dir: Path to job output directory + + Returns: + Dictionary of aggregated metrics + """ + + output_path = Path(output_dir) + metrics = {} + + # Find metrics CSV + csv_files = list(output_path.glob("*_metrics_*.csv")) + + if not csv_files: + logger.warning(f"No metrics CSV found in {output_dir}") + return metrics + + # Load most recent CSV + latest_csv = max(csv_files, key=lambda x: x.stat().st_mtime) + df = pd.read_csv(latest_csv) + + logger.info(f"Loaded metrics from {latest_csv}") + logger.info(f"Found {len(df)} samples") + + # Collect key metrics + metric_columns = ["plddt", "tm_score", "rmsd"] + + for metric in metric_columns: + if metric in df.columns: + values = pd.to_numeric(df[metric], errors="coerce").dropna() + if len(values) > 0: + metrics[f"avg_{metric}"] = float(values.mean()) + metrics[f"std_{metric}"] = float(values.std()) + metrics[f"min_{metric}"] = float(values.min()) + metrics[f"max_{metric}"] = float(values.max()) + + return metrics + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/distributed_generate.py b/src/lobster/cmdline/distributed_generate.py new file mode 100755 index 00000000..5f210d7b --- /dev/null +++ b/src/lobster/cmdline/distributed_generate.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 +""" +WandB Distributed Generation Script for genUME + +Uses wandb agents as a distributed job queue to parallelize structure generation. +Supports three modes: + 1. Unconditional: Each agent generates a subset of samples (e.g., 100 samples → 20 jobs × 5) + 2. Inverse Folding: Each agent processes a subset of input structures + 3. Forward Folding: Each agent processes a subset of input structures + +Usage: + # Initialize the job queue + wandb sweep src/lobster/cmdline/distributed_generation/wandb_config.yaml + + # Submit SLURM array to process jobs + # Update submit_slurm.sh with sweep ID + sbatch src/lobster/cmdline/distributed_generation/submit_slurm.sh +""" + +import glob +from pathlib import Path +from loguru import logger +import wandb +from omegaconf import OmegaConf + +# Import the original generation function +from lobster.cmdline.generate import generate as run_generation + + +def main(): + """ + Main distributed generation function. + Each wandb agent runs this and gets assigned a job_id. + Supports unconditional, inverse_folding, and forward_folding modes. + """ + # Initialize wandb run - this gets config from the sweep + with wandb.init() as run: + config = run.config + + job_id = config.job_id + + # Load base configuration + base_config_path = config.get( + "base_config_path", "src/lobster/hydra_config/experiment/generate_unconditional.yaml" + ) + + logger.info("Starting distributed generation job") + logger.info(f"Job ID: {job_id}") + logger.info(f"Loading base config from: {base_config_path}") + gen_config = OmegaConf.load(base_config_path) + + # Detect generation mode from config + mode = gen_config.generation.get("mode", "unconditional") + logger.info(f"Generation mode: {mode}") + + # Set common parameters + output_base = gen_config.get("output_dir", "./examples/generated") + gen_config.output_dir = f"{output_base}/job_{job_id}" + + base_seed = gen_config.get("seed", 12345) + gen_config.seed = base_seed + job_id + + # CRITICAL: Disable Foldseek clustering during distributed generation + gen_config.generation.calculate_foldseek_diversity = False + logger.info("Foldseek diversity calculation disabled (will run post-aggregation)") + + # Branch based on mode + if mode == "unconditional": + _setup_unconditional_job(gen_config, config, job_id) + elif mode in ["inverse_folding", "forward_folding"]: + _setup_structure_based_job(gen_config, config, job_id, mode) + else: + raise ValueError(f"Unknown generation mode: {mode}") + + # Run generation + logger.info("Starting generation...") + run_generation(gen_config) + logger.info("Generation complete!") + + # Collect and log metrics to wandb + metrics = collect_job_metrics(gen_config.output_dir) + + # Log to wandb + wandb.log({"job_id": job_id, "mode": mode, **metrics}) + + logger.info(f"Job {job_id} completed successfully") + logger.info(f"Metrics: {metrics}") + + +def _setup_unconditional_job(gen_config, config, job_id): + """ + Setup job configuration for unconditional generation mode. + + Args: + gen_config: OmegaConf config to modify + config: WandB sweep config + job_id: Job ID for this worker + """ + samples_per_job = config.samples_per_job + total_samples = config.total_samples + + start_sample = job_id * samples_per_job + end_sample = min((job_id + 1) * samples_per_job, total_samples) + num_samples = end_sample - start_sample + + logger.info(f"Sample range: {start_sample}-{end_sample} ({num_samples} samples per length)") + + # Set number of samples for this job chunk + # IMPORTANT: num_samples is PER LENGTH in the config + # If config has length: [100, 200, 300], each job generates num_samples × 3 structures + gen_config.generation.num_samples = num_samples + + # Optional: Override any parameters from wandb config + if "length" in config: + gen_config.generation.length = config.length + if "nsteps" in config: + gen_config.generation.nsteps = config.nsteps + + # Calculate actual number of structures + lengths = gen_config.generation.length + if isinstance(lengths, list): + num_lengths = len(lengths) + total_structures = num_samples * num_lengths + else: + num_lengths = 1 + total_structures = num_samples + + logger.info("Configuration for this job:") + logger.info(f" Output: {gen_config.output_dir}") + logger.info(f" Samples per length: {num_samples} (indices {start_sample}-{end_sample})") + logger.info(f" Lengths: {gen_config.generation.length} ({num_lengths} lengths)") + logger.info(f" Total structures this job: {total_structures}") + logger.info(f" Seed: {gen_config.seed}") + logger.info(f" Steps: {gen_config.generation.nsteps}") + + +def _setup_structure_based_job(gen_config, config, job_id, mode): + """ + Setup job configuration for inverse_folding or forward_folding modes. + + Args: + gen_config: OmegaConf config to modify + config: WandB sweep config + job_id: Job ID for this worker + mode: "inverse_folding" or "forward_folding" + """ + structures_per_job = config.structures_per_job + total_structures = config.total_structures + + # Get input structure pattern from base config + input_structures_pattern = gen_config.generation.input_structures + + if not input_structures_pattern: + raise ValueError(f"input_structures must be set in base config for {mode} mode") + + logger.info(f"Input structures pattern: {input_structures_pattern}") + + # Expand glob pattern to get all structure files + if isinstance(input_structures_pattern, str): + if "*" in input_structures_pattern or "?" in input_structures_pattern: + # Glob pattern + all_structure_files = sorted(glob.glob(input_structures_pattern)) + else: + # Single file or directory + path = Path(input_structures_pattern) + if path.is_file(): + all_structure_files = [str(path)] + elif path.is_dir(): + # Find all structure files in directory (PDB, CIF, PT) + all_structure_files = [] + all_structure_files.extend(sorted(glob.glob(str(path / "*.pdb")))) + all_structure_files.extend(sorted(glob.glob(str(path / "*.cif")))) + all_structure_files.extend(sorted(glob.glob(str(path / "*.pt")))) + else: + raise ValueError(f"Input path does not exist: {input_structures_pattern}") + elif isinstance(input_structures_pattern, (list, tuple)): + # Already a list of files + all_structure_files = sorted([str(p) for p in input_structures_pattern if Path(p).is_file()]) + else: + raise ValueError(f"Invalid input_structures format: {type(input_structures_pattern)}") + + if not all_structure_files: + raise ValueError(f"No structure files found matching: {input_structures_pattern}") + + logger.info(f"Found {len(all_structure_files)} total structure files") + + # Calculate this job's slice of structures + start_idx = job_id * structures_per_job + end_idx = min((job_id + 1) * structures_per_job, total_structures) + + # Ensure we don't exceed available files + end_idx = min(end_idx, len(all_structure_files)) + + job_structure_files = all_structure_files[start_idx:end_idx] + num_structures = len(job_structure_files) + + if num_structures == 0: + raise ValueError( + f"Job {job_id} has no structures to process (start_idx={start_idx}, total={len(all_structure_files)})" + ) + + logger.info(f"Structure range: {start_idx}-{end_idx} ({num_structures} structures)") + logger.info(f"First file: {job_structure_files[0]}") + logger.info(f"Last file: {job_structure_files[-1]}") + + # Override config with this job's structure subset + gen_config.generation.input_structures = job_structure_files + + # Optional: Override any parameters from wandb config + if "nsteps" in config: + gen_config.generation.nsteps = config.nsteps + + logger.info("Configuration for this job:") + logger.info(f" Output: {gen_config.output_dir}") + logger.info(f" Mode: {mode}") + logger.info(f" Structures to process: {num_structures} (indices {start_idx}-{end_idx})") + logger.info(f" Seed: {gen_config.seed}") + logger.info(f" Steps: {gen_config.generation.nsteps}") + + +def collect_job_metrics(output_dir: str) -> dict: + """ + Collect metrics from this job's outputs. + + Args: + output_dir: Path to job output directory + + Returns: + Dictionary of aggregated metrics + """ + import pandas as pd + + output_path = Path(output_dir) + metrics = {} + + # Find metrics CSV + csv_files = list(output_path.glob("*_metrics_*.csv")) + + if not csv_files: + logger.warning(f"No metrics CSV found in {output_dir}") + return metrics + + # Load most recent CSV + latest_csv = max(csv_files, key=lambda x: x.stat().st_mtime) + df = pd.read_csv(latest_csv) + + logger.info(f"Loaded metrics from {latest_csv}") + logger.info(f"Found {len(df)} samples") + + # Collect key metrics + metric_columns = ["plddt", "predicted_aligned_error", "tm_score", "rmsd"] + + for metric in metric_columns: + if metric in df.columns: + values = pd.to_numeric(df[metric], errors="coerce").dropna() + if len(values) > 0: + metrics[f"avg_{metric}"] = float(values.mean()) + metrics[f"std_{metric}"] = float(values.std()) + metrics[f"min_{metric}"] = float(values.min()) + metrics[f"max_{metric}"] = float(values.max()) + + return metrics + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/distributed_generation/README.md b/src/lobster/cmdline/distributed_generation/README.md new file mode 100644 index 00000000..b052fa05 --- /dev/null +++ b/src/lobster/cmdline/distributed_generation/README.md @@ -0,0 +1,437 @@ +# Distributed Generation with WandB + +This directory contains scripts for distributed structure generation using wandb agents as a job queue. + +**Supports Three Modes:** +1. **Unconditional**: Generate novel protein structures from scratch +2. **Inverse Folding**: Design sequences for given structures (sequence recovery) +3. **Forward Folding**: Generate structures from given sequences/structures + +## Files + +- **`wandb_config.yaml`** - Job queue configuration (generated by `create_job_config.py`) +- **`submit_slurm.sh`** - SLURM submission script +- **`create_job_config.py`** - Helper to generate job distribution configs for all modes +- **`aggregate_results.py`** - Aggregates results with mode-specific metrics + +## Quick Start + +Choose your mode and follow the corresponding workflow: + +### Mode 1: Unconditional Generation + +Generate novel protein structures from scratch. + +#### 1. Generate Job Configuration + +```bash +cd /homefs/home/lisanzas/scratch/Develop/lobster + +# Generate config for 100 samples (5 per job = 20 jobs) +uv run python src/lobster/cmdline/distributed_generation/create_job_config.py \ + --mode unconditional \ + --total_samples 100 \ + --samples_per_job 5 +``` + +### Mode 2: Inverse Folding + +Design sequences for given structures (sequence recovery task). + +#### 1. Generate Job Configuration + +```bash +cd /homefs/home/lisanzas/scratch/Develop/lobster + +# Distribute 449 structure files across jobs (5 structures per job = 90 jobs) +uv run python src/lobster/cmdline/distributed_generation/create_job_config.py \ + --mode inverse_folding \ + --input_structures "/data2/lisanzas/multi_flow_data/test_set_filtered_pt/*.pt" \ + --structures_per_job 5 \ + --base_config src/lobster/hydra_config/experiment/generate_inverse_folding_450M.yaml +``` + +### Mode 3: Forward Folding + +Generate structures from given sequences/structures. + +#### 1. Generate Job Configuration + +```bash +cd /homefs/home/lisanzas/scratch/Develop/lobster + +# Distribute structure files across jobs +uv run python src/lobster/cmdline/distributed_generation/create_job_config.py \ + --mode forward_folding \ + --input_structures "/data2/lisanzas/multi_flow_data/test_set_filtered_pt/*.pt" \ + --structures_per_job 5 \ + --base_config src/lobster/hydra_config/experiment/generate_forward_folding_450M.yaml +``` + +## Common Steps (All Modes) + +### 2. Initialize WandB Sweep + +```bash +wandb sweep src/lobster/cmdline/distributed_generation/wandb_config.yaml +``` + +**Note:** The sweep ID format depends on your mode: +- Unconditional: `prescient-design/lobster-distributed-generation/abc123xyz` +- Inverse folding: `prescient-design/lobster-distributed-inverse-folding/abc123xyz` +- Forward folding: `prescient-design/lobster-distributed-forward-folding/abc123xyz` + +### 3. Update SLURM Script + +Edit `submit_slurm.sh` and update the wandb agent line with your sweep ID: + +```bash +# Update this line in submit_slurm.sh: +wandb agent prescient-design// +``` + +### 4. Submit Jobs + +```bash +sbatch src/lobster/cmdline/distributed_generation/submit_slurm.sh +``` + +### 5. Monitor Progress + +- Check SLURM jobs: `squeue -u $USER` +- View in WandB UI: https://genentech.wandb.io/prescient-design/ +- Check logs: `/data2/ume/gen_ume/slurm/logs/distributed_gen/` + +### 6. Aggregate Results + +After all jobs complete, aggregate based on your mode: + +#### Unconditional Generation +```bash +uv run python src/lobster/cmdline/distributed_generation/aggregate_results.py \ + ./examples/generated_unconditional 20 \ + --mode unconditional +``` + +This will: +- Combine all metrics and sequences +- Organize structures by length +- Run Foldseek clustering on ALL structures per length +- Calculate diversity metrics + +#### Inverse Folding +```bash +uv run python src/lobster/cmdline/distributed_generation/aggregate_results.py \ + ./examples/generated_inverse_folding_450M 90 \ + --mode inverse_folding +``` + +This will: +- Combine all metrics from all jobs +- Calculate overall AAR (amino acid recovery) +- Calculate TM-score, RMSD, pLDDT statistics +- Create per-structure summary table +- **No Foldseek** (diversity not applicable for inverse folding) + +#### Forward Folding +```bash +uv run python src/lobster/cmdline/distributed_generation/aggregate_results.py \ + ./examples/generated_forward_folding 90 \ + --mode forward_folding +``` + +This will: +- Combine all metrics from all jobs +- Calculate TM-score, RMSD, pLDDT statistics +- Create per-structure summary table +- **No Foldseek** (diversity not applicable) + +**Auto-detection:** If you don't specify `--mode`, it will auto-detect from the CSV columns. + +## Key Features + +- **Three Generation Modes**: Unconditional, inverse folding, and forward folding +- **Mode Auto-Detection**: Automatically detects mode from CSV columns +- **Smart Foldseek**: Only runs for unconditional generation (diversity analysis) +- **AAR Tracking**: Automatically calculated for inverse folding (from `percent_identity` column) +- **Per-Structure Grouping**: Inverse/forward folding results grouped by input structure +- **Progress Tracking**: All jobs report to WandB +- **Fault Tolerance**: WandB reassigns failed jobs +- **Parallel Execution**: Multiple SLURM array jobs run simultaneously + +## Configuration + +### Base Configuration by Mode + +**Unconditional:** +``` +src/lobster/hydra_config/experiment/generate_unconditional.yaml +``` + +**Inverse Folding:** +``` +src/lobster/hydra_config/experiment/generate_inverse_folding_450M.yaml +``` + +**Forward Folding:** +``` +src/lobster/hydra_config/experiment/generate_forward_folding_450M.yaml +``` + +### Important: Understanding Job Distribution + +#### Unconditional Generation +`total_samples` vs `num_samples`: `num_samples` in the config is **per length**, not total structures. + +- In `generate_unconditional.yaml`: `num_samples: 10` means 10 samples per length +- If `length: [100, 200, 300, 400, 500]` (5 lengths), actual output = 10 × 5 = **50 structures** + +**In Distributed Generation**: +- `--total_samples 1000` means 1000 samples **per length** +- `--samples_per_job 50` means each job generates 50 samples **per length** +- With 5 lengths: each job generates 50 × 5 = **250 structures** +- Total across all jobs: 1000 × 5 = **5000 structures** + +**Recommendation**: For distributed generation, use a **single length** in your config: +```yaml +generation: + length: [500] # Single length +``` +Then `--total_samples 1000` = exactly 1000 structures + +#### Inverse/Forward Folding +For structure-based modes, distribution is by **input files**: + +- `--input_structures "/path/*.pt"` expands to list of files (e.g., 449 files) +- `--structures_per_job 5` means each job processes 5 input files +- Result: 449 files ÷ 5 = 90 jobs (89 jobs × 5 files + 1 job × 4 files) + +### Job Parameters + +#### Unconditional +Each job receives: +- Unique `job_id` +- Unique sample range (`start_sample` to `end_sample`) +- Unique seed (`base_seed + job_id`) +- Independent output directory (`output_dir/job_{job_id}`) +- Overridden `num_samples` (end_sample - start_sample) + +#### Inverse/Forward Folding +Each job receives: +- Unique `job_id` +- Subset of input structure files (e.g., files 0-4, 5-9, etc.) +- Unique seed (`base_seed + job_id`) +- Independent output directory (`output_dir/job_{job_id}`) +- Overridden `input_structures` (list of files for this job) + +## Output Structure + +### Unconditional Generation +``` +./examples/generated_unconditional/ +├── job_0/ +│ ├── *.pdb +│ ├── unconditional_metrics_*.csv +│ └── sequences_*.csv +├── job_1/ +│ └── ... +└── aggregated/ + ├── combined_metrics.csv + ├── combined_sequences.csv + ├── summary_per_length.csv + ├── job_*_*.pdb + └── foldseek_results/ + ├── length_100/ + ├── length_200/ + └── ... +``` + +### Inverse Folding +``` +./examples/generated_inverse_folding_450M/ +├── job_0/ +│ ├── *.pdb # Generated + ESMFold structures +│ ├── inverse_folding_metrics_*.csv # AAR, TM-score, RMSD, pLDDT +│ └── sequences_inverse_folding_*.csv +├── job_1/ +│ └── ... +└── aggregated/ + ├── combined_inverse_folding_metrics.csv # All metrics combined + ├── combined_inverse_folding_sequences.csv + ├── summary_per_structure.csv # AAR + metrics per input structure + └── job_*_*.pdb # All PDB files +``` + +### Forward Folding +``` +./examples/generated_forward_folding/ +├── job_0/ +│ ├── *.pdb # Generated structures +│ └── forward_folding_metrics_*.csv # TM-score, RMSD, pLDDT +├── job_1/ +│ └── ... +└── aggregated/ + ├── combined_forward_folding_metrics.csv # All metrics combined + ├── summary_per_structure.csv # Metrics per input structure + └── job_*_*.pdb # All PDB files +``` + +## Advanced Usage + +### Custom Parameters + +```bash +# Multi-length generation +uv run python src/lobster/cmdline/distributed_generation/create_job_config.py \ + --total_samples 500 \ + --samples_per_job 25 \ + --lengths 100 200 300 400 500 + +# Custom Foldseek parameters +uv run python src/lobster/cmdline/distributed_generation/aggregate_results.py \ + ./examples/generated_unconditional 20 \ + --tmscore-threshold 0.6 \ + --rmsd-threshold 2.5 + +# Skip Foldseek (fast aggregation only) +uv run python src/lobster/cmdline/distributed_generation/aggregate_results.py \ + ./examples/generated_unconditional 20 \ + --no-foldseek +``` + +## Output Files + +### Unconditional Generation + +After aggregation, files in `{output_dir}/aggregated/`: + +- **`combined_metrics.csv`**: All metrics from all jobs combined +- **`combined_sequences.csv`**: All sequences from all jobs combined +- **`summary_per_length.csv`**: Comprehensive summary table with per-length statistics: + - Total structures + - Structures passing RMSD threshold + - Percentage passing RMSD + - Number of Foldseek clusters + - Diversity percentage + - Average TM score, RMSD, and pLDDT +- **`foldseek_results/length_{L}/`**: Foldseek clustering results per length + +Example `summary_per_length.csv`: +``` +Length Total_Structures Structures_RMSD<2.0 Pct_RMSD<2.0 Num_Clusters Diversity_Pct Avg_TM_Score Avg_RMSD Avg_pLDDT +100 196 196 100.0 68 34.69 0.9345 0.5234 0.7845 +200 196 194 98.98 72 37.11 0.9256 0.6123 0.7654 +300 196 195 99.49 75 38.46 0.9123 0.7234 0.7423 +``` + +### Inverse Folding + +After aggregation, files in `{output_dir}/aggregated/`: + +- **`combined_inverse_folding_metrics.csv`**: All metrics from all jobs combined + - Columns: `run_id`, `timestamp`, `mode`, **`percent_identity`** (AAR), `plddt`, `predicted_aligned_error`, `tm_score`, `rmsd`, `sequence_length`, `input_file`, `job_id` +- **`combined_inverse_folding_sequences.csv`**: All designed sequences +- **`summary_per_structure.csv`**: Per-structure summary with: + - Input_Structure (from `input_file` column) + - Num_Designs + - **Avg_AAR**, Min_AAR, Max_AAR (amino acid recovery %) + - Avg_TM_Score, Min_TM_Score, Max_TM_Score + - Avg_RMSD, Min_RMSD, Max_RMSD + - Designs_RMSD<2.0, Pct_RMSD<2.0 + - Avg_pLDDT + +Example `summary_per_structure.csv`: +``` +Input_Structure Num_Designs Avg_AAR Min_AAR Max_AAR Avg_TM_Score Min_TM_Score Max_TM_Score Avg_RMSD Designs_RMSD<2.0 Pct_RMSD<2.0 Avg_pLDDT +batch_000 1 22.79 22.79 22.79 0.8530 0.8530 0.8530 2.026 0 0.0 0.663 +batch_001 1 78.99 78.99 78.99 0.9777 0.9777 0.9777 0.364 1 100.0 0.789 +batch_002 1 43.50 43.50 43.50 0.8807 0.8807 0.8807 1.160 1 100.0 0.769 +``` + +**Key Metric:** AAR (Amino Acid Recovery) = `percent_identity` column = % of positions matching native sequence + +### Forward Folding + +After aggregation, files in `{output_dir}/aggregated/`: + +- **`combined_forward_folding_metrics.csv`**: All metrics from all jobs combined + - Columns: `run_id`, `timestamp`, `mode`, `tm_score`, `rmsd`, `plddt`, `input_file`, `job_id` +- **`summary_per_structure.csv`**: Per-structure summary with: + - Input_Structure + - Num_Structures + - Avg_TM_Score, Min_TM_Score, Max_TM_Score + - Avg_RMSD, Min_RMSD, Max_RMSD + - Structures_RMSD<2.0, Pct_RMSD<2.0 + - Avg_pLDDT + +Example `summary_per_structure.csv`: +``` +Input_Structure Num_Structures Avg_TM_Score Min_TM_Score Max_TM_Score Avg_RMSD Structures_RMSD<2.0 Pct_RMSD<2.0 Avg_pLDDT +batch_000 1 0.7456 0.7456 0.7456 2.234 0 0.0 0.689 +batch_001 1 0.7312 0.7312 0.7312 1.456 1 100.0 0.691 +``` + +## Troubleshooting + +### Jobs not getting assigned +- Check `method: grid` is set in `wandb_config.yaml` +- Verify all parameter lists have equal length +- Check sweep status in WandB UI +- Verify the correct project name in `submit_slurm.sh`: + - Unconditional: `lobster-distributed-generation` + - Inverse folding: `lobster-distributed-inverse-folding` + - Forward folding: `lobster-distributed-forward-folding` + +### Mode-specific issues + +#### Invalid input_structures format error +**Error:** `ValueError: Invalid input_structures format: ` + +**Solution:** This was fixed in the latest version. Make sure you're using the updated `generate.py` that includes `ListConfig` support. + +#### No structures found +**Error:** `ValueError: No structure files found matching: ` + +**Solution:** +- Verify the glob pattern is correct +- Check file permissions +- Use absolute paths instead of relative paths +- Test the glob pattern: `ls /path/to/structures/*.pt | wc -l` + +#### Mode auto-detection wrong +**Error:** Aggregation detects wrong mode + +**Solution:** +- Explicitly specify `--mode` in aggregation command +- Check that CSV has the expected columns: + - Inverse folding: `percent_identity` + `input_file` + - Forward folding: `input_file` (no `percent_identity`) + - Unconditional: `sequence_length` + +### Foldseek errors +- **Only applies to unconditional generation** +- Verify binary path: `/homefs/home/lisanzas/scratch/Develop/lobster/src/lobster/metrics/foldseek/bin/foldseek` +- Check structure files exist in aggregated directory +- Look for timeout errors in aggregation logs +- To skip Foldseek: use `--no-foldseek` flag + +### OOM errors +- Reduce `--array` size in `submit_slurm.sh` +- Increase `--mem` if needed +- Ensure `batch_size: 1` for long sequences +- For inverse/forward folding: reduce `--structures_per_job` + +### Performance Tips + +#### Unconditional Generation +- Use single length for predictable job sizes +- Start with fewer jobs to test timing +- Monitor Foldseek aggregation time (can be slow for many structures) + +#### Inverse/Forward Folding +- Adjust `--structures_per_job` based on structure complexity: + - Small structures (<200 AA): 10-20 per job + - Medium structures (200-400 AA): 5-10 per job + - Large structures (>400 AA): 2-5 per job +- Sort files deterministically (alphabetically) for reproducibility + diff --git a/src/lobster/cmdline/distributed_generation/aggregate_results.py b/src/lobster/cmdline/distributed_generation/aggregate_results.py new file mode 100755 index 00000000..74263c5e --- /dev/null +++ b/src/lobster/cmdline/distributed_generation/aggregate_results.py @@ -0,0 +1,1150 @@ +#!/usr/bin/env python3 +""" +Aggregate results from distributed generation jobs. + +Supports three modes: +1. Unconditional: Groups by length, runs Foldseek diversity analysis +2. Inverse Folding: Groups by input structure, calculates AAR and structural metrics +3. Forward Folding: Groups by input structure, calculates structural metrics +""" + +import pandas as pd +from pathlib import Path +from loguru import logger +import shutil + +from lobster.metrics.cal_foldseek_clusters import run_easy_cluster + + +def detect_generation_mode(metrics_df: pd.DataFrame) -> str: + """ + Detect generation mode from metrics DataFrame. + + CSV columns by mode: + - Inverse folding: has 'percent_identity' and 'input_file' columns, mode='inverse_folding' + - Forward folding: has 'input_file' but no 'percent_identity', mode='forward_folding' + - Unconditional: has 'sequence_length' for grouping by length + + Args: + metrics_df: DataFrame with metrics from a single job + + Returns: + "unconditional", "inverse_folding", or "forward_folding" + """ + # Check for mode column (most reliable) + if "mode" in metrics_df.columns: + mode_value = metrics_df["mode"].iloc[0] + if mode_value in ["inverse_folding", "forward_folding", "unconditional"]: + logger.info(f"Detected mode from 'mode' column: {mode_value}") + return mode_value + + # Fallback: check for mode-specific columns + if "percent_identity" in metrics_df.columns and "input_file" in metrics_df.columns: + logger.info("Detected mode from columns: inverse_folding") + return "inverse_folding" + elif "input_file" in metrics_df.columns: + logger.info("Detected mode from columns: forward_folding") + return "forward_folding" + elif "sequence_length" in metrics_df.columns: + logger.info("Detected mode from columns: unconditional") + return "unconditional" + else: + logger.warning("Could not detect mode from columns, defaulting to unconditional") + return "unconditional" + + +def aggregate_distributed_results( + base_output_dir: str, + num_jobs: int, + mode: str = None, + run_foldseek: bool = None, + foldseek_bin_path: str = None, + foldseek_tmscore_threshold: float = 0.5, + rmsd_threshold: float = 2.0, +): + """ + Aggregate results from multiple distributed generation jobs. + + Supports three modes: + - unconditional: Groups by length, runs Foldseek diversity analysis + - inverse_folding: Groups by input structure, reports AAR and structural metrics + - forward_folding: Groups by input structure, reports structural metrics + + Args: + base_output_dir: Base output directory containing job_* subdirectories + num_jobs: Number of jobs to aggregate + mode: Generation mode ("unconditional", "inverse_folding", "forward_folding") + If None, will auto-detect from metrics CSV + run_foldseek: Whether to run Foldseek clustering + If None, auto-set based on mode (True for unconditional only) + foldseek_bin_path: Path to Foldseek binary directory + foldseek_tmscore_threshold: TM-score threshold for clustering + rmsd_threshold: RMSD threshold for filtering structures + + Returns: + Dictionary with aggregation results + """ + base_path = Path(base_output_dir) + + # Auto-detect mode if not provided + if mode is None: + logger.info("Mode not specified, attempting auto-detection...") + for job_id in range(num_jobs): + job_dir = base_path / f"job_{job_id}" + if not job_dir.exists(): + continue + + metrics_files = list(job_dir.glob("*_metrics_*.csv")) + if metrics_files: + df = pd.read_csv(metrics_files[0]) + mode = detect_generation_mode(df) + break + + if mode is None: + logger.warning("Could not auto-detect mode, defaulting to unconditional") + mode = "unconditional" + + logger.info(f"Aggregating results in {mode} mode") + + # Auto-set Foldseek based on mode if not explicitly provided + if run_foldseek is None: + run_foldseek = mode == "unconditional" + if not run_foldseek: + logger.info(f"Foldseek disabled for {mode} mode (diversity analysis not applicable)") + + # Branch to mode-specific aggregation + if mode == "unconditional": + return aggregate_unconditional( + base_path=base_path, + num_jobs=num_jobs, + run_foldseek=run_foldseek, + foldseek_bin_path=foldseek_bin_path, + foldseek_tmscore_threshold=foldseek_tmscore_threshold, + rmsd_threshold=rmsd_threshold, + ) + elif mode == "inverse_folding": + return aggregate_inverse_folding( + base_path=base_path, + num_jobs=num_jobs, + rmsd_threshold=rmsd_threshold, + ) + elif mode == "forward_folding": + return aggregate_forward_folding( + base_path=base_path, + num_jobs=num_jobs, + rmsd_threshold=rmsd_threshold, + ) + else: + raise ValueError(f"Unknown mode: {mode}") + + +def aggregate_unconditional( + base_path: Path, + num_jobs: int, + run_foldseek: bool, + foldseek_bin_path: str, + foldseek_tmscore_threshold: float, + rmsd_threshold: float, +) -> dict: + """ + Aggregate unconditional generation results. + Groups by length, runs Foldseek diversity analysis. + + This is the original aggregation logic for unconditional generation. + """ + + # Create aggregated output directory + agg_dir = base_path / "aggregated" + agg_dir.mkdir(exist_ok=True) + + all_metrics = [] + all_sequences = [] + structures_by_length = {} # Track structures by length for Foldseek + + logger.info(f"Aggregating {num_jobs} jobs from {base_path}") + + # Collect from each job + for job_id in range(num_jobs): + job_dir = base_path / f"job_{job_id}" + + if not job_dir.exists(): + logger.warning(f"Job {job_id} directory not found: {job_dir}") + continue + + logger.info(f"Processing job {job_id}") + + # Find metrics CSV + metrics_files = list(job_dir.glob("*_metrics_*.csv")) + if metrics_files: + df = pd.read_csv(metrics_files[0]) + df["job_id"] = job_id + all_metrics.append(df) + + # Find sequences CSV + seq_files = list(job_dir.glob("sequences_*.csv")) + if seq_files: + df_seq = pd.read_csv(seq_files[0]) + df_seq["job_id"] = job_id + all_sequences.append(df_seq) + + # Collect PDB files organized by length + # ONLY collect ESMFold structures for Foldseek analysis + pdb_files = list(job_dir.glob("*.pdb")) + esmfold_count = 0 + for pdb in pdb_files: + # Filter: Only include ESMFold structures (contain "_esmfold_" in filename) + if "_esmfold_" not in pdb.name: + continue + esmfold_count += 1 + + # Extract length from filename if possible + # Assuming format like: unconditional_length_500_sample_0_esmfold_000.pdb + try: + parts = pdb.stem.split("_") + if "length" in parts: + length_idx = parts.index("length") + 1 + length = int(parts[length_idx]) + else: + # Fallback: read PDB to get length + length = get_pdb_length(pdb) + + if length not in structures_by_length: + structures_by_length[length] = [] + + # Copy with unique name + new_name = f"job_{job_id}_{pdb.name}" + dest = agg_dir / new_name + shutil.copy2(pdb, dest) + structures_by_length[length].append(dest) + + except Exception as e: + logger.warning(f"Could not determine length for {pdb.name}: {e}") + # Copy anyway + new_name = f"job_{job_id}_{pdb.name}" + shutil.copy2(pdb, agg_dir / new_name) + + if esmfold_count > 0: + logger.info(f" Found {esmfold_count} ESMFold structures (out of {len(pdb_files)} total)") + + # Combine metrics + diversity_results = {} + if all_metrics: + combined_metrics = pd.concat(all_metrics, ignore_index=True) + output_metrics = agg_dir / "combined_metrics.csv" + combined_metrics.to_csv(output_metrics, index=False) + logger.info(f"Saved combined metrics: {output_metrics}") + logger.info(f"Total samples: {len(combined_metrics)}") + + # Print summary statistics + logger.info("\n=== Summary Statistics ===") + for col in ["plddt", "tm_score", "rmsd", "predicted_aligned_error"]: + if col in combined_metrics.columns: + logger.info(f"{col}:") + logger.info(f" Mean: {combined_metrics[col].mean():.3f}") + logger.info(f" Std: {combined_metrics[col].std():.3f}") + logger.info(f" Min: {combined_metrics[col].min():.3f}") + logger.info(f" Max: {combined_metrics[col].max():.3f}") + + # Run Foldseek clustering per length + if run_foldseek and structures_by_length: + total_esmfold = sum(len(pdbs) for pdbs in structures_by_length.values()) + logger.info("\n=== Running Foldseek Clustering ===") + logger.info(f"Found {total_esmfold} ESMFold structures at {len(structures_by_length)} different lengths") + logger.info("Note: Only ESMFold-validated structures are used for diversity analysis") + + diversity_results = run_foldseek_clustering( + structures_by_length=structures_by_length, + output_dir=agg_dir, + combined_metrics=combined_metrics, + foldseek_bin_path=foldseek_bin_path, + tmscore_threshold=foldseek_tmscore_threshold, + rmsd_threshold=rmsd_threshold, + ) + + # Log diversity results + logger.info("\n=== Diversity Results ===") + for length, results in diversity_results.items(): + logger.info(f"Length {length}:") + logger.info(f" Total structures: {results['total_structures']}") + logger.info(f" Structures passing RMSD < {rmsd_threshold}: {results['structures_passing_rmsd']}") + logger.info(f" Number of clusters: {results['num_clusters']}") + logger.info(f" Diversity: {results['diversity_percentage']:.1f}%") + + # Create comprehensive summary table per length + logger.info("\n=== Creating Summary Table ===") + summary_table = create_summary_table( + combined_metrics=combined_metrics, + diversity_results=diversity_results, + structures_by_length=structures_by_length, + rmsd_threshold=rmsd_threshold, + ) + + if summary_table is not None: + # Save to CSV + summary_csv = agg_dir / "summary_per_length.csv" + summary_table.to_csv(summary_csv, index=False) + logger.info(f"Saved summary table: {summary_csv}") + + # Print table + logger.info("\n=== Summary Per Length ===") + logger.info(f"\n{summary_table.to_string(index=False)}") + else: + logger.warning("Could not create summary table") + + # Combine sequences + if all_sequences: + combined_sequences = pd.concat(all_sequences, ignore_index=True) + output_sequences = agg_dir / "combined_sequences.csv" + combined_sequences.to_csv(output_sequences, index=False) + logger.info(f"Saved combined sequences: {output_sequences}") + + logger.info(f"\nAggregation complete! Results in: {agg_dir}") + + return { + "mode": "unconditional", + "aggregated_dir": str(agg_dir), + "total_samples": len(combined_metrics) if all_metrics else 0, + "diversity_results": diversity_results, + } + + +def create_summary_table( + combined_metrics: pd.DataFrame, diversity_results: dict, structures_by_length: dict, rmsd_threshold: float = 2.0 +) -> pd.DataFrame: + """ + Create a comprehensive summary table with metrics per length. + + Args: + combined_metrics: DataFrame with all metrics + diversity_results: Dictionary with diversity results per length + structures_by_length: Dict mapping length -> list of PDB files + rmsd_threshold: RMSD threshold used for filtering + + Returns: + DataFrame with summary statistics per length + """ + if combined_metrics is None or len(combined_metrics) == 0: + logger.warning("No metrics available for summary table") + return None + + summary_rows = [] + + # Get unique lengths from metrics + if "sequence_length" in combined_metrics.columns: + lengths = sorted(combined_metrics["sequence_length"].unique()) + else: + lengths = sorted(structures_by_length.keys()) + + for length in lengths: + # Filter metrics for this length + length_metrics = combined_metrics[combined_metrics["sequence_length"] == length] + # Filter rows where rmsd is not found + length_metrics = length_metrics[length_metrics["rmsd"].notna()] + + if len(length_metrics) == 0: + continue + + # Calculate basic metrics + total_structures = len(length_metrics) + structures_passing_rmsd = len(length_metrics[length_metrics["rmsd"] < rmsd_threshold]) + pct_passing_rmsd = (structures_passing_rmsd / total_structures * 100) if total_structures > 0 else 0 + + # Get diversity metrics if available + num_clusters = 0 + diversity_pct = 0.0 + if length in diversity_results: + num_clusters = diversity_results[length]["num_clusters"] + diversity_pct = diversity_results[length]["diversity_percentage"] + + # Calculate average metrics + avg_tm = length_metrics["tm_score"].mean() if "tm_score" in length_metrics.columns else 0 + avg_rmsd = length_metrics["rmsd"].mean() if "rmsd" in length_metrics.columns else 0 + avg_plddt = length_metrics["plddt"].mean() if "plddt" in length_metrics.columns else 0 + + summary_rows.append( + { + "Length": int(length), + "Total_Structures": total_structures, + f"Structures_RMSD<{rmsd_threshold}": structures_passing_rmsd, + f"Pct_RMSD<{rmsd_threshold}": round(pct_passing_rmsd, 2), + "Num_Clusters": num_clusters, + "Diversity_Pct": round(diversity_pct, 2), + "Avg_TM_Score": round(avg_tm, 4), + "Avg_RMSD": round(avg_rmsd, 4), + "Avg_pLDDT": round(avg_plddt, 4), + } + ) + + if not summary_rows: + logger.warning("No data to create summary table") + return None + + return pd.DataFrame(summary_rows) + + +def get_pdb_length(pdb_path: Path) -> int: + """ + Get sequence length from PDB file by counting CA atoms. + + Args: + pdb_path: Path to PDB file + + Returns: + Number of residues + """ + try: + import biotite.structure.io.pdb as pdb + + structure = pdb.PDBFile.read(str(pdb_path)) + atom_array = structure.get_structure()[0] + + # Count CA atoms + ca_mask = atom_array.atom_name == "CA" + return ca_mask.sum() + except Exception as e: + logger.warning(f"Failed to get length from {pdb_path}: {e}") + return 0 + + +def aggregate_inverse_folding( + base_path: Path, + num_jobs: int, + rmsd_threshold: float = 2.0, +) -> dict: + """ + Aggregate inverse folding results. + Groups by input structure, calculates AAR and structural metrics. + + Key metrics (already in CSV): + - AAR (Amino Acid Recovery): from 'percent_identity' column + - TM-score: from 'tm_score' column + - RMSD: from 'rmsd' column + - pLDDT: from 'plddt' column + """ + agg_dir = base_path / "aggregated" + agg_dir.mkdir(exist_ok=True) + + all_metrics = [] + all_sequences = [] + + logger.info(f"Aggregating {num_jobs} jobs from {base_path}") + + # Collect from each job + for job_id in range(num_jobs): + job_dir = base_path / f"job_{job_id}" + if not job_dir.exists(): + logger.warning(f"Job {job_id} directory not found") + continue + + logger.info(f"Processing job {job_id}") + + # Find inverse folding metrics CSV + metrics_files = list(job_dir.glob("*inverse_folding*metrics*.csv")) + if not metrics_files: + metrics_files = list(job_dir.glob("*_metrics_*.csv")) + + if metrics_files: + df = pd.read_csv(metrics_files[0]) + df["job_id"] = job_id + all_metrics.append(df) + logger.info(f" Found {len(df)} designs") + + # Find sequences CSV + seq_files = list(job_dir.glob("sequences_inverse_folding*.csv")) + if not seq_files: + seq_files = list(job_dir.glob("sequences_*.csv")) + + if seq_files: + df_seq = pd.read_csv(seq_files[0]) + df_seq["job_id"] = job_id + all_sequences.append(df_seq) + + # Copy PDB files (both generated and ESMFold) + pdb_files = list(job_dir.glob("*.pdb")) + for pdb in pdb_files: + new_name = f"job_{job_id}_{pdb.name}" + shutil.copy2(pdb, agg_dir / new_name) + + if pdb_files: + logger.info(f" Copied {len(pdb_files)} PDB files") + + # Combine sequences + if all_sequences: + combined_sequences = pd.concat(all_sequences, ignore_index=True) + output_sequences = agg_dir / "combined_inverse_folding_sequences.csv" + combined_sequences.to_csv(output_sequences, index=False) + logger.info(f"Saved combined sequences: {output_sequences}") + + # Combine metrics + if all_metrics: + combined_metrics = pd.concat(all_metrics, ignore_index=True) + + # Merge with sequences to get actual structure names + if all_sequences and "input_structure" in combined_sequences.columns: + # Merge on BOTH run_id AND job_id to avoid many-to-many joins + # (each job restarts batch numbering from batch_000) + structure_map = combined_sequences[["run_id", "job_id", "input_structure"]].drop_duplicates() + combined_metrics = combined_metrics.merge(structure_map, on=["run_id", "job_id"], how="left") + # Replace input_file with actual structure names where available + if "input_structure" in combined_metrics.columns: + combined_metrics["input_file"] = combined_metrics["input_structure"].fillna( + combined_metrics["input_file"] + ) + combined_metrics = combined_metrics.drop(columns=["input_structure"]) + logger.info("Replaced generic batch identifiers with actual structure names from sequences CSV") + + output_metrics = agg_dir / "combined_inverse_folding_metrics.csv" + combined_metrics.to_csv(output_metrics, index=False) + + logger.info(f"Saved combined metrics: {output_metrics}") + logger.info(f"Total designs: {len(combined_metrics)}") + + # Print summary statistics + logger.info("\n=== Inverse Folding Summary ===") + + # AAR (already calculated as percent_identity in the CSV) + if "percent_identity" in combined_metrics.columns: + aar = combined_metrics["percent_identity"].mean() + logger.info(f"Average AAR (Amino Acid Recovery): {aar:.2f}%") + logger.info(f" Min: {combined_metrics['percent_identity'].min():.2f}%") + logger.info(f" Max: {combined_metrics['percent_identity'].max():.2f}%") + logger.info(f" Median: {combined_metrics['percent_identity'].median():.2f}%") + + # TM-score + if "tm_score" in combined_metrics.columns: + tm = combined_metrics["tm_score"].mean() + logger.info(f"Average TM-score: {tm:.3f}") + logger.info(f" Min: {combined_metrics['tm_score'].min():.3f}") + logger.info(f" Max: {combined_metrics['tm_score'].max():.3f}") + + # RMSD + if "rmsd" in combined_metrics.columns: + rmsd = combined_metrics["rmsd"].mean() + logger.info(f"Average RMSD: {rmsd:.3f} Å") + logger.info(f" Min: {combined_metrics['rmsd'].min():.3f}") + logger.info(f" Max: {combined_metrics['rmsd'].max():.3f}") + + # Count structures passing RMSD threshold + passing = len(combined_metrics[combined_metrics["rmsd"] < rmsd_threshold]) + pct = (passing / len(combined_metrics)) * 100 + logger.info(f"Designs with RMSD < {rmsd_threshold}: {passing}/{len(combined_metrics)} ({pct:.1f}%)") + + # pLDDT + if "plddt" in combined_metrics.columns: + plddt = combined_metrics["plddt"].mean() + logger.info(f"Average pLDDT: {plddt:.3f}") + + # Create per-structure summary table + logger.info("\n=== Creating Per-Structure Summary ===") + summary_table = create_inverse_folding_summary(combined_metrics, rmsd_threshold) + if summary_table is not None: + summary_csv = agg_dir / "summary_per_structure.csv" + summary_table.to_csv(summary_csv, index=False) + logger.info(f"Saved per-structure summary: {summary_csv}") + + # Print table (limit to first 20 rows for readability) + logger.info("\n=== Summary Per Structure ===") + if len(summary_table) > 20: + logger.info(f"\n{summary_table.head(20).to_string(index=False)}") + logger.info(f"... and {len(summary_table) - 20} more structures") + else: + logger.info(f"\n{summary_table.to_string(index=False)}") + else: + logger.warning("Could not create per-structure summary table") + + # Create overall summary table (single row with aggregate stats) + logger.info("\n=== Creating Overall Summary ===") + overall_summary = create_overall_summary(combined_metrics, rmsd_threshold) + if overall_summary is not None: + overall_csv = agg_dir / "overall_summary.csv" + overall_summary.to_csv(overall_csv, index=False) + logger.info(f"Saved overall summary: {overall_csv}") + logger.info(f"\n{overall_summary.to_string(index=False)}") + else: + logger.warning("Could not create overall summary table") + + logger.info(f"\nAggregation complete! Results in: {agg_dir}") + + return { + "mode": "inverse_folding", + "aggregated_dir": str(agg_dir), + "total_designs": len(combined_metrics) if all_metrics else 0, + "average_aar": combined_metrics["percent_identity"].mean() + if all_metrics and "percent_identity" in combined_metrics.columns + else None, + "average_tm_score": combined_metrics["tm_score"].mean() + if all_metrics and "tm_score" in combined_metrics.columns + else None, + } + + +def create_inverse_folding_summary(combined_metrics: pd.DataFrame, rmsd_threshold: float = 2.0) -> pd.DataFrame: + """ + Create summary table for inverse folding results, grouped by input structure. + + CSV has 'input_file' column with values like 'batch_000', 'batch_001', etc. + + Returns: + DataFrame with columns: + - Input_Structure: from 'input_file' column + - Num_Designs + - Avg_AAR: from 'percent_identity' column + - Avg_TM_Score: from 'tm_score' column + - Avg_RMSD: from 'rmsd' column + - Avg_pLDDT: from 'plddt' column + - Designs_Passing_RMSD (count) + - Pct_Passing_RMSD + """ + if "input_file" not in combined_metrics.columns: + logger.warning("No 'input_file' column found in metrics") + return None + + summary_rows = [] + + # Group by input structure (input_file column) + for structure_file, group in combined_metrics.groupby("input_file"): + num_designs = len(group) + + row = { + "Input_Structure": structure_file, + "Num_Designs": num_designs, + } + + # AAR + if "percent_identity" in group.columns: + row["Avg_AAR"] = round(group["percent_identity"].mean(), 2) + row["Min_AAR"] = round(group["percent_identity"].min(), 2) + row["Max_AAR"] = round(group["percent_identity"].max(), 2) + + # TM-score + if "tm_score" in group.columns: + row["Avg_TM_Score"] = round(group["tm_score"].mean(), 4) + row["Min_TM_Score"] = round(group["tm_score"].min(), 4) + row["Max_TM_Score"] = round(group["tm_score"].max(), 4) + + # RMSD + if "rmsd" in group.columns: + row["Avg_RMSD"] = round(group["rmsd"].mean(), 4) + row["Min_RMSD"] = round(group["rmsd"].min(), 4) + row["Max_RMSD"] = round(group["rmsd"].max(), 4) + + passing = len(group[group["rmsd"] < rmsd_threshold]) + row[f"Designs_RMSD<{rmsd_threshold}"] = passing + row[f"Pct_RMSD<{rmsd_threshold}"] = round((passing / num_designs) * 100, 2) + + # pLDDT + if "plddt" in group.columns: + row["Avg_pLDDT"] = round(group["plddt"].mean(), 4) + + summary_rows.append(row) + + if not summary_rows: + logger.warning("No data to create summary table") + return None + + return pd.DataFrame(summary_rows) + + +def create_overall_summary( + combined_metrics: pd.DataFrame, + rmsd_threshold: float = 2.0, +) -> pd.DataFrame: + """ + Create a single-row overall summary table with aggregate statistics across all structures. + + Columns: + - Total_Structures: Number of unique structures + - Avg_TM_Score, Std_TM_Score, Min_TM_Score, Max_TM_Score + - Avg_RMSD, Std_RMSD, Min_RMSD, Max_RMSD + - Structures_RMSD<{threshold}: Count of structures with RMSD below threshold + - Pct_RMSD<{threshold}: Percentage of structures passing threshold + """ + if combined_metrics.empty: + logger.warning("Empty metrics DataFrame, cannot create overall summary") + return None + + # Count unique structures + total_structures = ( + combined_metrics["input_file"].nunique() if "input_file" in combined_metrics.columns else len(combined_metrics) + ) + + summary = { + "Total_Structures": total_structures, + } + + # TM-score statistics + if "tm_score" in combined_metrics.columns: + summary["Avg_TM_Score"] = round(combined_metrics["tm_score"].mean(), 4) + summary["Std_TM_Score"] = round(combined_metrics["tm_score"].std(), 4) + summary["Min_TM_Score"] = round(combined_metrics["tm_score"].min(), 4) + summary["Max_TM_Score"] = round(combined_metrics["tm_score"].max(), 4) + + # RMSD statistics + if "rmsd" in combined_metrics.columns: + summary["Avg_RMSD"] = round(combined_metrics["rmsd"].mean(), 4) + summary["Std_RMSD"] = round(combined_metrics["rmsd"].std(), 4) + summary["Min_RMSD"] = round(combined_metrics["rmsd"].min(), 4) + summary["Max_RMSD"] = round(combined_metrics["rmsd"].max(), 4) + + # Count structures passing RMSD threshold + passing_count = len(combined_metrics[combined_metrics["rmsd"] < rmsd_threshold]) + summary[f"Structures_RMSD<{rmsd_threshold}"] = passing_count + summary[f"Pct_RMSD<{rmsd_threshold}"] = round((passing_count / len(combined_metrics)) * 100, 2) + + # AAR statistics (for inverse folding) + if "percent_identity" in combined_metrics.columns: + summary["Avg_AAR"] = round(combined_metrics["percent_identity"].mean(), 2) + summary["Std_AAR"] = round(combined_metrics["percent_identity"].std(), 2) + summary["Min_AAR"] = round(combined_metrics["percent_identity"].min(), 2) + summary["Max_AAR"] = round(combined_metrics["percent_identity"].max(), 2) + + # pLDDT statistics (if available) + if "plddt" in combined_metrics.columns: + summary["Avg_pLDDT"] = round(combined_metrics["plddt"].mean(), 4) + summary["Std_pLDDT"] = round(combined_metrics["plddt"].std(), 4) + + return pd.DataFrame([summary]) + + +def aggregate_forward_folding( + base_path: Path, + num_jobs: int, + rmsd_threshold: float = 2.0, +) -> dict: + """ + Aggregate forward folding results. + Groups by input structure, calculates structural metrics. + + Key metrics (already in CSV): + - TM-score: from 'tm_score' column + - RMSD: from 'rmsd' column + - pLDDT: from 'plddt' column (if ESMFold validation used) + """ + # Very similar to inverse_folding, but no AAR metric + agg_dir = base_path / "aggregated" + agg_dir.mkdir(exist_ok=True) + + all_metrics = [] + all_sequences = [] + + logger.info(f"Aggregating {num_jobs} jobs from {base_path}") + + # Collect from each job + for job_id in range(num_jobs): + job_dir = base_path / f"job_{job_id}" + if not job_dir.exists(): + logger.warning(f"Job {job_id} directory not found") + continue + + logger.info(f"Processing job {job_id}") + + # Find forward folding metrics CSV + metrics_files = list(job_dir.glob("*forward_folding*metrics*.csv")) + if not metrics_files: + metrics_files = list(job_dir.glob("*_metrics_*.csv")) + + if metrics_files: + df = pd.read_csv(metrics_files[0]) + df["job_id"] = job_id + all_metrics.append(df) + logger.info(f" Found {len(df)} structures") + + # Find sequences CSV (for actual structure names) + seq_files = list(job_dir.glob("sequences_forward_folding*.csv")) + if not seq_files: + seq_files = list(job_dir.glob("sequences_*.csv")) + + if seq_files: + df_seq = pd.read_csv(seq_files[0]) + df_seq["job_id"] = job_id + all_sequences.append(df_seq) + + # Copy PDB files + pdb_files = list(job_dir.glob("*.pdb")) + for pdb in pdb_files: + new_name = f"job_{job_id}_{pdb.name}" + shutil.copy2(pdb, agg_dir / new_name) + + if pdb_files: + logger.info(f" Copied {len(pdb_files)} PDB files") + + # Combine sequences + if all_sequences: + combined_sequences = pd.concat(all_sequences, ignore_index=True) + output_sequences = agg_dir / "combined_forward_folding_sequences.csv" + combined_sequences.to_csv(output_sequences, index=False) + logger.info(f"Saved combined sequences: {output_sequences}") + + # Combine metrics + if all_metrics: + combined_metrics = pd.concat(all_metrics, ignore_index=True) + + # Merge with sequences to get actual structure names + if all_sequences and "input_structure" in combined_sequences.columns: + # Merge on BOTH run_id AND job_id to avoid many-to-many joins + # (each job restarts batch numbering from batch_000) + structure_map = combined_sequences[["run_id", "job_id", "input_structure"]].drop_duplicates() + combined_metrics = combined_metrics.merge(structure_map, on=["run_id", "job_id"], how="left") + # Replace input_file with actual structure names where available + if "input_structure" in combined_metrics.columns: + combined_metrics["input_file"] = combined_metrics["input_structure"].fillna( + combined_metrics["input_file"] + ) + combined_metrics = combined_metrics.drop(columns=["input_structure"]) + logger.info("Replaced generic batch identifiers with actual structure names from sequences CSV") + + output_metrics = agg_dir / "combined_forward_folding_metrics.csv" + combined_metrics.to_csv(output_metrics, index=False) + + logger.info(f"Saved combined metrics: {output_metrics}") + logger.info(f"Total structures: {len(combined_metrics)}") + + # Print summary statistics + logger.info("\n=== Forward Folding Summary ===") + + # TM-score + if "tm_score" in combined_metrics.columns: + tm = combined_metrics["tm_score"].mean() + logger.info(f"Average TM-score: {tm:.3f}") + logger.info(f" Min: {combined_metrics['tm_score'].min():.3f}") + logger.info(f" Max: {combined_metrics['tm_score'].max():.3f}") + + # RMSD + if "rmsd" in combined_metrics.columns: + rmsd = combined_metrics["rmsd"].mean() + logger.info(f"Average RMSD: {rmsd:.3f} Å") + logger.info(f" Min: {combined_metrics['rmsd'].min():.3f}") + logger.info(f" Max: {combined_metrics['rmsd'].max():.3f}") + + passing = len(combined_metrics[combined_metrics["rmsd"] < rmsd_threshold]) + pct = (passing / len(combined_metrics)) * 100 + logger.info(f"Structures with RMSD < {rmsd_threshold}: {passing}/{len(combined_metrics)} ({pct:.1f}%)") + + # pLDDT + if "plddt" in combined_metrics.columns: + plddt = combined_metrics["plddt"].mean() + logger.info(f"Average pLDDT: {plddt:.3f}") + + # Create per-structure summary if input_file column exists + if "input_file" in combined_metrics.columns: + logger.info("\n=== Creating Per-Structure Summary ===") + summary_table = create_forward_folding_summary(combined_metrics, rmsd_threshold) + if summary_table is not None: + summary_csv = agg_dir / "summary_per_structure.csv" + summary_table.to_csv(summary_csv, index=False) + logger.info(f"Saved per-structure summary: {summary_csv}") + + logger.info("\n=== Summary Per Structure ===") + if len(summary_table) > 20: + logger.info(f"\n{summary_table.head(20).to_string(index=False)}") + logger.info(f"... and {len(summary_table) - 20} more structures") + else: + logger.info(f"\n{summary_table.to_string(index=False)}") + + # Create overall summary table (single row with aggregate stats) + logger.info("\n=== Creating Overall Summary ===") + overall_summary = create_overall_summary(combined_metrics, rmsd_threshold) + if overall_summary is not None: + overall_csv = agg_dir / "overall_summary.csv" + overall_summary.to_csv(overall_csv, index=False) + logger.info(f"Saved overall summary: {overall_csv}") + logger.info(f"\n{overall_summary.to_string(index=False)}") + else: + logger.warning("Could not create overall summary table") + + logger.info(f"\nAggregation complete! Results in: {agg_dir}") + + return { + "mode": "forward_folding", + "aggregated_dir": str(agg_dir), + "total_structures": len(combined_metrics) if all_metrics else 0, + "average_tm_score": combined_metrics["tm_score"].mean() + if all_metrics and "tm_score" in combined_metrics.columns + else None, + } + + +def create_forward_folding_summary(combined_metrics: pd.DataFrame, rmsd_threshold: float = 2.0) -> pd.DataFrame: + """ + Create summary table for forward folding results, grouped by input structure. + + Returns: + DataFrame with columns similar to inverse folding but without AAR + """ + if "input_file" not in combined_metrics.columns: + logger.warning("No 'input_file' column found in metrics") + return None + + summary_rows = [] + + for structure_file, group in combined_metrics.groupby("input_file"): + num_structures = len(group) + + row = { + "Input_Structure": structure_file, + "Num_Structures": num_structures, + } + + # TM-score + if "tm_score" in group.columns: + row["Avg_TM_Score"] = round(group["tm_score"].mean(), 4) + row["Min_TM_Score"] = round(group["tm_score"].min(), 4) + row["Max_TM_Score"] = round(group["tm_score"].max(), 4) + + # RMSD + if "rmsd" in group.columns: + row["Avg_RMSD"] = round(group["rmsd"].mean(), 4) + row["Min_RMSD"] = round(group["rmsd"].min(), 4) + row["Max_RMSD"] = round(group["rmsd"].max(), 4) + + passing = len(group[group["rmsd"] < rmsd_threshold]) + row[f"Structures_RMSD<{rmsd_threshold}"] = passing + row[f"Pct_RMSD<{rmsd_threshold}"] = round((passing / num_structures) * 100, 2) + + # pLDDT + if "plddt" in group.columns: + row["Avg_pLDDT"] = round(group["plddt"].mean(), 4) + + summary_rows.append(row) + + if not summary_rows: + return None + + return pd.DataFrame(summary_rows) + + +def run_foldseek_clustering( + structures_by_length: dict, + output_dir: Path, + combined_metrics: pd.DataFrame, + foldseek_bin_path: str = None, + tmscore_threshold: float = 0.5, + rmsd_threshold: float = 2.0, +) -> dict: + """ + Run Foldseek clustering on aggregated structures, organized by length. + + Args: + structures_by_length: Dict mapping length -> list of PDB file paths + output_dir: Output directory for Foldseek results + combined_metrics: DataFrame with all metrics (used for RMSD filtering) + foldseek_bin_path: Path to Foldseek binary directory + tmscore_threshold: TM-score threshold for clustering + rmsd_threshold: RMSD threshold for filtering + + Returns: + Dictionary with diversity metrics per length + """ + if foldseek_bin_path is None: + foldseek_bin_path = "/homefs/home/lisanzas/scratch/Develop/lobster/src/lobster/metrics/foldseek/bin" + + diversity_results = {} + + for length, pdb_files in structures_by_length.items(): + logger.info(f"\nProcessing length {length}: {len(pdb_files)} structures") + + # Create length-specific directory + length_dir = output_dir / "foldseek_results" / f"length_{length}" + length_dir.mkdir(parents=True, exist_ok=True) + + # Filter structures by RMSD threshold using metrics DataFrame + # Only include structures that have valid RMSD < threshold in combined_metrics + filtered_pdbs = [] + + # Get metrics for this length with valid RMSD + length_metrics = combined_metrics[ + (combined_metrics["sequence_length"] == length) + & (combined_metrics["rmsd"].notna()) + & (combined_metrics["rmsd"] < rmsd_threshold) + ] + + # Determine which column to use for matching filenames + id_column = None + if "structure_file" in combined_metrics.columns: + id_column = "structure_file" + elif "run_id" in combined_metrics.columns: + id_column = "run_id" + else: + logger.warning(" No 'structure_file' or 'run_id' column in metrics, cannot filter structures") + continue + + # Build mapping from (job_id, sample_idx) pairs that passed RMSD threshold + # We need BOTH job_id and sample number to uniquely identify structures + # run_id format: unconditional_length_100_iter_000 + # filename format: job_0_generated_structure_length_100_000_esmfold_000.pdb + passing_job_sample_pairs = set() + + for idx, row in length_metrics.iterrows(): + identifier = row[id_column] + job_id = row.get("job_id", None) + + if pd.notna(identifier) and pd.notna(job_id): + # Extract sample number from run_id + # Pattern: unconditional_length_XXX_iter_YYY -> YYY is the sample number + try: + parts = str(identifier).split("_") + if "iter" in parts: + iter_idx = parts.index("iter") + 1 + sample_num = int(parts[iter_idx]) + passing_job_sample_pairs.add((int(job_id), sample_num)) + except (ValueError, IndexError): + pass + + logger.info(f" Found {len(passing_job_sample_pairs)} (job_id, sample) pairs passing RMSD threshold") + + # Filter PDB files to only include those with passing (job_id, sample_num) pairs + for pdb_path in pdb_files: + pdb_name = pdb_path.stem # Get filename without extension + + # Extract job_id and sample number from filename + # Pattern: job_0_generated_structure_length_100_000_esmfold_000 + try: + parts = pdb_name.split("_") + + # Extract job_id (first part after "job") + job_id = None + if "job" in parts: + job_idx = parts.index("job") + 1 + if job_idx < len(parts): + job_id = int(parts[job_idx]) + + # Extract sample number (comes after "length_XXX") + sample_num = None + if "length" in parts: + length_idx = parts.index("length") + 1 + # Skip the length value, next number is the sample index + if length_idx + 1 < len(parts): + sample_num = int(parts[length_idx + 1]) + + # Check if this (job_id, sample_num) pair passed RMSD threshold + if job_id is not None and sample_num is not None: + if (job_id, sample_num) in passing_job_sample_pairs: + filtered_pdbs.append(pdb_path) + + except (ValueError, IndexError) as e: + logger.debug(f"Could not parse filename {pdb_name}: {e}") + + structures_passing_rmsd = len(filtered_pdbs) + logger.info(f" Structures passing RMSD < {rmsd_threshold}: {structures_passing_rmsd}") + + if structures_passing_rmsd == 0: + logger.warning(f" No structures passed RMSD filter for length {length}") + continue + + # Create temp directory with filtered PDBs for Foldseek + temp_dir = length_dir / "foldseek_temp" + temp_dir.mkdir(parents=True, exist_ok=True) + + # Copy filtered PDB files to temp directory + for pdb in filtered_pdbs: + shutil.copy2(pdb, temp_dir / pdb.name) + + logger.info(f" Copied {structures_passing_rmsd} structures to temp directory") + + # Run Foldseek clustering using the existing function from lobster.metrics + try: + logger.info(f" Running Foldseek with TM-score threshold {tmscore_threshold}") + + num_clusters, total_proteins = run_easy_cluster( + designable_dir=temp_dir, + output_dir=length_dir, + tmscore_threshold=tmscore_threshold, + foldseek_bin_path=foldseek_bin_path, + ) + + if num_clusters is not None: + diversity_pct = (num_clusters / total_proteins) * 100 if total_proteins > 0 else 0 + + diversity_results[length] = { + "total_structures": len(pdb_files), + "structures_passing_rmsd": structures_passing_rmsd, + "num_clusters": num_clusters, + "diversity_percentage": diversity_pct, + } + + logger.info(f" ✓ Clustering complete: {num_clusters} clusters ({diversity_pct:.1f}% diversity)") + else: + logger.error(f" Foldseek clustering failed for length {length}") + + except Exception as e: + logger.error(f" Foldseek clustering failed: {e}") + + return diversity_results + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Aggregate distributed generation results", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Auto-detect mode + python aggregate_results.py ./examples/output 90 + + # Unconditional generation (with Foldseek) + python aggregate_results.py ./examples/generated_unconditional 20 --mode unconditional + + # Inverse folding (no Foldseek needed) + python aggregate_results.py ./examples/generated_inverse_folding_450M 90 --mode inverse_folding + + # Forward folding (no Foldseek needed) + python aggregate_results.py ./examples/generated_forward_folding 90 --mode forward_folding + """, + ) + + parser.add_argument("base_output_dir", help="Base output directory with job_* subdirectories") + parser.add_argument("num_jobs", type=int, help="Number of jobs to aggregate") + + parser.add_argument( + "--mode", + choices=["unconditional", "inverse_folding", "forward_folding"], + help="Generation mode (auto-detected if not provided)", + ) + + parser.add_argument("--no-foldseek", action="store_true", help="Skip Foldseek clustering") + parser.add_argument("--foldseek-bin", help="Path to Foldseek binary directory") + parser.add_argument("--tmscore-threshold", type=float, default=0.5, help="TM-score threshold for clustering") + parser.add_argument("--rmsd-threshold", type=float, default=2.0, help="RMSD threshold for filtering") + + args = parser.parse_args() + + # Handle run_foldseek logic + run_foldseek = None if not args.no_foldseek else False + + results = aggregate_distributed_results( + base_output_dir=args.base_output_dir, + num_jobs=args.num_jobs, + mode=args.mode, + run_foldseek=run_foldseek, + foldseek_bin_path=args.foldseek_bin, + foldseek_tmscore_threshold=args.tmscore_threshold, + rmsd_threshold=args.rmsd_threshold, + ) + + print("\n=== Aggregation Complete ===") + print(f"Mode: {results.get('mode', 'unknown')}") + print(f"Results saved to: {results['aggregated_dir']}") + + # Mode-specific summary + if results.get("mode") == "inverse_folding": + print(f"Total designs: {results.get('total_designs', 0)}") + if results.get("average_aar") is not None: + print(f"Average AAR: {results['average_aar']:.2f}%") + if results.get("average_tm_score") is not None: + print(f"Average TM-score: {results['average_tm_score']:.3f}") + + elif results.get("mode") == "forward_folding": + print(f"Total structures: {results.get('total_structures', 0)}") + if results.get("average_tm_score") is not None: + print(f"Average TM-score: {results['average_tm_score']:.3f}") + + elif results.get("mode") == "unconditional": + print(f"Total samples: {results.get('total_samples', 0)}") + if results.get("diversity_results"): + print("\nDiversity Summary:") + for length, metrics in results["diversity_results"].items(): + print( + f" Length {length}: {metrics['num_clusters']} clusters ({metrics['diversity_percentage']:.1f}% diversity)" + ) diff --git a/src/lobster/cmdline/distributed_generation/create_job_config.py b/src/lobster/cmdline/distributed_generation/create_job_config.py new file mode 100755 index 00000000..779fbe3e --- /dev/null +++ b/src/lobster/cmdline/distributed_generation/create_job_config.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +""" +Helper script to generate wandb_config.yaml for distributed generation. +Creates job distribution configurations for different sample counts and parallelization strategies. + +Supports two modes: +1. Unconditional: Distribute samples across jobs (e.g., 100 samples → 20 jobs × 5) +2. Structure-based (inverse/forward folding): Distribute input structure files across jobs +""" + +import glob +import yaml +from pathlib import Path + + +def create_job_config( + total_samples: int, + samples_per_job: int, + base_config_path: str = "src/lobster/hydra_config/experiment/generate_unconditional.yaml", + output_file: str = "src/lobster/cmdline/distributed_generation/wandb_config.yaml", + lengths: list[int] | None = None, +): + """ + Generate job distribution config. + + Args: + total_samples: Total number of samples to generate + samples_per_job: Number of samples per job + base_config_path: Path to base config file + output_file: Output file name + lengths: Optional list of lengths for multi-length generation + """ + num_jobs = (total_samples + samples_per_job - 1) // samples_per_job + + print(f"Creating config for {total_samples} samples") + print(f"Samples per job: {samples_per_job}") + print(f"Number of jobs: {num_jobs}") + + job_ids = list(range(num_jobs)) + + config = { + "program": "src/lobster/cmdline/distributed_generate.py", + "method": "grid", + "project": "lobster-distributed-generation", + "entity": "prescient-design", + "metric": {"name": "job_completed", "goal": "maximize"}, + "parameters": { + "base_config_path": {"value": base_config_path}, + "job_id": {"values": job_ids}, + "samples_per_job": {"value": samples_per_job}, + "total_samples": {"value": total_samples}, + }, + "command": ["${env}", "python", "${program}"], + } + + # Add lengths if specified + if lengths: + config["parameters"]["length"] = {"value": lengths} + + with open(output_file, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + print(f"\nConfig saved to: {output_file}") + print("\nNext steps:") + print("1. Review the config file") + print(f"2. Initialize: wandb sweep {output_file}") + print(f"3. Update submit_slurm.sh with sweep ID and --array=1-{num_jobs}") + print("4. Submit: sbatch src/lobster/cmdline/distributed_generation/submit_slurm.sh") + + +def create_structure_based_job_config( + input_structures: str, + structures_per_job: int, + base_config_path: str, + output_file: str = "src/lobster/cmdline/distributed_generation/wandb_config.yaml", + mode: str = "inverse_folding", +): + """ + Generate job distribution config for structure-based modes (inverse_folding, forward_folding). + + Args: + input_structures: Glob pattern or path to structure files + structures_per_job: Number of structures each job processes + base_config_path: Path to base config (inverse_folding or forward_folding yaml) + output_file: Output wandb config file + mode: Generation mode ("inverse_folding" or "forward_folding") + """ + # Expand glob to count files + if "*" in input_structures or "?" in input_structures: + structure_files = sorted(glob.glob(input_structures)) + else: + path = Path(input_structures) + if path.is_file(): + structure_files = [str(path)] + elif path.is_dir(): + structure_files = [] + structure_files.extend(sorted(glob.glob(str(path / "*.pdb")))) + structure_files.extend(sorted(glob.glob(str(path / "*.cif")))) + structure_files.extend(sorted(glob.glob(str(path / "*.pt")))) + else: + raise ValueError(f"Input path does not exist: {input_structures}") + + total_structures = len(structure_files) + + if total_structures == 0: + raise ValueError(f"No structure files found matching: {input_structures}") + + # Calculate number of jobs needed + num_jobs = (total_structures + structures_per_job - 1) // structures_per_job + + print(f"Creating config for {mode} mode") + print(f"Input pattern: {input_structures}") + print(f"Total structures found: {total_structures}") + print(f"Structures per job: {structures_per_job}") + print(f"Number of jobs: {num_jobs}") + + job_ids = list(range(num_jobs)) + + config = { + "program": "src/lobster/cmdline/distributed_generate.py", + "method": "grid", + "project": f"lobster-distributed-{mode.replace('_', '-')}", + "entity": "prescient-design", + "metric": {"name": "job_completed", "goal": "maximize"}, + "parameters": { + "base_config_path": {"value": base_config_path}, + "job_id": {"values": job_ids}, + "structures_per_job": {"value": structures_per_job}, + "total_structures": {"value": total_structures}, + "mode": {"value": mode}, + }, + "command": ["${env}", "python", "${program}"], + } + + # Save config + with open(output_file, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + print(f"\nConfig saved to: {output_file}") + print("\nStructure distribution:") + print(f" Jobs 0-{num_jobs - 2}: {structures_per_job} structures each") + last_job_count = total_structures - (num_jobs - 1) * structures_per_job + print(f" Job {num_jobs - 1}: {last_job_count} structures") + print("\nNext steps:") + print("1. Review the config file") + print(f"2. Initialize: wandb sweep {output_file}") + print(f"3. Update submit_slurm.sh with sweep ID and --array=1-{num_jobs}") + print("4. Submit: sbatch src/lobster/cmdline/distributed_generation/submit_slurm.sh") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Generate wandb distributed generation config", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Unconditional generation + python create_job_config.py --mode unconditional --total_samples 100 --samples_per_job 5 + + # Inverse folding + python src/lobster/cmdline/distributed_generation/create_job_config.py --mode inverse_folding \ + --input_structures "/data2/lisanzas/multi_flow_data/test_set_filtered_pt/*.pt" \ + --structures_per_job 5 \ + --base_config src/lobster/hydra_config/experiment/generate_inverse_folding_450M.yaml + + # Forward folding + python src/lobster/cmdline/distributed_generation/create_job_config.py --mode forward_folding \ + --input_structures "/data2/lisanzas/multi_flow_data/test_set_filtered_pt/*.pt" \ + --structures_per_job 5 \ + --base_config src/lobster/hydra_config/experiment/generate_forward_folding_450M.yaml + """, + ) + + parser.add_argument( + "--mode", + choices=["unconditional", "inverse_folding", "forward_folding"], + required=True, + help="Generation mode", + ) + parser.add_argument( + "--output", default="src/lobster/cmdline/distributed_generation/wandb_config.yaml", help="Output config file" + ) + parser.add_argument( + "--base_config", + help="Path to base config file (auto-detected if not provided)", + ) + + # Unconditional mode arguments + unconditional_group = parser.add_argument_group("unconditional mode arguments") + unconditional_group.add_argument("--total_samples", type=int, help="Total samples to generate") + unconditional_group.add_argument("--samples_per_job", type=int, default=50, help="Samples per job") + unconditional_group.add_argument("--lengths", type=int, nargs="+", help="Optional: lengths to generate") + + # Structure-based mode arguments (inverse_folding, forward_folding) + structure_group = parser.add_argument_group("inverse_folding/forward_folding mode arguments") + structure_group.add_argument("--input_structures", help="Glob pattern or path to structure files") + structure_group.add_argument("--structures_per_job", type=int, default=5, help="Structures per job") + + args = parser.parse_args() + + # Auto-detect base config if not provided + if not args.base_config: + if args.mode == "unconditional": + args.base_config = "src/lobster/hydra_config/experiment/generate_unconditional.yaml" + elif args.mode == "inverse_folding": + args.base_config = "src/lobster/hydra_config/experiment/generate_inverse_folding_450M.yaml" + elif args.mode == "forward_folding": + args.base_config = "src/lobster/hydra_config/experiment/generate_forward_folding_450M.yaml" + + # Call appropriate function based on mode + if args.mode == "unconditional": + if not args.total_samples: + parser.error("--total_samples is required for unconditional mode") + create_job_config( + total_samples=args.total_samples, + samples_per_job=args.samples_per_job, + base_config_path=args.base_config, + output_file=args.output, + lengths=args.lengths, + ) + elif args.mode in ["inverse_folding", "forward_folding"]: + if not args.input_structures: + parser.error("--input_structures is required for inverse_folding/forward_folding mode") + create_structure_based_job_config( + input_structures=args.input_structures, + structures_per_job=args.structures_per_job, + base_config_path=args.base_config, + output_file=args.output, + mode=args.mode, + ) diff --git a/src/lobster/cmdline/esmfold_baseline.py b/src/lobster/cmdline/esmfold_baseline.py new file mode 100644 index 00000000..ee1ad708 --- /dev/null +++ b/src/lobster/cmdline/esmfold_baseline.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python3 +""" +ESMFold Baseline for Forward Folding Comparison + +This script runs ESMFold as a baseline for forward folding tasks. +It takes input structures, extracts sequences, predicts structures using ESMFold, +and compares predictions to ground truth structures. + +Outputs the same CSV format as forward_folding mode for easy comparison. + +Usage: + uv run python -m lobster.cmdline.esmfold_baseline \\ + --config-path "../hydra_config/experiment" \\ + --config-name esmfold_baseline +""" + +import glob +from pathlib import Path +from datetime import datetime +import torch +from loguru import logger +import hydra +from omegaconf import DictConfig, ListConfig +import csv + +from lobster.model._lobster_fold import LobsterPLMFold +from lobster.transforms._structure_transforms import StructureBackboneTransform +from lobster.model.latent_generator.io import writepdb, load_pdb +from lobster.model.latent_generator.utils.residue_constants import restype_order_with_x_inv +from lobster.metrics import align_and_compute_rmsd +from tmtools import tm_align + + +@hydra.main(version_base=None, config_path="../hydra_config/experiment", config_name="esmfold_baseline") +def main(cfg: DictConfig) -> None: + """ + Run ESMFold baseline for forward folding comparison. + + Args: + cfg: Hydra configuration + """ + logger.info("=" * 80) + logger.info("ESMFold Baseline for Forward Folding") + logger.info("=" * 80) + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Set random seed + seed = cfg.get("seed", 12345) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + logger.info(f"Random seed: {seed}") + + # Create output directory + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Output directory: {output_dir}") + + # Load ESMFold model + logger.info("Loading ESMFold model...") + max_length = cfg.generation.get("max_length", 512) + plm_fold = LobsterPLMFold(model_name="esmfold_v1", max_length=max_length) + plm_fold.to(device) + plm_fold.eval() + logger.info("✓ ESMFold loaded successfully") + + # Get input structure paths + input_structures = cfg.generation.input_structures + if not input_structures: + raise ValueError("input_structures must be provided") + + # Handle different input formats + structure_paths = [] + if isinstance(input_structures, str): + # Single path or glob pattern + if "*" in input_structures or "?" in input_structures: + # Glob pattern + structure_paths = sorted(glob.glob(input_structures)) + else: + # Single file or directory + path = Path(input_structures) + if path.is_file(): + structure_paths = [str(path)] + elif path.is_dir(): + # Find all structure files in directory (PDB, CIF, PT) + structure_paths = sorted(list(glob.glob(str(path / "*.pdb")))) + structure_paths.extend(sorted(glob.glob(str(path / "*.cif")))) + structure_paths.extend(sorted(glob.glob(str(path / "*.pt")))) + else: + raise ValueError(f"Input path does not exist: {input_structures}") + elif isinstance(input_structures, (list, tuple, ListConfig)): + # List of paths (includes OmegaConf ListConfig) + for path_str in input_structures: + path = Path(path_str) + if path.is_file(): + structure_paths.append(str(path)) + else: + logger.warning(f"Skipping non-existent file: {path_str}") + else: + raise ValueError(f"Invalid input_structures format: {type(input_structures)}") + + if not structure_paths: + raise ValueError("No valid structure files found in input_structures") + + logger.info(f"Found {len(structure_paths)} structure files to process") + + # Initialize structure transform + structure_transform = StructureBackboneTransform() + + # Initialize CSV writer + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_path = output_dir / f"esmfold_baseline_metrics_{timestamp}.csv" + sequences_csv_path = output_dir / f"sequences_esmfold_baseline_{timestamp}.csv" + + # Write CSV headers + with open(csv_path, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["run_id", "timestamp", "mode", "plddt", "tm_score", "rmsd", "sequence_length", "input_file"]) + + with open(sequences_csv_path, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["run_id", "sample_idx", "sequence", "original_sequence", "length", "input_structure"]) + + logger.info(f"Initialized CSV metrics file: {csv_path}") + logger.info(f"Initialized sequences CSV file: {sequences_csv_path}") + + # Process structures + batch_size = cfg.generation.get("batch_size", 1) + all_tm_scores = [] + all_rmsd_scores = [] + all_plddt_scores = [] + + with torch.no_grad(): + # Process structure files in batches + for batch_start in range(0, len(structure_paths), batch_size): + batch_end = min(batch_start + batch_size, len(structure_paths)) + batch_paths = structure_paths[batch_start:batch_end] + batch_idx = batch_start // batch_size + + logger.info(f"\nProcessing batch {batch_idx + 1}/{(len(structure_paths) + batch_size - 1) // batch_size}") + + # Load structures from files + batch_data = [] + valid_indices = [] + + for i, structure_path in enumerate(batch_paths): + logger.info(f"Loading {structure_path}") + + # Check file extension to determine loading method + if structure_path.endswith(".pt"): + # Load .pt file directly + try: + structure_data = torch.load(structure_path, map_location="cpu", weights_only=False) + if structure_data is not None: + # Apply StructureBackboneTransform + structure_data = structure_transform(structure_data) + batch_data.append(structure_data) + valid_indices.append(i) + else: + logger.warning(f"Failed to load structure from {structure_path} - data is None") + except Exception as e: + logger.warning(f"Failed to load .pt file {structure_path}: {e}") + else: + # Load PDB/CIF file + try: + structure_data = load_pdb(structure_path, add_batch_dim=False) + if structure_data is not None: + # Apply StructureBackboneTransform + structure_data = structure_transform(structure_data) + batch_data.append(structure_data) + valid_indices.append(i) + else: + logger.warning(f"Failed to load structure from {structure_path}") + except Exception as e: + logger.warning(f"Failed to load structure {structure_path}: {e}") + + if not batch_data: + logger.warning(f"No valid structures in batch {batch_idx + 1}, skipping") + continue + + # Filter structures by minimum length (30 residues) and check sequence quality + filtered_batch_data = [] + filtered_valid_indices = [] + for i, data in enumerate(batch_data): + if data["coords_res"].shape[0] >= 30: + percent_20s = (data["sequence"] == 20).sum() / data["sequence"].shape[0] + if percent_20s > 0.1: + logger.info( + f"Skipping structure {batch_paths[valid_indices[i]]} - sequence contains more than 10% unknown residues" + ) + continue + filtered_batch_data.append(data) + filtered_valid_indices.append(valid_indices[i]) + else: + logger.info( + f"Skipping structure {batch_paths[valid_indices[i]]} - too short ({data['coords_res'].shape[0]} residues, minimum 30)" + ) + + if not filtered_batch_data: + logger.warning(f"No structures with sufficient length in batch {batch_idx + 1}, skipping") + continue + + # Process each structure in the batch + for i, (data, valid_idx) in enumerate(zip(filtered_batch_data, filtered_valid_indices)): + original_path = batch_paths[valid_idx] + original_name = Path(original_path).stem + + # Extract sequence from structure + seq_tensor = data["sequence"] + if seq_tensor.dim() > 1: + seq_tensor = seq_tensor.squeeze() + + # Convert sequence to string + sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_tensor]) + seq_length = len(sequence_str) + + logger.info(f"\nStructure {batch_idx * batch_size + i + 1}: {original_name}") + logger.info(f" Sequence length: {seq_length}") + logger.info(f" Sequence: {sequence_str[:50]}{'...' if len(sequence_str) > 50 else ''}") + + # Get ground truth coordinates and move to device + ground_truth_coords = data["coords_res"].to(device) # Shape: [L, 3, 3] + + # Tokenize sequence for ESMFold + try: + tokenized_input = plm_fold.tokenizer( + [sequence_str], + padding=True, + truncation=True, + max_length=max_length, + add_special_tokens=False, + return_tensors="pt", + )["input_ids"].to(device) + + # Predict structure with ESMFold + logger.info(" Running ESMFold prediction...") + outputs = plm_fold.model(tokenized_input) + + # Extract predicted coordinates + pred_coords = outputs["positions"][-1] # Shape: [B, L, 14, 3], last recycle + pred_coords_ca = pred_coords[0, :, 1, :] # CA atoms, Shape: [L, 3] + pred_coords_backbone = pred_coords[0, :, [0, 1, 2], :] # N, CA, C atoms, Shape: [L, 3, 3] + + # Extract pLDDT scores + plddt = outputs["plddt"][0] # Shape: [L] + mean_plddt = plddt.mean().item() + logger.info(f" Mean pLDDT: {mean_plddt:.2f}") + + # Get ground truth CA coordinates for TM-align + ground_truth_ca = ground_truth_coords[:, 1, :] # CA atoms, Shape: [L, 3] + + # Calculate TM-score using TM-align (CA atoms only) + tm_out = tm_align( + pred_coords_ca.cpu().numpy(), + ground_truth_ca.cpu().numpy(), + sequence_str, + sequence_str, + ) + + tm_score = tm_out.tm_norm_chain1 + + # Calculate RMSD using Kabsch alignment (all backbone atoms: N, CA, C) + # This matches the approach in the base generation script + # Ensure both tensors are on the same device + rmsd = align_and_compute_rmsd( + coords1=pred_coords_backbone.to(device), # ESMFold prediction [L, 3, 3] + coords2=ground_truth_coords.to(device), # Ground truth [L, 3, 3] + mask=None, # Use all positions + return_aligned=False, + device=device, + ) + + logger.info(f" TM-score: {tm_score:.3f}") + logger.info(f" RMSD (Kabsch): {rmsd:.2f} Å") + + # Collect metrics + all_tm_scores.append(tm_score) + all_rmsd_scores.append(rmsd) + all_plddt_scores.append(mean_plddt) + + # Save ESMFold predicted structure + esmfold_filename = output_dir / f"esmfold_baseline_{original_name}_predicted.pdb" + writepdb(str(esmfold_filename), pred_coords_backbone.cpu(), seq_tensor) + logger.info(f" Saved ESMFold prediction: {esmfold_filename}") + + # Save ground truth structure + ground_truth_filename = output_dir / f"esmfold_baseline_{original_name}_ground_truth.pdb" + writepdb(str(ground_truth_filename), ground_truth_coords, seq_tensor) + logger.info(f" Saved ground truth: {ground_truth_filename}") + + # Write metrics to CSV + run_id = f"esmfold_baseline_batch_{batch_idx:03d}_{i}" + current_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + with open(csv_path, "a", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow( + [ + run_id, + current_timestamp, + "esmfold_baseline", + round(mean_plddt, 4), + round(tm_score, 4), + round(rmsd, 4), + seq_length, + original_name, + ] + ) + + # Write sequences to CSV + with open(sequences_csv_path, "a", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow( + [ + run_id, + i, + sequence_str, + sequence_str, # For ESMFold baseline, input and output sequences are the same + seq_length, + original_name, + ] + ) + + except Exception as e: + logger.error(f" Error processing structure {original_name}: {e}") + import traceback + + traceback.print_exc() + continue + + # Calculate and report aggregate statistics + logger.info("\n" + "=" * 80) + logger.info("ESMFOLD BASELINE AGGREGATE STATISTICS") + logger.info("=" * 80) + + if all_tm_scores: + avg_tm_score = sum(all_tm_scores) / len(all_tm_scores) + min_tm_score = min(all_tm_scores) + max_tm_score = max(all_tm_scores) + logger.info(f"TM-Score Statistics (n={len(all_tm_scores)}):") + logger.info(f" Average: {avg_tm_score:.3f}") + logger.info(f" Min: {min_tm_score:.3f}") + logger.info(f" Max: {max_tm_score:.3f}") + else: + logger.warning("No TM-Score data collected") + + if all_rmsd_scores: + # Filter out infinite RMSD values + valid_rmsd = [r for r in all_rmsd_scores if r != float("inf")] + if valid_rmsd: + avg_rmsd = sum(valid_rmsd) / len(valid_rmsd) + min_rmsd = min(valid_rmsd) + max_rmsd = max(valid_rmsd) + logger.info(f"\nRMSD Statistics (n={len(valid_rmsd)}):") + logger.info(f" Average: {avg_rmsd:.2f} Å") + logger.info(f" Min: {min_rmsd:.2f} Å") + logger.info(f" Max: {max_rmsd:.2f} Å") + + # Calculate RMSD pass rate (< 2.0Å threshold) + rmsd_threshold = 2.0 + pass_count = sum(1 for rmsd in valid_rmsd if rmsd < rmsd_threshold) + total_count = len(valid_rmsd) + pass_rate = (pass_count / total_count * 100) if total_count > 0 else 0.0 + logger.info(f" RMSD Pass Rate (< {rmsd_threshold:.1f}Å): {pass_count}/{total_count} ({pass_rate:.1f}%)") + else: + logger.warning("No valid RMSD data collected") + else: + logger.warning("No RMSD data collected") + + if all_plddt_scores: + avg_plddt = sum(all_plddt_scores) / len(all_plddt_scores) + min_plddt = min(all_plddt_scores) + max_plddt = max(all_plddt_scores) + logger.info(f"\npLDDT Statistics (n={len(all_plddt_scores)}):") + logger.info(f" Average: {avg_plddt:.2f}") + logger.info(f" Min: {min_plddt:.2f}") + logger.info(f" Max: {max_plddt:.2f}") + else: + logger.warning("No pLDDT data collected") + + logger.info("=" * 80) + logger.info("\n✓ ESMFold baseline completed successfully!") + logger.info(f" Results saved to: {output_dir}") + logger.info(f" Metrics CSV: {csv_path}") + logger.info(f" Sequences CSV: {sequences_csv_path}") + logger.info(f" Total structures processed: {len(all_tm_scores)}") + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/evaluate_inverse_folding.py b/src/lobster/cmdline/evaluate_inverse_folding.py new file mode 100644 index 00000000..30c52729 --- /dev/null +++ b/src/lobster/cmdline/evaluate_inverse_folding.py @@ -0,0 +1,341 @@ +"""Evaluate inverse folding on protein-only structures (e.g., CAMEO). + +This script evaluates sequence recovery for inverse folding without ligand context. +Useful for comparing protein-only inverse folding performance. + +Usage: + uv run python -m lobster.cmdline.evaluate_inverse_folding \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir "/cv/data/ai4dd/data2/lisanzas/AFDB/valid_cameo_processed/*.pt" \ + --num_samples 127 \ + --nsteps 100 \ + --device cuda +""" + +import argparse +import glob +import os + +import torch +from loguru import logger +from tqdm import tqdm + +from lobster.model.gen_ume import ProteinLigandEncoderLightningModule +from lobster.model.latent_generator.io import writepdb +from lobster.model.latent_generator.utils.residue_constants import ( + convert_lobster_aa_tokenization_to_standard_aa, +) +from lobster.transforms._structure_transforms import StructureBackboneTransform + + +def load_structures(data_dir: str, num_samples: int, max_length: int = 512) -> list[dict]: + """Load structures from .pt files.""" + if "*" in data_dir: + pt_files = sorted(glob.glob(data_dir)) + else: + pt_files = sorted(glob.glob(os.path.join(data_dir, "*.pt"))) + + if not pt_files: + raise ValueError(f"No .pt files found at {data_dir}") + + logger.info(f"Found {len(pt_files)} .pt files") + + transform = StructureBackboneTransform(max_length=max_length) + structures = [] + + for pt_path in tqdm(pt_files[: num_samples * 3], desc="Loading structures"): + try: + data = torch.load(pt_path, map_location="cpu", weights_only=False) + data = transform(data) + + # Filter by length and unknown residues + if data["coords_res"].shape[0] >= 30: + percent_unknown = (data["sequence"] == 20).sum().float() / data["sequence"].shape[0] + if percent_unknown <= 0.1: + structures.append(data) + + if len(structures) >= num_samples: + break + except Exception as e: + logger.warning(f"Failed to load {pt_path}: {e}") + + logger.info(f"Loaded {len(structures)} valid structures") + return structures + + +def create_dummy_ligand(protein_coords: torch.Tensor, mask: torch.Tensor, num_atoms: int = 10, device: str = "cuda"): + """Create a dummy ligand near the protein centroid. + + Parameters + ---------- + protein_coords : Tensor + Protein coordinates [B, L, 3, 3] (N, CA, C per residue) + mask : Tensor + Valid residue mask [B, L] + num_atoms : int + Number of dummy ligand atoms + device : str + Device to create tensors on + + Returns + ------- + dict with ligand_coords, ligand_atom_types, ligand_mask + """ + B, L = mask.shape + + # Get protein centroid (using CA atoms) + ca_coords = protein_coords[:, :, 1, :] # [B, L, 3] + + # Compute centroid per sample + centroids = [] + for i in range(B): + valid = mask[i].bool() + if valid.sum() > 0: + centroid = ca_coords[i, valid].mean(dim=0) + else: + centroid = torch.zeros(3, device=device) + centroids.append(centroid) + centroids = torch.stack(centroids) # [B, 3] + + # Create dummy ligand atoms around centroid (random positions within 5Å) + ligand_coords = centroids.unsqueeze(1) + torch.randn(B, num_atoms, 3, device=device) * 2.0 # [B, num_atoms, 3] + + # Use carbon atoms (index 3 in ELEMENT_VOCAB_EXTENDED) + ligand_atom_types = torch.full((B, num_atoms), 3, dtype=torch.long, device=device) # All carbons + + # All atoms valid + ligand_mask = torch.ones(B, num_atoms, device=device) + + # Simple linear bond matrix (chain of atoms) + bond_matrix = torch.zeros(B, num_atoms, num_atoms, dtype=torch.long, device=device) + for i in range(num_atoms - 1): + bond_matrix[:, i, i + 1] = 1 # Single bond + bond_matrix[:, i + 1, i] = 1 # Symmetric + + return { + "ligand_coords": ligand_coords, + "ligand_atom_types": ligand_atom_types, + "ligand_mask": ligand_mask, + "bond_matrix": bond_matrix, + } + + +def evaluate_inverse_folding( + model, + structures: list[dict], + device: str, + nsteps: int = 100, + output_dir: str | None = None, + use_dummy_ligand: bool = False, + dummy_ligand_atoms: int = 10, + batch_size: int = 10, +) -> dict: + """Run inverse folding evaluation. + + Parameters + ---------- + use_dummy_ligand : bool + If True, add a dummy ligand (random atoms near centroid) to test + if the model needs ligand context to function properly. + dummy_ligand_atoms : int + Number of atoms in the dummy ligand. + batch_size : int + Batch size for evaluation. + """ + model.eval() + model.to(device) + + all_recoveries = [] + per_sample_results = [] + + for batch_start in tqdm(range(0, len(structures), batch_size), desc="Evaluating"): + batch_end = min(batch_start + batch_size, len(structures)) + batch_structures = structures[batch_start:batch_end] + + # Prepare batch + max_len = max(s["coords_res"].shape[0] for s in batch_structures) + B = len(batch_structures) + + sequence = torch.zeros((B, max_len), dtype=torch.long, device=device) + coords_res = torch.zeros((B, max_len, 3, 3), device=device) + mask = torch.zeros((B, max_len), device=device) + indices = torch.zeros((B, max_len), dtype=torch.long, device=device) + + for i, s in enumerate(batch_structures): + L = s["coords_res"].shape[0] + sequence[i, :L] = s["sequence"].to(device) + coords_res[i, :L] = s["coords_res"].to(device) + mask[i, :L] = s["mask"].to(device) + indices[i, :L] = s["indices"].to(device) + + # Handle NaNs + nan_mask = torch.isnan(coords_res).any(dim=-1).any(dim=-1) + mask[nan_mask] = 0 + coords_res[nan_mask] = 0 + + # Generate sequences + with torch.no_grad(): + if use_dummy_ligand: + # Create dummy ligand for this batch + dummy = create_dummy_ligand(coords_res, mask, dummy_ligand_atoms, device) + + # Encode dummy ligand structure to tokens + encode_result = model.encode_ligand_structure( + dummy["ligand_coords"], + dummy["ligand_mask"], + torch.arange(dummy_ligand_atoms, device=device).unsqueeze(0).expand(B, -1), + return_continuous=True, + ) + ligand_structure_tokens, _, ligand_structure_embeddings = encode_result + + result = model.generate_sample( + length=max_len, + num_samples=B, + inverse_folding=True, + nsteps=nsteps, + input_structure_coords=coords_res, + input_mask=mask, + input_indices=indices, + # Dummy ligand as context + generate_ligand=True, + num_atoms=dummy_ligand_atoms, + input_ligand_atom_tokens=dummy["ligand_atom_types"], + input_ligand_structure_tokens=ligand_structure_tokens, + input_ligand_structure_embeddings=ligand_structure_embeddings, + input_bond_matrix=dummy["bond_matrix"], + ligand_is_context=True, + ) + else: + # Pure protein-only inverse folding + result = model.generate_sample( + length=max_len, + num_samples=B, + inverse_folding=True, + nsteps=nsteps, + input_structure_coords=coords_res, + input_mask=mask, + input_indices=indices, + generate_ligand=False, + ) + + # Get predicted sequences + seq_logits = result["sequence_logits"] + if seq_logits.shape[-1] == 33: + pred_seq = convert_lobster_aa_tokenization_to_standard_aa(seq_logits, device=device) + else: + pred_seq = seq_logits.argmax(dim=-1) + pred_seq[pred_seq > 20] = 20 + + # Compute per-sample recovery + for i in range(B): + valid_mask = mask[i].bool() + gt = sequence[i][valid_mask] + pred = pred_seq[i][valid_mask] + recovery = (gt == pred).float().mean().item() * 100 + all_recoveries.append(recovery) + per_sample_results.append( + { + "sample_idx": batch_start + i, + "length": int(valid_mask.sum().item()), + "recovery": recovery, + } + ) + + # Save first few structures + if batch_start == 0 and output_dir: + os.makedirs(output_dir, exist_ok=True) + decoded_x = model.decode_structure(result, mask) + vit_output = decoded_x.get("vit_decoder") + if isinstance(vit_output, dict): + x_recon_xyz = vit_output.get("protein_coords") + else: + x_recon_xyz = vit_output + + for i in range(min(5, B)): + pdb_path = os.path.join(output_dir, f"sample_{i}_recovery_{all_recoveries[i]:.1f}.pdb") + writepdb(pdb_path, x_recon_xyz[i], pred_seq[i]) + logger.info(f"Saved {pdb_path}") + + # Summary statistics + import numpy as np + + recoveries = np.array(all_recoveries) + + summary = { + "mean_recovery": float(recoveries.mean()), + "std_recovery": float(recoveries.std()), + "median_recovery": float(np.median(recoveries)), + "min_recovery": float(recoveries.min()), + "max_recovery": float(recoveries.max()), + "num_samples": len(recoveries), + } + + return {"summary": summary, "per_sample": per_sample_results} + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate inverse folding on protein-only structures") + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") + parser.add_argument("--data_dir", type=str, required=True, help="Path to .pt files (supports glob patterns)") + parser.add_argument("--num_samples", type=int, default=127, help="Number of samples to evaluate") + parser.add_argument("--nsteps", type=int, default=100, help="Number of generation steps") + parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length") + parser.add_argument("--device", type=str, default="cuda", help="Device to use") + parser.add_argument("--output_dir", type=str, default=None, help="Directory to save structures") + parser.add_argument("--dummy_ligand", action="store_true", help="Add dummy ligand to test model bias") + parser.add_argument("--dummy_ligand_atoms", type=int, default=10, help="Number of atoms in dummy ligand") + parser.add_argument("--batch_size", type=int, default=10, help="Batch size for evaluation") + args = parser.parse_args() + + # Load model + logger.info(f"Loading checkpoint: {args.checkpoint}") + model = ProteinLigandEncoderLightningModule.load_from_checkpoint( + args.checkpoint, + map_location=args.device, + strict=False, + ) + logger.info("Model loaded successfully") + + # Load structures + structures = load_structures(args.data_dir, args.num_samples, args.max_length) + + # Run evaluation + results = evaluate_inverse_folding( + model, + structures, + args.device, + args.nsteps, + args.output_dir, + use_dummy_ligand=args.dummy_ligand, + dummy_ligand_atoms=args.dummy_ligand_atoms, + batch_size=args.batch_size, + ) + + # Print results + print("\n" + "=" * 70) + if args.dummy_ligand: + print(f"Inverse Folding with DUMMY LIGAND ({args.dummy_ligand_atoms} atoms)") + else: + print("Protein-Only Inverse Folding Evaluation Results") + print("=" * 70) + print(f"\nSamples evaluated: {results['summary']['num_samples']}") + print("\n--- Sequence Recovery ---") + print(f" Mean: {results['summary']['mean_recovery']:.2f}%") + print(f" Std: {results['summary']['std_recovery']:.2f}%") + print(f" Median: {results['summary']['median_recovery']:.2f}%") + print(f" Min: {results['summary']['min_recovery']:.2f}%") + print(f" Max: {results['summary']['max_recovery']:.2f}%") + print("=" * 70) + + # Show distribution + import numpy as np + + recoveries = np.array([r["recovery"] for r in results["per_sample"]]) + print("\n--- Recovery Distribution ---") + for threshold in [10, 20, 30, 40, 50, 60]: + pct = (recoveries >= threshold).mean() * 100 + print(f" >= {threshold}%: {pct:.1f}% of samples") + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/evaluate_ligand_conditioned_protein_generation.py b/src/lobster/cmdline/evaluate_ligand_conditioned_protein_generation.py new file mode 100644 index 00000000..6fe1e3e7 --- /dev/null +++ b/src/lobster/cmdline/evaluate_ligand_conditioned_protein_generation.py @@ -0,0 +1,426 @@ +"""Standalone evaluation of ligand-conditioned protein generation. + +Evaluates whether the model can generate self-consistent proteins conditioned +on a ligand. The model generates both sequence and structure from scratch; +the sequence is then folded with ESMFold, and the self-consistency between +the model-decoded structure and the ESMFold prediction is measured. + +Usage: + uv run python -m lobster.cmdline.evaluate_ligand_conditioned_protein_generation \ + --output results.csv \ + --structure_path ./output/ \ + --length 100 \ + --num_samples 10 + +Example (full test set): + uv run python -m lobster.cmdline.evaluate_ligand_conditioned_protein_generation \ + --output ligand_cond_protein_gen_results.csv \ + --structure_path ./ligand_cond_eval/ \ + --length 100 \ + --save_structures \ + --num_samples -1 +""" + +import argparse +import os +import random +import sys +from pathlib import Path + +import numpy as np +import torch +from loguru import logger + +from lobster.metrics.ligand_conditioned_protein_generation import ( + LigandConditionedProteinGenerationEvaluator, +) +from lobster.model.gen_ume import ProteinLigandEncoderLightningModule + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Evaluate ligand-conditioned protein generation via " + "self-consistency (decoded structure vs ESMFold prediction)" + ) + ) + parser.add_argument( + "--checkpoint", + type=str, + default="/cv/scratch/u/lisanzas/gen_ume_protein_ligand_medium/runs/2026-02-11T19-45-30/epoch=278-step=40057-val_loss=1.6365.ckpt", + help="Path to model checkpoint (.ckpt file)", + ) + parser.add_argument( + "--data_dir", + type=str, + default="/cv/home/lisanzas/lobster/data/posebusters/processed/posebusters_benchmark_no_overlap/", + help="Path to directory with *_ligand.pt files", + ) + parser.add_argument( + "--output", + type=str, + default="ligand_cond_protein_gen_results.csv", + help="Output CSV file for results", + ) + parser.add_argument( + "--structure_path", + type=str, + default=None, + help="Output directory for generated structures (PDB/FASTA files)", + ) + parser.add_argument( + "--length", + type=int, + default=100, + help="Length of protein to generate (number of residues, default: 100)", + ) + parser.add_argument( + "--pocket_threshold", + type=float, + default=5.0, + help="Distance threshold (angstrom) for defining binding pocket on decoded structure", + ) + parser.add_argument( + "--num_samples", + type=int, + default=100, + help="Number of ligands to evaluate (-1 for all)", + ) + parser.add_argument( + "--num_designs", + type=int, + default=10, + help="Number of designs to generate per ligand (best by scTM is reported)", + ) + parser.add_argument( + "--nsteps", + type=int, + default=200, + help="Number of diffusion steps for generation", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for computation (cuda/cpu)", + ) + parser.add_argument( + "--temperature_seq", + type=float, + default=0.15279667854390633, + help="Temperature for sequence sampling", + ) + parser.add_argument( + "--temperature_struc", + type=float, + default=0.18605909386731256, + help="Temperature for structure sampling", + ) + parser.add_argument( + "--stochasticity_seq", + type=int, + default=10, + help="Stochasticity parameter for sequence sampling", + ) + parser.add_argument( + "--stochasticity_struc", + type=int, + default=10, + help="Stochasticity parameter for structure sampling", + ) + parser.add_argument( + "--temperature_ligand", + type=float, + default=0.5819150856331732, + help="Temperature for ligand structure sampling", + ) + parser.add_argument( + "--stochasticity_ligand", + type=int, + default=20, + help="Stochasticity parameter for ligand structure sampling", + ) + parser.add_argument( + "--ligand_context_mode", + type=str, + default="atom_bond_only", + choices=["structure_tokens", "atom_bond_only"], + help="How to provide ligand context: 'atom_bond_only' or 'structure_tokens'", + ) + parser.add_argument( + "--inference_schedule_seq", + type=str, + default="LinearInferenceSchedule", + choices=[ + "LinearInferenceSchedule", + "LogInferenceSchedule", + "PowerInferenceSchedule", + ], + help="Inference schedule for sequence generation", + ) + parser.add_argument( + "--inference_schedule_struc", + type=str, + default="PowerInferenceSchedule", + choices=[ + "LinearInferenceSchedule", + "LogInferenceSchedule", + "PowerInferenceSchedule", + ], + help="Inference schedule for structure generation", + ) + parser.add_argument( + "--inference_schedule_ligand_atom", + type=str, + default="PowerInferenceSchedule", + choices=[ + "LinearInferenceSchedule", + "LogInferenceSchedule", + "PowerInferenceSchedule", + ], + help="Inference schedule for ligand atom token generation", + ) + parser.add_argument( + "--inference_schedule_ligand_struc", + type=str, + default="LinearInferenceSchedule", + choices=[ + "LinearInferenceSchedule", + "LogInferenceSchedule", + "PowerInferenceSchedule", + ], + help="Inference schedule for ligand structure token generation", + ) + parser.add_argument( + "--save_structures", + action="store_true", + help="Save decoded and ESMFold structures as PDB files", + ) + parser.add_argument( + "--minimize_ligand", + action="store_true", + help="Apply force-field minimization to decoded ligand geometry", + ) + parser.add_argument( + "--minimize_mode", + type=str, + default="bonds_and_angles", + choices=["bonds_only", "bonds_and_angles", "local", "full"], + help="Minimization mode for ligand geometry correction", + ) + parser.add_argument( + "--force_field", + type=str, + default="MMFF94", + choices=["MMFF94", "MMFF94s", "UFF"], + help="Force field for ligand minimization", + ) + parser.add_argument( + "--minimize_steps", + type=int, + default=500, + help="Maximum number of ligand minimization steps", + ) + parser.add_argument( + "--seed", + type=int, + default=1234, + help="Random seed for reproducibility", + ) + + args = parser.parse_args() + + # Set random seeds + logger.info(f"Setting random seed: {args.seed}") + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + # Validate paths + if not Path(args.checkpoint).exists(): + logger.error(f"Checkpoint not found: {args.checkpoint}") + sys.exit(1) + + if not Path(args.data_dir).exists(): + logger.error(f"Data directory not found: {args.data_dir}") + sys.exit(1) + + if args.structure_path: + os.makedirs(args.structure_path, exist_ok=True) + logger.info(f"Output directory: {args.structure_path}") + + # Set device + if args.device == "cuda" and not torch.cuda.is_available(): + logger.warning("CUDA not available, falling back to CPU") + args.device = "cpu" + + # Load model + logger.info(f"Loading model from {args.checkpoint}") + try: + model = ProteinLigandEncoderLightningModule.load_from_checkpoint( + args.checkpoint, + map_location=args.device, + ) + model.eval() + model.to(args.device) + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Failed to load model: {e}") + sys.exit(1) + + # Load ESMFold (always required for self-consistency) + from lobster.model import LobsterPLMFold + + logger.info("Loading ESMFold for self-consistency evaluation...") + plm_fold = LobsterPLMFold(model_name="esmfold_v1", max_length=512) + plm_fold.to(args.device) + logger.info("ESMFold loaded successfully") + + # Get max_length from model if available + num_samples = None if args.num_samples == -1 else args.num_samples + max_length = 512 + if hasattr(model, "encoder") and hasattr(model.encoder, "neobert"): + if hasattr(model.encoder.neobert, "config") and hasattr(model.encoder.neobert.config, "max_length"): + max_length = model.encoder.neobert.config.max_length + logger.info(f"Using model's max_length: {max_length}") + + # Create evaluator + evaluator = LigandConditionedProteinGenerationEvaluator( + data_dir=args.data_dir, + length=args.length, + pocket_distance_threshold=args.pocket_threshold, + num_samples=num_samples, + num_designs=args.num_designs, + nsteps=args.nsteps, + device=args.device, + max_length=max_length, + temperature_seq=args.temperature_seq, + temperature_struc=args.temperature_struc, + stochasticity_seq=args.stochasticity_seq, + stochasticity_struc=args.stochasticity_struc, + temperature_ligand=args.temperature_ligand, + stochasticity_ligand=args.stochasticity_ligand, + ligand_context_mode=args.ligand_context_mode, + inference_schedule_seq=args.inference_schedule_seq, + inference_schedule_struc=args.inference_schedule_struc, + inference_schedule_ligand_atom=args.inference_schedule_ligand_atom, + inference_schedule_ligand_struc=args.inference_schedule_ligand_struc, + save_structures=args.save_structures, + minimize_ligand=args.minimize_ligand, + minimize_mode=args.minimize_mode, + force_field=args.force_field, + minimize_steps=args.minimize_steps, + plm_fold=plm_fold, + ) + + # Load samples + logger.info("Loading test samples...") + samples = evaluator.load_test_set() + logger.info(f"Loaded {len(samples)} samples") + + # Run evaluation + logger.info("Running evaluation...") + logger.info(f" Protein length: {args.length}") + logger.info(f" Designs per ligand: {args.num_designs}") + logger.info(f" Ligand context mode: {args.ligand_context_mode}") + logger.info(f" Diffusion steps: {args.nsteps}") + if args.minimize_ligand: + logger.info(f" Ligand minimization: {args.minimize_mode} ({args.force_field}, {args.minimize_steps} steps)") + results = evaluator.evaluate(model, samples, structure_path=args.structure_path) + + # Save results CSV + output_path = args.output + if args.structure_path: + output_path = os.path.join(args.structure_path, os.path.basename(args.output)) + results["results_df"].to_csv(output_path, index=False) + logger.info(f"Results saved to {output_path}") + + # Print summary + summary = results["summary"] + _print_summary(args, summary) + + +def _print_summary(args, summary): + """Print evaluation summary to stdout.""" + print("\n" + "=" * 70) + print("Ligand-Conditioned Protein Generation: Self-Consistency Results") + print("=" * 70) + + print(f"\nLigands evaluated: {summary['n_ligands']}") + print(f"Designs per ligand: {summary['num_designs']}") + print(f"Total designs: {summary.get('n_total_designs', 'N/A')}") + print(f"Generated protein len: {summary['protein_length']}") + print(f"Ligand context mode: {args.ligand_context_mode}") + print(f"Pocket threshold: {args.pocket_threshold} A") + print(f"Avg pocket size: {summary['mean_pocket_size']:.1f} residues") + print("(Metrics below are over the best design per ligand)") + + print("\n--- Protein-Ligand Contacts ---") + print(f" Contacts (CA<4.5A): {summary['mean_n_contacts']:.1f} (+/-{summary['std_n_contacts']:.1f})") + print( + f" Residues in contact: {summary['mean_n_residues_in_contact']:.1f} " + f"({summary['mean_frac_residues_in_contact']:.1%})" + ) + print(f" Ligand atoms contacted: {summary['mean_frac_ligand_atoms_in_contact']:.1%}") + print( + f" Min distance (A): {summary['mean_min_protein_ligand_dist']:.2f} " + f"(+/-{summary['std_min_protein_ligand_dist']:.2f})" + ) + + print("\n--- Self-Consistency (Decoded vs ESMFold) ---") + print( + f" scTM: {summary['mean_scTM']:.4f} " + f"(+/-{summary['std_scTM']:.4f}, " + f"median {summary['median_scTM']:.4f})" + ) + print( + f" scRMSD (A): {summary['mean_scRMSD']:.2f} " + f"(+/-{summary['std_scRMSD']:.2f}, " + f"median {summary['median_scRMSD']:.2f})" + ) + + print("\n--- Pocket Self-Consistency ---") + print(f" pocket scTM: {summary['mean_pocket_scTM']:.4f} (+/-{summary['std_pocket_scTM']:.4f})") + print(f" pocket scRMSD: {summary['mean_pocket_scRMSD']:.2f} (+/-{summary['std_pocket_scRMSD']:.2f})") + + print("\n--- ESMFold Confidence ---") + print(f" pLDDT: {summary['mean_plddt']:.2f} (+/-{summary['std_plddt']:.2f})") + print(f" PAE: {summary['mean_pae']:.2f} (+/-{summary['std_pae']:.2f})") + + print("\n" + "=" * 70) + + # Key insights + sc_tm = summary["mean_scTM"] + if sc_tm > 0.5: + print(f"High self-consistency (scTM={sc_tm:.3f} > 0.5): generated sequences fold into the predicted structure") + elif sc_tm > 0.3: + print(f"Moderate self-consistency (scTM={sc_tm:.3f}): partial agreement between decoded and folded structures") + else: + print(f"Low self-consistency (scTM={sc_tm:.3f} < 0.3): decoded and folded structures diverge significantly") + + plddt = summary["mean_plddt"] + if plddt > 70: + print(f"Good ESMFold confidence (pLDDT={plddt:.1f} > 70)") + elif plddt > 50: + print(f"Moderate ESMFold confidence (pLDDT={plddt:.1f})") + else: + print(f"Low ESMFold confidence (pLDDT={plddt:.1f} < 50)") + + mean_contacts = summary["mean_n_contacts"] + min_dist = summary["mean_min_protein_ligand_dist"] + if mean_contacts < 1: + print( + f"WARNING: No protein-ligand contacts (min dist={min_dist:.1f}A). Protein and ligand are not interacting." + ) + elif mean_contacts < 5: + print(f"Few protein-ligand contacts ({mean_contacts:.0f}). Weak interaction.") + else: + print(f"Protein-ligand contacts: {mean_contacts:.0f} (min dist={min_dist:.1f}A)") + + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/evaluate_protein_ligand_forward_folding.py b/src/lobster/cmdline/evaluate_protein_ligand_forward_folding.py new file mode 100644 index 00000000..e09bd166 --- /dev/null +++ b/src/lobster/cmdline/evaluate_protein_ligand_forward_folding.py @@ -0,0 +1,406 @@ +"""Standalone evaluation of forward folding on protein-ligand complexes. + +Evaluates whether ligand context improves forward folding (structure prediction) +performance, particularly for binding pocket residues. + +Usage: + uv run python -m lobster.cmdline.evaluate_protein_ligand_forward_folding \ + --checkpoint path/to/model.ckpt \ + --data_dir /data2/lisanzas/pdb_bind_12_15_25/test/ \ + --output results.csv \ + --structure_path ./output/ \ + --pocket_threshold 5.0 \ + --num_samples 100 + +Example (full test set): + uv run python -m lobster.cmdline.evaluate_protein_ligand_forward_folding \ + --checkpoint /data2/ume/gen_ume_protein_ligand/best.ckpt \ + --data_dir /data2/lisanzas/pdb_bind_12_15_25/test/ \ + --output protein_ligand_forward_folding_results.csv \ + --structure_path ./protein_ligand_eval/ \ + --save_structures \ + --num_samples -1 +""" + +import argparse +import os +import random +import sys +from pathlib import Path + +import numpy as np +import torch +from loguru import logger + +from lobster.metrics.protein_ligand_forward_folding import ProteinLigandForwardFoldingEvaluator +from lobster.model.gen_ume import ProteinLigandEncoderLightningModule + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate forward folding on protein-ligand complexes with/without ligand context" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint (.ckpt file)", + ) + parser.add_argument( + "--data_dir", + type=str, + default="/data2/lisanzas/pdb_bind_12_15_25/test/", + help="Path to protein-ligand test directory", + ) + parser.add_argument( + "--output", + type=str, + default="protein_ligand_forward_folding_results.csv", + help="Output CSV file for results", + ) + parser.add_argument( + "--structure_path", + type=str, + default=None, + help="Output directory for predicted structures (PDB files)", + ) + parser.add_argument( + "--pocket_threshold", + type=float, + default=5.0, + help="Distance threshold (Å) for defining binding pocket", + ) + parser.add_argument( + "--num_samples", + type=int, + default=100, + help="Number of samples to evaluate (-1 for all)", + ) + parser.add_argument( + "--nsteps", + type=int, + default=100, + help="Number of diffusion steps for generation", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for computation (cuda/cpu)", + ) + parser.add_argument( + "--temperature_seq", + type=float, + default=0.5, + help="Temperature for sequence sampling", + ) + parser.add_argument( + "--temperature_struc", + type=float, + default=0.5, + help="Temperature for structure sampling", + ) + parser.add_argument( + "--save_structures", + action="store_true", + help="Save predicted structures as PDB files", + ) + parser.add_argument( + "--save_gt_structure", + action="store_true", + help="Save ground truth structures as PDB files", + ) + parser.add_argument( + "--minimize_ligand", + action="store_true", + help="Apply geometry correction to decoded ligand structures", + ) + parser.add_argument( + "--minimize_mode", + type=str, + default="bonds_and_angles", + choices=["bonds_only", "bonds_and_angles", "local", "full"], + help="Minimization mode", + ) + parser.add_argument( + "--force_field", + type=str, + default="MMFF94", + help="Force field for minimization (MMFF94, UFF, etc.)", + ) + parser.add_argument( + "--minimize_steps", + type=int, + default=500, + help="Maximum number of minimization steps", + ) + # Additional generation hyperparameters + parser.add_argument( + "--stochasticity_seq", + type=int, + default=20, + help="Stochasticity parameter for sequence sampling", + ) + parser.add_argument( + "--stochasticity_struc", + type=int, + default=20, + help="Stochasticity parameter for structure sampling", + ) + parser.add_argument( + "--temperature_ligand", + type=float, + default=0.5, + help="Temperature for ligand structure sampling", + ) + parser.add_argument( + "--stochasticity_ligand", + type=int, + default=20, + help="Stochasticity parameter for ligand structure sampling", + ) + parser.add_argument( + "--ligand_context_mode", + type=str, + default="structure_tokens", + choices=["structure_tokens", "atom_bond_only"], + help="How to provide ligand context: 'structure_tokens' or 'atom_bond_only'", + ) + parser.add_argument( + "--inference_schedule_seq", + type=str, + default="LogInferenceSchedule", + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for sequence generation", + ) + parser.add_argument( + "--inference_schedule_struc", + type=str, + default="LinearInferenceSchedule", + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for structure generation", + ) + parser.add_argument( + "--inference_schedule_ligand_atom", + type=str, + default=None, + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for ligand atom token generation (default: use sequence schedule)", + ) + parser.add_argument( + "--inference_schedule_ligand_struc", + type=str, + default=None, + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for ligand structure token generation (default: use structure schedule)", + ) + parser.add_argument( + "--seed", + type=int, + default=1234, + help="Random seed for reproducibility (sets torch, numpy, and python random seeds)", + ) + parser.add_argument( + "--num_predictions", + type=int, + default=1, + help="Number of predictions per sample for best-of-N evaluation (default: 1)", + ) + parser.add_argument( + "--best_of_n_metric", + type=str, + default="rmsd", + choices=["rmsd", "tm_score"], + help="Metric to use for best-of-N selection: 'rmsd' (lower is better) or 'tm_score' (higher is better)", + ) + parser.add_argument( + "--save_all_predictions", + action="store_true", + help="Save all N predicted structures (not just the best). Requires --save_structures and --num_predictions > 1", + ) + parser.add_argument( + "--try_reflection", + action="store_true", + help="Try both original and reflected (mirror image) coordinates, selecting the one with higher TM-score. " + "Useful if the model outputs mirror images of structures.", + ) + parser.add_argument( + "--max_protein_length", + type=int, + default=512, + help="Maximum protein-only length. Samples exceeding this are skipped (default: 512)", + ) + + args = parser.parse_args() + + # Set random seeds for reproducibility + logger.info(f"Setting random seed: {args.seed}") + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + # Validate checkpoint exists + if not Path(args.checkpoint).exists(): + logger.error(f"Checkpoint not found: {args.checkpoint}") + sys.exit(1) + + # Validate data directory exists + if not Path(args.data_dir).exists(): + logger.error(f"Data directory not found: {args.data_dir}") + sys.exit(1) + + # Create structure_path directory if specified + if args.structure_path: + os.makedirs(args.structure_path, exist_ok=True) + logger.info(f"Output directory: {args.structure_path}") + + # Set device + if args.device == "cuda" and not torch.cuda.is_available(): + logger.warning("CUDA not available, falling back to CPU") + args.device = "cpu" + + # Load model + logger.info(f"Loading model from {args.checkpoint}") + try: + model = ProteinLigandEncoderLightningModule.load_from_checkpoint( + args.checkpoint, + map_location=args.device, + ) + model.eval() + model.to(args.device) + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Failed to load model: {e}") + sys.exit(1) + + # Create evaluator + num_samples = None if args.num_samples == -1 else args.num_samples + + # Get max_length from model if available + max_length = 512 # default + if hasattr(model, "encoder") and hasattr(model.encoder, "neobert"): + if hasattr(model.encoder.neobert, "config") and hasattr(model.encoder.neobert.config, "max_length"): + max_length = model.encoder.neobert.config.max_length + logger.info(f"Using model's max_length: {max_length}") + + evaluator = ProteinLigandForwardFoldingEvaluator( + data_dir=args.data_dir, + pocket_distance_threshold=args.pocket_threshold, + num_samples=num_samples, + nsteps=args.nsteps, + device=args.device, + max_length=max_length, + max_protein_length=args.max_protein_length, + temperature_seq=args.temperature_seq, + temperature_struc=args.temperature_struc, + save_structures=args.save_structures, + save_gt_structure=args.save_gt_structure, + minimize_ligand=args.minimize_ligand, + minimize_mode=args.minimize_mode, + force_field=args.force_field, + minimize_steps=args.minimize_steps, + # Additional generation hyperparameters + stochasticity_seq=args.stochasticity_seq, + stochasticity_struc=args.stochasticity_struc, + temperature_ligand=args.temperature_ligand, + stochasticity_ligand=args.stochasticity_ligand, + ligand_context_mode=args.ligand_context_mode, + inference_schedule_seq=args.inference_schedule_seq, + inference_schedule_struc=args.inference_schedule_struc, + inference_schedule_ligand_atom=args.inference_schedule_ligand_atom, + inference_schedule_ligand_struc=args.inference_schedule_ligand_struc, + # Best-of-N parameters + num_predictions=args.num_predictions, + best_of_n_metric=args.best_of_n_metric, + save_all_predictions=args.save_all_predictions, + # Mirror image handling + try_reflection=args.try_reflection, + ) + + # Load samples + logger.info("Loading test samples...") + samples = evaluator.load_test_set() + logger.info(f"Loaded {len(samples)} samples") + + # Run evaluation + logger.info("Running evaluation...") + results = evaluator.evaluate(model, samples, structure_path=args.structure_path) + + # Save results CSV + output_path = args.output + if args.structure_path: + output_path = os.path.join(args.structure_path, os.path.basename(args.output)) + results["results_df"].to_csv(output_path, index=False) + logger.info(f"Results saved to {output_path}") + + # Print summary + summary = results["summary"] + print("\n" + "=" * 70) + print("Protein-Ligand Forward Folding Evaluation Results") + print("=" * 70) + + print(f"\nSamples evaluated: {summary['n_samples']}") + print(f"Average pocket size: {summary['mean_pocket_size']:.1f} residues") + print(f"Pocket distance threshold: {args.pocket_threshold} Å") + if args.num_predictions > 1: + print(f"Best-of-N: {args.num_predictions} predictions (selecting by {args.best_of_n_metric})") + if args.try_reflection: + print("Mirror image handling: enabled") + if "reflection_rate_no_ligand" in summary: + print( + f" Reflected (no ligand): {summary['n_reflected_no_ligand']}/{summary['n_samples']} " + f"({summary['reflection_rate_no_ligand']:.1%})" + ) + print( + f" Reflected (with ligand): {summary['n_reflected_with_ligand']}/{summary['n_samples']} " + f"({summary['reflection_rate_with_ligand']:.1%})" + ) + + print("\n--- TM-Score (Overall Structure Quality) ---") + print(f" Without ligand: {summary['mean_tm_score_no_ligand']:.4f}") + print(f" With ligand: {summary['mean_tm_score_with_ligand']:.4f}") + print(f" Delta: {summary['mean_tm_score_delta']:+.4f} (±{summary['std_tm_score_delta']:.4f})") + + print("\n--- Overall RMSD (Å) ---") + print(f" Without ligand: {summary['mean_rmsd_overall_no_ligand']:.2f}") + print(f" With ligand: {summary['mean_rmsd_overall_with_ligand']:.2f}") + print(f" Delta: {summary['mean_rmsd_overall_delta']:+.2f} (±{summary['std_rmsd_overall_delta']:.2f})") + + print("\n--- Binding Pocket RMSD (Å) ---") + print(f" Without ligand: {summary['mean_rmsd_pocket_no_ligand']:.2f}") + print(f" With ligand: {summary['mean_rmsd_pocket_with_ligand']:.2f}") + print(f" Delta: {summary['mean_rmsd_pocket_delta']:+.2f} (±{summary['std_rmsd_pocket_delta']:.2f})") + + print("\n--- Non-Pocket RMSD (Å) ---") + print(f" Without ligand: {summary['mean_rmsd_nonpocket_no_ligand']:.2f}") + print(f" With ligand: {summary['mean_rmsd_nonpocket_with_ligand']:.2f}") + print(f" Delta: {summary['mean_rmsd_nonpocket_delta']:+.2f} (±{summary['std_rmsd_nonpocket_delta']:.2f})") + + print("\n" + "=" * 70) + + # Key insights + tm_delta = summary["mean_tm_score_delta"] + pocket_rmsd_delta = summary["mean_rmsd_pocket_delta"] + + if tm_delta > 0.01: + print(f"🎯 Ligand context IMPROVES TM-score by {tm_delta:.4f}!") + elif tm_delta < -0.01: + print(f"⚠️ Ligand context DECREASES TM-score by {abs(tm_delta):.4f}") + else: + print("📊 Ligand context has minimal effect on TM-score") + + # For RMSD, negative delta means improvement (lower RMSD is better) + if pocket_rmsd_delta < -0.1: + print(f"🎯 Ligand context IMPROVES pocket RMSD by {abs(pocket_rmsd_delta):.2f} Å!") + elif pocket_rmsd_delta > 0.1: + print(f"⚠️ Ligand context INCREASES pocket RMSD by {pocket_rmsd_delta:.2f} Å") + else: + print("📊 Ligand context has minimal effect on pocket RMSD") + + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/evaluate_protein_ligand_inverse_folding.py b/src/lobster/cmdline/evaluate_protein_ligand_inverse_folding.py new file mode 100644 index 00000000..bce3245f --- /dev/null +++ b/src/lobster/cmdline/evaluate_protein_ligand_inverse_folding.py @@ -0,0 +1,417 @@ +"""Standalone evaluation of inverse folding on protein-ligand complexes. + +Evaluates whether ligand context improves inverse folding performance, +particularly for binding pocket residues. + +Usage: + uv run python -m lobster.cmdline.evaluate_protein_ligand_inverse_folding \ + --checkpoint path/to/model.ckpt \ + --data_dir /data2/lisanzas/pdb_bind_12_15_25/test/ \ + --output results.csv \ + --structure_path ./output/ \ + --pocket_threshold 5.0 \ + --num_samples 100 + +Example (full test set): + uv run python -m lobster.cmdline.evaluate_protein_ligand_inverse_folding \ + --checkpoint /data2/ume/gen_ume_protein_ligand/best.ckpt \ + --data_dir /data2/lisanzas/pdb_bind_12_15_25/test/ \ + --output protein_ligand_inverse_folding_results.csv \ + --structure_path ./protein_ligand_eval/ \ + --num_samples -1 +""" + +import argparse +import os +import random +import sys +from pathlib import Path + +import numpy as np +import torch +from loguru import logger + +from lobster.metrics.protein_ligand_inverse_folding import ProteinLigandInverseFoldingEvaluator +from lobster.model.gen_ume import ProteinLigandEncoderLightningModule + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate inverse folding on protein-ligand complexes with/without ligand context" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint (.ckpt file)", + ) + parser.add_argument( + "--data_dir", + type=str, + default="/data2/lisanzas/pdb_bind_12_15_25/test/", + help="Path to protein-ligand test directory", + ) + parser.add_argument( + "--output", + type=str, + default="protein_ligand_inverse_folding_results.csv", + help="Output CSV file for results", + ) + parser.add_argument( + "--structure_path", + type=str, + default=None, + help="Output directory for designed sequences (FASTA files)", + ) + parser.add_argument( + "--pocket_threshold", + type=float, + default=5.0, + help="Distance threshold (Å) for defining binding pocket", + ) + parser.add_argument( + "--num_samples", + type=int, + default=100, + help="Number of samples to evaluate (-1 for all)", + ) + parser.add_argument( + "--nsteps", + type=int, + default=100, + help="Number of diffusion steps for generation", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for computation (cuda/cpu)", + ) + parser.add_argument( + "--temperature_seq", + type=float, + default=0.5, + help="Temperature for sequence sampling", + ) + parser.add_argument( + "--temperature_struc", + type=float, + default=0.5, + help="Temperature for structure sampling", + ) + parser.add_argument( + "--decode_structure", + action="store_true", + help="Decode and save predicted structures as PDB files", + ) + parser.add_argument( + "--save_gt_structure", + action="store_true", + help="Save ground truth structures as PDB files", + ) + parser.add_argument( + "--minimize_ligand", + action="store_true", + help="Apply geometry correction to decoded ligand structures", + ) + parser.add_argument( + "--minimize_mode", + type=str, + default="bonds_and_angles", + choices=["bonds_only", "bonds_and_angles", "local", "full"], + help="Minimization mode", + ) + parser.add_argument( + "--force_field", + type=str, + default="MMFF94", + help="Force field for minimization (MMFF94, UFF, etc.)", + ) + parser.add_argument( + "--minimize_steps", + type=int, + default=500, + help="Maximum number of minimization steps", + ) + parser.add_argument( + "--stochasticity_seq", + type=int, + default=20, + help="Stochasticity parameter for sequence sampling", + ) + parser.add_argument( + "--stochasticity_struc", + type=int, + default=20, + help="Stochasticity parameter for structure sampling", + ) + parser.add_argument( + "--temperature_ligand", + type=float, + default=0.5, + help="Temperature for ligand structure sampling", + ) + parser.add_argument( + "--stochasticity_ligand", + type=int, + default=20, + help="Stochasticity parameter for ligand structure sampling", + ) + parser.add_argument( + "--inference_schedule_seq", + type=str, + default="LogInferenceSchedule", + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for sequence generation", + ) + parser.add_argument( + "--inference_schedule_struc", + type=str, + default="LinearInferenceSchedule", + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for structure generation", + ) + parser.add_argument( + "--inference_schedule_ligand_atom", + type=str, + default=None, + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for ligand atom token generation (default: use sequence schedule)", + ) + parser.add_argument( + "--inference_schedule_ligand_struc", + type=str, + default=None, + choices=["LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule"], + help="Inference schedule for ligand structure token generation (default: use structure schedule)", + ) + parser.add_argument( + "--save_reconstructed_input", + action="store_true", + help="Save reconstructed input structures (encode then decode) to verify token fidelity", + ) + parser.add_argument( + "--use_se3_augmentation", + action="store_true", + help="Apply random SE3 augmentation (rotation + translation) to input structures before encoding", + ) + parser.add_argument( + "--se3_translation_scale", + type=float, + default=1.0, + help="Scale factor for random translation when SE3 augmentation is enabled", + ) + parser.add_argument( + "--seed", + type=int, + default=1234, + help="Random seed for reproducibility (sets torch, numpy, and python random seeds)", + ) + parser.add_argument( + "--use_esmfold", + action="store_true", + help="Validate designed sequences with ESMFold (fold and compare to GT structure)", + ) + parser.add_argument( + "--max_protein_length", + type=int, + default=512, + help="Maximum protein-only length. Samples exceeding this are skipped. Also used as ESMFold max length (default: 512)", + ) + + args = parser.parse_args() + + # Set random seeds for reproducibility + logger.info(f"Setting random seed: {args.seed}") + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + # Validate checkpoint exists + if not Path(args.checkpoint).exists(): + logger.error(f"Checkpoint not found: {args.checkpoint}") + sys.exit(1) + + # Validate data directory exists + if not Path(args.data_dir).exists(): + logger.error(f"Data directory not found: {args.data_dir}") + sys.exit(1) + + # Create structure_path directory if specified + if args.structure_path: + os.makedirs(args.structure_path, exist_ok=True) + logger.info(f"Output directory: {args.structure_path}") + + # Set device + if args.device == "cuda" and not torch.cuda.is_available(): + logger.warning("CUDA not available, falling back to CPU") + args.device = "cpu" + + # Load model + logger.info(f"Loading model from {args.checkpoint}") + try: + model = ProteinLigandEncoderLightningModule.load_from_checkpoint( + args.checkpoint, + map_location=args.device, + ) + model.eval() + model.to(args.device) + logger.info("Model loaded successfully") + except Exception as e: + logger.error(f"Failed to load model: {e}") + sys.exit(1) + + # Initialize ESMFold if requested + plm_fold = None + if args.use_esmfold: + from lobster.model import LobsterPLMFold + + logger.info("Loading ESMFold for structure validation...") + plm_fold = LobsterPLMFold(model_name="esmfold_v1", max_length=512) + plm_fold.to(args.device) + logger.info("ESMFold loaded successfully") + + # Create evaluator + num_samples = None if args.num_samples == -1 else args.num_samples + + # Get max_length from model if available + max_length = 512 # default + if hasattr(model, "encoder") and hasattr(model.encoder, "neobert"): + if hasattr(model.encoder.neobert, "config") and hasattr(model.encoder.neobert.config, "max_length"): + max_length = model.encoder.neobert.config.max_length + logger.info(f"Using model's max_length: {max_length}") + + evaluator = ProteinLigandInverseFoldingEvaluator( + data_dir=args.data_dir, + pocket_distance_threshold=args.pocket_threshold, + num_samples=num_samples, + nsteps=args.nsteps, + device=args.device, + max_length=max_length, + decode_structure=args.decode_structure, + save_gt_structure=args.save_gt_structure, + minimize_ligand=args.minimize_ligand, + minimize_mode=args.minimize_mode, + force_field=args.force_field, + minimize_steps=args.minimize_steps, + save_reconstructed_input=args.save_reconstructed_input, + use_se3_augmentation=args.use_se3_augmentation, + se3_translation_scale=args.se3_translation_scale, + # Generation hyperparameters + temperature_seq=args.temperature_seq, + temperature_struc=args.temperature_struc, + stochasticity_seq=args.stochasticity_seq, + stochasticity_struc=args.stochasticity_struc, + temperature_ligand=args.temperature_ligand, + stochasticity_ligand=args.stochasticity_ligand, + inference_schedule_seq=args.inference_schedule_seq, + inference_schedule_struc=args.inference_schedule_struc, + inference_schedule_ligand_atom=args.inference_schedule_ligand_atom, + inference_schedule_ligand_struc=args.inference_schedule_ligand_struc, + use_esmfold=args.use_esmfold, + plm_fold=plm_fold, + max_protein_length=args.max_protein_length, + ) + + # Log SE3 augmentation status + if args.use_se3_augmentation: + logger.info(f"SE3 augmentation ENABLED (translation_scale={args.se3_translation_scale})") + else: + logger.info("SE3 augmentation DISABLED (deterministic encoding)") + + # Load samples + logger.info("Loading test samples...") + samples = evaluator.load_test_set() + logger.info(f"Loaded {len(samples)} samples") + + # Run evaluation with sequence saving + logger.info("Running evaluation...") + results = evaluator.evaluate(model, samples, structure_path=args.structure_path) + + # Save results CSV + output_path = args.output + if args.structure_path: + output_path = os.path.join(args.structure_path, os.path.basename(args.output)) + results["results_df"].to_csv(output_path, index=False) + logger.info(f"Results saved to {output_path}") + + # Print summary + summary = results["summary"] + print("\n" + "=" * 70) + print("Protein-Ligand Inverse Folding Evaluation Results") + print("=" * 70) + + print(f"\nSamples evaluated: {summary['n_samples']}") + print(f"Average pocket size: {summary['mean_pocket_size']:.1f} residues") + print(f"Pocket distance threshold: {args.pocket_threshold} Å") + + print("\n--- Overall Amino Acid Recovery ---") + print(f" Without ligand: {summary['mean_aar_overall_no_ligand']:.2%}") + print(f" With ligand: {summary['mean_aar_overall_with_ligand']:.2%}") + print(f" Delta: {summary['mean_aar_overall_delta']:+.2%}") + + print("\n--- Binding Pocket Amino Acid Recovery ---") + print(f" Without ligand: {summary['mean_aar_pocket_no_ligand']:.2%}") + print(f" With ligand: {summary['mean_aar_pocket_with_ligand']:.2%}") + print(f" Delta: {summary['mean_aar_pocket_delta']:+.2%} (±{summary['std_aar_pocket_delta']:.2%})") + + print("\n--- Non-Pocket Amino Acid Recovery ---") + print(f" Without ligand: {summary['mean_aar_nonpocket_no_ligand']:.2%}") + print(f" With ligand: {summary['mean_aar_nonpocket_with_ligand']:.2%}") + print(f" Delta: {summary['mean_aar_nonpocket_delta']:+.2%} (±{summary['std_aar_nonpocket_delta']:.2%})") + + # ESMFold validation results + if args.use_esmfold and "mean_esmfold_tm_no_ligand" in summary: + print("\n--- ESMFold Designability Validation ---") + print(f" {'Condition':<20} {'TM-score':<12} {'RMSD (Å)':<12} {'Pocket RMSD':<14} {'pLDDT':<12} {'PAE':<12}") + print(" " + "-" * 82) + print( + f" {'GT sequence':<20} " + f"{summary['mean_esmfold_tm_gt']:<12.3f} " + f"{summary['mean_esmfold_rmsd_gt']:<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_gt']:<14.2f} " + f"{summary['mean_esmfold_plddt_gt']:<12.2f} " + f"{summary['mean_esmfold_pae_gt']:<12.2f}" + ) + print( + f" {'No ligand':<20} " + f"{summary['mean_esmfold_tm_no_ligand']:<12.3f} " + f"{summary['mean_esmfold_rmsd_no_ligand']:<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_no_ligand']:<14.2f} " + f"{summary['mean_esmfold_plddt_no_ligand']:<12.2f} " + f"{summary['mean_esmfold_pae_no_ligand']:<12.2f}" + ) + print( + f" {'With ligand':<20} " + f"{summary['mean_esmfold_tm_with_ligand']:<12.3f} " + f"{summary['mean_esmfold_rmsd_with_ligand']:<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_with_ligand']:<14.2f} " + f"{summary['mean_esmfold_plddt_with_ligand']:<12.2f} " + f"{summary['mean_esmfold_pae_with_ligand']:<12.2f}" + ) + print( + f" {'Delta (ligand)':<20} " + f"{summary['mean_esmfold_tm_delta']:+<12.3f} " + f"{summary['mean_esmfold_rmsd_delta']:+<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_delta']:+<14.2f} " + f"{summary['mean_esmfold_plddt_delta']:+<12.2f}" + ) + + print("\n" + "=" * 70) + + # Key insight + pocket_delta = summary["mean_aar_pocket_delta"] + if pocket_delta > 0.01: + print(f"Ligand context IMPROVES pocket recovery by {pocket_delta * 100:.1f}%!") + elif pocket_delta < -0.01: + print(f"Ligand context DECREASES pocket recovery by {abs(pocket_delta) * 100:.1f}%") + else: + print("Ligand context has minimal effect on pocket recovery") + + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/src/lobster/cmdline/extract_pt_to_fasta.py b/src/lobster/cmdline/extract_pt_to_fasta.py new file mode 100755 index 00000000..2eb43dbe --- /dev/null +++ b/src/lobster/cmdline/extract_pt_to_fasta.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +""" +Extract sequences from .pt files and save to FASTA format. +""" + +import torch +from pathlib import Path +import glob +from tqdm import tqdm + +# Standard amino acid mapping (same as lobster) +RESTYPES = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X"] + + +def extract_sequence_from_pt(pt_file): + """Extract amino acid sequence from .pt file.""" + try: + # Load .pt file + data = torch.load(pt_file, map_location="cpu") + + # Try to find sequence - check multiple possible keys + sequence_tensor = None + if "sequence" in data: + sequence_tensor = data["sequence"] + elif "seq" in data: + sequence_tensor = data["seq"] + elif "aatype" in data: + sequence_tensor = data["aatype"] + else: + print(f"No sequence found in {pt_file}. Available keys: {list(data.keys())}") + return None + + # Convert tensor to numpy if needed + if isinstance(sequence_tensor, torch.Tensor): + sequence_indices = sequence_tensor.cpu().numpy() + else: + sequence_indices = sequence_tensor + + # Flatten if multidimensional + if sequence_indices.ndim > 1: + sequence_indices = sequence_indices.flatten() + + # Convert integer codes to amino acids + sequence = "".join([RESTYPES[int(i)] if 0 <= int(i) < len(RESTYPES) else "X" for i in sequence_indices]) + + return sequence + + except Exception as e: + print(f"Error loading {pt_file}: {e}") + return None + + +def extract_all_sequences_to_fasta(input_dir, output_fasta, truncate_at_x=False): + """Extract all sequences from .pt files to FASTA. + + Args: + input_dir: Directory containing .pt files + output_fasta: Output FASTA file path + truncate_at_x: If True, truncate sequences at first X (unknown residue) + """ + + input_path = Path(input_dir) + + # Find all .pt files + pt_files = sorted(glob.glob(str(input_path / "*.pt"))) + + if not pt_files: + print(f"No .pt files found in {input_dir}") + return + + print(f"Found {len(pt_files)} .pt files") + print(f"Output FASTA: {output_fasta}") + if truncate_at_x: + print("Mode: Truncate at first X (unknown residue)") + print() + + sequences_written = 0 + errors = 0 + total_residues = 0 + truncated_count = 0 + + with open(output_fasta, "w") as fasta_out: + for pt_file in tqdm(pt_files, desc="Extracting sequences"): + # Get structure name from filename (without .pt) + structure_name = Path(pt_file).stem + + # Remove common suffixes + for suffix in ["_processed", "_cleaned"]: + if structure_name.endswith(suffix): + structure_name = structure_name[: -len(suffix)] + + # Extract sequence + sequence = extract_sequence_from_pt(pt_file) + + if sequence: + # Optionally truncate at first X + if truncate_at_x and "X" in sequence: + first_x = sequence.index("X") + if first_x > 0: # Only truncate if there's something before X + sequence = sequence[:first_x] + truncated_count += 1 + + # Skip empty sequences + if len(sequence) == 0: + errors += 1 + continue + + # Write to FASTA format + fasta_out.write(f">{structure_name}\n") + + # Write sequence in 80 character lines (standard FASTA) + for i in range(0, len(sequence), 80): + fasta_out.write(sequence[i : i + 80] + "\n") + + sequences_written += 1 + total_residues += len(sequence) + else: + errors += 1 + + # Print summary + print(f"\n{'=' * 60}") + print("SUMMARY") + print("=" * 60) + print(f"✓ Successfully extracted: {sequences_written} sequences") + print(f"✗ Failed to process: {errors} files") + if truncate_at_x and truncated_count > 0: + print(f" Sequences truncated: {truncated_count}") + print(f" Total residues: {total_residues:,}") + + if sequences_written > 0: + print(f" Average length: {total_residues / sequences_written:.1f} residues") + + print(f"\nOutput saved to: {output_fasta}") + print("=" * 60) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Extract sequences from .pt files to FASTA format", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Extract all sequences (default) + python extract_pt_to_fasta.py + + # Specify input and output + python extract_pt_to_fasta.py --input-dir /path/to/pt/files --output sequences.fasta + + # Truncate at first X (unknown residue) + python extract_pt_to_fasta.py --truncate-at-x + """, + ) + + parser.add_argument( + "--input-dir", + type=str, + default="/data2/lisanzas/multi_flow_data/test_set_filtered_pt", + help="Directory containing .pt files (default: /data2/lisanzas/multi_flow_data/test_set_filtered_pt)", + ) + + parser.add_argument( + "--output", + type=str, + default="test_set_filtered_sequences.fasta", + help="Output FASTA file path (default: test_set_filtered_sequences.fasta)", + ) + + parser.add_argument("--truncate-at-x", action="store_true", help="Truncate sequences at first X (unknown residue)") + + args = parser.parse_args() + + print("=" * 60) + print("PT to FASTA Extractor") + print("=" * 60) + print(f"Input directory: {args.input_dir}") + print(f"Output FASTA: {args.output}\n") + + extract_all_sequences_to_fasta(args.input_dir, args.output, args.truncate_at_x) diff --git a/src/lobster/cmdline/generate.py b/src/lobster/cmdline/generate.py index 4f739e54..92a62f58 100644 --- a/src/lobster/cmdline/generate.py +++ b/src/lobster/cmdline/generate.py @@ -1,10 +1,11 @@ +import csv import logging from pathlib import Path import glob import hydra import torch -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf from loguru import logger from lobster.model.latent_generator.io import writepdb, load_pdb @@ -29,18 +30,52 @@ from lobster.transforms._structure_transforms import StructureBackboneTransform, AminoAcidTokenizerTransform from tmtools import tm_align from lobster.model import LobsterPLMFold +from lobster.model.gen_ume.binder_utils import ( + get_target_chain_info, + initialize_binder_at_origin, + get_next_chain_index, + create_binder_inpainting_masks, +) +from bionemo.moco.schedules.inference_time_schedules import ( + LinearInferenceSchedule, + LogInferenceSchedule, + PowerInferenceSchedule, +) # Set up logging logging.basicConfig(level=logging.INFO) -@hydra.main(version_base=None, config_path="../hydra_config", config_name="generate") +def _get_inference_schedule_class(schedule_name: str): + """Convert schedule name string to schedule class. + + Args: + schedule_name: String name of schedule ("LinearInferenceSchedule", "LogInferenceSchedule", "PowerInferenceSchedule") + + Returns: + Schedule class (callable) + """ + schedule_map = { + "LinearInferenceSchedule": LinearInferenceSchedule, + "LogInferenceSchedule": LogInferenceSchedule, + "PowerInferenceSchedule": PowerInferenceSchedule, + } + + if schedule_name not in schedule_map: + raise ValueError(f"Unknown schedule name: {schedule_name}. Available options: {list(schedule_map.keys())}") + + return schedule_map[schedule_name] + + +@hydra.main(version_base=None, config_path="../hydra_config/experiment", config_name="generate_unconditional") def generate(cfg: DictConfig) -> None: """Generate protein structures using genUME model. This command-line interface supports: - Unconditional generation: Generate novel protein structures from scratch - Inverse folding: Generate sequences for given protein structures + - Forward folding: Generate structures for given sequences + - Inpainting: Generate structures for given sequences and structures - Optional ESMFold validation of generated structures """ logger.info("Starting genUME structure generation") @@ -90,7 +125,7 @@ def generate(cfg: DictConfig) -> None: plotter = None if cfg.generation.get("save_csv_metrics", True): generation_mode = cfg.generation.mode - csv_writer = MetricsCSVWriter(output_dir, generation_mode) + csv_writer = MetricsCSVWriter(output_dir, generation_mode, resume=cfg.generation.get("resume", False)) logger.info(f"CSV metrics logging enabled for {generation_mode} mode") # Initialize plotter if plotting is enabled @@ -110,6 +145,8 @@ def generate(cfg: DictConfig) -> None: _generate_forward_folding(model, cfg, device, output_dir, plm_fold, csv_writer, plotter) elif generation_mode == "inpainting": _generate_inpainting(model, cfg, device, output_dir, plm_fold, csv_writer, plotter) + elif generation_mode == "binder_design": + _generate_binders(model, cfg, device, output_dir, plm_fold, csv_writer, plotter) else: raise ValueError(f"Unknown generation mode: {generation_mode}") @@ -127,6 +164,9 @@ def _check_sequence_tokens( - Tokens > 20 (mask/special tokens) - Negative values + Also checks for low-complexity sequences: + - Sequences where one amino acid accounts for > 50% of the total + Args: sequences: Sequence tensor (B, L) with amino acid token indices mask: Validity mask (B, L) indicating which positions are valid @@ -159,6 +199,25 @@ def _check_sequence_tokens( if num_negative > 0: return False, f"Sample {i} in {stage_name} contains negative token values" + # Check for low-complexity sequences (one amino acid > 50%) + # Only check valid amino acids (0-19) + valid_aa_mask = (seq_i >= 0) & (seq_i < 20) + valid_aa_seq = seq_i[valid_aa_mask] + + if len(valid_aa_seq) > 0: + # Count frequency of each amino acid + aa_counts = torch.bincount(valid_aa_seq, minlength=20) + max_count = aa_counts.max().item() + max_percentage = (max_count / len(valid_aa_seq)) * 100 + + if max_percentage > 50.0: + max_aa_idx = aa_counts.argmax().item() + return False, ( + f"Sample {i} in {stage_name} is low-complexity: " + f"amino acid {max_aa_idx} accounts for {max_percentage:.1f}% " + f"({max_count}/{len(valid_aa_seq)}) of the sequence" + ) + return True, "" @@ -275,6 +334,14 @@ def _execute_self_reflection_pipeline( forward_params = _get_self_reflection_params(cfg, "forward_folding") logger.info(f" Forward folding parameters: {forward_params}") + # Get inference schedule classes from config (use same as main generation) + inference_schedule_seq = gen_cfg.get("inference_schedule_seq", "LogInferenceSchedule") + inference_schedule_struc = gen_cfg.get("inference_schedule_struc", "LinearInferenceSchedule") + if isinstance(inference_schedule_seq, str): + inference_schedule_seq = _get_inference_schedule_class(inference_schedule_seq) + if isinstance(inference_schedule_struc, str): + inference_schedule_struc = _get_inference_schedule_class(inference_schedule_struc) + forward_sample = model.generate_sample( length=current_length, num_samples=batch_size, @@ -287,6 +354,8 @@ def _execute_self_reflection_pipeline( temperature_struc=forward_params["temperature_struc"], stochasticity_seq=forward_params["stochasticity_seq"], stochasticity_struc=forward_params["stochasticity_struc"], + inference_schedule_seq=inference_schedule_seq, + inference_schedule_struc=inference_schedule_struc, asynchronous_sampling=gen_cfg.get("asynchronous_sampling", False), ) @@ -383,6 +452,11 @@ def _execute_self_reflection_pipeline( inverse_params = _get_self_reflection_params(cfg, "inverse_folding") logger.info(f" Inverse folding parameters: {inverse_params}") + # Get inference schedule classes from inverse folding parameters + inference_schedule_seq = inverse_params.get("inference_schedule_seq", "LogInferenceSchedule") + if isinstance(inference_schedule_seq, str): + inference_schedule_seq = _get_inference_schedule_class(inference_schedule_seq) + inverse_sample = model.generate_sample( length=current_length, num_samples=batch_size, @@ -393,6 +467,7 @@ def _execute_self_reflection_pipeline( nsteps=inverse_params["nsteps"], temperature_seq=inverse_params["temperature_seq"], stochasticity_seq=inverse_params["stochasticity_seq"], + inference_schedule_seq=inference_schedule_seq, asynchronous_sampling=gen_cfg.get("asynchronous_sampling", False), ) @@ -881,28 +956,104 @@ def _generate_unconditional( f"Generating {num_samples} structures of length {current_length} with {nsteps} steps, will run with batch size {batch_size} for {n_iterations} iterations" ) - # Initialize metrics collection for this length + # Resume support: build set of already-completed iterations (by PDB existence) + # This handles gaps from skipped iterations (e.g. max_retries exceeded) + completed_iterations = set() + is_resuming = gen_cfg.get("resume", False) + if is_resuming: + for check_iter in range(n_iterations): + all_exist = True + for check_i in range(batch_size): + check_file = ( + output_dir + / f"generated_structure_length_{current_length}_{check_iter * batch_size + check_i:03d}.pdb" + ) + if not check_file.exists(): + all_exist = False + break + if all_exist: + completed_iterations.add(check_iter) + if completed_iterations: + logger.info( + f"Resuming: {len(completed_iterations)}/{n_iterations} iterations already complete " + f"for length {current_length} (will skip them individually)" + ) + if len(completed_iterations) >= n_iterations: + logger.info(f"All {n_iterations} iterations already complete for length {current_length}, skipping") + continue + + # Initialize metrics collection for this length, pre-loading from CSV on resume. + # Deduplicate by run_id keeping only the latest entry (by row order / timestamp). all_metrics = [] + if is_resuming and completed_iterations: + existing_csvs = sorted(output_dir.glob("*_metrics_*.csv"), key=lambda x: x.stat().st_mtime) + if existing_csvs: + csv_path = existing_csvs[-1] + logger.info(f"Loading prior metrics from {csv_path} for length {current_length}") + try: + csv_col_to_internal_key = { + "plddt": "_plddt", + "predicted_aligned_error": "_predicted_aligned_error", + "tm_score": "_tm_score", + "rmsd": "_rmsd", + } + metrics_by_run_id = {} + with open(csv_path, newline="") as f: + reader = csv.DictReader(f) + for row in reader: + if row.get("sequence_length") and int(float(row["sequence_length"])) == current_length: + run_id = row.get("run_id", "") + metrics_dict = {} + for key, value in row.items(): + if key in ("run_id", "timestamp", "mode", "sequence_length", "num_samples"): + continue + if value is not None and value != "": + try: + internal_key = csv_col_to_internal_key.get(key, key) + metrics_dict[internal_key] = float(value) + except (ValueError, TypeError): + pass + if metrics_dict: + metrics_by_run_id[run_id] = metrics_dict + all_metrics = list(metrics_by_run_id.values()) + logger.info(f"Pre-loaded {len(all_metrics)} unique metric entries for length {current_length}") + except Exception as e: + logger.warning(f"Failed to load prior metrics from CSV: {e}") + all_metrics = [] # Get quality control config for retry logic qc_config = {} if hasattr(gen_cfg, "self_reflection") and hasattr(gen_cfg.self_reflection, "quality_control"): qc_config = gen_cfg.self_reflection.quality_control - # Enable retries if any QC threshold is enabled - qc_enabled = ( + # Check for independent sequence token check (not tied to self-reflection) + enable_sequence_token_check = gen_cfg.get("enable_sequence_token_check", True) + sequence_token_check_retries = gen_cfg.get("sequence_token_check_retries", 10) + + # Enable retries if any QC threshold is enabled (from self-reflection) + self_reflection_qc_enabled = ( qc_config.get("enable_tm_threshold", False) or qc_config.get("enable_min_percent_identity_threshold", False) or qc_config.get("enable_max_percent_identity_threshold", False) or qc_config.get("enable_sequence_token_check", True) # Token check enabled by default ) - max_retries = qc_config.get("max_retries", 3) if qc_enabled else 0 + + # Determine max_retries based on what's enabled + if self_reflection_qc_enabled: + max_retries = qc_config.get("max_retries", 3) + if enable_sequence_token_check and not gen_cfg.get("enable_self_reflection", False): + max_retries = sequence_token_check_retries + else: + max_retries = 0 # Track retry statistics total_retries = 0 max_retries_exceeded = 0 for n_iter in range(n_iterations): + if n_iter in completed_iterations: + logger.debug(f"Skipping already-completed iteration {n_iter + 1}/{n_iterations}") + continue logger.info(f"Iteration {n_iter + 1}/{n_iterations}") # Retry loop for quality control @@ -915,6 +1066,42 @@ def _generate_unconditional( total_retries += 1 with torch.no_grad(): + # Get inference schedule classes from config + inference_schedule_seq = gen_cfg.get("inference_schedule_seq", "LogInferenceSchedule") + inference_schedule_struc = gen_cfg.get("inference_schedule_struc", "LinearInferenceSchedule") + + # Convert string names to classes if needed + if isinstance(inference_schedule_seq, str): + inference_schedule_seq = _get_inference_schedule_class(inference_schedule_seq) + if isinstance(inference_schedule_struc, str): + inference_schedule_struc = _get_inference_schedule_class(inference_schedule_struc) + + # Build sequence anchor tensors if enabled + anchor_tokens = None + anchor_mask = None + anchor_cfg = gen_cfg.get("sequence_anchor_fraction", 0.0) + if isinstance(anchor_cfg, (dict, DictConfig)): + anchor_fraction = float( + anchor_cfg.get(current_length, anchor_cfg.get(str(current_length), 0.0)) + ) + else: + anchor_fraction = float(anchor_cfg) + if anchor_fraction > 0.0: + num_anchors = max(1, int(current_length * anchor_fraction)) + anchor_positions = torch.randperm(current_length, device=device)[:num_anchors] + # Sample from 19 amino acids excluding Cysteine (index 4) + allowed_aa = torch.tensor( + [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], device=device + ) + rand_indices = torch.randint(0, len(allowed_aa), (batch_size, current_length), device=device) + anchor_tokens = allowed_aa[rand_indices] + anchor_mask = torch.ones((batch_size, current_length), device=device) + anchor_mask[:, anchor_positions] = 0 # 0 = keep anchored + logger.info( + f"Sequence anchors enabled: fixing {num_anchors}/{current_length} " + f"positions ({anchor_fraction * 100:.0f}%)" + ) + # Generate samples generate_sample = model.generate_sample( length=current_length, @@ -924,7 +1111,11 @@ def _generate_unconditional( temperature_struc=gen_cfg.get("temperature_struc", 1.0), stochasticity_seq=gen_cfg.get("stochasticity_seq", 20), stochasticity_struc=gen_cfg.get("stochasticity_struc", 20), + inference_schedule_seq=inference_schedule_seq, + inference_schedule_struc=inference_schedule_struc, asynchronous_sampling=gen_cfg.get("asynchronous_sampling", False), + sequence_anchor_tokens=anchor_tokens, + sequence_anchor_mask=anchor_mask, ) # Create mask for decoding @@ -965,15 +1156,43 @@ def _generate_unconditional( # Quality control failed, will retry retry_count += 1 if retry_count > max_retries: - logger.error( + logger.warning( f" Max retries ({max_retries}) exceeded for iteration {n_iter + 1}. " + f"Using current sequences despite quality control failure. " f"Skipping self-reflection for this iteration." ) max_retries_exceeded += 1 iteration_success = True continue + elif enable_sequence_token_check: + # Extract sequences for validation + if generate_sample["sequence_logits"].shape[-1] == 33: + check_seq = convert_lobster_aa_tokenization_to_standard_aa( + generate_sample["sequence_logits"], device=device + ) + else: + check_seq = generate_sample["sequence_logits"].argmax(dim=-1) + check_seq[check_seq > 21] = 20 + + # Run sequence token check + is_valid, error_msg = _check_sequence_tokens(check_seq, mask, "unconditional generation") + if not is_valid: + logger.warning(f" Sequence token check FAILED: {error_msg}") + logger.warning(" Iteration will be retried (invalid sequence tokens)") + retry_count += 1 + if retry_count > max_retries: + logger.warning( + f" Max retries ({max_retries}) exceeded for iteration {n_iter + 1}. " + f"Using current sequences despite quality control failure." + ) + max_retries_exceeded += 1 + iteration_success = True + continue + else: + logger.info(" Sequence token check PASSED: All sequences contain valid amino acids") + iteration_success = True else: - # Self-reflection disabled, no quality control + # No quality control at all iteration_success = True # Only proceed with normal flow if iteration succeeded or max retries exceeded @@ -1002,22 +1221,32 @@ def _generate_unconditional( seq = generate_sample["sequence_logits"].argmax(dim=-1) seq[seq > 21] = 20 + # Extract structure tokens (argmax) + structure_tokens = generate_sample["structure_logits"].argmax(dim=-1) # Shape: [batch_size, length] + # Write sequences to CSV # Note: For self-reflection mode, we only store initial unconditional sequences (not forward/inverse intermediates) if csv_writer is not None: # Convert sequences to strings sequence_strs = [] + structure_token_strs = [] for i in range(batch_size): seq_i = seq[i, mask[i] == 1] sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_i]) sequence_strs.append(sequence_str) + # Convert structure tokens to comma-separated string + tokens_i = structure_tokens[i, mask[i] == 1] + tokens_str = ",".join([str(t.item()) for t in tokens_i]) + structure_token_strs.append(tokens_str) + # Write to sequences CSV csv_writer.write_sequences( sequences=sequence_strs, run_id=f"unconditional_length_{current_length}_iter_{n_iter:03d}", iteration=n_iter, sequence_type="unconditional", + latent_generator_tokens=structure_token_strs, ) # Save generated structures @@ -1044,10 +1273,11 @@ def _generate_unconditional( ) # Log metrics for unconditional generation - if batch_metrics: + if batch_metrics and not batch_metrics.get("_skipped", False): logger.info("ESMFold validation metrics for unconditional generation:") for key, value in batch_metrics.items(): - logger.info(f" {key}: {value:.4f}") + if isinstance(value, (int, float)): + logger.info(f" {key}: {value:.4f}") # Store metrics for CSV logging if csv_writer is not None: @@ -1175,7 +1405,7 @@ def _generate_unconditional( foldseek_bin_path = cfg.generation.get( "foldseek_bin_path", - "/homefs/home/lisanzas/scratch/Develop/lobster/src/lobster/metrics/foldseek/bin", + str(Path(__file__).resolve().parent.parent / "metrics" / "foldseek" / "bin"), ) try: @@ -1264,8 +1494,8 @@ def _generate_inverse_folding( structure_paths.extend(glob.glob(str(path / "*.pt"))) else: raise ValueError(f"Input path does not exist: {input_structures}") - elif isinstance(input_structures, (list, tuple)): - # List of paths + elif isinstance(input_structures, (list, tuple, ListConfig)): + # List of paths (includes OmegaConf ListConfig) for path_str in input_structures: path = Path(path_str) if path.is_file(): @@ -1289,6 +1519,68 @@ def _generate_inverse_folding( logger.info(f"Processing structures with {nsteps} generation steps, batch size {batch_size}, n_trials {n_trials}") logger.info(f"Generating {n_designs_per_structure} sequence design(s) per structure") + # Build ligand file mapping for pocket AAR computation (optional) + ligand_structures_cfg = gen_cfg.get("ligand_structures", None) + pocket_distance_threshold = gen_cfg.get("pocket_distance_threshold", 5.0) + ligand_file_map = {} # Maps protein path -> ligand path + + if ligand_structures_cfg is not None: + # Resolve ligand file paths + ligand_paths = [] + if isinstance(ligand_structures_cfg, str): + if "*" in ligand_structures_cfg or "?" in ligand_structures_cfg: + ligand_paths = glob.glob(ligand_structures_cfg) + else: + lpath = Path(ligand_structures_cfg) + if lpath.is_file(): + ligand_paths = [str(lpath)] + elif lpath.is_dir(): + ligand_paths = list(glob.glob(str(lpath / "*ligand.pt"))) + + # Build mapping: for each ligand file, find matching protein file + # Convention: *_ligand.pt <-> *_protein.pt (same prefix) + ligand_by_prefix = {} + for lp in ligand_paths: + prefix = Path(lp).stem.replace("_ligand", "") + ligand_by_prefix[prefix] = lp + + for sp in structure_paths: + prefix = Path(sp).stem.replace("_protein", "") + if prefix in ligand_by_prefix: + ligand_file_map[sp] = ligand_by_prefix[prefix] + + logger.info( + f"Pocket AAR enabled: found {len(ligand_file_map)}/{len(structure_paths)} " + f"matching ligand files (threshold={pocket_distance_threshold} Å)" + ) + else: + logger.info("Pocket AAR disabled (no ligand_structures configured)") + + def _compute_pocket_mask(protein_coords, ligand_coords, protein_mask=None, threshold=5.0): + """Compute pocket mask: residues with CA within threshold of any ligand atom.""" + if protein_coords.dim() == 3: + ca_coords = protein_coords[:, 1, :] # CA atoms (index 1) + else: + ca_coords = protein_coords + distances = torch.cdist(ca_coords.unsqueeze(0), ligand_coords.unsqueeze(0)).squeeze(0) + min_distances = distances.min(dim=1).values + pocket_mask = min_distances < threshold + if protein_mask is not None: + pocket_mask = pocket_mask & protein_mask.bool() + return pocket_mask + + def _compute_aar(predicted_seq, ground_truth_seq, aar_mask=None): + """Compute amino acid recovery rate (0-1) with optional mask.""" + if aar_mask is not None: + aar_mask = aar_mask.bool() + if aar_mask.sum() == 0: + return float("nan") + predicted_seq = predicted_seq[aar_mask] + ground_truth_seq = ground_truth_seq[aar_mask] + if len(predicted_seq) == 0: + return float("nan") + return (predicted_seq == ground_truth_seq).float().mean().item() + # Initialize StructureBackboneTransform structure_transform = StructureBackboneTransform(max_length=cfg.generation.get("max_length", 512)) @@ -1299,6 +1591,12 @@ def _generate_inverse_folding( all_tm_scores = [] all_rmsd_scores = [] + # Pocket AAR aggregate statistics + all_aar_overall = [] + all_aar_pocket = [] + all_aar_nonpocket = [] + all_n_pocket_residues = [] + with torch.no_grad(): # Process structure files in batches for batch_start in range(0, len(structure_paths), batch_size): @@ -1311,6 +1609,7 @@ def _generate_inverse_folding( # Load structures from files batch_data = [] valid_indices = [] + max_len = cfg.generation.get("max_length", 512) for i, structure_path in enumerate(batch_paths): logger.info(f"Loading {structure_path}") @@ -1321,6 +1620,14 @@ def _generate_inverse_folding( try: structure_data = torch.load(structure_path, map_location="cpu") if structure_data is not None: + # Skip structures that exceed max_length before cropping + raw_length = structure_data["coords_res"].shape[0] + if raw_length > max_len: + logger.info( + f"Skipping structure {structure_path} - too long " + f"({raw_length} residues, maximum {max_len})" + ) + continue # Apply StructureBackboneTransform structure_data = structure_transform(structure_data) batch_data.append(structure_data) @@ -1333,6 +1640,14 @@ def _generate_inverse_folding( # Load PDB/CIF file using existing method structure_data = load_pdb(structure_path, add_batch_dim=False) if structure_data is not None: + # Skip structures that exceed max_length before cropping + raw_length = structure_data["coords_res"].shape[0] + if raw_length > max_len: + logger.info( + f"Skipping structure {structure_path} - too long " + f"({raw_length} residues, maximum {max_len})" + ) + continue # Apply StructureBackboneTransform structure_data = structure_transform(structure_data) batch_data.append(structure_data) @@ -1389,6 +1704,28 @@ def _generate_inverse_folding( logger.info(f"Batch {batch_idx + 1}: {B} structures, max length {max_length}") + # Load ligand coordinates for pocket AAR (if configured) + batch_ligand_coords = [] + for fvi in filtered_valid_indices: + protein_path = batch_paths[fvi] + ligand_path = ligand_file_map.get(protein_path) + if ligand_path is not None: + try: + ligand_data_loaded = torch.load(ligand_path, weights_only=False, map_location="cpu") + lig_coords = ligand_data_loaded.get( + "atom_coords", ligand_data_loaded.get("coords", ligand_data_loaded.get("ligand_coords")) + ) + if lig_coords is not None: + batch_ligand_coords.append(lig_coords.to(device)) + else: + logger.warning(f"No ligand coordinates found in {ligand_path}") + batch_ligand_coords.append(None) + except Exception as e: + logger.warning(f"Failed to load ligand file {ligand_path}: {e}") + batch_ligand_coords.append(None) + else: + batch_ligand_coords.append(None) + # Loop over designs - generate multiple independent designs per structure for design_idx in range(n_designs_per_structure): if n_designs_per_structure > 1: @@ -1404,38 +1741,118 @@ def _generate_inverse_folding( f"Trial {trial + 1}/{n_trials} for batch {batch_idx + 1}, design {design_idx + 1}/{n_designs_per_structure}" ) - # Generate sequences - generate_sample = model.generate_sample( - length=max_length, - num_samples=B, - inverse_folding=True, - nsteps=nsteps, - input_structure_coords=coords_res, - input_mask=mask, - input_indices=indices, - temperature_seq=gen_cfg.get("temperature_seq", 0.5), - stochasticity_seq=gen_cfg.get("stochasticity_seq", 20), - asynchronous_sampling=gen_cfg.get("asynchronous_sampling", False), - ) + # Retry loop for quality control (like unconditional generation) + if gen_cfg.get("enable_sequence_token_check", True): + max_retries = gen_cfg.get("sequence_token_check_retries", 10) + else: + max_retries = 0 + retry_count = 0 + valid_sequences_generated = False + + while retry_count <= max_retries and not valid_sequences_generated: + if retry_count > 0: + logger.info(f" Retry attempt {retry_count}/{max_retries}") + + # Generate sequences + generate_sample = model.generate_sample( + length=max_length, + num_samples=B, + inverse_folding=True, + nsteps=nsteps, + input_structure_coords=coords_res, + input_mask=mask, + input_indices=indices, + temperature_seq=gen_cfg.get("temperature_seq", 0.5), + stochasticity_seq=gen_cfg.get("stochasticity_seq", 20), + asynchronous_sampling=gen_cfg.get("asynchronous_sampling", False), + ) - # Decode structures - decoded_x = model.decode_structure(generate_sample, mask) + # Decode structures + decoded_x = model.decode_structure(generate_sample, mask) - # Extract coordinates - x_recon_xyz = None - for decoder_name in decoded_x: - if "vit_decoder" == decoder_name: - x_recon_xyz = decoded_x[decoder_name] - break + # Extract coordinates + x_recon_xyz = None + for decoder_name in decoded_x: + if "vit_decoder" == decoder_name: + x_recon_xyz = decoded_x[decoder_name] + break - # Extract sequences - if generate_sample["sequence_logits"].shape[-1] == 33: - seq = convert_lobster_aa_tokenization_to_standard_aa( - generate_sample["sequence_logits"], device=device - ) - else: - seq = generate_sample["sequence_logits"].argmax(dim=-1) - seq[seq > 21] = 20 + # Extract sequences + if generate_sample["sequence_logits"].shape[-1] == 33: + seq = convert_lobster_aa_tokenization_to_standard_aa( + generate_sample["sequence_logits"], device=device + ) + else: + seq = generate_sample["sequence_logits"].argmax(dim=-1) + seq[seq > 21] = 20 + + # Quality control check for invalid tokens + is_valid, error_msg = _check_sequence_tokens(seq, mask, "inverse folding") + if not is_valid: + logger.warning(f" Quality control FAILED: {error_msg}") + retry_count += 1 + if retry_count > max_retries: + logger.warning( + f" Max retries ({max_retries}) exceeded for trial {trial + 1}. " + f"Using argmax without 'X' token as fallback." + ) + + # Store original sequence (with X tokens) for comparison + seq_with_x = seq.clone() + + # Re-extract sequences with X-masked logits to avoid unknown tokens + if generate_sample["sequence_logits"].shape[-1] == 33: + # Mask token 24 (X) in 33-token scheme + masked_logits = generate_sample["sequence_logits"].clone() + masked_logits[..., 24] = float("-inf") + seq = convert_lobster_aa_tokenization_to_standard_aa(masked_logits, device=device) + else: + # Mask token 20 (X) in standard scheme + masked_logits = generate_sample["sequence_logits"].clone() + masked_logits[..., 20] = float("-inf") + seq = masked_logits.argmax(dim=-1) + seq[seq > 21] = 20 + + # Log sequences for visual inspection + logger.info(" Sequence comparison (original with X vs. X-masked):") + for i in range(seq.shape[0]): + # Convert token indices to amino acid strings + valid_positions = mask[i] == 1 + seq_with_x_str = "".join( + [ + restype_order_with_x_inv.get(int(t), "?") + for t in seq_with_x[i, valid_positions].cpu().numpy() + ] + ) + seq_masked_str = "".join( + [ + restype_order_with_x_inv.get(int(t), "?") + for t in seq[i, valid_positions].cpu().numpy() + ] + ) + + # Count X tokens + num_x_before = seq_with_x_str.count("X") + num_x_after = seq_masked_str.count("X") + + logger.info(f" Sample {i}:") + logger.info( + f" Before (X count={num_x_before}): {seq_with_x_str[:100]}{'...' if len(seq_with_x_str) > 100 else ''}" + ) + logger.info( + f" After (X count={num_x_after}): {seq_masked_str[:100]}{'...' if len(seq_masked_str) > 100 else ''}" + ) + + valid_sequences_generated = True + break + logger.warning(f" Regenerating sequences (retry {retry_count}/{max_retries})") + continue + else: + logger.info(" Quality control PASSED: All sequences contain valid amino acids") + valid_sequences_generated = True + + # Extract structure tokens (argmax) + structure_tokens = generate_sample["structure_logits"].argmax(dim=-1) # Shape: [batch_size, length] # Calculate TM-scores for this trial trial_tm_scores = [] @@ -1506,6 +1923,13 @@ def _generate_inverse_folding( chain_group=chain_group, # Specify which chains to predict ) + # Skip if sequence too long for ESMFold + if result is None: + logger.warning( + f"Chain group {chain_group} exceeds ESMFold max length, skipping ESMFold validation" + ) + continue + chain_group_results.append(result) logger.info( @@ -1558,12 +1982,22 @@ def _generate_inverse_folding( restype_order_inv=restype_order_with_x_inv, ) - trial_tm_scores.append(result["folded_structure_metrics"]["_tm_score"]) - outputs = result["esmfold_outputs"] - pred_coords = result["pred_coords"] - trial_folded_structure_metrics = result["folded_structure_metrics"] - - logger.info(f"TM-score: {result['folded_structure_metrics']['_tm_score']:.3f}") + # Skip if sequence too long for ESMFold + if result is None: + logger.warning( + f"Structure exceeds ESMFold max length ({len(seq_i)} residues with linkers), " + "skipping ESMFold validation for this batch" + ) + trial_tm_scores.append(float("nan")) + outputs = None + pred_coords = None + trial_folded_structure_metrics = {"_skipped": True, "_reason": "sequence_too_long"} + else: + trial_tm_scores.append(result["folded_structure_metrics"]["_tm_score"]) + outputs = result["esmfold_outputs"] + pred_coords = result["pred_coords"] + trial_folded_structure_metrics = result["folded_structure_metrics"] + logger.info(f"TM-score: {result['folded_structure_metrics']['_tm_score']:.3f}") else: # If ESMFold is not available, use generated structure as fallback @@ -1659,15 +2093,79 @@ def _generate_inverse_folding( all_percent_identities.extend(batch_percent_identities) + # Compute pocket AAR if ligand data is available + if ligand_file_map and original_sequences: + batch_aar_overall = [] + batch_aar_pocket = [] + batch_aar_nonpocket = [] + batch_n_pocket = [] + + for i, (orig_seq, gen_seq) in enumerate(zip(original_sequences, seq)): + orig_len = len(orig_seq) + gen_len = len(gen_seq) + min_len = min(orig_len, gen_len) + + if min_len > 0: + orig_seq_dev = orig_seq[:min_len].to(device) + gen_seq_dev = gen_seq[:min_len].to(device) + + # Overall AAR (0-1 scale) + aar_overall = _compute_aar(gen_seq_dev, orig_seq_dev) + batch_aar_overall.append(aar_overall) + + # Pocket / non-pocket AAR + lig_coords_i = batch_ligand_coords[i] if i < len(batch_ligand_coords) else None + if lig_coords_i is not None: + # Get protein coords (unpadded) + orig_coords_i = filtered_batch_data[i]["coords_res"][:min_len].to(device) + pocket_mask_i = _compute_pocket_mask( + orig_coords_i, lig_coords_i, threshold=pocket_distance_threshold + ) + non_pocket_mask_i = ~pocket_mask_i + + n_pocket = int(pocket_mask_i.sum().item()) + batch_n_pocket.append(n_pocket) + + aar_pocket = _compute_aar(gen_seq_dev, orig_seq_dev, pocket_mask_i) + aar_nonpocket = _compute_aar(gen_seq_dev, orig_seq_dev, non_pocket_mask_i) + batch_aar_pocket.append(aar_pocket) + batch_aar_nonpocket.append(aar_nonpocket) + + logger.info( + f" AAR overall: {aar_overall:.3f}, " + f"pocket: {aar_pocket:.3f} ({n_pocket} residues), " + f"non-pocket: {aar_nonpocket:.3f}" + ) + else: + batch_aar_pocket.append(float("nan")) + batch_aar_nonpocket.append(float("nan")) + batch_n_pocket.append(0) + else: + batch_aar_overall.append(float("nan")) + batch_aar_pocket.append(float("nan")) + batch_aar_nonpocket.append(float("nan")) + batch_n_pocket.append(0) + + all_aar_overall.extend(batch_aar_overall) + all_aar_pocket.extend(batch_aar_pocket) + all_aar_nonpocket.extend(batch_aar_nonpocket) + all_n_pocket_residues.extend(batch_n_pocket) + # Write sequences to CSV if csv_writer is not None: # Convert generated sequences to strings generated_sequence_strs = [] + structure_token_strs = [] for i in range(B): seq_i = seq[i, mask[i] == 1] sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_i]) generated_sequence_strs.append(sequence_str) + # Convert structure tokens to comma-separated string + tokens_i = structure_tokens[i, mask[i] == 1] + tokens_str = ",".join([str(t.item()) for t in tokens_i]) + structure_token_strs.append(tokens_str) + # Convert original sequences to strings original_sequence_strs = [] for orig_seq in original_sequences: @@ -1688,6 +2186,7 @@ def _generate_inverse_folding( input_structure=[Path(batch_paths[i]).stem for i in filtered_valid_indices], trial_number=best_trial["trial"] + 1, percent_identities=batch_percent_identities, + latent_generator_tokens=structure_token_strs, ) # Save results @@ -1710,8 +2209,14 @@ def _generate_inverse_folding( if plm_fold is not None: logger.info(f"Validating batch {batch_idx + 1} with ESMFold (reusing trial results)...") + # Check if ESMFold was skipped due to sequence length + if best_trial["folded_structure_metrics"].get("_skipped", False): + logger.info( + f"Skipping ESMFold validation (reason: {best_trial['folded_structure_metrics'].get('_reason', 'unknown')})" + ) + batch_metrics = best_trial["folded_structure_metrics"] # Reuse ESMFold results from the best trial - if ( + elif ( best_trial["folded_structure_metrics"] is not None and best_trial["esmfold_pred_coords"] is not None ): @@ -1795,8 +2300,8 @@ def _generate_inverse_folding( max_length=max_length, ) - # Collect metrics for aggregate statistics - if batch_metrics: + # Collect metrics for aggregate statistics (skip if ESMFold was skipped) + if batch_metrics and not batch_metrics.get("_skipped", False): all_plddt_scores.append(batch_metrics["_plddt"]) all_predicted_aligned_errors.append(batch_metrics["_predicted_aligned_error"]) all_tm_scores.append(batch_metrics["_tm_score"]) @@ -1864,6 +2369,36 @@ def _generate_inverse_folding( logger.info("=" * 80) + # Report pocket AAR statistics if available + if all_aar_overall: + import math + + valid_aar_overall = [x for x in all_aar_overall if not math.isnan(x)] + valid_aar_pocket = [x for x in all_aar_pocket if not math.isnan(x)] + valid_aar_nonpocket = [x for x in all_aar_nonpocket if not math.isnan(x)] + valid_n_pocket = [x for x in all_n_pocket_residues if x > 0] + + logger.info("") + logger.info("--- Pocket Amino Acid Recovery (AAR) ---") + if valid_aar_overall: + avg_aar = sum(valid_aar_overall) / len(valid_aar_overall) + logger.info(f" Overall AAR: {avg_aar:.4f} ({avg_aar * 100:.2f}%) (n={len(valid_aar_overall)})") + if valid_aar_pocket: + avg_pocket = sum(valid_aar_pocket) / len(valid_aar_pocket) + logger.info(f" Pocket AAR: {avg_pocket:.4f} ({avg_pocket * 100:.2f}%) (n={len(valid_aar_pocket)})") + if valid_aar_nonpocket: + avg_nonpocket = sum(valid_aar_nonpocket) / len(valid_aar_nonpocket) + logger.info( + f" Non-pocket AAR: {avg_nonpocket:.4f} ({avg_nonpocket * 100:.2f}%) (n={len(valid_aar_nonpocket)})" + ) + if valid_aar_pocket and valid_aar_nonpocket: + delta = avg_pocket - avg_nonpocket + logger.info(f" Delta (pocket - non-pocket): {delta:+.4f} ({delta * 100:+.2f}%)") + if valid_n_pocket: + avg_pocket_size = sum(valid_n_pocket) / len(valid_n_pocket) + logger.info(f" Average pocket size: {avg_pocket_size:.1f} residues") + logger.info("=" * 80) + # Write aggregate statistics to CSV if csv_writer is not None: logger.info("Writing inverse folding aggregate statistics to CSV...") @@ -1877,6 +2412,13 @@ def _generate_inverse_folding( "rmsd": all_rmsd_scores, } + # Add pocket AAR metrics if available + if all_aar_overall: + metric_lists["aar_overall"] = all_aar_overall + metric_lists["aar_pocket"] = all_aar_pocket + metric_lists["aar_nonpocket"] = all_aar_nonpocket + metric_lists["n_pocket_residues"] = [float(x) for x in all_n_pocket_residues] + # Calculate aggregate statistics aggregate_stats = calculate_aggregate_stats(metric_lists) @@ -1933,8 +2475,8 @@ def _generate_forward_folding( structure_paths.extend(glob.glob(str(path / "*.pt"))) else: raise ValueError(f"Input path does not exist: {input_structures}") - elif isinstance(input_structures, (list, tuple)): - # List of paths + elif isinstance(input_structures, (list, tuple, ListConfig)): + # List of paths (includes OmegaConf ListConfig) for path_str in input_structures: path = Path(path_str) if path.is_file(): @@ -2113,8 +2655,12 @@ def _generate_forward_folding( seq = generate_sample["sequence_logits"].argmax(dim=-1) seq[seq > 21] = 20 - # Calculate TM-scores for this trial + # Extract structure tokens (argmax) + structure_tokens = generate_sample["structure_logits"].argmax(dim=-1) # Shape: [batch_size, length] + + # Calculate TM-scores and RMSDs for this trial trial_tm_scores = [] + trial_rmsd_scores = [] for i in range(B): # Get original and generated coordinates orig_coords = coords_res[i, mask[i] == 1, :, :] # Original structure @@ -2125,7 +2671,6 @@ def _generate_forward_folding( sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_i]) # Calculate TM-Score using TM-align - tm_out = tm_align( gen_coords[:, 1, :].cpu().numpy(), # CA atoms of generated structure orig_coords[:, 1, :].detach().cpu().numpy(), # CA atoms of original structure @@ -2133,14 +2678,26 @@ def _generate_forward_folding( sequence_str, ) trial_tm_scores.append(tm_out.tm_norm_chain1) - logger.info(f"TM-Score: {tm_out.tm_norm_chain1:.3f}, RMSD: {tm_out.rmsd:.2f} Å") + + # Calculate RMSD using Kabsch alignment (all backbone atoms) + rmsd = align_and_compute_rmsd( + coords1=gen_coords, + coords2=orig_coords, + mask=None, # Use all positions + return_aligned=False, + device=device, + ) + trial_rmsd_scores.append(rmsd) + logger.info(f"TM-Score: {tm_out.tm_norm_chain1:.3f}, RMSD: {rmsd:.2f} Å") # Store trial results best_trial_results.append( { "trial": trial, "tm_scores": trial_tm_scores, + "rmsd_scores": trial_rmsd_scores, "avg_tm_score": sum(trial_tm_scores) / len(trial_tm_scores), + "avg_rmsd": sum(trial_rmsd_scores) / len(trial_rmsd_scores), "generate_sample": generate_sample, "x_recon_xyz": x_recon_xyz, "seq": seq, @@ -2158,15 +2715,24 @@ def _generate_forward_folding( x_recon_xyz = best_trial["x_recon_xyz"] seq = best_trial["seq"] + # Extract structure tokens from best trial (argmax) + structure_tokens = generate_sample["structure_logits"].argmax(dim=-1) # Shape: [batch_size, length] + # Write sequences to CSV if csv_writer is not None: # Convert generated sequences to strings generated_sequence_strs = [] + structure_token_strs = [] for i in range(B): seq_i = seq[i, mask[i] == 1] sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_i]) generated_sequence_strs.append(sequence_str) + # Convert structure tokens to comma-separated string + tokens_i = structure_tokens[i, mask[i] == 1] + tokens_str = ",".join([str(t.item()) for t in tokens_i]) + structure_token_strs.append(tokens_str) + # Convert original sequences to strings (from input structures) original_sequence_strs = [] for i, data in enumerate(filtered_batch_data): @@ -2183,6 +2749,7 @@ def _generate_forward_folding( run_id=f"forward_folding_batch_{batch_idx:03d}", input_structure=[Path(batch_paths[i]).stem for i in filtered_valid_indices], trial_number=best_trial["trial"] + 1, + latent_generator_tokens=structure_token_strs, ) # Save generated and original structures @@ -2220,18 +2787,27 @@ def _generate_forward_folding( seq_i = seq[i, mask[i] == 1] sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_i]) - # Calculate TM-Score and RMSD using TM-align - + # Calculate TM-Score using TM-align tm_out = tm_align( gen_coords[:, 1, :].cpu().numpy(), # CA atoms of generated structure orig_coords[:, 1, :].detach().cpu().numpy(), # CA atoms of original structure sequence_str, sequence_str, ) - logger.info(f"Sequence: {sequence_str}") - logger.info(f"TM-Score: {tm_out.tm_norm_chain1:.3f}, RMSD: {tm_out.rmsd:.2f} Å") batch_tm_scores.append(tm_out.tm_norm_chain1) - batch_rmsd_scores.append(tm_out.rmsd) + + # Calculate RMSD using Kabsch alignment (all backbone atoms) + rmsd = align_and_compute_rmsd( + coords1=gen_coords, + coords2=orig_coords, + mask=None, # Use all positions + return_aligned=False, + device=device, + ) + batch_rmsd_scores.append(rmsd) + + logger.info(f"Sequence: {sequence_str}") + logger.info(f"TM-Score: {tm_out.tm_norm_chain1:.3f}, RMSD: {rmsd:.2f} Å") # Collect metrics for aggregate statistics all_tm_scores.extend(batch_tm_scores) @@ -2345,8 +2921,8 @@ def _generate_inpainting( structure_paths.extend(glob.glob(str(path / "*.pt"))) else: raise ValueError(f"Input path does not exist: {input_structures}") - elif isinstance(input_structures, (list, tuple)): - # List of paths + elif isinstance(input_structures, (list, tuple, ListConfig)): + # List of paths (includes OmegaConf ListConfig) for path_str in input_structures: path = Path(path_str) if path.is_file(): @@ -2627,6 +3203,9 @@ def _generate_inpainting( seq = generate_sample["sequence_logits"].argmax(dim=-1) seq[seq > 21] = 20 + # Extract structure tokens (argmax) + structure_tokens = generate_sample["structure_logits"].argmax(dim=-1) # Shape: [batch_size, length] + # Calculate TM-scores for this trial trial_tm_scores = [] trial_rmsd_inpainted = [] @@ -2717,6 +3296,13 @@ def _generate_inpainting( chain_group=chain_group, # Specify which chains to predict ) + # Skip if sequence too long for ESMFold + if result is None: + logger.warning( + f"Chain group {chain_group} exceeds ESMFold max length, skipping ESMFold validation" + ) + continue + chain_group_results.append(result) logger.info( @@ -2784,19 +3370,31 @@ def _generate_inpainting( inpainting_mask_struc_i=inpaint_mask_struc_i, ) - # Update coordinates with aligned version - x_recon_xyz[i, mask[i] == 1] = result["gen_coords_aligned"] - - trial_tm_scores.append(result["folded_structure_metrics"]["_tm_score"]) - trial_rmsd_inpainted.append(result["rmsd_inpainted"]) - outputs = result["esmfold_outputs"] - pred_coords = result["pred_coords"] - trial_folded_structure_metrics = result["folded_structure_metrics"] - - logger.info( - f"TM-score: {result['folded_structure_metrics']['_tm_score']:.3f}, " - f"Inpainted RMSD: {result['rmsd_inpainted']:.3f} Å" - ) + # Skip if sequence too long for ESMFold + if result is None: + logger.warning( + f"Structure exceeds ESMFold max length ({len(seq_i)} residues with linkers), " + "skipping ESMFold validation for this batch" + ) + trial_tm_scores.append(float("nan")) + trial_rmsd_inpainted.append(float("nan")) + outputs = None + pred_coords = None + trial_folded_structure_metrics = {"_skipped": True, "_reason": "sequence_too_long"} + else: + # Update coordinates with aligned version + x_recon_xyz[i, mask[i] == 1] = result["gen_coords_aligned"] + + trial_tm_scores.append(result["folded_structure_metrics"]["_tm_score"]) + trial_rmsd_inpainted.append(result["rmsd_inpainted"]) + outputs = result["esmfold_outputs"] + pred_coords = result["pred_coords"] + trial_folded_structure_metrics = result["folded_structure_metrics"] + + logger.info( + f"TM-score: {result['folded_structure_metrics']['_tm_score']:.3f}, " + f"Inpainted RMSD: {result['rmsd_inpainted']:.3f} Å" + ) else: # Calculate TM-Score using TM-align @@ -2892,11 +3490,17 @@ def _generate_inpainting( if csv_writer is not None: # Convert full generated sequences to strings generated_sequence_strs = [] + structure_token_strs = [] for i in range(B): seq_i = seq[i, mask[i] == 1] sequence_str = "".join([restype_order_with_x_inv[j.item()] for j in seq_i]) generated_sequence_strs.append(sequence_str) + # Convert structure tokens to comma-separated string + tokens_i = structure_tokens[i, mask[i] == 1] + tokens_str = ",".join([str(t.item()) for t in tokens_i]) + structure_token_strs.append(tokens_str) + # Convert full original sequences to strings original_sequence_strs = [] for orig_seq in original_sequences: @@ -2986,6 +3590,7 @@ def _generate_inpainting( trial_number=best_trial["trial"] + 1, percent_identities=batch_percent_identities_masked, masked_positions=masked_positions_per_seq, + latent_generator_tokens=structure_token_strs, ) # Save results @@ -3091,8 +3696,8 @@ def _generate_inpainting( max_length=max_length, ) - # Collect metrics for aggregate statistics - if batch_metrics: + # Collect metrics for aggregate statistics (skip if ESMFold was skipped) + if batch_metrics and not batch_metrics.get("_skipped", False): all_plddt_scores.append(batch_metrics["_plddt"]) all_predicted_aligned_errors.append(batch_metrics["_predicted_aligned_error"]) all_tm_scores.append(batch_metrics["_tm_score"]) @@ -3256,9 +3861,374 @@ def _generate_inpainting( logger.debug(f"Correlation plots not applicable: {e}") -def _generate_binders(model, cfg: DictConfig, device: torch.device, output_dir: Path, plm_fold=None) -> None: - """Generate binders.""" - raise NotImplementedError("Binder generation is not implemented") +def _generate_binders( + model, cfg: DictConfig, device: torch.device, output_dir: Path, plm_fold=None, csv_writer=None, plotter=None +) -> None: + """Generate binders for target protein structures.""" + logger.info("Starting binder design generation...") + + # Get input structure paths + input_structures = cfg.generation.input_structures + if not input_structures: + raise ValueError("input_structures must be provided for binder_design mode") + + # Handle different input formats (same as inpainting mode) + structure_paths = [] + if isinstance(input_structures, str): + if "*" in input_structures or "?" in input_structures: + # Glob pattern + structure_paths = glob.glob(input_structures) + else: + # Single file or directory + path = Path(input_structures) + if path.is_file(): + structure_paths = [str(path)] + elif path.is_dir(): + # Find all structure files in directory + structure_paths = list(glob.glob(str(path / "*.pdb"))) + structure_paths.extend(glob.glob(str(path / "*.cif"))) + else: + raise ValueError(f"Input path does not exist: {input_structures}") + elif isinstance(input_structures, (list, tuple, ListConfig)): + # List of paths + for path_str in input_structures: + path = Path(path_str) + if path.is_file(): + structure_paths.append(str(path)) + else: + logger.warning(f"Skipping non-existent file: {path_str}") + else: + raise ValueError(f"Invalid input_structures format: {type(input_structures)}") + + if not structure_paths: + raise ValueError("No valid structure files found in input_structures") + + logger.info(f"Found {len(structure_paths)} structure(s) to process") + + # Get configuration parameters + gen_cfg = cfg.generation + target_chain = gen_cfg.get("target_chain") + binder_length = gen_cfg.get("binder_length") + epitope_indices = gen_cfg.get("epitope_indices", None) + nsteps = gen_cfg.get("nsteps", 200) + n_trials = gen_cfg.get("n_trials", 1) + n_designs_per_structure = gen_cfg.get("n_designs_per_structure", 1) + + if not target_chain: + raise ValueError("target_chain must be specified for binder_design mode") + if not binder_length: + raise ValueError("binder_length must be specified for binder_design mode") + + logger.info(f"Target chain: {target_chain}") + logger.info(f"Binder length: {binder_length}") + if epitope_indices: + logger.info(f"Epitope indices: {epitope_indices}") + logger.info(f"Generation steps: {nsteps}") + logger.info(f"Designs per structure: {n_designs_per_structure}") + + # Initialize transforms + structure_transform = StructureBackboneTransform(max_length=gen_cfg.get("max_length", 512)) + tokenizer_transform = AminoAcidTokenizerTransform(max_length=gen_cfg.get("max_length", 512)) + + # Process each structure + with torch.no_grad(): + for structure_idx, structure_path in enumerate(structure_paths): + logger.info(f"\n{'=' * 70}") + logger.info(f"Processing structure {structure_idx + 1}/{len(structure_paths)}") + logger.info(f"Input: {structure_path}") + logger.info(f"{'=' * 70}") + + # Load target structure + logger.info(f"Loading target structure from {structure_path}") + target_data = load_pdb(structure_path, add_batch_dim=False) + + if target_data is None: + logger.warning(f"Failed to load structure from {structure_path}, skipping") + continue + + # Apply transforms + target_data = structure_transform(target_data) + + # Check minimum length + if target_data["coords_res"].shape[0] < 30: + logger.warning(f"Structure too short ({target_data['coords_res'].shape[0]} residues), skipping") + continue + + # Identify target chain + try: + target_chain_idx, target_start, target_end = get_target_chain_info(target_data, target_chain) + logger.info(f"Target chain '{target_chain}' found:") + logger.info(f" Chain index: {target_chain_idx}") + logger.info(f" Residue range: {target_start}-{target_end}") + logger.info(f" Length: {target_end - target_start} residues") + except ValueError as e: + logger.error(str(e)) + continue + + # Extract only target chain from structure + # Note: StructureBackboneTransform renames 'chains_ids' to 'chains' + chains_key = "chains" if "chains" in target_data else "chains_ids" + target_chain_mask = target_data[chains_key] == target_chain_idx + target_data_filtered = { + "coords_res": target_data["coords_res"][target_chain_mask], + "sequence": target_data["sequence"][target_chain_mask], + chains_key: target_data[chains_key][target_chain_mask], + "real_chains": target_data["real_chains"][target_chain_mask], + "indices": target_data["indices"][target_chain_mask], + "mask": target_data["mask"][target_chain_mask], + } + + # Initialize binder position + if epitope_indices: + logger.info(f"Initializing binder with length {binder_length} near epitope") + logger.info(f" Epitope residue indices: {epitope_indices}") + logger.info(" Ball center: 5Å from epitope, radius: 12Å, min target distance: 5Å") + else: + logger.info(f"Initializing binder with length {binder_length} around target center of mass") + logger.info(" Ball radius: 12Å, min target distance: 5Å") + + binder_data = initialize_binder_at_origin( + binder_length, + device="cpu", + target_coords=target_data_filtered["coords_res"], + epitope_indices=epitope_indices, + ) + + # Get next chain index for binder + binder_chain_idx = get_next_chain_index(target_data_filtered) + logger.info(f"Binder will be assigned chain index: {binder_chain_idx}") + + # Create composite structure (target + binder) + logger.info("Creating composite structure (target + binder)") + + L_target = target_data_filtered["coords_res"].shape[0] + L_binder = binder_data["coords_res"].shape[0] + L_total = L_target + L_binder + + # Check max length + max_length = gen_cfg.get("max_length", 512) + if L_total > max_length: + logger.warning( + f"Total length {L_total} (target: {L_target}, binder: {L_binder}) " + f"exceeds max_length {max_length}. Skipping structure." + ) + continue + + # Concatenate all tensors + coords_res_combined = torch.cat([target_data_filtered["coords_res"], binder_data["coords_res"]], dim=0) + + sequence_combined = torch.cat([target_data_filtered["sequence"], binder_data["sequence"]], dim=0) + + mask_combined = torch.cat([target_data_filtered["mask"], binder_data["mask"]], dim=0) + + # Create chain IDs for binder + binder_chain_ids = torch.full((L_binder,), binder_chain_idx, dtype=target_data_filtered[chains_key].dtype) + chains_ids_combined = torch.cat([target_data_filtered[chains_key], binder_chain_ids], dim=0) + + # Create indices for binder + binder_indices = torch.arange( + binder_chain_idx, binder_chain_idx + L_binder, dtype=target_data_filtered["indices"].dtype + ) + indices_combined = torch.cat([target_data_filtered["indices"], binder_indices], dim=0) + + logger.info("Composite structure created:") + logger.info(f" Total length: {L_total} ({L_target} target + {L_binder} binder)") + logger.info(f" Target chain index: {target_chain_idx}") + logger.info(f" Binder chain index: {binder_chain_idx}") + + # Save initial structure (before generation) + structure_name = Path(structure_path).stem + initial_structure_path = output_dir / f"{structure_name}_initial_structure.pdb" + writepdb(str(initial_structure_path), coords_res_combined, sequence_combined) + logger.info(f"Saved initial structure: {initial_structure_path}") + + # Add batch dimension and move to device + coords_res = coords_res_combined.unsqueeze(0).to(device) + sequence = sequence_combined.unsqueeze(0).to(device) + mask = mask_combined.unsqueeze(0).to(device) + chains_ids = chains_ids_combined.unsqueeze(0).to(device) + indices = indices_combined.unsqueeze(0).to(device) + + # Apply tokenizer to sequence + tokenized_data = tokenizer_transform({"sequence": sequence.squeeze(0).cpu()}) + sequence_tokenized = tokenized_data["sequence"].unsqueeze(0).to(device) + + # Create inpainting masks + # Note: First binder residue is kept fixed to preserve chain break token + logger.info("Creating inpainting masks (target=fixed, first binder token=fixed, rest of binder=generate)") + + mask_sequence, mask_structure = create_binder_inpainting_masks( + chains_ids, target_chain_idx, binder_chain_idx, device + ) + + # Verify masks + num_fixed = (mask_sequence == 0).sum().item() + num_generate = (mask_sequence == 1).sum().item() + logger.info(f" Fixed residues: {num_fixed} (target + 1 binder chain-break token)") + logger.info(f" Generate residues: {num_generate} (binder minus first token)") + + # Generate binder designs + for design_idx in range(n_designs_per_structure): + if n_designs_per_structure > 1: + logger.info(f"\n--- Design {design_idx + 1}/{n_designs_per_structure} ---") + + best_result = None + + for trial in range(n_trials): + if n_trials > 1: + logger.info(f"Trial {trial + 1}/{n_trials}") + + # Generate with inpainting + generate_sample = model.generate_sample( + length=L_total, + num_samples=1, + nsteps=nsteps, + temperature_seq=gen_cfg.get("temperature_seq", 0.5), + temperature_struc=gen_cfg.get("temperature_struc", 1.0), + stochasticity_seq=gen_cfg.get("stochasticity_seq", 20), + stochasticity_struc=gen_cfg.get("stochasticity_struc", 20), + inpainting=True, + input_structure_coords=coords_res, + input_sequence_tokens=sequence_tokenized, + input_mask=mask, + input_indices=indices, + inpainting_mask_sequence=mask_sequence, + inpainting_mask_structure=mask_structure, + asynchronous_sampling=gen_cfg.get("asynchronous_sampling", False), + ) + + # Decode structures + decoded_x = model.decode_structure(generate_sample, mask) + + # Extract coordinates + x_recon_xyz = None + for decoder_name in decoded_x: + if "vit_decoder" == decoder_name: + x_recon_xyz = decoded_x[decoder_name] + break + + if x_recon_xyz is None: + logger.error("No vit_decoder output found, skipping this trial") + continue + + # Extract coordinates (B, L, 3, 3) - N, CA, C atoms + gen_coords = x_recon_xyz[:, :, [0, 1, 2], :] + + # Extract sequences + if generate_sample["sequence_logits"].shape[-1] == 33: + gen_sequence = convert_lobster_aa_tokenization_to_standard_aa( + generate_sample["sequence_logits"], device=device + ) + else: + gen_sequence = generate_sample["sequence_logits"].argmax(dim=-1) + gen_sequence[gen_sequence > 21] = 20 + + # Store result + result = { + "coords": gen_coords, + "sequence": gen_sequence, + "mask": mask, + "chains_ids": chains_ids, + "indices": indices, + } + + # For now, just keep the first/only trial + best_result = result + if n_trials == 1: + break + + # Save outputs + structure_name = Path(structure_path).stem + prefix = f"{structure_name}_design{design_idx:03d}" + + gen_coords = best_result["coords"] + gen_sequence = best_result["sequence"] + + # Save complete complex + complex_path = output_dir / f"{prefix}_complex.pdb" + writepdb(str(complex_path), gen_coords[0], gen_sequence[0]) + logger.info(f"Saved complex: {complex_path}") + + # Save binder alone + binder_mask = chains_ids[0] == binder_chain_idx + binder_coords = gen_coords[0, binder_mask] + binder_sequence = gen_sequence[0, binder_mask] + binder_path = output_dir / f"{prefix}_binder.pdb" + writepdb(str(binder_path), binder_coords, binder_sequence) + logger.info(f"Saved binder: {binder_path}") + + # Save target alone (for reference) + target_mask = chains_ids[0] == target_chain_idx + target_coords = gen_coords[0, target_mask] + target_sequence = gen_sequence[0, target_mask] + target_path = output_dir / f"{prefix}_target.pdb" + writepdb(str(target_path), target_coords, target_sequence) + logger.info(f"Saved target: {target_path}") + + # Validate with ESMFold if enabled + if gen_cfg.get("use_esmfold", False) and plm_fold is not None: + logger.info("Validating with ESMFold...") + + # Validate the complex (target + binder together) + try: + # Get chain groups for validation + esmfold_chain_groups = gen_cfg.get("esmfold_chain_groups", None) + if esmfold_chain_groups is None: + # Default: validate target + binder together + esmfold_chain_groups = [[target_chain_idx, binder_chain_idx]] + + # Call ESMFold validation + result = predict_structure_with_esmfold( + plm_fold=plm_fold, + seq_i=gen_sequence[0], + chains_i=chains_ids[0], + orig_coords=coords_res[0], # Original composite structure + gen_coords=gen_coords[0], # Generated composite structure + mask_i=mask[0], + cfg=cfg, + device=device, + restype_order_inv=restype_order_with_x_inv, + inpainting_mask_seq_i=mask_sequence[0], + inpainting_mask_struc_i=mask_structure[0], + chain_group=esmfold_chain_groups[0] if esmfold_chain_groups else None, + ) + + # Skip if sequence too long for ESMFold + if result is None: + logger.warning("Structure exceeds ESMFold max length, skipping ESMFold validation") + continue + + logger.info("ESMFold validation metrics:") + if "folded_structure_metrics" in result: + for key, value in result["folded_structure_metrics"].items(): + if isinstance(value, (int, float)): + logger.info(f" {key}: {value:.4f}") + else: + logger.info(f" {key}: {value}") + + # Save ESMFold predicted structure + if "pred_coords" in result and result["pred_coords"] is not None: + pred_coords = result["pred_coords"] + # pred_coords shape is (1, L, 3, 3) or (L, 3, 3) + if pred_coords.dim() == 4: + pred_coords = pred_coords.squeeze(0) # Remove batch dim + + # Save ESMFold predicted complex + esmfold_path = output_dir / f"{prefix}_esmfold.pdb" + writepdb(str(esmfold_path), pred_coords, gen_sequence[0]) + logger.info(f"Saved ESMFold prediction: {esmfold_path}") + + # Save ESMFold predicted binder only + if pred_coords.shape[0] == gen_sequence[0].shape[0]: + esmfold_binder_coords = pred_coords[binder_mask] + esmfold_binder_path = output_dir / f"{prefix}_esmfold_binder.pdb" + writepdb(str(esmfold_binder_path), esmfold_binder_coords, binder_sequence) + logger.info(f"Saved ESMFold binder prediction: {esmfold_binder_path}") + + except Exception as e: + logger.warning(f"ESMFold validation failed: {e}") + + logger.info("\nBinder design generation completed!") def _validate_with_esmfold( diff --git a/src/lobster/data/_collate_structure.py b/src/lobster/data/_collate_structure.py index 7c9f5da2..bbe7ccd4 100644 --- a/src/lobster/data/_collate_structure.py +++ b/src/lobster/data/_collate_structure.py @@ -5,22 +5,92 @@ def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: - """Collate fn for batching protein backbone data.""" - if "protein" and "ligand" in batch[0]: - ligand_batch = [bb_dict["ligand"] for bb_dict in batch] - batch = [bb_dict["protein"] for bb_dict in batch] - # make sure batch is not list of None - if batch[0] is not None: - protein_present = True - batch = collate_fn_backbone(batch) + """Collate fn for batching protein backbone data. + + Handles three cases: + 1. Protein-only samples (from StructureDataset): {"coords_res": ..., "mask": ..., ...} + 2. Protein+ligand samples (from LigandDataset): {"protein": {...}, "ligand": {...}} + 3. Ligand-only samples (from LigandDataset with no protein): {"protein": None, "ligand": {...}} + + When mixing different sample types in a batch, creates validity masks to track + which samples have valid protein/ligand data. Missing data is padded with zeros + to maintain consistent batch dimensions. + """ + # Check if any sample has the nested protein/ligand structure + has_nested_structure = any("ligand" in item or "protein" in item for item in batch) + + if has_nested_structure: + # Normalize all samples to the nested structure format + normalized_batch = [] + for item in batch: + if "ligand" in item or "protein" in item: + # Already in nested format + normalized_batch.append(item) + else: + # Protein-only sample from StructureDataset - wrap it + normalized_batch.append({"protein": item, "ligand": None}) + + # Now process the normalized batch + ligand_list = [bb_dict.get("ligand") for bb_dict in normalized_batch] + protein_list = [bb_dict.get("protein") for bb_dict in normalized_batch] + + # Create validity masks for mixed batches + protein_valid_mask = torch.tensor([p is not None for p in protein_list], dtype=torch.bool) + ligand_valid_mask = torch.tensor([l is not None for l in ligand_list], dtype=torch.bool) + + # Get valid samples for determining padding dimensions + valid_proteins = [p for p in protein_list if p is not None] + valid_ligands = [l for l in ligand_list if l is not None] + + # Create padded protein batch + if valid_proteins: + # Get max length from valid proteins + max_protein_len = max(p["coords_res"].shape[0] for p in valid_proteins) + + # Create dummy protein data for samples without protein + dummy_protein = { + "coords_res": torch.zeros(max_protein_len, 3, 3), + "mask": torch.zeros(max_protein_len), + "indices": torch.full((max_protein_len,), -1, dtype=torch.long), + "sequence": torch.zeros(max_protein_len, dtype=torch.long), + "chains": torch.full((max_protein_len,), -1, dtype=torch.long), + } + + # Replace None with dummy data + padded_protein_list = [p if p is not None else dummy_protein for p in protein_list] + protein_batch = collate_fn_backbone(padded_protein_list) else: - protein_present = False - ligand_batch = collate_fn_ligand(ligand_batch) - if protein_present: - # combine batch and ligand_batch - batch = {**batch, **ligand_batch} + protein_batch = {} + + # Create padded ligand batch + if valid_ligands: + # Get max length from valid ligands + max_ligand_len = max(l["atom_coords"].shape[0] for l in valid_ligands) + + # Create dummy ligand data for samples without ligand + dummy_ligand = { + "atom_coords": torch.zeros(max_ligand_len, 3), + "mask": torch.zeros(max_ligand_len), + "atom_indices": torch.full((max_ligand_len,), -1, dtype=torch.long), + } + # Add optional fields if present in valid ligands + if "element_indices" in valid_ligands[0]: + dummy_ligand["element_indices"] = torch.zeros(max_ligand_len, dtype=torch.long) + if "bond_matrix" in valid_ligands[0]: + dummy_ligand["bond_matrix"] = torch.zeros(max_ligand_len, max_ligand_len, dtype=torch.long) + + # Replace None with dummy data + padded_ligand_list = [l if l is not None else dummy_ligand for l in ligand_list] + ligand_batch = collate_fn_ligand(padded_ligand_list) else: - batch = ligand_batch + ligand_batch = {} + + # Combine batches + batch = {**protein_batch, **ligand_batch} + + # Add validity masks + batch["protein_valid_mask"] = protein_valid_mask + batch["ligand_valid_mask"] = ligand_valid_mask return batch max_length = max(bb_dict["coords_res"].shape[0] for bb_dict in batch) @@ -263,6 +333,7 @@ def collate_fn_ligand(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.T padded_ligand_mask = [] padded_ligand_indices = [] padded_element_indices = [] + padded_bond_matrices = [] max_length = max(atom_dict["atom_coords"].shape[0] for atom_dict in batch) for atom_dict in batch: @@ -311,6 +382,15 @@ def collate_fn_ligand(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.T ) ) + # Handle bond_matrix if present (shape: [N_atoms, N_atoms]) + if "bond_matrix" in atom_dict: + bond_matrix = atom_dict["bond_matrix"] + n_atoms = bond_matrix.shape[0] + # Pad bond_matrix to [max_length, max_length] + padded_bond = torch.zeros(max_length, max_length, dtype=bond_matrix.dtype) + padded_bond[:n_atoms, :n_atoms] = bond_matrix + padded_bond_matrices.append(padded_bond) + out = { "ligand_coords": torch.stack(padded_ligand_coords, dim=0), "ligand_mask": torch.stack(padded_ligand_mask, dim=0), @@ -320,6 +400,10 @@ def collate_fn_ligand(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.T if padded_element_indices: out["ligand_element_indices"] = torch.stack(padded_element_indices, dim=0) + # Add bond_matrix to output if present + if padded_bond_matrices: + out["bond_matrix"] = torch.stack(padded_bond_matrices, dim=0) + # Handle additional properties like radius_of_gyration if "radius_of_gyration" in batch[0]: out["radius_of_gyration"] = torch.tensor( diff --git a/src/lobster/data/_coord_structure_datamodule.py b/src/lobster/data/_coord_structure_datamodule.py index a0c04ff3..6c97aa6f 100644 --- a/src/lobster/data/_coord_structure_datamodule.py +++ b/src/lobster/data/_coord_structure_datamodule.py @@ -36,7 +36,7 @@ def __init__( datasets: Sequence[str] = None, *, transform_fn: Iterable[Callable] = None, - ligand_transforms: Iterable[Callable] = None, + ligand_transform_fn: Iterable[Callable] = None, lengths: Sequence[float] | None = (0.9, 0.05, 0.05), generator: Generator | None = None, seed: int = 0xDEADBEEF, @@ -59,7 +59,9 @@ def __init__( files_to_keep_list: list[str] | None = None, use_shards: bool = False, use_ligand_dataset: bool = False, + dataset_types: list[str] | None = None, buffer_size: int = 5, + stat_workers: int | None = None, ) -> None: """:param path_to_datasets: path to data set directories @@ -116,6 +118,14 @@ def __init__( :param is_relative_model: If ``True``, assumes training between two sequences and calls a relative representation data loader + :param use_ligand_dataset: If ``True``, use LigandDataset for all datasets + (default: ``False``). Deprecated - use dataset_types instead for mixed datasets. + + :param dataset_types: List of dataset types corresponding to each path in + path_to_datasets. Valid types: 'structure' (protein-only) or 'ligand' + (ligand-only or protein-ligand pairs). If None, uses use_ligand_dataset + for backwards compatibility (default: ``None``). + """ transforms = transform_fn super().__init__() @@ -150,6 +160,7 @@ def __init__( self._sampler = sampler if self._sampler is not None: self._shuffle = False + self._cluster_file = cluster_file self._cluster_file_list = cluster_file_list self._files_to_keep = files_to_keep @@ -164,6 +175,7 @@ def __init__( self._mlm = mlm self.repeat_count = repeat_count self.testing = testing + self.stat_workers = stat_workers if self.testing and not use_shards: self._path_to_datasets = [ "/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/train_dummy.pt", @@ -175,6 +187,25 @@ def __init__( self.use_shards = use_shards self.buffer_size = buffer_size self.use_ligand_dataset = use_ligand_dataset + + # Handle dataset_types with backwards compatibility + if dataset_types is not None: + if len(dataset_types) != len(self._path_to_datasets): + raise ValueError( + f"Length of dataset_types ({len(dataset_types)}) must match " + f"length of path_to_datasets ({len(self._path_to_datasets)})" + ) + self._dataset_types = dataset_types + logger.info(f"Using per-dataset types: {self._dataset_types}") + else: + # Backwards compatibility: use use_ligand_dataset for all datasets + if use_ligand_dataset: + self._dataset_types = ["ligand"] * len(self._path_to_datasets) + logger.info("Using ligand dataset for all paths (backwards compatibility mode)") + else: + self._dataset_types = ["structure"] * len(self._path_to_datasets) + logger.info("Using structure dataset for all paths (backwards compatibility mode)") + if transforms is None: logger.info("No transform function provided. Using default transform function: StructureBackboneTransform") self._transform_fn = StructureBackboneTransform(max_length=max_length) @@ -183,29 +214,54 @@ def __init__( transforms = list(transforms.values()) self._transform_fn = self.compose_transforms(transforms) - if ligand_transforms is None: + if ligand_transform_fn is None: logger.info( "No ligand transform function provided. Using default transform function: StructureLigandTransform" ) self._ligand_transform_fn = StructureLigandTransform(max_length=max_length) else: logger.info("Using custom ligand transform function.") - ligand_transforms = list(ligand_transforms.values()) - self._ligand_transform_fn = self.compose_transforms(ligand_transforms) + ligand_transform_fn = list(ligand_transform_fn.values()) + self._ligand_transform_fn = self.compose_transforms(ligand_transform_fn) logger.info( f"SequenceLightningDataModule: path_to_datasets={path_to_datasets}, root={root}, lengths={lengths}, seed={seed}, batch_size={batch_size}, max_length={max_length}, shuffle={shuffle}, sampler={sampler}, batch_sampler={batch_sampler}, num_workers={num_workers}, collate_fn={collate_fn}, use_shards={use_shards}" ) def _create_dataset( - self, path: str, is_train: bool = False, cluster_file: str | None = None, files_to_keep: str | None = None - ) -> StructureDataset | ShardedStructureDataset: - """Create either a regular or sharded dataset based on configuration.""" + self, + path: str, + is_train: bool = False, + cluster_file: str | None = None, + files_to_keep: str | None = None, + dataset_type: str | None = None, + ) -> StructureDataset | ShardedStructureDataset | LigandDataset: + """Create either a regular or sharded dataset based on configuration. + + Args: + path: Path to the dataset + is_train: Whether this is a training dataset + cluster_file: Optional cluster file for deduplication + files_to_keep: Optional list of files to keep + dataset_type: Type of dataset ('structure' or 'ligand'). If None, uses use_ligand_dataset for backwards compatibility. + + Returns: + Dataset instance (StructureDataset, LigandDataset, or ShardedStructureDataset) + """ if cluster_file is None: cluster_file = self._cluster_file if files_to_keep is None: files_to_keep = self._files_to_keep - logger.info(f"Creating dataset from {path} with cluster_file {cluster_file} and files_to_keep {files_to_keep}") + + # Backwards compatibility: if dataset_type not provided, use use_ligand_dataset + if dataset_type is None: + dataset_type = "ligand" if self.use_ligand_dataset else "structure" + + logger.info( + f"Creating dataset from {path} with cluster_file {cluster_file}, " + f"files_to_keep {files_to_keep}, dataset_type={dataset_type}" + ) + if self.use_shards: logger.info(f"Creating sharded dataset from {path}") return ShardedStructureDataset( @@ -217,14 +273,15 @@ def _create_dataset( ) else: logger.info(f"Creating regular dataset from {path}") - if self.use_ligand_dataset: + if dataset_type == "ligand": logger.info(f"Creating ligand dataset from {path}") return LigandDataset( root=path, + cluster_file=cluster_file if is_train else None, transform_protein=self._transform_fn, transform_ligand=self._ligand_transform_fn, ) - else: + elif dataset_type == "structure": logger.info(f"Creating structure dataset from {path}") return StructureDataset( root=path, @@ -232,7 +289,10 @@ def _create_dataset( testing=self.testing, cluster_file=cluster_file if is_train else None, files_to_keep=files_to_keep, + stat_workers=self.stat_workers, ) + else: + raise ValueError(f"Unknown dataset_type: {dataset_type}. Must be 'structure' or 'ligand'.") def setup(self, stage: str = "fit") -> None: if stage == "fit": @@ -268,13 +328,13 @@ def setup(self, stage: str = "fit") -> None: # For regular datasets, use ConcatDataset if self._cluster_file_list is not None: self._train_dataset = torch.utils.data.ConcatDataset( - # [self._create_dataset(p, is_train=True) for p in self._path_to_datasets if "train" in p] [ self._create_dataset( self._path_to_datasets[j], is_train=True, cluster_file=self._cluster_file_list[j], files_to_keep=self._files_to_keep_list[j], + dataset_type=self._dataset_types[j], ) for j in range(len(self._path_to_datasets)) if "train" in self._path_to_datasets[j] @@ -282,20 +342,46 @@ def setup(self, stage: str = "fit") -> None: ) else: self._train_dataset = torch.utils.data.ConcatDataset( - [self._create_dataset(p, is_train=True) for p in self._path_to_datasets if "train" in p] + [ + self._create_dataset( + self._path_to_datasets[j], + is_train=True, + dataset_type=self._dataset_types[j], + ) + for j in range(len(self._path_to_datasets)) + if "train" in self._path_to_datasets[j] + ] ) self._val_dataset = torch.utils.data.ConcatDataset( - [self._create_dataset(p) for p in self._path_to_datasets if "val" in p] + [ + self._create_dataset( + self._path_to_datasets[j], + dataset_type=self._dataset_types[j], + ) + for j in range(len(self._path_to_datasets)) + if "val" in self._path_to_datasets[j] + ] ) self._test_dataset = torch.utils.data.ConcatDataset( - [self._create_dataset(p) for p in self._path_to_datasets if "test" in p] + [ + self._create_dataset( + self._path_to_datasets[j], + dataset_type=self._dataset_types[j], + ) + for j in range(len(self._path_to_datasets)) + if "test" in self._path_to_datasets[j] + ] ) else: # iid split logger.info("Using iid splits.") if self.use_shards: # For sharded datasets, use the first path - dataset = self._create_dataset(self._path_to_datasets[0], is_train=True) + dataset = self._create_dataset( + self._path_to_datasets[0], + is_train=True, + dataset_type=self._dataset_types[0], + ) # Calculate split sizes total_size = len(dataset) train_size = int(total_size * self._lengths[0]) @@ -309,7 +395,14 @@ def setup(self, stage: str = "fit") -> None: ) else: # For regular datasets, use ConcatDataset - datasets = [self._create_dataset(p, is_train=True) for p in self._path_to_datasets] + datasets = [ + self._create_dataset( + self._path_to_datasets[j], + is_train=True, + dataset_type=self._dataset_types[j], + ) + for j in range(len(self._path_to_datasets)) + ] dataset = torch.utils.data.ConcatDataset(datasets) ( self._train_dataset, @@ -366,25 +459,13 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - # Only log if we're actually in a validation step - if hasattr(self, "trainer") and self.trainer.state.stage == "validate": - if not self.use_shards and isinstance(self._sampler, (functools.partial, RandomizedMinorityUpsampler)): - group_indices = [] - for dataset in self._val_dataset.datasets: - group_indices.extend(dataset.get_cluster_dict) - if isinstance(self._sampler, functools.partial): - self._sampler = self._sampler(group_indices) - else: - self._sampler = RandomizedMinorityUpsampler(group_indices) - logger.info(f"Val dataloader using RandomizedMinorityUpsampler with {len(group_indices)} clusters") - else: - logger.info("Using standard sampling strategy") - + # Validation uses standard sequential sampling (no custom sampler) + # to ensure reproducible evaluation and compatibility with DDP return DataLoader( self._val_dataset, batch_size=self._batch_size, shuffle=False, - sampler=self._sampler if not self.use_shards else None, + sampler=None, num_workers=self._num_workers, collate_fn=self._collate_fn, pin_memory=self._pin_memory, @@ -392,11 +473,13 @@ def val_dataloader(self) -> DataLoader: ) def test_dataloader(self) -> DataLoader: + # Test uses standard sequential sampling (no custom sampler) + # to ensure reproducible evaluation and compatibility with DDP return DataLoader( self._test_dataset, batch_size=self._batch_size, shuffle=False, - sampler=self._sampler if not self.use_shards else None, + sampler=None, num_workers=self._num_workers, collate_fn=self._collate_fn, pin_memory=self._pin_memory, @@ -404,11 +487,12 @@ def test_dataloader(self) -> DataLoader: ) def predict_dataloader(self) -> DataLoader: + # Predict uses standard sequential sampling (no custom sampler) return DataLoader( self._predict_dataset, batch_size=self._batch_size, shuffle=False, - sampler=self._sampler if not self.use_shards else None, + sampler=None, num_workers=self._num_workers, collate_fn=self._collate_fn, pin_memory=self._pin_memory, diff --git a/src/lobster/datasets/_ligand_dataset.py b/src/lobster/datasets/_ligand_dataset.py index 62af3e09..135e552a 100644 --- a/src/lobster/datasets/_ligand_dataset.py +++ b/src/lobster/datasets/_ligand_dataset.py @@ -18,11 +18,31 @@ class LigandDataset(Dataset): """Dataset class for ligand atom coordinates. Expects .pt files with a 'coords' key for atom coordinates. + + Parameters + ---------- + root : str | os.PathLike + Root directory containing ligand/protein .pt files. + cluster_file : str | os.PathLike, optional + Path to cluster file (.pt) mapping sample IDs to cluster IDs. + If provided, samples are grouped by cluster for balanced sampling. + If None, each sample is treated as its own cluster. + transform_protein : Callable, optional + Transform to apply to protein data. + transform_ligand : Callable, optional + Transform to apply to ligand data. + pre_transform : Callable, optional + Pre-transform to apply. + min_len : int + Minimum length filter (default: 1). + testing : bool + If True, limit dataset size for testing (default: False). """ def __init__( self, root: str | os.PathLike, + cluster_file: str | os.PathLike | None = None, transform_protein: Callable | None = None, transform_ligand: Callable | None = None, pre_transform: Callable | None = None, @@ -34,12 +54,14 @@ def __init__( lobster.ensure_package("torch_geometric", group="struct-gpu (or --extra struct-cpu)") self.root = pathlib.Path(root) + self.cluster_file = cluster_file self.transform_protein = transform_protein self.transform_ligand = transform_ligand self.pre_transform = pre_transform self.min_len = min_len self.testing = testing self._load_data() + self._build_cluster_dict() logger.info("Loaded ligand data points.") super().__init__(root, transform_protein, transform_ligand, pre_transform) @@ -90,20 +112,76 @@ def _load_data(self): def len(self) -> int: return len(self.dataset_filenames) - def __getitem__(self, idx: int): - if isinstance(self.dataset_filenames[idx], tuple): - x_ligand = torch.load(self.dataset_filenames[idx][0]) - x_protein = torch.load(self.dataset_filenames[idx][1]) - if self.transform_protein: - x_protein = self.transform_protein(x_protein) + def _get_sample_id(self, idx: int) -> str: + """Extract sample ID from filename for cluster lookup.""" + filename = self.dataset_filenames[idx] + if isinstance(filename, tuple): + # For protein-ligand pairs, use ligand file's ID + filename = filename[0] + # Extract ID: typically the first part before underscore + return pathlib.Path(filename).stem.split("_")[0] + + def _build_cluster_dict(self): + """Build cluster dictionary from cluster file or default to individual clusters.""" + if self.cluster_file is not None: + # Load cluster file: expects dict mapping sample_id -> cluster_id + cluster_mapping = torch.load(self.cluster_file) + logger.info(f"Loaded cluster file {self.cluster_file} with {len(cluster_mapping)} entries.") + + # Build cluster_dict as list of lists (indices grouped by cluster) + cluster_to_indices = {} + for idx in range(len(self.dataset_filenames)): + sample_id = self._get_sample_id(idx) + cluster_id = cluster_mapping.get(sample_id) + if cluster_id is not None: + if cluster_id not in cluster_to_indices: + cluster_to_indices[cluster_id] = [] + cluster_to_indices[cluster_id].append(idx) + + self.cluster_dict = list(cluster_to_indices.values()) + logger.info(f"Built {len(self.cluster_dict)} clusters from cluster file.") else: - x_protein = None - x_ligand = torch.load(self.dataset_filenames[idx]) - # pick a random 'conformer' in 'conformers' list - if "conformers" in x_ligand: - x_ligand = x_ligand["conformers"][np.random.randint(0, len(x_ligand["conformers"]))] - - if self.transform_ligand: - x_ligand = self.transform_ligand(x_ligand) - - return {"protein": x_protein, "ligand": x_ligand} + # No cluster file: each sample is its own cluster + self.cluster_dict = [[i] for i in range(len(self.dataset_filenames))] + logger.info(f"No cluster file: {len(self.cluster_dict)} samples (each as own cluster).") + + @property + def get_cluster_dict(self): + """Return cluster dict for compatibility with RandomizedMinorityUpsampler.""" + return self.cluster_dict + + def __getitem__(self, idx: int, _retry_count: int = 0): + max_retries = 5 + try: + if isinstance(self.dataset_filenames[idx], tuple): + x_ligand = torch.load(self.dataset_filenames[idx][0]) + x_protein = torch.load(self.dataset_filenames[idx][1]) + if self.transform_protein: + x_protein = self.transform_protein(x_protein) + else: + x_protein = None + x_ligand = torch.load(self.dataset_filenames[idx]) + # pick a random 'conformer' in 'conformers' list + if "conformers" in x_ligand: + x_ligand = x_ligand["conformers"][np.random.randint(0, len(x_ligand["conformers"]))] + + if self.transform_ligand: + x_ligand = self.transform_ligand(x_ligand) + + return {"protein": x_protein, "ligand": x_ligand} + except (EOFError, RuntimeError, Exception) as e: + # Handle corrupted files by trying a different random sample + filename = ( + self.dataset_filenames[idx] + if not isinstance(self.dataset_filenames[idx], tuple) + else self.dataset_filenames[idx][0] + ) + logger.warning(f"Failed to load file {filename}: {e}. Trying another sample.") + if _retry_count < max_retries: + # Pick a random different index + new_idx = np.random.randint(0, len(self.dataset_filenames)) + while new_idx == idx: + new_idx = np.random.randint(0, len(self.dataset_filenames)) + return self.__getitem__(new_idx, _retry_count + 1) + else: + raise RuntimeError(f"Failed to load data after {max_retries} retries. Last error: {e}") from e diff --git a/src/lobster/datasets/_structure_dataset.py b/src/lobster/datasets/_structure_dataset.py index 56ce9393..c4b026d1 100644 --- a/src/lobster/datasets/_structure_dataset.py +++ b/src/lobster/datasets/_structure_dataset.py @@ -6,6 +6,7 @@ import os import pathlib import pickle +import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path @@ -26,6 +27,97 @@ logger = logging.getLogger(__name__) +def _collect_file_paths_recursive(root_path: Path, exclude_patterns: list[str], handle_errors: bool = True): + """ + Recursively collect file paths using os.scandir (fast, single-threaded I/O). + + This only collects paths, deferring stat calls for parallel processing. + + Parameters + ---------- + root_path : Path + Directory to scan + exclude_patterns : list[str] + Patterns to exclude from filenames + handle_errors : bool + If True, log errors and continue. If False, let errors propagate. + + Yields + ------ + str + File path that matches criteria + """ + dir_excludes = {"__pycache__", "cache"} + + try: + with os.scandir(root_path) as entries: + for entry in entries: + try: + # Check if it's a directory (don't follow symlinks) + if entry.is_dir(follow_symlinks=False): + if not entry.name.startswith(".") and entry.name not in dir_excludes: + # Recursively scan subdirectory + yield from _collect_file_paths_recursive(Path(entry.path), exclude_patterns, handle_errors) + continue + + # Check if it's a file + if not entry.is_file(follow_symlinks=False): + continue + + # Filter by extension + if not entry.name.endswith(".pt"): + continue + + # Exclude certain patterns + if any(pattern in entry.name for pattern in exclude_patterns): + continue + + # Just yield the path, defer stat() call for parallel processing + yield entry.path + + except OSError as e: + if handle_errors: + logger.warning(f"Error accessing {entry.path if hasattr(entry, 'path') else root_path}: {e}") + continue + else: + raise + + except OSError as e: + if handle_errors: + logger.warning(f"Error scanning directory {root_path}: {e}") + else: + raise + + +def _get_file_metadata(file_path: str) -> dict | None: + """ + Worker function to get metadata for a single file path. + + This is designed to be called in parallel via ThreadPoolExecutor. + + Parameters + ---------- + file_path : str + Path to the file + + Returns + ------- + dict | None + File metadata or None if stat fails + """ + try: + stat_info = os.stat(file_path) + return { + "path": file_path, + "size_bytes": stat_info.st_size, + "mtime": stat_info.st_mtime, + "stem": Path(file_path).stem, + } + except OSError as e: + logger.warning(f"Could not stat {file_path}: {e}") + return None + + def merge_small_lists(list_of_lists, min_size=100): # Identify lists with less than min_size entries small_lists = [sublist for sublist in list_of_lists if len(sublist) < min_size] @@ -53,11 +145,12 @@ def make_struc_dict(cluster_file, processed_dir): def process_file(file_info): """Process a single file and return relevant information.""" - file_path, files_to_keep, cluster_dict = file_info + file_path, files_to_keep, cluster_dict, file_metadata = file_info - # Quick filter for .pt files - if not file_path.endswith(".pt") or any(x in file_path for x in ["cluster", "filter", "transform"]): - return None, None + # Quick filter for .pt files (skip if metadata provided, as it's pre-filtered) + if file_metadata is None: + if not file_path.endswith(".pt") or any(x in file_path for x in ["cluster", "filter", "transform"]): + return None, None fname = Path(file_path).stem @@ -65,17 +158,25 @@ def process_file(file_info): if files_to_keep is not None and fname not in files_to_keep: return file_path, None - # Check file size - try: - if Path(file_path).stat().st_size == 0: + # Check file size - use metadata if available to avoid stat call + if file_metadata is not None: + if "size_bytes" in file_metadata and file_metadata["size_bytes"] is not None: + if file_metadata["size_bytes"] == 0: + return file_path, None + else: + try: + if Path(file_path).stat().st_size == 0: + return file_path, None + except OSError: return file_path, None - except OSError: - return file_path, None # Get cluster info if needed + # Always return cluster_info tuple (fname, cluster_id) when cluster_dict is provided + # This allows us to distinguish between "file not checked" vs "file not in cluster" cluster_info = None if cluster_dict is not None: - cluster_info = (fname, cluster_dict.get(fname)) + cluster_id = cluster_dict.get(fname) # Can be None if not in cluster + cluster_info = (fname, cluster_id) return file_path, cluster_info @@ -118,6 +219,29 @@ class StructureDataset(Dataset): Whether to use memory mapping for loading large datasets and cluster files. This can significantly reduce memory usage for large datasets by loading data on-demand rather than all at once, by default False. + + cache_file : str | os.PathLike, optional + Path to cache file for storing file listings. If None, auto-generates path + as {processed_dir}/.cache/file_listing_cache.parquet, by default None. + + use_cache : bool, optional + Whether to use file listing cache to speed up initialization, by default True. + + rebuild_cache : bool, optional + Whether to force rebuild of cache file, by default False. + + cache_max_age_hours : float, optional + Maximum age of cache in hours before auto-rebuild. If None, cache never + expires based on age, by default None. + + skip_stat : bool, optional + Whether to skip stat calls during cache building (assumes all files exist + and are non-zero). Dramatically speeds up cache building on slow filesystems. + Files are validated on first access instead, by default False. + + stat_workers : int, optional + Number of workers for parallel stat operations. If None, uses cpu_count() * 4. + Reduce this for network filesystems (try 8-32), by default None. """ def __init__( @@ -132,6 +256,12 @@ def __init__( testing: bool = False, files_to_keep: str | os.PathLike = None, use_mmap: bool = False, + cache_file: str | os.PathLike = None, + use_cache: bool = True, + rebuild_cache: bool = False, + cache_max_age_hours: float = None, + skip_stat: bool = True, + stat_workers: int = None, ): import lobster @@ -157,11 +287,33 @@ def __init__( self.min_len = min_len self.testing = testing self.use_mmap = use_mmap + self.cache_file = cache_file + self.use_cache = use_cache + self.rebuild_cache = rebuild_cache + self.cache_max_age_hours = cache_max_age_hours + self.skip_stat = skip_stat + self.stat_workers = stat_workers logger.info(f"Loading data from {self.root}") self._load_data() logger.info("Loaded data points.") - # breakpoint() - super().__init__(root, transform, pre_transform) + + # For large datasets, skip PyG's expensive __init__ operations + if len(self.dataset_filenames) > 100000: + logger.info( + f"Large dataset detected ({len(self.dataset_filenames)} files), using lightweight initialization" + ) + # Call object.__init__ directly to skip PyG's validation/processing logic + # This bypasses all the expensive PyG checks for huge datasets + object.__init__(self) + self._transform = transform + self._pre_transform = pre_transform + # Set PyG internal attributes that are normally set in Dataset.__init__ + self._indices = None + self.__dict__["root"] = str(root) # Ensure root is set + logger.info("Initialization complete (bypassed PyG overhead)") + else: + # Normal PyG initialization for small datasets + super().__init__(root, transform, pre_transform) @property def raw_dir(self) -> str: @@ -181,6 +333,23 @@ def processed_dir(self): def processed_dir(self, value): self._processed_dir = value + @property + def processed_paths(self): + """Override PyG's processed_paths to return actual file paths from cache.""" + # For large datasets loaded from cache, dataset_filenames already contains full paths + # Don't let PyG construct paths by joining processed_dir + filename + return self.dataset_filenames + + @property + def transform(self): + """Handle transform for both PyG and lightweight init.""" + return getattr(self, "_transform", None) + + @transform.setter + def transform(self, value): + """Handle transform setter for both PyG and lightweight init.""" + self._transform = value + @property def get_cluster_dict(self): return self.cluster_dict @@ -188,6 +357,8 @@ def get_cluster_dict(self): @property def processed_file_names(self) -> list[str]: """Return list of processed files (ending with `.pt`).""" + if len(self.dataset_filenames) > 100000: # Large dataset threshold + return [] # PyG won't try to check files # use both dataset_filenames and identifiers to create processed file names assums .cif or .pdb ending for strucs return [f"{self.dataset_filenames[i]}" for i, f in enumerate(self.dataset_filenames)] @@ -195,19 +366,203 @@ def len(self) -> int: """Return the number of examples in the dataset.""" return len(self.dataset_filenames) + def __len__(self) -> int: + """Return the number of examples in the dataset. Required for PyTorch DataLoader.""" + return len(self.dataset_filenames) + + def indices(self): + """Return indices for the dataset. Required for PyG compatibility.""" + # Handle both PyG and lightweight initialization + _indices = getattr(self, "_indices", None) + if _indices is None: + return range(self.len()) + return _indices + def process(self): # Process datasets into pt files - if self.load_to_disk: - return - for idx, dataset_file in enumerate(self.dataset_filenames): - logger.info(f"Processing {dataset_file}...") - file_exists = os.path.exists(pathlib.Path(self.processed_dir) / self.processed_file_names[idx]) - if not file_exists or self.overwrite: - raise NotImplementedError - else: - logger.info(f"Skipping {dataset_file} as it already exists.") - - logger.info("Finished processing datasets.") + return + + def _get_cache_path(self) -> Path: + """Determine cache file path.""" + if self.cache_file is not None: + return Path(self.cache_file) + + # Auto-generate cache path + cache_dir = Path(self.processed_dir) / ".cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / "file_listing_cache.parquet" + + def _is_cache_valid(self, cache_path: Path) -> bool: + """Check if cache is valid and should be used.""" + # Force rebuild if requested + if self.rebuild_cache: + logger.info("Cache rebuild requested, will rebuild cache.") + return False + + # Check if cache exists + if not cache_path.exists(): + logger.info("Cache file does not exist.") + return False + + # Check age if max_age is set + if self.cache_max_age_hours is not None: + cache_age_hours = (time.time() - cache_path.stat().st_mtime) / 3600 + if cache_age_hours > self.cache_max_age_hours: + logger.info(f"Cache is {cache_age_hours:.1f} hours old (max: {self.cache_max_age_hours}), rebuilding.") + return False + + # Validate cache contents + try: + cache_data = pd.read_parquet(cache_path) + if "metadata" not in cache_data.columns: + logger.warning("Cache file missing metadata, rebuilding.") + return False + + metadata = cache_data["metadata"].iloc[0] + if metadata.get("processed_dir") != str(self.processed_dir): + logger.warning( + f"Cache processed_dir mismatch (cached: {metadata.get('processed_dir')}, current: {self.processed_dir}), rebuilding." + ) + return False + + logger.info(f"Cache is valid with {metadata.get('file_count', 0)} files.") + return True + except Exception as e: + logger.warning(f"Error reading cache file: {e}, rebuilding.") + return False + + def _scan_files_from_disk(self) -> list[dict]: + """ + Scan filesystem and return file metadata using parallel processing. + + Two-stage approach: + 1. Collect file paths (single-threaded, I/O bound) + 2. Get file metadata with parallel stat calls (multi-threaded) OR skip if skip_stat=True + """ + logger.info(f"Scanning files from disk in {self.processed_dir}...") + start_time = time.time() + + exclude_patterns = ["cluster", "filter", "transform"] + + # Stage 1: Collect file paths (fast directory traversal) + logger.info("Stage 1: Discovering file paths...") + stage1_start = time.time() + + # Use Python-based scanning + file_path_generator = _collect_file_paths_recursive( + Path(self.processed_dir), exclude_patterns, handle_errors=True + ) + + # Collect all paths into a list (needed for parallel processing) + file_paths = [] + with tqdm(desc="Discovering paths", unit=" paths", mininterval=0.5) as pbar: + for path in file_path_generator: + file_paths.append(path) + pbar.update(1) + + stage1_duration = time.time() - stage1_start + logger.info(f"Stage 1 complete: Found {len(file_paths)} file paths in {stage1_duration:.2f}s") + + # Stage 2: Get file metadata + if self.skip_stat: + # Fast path: Skip stat calls, assume files exist and are valid + logger.info("Stage 2: Skipping stat calls (skip_stat=True)") + logger.warning("Files will be validated on first access. Invalid files may cause errors later.") + files = [ + { + "path": path, + "size_bytes": None, # Unknown, will be checked on access + "mtime": None, # Unknown + "stem": Path(path).stem, + } + for path in file_paths + ] + stage2_duration = 0.0 + else: + # Normal path: Parallel stat calls to get file metadata + max_workers = self.stat_workers if self.stat_workers is not None else min(mp.cpu_count() * 4, 128) + logger.info(f"Stage 2: Getting file metadata with {max_workers} workers...") + + if max_workers <= 32: + logger.info(f"Using reduced worker count ({max_workers}) - optimized for network filesystems") + + stage2_start = time.time() + + files = [] + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Use chunksize for better performance + chunksize = max(1, len(file_paths) // (max_workers * 10)) + + # Map file paths to metadata + results = executor.map(_get_file_metadata, file_paths, chunksize=chunksize) + + # Collect results with progress bar + with tqdm( + results, total=len(file_paths), desc="Processing metadata", unit=" files", mininterval=0.5 + ) as pbar: + for metadata in pbar: + if metadata is not None: # Skip files that failed stat + files.append(metadata) + + stage2_duration = time.time() - stage2_start + logger.info(f"Stage 2 complete: Processed {len(files)} files in {stage2_duration:.2f}s") + + total_duration = time.time() - start_time + logger.info(f"Total scan time: {total_duration:.2f}s ({len(file_paths) / total_duration:.0f} files/sec)") + return files + + def _save_cache(self, file_data: list[dict], cache_path: Path): + """Save file listing to cache.""" + logger.info(f"Saving cache to {cache_path}...") + start_time = time.time() + + try: + # Create metadata + metadata = { + "created_at": time.time(), + "processed_dir": str(self.processed_dir), + "file_count": len(file_data), + "total_size_bytes": sum(f["size_bytes"] for f in file_data if f["size_bytes"] is not None), + "scan_duration_seconds": time.time() - start_time, + } + + # Convert to DataFrame + df = pd.DataFrame(file_data) + # Add metadata as a column (store as dict in first row) + df["metadata"] = None + df.at[0, "metadata"] = metadata + + # Save to parquet + df.to_parquet(cache_path, engine="pyarrow", compression="snappy") + + duration = time.time() - start_time + logger.info(f"Cache saved with {len(file_data)} files in {duration:.2f}s") + except Exception as e: + logger.warning(f"Failed to save cache: {e}") + + def _load_cache(self, cache_path: Path) -> list[dict]: + """Load file listing from cache.""" + logger.info(f"Loading from cache: {cache_path}") + start_time = time.time() + + try: + df = pd.read_parquet(cache_path, engine="pyarrow") + # Remove metadata column + metadata = df["metadata"].iloc[0] if "metadata" in df.columns else {} + df = df.drop(columns=["metadata"], errors="ignore") + + # Convert to list of dicts + file_data = df.to_dict("records") + + duration = time.time() - start_time + logger.info( + f"Cache loaded: {len(file_data)} files in {duration:.2f}s (original scan took {metadata.get('scan_duration_seconds', 'unknown')}s)" + ) + return file_data + except Exception as e: + logger.error(f"Failed to load cache: {e}") + raise def _load_data(self): """Load the dataset from the processed files.""" @@ -216,20 +571,57 @@ def _load_data(self): self.cluster_dict = torch.load(self.cluster_file) logger.info(f"Loaded cluster file {self.cluster_file} with {len(self.cluster_dict)} clusters.") - # Load files to keep + # Load files to keep (convert to set for O(1) lookup) files_to_keep = None if self.files_to_keep is not None: with open(self.files_to_keep, "rb") as f: - files_to_keep = pickle.load(f) + files_to_keep_list = pickle.load(f) + files_to_keep = set(files_to_keep_list) if isinstance(files_to_keep_list, list) else files_to_keep_list logger.info(f"Using files_to_keep with currently {len(files_to_keep)} files to keep") - # Get all .pt files recursively using glob - all_files = glob.glob(str(Path(self.processed_dir) / "**/*.pt"), recursive=True) + # Get file listings - use cache if enabled and not loading to disk + if not self.load_to_disk and self.use_cache: + cache_path = self._get_cache_path() + + if self._is_cache_valid(cache_path): + # Load from cache + file_data = self._load_cache(cache_path) + else: + # Scan from disk and save to cache + file_data = self._scan_files_from_disk() + self._save_cache(file_data, cache_path) + + # Convert file_data list of dicts to list of paths for compatibility + all_files = [f["path"] for f in file_data] + # Store file metadata for potential future use + self._file_metadata = {f["path"]: f for f in file_data} + elif not self.load_to_disk: + # Cache disabled, use traditional glob method + logger.info("Cache disabled, using glob to find files...") + all_files = glob.glob(str(Path(self.processed_dir) / "**/*.pt"), recursive=True) + else: + # For load_to_disk mode, we'll handle file loading separately + all_files = [] # Prepare arguments for parallel processing - process_args = [ - (f, files_to_keep, self.cluster_dict if self.cluster_file is not None else None) for f in all_files - ] + # Include file metadata if available (from cache) + if hasattr(self, "_file_metadata"): + logger.info(f"Using file metadata from cache with {len(self._file_metadata)} files") + process_args = [ + ( + f, + files_to_keep, + self.cluster_dict if self.cluster_file is not None else None, + self._file_metadata.get(f), + ) + for f in all_files + ] + else: + logger.info("No file metadata from cache, using None") + process_args = [ + (f, files_to_keep, self.cluster_dict if self.cluster_file is not None else None, None) + for f in all_files + ] # Process files in parallel processed_files = [] @@ -238,9 +630,24 @@ def _load_data(self): if not self.load_to_disk: # Use ThreadPoolExecutor for I/O bound operations - with ThreadPoolExecutor(max_workers=min(32, mp.cpu_count() * 2)) as executor: + # Use stat_workers if set, otherwise use conservative defaults + if self.stat_workers is not None: + max_workers = self.stat_workers + else: + # Conservative defaults to avoid OOM on multi-node training + max_workers = min(128, mp.cpu_count() * 4) if len(process_args) > 10000 else min(32, mp.cpu_count() * 2) + logger.info(f"Using {max_workers} workers for parallel processing (available CPUs: {mp.cpu_count()})") + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Use chunksize for better performance with large datasets + chunksize = max(1, len(process_args) // (max_workers * 10)) results = list( - tqdm(executor.map(process_file, process_args), total=len(process_args), desc="Processing files") + tqdm( + executor.map(process_file, process_args, chunksize=chunksize), + total=len(process_args), + desc="Processing files", + mininterval=1.0, # Reduce progress bar update frequency + ) ) # Process results @@ -248,13 +655,17 @@ def _load_data(self): if file_path is None: continue - if cluster_info is None and self.cluster_file is not None: + # Only skip files that were checked against cluster_dict but not found in it + # cluster_info is None means: file was filtered for other reasons (files_to_keep, size, etc) + # cluster_info[1] is None means: file was checked but not in cluster_dict + if cluster_info is not None and cluster_info[1] is None and self.cluster_file is not None: skip_files.append(file_path) continue processed_files.append(file_path) - if self.cluster_file is not None and cluster_info[1] is not None: # If we have cluster info + # Add to cluster if we have cluster info and the file is in a cluster + if self.cluster_file is not None and cluster_info is not None and cluster_info[1] is not None: cluster_id = cluster_info[1] if cluster_id not in cluster_dict: cluster_dict[cluster_id] = [] @@ -311,12 +722,17 @@ def _load_data(self): def __getitem__(self, idx: int) -> tuple: """Return the dataset at the given index.""" if not self.load_to_disk: + # Use dataset_filenames directly (already full paths from cache) + # instead of processed_paths which PyG constructs incorrectly try: - x = torch.load(self.processed_paths[idx]) + file_path = self.dataset_filenames[idx] + x = torch.load(file_path) except Exception as e: - # ic(f"Error loading {self.processed_paths[idx]}: {e}") + logger.error( + f"Error loading {self.dataset_filenames[idx] if idx < len(self.dataset_filenames) else 'unknown'}: {e}" + ) # load the next file if it exists - if idx + 1 < len(self.processed_paths): + if idx + 1 < len(self.dataset_filenames): return self.__getitem__(idx + 1) elif idx - 1 >= 0: return self.__getitem__(idx - 1) @@ -328,6 +744,7 @@ def __getitem__(self, idx: int) -> tuple: if self.use_mmap and hasattr(x, "to"): x = x.to("cpu") + # Handle transform (works with both PyG's and our lightweight init) if self.transform: x = self.transform(x) diff --git a/src/lobster/hydra_config/callbacks/backbone_reconstruction.yaml b/src/lobster/hydra_config/callbacks/backbone_reconstruction.yaml new file mode 100644 index 00000000..3dd06939 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/backbone_reconstruction.yaml @@ -0,0 +1,11 @@ +backbone_reconstruction: + _target_: lobster.model.latent_generator.callbacks.BackboneReconstruction + structure_path: "${paths.output_dir}/structures/" + save_every_n: 10000 + max_total_files: 1000 # Set to a number (e.g., 100) to limit total PDB files saved + use_extended_element_vocab: false + # Ligand minimization options + minimize_ligand: false + minimize_mode: "bonds_and_angles" # "bonds_only" or "bonds_and_angles" + force_field: "MMFF94" + minimize_steps: 500 diff --git a/src/lobster/hydra_config/callbacks/forward_folding.yaml b/src/lobster/hydra_config/callbacks/forward_folding.yaml new file mode 100644 index 00000000..27b36907 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/forward_folding.yaml @@ -0,0 +1,14 @@ +ForwardFoldingCallback: + _target_: lobster.callbacks._forward_folding_callback.ForwardFoldingCallback + structure_path: ${paths.output_dir}/structures/ + cameo_data_path: "/cv/data/ai4dd/data2/lisanzas/AFDB/valid_cameo_processed/*.pt" + save_every_n: 1000 + num_samples: 127 + max_length: 512 + nsteps: 200 + temperature_seq: 0.3610371899835548 + temperature_struc: 0.2195534567490864 + stochasticity_seq: 1 + stochasticity_struc: 20 + cache_dir: null + diff --git a/src/lobster/hydra_config/callbacks/gen_ume_full_eval.yaml b/src/lobster/hydra_config/callbacks/gen_ume_full_eval.yaml new file mode 100644 index 00000000..c72b99d2 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/gen_ume_full_eval.yaml @@ -0,0 +1,19 @@ +defaults: + - model_checkpoint.yaml + - lr_monitor.yaml + - progress_bar.yaml + - structure_decode.yaml + - inverse_folding.yaml + - inverse_folding_cameo.yaml + - forward_folding.yaml + - _self_ + + +model_checkpoint: + dirpath: ${paths.output_dir} + filename: "{epoch}-{step}-{val_loss:.4f}" + monitor: val_loss + +early_stopping: + monitor: val_loss + diff --git a/src/lobster/hydra_config/callbacks/gen_ume_fwd.yaml b/src/lobster/hydra_config/callbacks/gen_ume_fwd.yaml new file mode 100644 index 00000000..52e9bf53 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/gen_ume_fwd.yaml @@ -0,0 +1,18 @@ +defaults: + - model_checkpoint.yaml + - lr_monitor.yaml + - progress_bar.yaml + - structure_decode.yaml + - inverse_folding.yaml + - forward_folding.yaml + - _self_ + + +model_checkpoint: + dirpath: ${paths.output_dir} + filename: "{epoch}-{step}-{val_loss:.4f}" + monitor: val_loss + +early_stopping: + monitor: val_loss + diff --git a/src/lobster/hydra_config/callbacks/gen_ume_inv.yaml b/src/lobster/hydra_config/callbacks/gen_ume_inv.yaml index 043014f1..76609960 100644 --- a/src/lobster/hydra_config/callbacks/gen_ume_inv.yaml +++ b/src/lobster/hydra_config/callbacks/gen_ume_inv.yaml @@ -4,7 +4,6 @@ defaults: - progress_bar.yaml - structure_decode.yaml - inverse_folding.yaml - - unconditional_generation.yaml - _self_ diff --git a/src/lobster/hydra_config/callbacks/gen_ume_protein_ligand.yaml b/src/lobster/hydra_config/callbacks/gen_ume_protein_ligand.yaml new file mode 100644 index 00000000..2e1536c8 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/gen_ume_protein_ligand.yaml @@ -0,0 +1,51 @@ +# Callbacks for Gen-UME Protein-Ligand training +# Includes structure decoding, inverse folding, and protein-ligand complex visualization + +defaults: + - model_checkpoint.yaml + - lr_monitor.yaml + - progress_bar.yaml + - structure_decode.yaml + - protein_ligand_decode.yaml + - inverse_folding.yaml + - inverse_folding_cameo.yaml + - protein_ligand_inverse_folding.yaml + - protein_ligand_forward_folding.yaml + - forward_folding.yaml + - _self_ + + +# Best checkpoints by validation loss +model_checkpoint: + dirpath: ${paths.output_dir} + filename: "{epoch}-{step}-{val_loss:.4f}" + monitor: val_loss + save_top_k: 3 + mode: min + verbose: true + save_last: false # Disabled - using last_checkpoint callback instead + +# Keep last N checkpoints regardless of val_loss (crash recovery) +last_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.output_dir} + filename: "last-{epoch}-{step}" + monitor: step + mode: max + save_top_k: 3 + verbose: true + every_n_train_steps: 10000 # Save every 1000 steps + +structure_decode: + structure_path: "${paths.output_dir}/structures/" + save_every_n: 10000 + +protein_ligand_decode: + structure_path: "${paths.output_dir}/structures/" + save_every_n: 10000 + save_separate: false + # Ligand minimization options + minimize_ligand: false + minimize_mode: "bonds_and_angles" # "bonds_only" or "bonds_and_angles" + force_field: "MMFF94" + minimize_steps: 500 diff --git a/src/lobster/hydra_config/callbacks/inverse_folding.yaml b/src/lobster/hydra_config/callbacks/inverse_folding.yaml index 1a11c27b..d3d46d74 100644 --- a/src/lobster/hydra_config/callbacks/inverse_folding.yaml +++ b/src/lobster/hydra_config/callbacks/inverse_folding.yaml @@ -2,6 +2,8 @@ InverseFoldingCallback: _target_: lobster.callbacks._inverse_folding_callback.InverseFoldingCallback structure_path: ${paths.output_dir}/structures/ save_every_n: 1000 + dataset_name: cath + metric_prefix: inverse_folding_cath length: 100 num_samples: 10 - use_plm_fold: true + use_plm_fold: false diff --git a/src/lobster/hydra_config/callbacks/inverse_folding_cameo.yaml b/src/lobster/hydra_config/callbacks/inverse_folding_cameo.yaml new file mode 100644 index 00000000..4e39bd8a --- /dev/null +++ b/src/lobster/hydra_config/callbacks/inverse_folding_cameo.yaml @@ -0,0 +1,11 @@ +InverseFoldingCAMEOCallback: + _target_: lobster.callbacks._inverse_folding_callback.InverseFoldingCallback + structure_path: ${paths.output_dir}/structures/ + save_every_n: 1000 + dataset_name: cameo + dataset_path: "/cv/data/ai4dd/data2/lisanzas/AFDB/valid_cameo_processed/*.pt" + metric_prefix: inverse_folding_cameo + num_samples: 127 + use_plm_fold: false + max_length: 512 + diff --git a/src/lobster/hydra_config/callbacks/latent_generator_defaults.yaml b/src/lobster/hydra_config/callbacks/latent_generator_defaults.yaml new file mode 100644 index 00000000..1c84359b --- /dev/null +++ b/src/lobster/hydra_config/callbacks/latent_generator_defaults.yaml @@ -0,0 +1,13 @@ +defaults: + - model_checkpoint.yaml + - lr_monitor.yaml + - backbone_reconstruction.yaml + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir} + filename: "{epoch}-{step}-{val_loss:.4f}" + monitor: val_loss + +early_stopping: + monitor: val_loss \ No newline at end of file diff --git a/src/lobster/hydra_config/callbacks/protein_ligand_decode.yaml b/src/lobster/hydra_config/callbacks/protein_ligand_decode.yaml new file mode 100644 index 00000000..7af1db54 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/protein_ligand_decode.yaml @@ -0,0 +1,11 @@ +protein_ligand_decode: + _target_: lobster.callbacks._protein_ligand_decode.ProteinLigandDecodeCallback + structure_path: "${paths.output_dir}/structures/" + save_every_n: 1000 + save_separate: true + # Ligand minimization options + minimize_ligand: false + minimize_mode: "bonds_and_angles" # "bonds_only" or "bonds_and_angles" + force_field: "MMFF94" + minimize_steps: 500 + diff --git a/src/lobster/hydra_config/callbacks/protein_ligand_forward_folding.yaml b/src/lobster/hydra_config/callbacks/protein_ligand_forward_folding.yaml new file mode 100644 index 00000000..8e9f32fd --- /dev/null +++ b/src/lobster/hydra_config/callbacks/protein_ligand_forward_folding.yaml @@ -0,0 +1,35 @@ +# Protein-Ligand Forward Folding Callback +# Evaluates ligand-conditioned vs unconditioned forward folding (sequence → structure) +# +# Key Question: Does providing ligand context improve structure prediction +# for binding pocket residues? +# +# Tracks: +# - tm_score_*: Overall TM-score (structure similarity) +# - rmsd_overall_*: Overall backbone RMSD +# - rmsd_pocket_*: Pocket-only RMSD (residues within threshold of ligand) +# - rmsd_nonpocket_*: Non-pocket RMSD +# - *_delta: Improvement from providing ligand context +# +# Usage in experiment config: +# defaults: +# - override /callbacks: gen_ume_protein_ligand +# callbacks: +# protein_ligand_forward_folding: +# save_every_n: 5000 + +protein_ligand_forward_folding: + _target_: lobster.callbacks.ProteinLigandForwardFoldingCallback + data_dir: /cv/home/lisanzas/lobster/data/posebusters/processed/posebusters_benchmark_no_overlap/ + structure_path: ${paths.output_dir}/protein_ligand_eval/ + save_every_n: 1000 + num_samples: 206 # All PoseBusters benchmark complexes (no overlap with training) + pocket_distance_threshold: 5.0 # Å - residues within this distance of ligand are "pocket" + nsteps: 100 # Diffusion steps for generation + metric_prefix: protein_ligand_forward_folding + # Ligand minimization options (for decoded ligand structures) + minimize_ligand: false + minimize_mode: "bonds_and_angles" # "bonds_only", "bonds_and_angles", "local", or "full" + force_field: "MMFF94" + minimize_steps: 500 + diff --git a/src/lobster/hydra_config/callbacks/protein_ligand_inverse_folding.yaml b/src/lobster/hydra_config/callbacks/protein_ligand_inverse_folding.yaml new file mode 100644 index 00000000..afbad4e5 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/protein_ligand_inverse_folding.yaml @@ -0,0 +1,33 @@ +# Protein-Ligand Inverse Folding Callback +# Evaluates ligand-conditioned vs unconditioned inverse folding +# +# Key Question: Does providing ligand context improve sequence recovery +# for binding pocket residues? +# +# Tracks: +# - aar_overall_*: Overall amino acid recovery +# - aar_pocket_*: Pocket-only recovery (residues within threshold of ligand) +# - aar_nonpocket_*: Non-pocket recovery +# - *_delta: Improvement from providing ligand context +# +# Usage in experiment config: +# defaults: +# - override /callbacks: gen_ume_protein_ligand +# callbacks: +# protein_ligand_inverse_folding: +# save_every_n: 5000 + +protein_ligand_inverse_folding: + _target_: lobster.callbacks.ProteinLigandInverseFoldingCallback + data_dir: /cv/home/lisanzas/lobster/data/posebusters/processed/posebusters_benchmark_no_overlap/ + structure_path: ${paths.output_dir}/protein_ligand_eval/ + save_every_n: 1000 + num_samples: 206 # All PoseBusters benchmark complexes (no overlap with training) + pocket_distance_threshold: 5.0 # Å - residues within this distance of ligand are "pocket" + nsteps: 100 # Diffusion steps for generation + metric_prefix: protein_ligand_inverse_folding + # Ligand minimization options (for decoded ligand structures) + minimize_ligand: false + minimize_mode: "bonds_and_angles" # "bonds_only", "bonds_and_angles", "local", or "full" + force_field: "MMFF94" + minimize_steps: 500 diff --git a/src/lobster/hydra_config/callbacks/s3_backup.yaml b/src/lobster/hydra_config/callbacks/s3_backup.yaml new file mode 100644 index 00000000..4454ed42 --- /dev/null +++ b/src/lobster/hydra_config/callbacks/s3_backup.yaml @@ -0,0 +1,20 @@ +# S3 Checkpoint Backup Callback Configuration +# +# Add to your training config to automatically backup checkpoints to S3: +# callbacks: +# defaults: +# - s3_backup +# +# Or override settings: +# +callbacks.s3_backup.upload_every_n_epochs=5 +# +s3_backup: + _target_: lobster.callbacks._s3_checkpoint_callback.S3CheckpointBackupCallback + s3_bucket: "prescient-pcluster-data" + s3_prefix: "gen_ume/checkpoints" + project_name: ${logger.project} # Uses the WandB project name + upload_every_n_epochs: 10 + upload_best_only: false + upload_last: true + dry_run: false + diff --git a/src/lobster/hydra_config/data/structure_afdb.yaml b/src/lobster/hydra_config/data/structure_afdb.yaml index 23168b3f..d65a672f 100644 --- a/src/lobster/hydra_config/data/structure_afdb.yaml +++ b/src/lobster/hydra_config/data/structure_afdb.yaml @@ -3,8 +3,8 @@ defaults: - transform_fn: structure_backbone_transform.yaml # Dataset transforms _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data/bucket/lisanza/structures/afdb_rep_v4_processed/ # this doesnt do anything -path_to_datasets: ["/data/bucket/lisanza/structures/afdb_rep_v4_processed/train_shards/", "/data/bucket/lisanza/structures/afdb_rep_v4_processed/val_shards/", "/data/bucket/lisanza/structures/afdb_rep_v4_processed/test_shards/"] +root: /cv/data/ai4dd/data/bucket/lisanza/structures/afdb_rep_v4_processed/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data/bucket/lisanza/structures/afdb_rep_v4_processed/train_shards/", "/cv/data/ai4dd/data/bucket/lisanza/structures/afdb_rep_v4_processed/val_shards/", "/cv/data/ai4dd/data/bucket/lisanza/structures/afdb_rep_v4_processed/test_shards/"] max_length: 512 # Maximum length of the sequence; not used use_shards: true buffer_size: 10 diff --git a/src/lobster/hydra_config/data/structure_afdb_genie.yaml b/src/lobster/hydra_config/data/structure_afdb_genie.yaml index 43ced35a..7f9b972f 100644 --- a/src/lobster/hydra_config/data/structure_afdb_genie.yaml +++ b/src/lobster/hydra_config/data/structure_afdb_genie.yaml @@ -4,8 +4,8 @@ defaults: - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything -path_to_datasets: ["/data2/lisanzas/latent_generator_files/afdb_data/processed_pt/train_afdb_genie2_data.pt", "/data2/lisanzas/latent_generator_files/pdb_data/split_data/validation.pt", "/data2/lisanzas/latent_generator_files/pdb_data/split_data/test.pt"] +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/afdb_data/processed_pt/train_afdb_genie2_data.pt", "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/validation.pt", "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/test.pt"] max_length: 512 # Maximum length of the sequence; not used cluster_file: null files_to_keep: null # this doesnt do anything diff --git a/src/lobster/hydra_config/data/structure_afdb_swissprot.yaml b/src/lobster/hydra_config/data/structure_afdb_swissprot.yaml new file mode 100644 index 00000000..583dbdaf --- /dev/null +++ b/src/lobster/hydra_config/data/structure_afdb_swissprot.yaml @@ -0,0 +1,21 @@ +defaults: + - sampler: randomized_minority_upsampler.yaml # Dataloader batch sampler + - transform_fn: structure_backbone_aa_tokenizer_transform.yaml # Dataset transforms + - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data2/lisanzas/AFDB/ # Base directory for AFDB dataset +# Note: Datamodule checks for "train", "val", or "test" in path names +# "valid_cameo_processed" contains "val" so it should be detected as validation set +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/AFDB/train_processed", "/cv/data/ai4dd/data2/lisanzas/AFDB/valid_cameo_processed", "/cv/data/ai4dd/data2/lisanzas/AFDB/test_multiflow_processed"] +max_length: 512 # Maximum length of the sequence; not used +cluster_file: /cv/data/ai4dd/data2/lisanzas/AFDB/pdb_swissprot_clusters.pt +files_to_keep: null # this doesnt do anything +testing: false +use_shards: false +datasets: afdb_swissprot + +# Dataloader Params +batch_size: 40 # Batch size for the dataloader +num_workers: 12 # Number of workers for the dataloader + diff --git a/src/lobster/hydra_config/data/structure_cath.yaml b/src/lobster/hydra_config/data/structure_cath.yaml index e2a70739..a2e177d7 100644 --- a/src/lobster/hydra_config/data/structure_cath.yaml +++ b/src/lobster/hydra_config/data/structure_cath.yaml @@ -4,8 +4,8 @@ defaults: - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data2/lisanzas/CATH_v4_3/temp/ # this doesnt do anything -path_to_datasets: ["/data2/lisanzas/CATH_v4_3/processed_structures_pt/train/cath_train.pt", "/data2/lisanzas/CATH_v4_3/processed_structures_pt/val/cath_val.pt", "/data2/lisanzas/CATH_v4_3/processed_structures_pt/test/cath_test.pt"] +root: /cv/data/ai4dd/data2/lisanzas/CATH_v4_3/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/CATH_v4_3/processed_structures_pt/train/cath_train.pt", "/cv/data/ai4dd/data2/lisanzas/CATH_v4_3/processed_structures_pt/val/cath_val.pt", "/cv/data/ai4dd/data2/lisanzas/CATH_v4_3/processed_structures_pt/test/cath_test.pt"] max_length: 512 # Maximum length of the sequence; not used cluster_file: null files_to_keep: null # this doesnt do anything diff --git a/src/lobster/hydra_config/data/structure_esm_atlas_afdb_swissprot.yaml b/src/lobster/hydra_config/data/structure_esm_atlas_afdb_swissprot.yaml new file mode 100644 index 00000000..549f9f60 --- /dev/null +++ b/src/lobster/hydra_config/data/structure_esm_atlas_afdb_swissprot.yaml @@ -0,0 +1,19 @@ +defaults: + - sampler: randomized_minority_upsampler.yaml # Dataloader batch sampler + - transform_fn: structure_backbone_aa_tokenizer_transform.yaml # Dataset transforms + - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data2/ume/simplefold_dataset/train_processed/ # Root directory (not used when path_to_datasets is specified) +path_to_datasets: ["/cv/data/ai4dd/data2/ume/simplefold_dataset/train_processed","/cv/data/ai4dd/data2/lisanzas/AFDB/valid_cameo_processed", "/cv/data/ai4dd/data2/lisanzas/AFDB/test_multiflow_processed"] +max_length: 512 # Maximum length of the sequence +cluster_file_list: [null, null, null] # Optional: path to cluster file for deduplication +files_to_keep_list: [null, null, null] # Optional: path to file containing list of files to keep +testing: false +use_shards: false +datasets: esm_atlas_afdb_swissprot + +# Dataloader Params +batch_size: 40 # Batch size for the dataloader +num_workers: 12 # Number of workers for the dataloader + diff --git a/src/lobster/hydra_config/data/structure_ligand.yaml b/src/lobster/hydra_config/data/structure_ligand.yaml index 7ccd6e3e..10c6415f 100644 --- a/src/lobster/hydra_config/data/structure_ligand.yaml +++ b/src/lobster/hydra_config/data/structure_ligand.yaml @@ -3,14 +3,14 @@ defaults: - ligand_transform_fn: structure_ligand_transform.yaml _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything -path_to_datasets: ["/data/bucket/lisanza/structures/pdb_bind/processed_2/"] +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data/bucket/lisanza/structures/GEOM/processed/train/", "/cv/data/ai4dd/data/bucket/lisanza/structures/GEOM/processed/val/", "/cv/data/ai4dd/data/bucket/lisanza/structures/GEOM/processed/test/"] max_length: 512 # Maximum length of the sequence; not used cluster_file: null files_to_keep: null use_ligand_dataset: true testing: false -datasets: ligand_pdb_bind +datasets: ligand_geom use_shards: false # Dataloader Params diff --git a/src/lobster/hydra_config/data/structure_ligand_pdb.yaml b/src/lobster/hydra_config/data/structure_ligand_pdb.yaml new file mode 100644 index 00000000..ca246c28 --- /dev/null +++ b/src/lobster/hydra_config/data/structure_ligand_pdb.yaml @@ -0,0 +1,25 @@ +defaults: + - transform_fn: structure_backbone_transform.yaml # Dataset transforms + - ligand_transform_fn: structure_ligand_transform.yaml + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", "/cv/data/ai4dd/data/bucket/lisanza/structures/GEOM/processed/train/", "/cv/data/ai4dd/data2/lisanzas/pdb_bind/train/", "/cv/data/ai4dd/data2/lisanzas/pdb_bind/val/", "/cv/data/ai4dd/data2/lisanzas/pdb_bind/test/"] +max_length: 512 # Maximum length of the sequence; not used + +# Per-dataset type specification +# 'structure' = protein-only (uses StructureDataset) +# 'ligand' = ligand-only or protein-ligand pairs (uses LigandDataset) +dataset_types: ["structure", "ligand", "ligand", "ligand", "ligand"] # PDB, GEOM, PDBBind train/val/test + +cluster_file_list: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", null, null, null, null] +files_to_keep: null # this doesnt do anything +files_to_keep_list: [null, null, null, null, null] # No filtering for any dataset +use_ligand_dataset: true # Kept for backwards compatibility +testing: false +datasets: ligand_pdb_geom_pbind +use_shards: false + +# Dataloader Params +batch_size: 40 # Batch size for the dataloader +num_workers: 12 # Number of workers for the dataloader diff --git a/src/lobster/hydra_config/data/structure_ligand_pdb_afdb_sair_bond.yaml b/src/lobster/hydra_config/data/structure_ligand_pdb_afdb_sair_bond.yaml new file mode 100644 index 00000000..f863381b --- /dev/null +++ b/src/lobster/hydra_config/data/structure_ligand_pdb_afdb_sair_bond.yaml @@ -0,0 +1,47 @@ +defaults: + - sampler: randomized_minority_upsampler.yaml # Balanced sampling across datasets + - transform_fn: structure_backbone_aa_tokenizer_transform.yaml # Dataset transforms (includes AminoAcidTokenizerTransform) + - ligand_transform_fn: structure_ligand_transform.yaml + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # not used + +# Datasets with bond matrices (preprocessed 12/15/25) + AFDB SwissProt: +# - PDB: protein-only, 278k structures (clustered at 40% seq identity) +# - AFDB SwissProt: protein-only, 198k structures from AlphaFold DB +# - GEOM: ligand conformers with bond_matrix, 247k +# - PDBBind: protein-ligand complexes with bond_matrix, 44k train / 5.5k val / 5.5k test +# - SAIR: protein-ligand complexes with bond_matrix, 560k +path_to_datasets: [ + "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", # PDB train (278k) + "/cv/data/ai4dd/data2/lisanzas/AFDB/train_processed", # AFDB SwissProt train (198k) + "/cv/data/ai4dd/data2/lisanzas/geom_12_15_25/train/", # GEOM train (247k) + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/train/", # PDBBind train (44k) + "/cv/data/ai4dd/data2/lisanzas/sair_12_15_25/train/", # SAIR train (560k) + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/val/", # PDBBind val (5.5k) + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/" # PDBBind test (5.5k) +] +max_length: 512 + +# Per-dataset type specification +# 'structure' = protein-only (uses StructureDataset) +# 'ligand' = ligand-only or protein-ligand pairs (uses LigandDataset) +dataset_types: ["structure", "structure", "ligand", "ligand", "ligand", "ligand", "ligand"] + +# Cluster files: PDB cluster, AFDB SwissProt cluster, null for ligand datasets +cluster_file_list: [ + "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", + "/cv/data/ai4dd/data2/lisanzas/AFDB/pdb_swissprot_clusters.pt", + null, null, null, null, null +] + +files_to_keep: null +files_to_keep_list: [null, null, null, null, null, null, null] +use_ligand_dataset: true +testing: false +datasets: ligand_pdb_afdb_geom_pbind_sair_bond +use_shards: false + +# Dataloader Params +batch_size: 8 +num_workers: 12 diff --git a/src/lobster/hydra_config/data/structure_ligand_pdb_sair.yaml b/src/lobster/hydra_config/data/structure_ligand_pdb_sair.yaml new file mode 100644 index 00000000..d3ddb415 --- /dev/null +++ b/src/lobster/hydra_config/data/structure_ligand_pdb_sair.yaml @@ -0,0 +1,25 @@ +defaults: + - transform_fn: structure_backbone_transform.yaml # Dataset transforms + - ligand_transform_fn: structure_ligand_transform.yaml + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", "/cv/data/ai4dd/data/bucket/lisanza/structures/GEOM/processed/train/", "/cv/data/ai4dd/data2/lisanzas/pdb_bind/train/","/cv/data/ai4dd/data2/lisanzas/sair_protein_ligand/train/", "/cv/data/ai4dd/data2/lisanzas/pdb_bind/val/", "/cv/data/ai4dd/data2/lisanzas/pdb_bind/test/"] +max_length: 512 # Maximum length of the sequence; not used + +# Per-dataset type specification +# 'structure' = protein-only (uses StructureDataset) +# 'ligand' = ligand-only or protein-ligand pairs (uses LigandDataset) +dataset_types: ["structure", "ligand", "ligand", "ligand", "ligand", "ligand"] # PDB, GEOM, PDBBind train/val/test + +cluster_file_list: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", null, null, null, null, null] +files_to_keep: null # this doesnt do anything +files_to_keep_list: [null, null, null, null, null, null] # No filtering for any dataset +use_ligand_dataset: true # Kept for backwards compatibility +testing: false +datasets: ligand_pdb_geom_pbind_sair +use_shards: false + +# Dataloader Params +batch_size: 40 # Batch size for the dataloader +num_workers: 12 # Number of workers for the dataloader diff --git a/src/lobster/hydra_config/data/structure_ligand_pdb_sair_bond.yaml b/src/lobster/hydra_config/data/structure_ligand_pdb_sair_bond.yaml new file mode 100644 index 00000000..c408c984 --- /dev/null +++ b/src/lobster/hydra_config/data/structure_ligand_pdb_sair_bond.yaml @@ -0,0 +1,46 @@ +defaults: + - sampler: randomized_minority_upsampler.yaml # Balanced sampling across datasets + - transform_fn: structure_backbone_aa_tokenizer_transform.yaml # Dataset transforms (includes AminoAcidTokenizerTransform) + - ligand_transform_fn: structure_ligand_transform.yaml + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # not used + +# Datasets with bond matrices (preprocessed 12/15/25): +# - PDB: protein-only (clustered at 40% seq identity) +# - GEOM: ligand conformers with bond_matrix from SDF files +# - PDBBind: protein-ligand pairs with bond_matrix from SDF files +# - SAIR: protein-ligand pairs with bond_matrix from SMILES +path_to_datasets: [ + "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", # PDB train + "/cv/data/ai4dd/data2/lisanzas/geom_12_15_25/train/", # GEOM train (with bond_matrix) + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/train/", # PDBBind train (with bond_matrix) + "/cv/data/ai4dd/data2/lisanzas/sair_12_15_25/train/", # SAIR train (with bond_matrix) + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/val/", # PDBBind val + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/" # PDBBind test +] +max_length: 512 # Maximum length of the sequence; not used + +# Per-dataset type specification +# 'structure' = protein-only (uses StructureDataset) +# 'ligand' = ligand-only or protein-ligand pairs (uses LigandDataset) +dataset_types: ["structure", "ligand", "ligand", "ligand", "ligand", "ligand"] + +# Cluster file for PDB dataset (sequence clustering) +cluster_file_list: [ + "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", + null, null, null, null, null +] + +files_to_keep: null +files_to_keep_list: [null, null, null, null, null, null] +use_ligand_dataset: true +testing: false +datasets: ligand_pdb_geom_pbind_sair_bond # New name to distinguish from old config +use_shards: false + +# Dataloader Params +# Note: batch_size reduced from 40 to 8 for memory efficiency with bond prediction +batch_size: 8 +num_workers: 12 + diff --git a/src/lobster/hydra_config/data/structure_ligand_pdb_sair_no_geom.yaml b/src/lobster/hydra_config/data/structure_ligand_pdb_sair_no_geom.yaml new file mode 100644 index 00000000..512feb3e --- /dev/null +++ b/src/lobster/hydra_config/data/structure_ligand_pdb_sair_no_geom.yaml @@ -0,0 +1,47 @@ +defaults: + - sampler: randomized_minority_upsampler.yaml # Balanced sampling across datasets + - transform_fn: structure_backbone_aa_tokenizer_transform.yaml # Dataset transforms (includes AminoAcidTokenizerTransform) + - ligand_transform_fn: structure_ligand_transform.yaml + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # not used + +# Datasets WITHOUT GEOM (no ligand-only): +# - PDB: protein-only (clustered at 40% seq identity) +# - PDBBind: protein-ligand pairs with bond_matrix from SDF files +# - SAIR: protein-ligand pairs with bond_matrix from SMILES +# +# Effective proportions (round-robin): +# - Protein-only: ~33% +# - Protein-ligand: ~67% +path_to_datasets: [ + "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", # PDB train + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/train/", # PDBBind train (with bond_matrix) + "/cv/data/ai4dd/data2/lisanzas/sair_12_15_25/train/", # SAIR train (with bond_matrix) + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/val/", # PDBBind val + "/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/" # PDBBind test +] +max_length: 512 # Maximum length of the sequence; not used + +# Per-dataset type specification +# 'structure' = protein-only (uses StructureDataset) +# 'ligand' = ligand-only or protein-ligand pairs (uses LigandDataset) +dataset_types: ["structure", "ligand", "ligand", "ligand", "ligand"] + +# Cluster file for PDB dataset (sequence clustering) +cluster_file_list: [ + "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", + null, null, null, null +] + +files_to_keep: null +files_to_keep_list: [null, null, null, null, null] +use_ligand_dataset: true +testing: false +datasets: pdb_pbind_sair_no_geom # New name for this config +use_shards: false + +# Dataloader Params +batch_size: 8 +num_workers: 12 + diff --git a/src/lobster/hydra_config/data/structure_pdb.yaml b/src/lobster/hydra_config/data/structure_pdb.yaml index a261ae6c..00e62005 100644 --- a/src/lobster/hydra_config/data/structure_pdb.yaml +++ b/src/lobster/hydra_config/data/structure_pdb.yaml @@ -4,10 +4,10 @@ defaults: - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything -path_to_datasets: ["/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", "/data2/lisanzas/latent_generator_files/pdb_data/split_data/validation.pt", "/data2/lisanzas/latent_generator_files/pdb_data/split_data/test.pt"] +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/validation.pt", "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/test.pt"] max_length: 512 # Maximum length of the sequence; not used -cluster_file: /data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt +cluster_file: "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt" files_to_keep: null # this doesnt do anything testing: false use_shards: false diff --git a/src/lobster/hydra_config/data/structure_pdb_afdb_swissprot.yaml b/src/lobster/hydra_config/data/structure_pdb_afdb_swissprot.yaml new file mode 100644 index 00000000..f939aa8d --- /dev/null +++ b/src/lobster/hydra_config/data/structure_pdb_afdb_swissprot.yaml @@ -0,0 +1,21 @@ +defaults: + - sampler: randomized_minority_upsampler.yaml # Dataloader batch sampler + - transform_fn: structure_backbone_aa_tokenizer_transform.yaml # Dataset transforms + - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader + +_target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule +root: /cv/data/ai4dd/data2/lisanzas/ # Base directory +# Training: PDB train + AFDB SwissProt train, Validation: AFDB SwissProt val, Test: AFDB SwissProt test +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt", "/cv/data/ai4dd/data2/lisanzas/AFDB/train_processed", "/cv/data/ai4dd/data2/lisanzas/AFDB/valid_cameo_processed", "/cv/data/ai4dd/data2/lisanzas/AFDB/test_multiflow_processed"] +max_length: 512 # Maximum length of the sequence; not used +# Cluster files: one for PDB, one for AFDB SwissProt +cluster_file_list: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", "/cv/data/ai4dd/data2/lisanzas/AFDB/pdb_swissprot_clusters.pt"] +files_to_keep: null # this doesnt do anything +files_to_keep_list: [null, null] # No filtering for either dataset +testing: false +use_shards: false +datasets: pdb_afdb_swissprot + +# Dataloader Params +batch_size: 100 # Batch size for the dataloader +num_workers: 12 # Number of workers for the dataloader diff --git a/src/lobster/hydra_config/data/structure_pdb_pinder.yaml b/src/lobster/hydra_config/data/structure_pdb_pinder.yaml index 6c5c72b8..9483de6e 100644 --- a/src/lobster/hydra_config/data/structure_pdb_pinder.yaml +++ b/src/lobster/hydra_config/data/structure_pdb_pinder.yaml @@ -3,12 +3,12 @@ defaults: - transform_fn: structure_backbone_transform.yaml # Dataset transforms _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything -path_to_datasets: ["/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt","/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/train.pt", "/data2/lisanzas/latent_generator_files/pdb_data/split_data/validation.pt", "/data2/lisanzas/latent_generator_files/pdb_data/split_data/test.pt"] +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/train.pt","/cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/train.pt", "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/validation.pt", "/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/split_data/test.pt"] max_length: 512 # Maximum length of the sequence; not used -cluster_file_list: ["/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", "/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/cluster_dict.pt"] +cluster_file_list: ["/cv/data/ai4dd/data2/lisanzas/latent_generator_files/pdb_data/pdb_seqid40_clusters.pt", "/cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/cluster_dict.pt"] files_to_keep: null # this doesnt do anything -files_to_keep_list: [null, "/data/lisanzas/structure_tokenizer/studies/data/pinder/pinder_to_keep.pkl"] +files_to_keep_list: [null, "/cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/pinder_to_keep.pkl"] testing: false use_shards: false datasets: pdb_pinder diff --git a/src/lobster/hydra_config/data/structure_pinder.yaml b/src/lobster/hydra_config/data/structure_pinder.yaml index 085929b6..b02a3b2e 100644 --- a/src/lobster/hydra_config/data/structure_pinder.yaml +++ b/src/lobster/hydra_config/data/structure_pinder.yaml @@ -4,11 +4,11 @@ defaults: - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything -path_to_datasets: ["/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/train.pt", "/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/val.pt", /data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/test.pt"] +root: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/train.pt", "/cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/val.pt", /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder_raw_pdbs_bb_coords/test.pt"] max_length: 512 -cluster_file: /data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/cluster_dict.pt -files_to_keep: /data/lisanzas/structure_tokenizer/studies/data/pinder/pinder_to_keep.pkl +cluster_file: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/processed_atom/pdb/all/cluster_dict.pt +files_to_keep: /cv/data/ai4dd/data/lisanzas/structure_tokenizer/studies/data/pinder/pinder_to_keep.pkl testing: false datasets: pinder use_shards: false diff --git a/src/lobster/hydra_config/data/structure_sabdab.yaml b/src/lobster/hydra_config/data/structure_sabdab.yaml index 88372d11..1bcacfce 100644 --- a/src/lobster/hydra_config/data/structure_sabdab.yaml +++ b/src/lobster/hydra_config/data/structure_sabdab.yaml @@ -4,8 +4,8 @@ defaults: - collate_fn: collate_fn_backbone_binder_target.yaml # Collate function for the dataloader _target_: lobster.data._coord_structure_datamodule.StructureLightningDataModule -root: /data2/lisanzas/sabdab/temp/ # this doesnt do anything -path_to_datasets: ["/data2/lisanzas/sabdab/train_denovo_processed_pt/train_denovo_data.pt", "/data2/lisanzas/sabdab/val_denovo_processed_pt/val_denovo_data.pt", "/data2/lisanzas/sabdab/test_denovo_processed_pt/test_dummy_denovo_data.pt"] +root: /cv/data/ai4dd/data2/lisanzas/sabdab/temp/ # this doesnt do anything +path_to_datasets: ["/cv/data/ai4dd/data2/lisanzas/sabdab/train_denovo_processed_pt/train_denovo_data.pt", "/cv/data/ai4dd/data2/lisanzas/sabdab/val_denovo_processed_pt/val_denovo_data.pt", "/cv/data/ai4dd/data2/lisanzas/sabdab/test_denovo_processed_pt/test_dummy_denovo_data.pt"] max_length: 512 # Maximum length of the sequence; not used cluster_file: null files_to_keep: null # this doesnt do anything diff --git a/src/lobster/hydra_config/experiment/generate_forward_folding.yaml b/src/lobster/hydra_config/experiment/generate_forward_folding.yaml index 9b14ac9d..92a86ea6 100644 --- a/src/lobster/hydra_config/experiment/generate_forward_folding.yaml +++ b/src/lobster/hydra_config/experiment/generate_forward_folding.yaml @@ -33,7 +33,7 @@ generation: # Directory: input_structures: "/path/to/structure/directory/" # Glob pattern: input_structures: "/path/to/structures/*.pdb" # List of files: input_structures: ["/path/to/file1.pdb", "/path/to/file2.pdb"] - input_structures: "test_data/inv_folding/9jl9.pdb" + input_structures: "/cv/data/ai4dd/data2/lisanzas/multi_flow_data/test_set_filtered_pt/*.pt" # Maximum length for structures max_length: 512 diff --git a/src/lobster/hydra_config/experiment/generate_inverse_folding.yaml b/src/lobster/hydra_config/experiment/generate_inverse_folding.yaml index b3b78a61..33706384 100644 --- a/src/lobster/hydra_config/experiment/generate_inverse_folding.yaml +++ b/src/lobster/hydra_config/experiment/generate_inverse_folding.yaml @@ -17,7 +17,7 @@ generation: mode: inverse_folding nsteps: 200 # Fewer steps for inverse folding batch_size: 1 # Process 1 structures at a time - n_trials: 3 # Number of trials for best output selection + n_trials: 1 # Number of trials for best output selection #Generation parameters temperature_seq: 0.16423763902324678 @@ -25,14 +25,17 @@ generation: stochasticity_seq: 20 stochasticity_struc: 10 - n_designs_per_structure: 10 # Number of designs to generate per structure + enable_sequence_token_check: true + sequence_token_check_retries: 500 + + n_designs_per_structure: 1 # Number of designs to generate per structure # Input structures - update these paths to your PDB files # Examples of different input formats: # Single file: input_structures: "/path/to/structure.pdb" # Directory: input_structures: "/path/to/pdb/directory/" # Glob pattern: input_structures: "/path/to/structures/*.pdb" # List of files: input_structures: ["/path/to/file1.pdb", "/path/to/file2.pdb"] - input_structures: "test_data/inv_folding/9jl9.pdb" + input_structures: "/cv/data/ai4dd/data2/lisanzas/multi_flow_data/test_set_filtered_pt/*.pt" # Enable ESMFold validation use_esmfold: true diff --git a/src/lobster/hydra_config/experiment/generate_unconditional.yaml b/src/lobster/hydra_config/experiment/generate_unconditional.yaml index 685de064..d94bf8e7 100644 --- a/src/lobster/hydra_config/experiment/generate_unconditional.yaml +++ b/src/lobster/hydra_config/experiment/generate_unconditional.yaml @@ -3,7 +3,7 @@ # Usage: uv run python -m lobster.cmdline.generate --config-path "../hydra_config/experiment" --config-name generate_unconditional # Output directory -output_dir: "./examples/generated_unconditional" +output_dir: "./examples/generated_unconditional_test" # Random seed for reproducibility seed: 12345 @@ -16,7 +16,7 @@ model: # Generation parameters generation: mode: unconditional - length: [500] + length: [100, 200, 300, 400, 500] num_samples: 10 nsteps: 1000 batch_size: 1 @@ -62,7 +62,7 @@ generation: enable_max_percent_identity_threshold: true # Enable maximum percent identity quality control max_percent_identity: 100 # Maximum percent identity - if too high, structure is too unique (%) enable_sequence_token_check: true # Enable check for invalid sequence tokens (mask/unknown amino acids) - max_retries: 30 # Maximum retry attempts per iteration + max_retries: 100 # Maximum retry attempts per iteration # Enable ESMFold validation (required to measure improvement) use_esmfold: true @@ -70,7 +70,7 @@ generation: # Foldseek Diversity Analysis (structural clustering) calculate_foldseek_diversity: true # Enable/disable diversity calculation - foldseek_bin_path: "/homefs/home/lisanzas/scratch/Develop/lobster/src/lobster/metrics/foldseek/bin" + foldseek_bin_path: "/cv/home/lisanzas/from_pcluster/scratch/Develop/lobster/src/lobster/metrics/foldseek/bin" foldseek_tmscore_threshold: 0.5 # TM-score threshold for Foldseek clustering rmsd_threshold_for_diversity: 2.0 # Only cluster structures with RMSD < this threshold diff --git a/src/lobster/hydra_config/experiment/train_latent_generator.yaml b/src/lobster/hydra_config/experiment/train_latent_generator.yaml new file mode 100644 index 00000000..763dcd21 --- /dev/null +++ b/src/lobster/hydra_config/experiment/train_latent_generator.yaml @@ -0,0 +1,35 @@ +# @package _global_ +defaults: + - override /data: structure_pdb + - override /model: latent_generator + - override /callbacks: latent_generator_defaults.yaml + +compile: false + +logger: + name: latent_generator-tokens_${model.quantizer.n_tokens}-enc_${model.structure_encoder.embed_dim_hidden}-dec_${model.decoder_factory.decoder_mapping.vit_decoder.struc_token_dim}_${data.datasets}_${paths.timestamp} + project: lobster_latent_generator + entity: null + save_dir: ${oc.env:LOBSTER_RUNS_DIR} + + +paths: + timestamp: ${now:%Y-%m-%d}T${now:%H-%M-%S} + output_dir: ${paths.root_dir}/${paths.timestamp} + root_dir: ${oc.env:LOBSTER_RUNS_DIR} + + +trainer: + num_nodes: 1 + max_epochs: -1 + gradient_clip_val: 0.5 + max_time: null + max_steps: -1 + val_check_interval: null + limit_val_batches: 50_000 + precision: bf16-mixed + accumulate_grad_batches: 16 + devices: auto + num_sanity_val_steps: 0 + + diff --git a/src/lobster/hydra_config/generate.yaml b/src/lobster/hydra_config/generate.yaml new file mode 100644 index 00000000..67dc61de --- /dev/null +++ b/src/lobster/hydra_config/generate.yaml @@ -0,0 +1,46 @@ +# @package _global_ +defaults: + - model: gen_ume + - generation: unconditional + - _self_ + +# Output directory for generated structures +output_dir: "./generated_structures" + +# Random seed for reproducibility +seed: 42 + +# Model configuration +model: + _target_: lobster.model.gen_ume.UMESequenceStructureEncoderLightningModule + ckpt_path: null # Path to model checkpoint + mask_token_id: 32 + pad_token_id: 1 + vocab_size: 33 + encoder_kwargs: + max_length: 512 + model_size: mini + +# Generation parameters +generation: + mode: unconditional # "unconditional" or "inverse_folding" + length: 100 # Length of sequences to generate + num_samples: 10 # Number of samples to generate + nsteps: 1000 # Number of generation steps + + # Temperature parameters + temperature_seq: 0.4579796403264936 # Temperature for sequence generation + temperature_struc: 0.35751879409731435 # Temperature for structure generation + + # Stochasticity parameters + stochasticity_seq: 30 # Stochasticity for sequence generation + stochasticity_struc: 70 # Stochasticity for structure generation + + # ESMFold validation + use_esmfold: false # Whether to validate with ESMFold + max_length: 512 # Max length for ESMFold + + # Inverse folding specific + input_structures: null # Path to input structures for inverse folding + +# Note: Datamodule no longer needed - inverse folding now loads PDB files directly diff --git a/src/lobster/hydra_config/model/gen_ume_protein_ligand.yaml b/src/lobster/hydra_config/model/gen_ume_protein_ligand.yaml new file mode 100644 index 00000000..3e364ea5 --- /dev/null +++ b/src/lobster/hydra_config/model/gen_ume_protein_ligand.yaml @@ -0,0 +1,57 @@ +# Model config for Gen-UME Protein-Ligand +# Extends Gen-UME to support protein-only AND protein-ligand tasks +_target_: lobster.model.gen_ume.ProteinLigandEncoderLightningModule + +# Protein sequence tokens +mask_token_id: 32 +pad_token_id: 1 +vocab_size: 33 + +# Ligand atom tokens (from ELEMENT_VOCAB_EXTENDED) +ligand_atom_vocab_size: 25 +ligand_mask_token_id: 1 # MASK token +ligand_pad_token_id: 0 # PAD token + +# Bond types (0=none, 1=single, 2=double, 3=triple, 4=aromatic, 5=masked) +num_bond_types: 6 + +# Training params +lr: 1e-4 +beta1: 0.9 +beta2: 0.98 +eps: 1e-12 +weight_decay: 0.01 +scheduler: cosine +scheduler_kwargs: + num_warmup_steps: 5000 + num_training_steps: 100000 +num_warmup_steps: 5000 +num_training_steps: 100000 + +# Loss weights +bond_loss_weight: 1.0 +ligand_atom_loss_weight: 1.0 +ligand_struct_loss_weight: 1.0 + +# SE(3) augmentation (random rotation + translation during training) +use_se3_augmentation: true +se3_translation_scale: 1.0 + +# LatentGenerator for structure encoding/decoding +# Uses FSQ (4375 tokens) for both protein and ligand structures +decode_tokens_during_training: true +latent_generator_model_name: "LG Protein Ligand fsq 4375" + +# Flow matching config +use_masked_prior: true +inverse_folding: false # Set to true for inverse folding training + +# Encoder architecture kwargs (passed to NeoBERT via ProteinLigandEncoderModule) +# Note: vocab sizes are computed dynamically from the LatentGenerator quantizer +ckpt_path: null +encoder_kwargs: + # max_length must accommodate protein (up to 512) + ligand (up to 200 atoms) + # This controls the rotary position embedding precomputation + max_length: 768 + model_size: small # Options: mini, small, medium, large + diff --git a/src/lobster/hydra_config/model/gen_ume_protein_ligand_diffusion.yaml b/src/lobster/hydra_config/model/gen_ume_protein_ligand_diffusion.yaml new file mode 100644 index 00000000..07493499 --- /dev/null +++ b/src/lobster/hydra_config/model/gen_ume_protein_ligand_diffusion.yaml @@ -0,0 +1,68 @@ +# Model config for Gen-UME Protein-Ligand with Diffusion Loss +# Uses continuous structure embeddings instead of discrete tokens +# The DiffusionLoss module (from MAR paper) models the per-token +# probability distribution in continuous space. +_target_: lobster.model.gen_ume.ProteinLigandEncoderLightningModule + +# Protein sequence tokens +mask_token_id: 32 +pad_token_id: 1 +vocab_size: 33 + +# Ligand atom tokens (from ELEMENT_VOCAB_EXTENDED) +ligand_atom_vocab_size: 25 +ligand_mask_token_id: 1 # MASK token +ligand_pad_token_id: 0 # PAD token + +# Bond types (0=none, 1=single, 2=double, 3=triple, 4=aromatic, 5=masked) +num_bond_types: 6 + +# Training params +lr: 1e-4 +beta1: 0.9 +beta2: 0.98 +eps: 1e-12 +weight_decay: 0.01 +scheduler: cosine +scheduler_kwargs: + num_warmup_steps: 5000 + num_training_steps: 100000 +num_warmup_steps: 5000 +num_training_steps: 100000 + +# Loss weights +bond_loss_weight: 1.0 +ligand_atom_loss_weight: 1.0 +ligand_struct_loss_weight: 1.0 + +# LatentGenerator for structure encoding/decoding +# Note: For diffusion loss, you need a continuous LatentGenerator (quantizer=null) +# "LG Protein Ligand cont" has 256-dim continuous embeddings +decode_tokens_during_training: true +latent_generator_model_name: "LG Protein Ligand cont" + +# Flow matching config (still used for position/mask selection) +use_masked_prior: true +inverse_folding: false # Set to true for inverse folding training + +# === DIFFUSION LOSS CONFIG (NEW) === +# Option A: Hybrid - discrete flow matching for position selection, +# DiffusionLoss for continuous embedding prediction +use_diffusion_loss_structure: true +diffusion_target_dim: 256 # Must match LatentGenerator structure_encoder.embed_dim +diffusion_z_dim: null # Auto-detect from encoder hidden_size (typically 768) +diffusion_depth: 3 # MLP depth (from MAR: diffloss_d) +diffusion_width: 1024 # MLP width (from MAR: diffloss_w) +diffusion_num_sampling_steps: "100" # Steps for generation +diffusion_noise_schedule: cosine # "linear" or "cosine" +diffusion_loss_weight: 1.0 # Weight for structure diffusion loss + +# Encoder architecture kwargs (passed to NeoBERT via ProteinLigandEncoderModule) +# Note: vocab sizes are computed dynamically from the LatentGenerator quantizer +ckpt_path: null +encoder_kwargs: + # max_length must accommodate protein (up to 512) + ligand (up to 200 atoms) + # This controls the rotary position embedding precomputation + max_length: 768 + model_size: small # Options: mini, small, base, large + diff --git a/src/lobster/hydra_config/model/latent_generator.yaml b/src/lobster/hydra_config/model/latent_generator.yaml new file mode 100644 index 00000000..be246f28 --- /dev/null +++ b/src/lobster/hydra_config/model/latent_generator.yaml @@ -0,0 +1,75 @@ + +_target_: lobster.model.latent_generator.tokenizer.TokenizerMulti + +ckpt_path: null + +structure_encoder: + _target_: lobster.model.latent_generator.structure_encoder.ViTEncoder + data_fixed_size: 512 + n_atoms: 3 + model_n_channel: 20 + uvit_n_layers: 6 + uvit_n_heads: 8 + uvit_dim_head: 32 + uvit_position_embedding_type: 'rotary' + embed_dim: 4 + embed_dim_hidden: 256 + backbone_noise: 0.30 + concat_sine_pw: true + encode_ligand: false + ligand_atom_embedding: false + +quantizer: + _target_: lobster.model.latent_generator.quantizer.SimpleLinearQuantizer + n_tokens: 256 + gumbel: true + tau: 0.5 + embed_dim: 4 + +decoder_factory: + _target_: lobster.model.latent_generator.structure_decoder.DecoderFactory.from_mapping + decoder2loss_dict: + vit_decoder: [l2_loss, pairwise_l2_loss] + decoder_mapping: + vit_decoder: + _target_: lobster.model.latent_generator.structure_decoder.ViTDecoder + data_fixed_size: 512 + n_atoms: 3 + uvit_n_layers: 6 + uvit_n_heads: 8 + uvit_dim_head: 32 + uvit_position_embedding_type: 'rotary' + indexed: false + struc_token_dim: 512 + struc_token_codebook_size: 256 + encode_ligand: false + +loss_factory: + _target_: lobster.model.latent_generator.tokenizer.LossFactory.from_mapping + weight_dict: + pairwise_l2_loss: 1.0 + l2_loss: 0.01 + loss_mapping: + pairwise_l2_loss: + _target_: lobster.model.latent_generator.tokenizer.PairWiseL2Loss + l2_loss: + _target_: lobster.model.latent_generator.tokenizer.L2Loss + +optim: + _target_: torch.optim.Adam + _partial_: true + lr: 1e-4 + weight_decay: 0.0 + +num_warmup_steps: 5000 +num_training_steps: 50000 + +lr_scheduler: + _target_: transformers.get_cosine_schedule_with_warmup + _partial_: true + num_warmup_steps: 5000 + num_training_steps: 50000 + + + + diff --git a/src/lobster/hydra_config/model/latent_generator_ligand.yaml b/src/lobster/hydra_config/model/latent_generator_ligand.yaml new file mode 100644 index 00000000..a016d4b3 --- /dev/null +++ b/src/lobster/hydra_config/model/latent_generator_ligand.yaml @@ -0,0 +1,90 @@ + +_target_: lobster.model.latent_generator.tokenizer.TokenizerMulti + +ckpt_path: null + +structure_encoder: + _target_: lobster.model.latent_generator.structure_encoder.ViTEncoder + data_fixed_size: 512 + n_atoms: 3 + model_n_channel: 20 + uvit_n_layers: 6 + uvit_n_heads: 8 + uvit_dim_head: 32 + uvit_position_embedding_type: 'rotary' + embed_dim: 4 + embed_dim_hidden: 256 + backbone_noise: 0.00 #0.30 + concat_sine_pw: true + encode_ligand: true + ligand_atom_embedding: false + +quantizer: + _target_: lobster.model.latent_generator.quantizer.LigandTokenizer + n_tokens: 256 + gumbel: true + tau: 0.5 + embed_dim: 4 + ligand_n_tokens: 512 + ligand_embed_dim: 4 + ligand_softmax: false + ligand_emb_noise: 0.0 + ligand_gumbel: true + ligand_use_gumbel_noise: true + ligand_tau: 0.5 + +decoder_factory: + _target_: lobster.model.latent_generator.structure_decoder.DecoderFactory.from_mapping + decoder2loss_dict: + vit_decoder: [l2_loss, pairwise_l2_loss, ligand_l2_loss, ligand_pairwise_l2_loss] + decoder_mapping: + vit_decoder: + _target_: lobster.model.latent_generator.structure_decoder.ViTDecoder + data_fixed_size: 512 + n_atoms: 3 + uvit_n_layers: 6 + uvit_n_heads: 8 + uvit_dim_head: 32 + uvit_position_embedding_type: 'rotary' + indexed: false + struc_token_dim: 512 + struc_token_codebook_size: 256 + encode_ligand: true + ligand_struc_token_codebook_size: 512 + ligand_struc_token_dim: 512 + +loss_factory: + _target_: lobster.model.latent_generator.tokenizer.LossFactory.from_mapping + weight_dict: + pairwise_l2_loss: 1.0 + l2_loss: 0.01 + ligand_pairwise_l2_loss: 1.0 + ligand_l2_loss: 0.01 + loss_mapping: + pairwise_l2_loss: + _target_: lobster.model.latent_generator.tokenizer.PairWiseL2Loss + l2_loss: + _target_: lobster.model.latent_generator.tokenizer.L2Loss + ligand_pairwise_l2_loss: + _target_: lobster.model.latent_generator.tokenizer.LigandPairWiseL2Loss + ligand_l2_loss: + _target_: lobster.model.latent_generator.tokenizer.LigandL2Loss + +optim: + _target_: torch.optim.Adam + _partial_: true + lr: 1e-4 + weight_decay: 0.0 + +num_warmup_steps: 5000 +num_training_steps: 50000 + +lr_scheduler: + _target_: transformers.get_cosine_schedule_with_warmup + _partial_: true + num_warmup_steps: 5000 + num_training_steps: 50000 + + + + diff --git a/src/lobster/hydra_config/model/latent_generator_ligand_fsq.yaml b/src/lobster/hydra_config/model/latent_generator_ligand_fsq.yaml new file mode 100644 index 00000000..a3dfc37d --- /dev/null +++ b/src/lobster/hydra_config/model/latent_generator_ligand_fsq.yaml @@ -0,0 +1,80 @@ + +_target_: lobster.model.latent_generator.tokenizer.TokenizerMulti + +ckpt_path: null + +structure_encoder: + _target_: lobster.model.latent_generator.structure_encoder.ViTEncoder + data_fixed_size: 512 + n_atoms: 3 + model_n_channel: 20 + uvit_n_layers: 6 + uvit_n_heads: 8 + uvit_dim_head: 32 + uvit_position_embedding_type: 'rotary' + embed_dim: 3 # FSQ uses levels to determine output dim (3 for [8,6,5]) + embed_dim_hidden: 256 + backbone_noise: 0.00 + concat_sine_pw: true + encode_ligand: true + ligand_atom_embedding: false + +quantizer: + _target_: lobster.model.latent_generator.quantizer.FSQLigandTokenizer + protein_levels: [8, 6, 5] # 240 tokens for protein + ligand_levels: [8, 6, 5] # 240 tokens for ligand + return_oh_like: true + n_tokens: 240 # For logger interpolation (product of levels: 8*6*5) + +decoder_factory: + _target_: lobster.model.latent_generator.structure_decoder.DecoderFactory.from_mapping + decoder2loss_dict: + vit_decoder: [l2_loss, pairwise_l2_loss, ligand_l2_loss, ligand_pairwise_l2_loss] + decoder_mapping: + vit_decoder: + _target_: lobster.model.latent_generator.structure_decoder.ViTDecoder + data_fixed_size: 512 + n_atoms: 3 + uvit_n_layers: 6 + uvit_n_heads: 8 + uvit_dim_head: 32 + uvit_position_embedding_type: 'rotary' + indexed: false + struc_token_dim: 512 + struc_token_codebook_size: 240 # Product of protein_levels [8*6*5] + encode_ligand: true + ligand_struc_token_codebook_size: 240 # Product of ligand_levels [8*6*5] + ligand_struc_token_dim: 512 + +loss_factory: + _target_: lobster.model.latent_generator.tokenizer.LossFactory.from_mapping + weight_dict: + pairwise_l2_loss: 1.0 + l2_loss: 0.01 + ligand_pairwise_l2_loss: 1.0 + ligand_l2_loss: 0.01 + loss_mapping: + pairwise_l2_loss: + _target_: lobster.model.latent_generator.tokenizer.PairWiseL2Loss + l2_loss: + _target_: lobster.model.latent_generator.tokenizer.L2Loss + ligand_pairwise_l2_loss: + _target_: lobster.model.latent_generator.tokenizer.LigandPairWiseL2Loss + ligand_l2_loss: + _target_: lobster.model.latent_generator.tokenizer.LigandL2Loss + +optim: + _target_: torch.optim.Adam + _partial_: true + lr: 1e-4 + weight_decay: 0.0 + +num_warmup_steps: 5000 +num_training_steps: 50000 + +lr_scheduler: + _target_: transformers.get_cosine_schedule_with_warmup + _partial_: true + num_warmup_steps: 5000 + num_training_steps: 50000 + diff --git a/src/lobster/metrics/_alphafold2_scores.py b/src/lobster/metrics/_alphafold2_scores.py index 22fd6507..2729adc0 100644 --- a/src/lobster/metrics/_alphafold2_scores.py +++ b/src/lobster/metrics/_alphafold2_scores.py @@ -98,8 +98,6 @@ def alphafold2_complex_scores( num_recycles=num_recycles, data_dir=alphafold_weights_dir, use_multimer=use_multimer, - use_initial_guess=False, - use_initial_atom_pos=False, ) complex_model.prep_inputs( @@ -201,8 +199,6 @@ def alphafold2_binder_scores( binder_model = mk_afdesign_model( protocol="hallucination", use_templates=False, - initial_guess=False, - use_initial_atom_pos=False, num_recycles=num_recycles, data_dir=alphafold_weights_dir, use_multimer=use_multimer, diff --git a/src/lobster/metrics/_generation_utils.py b/src/lobster/metrics/_generation_utils.py index 58a61d21..aff2dff0 100644 --- a/src/lobster/metrics/_generation_utils.py +++ b/src/lobster/metrics/_generation_utils.py @@ -633,21 +633,25 @@ def predict_structure_with_esmfold( # 3. Add linkers between chains for ESMFold sequence_str, position_ids, linker_mask = add_linker_to_sequence(sequence_str) - # 4. Tokenize the sequence + # 4. Check if sequence exceeds max length (ESMFold limit) + max_length = cfg.generation.get("max_length", 512) + if len(sequence_str) > max_length: + return None # Skip ESMFold for sequences that are too long + + # 5. Tokenize the sequence tokenized_input = plm_fold.tokenizer.encode_plus( sequence_str, padding=True, - truncation=True, - max_length=cfg.generation.get("max_length", 512), + truncation=False, add_special_tokens=False, return_tensors="pt", )["input_ids"].to(device) - # 5. Fold with ESMFold + # 6. Fold with ESMFold with torch.no_grad(): outputs = plm_fold.model(tokenized_input, position_ids=position_ids.unsqueeze(0).to(device)) - # 6. Remove linkers from outputs + # 7. Remove linkers from outputs outputs["positions"] = outputs["positions"][:, :, linker_mask == 1, :, :] outputs["plddt"] = outputs["plddt"][:, linker_mask == 1] outputs["predicted_aligned_error"] = outputs["predicted_aligned_error"][:, linker_mask == 1] @@ -656,12 +660,12 @@ def predict_structure_with_esmfold( sequence_list = list(sequence_str) sequence_str = "".join([seq_char for seq_char, mask_val in zip(sequence_list, linker_mask) if mask_val == 1]) - # 7. Get folded structure metrics (TM-score, etc.) + # 8. Get folded structure metrics (TM-score, etc.) folded_structure_metrics, pred_coords = get_folded_structure_metrics( outputs, orig_coords[None], [sequence_str], mask=mask_i[None], device=device ) - # 8. Prepare return dictionary with common results + # 9. Prepare return dictionary with common results result = { "folded_structure_metrics": folded_structure_metrics, "pred_coords": pred_coords, @@ -672,7 +676,7 @@ def predict_structure_with_esmfold( "num_chains": len(chains_i.unique()), } - # 9. OPTIONAL: Align generated coords to prediction (inpainting mode only) + # 10. OPTIONAL: Align generated coords to prediction (inpainting mode only) if gen_coords is not None: gen_coords_aligned, rmsd_inpainted = align_and_compute_rmsd_inpainted( gen_coords=gen_coords, @@ -1154,26 +1158,41 @@ def _create_percent_identity_correlation_plot(self, df: pd.DataFrame): class MetricsCSVWriter: """Helper class to write metrics to CSV files.""" - def __init__(self, output_dir: Path, mode: str): + def __init__(self, output_dir: Path, mode: str, resume: bool = False): """Initialize CSV writer for a specific generation mode. Args: output_dir: Directory to save CSV files mode: Generation mode (unconditional, inverse_folding, forward_folding, inpainting) + resume: If True, append to existing CSV files instead of creating new ones """ self.output_dir = output_dir self.mode = mode self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - # Create CSV file path for metrics - self.csv_path = output_dir / f"{mode}_metrics_{self.timestamp}.csv" - - # Create CSV file path for sequences - self.sequences_csv_path = output_dir / f"sequences_{mode}_{self.timestamp}.csv" + if resume: + existing_metrics = sorted(Path(output_dir).glob(f"{mode}_metrics_*.csv"), key=lambda x: x.stat().st_mtime) + existing_sequences = sorted( + Path(output_dir).glob(f"sequences_{mode}_*.csv"), key=lambda x: x.stat().st_mtime + ) + if existing_metrics: + self.csv_path = existing_metrics[-1] + logger.info(f"Resume mode: appending to existing metrics CSV: {self.csv_path}") + else: + self.csv_path = output_dir / f"{mode}_metrics_{self.timestamp}.csv" + self._initialize_csv() - # Initialize CSV files with headers - self._initialize_csv() - self._initialize_sequences_csv() + if existing_sequences: + self.sequences_csv_path = existing_sequences[-1] + logger.info(f"Resume mode: appending to existing sequences CSV: {self.sequences_csv_path}") + else: + self.sequences_csv_path = output_dir / f"sequences_{mode}_{self.timestamp}.csv" + self._initialize_sequences_csv() + else: + self.csv_path = output_dir / f"{mode}_metrics_{self.timestamp}.csv" + self.sequences_csv_path = output_dir / f"sequences_{mode}_{self.timestamp}.csv" + self._initialize_csv() + self._initialize_sequences_csv() def _initialize_csv(self): """Initialize CSV file with appropriate headers based on mode.""" @@ -1275,6 +1294,7 @@ def _initialize_sequences_csv(self): "percent_identity_original", "masked_positions", "sequence_type", + "latent_generator_tokens", "timestamp", ] @@ -1508,6 +1528,7 @@ def write_sequences( sequence_type: str | None = None, percent_identities: list[float] | None = None, masked_positions: list[list[int]] | None = None, + latent_generator_tokens: list[str] | None = None, ): """Write generated sequences to CSV for diversity analysis. @@ -1569,6 +1590,11 @@ def write_sequences( if masked_positions and sample_idx < len(masked_positions): masked_pos = ",".join(map(str, masked_positions[sample_idx])) + # Handle latent_generator_tokens + tokens_str = "" + if latent_generator_tokens and sample_idx < len(latent_generator_tokens): + tokens_str = latent_generator_tokens[sample_idx] + writer.writerow( [ run_id, @@ -1587,6 +1613,7 @@ def write_sequences( percent_id, masked_pos, sequence_type or "", + tokens_str, timestamp, ] ) diff --git a/src/lobster/metrics/cal_foldseek_clusters.py b/src/lobster/metrics/cal_foldseek_clusters.py index 9a20c896..bf4b58c1 100644 --- a/src/lobster/metrics/cal_foldseek_clusters.py +++ b/src/lobster/metrics/cal_foldseek_clusters.py @@ -147,6 +147,10 @@ def copy_structures_by_rmsd(output_dir, length, rmsd_threshold=2.0): return None, 0 df_length = df[df["sequence_length"] == length].copy() + if "run_id" in df_length.columns: + df_length = df_length.drop_duplicates(subset="run_id", keep="last") + if "rmsd" in df_length.columns: + df_length = df_length[df_length["rmsd"].notna() & (df_length["rmsd"] != "")] logger.info(f"Found {len(df_length)} total structures for length {length}") if len(df_length) == 0: diff --git a/src/lobster/metrics/evaluate_protein_forward_folding_baseline.py b/src/lobster/metrics/evaluate_protein_forward_folding_baseline.py new file mode 100644 index 00000000..3c3d6184 --- /dev/null +++ b/src/lobster/metrics/evaluate_protein_forward_folding_baseline.py @@ -0,0 +1,560 @@ +#!/usr/bin/env python +"""Standalone baseline evaluation script for protein-only forward folding (structure prediction). + +Evaluates structure prediction (TM-score, RMSD) on proteins using a protein-only Gen-UME model. +This serves as a baseline comparison for protein-ligand models. + +Usage: + # Evaluate a Gen-UME protein-only checkpoint + uv run python -m lobster.metrics.evaluate_protein_forward_folding_baseline \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv + + # With structure saving + uv run python -m lobster.metrics.evaluate_protein_forward_folding_baseline \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv \ + --structure_path ./structures/ \ + --save_structures \ + --save_gt_structure +""" + +import argparse +import json +import os +import sys +from glob import glob + +import pandas as pd +import torch +from loguru import logger +from tmtools import tm_align +from torch import Tensor +from tqdm import tqdm + +from lobster.metrics import align_and_compute_rmsd +from lobster.model.latent_generator.io import writepdb +from lobster.model.latent_generator.utils.residue_constants import restype_order_with_x_inv +from lobster.transforms._structure_transforms import AminoAcidTokenizerTransform + + +def load_model(checkpoint_path: str, device: str = "cuda"): + """Load a Gen-UME protein-only model from checkpoint. + + Parameters + ---------- + checkpoint_path : str + Path to the model checkpoint (.ckpt file) + device : str + Device to load model on + + Returns + ------- + model : LightningModule + The loaded model + """ + from lobster.model.gen_ume import UMESequenceStructureEncoderLightningModule + + logger.info(f"Loading protein-only model from {checkpoint_path}") + + # Load checkpoint + model = UMESequenceStructureEncoderLightningModule.load_from_checkpoint( + checkpoint_path, + map_location=device, + strict=False, + ) + model.eval() + model.to(device) + + # Get max_length from encoder config + max_length = None + if hasattr(model, "encoder") and hasattr(model.encoder, "neobert"): + if hasattr(model.encoder.neobert, "config") and hasattr(model.encoder.neobert.config, "max_length"): + max_length = model.encoder.neobert.config.max_length + model.max_length = max_length + + logger.info(f"Model loaded successfully. Max length: {max_length}") + return model + + +class ProteinForwardFoldingBaselineEvaluator: + """Evaluates forward folding on proteins using a protein-only model. + + This evaluator serves as a baseline for protein-ligand forward folding evaluation. + It measures structure prediction quality without any ligand context. + + Parameters + ---------- + data_dir : str + Path to PDBBind test directory containing *_protein.pt and *_ligand.pt pairs + pocket_distance_threshold : float + Distance threshold (Å) for defining binding pocket residues + num_samples : int, optional + Limit number of samples to evaluate (None = all) + nsteps : int + Number of diffusion steps for generation + device : str + Device for computation + max_length : int + Maximum protein sequence length to process (default: 512). + temperature_seq : float + Temperature for sequence sampling + temperature_struc : float + Temperature for structure sampling + save_structures : bool + Whether to save predicted structures as PDB files (default: False). + save_gt_structure : bool + Whether to save ground truth structures as PDB files (default: False). + """ + + def __init__( + self, + data_dir: str, + pocket_distance_threshold: float = 5.0, + num_samples: int | None = None, + nsteps: int = 100, + device: str = "cuda", + max_length: int = 512, + temperature_seq: float = 0.5, + temperature_struc: float = 0.5, + save_structures: bool = False, + save_gt_structure: bool = False, + ): + self.data_dir = data_dir + self.pocket_distance_threshold = pocket_distance_threshold + self.num_samples = num_samples + self.nsteps = nsteps + self.device = device + self.max_length = max_length + self.temperature_seq = temperature_seq + self.temperature_struc = temperature_struc + self.save_structures = save_structures + self.save_gt_structure = save_gt_structure + + # Initialize tokenizer transform for sequence conversion + self.tokenizer_transform = AminoAcidTokenizerTransform(max_length=max_length) + + # Element vocabulary for ligand (used for pocket computation) + self.element_to_idx = { + "PAD": 0, + "MASK": 1, + "UNK": 2, + "C": 3, + "N": 4, + "O": 5, + "S": 6, + "P": 7, + "H": 8, + "F": 9, + "Cl": 10, + "Br": 11, + "I": 12, + "Fe": 13, + "Zn": 14, + "Mg": 15, + "Ca": 16, + "Mn": 17, + "Cu": 18, + "B": 19, + "Si": 20, + "Se": 21, + "Co": 22, + "Ni": 23, + "Bi": 24, + } + + def load_test_set(self) -> list[dict]: + """Load PDBBind test protein-ligand pairs. + + Returns list of dicts with protein and ligand data (ligand used only for pocket definition). + """ + protein_files = sorted(glob(os.path.join(self.data_dir, "*_protein.pt"))) + + if not protein_files: + raise ValueError(f"No protein files found in {self.data_dir}") + + if self.num_samples is not None: + protein_files = protein_files[: self.num_samples] + + logger.info(f"Loading {len(protein_files)} protein-ligand pairs from {self.data_dir}") + + samples = [] + for pf in tqdm(protein_files, desc="Loading samples"): + pdb_id = os.path.basename(pf).replace("_protein.pt", "") + ligand_file = pf.replace("_protein.pt", "_ligand.pt") + + if not os.path.exists(ligand_file): + logger.warning(f"Missing ligand file for {pdb_id}, skipping") + continue + + protein_data = torch.load(pf, weights_only=False, map_location=self.device) + ligand_data = torch.load(ligand_file, weights_only=False, map_location=self.device) + + protein_coords = protein_data.get("coords_res", protein_data.get("coords")) + protein_sequence = protein_data.get("sequence") + + if protein_coords is None or protein_sequence is None: + logger.warning(f"Missing protein data for {pdb_id}, skipping") + continue + + protein_mask = protein_data.get("mask", torch.ones(protein_coords.shape[0], device=self.device)) + protein_indices = protein_data.get("indices", torch.arange(protein_coords.shape[0], device=self.device)) + + # Load ligand coords for pocket computation + ligand_coords = ligand_data.get("atom_coords", ligand_data.get("coords", ligand_data.get("ligand_coords"))) + if ligand_coords is None: + logger.warning(f"Missing ligand coordinates for {pdb_id}, skipping") + continue + + samples.append( + { + "pdb_id": pdb_id, + "protein_coords": protein_coords, + "protein_sequence": protein_sequence, + "protein_mask": protein_mask, + "protein_indices": protein_indices, + "ligand_coords": ligand_coords, # Only for pocket computation + } + ) + + logger.info(f"Loaded {len(samples)} valid samples") + return samples + + def compute_binding_pocket( + self, + protein_coords: Tensor, + ligand_coords: Tensor, + protein_mask: Tensor | None = None, + ) -> Tensor: + """Compute pocket mask based on distance to ligand.""" + if protein_coords.dim() == 3: + ca_coords = protein_coords[:, 1, :] + else: + ca_coords = protein_coords + + distances = torch.cdist(ca_coords.unsqueeze(0), ligand_coords.unsqueeze(0)).squeeze(0) + min_distances = distances.min(dim=1).values + pocket_mask = min_distances < self.pocket_distance_threshold + + if protein_mask is not None: + pocket_mask = pocket_mask & protein_mask.bool() + + return pocket_mask + + def forward_fold(self, model, sample: dict) -> dict: + """Run forward folding on a protein sample. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-only model + sample : dict + Sample dictionary from load_test_set() + + Returns + ------- + dict with: + - predicted_coords: Tensor [L, 3, 3] (N, CA, C backbone) + - structure_tokens: Tensor [L] + """ + protein_mask = sample["protein_mask"].unsqueeze(0).float() + protein_indices = sample["protein_indices"].unsqueeze(0).long() + length = int(protein_mask.sum().item()) + + # Tokenize sequence for forward folding + gt_seq = sample["protein_sequence"] + tokenized_data = self.tokenizer_transform({"sequence": gt_seq.cpu()}) + tokenized_seq = tokenized_data["sequence"].to(self.device).unsqueeze(0) + + # Generate sample (forward folding mode) + with torch.no_grad(): + result = model.generate_sample( + length=length, + num_samples=1, + forward_folding=True, + nsteps=self.nsteps, + temperature_seq=self.temperature_seq, + temperature_struc=self.temperature_struc, + input_sequence_tokens=tokenized_seq, + input_mask=protein_mask, + input_indices=protein_indices, + ) + + # Decode structure + decoded_x = model.decode_structure(result, protein_mask) + + # Extract coordinates + predicted_coords = None + for decoder_name in decoded_x: + if "vit_decoder" == decoder_name: + vit_output = decoded_x[decoder_name] + if isinstance(vit_output, dict): + predicted_coords = vit_output.get("protein_coords", vit_output.get("coords")) + else: + predicted_coords = vit_output + break + + if predicted_coords is None: + raise RuntimeError("No vit_decoder found in decoded structures") + + structure_tokens = result.get("generated_struc_tokens") + + return { + "predicted_coords": predicted_coords.squeeze(0), + "structure_tokens": structure_tokens.squeeze(0) if structure_tokens is not None else None, + } + + def compute_tm_score( + self, + pred_coords: Tensor, + gt_coords: Tensor, + sequence: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute TM-score between predicted and ground truth structures.""" + if mask is not None: + mask = mask.bool() + pred_coords = pred_coords[mask] + gt_coords = gt_coords[mask] + sequence = sequence[mask] + + if len(pred_coords) == 0: + return float("nan") + + sequence_str = "".join([restype_order_with_x_inv.get(int(s), "X") for s in sequence.cpu().tolist()]) + pred_ca = pred_coords[:, 1, :].detach().cpu().numpy() + gt_ca = gt_coords[:, 1, :].detach().cpu().numpy() + + tm_out = tm_align(pred_ca, gt_ca, sequence_str, sequence_str) + return tm_out.tm_norm_chain1 + + def compute_rmsd( + self, + pred_coords: Tensor, + gt_coords: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute RMSD between predicted and ground truth structures.""" + if mask is not None: + mask = mask.bool() + pred_coords = pred_coords[mask] + gt_coords = gt_coords[mask] + + if len(pred_coords) == 0: + return float("nan") + + rmsd = align_and_compute_rmsd( + coords1=pred_coords.detach(), + coords2=gt_coords.detach(), + mask=None, + return_aligned=False, + device=pred_coords.device, + ) + return float(rmsd) + + def evaluate(self, model, samples: list[dict] | None = None, structure_path: str | None = None) -> dict: + """Run full evaluation on PDBBind test set.""" + model.eval() + model.to(self.device) + + if samples is None: + samples = self.load_test_set() + + if structure_path: + os.makedirs(structure_path, exist_ok=True) + + results = [] + skipped_samples = [] + + for sample in tqdm(samples, desc="Evaluating forward folding (baseline)"): + pdb_id = sample["pdb_id"] + gt_seq = sample["protein_sequence"] + gt_coords = sample["protein_coords"] + protein_mask = sample["protein_mask"] + + protein_length = len(gt_seq) + + if protein_length > self.max_length: + logger.warning( + f"Skipping {pdb_id}: protein length {protein_length} exceeds max_length {self.max_length}" + ) + skipped_samples.append({"pdb_id": pdb_id, "protein_length": protein_length}) + continue + + # Compute binding pocket (using ligand coords) + pocket_mask = self.compute_binding_pocket(gt_coords, sample["ligand_coords"], protein_mask) + non_pocket_mask = protein_mask.bool() & ~pocket_mask + + # Run forward folding + pred_result = self.forward_fold(model, sample) + pred_coords = pred_result["predicted_coords"] + + # Save structures if requested + if structure_path: + if self.save_gt_structure: + gt_pdb_path = os.path.join(structure_path, f"{pdb_id}_gt_protein.pdb") + writepdb(gt_pdb_path, gt_coords, gt_seq) + + if self.save_structures: + pred_pdb_path = os.path.join(structure_path, f"{pdb_id}_pred_baseline.pdb") + writepdb(pred_pdb_path, pred_coords.detach(), gt_seq) + + # Compute metrics + result = { + "pdb_id": pdb_id, + "length": len(gt_seq), + "n_pocket_residues": int(pocket_mask.sum().item()), + "n_nonpocket_residues": int(non_pocket_mask.sum().item()), + "tm_score": self.compute_tm_score(pred_coords, gt_coords, gt_seq, protein_mask), + "rmsd_overall": self.compute_rmsd(pred_coords, gt_coords, protein_mask), + "rmsd_pocket": self.compute_rmsd(pred_coords, gt_coords, pocket_mask), + "rmsd_nonpocket": self.compute_rmsd(pred_coords, gt_coords, non_pocket_mask), + } + results.append(result) + + if skipped_samples: + logger.info(f"Skipped {len(skipped_samples)} samples due to length > {self.max_length}") + + results_df = pd.DataFrame(results) + + if len(results_df) == 0: + logger.warning("No samples were successfully evaluated") + return {"results_df": results_df, "summary": {}} + + summary = { + "mean_tm_score": results_df["tm_score"].mean(), + "std_tm_score": results_df["tm_score"].std(), + "mean_rmsd_overall": results_df["rmsd_overall"].mean(), + "std_rmsd_overall": results_df["rmsd_overall"].std(), + "mean_rmsd_pocket": results_df["rmsd_pocket"].mean(), + "std_rmsd_pocket": results_df["rmsd_pocket"].std(), + "mean_rmsd_nonpocket": results_df["rmsd_nonpocket"].mean(), + "std_rmsd_nonpocket": results_df["rmsd_nonpocket"].std(), + "n_samples": len(results_df), + "mean_pocket_size": results_df["n_pocket_residues"].mean(), + } + + return {"results_df": results_df, "summary": summary} + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate protein-only forward folding (baseline for protein-ligand comparison)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt file)") + parser.add_argument( + "--data_dir", + type=str, + default="/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/", + help="Path to PDBBind test directory", + ) + parser.add_argument( + "--output", type=str, default="protein_forward_folding_baseline_results.csv", help="Output CSV file" + ) + parser.add_argument("--output_json", type=str, default=None, help="Output JSON file for summary statistics") + parser.add_argument("--structure_path", type=str, default=None, help="Directory to save structures (PDB)") + parser.add_argument("--pocket_threshold", type=float, default=5.0, help="Distance threshold (Å) for binding pocket") + parser.add_argument("--num_samples", type=int, default=None, help="Number of samples to evaluate") + parser.add_argument("--nsteps", type=int, default=100, help="Number of diffusion steps") + parser.add_argument("--max_length", type=int, default=768, help="Maximum protein sequence length") + parser.add_argument("--temperature_seq", type=float, default=0.5, help="Temperature for sequence sampling") + parser.add_argument("--temperature_struc", type=float, default=0.5, help="Temperature for structure sampling") + parser.add_argument("--save_structures", action="store_true", help="Save predicted structures as PDB files") + parser.add_argument("--save_gt_structure", action="store_true", help="Save ground truth structures as PDB files") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device") + + args = parser.parse_args() + + # Skip existence check for S3 paths (they're handled by the model loader) + if not args.checkpoint.startswith("s3://") and not os.path.exists(args.checkpoint): + logger.error(f"Checkpoint not found: {args.checkpoint}") + sys.exit(1) + + if not os.path.exists(args.data_dir): + logger.error(f"Data directory not found: {args.data_dir}") + sys.exit(1) + + # Load model + model = load_model(args.checkpoint, args.device) + + max_length = args.max_length + if hasattr(model, "max_length") and model.max_length is not None: + max_length = min(max_length, model.max_length) + logger.info(f"Using max_length: {max_length}") + + # Create evaluator + evaluator = ProteinForwardFoldingBaselineEvaluator( + data_dir=args.data_dir, + pocket_distance_threshold=args.pocket_threshold, + num_samples=args.num_samples, + nsteps=args.nsteps, + device=args.device, + max_length=max_length, + temperature_seq=args.temperature_seq, + temperature_struc=args.temperature_struc, + save_structures=args.save_structures, + save_gt_structure=args.save_gt_structure, + ) + + # Load test set + logger.info(f"Loading test set from {args.data_dir}") + samples = evaluator.load_test_set() + logger.info(f"Loaded {len(samples)} samples") + + # Run evaluation + logger.info("Starting evaluation...") + results = evaluator.evaluate(model=model, samples=samples, structure_path=args.structure_path) + + results_df = results["results_df"] + summary = results["summary"] + + results_df.to_csv(args.output, index=False) + logger.info(f"Saved per-structure results to {args.output}") + + # Print summary + print("\n" + "=" * 70) + print("PROTEIN FORWARD FOLDING BASELINE RESULTS") + print("=" * 70) + print(f"\nDataset: {args.data_dir}") + print(f"Checkpoint: {args.checkpoint}") + print(f"Pocket threshold: {args.pocket_threshold} Å") + print(f"Samples evaluated: {summary['n_samples']}") + print(f"Mean pocket size: {summary['mean_pocket_size']:.1f} residues") + + print("\n--- Results (Protein-Only Model, No Ligand Context) ---") + print(f"\n{'Metric':<25} {'Value':<20}") + print("-" * 45) + print(f"{'TM-Score':<25} {summary['mean_tm_score']:.3f} ± {summary['std_tm_score']:.3f}") + print(f"{'RMSD Overall (Å)':<25} {summary['mean_rmsd_overall']:.2f} ± {summary['std_rmsd_overall']:.2f}") + print(f"{'RMSD Pocket (Å)':<25} {summary['mean_rmsd_pocket']:.2f} ± {summary['std_rmsd_pocket']:.2f}") + print(f"{'RMSD Non-pocket (Å)':<25} {summary['mean_rmsd_nonpocket']:.2f} ± {summary['std_rmsd_nonpocket']:.2f}") + print("=" * 70) + + print("\nNote: This is a protein-only baseline. Compare with protein-ligand model") + print(" to see if ligand context improves structure prediction.") + + if args.output_json: + summary_json = {k: float(v) if hasattr(v, "item") else v for k, v in summary.items()} + summary_json["checkpoint"] = args.checkpoint + summary_json["data_dir"] = args.data_dir + summary_json["pocket_threshold"] = args.pocket_threshold + summary_json["nsteps"] = args.nsteps + summary_json["max_length"] = max_length + summary_json["temperature_seq"] = args.temperature_seq + summary_json["temperature_struc"] = args.temperature_struc + summary_json["model_type"] = "protein_only_baseline" + + with open(args.output_json, "w") as f: + json.dump(summary_json, f, indent=2) + logger.info(f"Saved summary to {args.output_json}") + + logger.info("Evaluation completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/src/lobster/metrics/evaluate_protein_inverse_folding_baseline.py b/src/lobster/metrics/evaluate_protein_inverse_folding_baseline.py new file mode 100644 index 00000000..23008a64 --- /dev/null +++ b/src/lobster/metrics/evaluate_protein_inverse_folding_baseline.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python +"""Standalone baseline evaluation script for protein-only inverse folding (sequence recovery). + +Evaluates sequence recovery on proteins using a protein-only Gen-UME model. +This serves as a baseline comparison for protein-ligand models. + +Usage: + # Evaluate a Gen-UME protein-only checkpoint + uv run python -m lobster.metrics.evaluate_protein_inverse_folding_baseline \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv + + # With structure decoding + uv run python -m lobster.metrics.evaluate_protein_inverse_folding_baseline \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv \ + --structure_path ./structures/ \ + --decode_structure \ + --save_gt_structure +""" + +import argparse +import json +import os +import sys +from glob import glob + +import pandas as pd +import torch +from loguru import logger +from torch import Tensor +from tqdm import tqdm + +from lobster.model.latent_generator.io import writepdb +from lobster.model.latent_generator.utils.residue_constants import ( + convert_lobster_aa_tokenization_to_standard_aa, +) + + +def load_model(checkpoint_path: str, device: str = "cuda"): + """Load a Gen-UME protein-only model from checkpoint. + + Parameters + ---------- + checkpoint_path : str + Path to the model checkpoint (.ckpt file) + device : str + Device to load model on + + Returns + ------- + model : LightningModule + The loaded model + """ + from lobster.model.gen_ume import UMESequenceStructureEncoderLightningModule + + logger.info(f"Loading protein-only model from {checkpoint_path}") + + # Load checkpoint + model = UMESequenceStructureEncoderLightningModule.load_from_checkpoint( + checkpoint_path, + map_location=device, + strict=False, + ) + model.eval() + model.to(device) + + # Get max_length from encoder config + max_length = None + if hasattr(model, "encoder") and hasattr(model.encoder, "neobert"): + if hasattr(model.encoder.neobert, "config") and hasattr(model.encoder.neobert.config, "max_length"): + max_length = model.encoder.neobert.config.max_length + model.max_length = max_length + + logger.info(f"Model loaded successfully. Max length: {max_length}") + return model + + +class ProteinInverseFoldingBaselineEvaluator: + """Evaluates inverse folding on proteins using a protein-only model. + + This evaluator serves as a baseline for protein-ligand inverse folding evaluation. + It measures sequence recovery without any ligand context. + + Parameters + ---------- + data_dir : str + Path to PDBBind test directory containing *_protein.pt and *_ligand.pt pairs + pocket_distance_threshold : float + Distance threshold (Å) for defining binding pocket residues + num_samples : int, optional + Limit number of samples to evaluate (None = all) + nsteps : int + Number of diffusion steps for generation + device : str + Device for computation + max_length : int + Maximum protein sequence length to process (default: 512). + decode_structure : bool + Whether to decode and save predicted structures as PDB files (default: False). + save_gt_structure : bool + Whether to save ground truth structures as PDB files (default: False). + """ + + def __init__( + self, + data_dir: str, + pocket_distance_threshold: float = 5.0, + num_samples: int | None = None, + nsteps: int = 100, + device: str = "cuda", + max_length: int = 512, + decode_structure: bool = False, + save_gt_structure: bool = False, + ): + self.data_dir = data_dir + self.pocket_distance_threshold = pocket_distance_threshold + self.num_samples = num_samples + self.nsteps = nsteps + self.device = device + self.max_length = max_length + self.decode_structure = decode_structure + self.save_gt_structure = save_gt_structure + + # Standard amino acid mapping (alphabetical order) + self.standard_aa_map = { + 0: "A", + 1: "R", + 2: "N", + 3: "D", + 4: "C", + 5: "Q", + 6: "E", + 7: "G", + 8: "H", + 9: "I", + 10: "L", + 11: "K", + 12: "M", + 13: "F", + 14: "P", + 15: "S", + 16: "T", + 17: "W", + 18: "Y", + 19: "V", + 20: "X", + } + + # Lobster amino acid mapping (for 21-token vocab model outputs) + self.lobster_aa_map = { + 0: "L", + 1: "A", + 2: "G", + 3: "V", + 4: "S", + 5: "E", + 6: "R", + 7: "T", + 8: "I", + 9: "D", + 10: "P", + 11: "K", + 12: "Q", + 13: "F", + 14: "N", + 15: "Y", + 16: "M", + 17: "H", + 18: "W", + 19: "C", + 20: "X", + } + + # Mapping from lobster tokenization to standard (alphabetical) tokenization + self.lobster_to_standard = torch.tensor( + [ + 10, + 0, + 7, + 19, + 15, + 6, + 1, + 16, + 9, + 3, + 14, + 11, + 5, + 13, + 2, + 18, + 12, + 8, + 17, + 4, + 20, + ], + dtype=torch.long, + device=device, + ) + + # Element vocabulary for pocket computation + self.element_to_idx = { + "PAD": 0, + "MASK": 1, + "UNK": 2, + "C": 3, + "N": 4, + "O": 5, + "S": 6, + "P": 7, + "H": 8, + "F": 9, + "Cl": 10, + "Br": 11, + "I": 12, + "Fe": 13, + "Zn": 14, + "Mg": 15, + "Ca": 16, + "Mn": 17, + "Cu": 18, + "B": 19, + "Si": 20, + "Se": 21, + "Co": 22, + "Ni": 23, + "Bi": 24, + } + + def load_test_set(self) -> list[dict]: + """Load PDBBind test protein-ligand pairs. + + Returns list of dicts with protein and ligand data (ligand used only for pocket definition). + """ + protein_files = sorted(glob(os.path.join(self.data_dir, "*_protein.pt"))) + + if not protein_files: + raise ValueError(f"No protein files found in {self.data_dir}") + + if self.num_samples is not None: + protein_files = protein_files[: self.num_samples] + + logger.info(f"Loading {len(protein_files)} protein-ligand pairs from {self.data_dir}") + + samples = [] + for pf in tqdm(protein_files, desc="Loading samples"): + pdb_id = os.path.basename(pf).replace("_protein.pt", "") + ligand_file = pf.replace("_protein.pt", "_ligand.pt") + + if not os.path.exists(ligand_file): + logger.warning(f"Missing ligand file for {pdb_id}, skipping") + continue + + protein_data = torch.load(pf, weights_only=False, map_location=self.device) + ligand_data = torch.load(ligand_file, weights_only=False, map_location=self.device) + + protein_coords = protein_data.get("coords_res", protein_data.get("coords")) + protein_sequence = protein_data.get("sequence") + + if protein_coords is None or protein_sequence is None: + logger.warning(f"Missing protein data for {pdb_id}, skipping") + continue + + protein_mask = protein_data.get("mask", torch.ones(protein_coords.shape[0], device=self.device)) + protein_indices = protein_data.get("indices", torch.arange(protein_coords.shape[0], device=self.device)) + + # Load ligand coords for pocket computation + ligand_coords = ligand_data.get("atom_coords", ligand_data.get("coords", ligand_data.get("ligand_coords"))) + if ligand_coords is None: + logger.warning(f"Missing ligand coordinates for {pdb_id}, skipping") + continue + + samples.append( + { + "pdb_id": pdb_id, + "protein_coords": protein_coords, + "protein_sequence": protein_sequence, + "protein_mask": protein_mask, + "protein_indices": protein_indices, + "ligand_coords": ligand_coords, # Only for pocket computation + } + ) + + logger.info(f"Loaded {len(samples)} valid samples") + return samples + + def compute_binding_pocket( + self, + protein_coords: Tensor, + ligand_coords: Tensor, + protein_mask: Tensor | None = None, + ) -> Tensor: + """Compute pocket mask based on distance to ligand.""" + if protein_coords.dim() == 3: + ca_coords = protein_coords[:, 1, :] + else: + ca_coords = protein_coords + + distances = torch.cdist(ca_coords.unsqueeze(0), ligand_coords.unsqueeze(0)).squeeze(0) + min_distances = distances.min(dim=1).values + pocket_mask = min_distances < self.pocket_distance_threshold + + if protein_mask is not None: + pocket_mask = pocket_mask & protein_mask.bool() + + return pocket_mask + + def inverse_fold(self, model, sample: dict) -> dict: + """Run inverse folding on a protein sample. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-only model + sample : dict + Sample dictionary from load_test_set() + + Returns + ------- + dict with: + - predicted_sequence: Tensor [L] + - sequence_logits: Tensor [L, vocab_size] + - decoded_coords: Tensor [L, 3, 3] (if decode_structure=True) + """ + protein_coords = sample["protein_coords"].unsqueeze(0).float() + protein_mask = sample["protein_mask"].unsqueeze(0).float() + protein_indices = sample["protein_indices"].unsqueeze(0).long() + length = protein_coords.shape[1] + + # Generate sample (inverse folding mode) + with torch.no_grad(): + result = model.generate_sample( + length=length, + num_samples=1, + inverse_folding=True, + nsteps=self.nsteps, + input_structure_coords=protein_coords, + input_mask=protein_mask, + input_indices=protein_indices, + ) + + # Decode structure to coordinates (optional) + decoded_coords = None + if self.decode_structure: + decoded_x = model.decode_structure(result, protein_mask) + for decoder_name in decoded_x: + if "vit_decoder" == decoder_name: + vit_output = decoded_x[decoder_name] + if isinstance(vit_output, dict): + decoded_coords = vit_output.get("protein_coords", vit_output.get("coords")) + else: + decoded_coords = vit_output + break + + # Get predicted sequence + sequence_logits = result["sequence_logits"] # [1, L, vocab_size] + uses_33_token_vocab = sequence_logits.shape[-1] == 33 + + # Handle both 33-token and 21-token vocab formats + if uses_33_token_vocab: + predicted_sequence = convert_lobster_aa_tokenization_to_standard_aa( + sequence_logits, device=sequence_logits.device + ).squeeze(0) + else: + predicted_sequence = sequence_logits.argmax(dim=-1).squeeze(0) + predicted_sequence[predicted_sequence > 20] = 20 + predicted_sequence = self.lobster_to_standard[predicted_sequence.long()] + + return { + "predicted_sequence": predicted_sequence, + "sequence_logits": sequence_logits.squeeze(0), + "decoded_coords": decoded_coords.squeeze(0) if decoded_coords is not None else None, + } + + def compute_aar( + self, + predicted_seq: Tensor, + ground_truth_seq: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute amino acid recovery rate.""" + if mask is not None: + mask = mask.bool() + if mask.sum() == 0: + return float("nan") + predicted_seq = predicted_seq[mask] + ground_truth_seq = ground_truth_seq[mask] + + if len(predicted_seq) == 0: + return float("nan") + + return (predicted_seq == ground_truth_seq).float().mean().item() + + def evaluate(self, model, samples: list[dict] | None = None, structure_path: str | None = None) -> dict: + """Run full evaluation on PDBBind test set.""" + model.eval() + model.to(self.device) + + if samples is None: + samples = self.load_test_set() + + if structure_path: + os.makedirs(structure_path, exist_ok=True) + + results = [] + skipped_samples = [] + + for sample in tqdm(samples, desc="Evaluating inverse folding (baseline)"): + pdb_id = sample["pdb_id"] + gt_seq = sample["protein_sequence"] + gt_coords = sample["protein_coords"] + protein_mask = sample["protein_mask"] + + protein_length = len(gt_seq) + + if protein_length > self.max_length: + logger.warning( + f"Skipping {pdb_id}: protein length {protein_length} exceeds max_length {self.max_length}" + ) + skipped_samples.append({"pdb_id": pdb_id, "protein_length": protein_length}) + continue + + # Compute binding pocket (using ligand coords) + pocket_mask = self.compute_binding_pocket(gt_coords, sample["ligand_coords"], protein_mask) + non_pocket_mask = protein_mask.bool() & ~pocket_mask + + # Run inverse folding + pred_result = self.inverse_fold(model, sample) + pred_seq = pred_result["predicted_sequence"] + + # Save structures and sequences if requested + if structure_path: + # Save sequences as FASTA + gt_seq_str = self.sequence_to_string(gt_seq) + pred_seq_str = self.sequence_to_string(pred_seq) + + fasta_path = os.path.join(structure_path, f"{pdb_id}_sequences.fasta") + with open(fasta_path, "w") as f: + f.write(f">{pdb_id}_ground_truth\n{gt_seq_str}\n") + f.write(f">{pdb_id}_predicted_baseline\n{pred_seq_str}\n") + + if self.save_gt_structure: + gt_pdb_path = os.path.join(structure_path, f"{pdb_id}_gt_protein.pdb") + writepdb(gt_pdb_path, gt_coords, gt_seq) + + if self.decode_structure and pred_result["decoded_coords"] is not None: + pred_pdb_path = os.path.join(structure_path, f"{pdb_id}_decoded_baseline.pdb") + writepdb(pred_pdb_path, pred_result["decoded_coords"], pred_seq) + + # Compute metrics + result = { + "pdb_id": pdb_id, + "length": len(gt_seq), + "n_pocket_residues": int(pocket_mask.sum().item()), + "n_nonpocket_residues": int(non_pocket_mask.sum().item()), + "aar_overall": self.compute_aar(pred_seq, gt_seq, protein_mask), + "aar_pocket": self.compute_aar(pred_seq, gt_seq, pocket_mask), + "aar_nonpocket": self.compute_aar(pred_seq, gt_seq, non_pocket_mask), + } + results.append(result) + + if skipped_samples: + logger.info(f"Skipped {len(skipped_samples)} samples due to length > {self.max_length}") + + results_df = pd.DataFrame(results) + + if len(results_df) == 0: + logger.warning("No samples were successfully evaluated") + return {"results_df": results_df, "summary": {}} + + summary = { + "mean_aar_overall": results_df["aar_overall"].mean(), + "std_aar_overall": results_df["aar_overall"].std(), + "mean_aar_pocket": results_df["aar_pocket"].mean(), + "std_aar_pocket": results_df["aar_pocket"].std(), + "mean_aar_nonpocket": results_df["aar_nonpocket"].mean(), + "std_aar_nonpocket": results_df["aar_nonpocket"].std(), + "n_samples": len(results_df), + "mean_pocket_size": results_df["n_pocket_residues"].mean(), + } + + return {"results_df": results_df, "summary": summary} + + def sequence_to_string(self, seq_tensor: Tensor) -> str: + """Convert sequence tensor (in standard format) to string.""" + return "".join([self.standard_aa_map.get(int(s), "X") for s in seq_tensor.cpu().tolist()]) + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate protein-only inverse folding (baseline for protein-ligand comparison)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt file)") + parser.add_argument( + "--data_dir", + type=str, + default="/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/", + help="Path to PDBBind test directory", + ) + parser.add_argument( + "--output", type=str, default="protein_inverse_folding_baseline_results.csv", help="Output CSV file" + ) + parser.add_argument("--output_json", type=str, default=None, help="Output JSON file for summary statistics") + parser.add_argument("--structure_path", type=str, default=None, help="Directory to save sequences and structures") + parser.add_argument("--pocket_threshold", type=float, default=5.0, help="Distance threshold (Å) for binding pocket") + parser.add_argument("--num_samples", type=int, default=None, help="Number of samples to evaluate") + parser.add_argument("--nsteps", type=int, default=100, help="Number of diffusion steps") + parser.add_argument("--max_length", type=int, default=768, help="Maximum protein sequence length") + parser.add_argument( + "--decode_structure", action="store_true", help="Decode and save predicted structures as PDB files" + ) + parser.add_argument("--save_gt_structure", action="store_true", help="Save ground truth structures as PDB files") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device") + + args = parser.parse_args() + + # Skip existence check for S3 paths (they're handled by the model loader) + if not args.checkpoint.startswith("s3://") and not os.path.exists(args.checkpoint): + logger.error(f"Checkpoint not found: {args.checkpoint}") + sys.exit(1) + + if not os.path.exists(args.data_dir): + logger.error(f"Data directory not found: {args.data_dir}") + sys.exit(1) + + # Load model + model = load_model(args.checkpoint, args.device) + + max_length = args.max_length + if hasattr(model, "max_length") and model.max_length is not None: + max_length = min(max_length, model.max_length) + logger.info(f"Using max_length: {max_length}") + + # Create evaluator + evaluator = ProteinInverseFoldingBaselineEvaluator( + data_dir=args.data_dir, + pocket_distance_threshold=args.pocket_threshold, + num_samples=args.num_samples, + nsteps=args.nsteps, + device=args.device, + max_length=max_length, + decode_structure=args.decode_structure, + save_gt_structure=args.save_gt_structure, + ) + + # Load test set + logger.info(f"Loading test set from {args.data_dir}") + samples = evaluator.load_test_set() + logger.info(f"Loaded {len(samples)} samples") + + # Run evaluation + logger.info("Starting evaluation...") + results = evaluator.evaluate(model=model, samples=samples, structure_path=args.structure_path) + + results_df = results["results_df"] + summary = results["summary"] + + results_df.to_csv(args.output, index=False) + logger.info(f"Saved per-structure results to {args.output}") + + # Print summary + print("\n" + "=" * 70) + print("PROTEIN INVERSE FOLDING BASELINE RESULTS") + print("=" * 70) + print(f"\nDataset: {args.data_dir}") + print(f"Checkpoint: {args.checkpoint}") + print(f"Pocket threshold: {args.pocket_threshold} Å") + print(f"Samples evaluated: {summary['n_samples']}") + print(f"Mean pocket size: {summary['mean_pocket_size']:.1f} residues") + + print("\n--- Sequence Recovery (AAR) - Protein-Only Model, No Ligand Context ---") + print(f"\n{'Region':<25} {'AAR':<20}") + print("-" * 45) + print(f"{'Overall':<25} {summary['mean_aar_overall']:.2%} ± {summary['std_aar_overall']:.2%}") + print(f"{'Pocket':<25} {summary['mean_aar_pocket']:.2%} ± {summary['std_aar_pocket']:.2%}") + print(f"{'Non-pocket':<25} {summary['mean_aar_nonpocket']:.2%} ± {summary['std_aar_nonpocket']:.2%}") + print("=" * 70) + + print("\nNote: This is a protein-only baseline. Compare with protein-ligand model") + print(" to see if ligand context improves sequence recovery in the pocket.") + + if args.output_json: + summary_json = {k: float(v) if hasattr(v, "item") else v for k, v in summary.items()} + summary_json["checkpoint"] = args.checkpoint + summary_json["data_dir"] = args.data_dir + summary_json["pocket_threshold"] = args.pocket_threshold + summary_json["nsteps"] = args.nsteps + summary_json["max_length"] = max_length + summary_json["model_type"] = "protein_only_baseline" + + with open(args.output_json, "w") as f: + json.dump(summary_json, f, indent=2) + logger.info(f"Saved summary to {args.output_json}") + + logger.info("Evaluation completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/src/lobster/metrics/evaluate_protein_ligand_forward_folding.py b/src/lobster/metrics/evaluate_protein_ligand_forward_folding.py new file mode 100644 index 00000000..00ede3da --- /dev/null +++ b/src/lobster/metrics/evaluate_protein_ligand_forward_folding.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python +"""Standalone evaluation script for protein-ligand forward folding (structure prediction). + +Evaluates structure prediction (TM-score, RMSD) on protein-ligand complexes. +Compares forward folding with and without ligand context. + +Usage: + # Evaluate a Gen-UME protein-ligand checkpoint + uv run python -m lobster.metrics.evaluate_protein_ligand_forward_folding \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv + + # With structure saving + uv run python -m lobster.metrics.evaluate_protein_ligand_forward_folding \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv \ + --structure_path ./structures/ \ + --save_structures \ + --save_gt_structure + + # Customize pocket threshold and number of samples + uv run python -m lobster.metrics.evaluate_protein_ligand_forward_folding \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv \ + --pocket_threshold 6.0 \ + --num_samples 500 +""" + +import argparse +import json +import os +import sys + +import torch +from loguru import logger + +from lobster.metrics.protein_ligand_forward_folding import ProteinLigandForwardFoldingEvaluator + + +def load_model(checkpoint_path: str, device: str = "cuda"): + """Load a Gen-UME protein-ligand model from checkpoint. + + Parameters + ---------- + checkpoint_path : str + Path to the model checkpoint (.ckpt file) + device : str + Device to load model on + + Returns + ------- + model : LightningModule + The loaded model + """ + from lobster.model.gen_ume import ProteinLigandEncoderLightningModule + + logger.info(f"Loading model from {checkpoint_path}") + + # Load checkpoint + model = ProteinLigandEncoderLightningModule.load_from_checkpoint( + checkpoint_path, + map_location=device, + strict=False, + ) + model.eval() + model.to(device) + + # Get max_length from encoder config + max_length = None + if hasattr(model, "encoder") and hasattr(model.encoder, "neobert"): + if hasattr(model.encoder.neobert, "config") and hasattr(model.encoder.neobert.config, "max_length"): + max_length = model.encoder.neobert.config.max_length + model.max_length = max_length # Store for later use + + logger.info(f"Model loaded successfully. Max length: {max_length}") + return model + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate protein-ligand forward folding (structure prediction quality)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Required arguments + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint (.ckpt file)", + ) + parser.add_argument( + "--data_dir", + type=str, + default="/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/", + help="Path to PDBBind test directory with *_protein.pt and *_ligand.pt pairs", + ) + + # Output options + parser.add_argument( + "--output", + type=str, + default="protein_ligand_forward_folding_results.csv", + help="Output CSV file for per-structure results", + ) + parser.add_argument( + "--output_json", + type=str, + default=None, + help="Output JSON file for summary statistics (optional)", + ) + parser.add_argument( + "--structure_path", + type=str, + default=None, + help="Directory to save predicted structures (PDB)", + ) + + # Evaluation parameters + parser.add_argument( + "--pocket_threshold", + type=float, + default=5.0, + help="Distance threshold (Å) for defining binding pocket residues", + ) + parser.add_argument( + "--num_samples", + type=int, + default=None, + help="Number of samples to evaluate (None = all available)", + ) + parser.add_argument( + "--nsteps", + type=int, + default=100, + help="Number of diffusion steps for generation", + ) + parser.add_argument( + "--max_length", + type=int, + default=768, + help="Maximum combined sequence length (protein + ligand) to process", + ) + + # Temperature parameters + parser.add_argument( + "--temperature_seq", + type=float, + default=0.5, + help="Temperature for sequence sampling", + ) + parser.add_argument( + "--temperature_struc", + type=float, + default=0.5, + help="Temperature for structure sampling", + ) + + # Structure saving options + parser.add_argument( + "--save_structures", + action="store_true", + help="Save predicted structures as PDB files", + ) + parser.add_argument( + "--save_gt_structure", + action="store_true", + help="Save ground truth structures as PDB files", + ) + + # Device + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device for computation (default: cuda if available)", + ) + + args = parser.parse_args() + + # Validate inputs + # Skip existence check for S3 paths (they're handled by the model loader) + if not args.checkpoint.startswith("s3://") and not os.path.exists(args.checkpoint): + logger.error(f"Checkpoint not found: {args.checkpoint}") + sys.exit(1) + + if not os.path.exists(args.data_dir): + logger.error(f"Data directory not found: {args.data_dir}") + sys.exit(1) + + # Load model + model = load_model(args.checkpoint, args.device) + + # Use model's max_length if not overridden + max_length = args.max_length + if hasattr(model, "max_length") and model.max_length is not None: + max_length = min(max_length, model.max_length) + logger.info(f"Using max_length: {max_length}") + + # Create evaluator + evaluator = ProteinLigandForwardFoldingEvaluator( + data_dir=args.data_dir, + pocket_distance_threshold=args.pocket_threshold, + num_samples=args.num_samples, + nsteps=args.nsteps, + device=args.device, + max_length=max_length, + temperature_seq=args.temperature_seq, + temperature_struc=args.temperature_struc, + save_structures=args.save_structures, + save_gt_structure=args.save_gt_structure, + ) + + # Load test set + logger.info(f"Loading test set from {args.data_dir}") + samples = evaluator.load_test_set() + logger.info(f"Loaded {len(samples)} samples") + + # Run evaluation + logger.info("Starting evaluation...") + results = evaluator.evaluate( + model=model, + samples=samples, + structure_path=args.structure_path, + ) + + # Save results + results_df = results["results_df"] + summary = results["summary"] + + # Save per-structure results to CSV + results_df.to_csv(args.output, index=False) + logger.info(f"Saved per-structure results to {args.output}") + + # Print summary + print("\n" + "=" * 80) + print("PROTEIN-LIGAND FORWARD FOLDING RESULTS") + print("=" * 80) + print(f"\nDataset: {args.data_dir}") + print(f"Checkpoint: {args.checkpoint}") + print(f"Pocket threshold: {args.pocket_threshold} Å") + print(f"Samples evaluated: {summary['n_samples']}") + print(f"Mean pocket size: {summary['mean_pocket_size']:.1f} residues") + + print("\n--- TM-Score (higher is better) ---") + print(f"\n{'Metric':<25} {'No Ligand':<15} {'With Ligand':<15} {'Delta':<15}") + print("-" * 70) + print( + f"{'TM-Score':<25} " + f"{summary['mean_tm_score_no_ligand']:<15.3f} " + f"{summary['mean_tm_score_with_ligand']:<15.3f} " + f"{summary['mean_tm_score_delta']:+.3f} ± {summary['std_tm_score_delta']:.3f}" + ) + + print("\n--- RMSD (Å, lower is better) ---") + print(f"\n{'Region':<25} {'No Ligand':<15} {'With Ligand':<15} {'Delta':<15}") + print("-" * 70) + print( + f"{'Overall':<25} " + f"{summary['mean_rmsd_overall_no_ligand']:<15.2f} " + f"{summary['mean_rmsd_overall_with_ligand']:<15.2f} " + f"{summary['mean_rmsd_overall_delta']:+.2f} ± {summary['std_rmsd_overall_delta']:.2f}" + ) + print( + f"{'Pocket':<25} " + f"{summary['mean_rmsd_pocket_no_ligand']:<15.2f} " + f"{summary['mean_rmsd_pocket_with_ligand']:<15.2f} " + f"{summary['mean_rmsd_pocket_delta']:+.2f} ± {summary['std_rmsd_pocket_delta']:.2f}" + ) + print( + f"{'Non-pocket':<25} " + f"{summary['mean_rmsd_nonpocket_no_ligand']:<15.2f} " + f"{summary['mean_rmsd_nonpocket_with_ligand']:<15.2f} " + f"{summary['mean_rmsd_nonpocket_delta']:+.2f} ± {summary['std_rmsd_nonpocket_delta']:.2f}" + ) + print("=" * 80) + + # Interpretation + tm_delta = summary["mean_tm_score_delta"] + rmsd_pocket_delta = summary["mean_rmsd_pocket_delta"] + + if tm_delta > 0.01: + print(f"\n✓ Ligand context IMPROVES TM-score by {tm_delta:+.3f}") + elif tm_delta < -0.01: + print(f"\n✗ Ligand context HURTS TM-score by {tm_delta:+.3f}") + else: + print(f"\n○ Ligand context has minimal effect on TM-score ({tm_delta:+.3f})") + + if rmsd_pocket_delta < -0.1: + print(f"✓ Ligand context IMPROVES pocket RMSD by {-rmsd_pocket_delta:.2f} Å") + elif rmsd_pocket_delta > 0.1: + print(f"✗ Ligand context HURTS pocket RMSD by {rmsd_pocket_delta:.2f} Å") + else: + print(f"○ Ligand context has minimal effect on pocket RMSD ({rmsd_pocket_delta:+.2f} Å)") + + # Save summary to JSON if requested + if args.output_json: + # Convert numpy/torch types to Python types for JSON serialization + summary_json = {k: float(v) if hasattr(v, "item") else v for k, v in summary.items()} + summary_json["checkpoint"] = args.checkpoint + summary_json["data_dir"] = args.data_dir + summary_json["pocket_threshold"] = args.pocket_threshold + summary_json["nsteps"] = args.nsteps + summary_json["max_length"] = max_length + summary_json["temperature_seq"] = args.temperature_seq + summary_json["temperature_struc"] = args.temperature_struc + + with open(args.output_json, "w") as f: + json.dump(summary_json, f, indent=2) + logger.info(f"Saved summary to {args.output_json}") + + logger.info("Evaluation completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/src/lobster/metrics/evaluate_protein_ligand_inverse_folding.py b/src/lobster/metrics/evaluate_protein_ligand_inverse_folding.py new file mode 100644 index 00000000..aea20c46 --- /dev/null +++ b/src/lobster/metrics/evaluate_protein_ligand_inverse_folding.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python +"""Standalone evaluation script for protein-ligand inverse folding (sequence recovery). + +Evaluates sequence recovery around ligand binding pockets on protein-ligand complexes. +Compares inverse folding with and without ligand context. + +Usage: + # Evaluate a Gen-UME protein-ligand checkpoint + uv run python -m lobster.metrics.evaluate_protein_ligand_inverse_folding \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv + + # With structure decoding and ground truth saving + uv run python -m lobster.metrics.evaluate_protein_ligand_inverse_folding \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv \ + --structure_path ./structures/ \ + --decode_structure \ + --save_gt_structure + + # Customize pocket threshold and number of samples + uv run python -m lobster.metrics.evaluate_protein_ligand_inverse_folding \ + --checkpoint /path/to/checkpoint.ckpt \ + --data_dir /path/to/pdbind/test/ \ + --output results.csv \ + --pocket_threshold 6.0 \ + --num_samples 500 +""" + +import argparse +import json +import os +import sys + +import torch +from loguru import logger + +from lobster.metrics.protein_ligand_inverse_folding import ProteinLigandInverseFoldingEvaluator + + +def load_model(checkpoint_path: str, device: str = "cuda"): + """Load a Gen-UME protein-ligand model from checkpoint. + + Parameters + ---------- + checkpoint_path : str + Path to the model checkpoint (.ckpt file) + device : str + Device to load model on + + Returns + ------- + model : LightningModule + The loaded model + """ + from lobster.model.gen_ume import ProteinLigandEncoderLightningModule + + logger.info(f"Loading model from {checkpoint_path}") + + # Load checkpoint + model = ProteinLigandEncoderLightningModule.load_from_checkpoint( + checkpoint_path, + map_location=device, + strict=False, + ) + model.eval() + model.to(device) + + # Get max_length from encoder config + max_length = None + if hasattr(model, "encoder") and hasattr(model.encoder, "neobert"): + if hasattr(model.encoder.neobert, "config") and hasattr(model.encoder.neobert.config, "max_length"): + max_length = model.encoder.neobert.config.max_length + model.max_length = max_length # Store for later use + + logger.info(f"Model loaded successfully. Max length: {max_length}") + return model + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate protein-ligand inverse folding (sequence recovery around binding pocket)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Required arguments + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint (.ckpt file)", + ) + parser.add_argument( + "--data_dir", + type=str, + default="/cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/", + help="Path to PDBBind test directory with *_protein.pt and *_ligand.pt pairs", + ) + + # Output options + parser.add_argument( + "--output", + type=str, + default="protein_ligand_inverse_folding_results.csv", + help="Output CSV file for per-structure results", + ) + parser.add_argument( + "--output_json", + type=str, + default=None, + help="Output JSON file for summary statistics (optional)", + ) + parser.add_argument( + "--structure_path", + type=str, + default=None, + help="Directory to save sequences (FASTA) and decoded structures (PDB)", + ) + + # Evaluation parameters + parser.add_argument( + "--pocket_threshold", + type=float, + default=5.0, + help="Distance threshold (Å) for defining binding pocket residues", + ) + parser.add_argument( + "--num_samples", + type=int, + default=None, + help="Number of samples to evaluate (None = all available)", + ) + parser.add_argument( + "--nsteps", + type=int, + default=100, + help="Number of diffusion steps for generation", + ) + parser.add_argument( + "--max_length", + type=int, + default=768, + help="Maximum combined sequence length (protein + ligand) to process", + ) + + # Structure decoding options + parser.add_argument( + "--decode_structure", + action="store_true", + help="Decode and save predicted structures as PDB files", + ) + parser.add_argument( + "--save_gt_structure", + action="store_true", + help="Save ground truth structures as PDB files", + ) + + # Device + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device for computation (default: cuda if available)", + ) + + # ESMFold validation + parser.add_argument( + "--use_esmfold", + action="store_true", + help="Validate designed sequences with ESMFold (fold and compare to GT structure)", + ) + parser.add_argument( + "--max_protein_length", + type=int, + default=512, + help="Maximum protein-only length. Samples exceeding this are skipped. Also used as ESMFold max length (default: 512)", + ) + + args = parser.parse_args() + + # Validate inputs + # Skip existence check for S3 paths (they're handled by the model loader) + if not args.checkpoint.startswith("s3://") and not os.path.exists(args.checkpoint): + logger.error(f"Checkpoint not found: {args.checkpoint}") + sys.exit(1) + + if not os.path.exists(args.data_dir): + logger.error(f"Data directory not found: {args.data_dir}") + sys.exit(1) + + # Load model + model = load_model(args.checkpoint, args.device) + + # Initialize ESMFold if requested + plm_fold = None + if args.use_esmfold: + from lobster.model import LobsterPLMFold + + logger.info("Loading ESMFold for structure validation...") + plm_fold = LobsterPLMFold(model_name="esmfold_v1", max_length=512) + plm_fold.to(args.device) + logger.info("ESMFold loaded successfully") + + # Use model's max_length if not overridden + max_length = args.max_length + if hasattr(model, "max_length") and model.max_length is not None: + max_length = min(max_length, model.max_length) + logger.info(f"Using max_length: {max_length}") + + # Create evaluator + evaluator = ProteinLigandInverseFoldingEvaluator( + data_dir=args.data_dir, + pocket_distance_threshold=args.pocket_threshold, + num_samples=args.num_samples, + nsteps=args.nsteps, + device=args.device, + max_length=max_length, + decode_structure=args.decode_structure, + save_gt_structure=args.save_gt_structure, + use_esmfold=args.use_esmfold, + plm_fold=plm_fold, + max_protein_length=args.max_protein_length, + ) + + # Load test set + logger.info(f"Loading test set from {args.data_dir}") + samples = evaluator.load_test_set() + logger.info(f"Loaded {len(samples)} samples") + + # Run evaluation + logger.info("Starting evaluation...") + results = evaluator.evaluate( + model=model, + samples=samples, + structure_path=args.structure_path, + ) + + # Save results + results_df = results["results_df"] + summary = results["summary"] + + # Save per-structure results to CSV + results_df.to_csv(args.output, index=False) + logger.info(f"Saved per-structure results to {args.output}") + + # Print summary + print("\n" + "=" * 70) + print("PROTEIN-LIGAND INVERSE FOLDING RESULTS") + print("=" * 70) + print(f"\nDataset: {args.data_dir}") + print(f"Checkpoint: {args.checkpoint}") + print(f"Pocket threshold: {args.pocket_threshold} Å") + print(f"Samples evaluated: {summary['n_samples']}") + print(f"Mean pocket size: {summary['mean_pocket_size']:.1f} residues") + + print("\n--- Sequence Recovery (AAR) ---") + print(f"\n{'Region':<20} {'No Ligand':<15} {'With Ligand':<15} {'Delta':<15}") + print("-" * 65) + print( + f"{'Overall':<20} " + f"{summary['mean_aar_overall_no_ligand']:<15.2%} " + f"{summary['mean_aar_overall_with_ligand']:<15.2%} " + f"{summary['mean_aar_overall_delta']:+.2%}" + ) + print( + f"{'Pocket':<20} " + f"{summary['mean_aar_pocket_no_ligand']:<15.2%} " + f"{summary['mean_aar_pocket_with_ligand']:<15.2%} " + f"{summary['mean_aar_pocket_delta']:+.2%} ± {summary['std_aar_pocket_delta']:.2%}" + ) + print( + f"{'Non-pocket':<20} " + f"{summary['mean_aar_nonpocket_no_ligand']:<15.2%} " + f"{summary['mean_aar_nonpocket_with_ligand']:<15.2%} " + f"{summary['mean_aar_nonpocket_delta']:+.2%} ± {summary['std_aar_nonpocket_delta']:.2%}" + ) + # ESMFold validation results + if args.use_esmfold and "mean_esmfold_tm_no_ligand" in summary: + print("\n--- ESMFold Designability Validation ---") + print(f" {'Condition':<20} {'TM-score':<12} {'RMSD (Å)':<12} {'Pocket RMSD':<14} {'pLDDT':<12} {'PAE':<12}") + print(" " + "-" * 82) + print( + f" {'GT sequence':<20} " + f"{summary['mean_esmfold_tm_gt']:<12.3f} " + f"{summary['mean_esmfold_rmsd_gt']:<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_gt']:<14.2f} " + f"{summary['mean_esmfold_plddt_gt']:<12.2f} " + f"{summary['mean_esmfold_pae_gt']:<12.2f}" + ) + print( + f" {'No ligand':<20} " + f"{summary['mean_esmfold_tm_no_ligand']:<12.3f} " + f"{summary['mean_esmfold_rmsd_no_ligand']:<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_no_ligand']:<14.2f} " + f"{summary['mean_esmfold_plddt_no_ligand']:<12.2f} " + f"{summary['mean_esmfold_pae_no_ligand']:<12.2f}" + ) + print( + f" {'With ligand':<20} " + f"{summary['mean_esmfold_tm_with_ligand']:<12.3f} " + f"{summary['mean_esmfold_rmsd_with_ligand']:<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_with_ligand']:<14.2f} " + f"{summary['mean_esmfold_plddt_with_ligand']:<12.2f} " + f"{summary['mean_esmfold_pae_with_ligand']:<12.2f}" + ) + print( + f" {'Delta (ligand)':<20} " + f"{summary['mean_esmfold_tm_delta']:+<12.3f} " + f"{summary['mean_esmfold_rmsd_delta']:+<12.2f} " + f"{summary['mean_esmfold_rmsd_pocket_delta']:+<14.2f} " + f"{summary['mean_esmfold_plddt_delta']:+<12.2f}" + ) + + print("=" * 70) + + # Interpretation + pocket_delta = summary["mean_aar_pocket_delta"] + if pocket_delta > 0.01: + print(f"\nLigand context IMPROVES pocket sequence recovery by {pocket_delta:+.2%}") + elif pocket_delta < -0.01: + print(f"\nLigand context HURTS pocket sequence recovery by {pocket_delta:+.2%}") + else: + print(f"\nLigand context has minimal effect on pocket sequence recovery ({pocket_delta:+.2%})") + + # Save summary to JSON if requested + if args.output_json: + # Convert numpy/torch types to Python types for JSON serialization + summary_json = {k: float(v) if hasattr(v, "item") else v for k, v in summary.items()} + summary_json["checkpoint"] = args.checkpoint + summary_json["data_dir"] = args.data_dir + summary_json["pocket_threshold"] = args.pocket_threshold + summary_json["nsteps"] = args.nsteps + summary_json["max_length"] = max_length + + with open(args.output_json, "w") as f: + json.dump(summary_json, f, indent=2) + logger.info(f"Saved summary to {args.output_json}") + + logger.info("Evaluation completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/src/lobster/metrics/evaluate_reconstruction.py b/src/lobster/metrics/evaluate_reconstruction.py new file mode 100644 index 00000000..1559ecc2 --- /dev/null +++ b/src/lobster/metrics/evaluate_reconstruction.py @@ -0,0 +1,1446 @@ +#!/usr/bin/env python3 +""" +Script to evaluate reconstruction quality of LatentGenerator models. +Takes structures from a directory, encodes and reconstructs them with multiple models, +and reports average aligned RMSD between original and reconstructed structures. +""" + +import os +import glob +import torch +import numpy as np +import argparse +from loguru import logger +from tqdm import tqdm +import json + +# Import from latent_generator +from lobster.model.latent_generator.cmdline import load_model, encode, decode, methods +from lobster.model.latent_generator.cmdline import minimize_ligand_structure, get_ligand_energy +from lobster.model.latent_generator.io import load_pdb, load_ligand, writepdb_ligand_complex, writepdb +from lobster.model.latent_generator.utils._utils import kabsch_torch_batched +from lobster.transforms._structure_transforms import StructureLigandTransform + +ligand_transform = StructureLigandTransform(max_length=512) + + +def check_file_not_empty(file_path: str) -> bool: + """Check if a file exists and is not empty.""" + if not os.path.exists(file_path): + return False + + # Check file size + if os.path.getsize(file_path) == 0: + return False + + # For PDB files, check if they contain actual structure data + if file_path.endswith(".pdb"): + with open(file_path) as f: + lines = f.readlines() + # Check if file contains ATOM or HETATM records + has_atoms = any(line.startswith(("ATOM", "HETATM")) for line in lines) + return has_atoms + + # For SDF files, check if they contain structure data + elif file_path.endswith(".sdf"): + with open(file_path) as f: + content = f.read() + # Check if file contains atom count and coordinates + return "$$$$" in content and any(line.strip().isdigit() for line in content.split("\n")[:10]) + + # For PT files, check if they can be loaded + elif file_path.endswith(".pt"): + # Try to load the file to check if it's valid + data = torch.load(file_path, map_location="cpu") + return data is not None + + return True + + +def load_structure_data(file_path: str) -> dict: + """Load structure data from file (PDB, SDF, or processed PT files).""" + if file_path.endswith(".pt"): + # Load already processed data + structure_data = torch.load(file_path, map_location="cpu") + if isinstance(structure_data, dict): + return structure_data + else: + # If it's a tensor, wrap it in a dict + return {"protein_coords": structure_data} + + elif file_path.endswith(".pdb"): + return load_pdb(file_path) + elif file_path.endswith(".sdf"): + return load_ligand(file_path) + else: + raise ValueError(f"Unsupported file format: {file_path}") + + +def find_paired_protein_ligand_files(data_dir: str) -> list[tuple[str, str, str]]: + """Find paired protein and ligand files in a directory. + + Looks for files matching patterns like: + - {name}_protein.pt and {name}_ligand.pt + - {name}.protein.pt and {name}.ligand.pt + + Args: + data_dir: Directory containing protein and ligand files + + Returns: + List of tuples: (base_name, protein_path, ligand_path) + """ + paired_files = [] + + # Pattern 1: {name}_protein.pt and {name}_ligand.pt + protein_files = glob.glob(os.path.join(data_dir, "*_protein.pt")) + for protein_path in protein_files: + base_name = os.path.basename(protein_path).replace("_protein.pt", "") + ligand_path = protein_path.replace("_protein.pt", "_ligand.pt") + if os.path.exists(ligand_path): + paired_files.append((base_name, protein_path, ligand_path)) + + # Pattern 2: {name}.protein.pt and {name}.ligand.pt (if not found with pattern 1) + if not paired_files: + protein_files = glob.glob(os.path.join(data_dir, "*.protein.pt")) + for protein_path in protein_files: + base_name = os.path.basename(protein_path).replace(".protein.pt", "") + ligand_path = protein_path.replace(".protein.pt", ".ligand.pt") + if os.path.exists(ligand_path): + paired_files.append((base_name, protein_path, ligand_path)) + + return paired_files + + +def load_paired_protein_ligand_data(protein_path: str, ligand_path: str) -> dict: + """Load and merge paired protein and ligand data into a single dict. + + The encoder expects specific key names: + - coords_res: protein coordinates (B, L, n_atoms, 3) + - mask: protein mask (B, L) + - indices: residue indices (B, L) + - sequence: amino acid sequence (B, L) + - ligand_coords: ligand coordinates (B, num_atoms, 3) + - ligand_mask: ligand mask (B, num_atoms) + - ligand_residue_index: ligand atom indices (B, num_atoms) + + Args: + protein_path: Path to protein .pt file + ligand_path: Path to ligand .pt file + + Returns: + Combined structure data dict with both protein and ligand information + """ + # Load protein data + protein_data = torch.load(protein_path, map_location="cpu") + if not isinstance(protein_data, dict): + protein_data = {"coords_res": protein_data} + + # Load ligand data + ligand_data = torch.load(ligand_path, map_location="cpu") + if not isinstance(ligand_data, dict): + ligand_data = {"ligand_coords": ligand_data} + + # Start with combined dict + combined = {} + + # Map protein coordinates - encoder expects 'coords_res' + # Expected shape: (B, L, n_atoms, 3) + if "coords_res" in protein_data: + coords = protein_data["coords_res"] + elif "protein_coords" in protein_data: + coords = protein_data["protein_coords"] + elif "coords" in protein_data: + coords = protein_data["coords"] + else: + coords = None + + if coords is not None: + # Add batch dimension if missing (shape is L, n_atoms, 3) + if coords.dim() == 3: + coords = coords.unsqueeze(0) # (1, L, n_atoms, 3) + combined["coords_res"] = coords # Key expected by encoder + combined["protein_coords"] = coords # Also keep for RMSD computation + + # Map protein mask - encoder expects 'mask' + if "mask" in protein_data: + mask = protein_data["mask"] + elif "protein_mask" in protein_data: + mask = protein_data["protein_mask"] + else: + mask = None + + if mask is not None: + if mask.dim() == 1: + mask = mask.unsqueeze(0) # (1, L) + combined["mask"] = mask # Key expected by encoder + combined["protein_mask"] = mask # Also keep for compatibility + + # Map sequence - encoder expects 'sequence' + if "sequence" in protein_data: + seq = protein_data["sequence"] + if seq.dim() == 1: + seq = seq.unsqueeze(0) # (1, L) + combined["sequence"] = seq # Key expected by encoder + elif "seq" in protein_data: + seq = protein_data["seq"] + if seq.dim() == 1: + seq = seq.unsqueeze(0) + combined["sequence"] = seq # Key expected by encoder + + # Map indices - encoder expects 'indices' with batch dim + if "indices" in protein_data: + indices = protein_data["indices"] + if indices.dim() == 1: + indices = indices.unsqueeze(0) # (1, L) + combined["indices"] = indices + + # Copy other protein fields + for key in ["sequence_str", "chains_ids", "real_chains", "pdb_path"]: + if key in protein_data: + combined[key] = protein_data[key] + + # Add ligand data with proper key mapping + # Handle different possible key names in ligand files + # Expected ligand_coords shape: (B, num_atoms, 3) + if "ligand_coords" in ligand_data: + lig_coords = ligand_data["ligand_coords"] + elif "atom_coords" in ligand_data: + lig_coords = ligand_data["atom_coords"] + elif "coords" in ligand_data: + lig_coords = ligand_data["coords"] + else: + lig_coords = None + + if lig_coords is not None: + # Add batch dimension if missing (shape is num_atoms, 3) + if lig_coords.dim() == 2: + lig_coords = lig_coords.unsqueeze(0) # (1, num_atoms, 3) + combined["ligand_coords"] = lig_coords + + # Map ligand mask and add batch dim if needed + if "ligand_mask" in ligand_data: + lig_mask = ligand_data["ligand_mask"] + elif "mask" in ligand_data: + lig_mask = ligand_data["mask"] + else: + lig_mask = None + + if lig_mask is not None: + if lig_mask.dim() == 1: + lig_mask = lig_mask.unsqueeze(0) # (1, num_atoms) + combined["ligand_mask"] = lig_mask + + # Map ligand indices and add batch dim if needed + if "ligand_residue_index" in ligand_data: + lig_idx = ligand_data["ligand_residue_index"] + elif "atom_indices" in ligand_data: + lig_idx = ligand_data["atom_indices"] + elif "residue_index" in ligand_data: + lig_idx = ligand_data["residue_index"] + else: + lig_idx = None + + if lig_idx is not None: + if lig_idx.dim() == 1: + lig_idx = lig_idx.unsqueeze(0) # (1, num_atoms) + combined["ligand_residue_index"] = lig_idx + combined["ligand_indices"] = lig_idx + + # Ligand atom names (list, no batch dim needed) + if "ligand_atom_names" in ligand_data: + combined["ligand_atom_names"] = ligand_data["ligand_atom_names"] + elif "atom_names" in ligand_data: + combined["ligand_atom_names"] = ligand_data["atom_names"] + + # Bond matrix for connectivity (important for geometry idealization) + if "bond_matrix" in ligand_data: + combined["bond_matrix"] = ligand_data["bond_matrix"] + elif "ligand_bond_matrix" in ligand_data: + combined["bond_matrix"] = ligand_data["ligand_bond_matrix"] + + return combined + + +def compute_complex_rmsd( + original_protein_coords: torch.Tensor, + reconstructed_protein_coords: torch.Tensor, + original_ligand_coords: torch.Tensor, + reconstructed_ligand_coords: torch.Tensor, + protein_mask: torch.Tensor = None, + ligand_mask: torch.Tensor = None, +) -> tuple[float, float, float]: + """Compute aligned RMSD for protein-ligand complex (aligned together). + + This aligns the entire complex (protein + ligand) together, preserving + relative positioning between protein and ligand. + + Args: + original_protein_coords: Original protein coordinates (B, L, n_atoms, 3) + reconstructed_protein_coords: Reconstructed protein coordinates (B, L, n_atoms, 3) + original_ligand_coords: Original ligand coordinates (B, num_atoms, 3) + reconstructed_ligand_coords: Reconstructed ligand coordinates (B, num_atoms, 3) + protein_mask: Protein mask (B, L) + ligand_mask: Ligand mask (B, num_atoms) + + Returns: + Tuple of (complex_rmsd, protein_rmsd_after_complex_align, ligand_rmsd_after_complex_align) + """ + device = reconstructed_protein_coords.device + original_protein_coords = original_protein_coords.to(device) + original_ligand_coords = original_ligand_coords.to(device) + + B, L, n_atoms, _ = original_protein_coords.shape + B_lig, num_lig_atoms, _ = original_ligand_coords.shape + + # Create masks if not provided + if protein_mask is None: + protein_mask = torch.ones(B, L, device=device) + else: + protein_mask = protein_mask.to(device) + + if ligand_mask is None: + ligand_mask = torch.ones(B_lig, num_lig_atoms, device=device) + else: + ligand_mask = ligand_mask.to(device) + + # Flatten protein coordinates (B, L*n_atoms, 3) + gt_protein_flat = original_protein_coords.reshape(B, -1, 3) + pred_protein_flat = reconstructed_protein_coords.reshape(B, -1, 3) + protein_mask_flat = protein_mask.unsqueeze(-1).repeat(1, 1, n_atoms).reshape(B, -1) + + # Concatenate protein + ligand for joint alignment + gt_complex = torch.cat([gt_protein_flat, original_ligand_coords], dim=1) + pred_complex = torch.cat([pred_protein_flat, reconstructed_ligand_coords], dim=1) + mask_complex = torch.cat([protein_mask_flat, ligand_mask], dim=1) + + # Align the ENTIRE complex together using Kabsch + aligned_complex = kabsch_torch_batched(pred_complex, gt_complex, mask_complex) + + # Split back into protein and ligand + n_protein_atoms = gt_protein_flat.shape[1] + aligned_protein_flat = aligned_complex[:, :n_protein_atoms, :] + aligned_ligand = aligned_complex[:, n_protein_atoms:, :] + + # Compute complex RMSD (all atoms) + squared_diff_complex = torch.sum((gt_complex - aligned_complex) ** 2, dim=-1) + masked_squared_diff = squared_diff_complex * mask_complex + total_atoms = torch.sum(mask_complex) + if total_atoms > 0: + complex_rmsd = torch.sqrt(torch.sum(masked_squared_diff) / total_atoms).item() + else: + complex_rmsd = float("inf") + + # Compute protein RMSD after complex alignment (CA atoms only) + aligned_protein = aligned_protein_flat.reshape(B, L, n_atoms, 3) + gt_protein_ca = original_protein_coords[:, :, 1, :] # CA atoms + aligned_protein_ca = aligned_protein[:, :, 1, :] + squared_diff_protein = torch.sum((gt_protein_ca - aligned_protein_ca) ** 2, dim=-1) + masked_squared_diff_protein = squared_diff_protein * protein_mask + total_protein_atoms = torch.sum(protein_mask) + if total_protein_atoms > 0: + protein_rmsd = torch.sqrt(torch.sum(masked_squared_diff_protein) / total_protein_atoms).item() + else: + protein_rmsd = float("inf") + + # Compute ligand RMSD after complex alignment + squared_diff_ligand = torch.sum((original_ligand_coords - aligned_ligand) ** 2, dim=-1) + masked_squared_diff_ligand = squared_diff_ligand * ligand_mask + total_ligand_atoms = torch.sum(ligand_mask) + if total_ligand_atoms > 0: + ligand_rmsd = torch.sqrt(torch.sum(masked_squared_diff_ligand) / total_ligand_atoms).item() + else: + ligand_rmsd = float("inf") + + return complex_rmsd, protein_rmsd, ligand_rmsd + + +def compute_aligned_rmsd( + original_coords: torch.Tensor, + reconstructed_coords: torch.Tensor, + mask: torch.Tensor = None, + structure_type: str = "protein", +) -> float: + """Compute aligned RMSD between original and reconstructed coordinates. + + Args: + original_coords: Original coordinates tensor + reconstructed_coords: Reconstructed coordinates tensor + mask: Mask indicating valid positions + structure_type: "protein" or "ligand" - determines alignment strategy + """ + if mask is None: + mask = torch.ones(original_coords.shape[:2], device=original_coords.device) + + B, L = original_coords.shape[:2] + if structure_type == "protein": + mask_expanded = mask.unsqueeze(-1).repeat(1, 1, 3).to(reconstructed_coords.device) + else: + mask_expanded = mask.to(reconstructed_coords.device) + original_coords = original_coords.to(reconstructed_coords.device) + + # Align structures using Kabsch algorithm + aligned_coords = kabsch_torch_batched( + reconstructed_coords.reshape(B, -1, 3), original_coords.reshape(B, -1, 3), mask_expanded.reshape(B, -1) + ) + + if structure_type == "protein": + aligned_coords = aligned_coords.reshape(B, L, 3, 3) + # For proteins, use CA atoms (index 1) for RMSD calculation + atoms_for_rmsd_original = original_coords[:, :, 1, :] # CA atoms + atoms_for_rmsd_aligned = aligned_coords[:, :, 1, :] # CA atoms + mask_expanded = mask[:, :].to(original_coords.device) + elif structure_type == "ligand": + # For ligands, use all atoms (average across atom dimension) + atoms_for_rmsd_original = original_coords + atoms_for_rmsd_aligned = aligned_coords + else: # not tested yet + # Default to all atoms + atoms_for_rmsd_original = original_coords.reshape(B, L * original_coords.shape[2], 3) + atoms_for_rmsd_aligned = aligned_coords.reshape(B, L * aligned_coords.shape[2], 3) + # Adjust mask for flattened coordinates + mask_expanded = mask.unsqueeze(-1).repeat(1, 1, original_coords.shape[2]).reshape(B, -1) + + # Compute RMSD + squared_diff = torch.sum((atoms_for_rmsd_original - atoms_for_rmsd_aligned) ** 2, dim=-1) + masked_squared_diff = squared_diff * mask_expanded + total_atoms = torch.sum(mask_expanded) + + if total_atoms > 0: + rmsd = torch.sqrt(torch.sum(masked_squared_diff) / total_atoms) + return rmsd.item(), original_coords, aligned_coords + else: + return float("inf"), None, None + + +def evaluate_model_on_structure( + model_name: str, + structure_data: dict, + structure_path: str, + save_structures: bool = False, + output_dir: str = None, + use_canonical_pose: bool = False, + num_steps: int = None, + minimize_ligand: bool = False, + minimize_steps: int = 500, + force_field: str = "MMFF94", + minimize_mode: str = "bonds_and_angles", +) -> dict: + """Evaluate a single model on a single structure. + + Parameters + ---------- + minimize_ligand : bool, default=False + If True, minimize ligand structure after decoding using Open Babel. + minimize_steps : int, default=500 + Maximum number of minimization steps. + force_field : str, default="MMFF94" + Force field for minimization. + minimize_mode : str, default="bonds_and_angles" + Minimization mode: "bonds_only" (ideal bond lengths) or "bonds_and_angles" (recommended). + Force field for minimization: "MMFF94", "MMFF94s", "UFF", "GAFF", "Ghemical". + """ + results = { + "model_name": model_name, + "structure_path": structure_path, + "num_steps": num_steps, + "rmsd": float("inf"), + "success": False, + "error": None, + "minimize_ligand": minimize_ligand, + } + + # Check protein length and skip if > 512 + protein_length = None + if "protein_coords" in structure_data: + protein_length = structure_data["protein_coords"].shape[1] + elif "coords_res" in structure_data: + protein_length = structure_data["coords_res"].shape[1] + + if protein_length is not None and protein_length > 512: + results["error"] = f"Protein length ({protein_length}) exceeds maximum allowed length (512)" + logger.warning(f"Skipping {structure_path} - protein length {protein_length} > 512") + return results + + # Handle different model types + if "Ligand" in model_name: + structure_data = ligand_transform(structure_data) + if "conformers" in structure_data: + structure_data = structure_data["conformers"][np.random.randint(0, len(structure_data["conformers"]))] + if "atom_coords" in structure_data: + structure_data["ligand_coords"] = structure_data["atom_coords"][None] + structure_data["ligand_mask"] = structure_data["mask"][None] + structure_data["ligand_residue_index"] = structure_data["atom_indices"][None] + structure_data["ligand_atom_names"] = structure_data["atom_names"] + structure_data["ligand_indices"] = structure_data["atom_indices"][None] + + # Encode and decode + if use_canonical_pose: + frame_type = "mol_frame" + else: + frame_type = None + + latents, embeddings = encode(structure_data, return_embeddings=True, frame_type=frame_type) + + # Pass num_steps to decode if provided + if num_steps is not None: + decoded_outputs, sequence_outputs = decode(latents, x_emb=embeddings, num_steps=num_steps) + else: + decoded_outputs, sequence_outputs = decode(latents, x_emb=embeddings) + + # Extract coordinates from decoded output and determine what was reconstructed + reconstructed_protein_coords = None + reconstructed_ligand_coords = None + refined_reconstructed_coords = None + + if isinstance(decoded_outputs, dict): + if "protein_coords" in decoded_outputs and decoded_outputs["protein_coords"] is not None: + reconstructed_protein_coords = decoded_outputs["protein_coords"] + if "protein_coords_refinement" in decoded_outputs: + refined_reconstructed_coords = decoded_outputs["protein_coords_refinement"] + if "ligand_coords" in decoded_outputs and decoded_outputs["ligand_coords"] is not None: + reconstructed_ligand_coords = decoded_outputs["ligand_coords"] + else: + reconstructed_protein_coords = decoded_outputs + + # Apply ligand minimization if requested + reconstructed_ligand_coords_minimized = None + if minimize_ligand and reconstructed_ligand_coords is not None: + try: + # Get atom types from structure data + ligand_atom_types = structure_data.get("ligand_atom_names", None) + if ligand_atom_types is None: + ligand_atom_types = structure_data.get("atom_names", None) + if ligand_atom_types is None: + # Default to carbon for unknown atoms + num_atoms = reconstructed_ligand_coords.shape[1] + ligand_atom_types = ["C"] * num_atoms + + # Get bond matrix if available + bond_matrix = None + if isinstance(decoded_outputs, dict): + bond_matrix = decoded_outputs.get("ligand_bond_matrix", None) + if bond_matrix is None: + bond_matrix = structure_data.get("ligand_bond_matrix", None) + if bond_matrix is None: + bond_matrix = structure_data.get("bond_matrix", None) + + # Calculate energy before minimization + energy_before = get_ligand_energy( + reconstructed_ligand_coords[0], ligand_atom_types, bond_matrix, force_field + ) + results["ligand_energy_before"] = energy_before + + # Minimize ligand structure + minimized_coords = minimize_ligand_structure( + reconstructed_ligand_coords[0], + ligand_atom_types, + bond_matrix=bond_matrix, + steps=minimize_steps, + force_field=force_field, + method="cg", + mode=minimize_mode, + ) + reconstructed_ligand_coords_minimized = minimized_coords.unsqueeze(0) + + # Calculate energy after minimization + energy_after = get_ligand_energy(minimized_coords, ligand_atom_types, bond_matrix, force_field) + results["ligand_energy_after"] = energy_after + results["ligand_energy_reduction"] = energy_before - energy_after + + except Exception as e: + logger.warning(f"Ligand minimization failed for {structure_path}: {e}") + results["minimize_error"] = str(e) + + # Compute RMSD for both protein and ligand when both are reconstructed + protein_rmsd = None + ligand_rmsd = None + gt_coords = None + aligned_pred_coords = None + + # Compute protein RMSD if protein was reconstructed + if reconstructed_protein_coords is not None: + # Get original protein coordinates + if "protein_coords" in structure_data and structure_data["protein_coords"] is not None: + original_protein_coords = structure_data["protein_coords"] + elif "coords_res" in structure_data: + original_protein_coords = structure_data["coords_res"] + # Add batch dim if needed + if original_protein_coords.dim() == 3: + original_protein_coords = original_protein_coords.unsqueeze(0) + else: + original_protein_coords = None + + if original_protein_coords is not None: + # Get protein mask + protein_mask = structure_data.get("mask") + if protein_mask is None: + protein_mask = structure_data.get("protein_mask") + + protein_rmsd, gt_coords, aligned_pred_coords = compute_aligned_rmsd( + original_protein_coords, reconstructed_protein_coords, protein_mask, "protein" + ) + + # Compute ligand RMSD if ligand was reconstructed + ligand_rmsd_minimized = None + if reconstructed_ligand_coords is not None: + # Get original ligand coordinates + if "ligand_coords" in structure_data: + original_ligand_coords = structure_data["ligand_coords"] + elif "atom_coords" in structure_data: + original_ligand_coords = structure_data["atom_coords"] + # Add batch dim if needed + if original_ligand_coords.dim() == 2: + original_ligand_coords = original_ligand_coords.unsqueeze(0) + else: + original_ligand_coords = None + + if original_ligand_coords is not None: + # Get ligand mask + ligand_mask = structure_data.get("ligand_mask") + + ligand_rmsd, _, _ = compute_aligned_rmsd( + original_ligand_coords, reconstructed_ligand_coords, ligand_mask, "ligand" + ) + + # Also compute RMSD for minimized ligand if available + if reconstructed_ligand_coords_minimized is not None: + ligand_rmsd_minimized, _, _ = compute_aligned_rmsd( + original_ligand_coords, reconstructed_ligand_coords_minimized, ligand_mask, "ligand" + ) + + # Compute complex RMSD when both protein and ligand are available + complex_rmsd = None + complex_protein_rmsd = None + complex_ligand_rmsd = None + complex_ligand_rmsd_minimized = None + if ( + reconstructed_protein_coords is not None + and reconstructed_ligand_coords is not None + and original_protein_coords is not None + and original_ligand_coords is not None + ): + complex_rmsd, complex_protein_rmsd, complex_ligand_rmsd = compute_complex_rmsd( + original_protein_coords, + reconstructed_protein_coords, + original_ligand_coords, + reconstructed_ligand_coords, + protein_mask, + ligand_mask, + ) + + # Also compute complex RMSD with minimized ligand if available + if reconstructed_ligand_coords_minimized is not None: + ( + complex_rmsd_minimized, + complex_protein_rmsd_minimized, + complex_ligand_rmsd_minimized, + ) = compute_complex_rmsd( + original_protein_coords, + reconstructed_protein_coords, + original_ligand_coords, + reconstructed_ligand_coords_minimized, + protein_mask, + ligand_mask, + ) + results["complex_rmsd_minimized"] = complex_rmsd_minimized + results["complex_protein_rmsd_minimized"] = complex_protein_rmsd_minimized + results["complex_ligand_rmsd_minimized"] = complex_ligand_rmsd_minimized + + # Handle refined reconstruction (protein only for now) + if refined_reconstructed_coords is not None and original_protein_coords is not None: + rmsd_refined, gt_coords_refined, aligned_pred_coords_refined = compute_aligned_rmsd( + original_protein_coords, refined_reconstructed_coords, protein_mask, "protein" + ) + else: + rmsd_refined = None + + # Store results - use protein RMSD as primary if available, else ligand + if protein_rmsd is not None and protein_rmsd != float("inf"): + results["rmsd"] = protein_rmsd + results["protein_rmsd"] = protein_rmsd + results["success"] = True + if ligand_rmsd is not None and ligand_rmsd != float("inf"): + results["ligand_rmsd"] = ligand_rmsd + if results["rmsd"] == float("inf"): # Only set primary if protein wasn't available + results["rmsd"] = ligand_rmsd + results["success"] = True + # Store minimized ligand RMSD if available + if ligand_rmsd_minimized is not None and ligand_rmsd_minimized != float("inf"): + results["ligand_rmsd_minimized"] = ligand_rmsd_minimized + # Calculate improvement from minimization + if ligand_rmsd is not None and ligand_rmsd != float("inf"): + results["ligand_rmsd_improvement"] = ligand_rmsd - ligand_rmsd_minimized + # Store complex RMSD results + if complex_rmsd is not None and complex_rmsd != float("inf"): + results["complex_rmsd"] = complex_rmsd + results["complex_protein_rmsd"] = complex_protein_rmsd + results["complex_ligand_rmsd"] = complex_ligand_rmsd + if rmsd_refined is not None and rmsd_refined != float("inf"): + results["rmsd_refined"] = rmsd_refined + results["success_refined"] = True + + # Check if we got any valid RMSD + if protein_rmsd is None and ligand_rmsd is None: + results["error"] = "No reconstructed coordinates in model output" + return results + + # Determine structure type for saving + has_protein = reconstructed_protein_coords is not None and original_protein_coords is not None + has_ligand = reconstructed_ligand_coords is not None and original_ligand_coords is not None + + # Save aligned structures if requested + if save_structures and output_dir is not None: + try: + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Generate base filename from structure path + structure_name = os.path.splitext(os.path.basename(structure_path))[0] + model_safe_name = model_name.replace(" ", "_").replace("/", "_") + + # Get sequence for protein + seq = None + if has_protein: + seq = structure_data.get("sequence") + if seq is None: + seq = structure_data.get("seq") + if seq is None: + # Create default sequence (UNK) for each residue + seq = torch.full((original_protein_coords.shape[1],), 20, dtype=torch.long) + + # Save ground truth structure + gt_filename = os.path.join(output_dir, f"{structure_name}_{model_safe_name}_gt.pdb") + if has_protein and has_ligand: + # Save protein-ligand complex together + writepdb_ligand_complex( + gt_filename, + protein_atoms=original_protein_coords[0], # Remove batch dim + protein_seq=seq[0] if seq.dim() > 1 else seq, + ligand_atoms=original_ligand_coords[0], # Remove batch dim + ligand_atom_names=structure_data.get("ligand_atom_names", None), + ligand_resname="LIG", + ligand_bond_matrix=structure_data.get("bond_matrix", None), + ) + elif has_protein: + writepdb(gt_filename, original_protein_coords, seq) + elif has_ligand: + writepdb_ligand_complex( + gt_filename, + ligand_atoms=original_ligand_coords[0], + ligand_atom_names=structure_data.get("ligand_atom_names", None), + ligand_resname="LIG", + ligand_bond_matrix=structure_data.get("bond_matrix", None), + ) + + # Save aligned prediction structure + pred_filename = os.path.join( + output_dir, + f"{structure_name}_{model_safe_name}_{num_steps if num_steps is not None else 'default'}_pred.pdb", + ) + if has_protein and has_ligand: + # Save protein-ligand complex together + writepdb_ligand_complex( + pred_filename, + protein_atoms=reconstructed_protein_coords[0], # Remove batch dim + protein_seq=seq[0] if seq.dim() > 1 else seq, + ligand_atoms=reconstructed_ligand_coords[0], # Remove batch dim + ligand_atom_names=structure_data.get("ligand_atom_names", None), + ligand_resname="LIG", + ligand_bond_matrix=structure_data.get("bond_matrix", None), + ) + elif has_protein: + writepdb(pred_filename, reconstructed_protein_coords, seq) + elif has_ligand: + writepdb_ligand_complex( + pred_filename, + ligand_atoms=reconstructed_ligand_coords[0], + ligand_atom_names=structure_data.get("ligand_atom_names", None), + ligand_resname="LIG", + ligand_bond_matrix=structure_data.get("bond_matrix", None), + ) + + results["saved_structures"] = {"ground_truth": gt_filename, "prediction": pred_filename} + + # Save minimized structure if available + if reconstructed_ligand_coords_minimized is not None: + minimized_filename = os.path.join( + output_dir, + f"{structure_name}_{model_safe_name}_{num_steps if num_steps is not None else 'default'}_minimized.pdb", + ) + if has_protein and has_ligand: + writepdb_ligand_complex( + minimized_filename, + protein_atoms=reconstructed_protein_coords[0], + protein_seq=seq[0] if seq.dim() > 1 else seq, + ligand_atoms=reconstructed_ligand_coords_minimized[0], + ligand_atom_names=structure_data.get("ligand_atom_names", None), + ligand_resname="LIG", + ligand_bond_matrix=structure_data.get("bond_matrix", None), + ) + elif has_ligand: + writepdb_ligand_complex( + minimized_filename, + ligand_atoms=reconstructed_ligand_coords_minimized[0], + ligand_atom_names=structure_data.get("ligand_atom_names", None), + ligand_resname="LIG", + ligand_bond_matrix=structure_data.get("bond_matrix", None), + ) + results["saved_structures"]["minimized"] = minimized_filename + + except Exception as e: + logger.warning(f"Failed to save structures for {structure_path}: {e}") + results["saved_structures"] = None + + return results + + +def evaluate_models_on_directory( + models: list[str], + data_dir: str, + output_file: str = "reconstruction_evaluation.json", + save_structures: bool = False, + structures_output_dir: str = None, + use_canonical_pose: bool = False, + num_steps_list: list[int] = None, + max_save_structures: int = None, + minimize_ligand: bool = False, + minimize_steps: int = 500, + force_field: str = "MMFF94", + minimize_mode: str = "bonds_and_angles", +) -> dict: + """Evaluate multiple models on all structures in a directory. + + Args: + max_save_structures: Maximum number of structures to save. If None, saves all. + minimize_ligand: If True, minimize ligand structure after decoding. + minimize_steps: Maximum number of minimization steps. + force_field: Force field for minimization. + """ + + if num_steps_list is None: + num_steps_list = [None] # Default to no num_steps specified + + # Track how many structures have been saved + structures_saved_count = 0 + + # Check if any model requires protein-ligand pairs + requires_protein_ligand = any("Ligand" in model and "Protein" in model for model in models) + + # First, check for paired protein-ligand files + paired_files = find_paired_protein_ligand_files(data_dir) + use_paired_mode = len(paired_files) > 0 and requires_protein_ligand + + if use_paired_mode: + logger.info(f"Found {len(paired_files)} paired protein-ligand complexes") + # Validate paired files + valid_paired_files = [] + for base_name, protein_path, ligand_path in tqdm(paired_files, desc="Checking paired files"): + if check_file_not_empty(protein_path) and check_file_not_empty(ligand_path): + valid_paired_files.append((base_name, protein_path, ligand_path)) + else: + logger.warning(f"Skipping invalid pair: {base_name}") + logger.info(f"Total: {len(valid_paired_files)} valid protein-ligand pairs") + valid_files = valid_paired_files # List of tuples for paired mode + else: + # Find all structure files (including processed PT files) + pdb_files = glob.glob(os.path.join(data_dir, "*.pdb")) + sdf_files = glob.glob(os.path.join(data_dir, "*.sdf")) + pt_files = glob.glob(os.path.join(data_dir, "*.pt")) + structure_files = pdb_files + sdf_files + pt_files + + # Filter out empty files + valid_files = [] + for file_path in tqdm(structure_files, desc="Checking files"): + if check_file_not_empty(file_path): + valid_files.append(file_path) + else: + logger.warning(f"Skipping empty or invalid file: {file_path}") + + logger.info(f"Found {len(pdb_files)} PDB files, {len(sdf_files)} SDF files, and {len(pt_files)} PT files") + logger.info(f"Total: {len(valid_files)} valid structure files") + + if len(valid_files) == 0: + logger.error("No valid structure files found!") + return {} + + # Results storage - now includes step information + all_results = {"models": {}, "structures": {}, "summary": {}} + + # Initialize model results for each step count + for model_name in models: + all_results["models"][model_name] = {} + for num_steps in num_steps_list: + step_key = f"steps_{num_steps}" if num_steps is not None else "default" + all_results["models"][model_name][step_key] = { + "rmsd_values": [], + "successful_reconstructions": 0, + "failed_reconstructions": 0, + "average_rmsd": float("inf"), + "std_rmsd": float("inf"), + # Protein-specific RMSD tracking + "protein_rmsd_values": [], + "average_protein_rmsd": float("inf"), + "std_protein_rmsd": float("inf"), + # Ligand-specific RMSD tracking + "ligand_rmsd_values": [], + "average_ligand_rmsd": float("inf"), + "std_ligand_rmsd": float("inf"), + # Minimized ligand RMSD tracking + "ligand_rmsd_minimized_values": [], + "average_ligand_rmsd_minimized": float("inf"), + "std_ligand_rmsd_minimized": float("inf"), + # Complex RMSD tracking (protein+ligand aligned together) + "complex_rmsd_values": [], + "average_complex_rmsd": float("inf"), + "std_complex_rmsd": float("inf"), + "complex_protein_rmsd_values": [], + "average_complex_protein_rmsd": float("inf"), + "complex_ligand_rmsd_values": [], + "average_complex_ligand_rmsd": float("inf"), + # Minimized complex RMSD tracking + "complex_rmsd_minimized_values": [], + "average_complex_rmsd_minimized": float("inf"), + "std_complex_rmsd_minimized": float("inf"), + "complex_protein_rmsd_minimized_values": [], + "average_complex_protein_rmsd_minimized": float("inf"), + "complex_ligand_rmsd_minimized_values": [], + "average_complex_ligand_rmsd_minimized": float("inf"), + # Refined reconstruction tracking + "rmsd_values_refined": [], + "successful_reconstructions_refined": 0, + "failed_reconstructions_refined": 0, + "average_rmsd_refined": float("inf"), + "std_rmsd_refined": float("inf"), + } + + # Initialize structures dictionary + if use_paired_mode: + for base_name, _, _ in valid_files: + all_results["structures"][base_name] = {} + else: + for structure_path in valid_files: + structure_name = os.path.basename(structure_path) + all_results["structures"][structure_name] = {} + + # Process one model at a time + for model_name in tqdm(models, desc="Processing models"): + logger.info(f"\n{'=' * 60}") + logger.info(f"Loading and evaluating model: {model_name}") + logger.info(f"{'=' * 60}") + + # Load the current model + logger.info(f"Loading model: {model_name}") + load_model( + methods[model_name].model_config.checkpoint, + methods[model_name].model_config.config_path, + methods[model_name].model_config.config_name, + overrides=methods[model_name].model_config.overrides, + ) + logger.info(f"Model {model_name} loaded successfully") + + # Evaluate this model on all structures with all step counts + if use_paired_mode: + # Paired protein-ligand mode + for base_name, protein_path, ligand_path in tqdm(valid_files, desc=f"Evaluating {model_name}"): + structure_name = base_name + + if structure_name not in all_results["structures"]: + all_results["structures"][structure_name] = {} + if model_name not in all_results["structures"][structure_name]: + all_results["structures"][structure_name][model_name] = {} + + # Load paired protein-ligand data + structure_data = load_paired_protein_ligand_data(protein_path, ligand_path) + structure_path = protein_path # Use protein path for reference + + # Determine if we should save this structure + should_save = save_structures and ( + max_save_structures is None or structures_saved_count < max_save_structures + ) + + # Test each num_steps value + for num_steps in num_steps_list: + step_key = f"steps_{num_steps}" if num_steps is not None else "default" + + # Evaluate model on this structure with this num_steps + result = evaluate_model_on_structure( + model_name, + structure_data, + structure_path, + should_save, + structures_output_dir, + use_canonical_pose, + num_steps, + minimize_ligand=minimize_ligand, + minimize_steps=minimize_steps, + force_field=force_field, + minimize_mode=minimize_mode, + ) + + # Increment saved counter if structure was saved + if should_save and result.get("saved_structures"): + structures_saved_count += 1 + + all_results["structures"][structure_name][model_name][step_key] = result + + # Update model statistics + if result["success"]: + all_results["models"][model_name][step_key]["rmsd_values"].append(result["rmsd"]) + all_results["models"][model_name][step_key]["successful_reconstructions"] += 1 + # Track protein RMSD if available + if "protein_rmsd" in result: + all_results["models"][model_name][step_key]["protein_rmsd_values"].append( + result["protein_rmsd"] + ) + # Track ligand RMSD if available + if "ligand_rmsd" in result: + all_results["models"][model_name][step_key]["ligand_rmsd_values"].append( + result["ligand_rmsd"] + ) + # Track minimized ligand RMSD if available + if "ligand_rmsd_minimized" in result: + all_results["models"][model_name][step_key]["ligand_rmsd_minimized_values"].append( + result["ligand_rmsd_minimized"] + ) + # Track complex RMSD if available + if "complex_rmsd" in result: + all_results["models"][model_name][step_key]["complex_rmsd_values"].append( + result["complex_rmsd"] + ) + all_results["models"][model_name][step_key]["complex_protein_rmsd_values"].append( + result["complex_protein_rmsd"] + ) + all_results["models"][model_name][step_key]["complex_ligand_rmsd_values"].append( + result["complex_ligand_rmsd"] + ) + # Track minimized complex RMSD if available + if "complex_rmsd_minimized" in result: + all_results["models"][model_name][step_key]["complex_rmsd_minimized_values"].append( + result["complex_rmsd_minimized"] + ) + all_results["models"][model_name][step_key]["complex_protein_rmsd_minimized_values"].append( + result["complex_protein_rmsd_minimized"] + ) + all_results["models"][model_name][step_key]["complex_ligand_rmsd_minimized_values"].append( + result["complex_ligand_rmsd_minimized"] + ) + if "success_refined" in result and result["success_refined"]: + all_results["models"][model_name][step_key]["rmsd_values_refined"].append( + result["rmsd_refined"] + ) + all_results["models"][model_name][step_key]["successful_reconstructions_refined"] += 1 + else: + all_results["models"][model_name][step_key]["failed_reconstructions"] += 1 + all_results["models"][model_name][step_key]["failed_reconstructions_refined"] += 1 + else: + # Single file mode (original behavior) + for structure_path in tqdm(valid_files, desc=f"Evaluating {model_name}"): + structure_name = os.path.basename(structure_path) + + if structure_name not in all_results["structures"]: + all_results["structures"][structure_name] = {} + if model_name not in all_results["structures"][structure_name]: + all_results["structures"][structure_name][model_name] = {} + + # Load structure data once + structure_data = load_structure_data(structure_path) + + # Determine if we should save this structure + should_save = save_structures and ( + max_save_structures is None or structures_saved_count < max_save_structures + ) + + # Test each num_steps value + for num_steps in num_steps_list: + step_key = f"steps_{num_steps}" if num_steps is not None else "default" + + # Evaluate model on this structure with this num_steps + result = evaluate_model_on_structure( + model_name, + structure_data, + structure_path, + should_save, + structures_output_dir, + use_canonical_pose, + num_steps, + minimize_ligand=minimize_ligand, + minimize_steps=minimize_steps, + force_field=force_field, + minimize_mode=minimize_mode, + ) + + # Increment saved counter if structure was saved + if should_save and result.get("saved_structures"): + structures_saved_count += 1 + + all_results["structures"][structure_name][model_name][step_key] = result + + # Update model statistics + if result["success"]: + all_results["models"][model_name][step_key]["rmsd_values"].append(result["rmsd"]) + all_results["models"][model_name][step_key]["successful_reconstructions"] += 1 + # Track protein RMSD if available + if "protein_rmsd" in result: + all_results["models"][model_name][step_key]["protein_rmsd_values"].append( + result["protein_rmsd"] + ) + # Track ligand RMSD if available + if "ligand_rmsd" in result: + all_results["models"][model_name][step_key]["ligand_rmsd_values"].append( + result["ligand_rmsd"] + ) + # Track minimized ligand RMSD if available + if "ligand_rmsd_minimized" in result: + all_results["models"][model_name][step_key]["ligand_rmsd_minimized_values"].append( + result["ligand_rmsd_minimized"] + ) + # Track complex RMSD if available + if "complex_rmsd" in result: + all_results["models"][model_name][step_key]["complex_rmsd_values"].append( + result["complex_rmsd"] + ) + all_results["models"][model_name][step_key]["complex_protein_rmsd_values"].append( + result["complex_protein_rmsd"] + ) + all_results["models"][model_name][step_key]["complex_ligand_rmsd_values"].append( + result["complex_ligand_rmsd"] + ) + # Track minimized complex RMSD if available + if "complex_rmsd_minimized" in result: + all_results["models"][model_name][step_key]["complex_rmsd_minimized_values"].append( + result["complex_rmsd_minimized"] + ) + all_results["models"][model_name][step_key]["complex_protein_rmsd_minimized_values"].append( + result["complex_protein_rmsd_minimized"] + ) + all_results["models"][model_name][step_key]["complex_ligand_rmsd_minimized_values"].append( + result["complex_ligand_rmsd_minimized"] + ) + if "success_refined" in result and result["success_refined"]: + all_results["models"][model_name][step_key]["rmsd_values_refined"].append( + result["rmsd_refined"] + ) + all_results["models"][model_name][step_key]["successful_reconstructions_refined"] += 1 + else: + all_results["models"][model_name][step_key]["failed_reconstructions"] += 1 + all_results["models"][model_name][step_key]["failed_reconstructions_refined"] += 1 + + # Clear model from memory after evaluation + logger.info(f"Clearing model {model_name} from memory") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Compute summary statistics + for model_name in models: + for num_steps in num_steps_list: + step_key = f"steps_{num_steps}" if num_steps is not None else "default" + model_results = all_results["models"][model_name][step_key] + if model_results["rmsd_values"]: + model_results["average_rmsd"] = np.mean(model_results["rmsd_values"]) + model_results["std_rmsd"] = np.std(model_results["rmsd_values"]) + model_results["min_rmsd"] = np.min(model_results["rmsd_values"]) + model_results["max_rmsd"] = np.max(model_results["rmsd_values"]) + # Compute protein RMSD statistics + if model_results["protein_rmsd_values"]: + model_results["average_protein_rmsd"] = np.mean(model_results["protein_rmsd_values"]) + model_results["std_protein_rmsd"] = np.std(model_results["protein_rmsd_values"]) + model_results["min_protein_rmsd"] = np.min(model_results["protein_rmsd_values"]) + model_results["max_protein_rmsd"] = np.max(model_results["protein_rmsd_values"]) + # Compute ligand RMSD statistics + if model_results["ligand_rmsd_values"]: + model_results["average_ligand_rmsd"] = np.mean(model_results["ligand_rmsd_values"]) + model_results["std_ligand_rmsd"] = np.std(model_results["ligand_rmsd_values"]) + model_results["min_ligand_rmsd"] = np.min(model_results["ligand_rmsd_values"]) + model_results["max_ligand_rmsd"] = np.max(model_results["ligand_rmsd_values"]) + # Compute minimized ligand RMSD statistics + if model_results["ligand_rmsd_minimized_values"]: + model_results["average_ligand_rmsd_minimized"] = np.mean( + model_results["ligand_rmsd_minimized_values"] + ) + model_results["std_ligand_rmsd_minimized"] = np.std(model_results["ligand_rmsd_minimized_values"]) + model_results["min_ligand_rmsd_minimized"] = np.min(model_results["ligand_rmsd_minimized_values"]) + model_results["max_ligand_rmsd_minimized"] = np.max(model_results["ligand_rmsd_minimized_values"]) + # Compute improvement from minimization + if model_results["ligand_rmsd_values"]: + model_results["average_ligand_rmsd_improvement"] = ( + model_results["average_ligand_rmsd"] - model_results["average_ligand_rmsd_minimized"] + ) + # Compute complex RMSD statistics + if model_results["complex_rmsd_values"]: + model_results["average_complex_rmsd"] = np.mean(model_results["complex_rmsd_values"]) + model_results["std_complex_rmsd"] = np.std(model_results["complex_rmsd_values"]) + model_results["min_complex_rmsd"] = np.min(model_results["complex_rmsd_values"]) + model_results["max_complex_rmsd"] = np.max(model_results["complex_rmsd_values"]) + model_results["average_complex_protein_rmsd"] = np.mean( + model_results["complex_protein_rmsd_values"] + ) + model_results["average_complex_ligand_rmsd"] = np.mean(model_results["complex_ligand_rmsd_values"]) + # Compute minimized complex RMSD statistics + if model_results.get("complex_rmsd_minimized_values"): + model_results["average_complex_rmsd_minimized"] = np.mean( + model_results["complex_rmsd_minimized_values"] + ) + model_results["std_complex_rmsd_minimized"] = np.std(model_results["complex_rmsd_minimized_values"]) + model_results["min_complex_rmsd_minimized"] = np.min(model_results["complex_rmsd_minimized_values"]) + model_results["max_complex_rmsd_minimized"] = np.max(model_results["complex_rmsd_minimized_values"]) + model_results["average_complex_protein_rmsd_minimized"] = np.mean( + model_results["complex_protein_rmsd_minimized_values"] + ) + model_results["average_complex_ligand_rmsd_minimized"] = np.mean( + model_results["complex_ligand_rmsd_minimized_values"] + ) + if model_results["rmsd_values_refined"]: + model_results["average_rmsd_refined"] = np.mean(model_results["rmsd_values_refined"]) + model_results["std_rmsd_refined"] = np.std(model_results["rmsd_values_refined"]) + model_results["min_rmsd_refined"] = np.min(model_results["rmsd_values_refined"]) + model_results["max_rmsd_refined"] = np.max(model_results["rmsd_values_refined"]) + else: + model_results["average_rmsd"] = float("inf") + model_results["std_rmsd"] = float("inf") + model_results["min_rmsd"] = float("inf") + model_results["max_rmsd"] = float("inf") + # filler values + model_results["average_rmsd_refined"] = float("inf") + model_results["std_rmsd_refined"] = float("inf") + model_results["min_rmsd_refined"] = float("inf") + model_results["max_rmsd_refined"] = float("inf") + + # Print summary + logger.info("\n" + "=" * 80) + logger.info("RECONSTRUCTION EVALUATION SUMMARY") + logger.info("=" * 80) + + for model_name in models: + for num_steps in num_steps_list: + step_key = f"steps_{num_steps}" if num_steps is not None else "default" + model_results = all_results["models"][model_name][step_key] + logger.info(f"\nModel: {model_name} (Steps: {num_steps if num_steps is not None else 'Default'})") + logger.info(f" Successful reconstructions: {model_results['successful_reconstructions']}") + logger.info(f" Failed reconstructions: {model_results['failed_reconstructions']}") + if model_results["average_rmsd"] != float("inf"): + logger.info(f" Average RMSD: {model_results['average_rmsd']:.3f} ± {model_results['std_rmsd']:.3f} Å") + logger.info(f" Min RMSD: {model_results['min_rmsd']:.3f} Å") + logger.info(f" Max RMSD: {model_results['max_rmsd']:.3f} Å") + else: + logger.info(" Average RMSD: N/A (no successful reconstructions)") + # Print protein-specific RMSD if available + if model_results.get("protein_rmsd_values"): + logger.info( + f" Average Protein RMSD: {model_results['average_protein_rmsd']:.3f} ± {model_results['std_protein_rmsd']:.3f} Å" + ) + logger.info(f" Min Protein RMSD: {model_results['min_protein_rmsd']:.3f} Å") + logger.info(f" Max Protein RMSD: {model_results['max_protein_rmsd']:.3f} Å") + # Print ligand-specific RMSD if available + if model_results.get("ligand_rmsd_values"): + logger.info( + f" Average Ligand RMSD: {model_results['average_ligand_rmsd']:.3f} ± {model_results['std_ligand_rmsd']:.3f} Å" + ) + logger.info(f" Min Ligand RMSD: {model_results['min_ligand_rmsd']:.3f} Å") + logger.info(f" Max Ligand RMSD: {model_results['max_ligand_rmsd']:.3f} Å") + # Print minimized ligand RMSD if available + if model_results.get("ligand_rmsd_minimized_values"): + logger.info( + f" Average Ligand RMSD (minimized): {model_results['average_ligand_rmsd_minimized']:.3f} ± {model_results['std_ligand_rmsd_minimized']:.3f} Å" + ) + logger.info(f" Min Ligand RMSD (minimized): {model_results['min_ligand_rmsd_minimized']:.3f} Å") + logger.info(f" Max Ligand RMSD (minimized): {model_results['max_ligand_rmsd_minimized']:.3f} Å") + if "average_ligand_rmsd_improvement" in model_results: + improvement = model_results["average_ligand_rmsd_improvement"] + logger.info(f" Ligand RMSD improvement from minimization: {improvement:+.3f} Å") + # Print complex RMSD if available (protein+ligand aligned together) + if model_results.get("complex_rmsd_values"): + logger.info( + f" Average Complex RMSD: {model_results['average_complex_rmsd']:.3f} ± {model_results['std_complex_rmsd']:.3f} Å" + ) + logger.info(f" Min Complex RMSD: {model_results['min_complex_rmsd']:.3f} Å") + logger.info(f" Max Complex RMSD: {model_results['max_complex_rmsd']:.3f} Å") + logger.info( + f" (Complex alignment) Protein RMSD: {model_results['average_complex_protein_rmsd']:.3f} Å" + ) + logger.info( + f" (Complex alignment) Ligand RMSD: {model_results['average_complex_ligand_rmsd']:.3f} Å" + ) + # Print minimized complex RMSD if available + if model_results.get("complex_rmsd_minimized_values"): + logger.info( + f" Average Complex RMSD (minimized): {model_results['average_complex_rmsd_minimized']:.3f} ± {model_results['std_complex_rmsd_minimized']:.3f} Å" + ) + logger.info(f" Min Complex RMSD (minimized): {model_results['min_complex_rmsd_minimized']:.3f} Å") + logger.info(f" Max Complex RMSD (minimized): {model_results['max_complex_rmsd_minimized']:.3f} Å") + logger.info( + f" (Complex alignment minimized) Protein RMSD: {model_results['average_complex_protein_rmsd_minimized']:.3f} Å" + ) + logger.info( + f" (Complex alignment minimized) Ligand RMSD: {model_results['average_complex_ligand_rmsd_minimized']:.3f} Å" + ) + if model_results["rmsd_values_refined"]: + logger.info( + f" Successful refined reconstructions: {model_results['successful_reconstructions_refined']}" + ) + logger.info(f" Failed refined reconstructions: {model_results['failed_reconstructions_refined']}") + if model_results["average_rmsd_refined"] != float("inf"): + logger.info( + f" Average refined RMSD: {model_results['average_rmsd_refined']:.3f} ± {model_results['std_rmsd_refined']:.3f} Å" + ) + logger.info(f" Min refined RMSD: {model_results['min_rmsd_refined']:.3f} Å") + logger.info(f" Max refined RMSD: {model_results['max_rmsd_refined']:.3f} Å") + else: + logger.info(" Average refined RMSD: N/A (no successful refined reconstructions)") + # Save results + with open(output_file, "w") as f: + json.dump(all_results, f, indent=2, default=str) + + logger.info(f"\nDetailed results saved to: {output_file}") + + return all_results + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate reconstruction quality of LatentGenerator models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f""" +Available Models: +{chr(10).join([f" - {name}" for name in methods.keys()])} + +Example usage Protein-only model: + python src/lobster/metrics/evaluate_reconstruction.py \ + --models "LG full attention 2" \ + --data_dir /cv/data/ai4dd/data2/lisanzas/latent_generator_files/casp_recon/CASP15_merged/ \ + --output_file reconstruction_results_protein_only.json \ + --save_structures \ + --structures_output_dir aligned_structures_protein_only + +Example usage Ligand-only model: + python src/lobster/metrics/evaluate_reconstruction.py \ + --models "LG Ligand 20A" \ + --data_dir /cv/data/ai4dd/data2/lisanzas/geom_12_15_25/test/ \ + --output_file reconstruction_results_ligand_only.json + +Example usage Protein-Ligand model (with paired *_protein.pt and *_ligand.pt files): + python src/lobster/metrics/evaluate_reconstruction.py \ + --models "LG Protein Ligand" \ + --data_dir /cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/ \ + --output_file reconstruction_results_protein_ligand.json \ + --save_structures \ + --structures_output_dir aligned_structures_protein_ligand + """, + ) + + parser.add_argument("--models", nargs="+", required=True, help="List of model names to evaluate") + + parser.add_argument( + "--data_dir", + type=str, + default="/cv/data/ai4dd/data2/lisanzas/latent_generator_files/casp_recon/CASP15_merged/", + help="Directory containing structure files to evaluate", + ) + + parser.add_argument( + "--output_file", type=str, default="reconstruction_evaluation.json", help="Output file to save results" + ) + + parser.add_argument( + "--save_structures", + action="store_true", + help="Save aligned ground truth and prediction structures to PDB files", + ) + + parser.add_argument( + "--structures_output_dir", + type=str, + default="aligned_structures", + help="Directory to save aligned structures (used with --save_structures)", + ) + + parser.add_argument( + "--max_save_structures", + type=int, + default=None, + help="Maximum number of structures to save (default: save all). Only saves first N structures.", + ) + + parser.add_argument( + "--use_canonical_pose", action="store_true", help="Make model invariant with mol_frame to get canonical pose" + ) + + parser.add_argument( + "--num_steps", nargs="+", type=int, help="List of num_steps values to test for each model (e.g., 1 5 10 20)" + ) + + # Minimization options + parser.add_argument( + "--minimize_ligand", + action="store_true", + help="Minimize ligand structure after decoding using Open Babel force field", + ) + parser.add_argument( + "--minimize_steps", + type=int, + default=500, + help="Maximum number of minimization steps (default: 500)", + ) + parser.add_argument( + "--force_field", + type=str, + default="MMFF94", + choices=["MMFF94", "MMFF94s", "UFF", "GAFF", "Ghemical"], + help="Force field for minimization (default: MMFF94)", + ) + parser.add_argument( + "--minimize_mode", + type=str, + default="bonds_and_angles", + choices=["bonds_only", "bonds_and_angles"], + help="Minimization mode: 'bonds_only' (ideal bond lengths), 'bonds_and_angles' (ideal bonds + angles via constrained minimization, recommended)", + ) + + args = parser.parse_args() + + # Validate models + invalid_models = [model for model in args.models if model not in methods] + if invalid_models: + logger.error(f"Invalid model names: {invalid_models}") + logger.error(f"Available models: {list(methods.keys())}") + return + + # Validate data directory + if not os.path.exists(args.data_dir): + logger.error(f"Data directory does not exist: {args.data_dir}") + return + + logger.info(f"Evaluating models: {args.models}") + logger.info(f"Data directory: {args.data_dir}") + logger.info(f"Output file: {args.output_file}") + if args.save_structures: + logger.info(f"Saving aligned structures to: {args.structures_output_dir}") + if args.max_save_structures: + logger.info(f"Limiting to first {args.max_save_structures} structures") + if args.minimize_ligand: + logger.info( + f"Ligand minimization enabled: force_field={args.force_field}, steps={args.minimize_steps}, mode={args.minimize_mode}" + ) + + # Run evaluation + results = evaluate_models_on_directory( + args.models, + args.data_dir, + args.output_file, + args.save_structures, + args.structures_output_dir, + args.use_canonical_pose, + args.num_steps, + args.max_save_structures, + minimize_ligand=args.minimize_ligand, + minimize_steps=args.minimize_steps, + force_field=args.force_field, + minimize_mode=args.minimize_mode, + ) + if results: + logger.info("Evaluation completed successfully!") + else: + logger.error("Evaluation failed!") + + +if __name__ == "__main__": + main() diff --git a/src/lobster/metrics/ligand_conditioned_protein_generation.py b/src/lobster/metrics/ligand_conditioned_protein_generation.py new file mode 100644 index 00000000..8fd5a619 --- /dev/null +++ b/src/lobster/metrics/ligand_conditioned_protein_generation.py @@ -0,0 +1,971 @@ +"""Ligand-Conditioned Protein Generation Evaluator. + +Evaluates the model's ability to generate proteins that bind to a given ligand, +starting from scratch (no protein structure or sequence input). + +The core metric is **self-consistency**: the model generates both a sequence and +a structure; we then fold the generated sequence with ESMFold and measure how +well the ESMFold-predicted structure agrees with the model-decoded structure. + +Metrics: +- scTM (self-consistency TM-score): TM-score between decoded and ESMFold structures +- scRMSD: RMSD between decoded and ESMFold structures +- Pocket scTM / scRMSD: same, restricted to residues near the decoded ligand +- pLDDT: ESMFold confidence score +- PAE: ESMFold predicted aligned error +""" + +import os +from glob import glob +from typing import TYPE_CHECKING + +import pandas as pd +import torch +from loguru import logger +from tmtools import tm_align +from torch import Tensor +from tqdm import tqdm + +from bionemo.moco.schedules.inference_time_schedules import ( + LinearInferenceSchedule, + LogInferenceSchedule, + PowerInferenceSchedule, +) + +from lobster.metrics import align_and_compute_rmsd +from lobster.model.latent_generator.io import writepdb, writepdb_ligand_complex +from lobster.model.latent_generator.utils import minimize_ligand_structure +from lobster.model.latent_generator.utils.residue_constants import ( + convert_lobster_aa_tokenization_to_standard_aa, + restype_order_with_x_inv, +) + +INFERENCE_SCHEDULE_MAP = { + "LinearInferenceSchedule": LinearInferenceSchedule, + "LogInferenceSchedule": LogInferenceSchedule, + "PowerInferenceSchedule": PowerInferenceSchedule, +} + + +def _get_inference_schedule_class(schedule_name: str): + """Convert string schedule name to class.""" + if schedule_name not in INFERENCE_SCHEDULE_MAP: + raise ValueError(f"Unknown inference schedule: {schedule_name}. Options: {list(INFERENCE_SCHEDULE_MAP.keys())}") + return INFERENCE_SCHEDULE_MAP[schedule_name] + + +if TYPE_CHECKING: + from lightning import LightningModule + + +class LigandConditionedProteinGenerationEvaluator: + """Evaluates ligand-conditioned protein generation via self-consistency. + + Given only a ligand (atom types + bond matrix), the model generates a protein + (both sequence and structure) from scratch. The generated sequence is then + folded with ESMFold, and the self-consistency between the model-decoded + structure and the ESMFold-predicted structure is measured. + + Parameters + ---------- + data_dir : str + Path to directory containing *_ligand.pt files. + length : int + Length of the protein to generate (number of residues). + pocket_distance_threshold : float + Distance threshold (angstrom) for defining binding pocket residues + on the decoded structure relative to decoded ligand coordinates. + num_samples : int, optional + Limit number of samples to evaluate (None = all). + nsteps : int + Number of diffusion steps for generation. + device : str + Device for computation. + max_length : int + Maximum combined sequence length (protein + ligand) to process. + temperature_seq : float + Temperature for sequence sampling. + temperature_struc : float + Temperature for structure sampling. + stochasticity_seq : int + Stochasticity parameter for sequence sampling. + stochasticity_struc : int + Stochasticity parameter for structure sampling. + temperature_ligand : float + Temperature for ligand structure sampling. + stochasticity_ligand : int + Stochasticity parameter for ligand structure sampling. + ligand_context_mode : str + How to provide ligand context: "atom_bond_only" or "structure_tokens". + inference_schedule_seq : str + Inference schedule for sequence generation. + inference_schedule_struc : str + Inference schedule for structure generation. + inference_schedule_ligand_atom : str, optional + Inference schedule for ligand atom token generation. + inference_schedule_ligand_struc : str, optional + Inference schedule for ligand structure token generation. + save_structures : bool + Whether to save generated structures as PDB files. + num_designs : int + Number of designs to generate per ligand. The best design (by scTM) + is selected for reporting. + minimize_ligand : bool + Whether to apply geometry correction to decoded ligand structures. + minimize_mode : str + Minimization mode: "bonds_only", "bonds_and_angles", "local", or "full". + force_field : str + Force field for minimization: "MMFF94", "MMFF94s", "UFF", etc. + minimize_steps : int + Maximum number of minimization steps. + plm_fold : object + Pre-loaded LobsterPLMFold model instance for ESMFold prediction. + """ + + def __init__( + self, + data_dir: str, + length: int = 100, + pocket_distance_threshold: float = 5.0, + num_samples: int | None = None, + num_designs: int = 10, + nsteps: int = 100, + device: str = "cuda", + max_length: int = 512, + temperature_seq: float = 0.5, + temperature_struc: float = 0.5, + stochasticity_seq: int = 20, + stochasticity_struc: int = 20, + temperature_ligand: float = 0.5, + stochasticity_ligand: int = 20, + ligand_context_mode: str = "atom_bond_only", + inference_schedule_seq: str = "LogInferenceSchedule", + inference_schedule_struc: str = "LinearInferenceSchedule", + inference_schedule_ligand_atom: str | None = None, + inference_schedule_ligand_struc: str | None = None, + save_structures: bool = False, + minimize_ligand: bool = False, + minimize_mode: str = "bonds_and_angles", + force_field: str = "MMFF94", + minimize_steps: int = 500, + plm_fold: object = None, + ): + self.data_dir = data_dir + self.length = length + self.pocket_distance_threshold = pocket_distance_threshold + self.num_samples = num_samples + self.num_designs = num_designs + self.nsteps = nsteps + self.device = device + self.max_length = max_length + self.temperature_seq = temperature_seq + self.temperature_struc = temperature_struc + self.stochasticity_seq = stochasticity_seq + self.stochasticity_struc = stochasticity_struc + self.temperature_ligand = temperature_ligand + self.stochasticity_ligand = stochasticity_ligand + self.ligand_context_mode = ligand_context_mode + self.inference_schedule_seq = inference_schedule_seq + self.inference_schedule_struc = inference_schedule_struc + self.inference_schedule_ligand_atom = inference_schedule_ligand_atom + self.inference_schedule_ligand_struc = inference_schedule_ligand_struc + self.save_structures = save_structures + self.minimize_ligand = minimize_ligand + self.minimize_mode = minimize_mode + self.force_field = force_field + self.minimize_steps = minimize_steps + self.plm_fold = plm_fold + + if plm_fold is None: + raise ValueError( + "plm_fold is required for self-consistency evaluation. " + "Load with: LobsterPLMFold(model_name='esmfold_v1', max_length=512)" + ) + + self.standard_aa_map = { + 0: "A", + 1: "R", + 2: "N", + 3: "D", + 4: "C", + 5: "Q", + 6: "E", + 7: "G", + 8: "H", + 9: "I", + 10: "L", + 11: "K", + 12: "M", + 13: "F", + 14: "P", + 15: "S", + 16: "T", + 17: "W", + 18: "Y", + 19: "V", + 20: "X", + } + + self.lobster_to_standard = torch.tensor( + [ + 10, + 0, + 7, + 19, + 15, + 6, + 1, + 16, + 9, + 3, + 14, + 11, + 5, + 13, + 2, + 18, + 12, + 8, + 17, + 4, + 20, + ], + dtype=torch.long, + device=device, + ) + + self.element_to_idx = { + "PAD": 0, + "MASK": 1, + "UNK": 2, + "C": 3, + "N": 4, + "O": 5, + "S": 6, + "P": 7, + "H": 8, + "F": 9, + "Cl": 10, + "Br": 11, + "I": 12, + "Fe": 13, + "Zn": 14, + "Mg": 15, + "Ca": 16, + "Mn": 17, + "Cu": 18, + "B": 19, + "Si": 20, + "Se": 21, + "Co": 22, + "Ni": 23, + "Bi": 24, + } + + def _atom_names_to_indices(self, atom_names: list) -> Tensor: + """Convert atom names (e.g., ['C1', 'N2', 'O3']) to element indices.""" + indices = [] + for name in atom_names: + if len(name) >= 2 and name[:2] in self.element_to_idx: + elem = name[:2] + elif name[0] in self.element_to_idx: + elem = name[0] + else: + elem = name[0].upper() + idx = self.element_to_idx.get(elem, 2) + indices.append(idx) + return torch.tensor(indices, dtype=torch.long, device=self.device) + + def load_test_set(self) -> list[dict]: + """Load ligand data from the test directory. + + Only ligand .pt files are required. Protein .pt files are used only to + derive sample IDs (via the *_ligand.pt naming convention). + + Returns list of dicts with: + - ligand_id: str + - ligand_coords: Tensor [N_atoms, 3] + - ligand_atom_types: Tensor [N_atoms] + - ligand_atom_names: list[str] or None + - ligand_mask: Tensor [N_atoms] + - ligand_indices: Tensor [N_atoms] + - bond_matrix: Tensor [N_atoms, N_atoms] (if available) + """ + ligand_files = sorted(glob(os.path.join(self.data_dir, "*_ligand.pt"))) + + if not ligand_files: + raise ValueError(f"No ligand files found in {self.data_dir}") + + if self.num_samples is not None: + ligand_files = ligand_files[: self.num_samples] + + logger.info(f"Loading {len(ligand_files)} ligand samples from {self.data_dir}") + + samples = [] + for lf in tqdm(ligand_files, desc="Loading samples"): + ligand_id = os.path.basename(lf).replace("_ligand.pt", "") + ligand_data = torch.load(lf, weights_only=False, map_location=self.device) + + ligand_coords = ligand_data.get( + "atom_coords", + ligand_data.get("coords", ligand_data.get("ligand_coords")), + ) + + if ligand_coords is None: + logger.warning(f"Missing ligand coordinates for {ligand_id}, skipping") + continue + + atom_names = ligand_data.get("atom_names") + if atom_names is not None and isinstance(atom_names, list): + ligand_atom_types = self._atom_names_to_indices(atom_names) + else: + ligand_atom_types = ligand_data.get( + "element_indices", + ligand_data.get( + "ligand_element_indices", + torch.full( + (ligand_coords.shape[0],), + 3, + dtype=torch.long, + device=self.device, + ), + ), + ) + + ligand_mask = ligand_data.get( + "mask", + ligand_data.get( + "ligand_mask", + torch.ones(ligand_coords.shape[0], device=self.device), + ), + ) + ligand_indices = ligand_data.get( + "atom_indices", + ligand_data.get( + "indices", + ligand_data.get( + "ligand_indices", + torch.arange(ligand_coords.shape[0], device=self.device), + ), + ), + ) + bond_matrix = ligand_data.get("bond_matrix") + + samples.append( + { + "ligand_id": ligand_id, + "ligand_coords": ligand_coords, + "ligand_atom_types": ligand_atom_types, + "ligand_atom_names": atom_names, + "ligand_mask": ligand_mask, + "ligand_indices": ligand_indices, + "bond_matrix": bond_matrix, + } + ) + + logger.info(f"Loaded {len(samples)} valid samples") + return samples + + def compute_binding_pocket( + self, + protein_coords: Tensor, + ligand_coords: Tensor, + ) -> Tensor: + """Compute pocket mask on the decoded structure. + + A residue is in the pocket if its CA atom is within + pocket_distance_threshold of any decoded ligand atom. + + Returns boolean mask [L] where True indicates pocket residues. + """ + if protein_coords.dim() == 3: + ca_coords = protein_coords[:, 1, :] + else: + ca_coords = protein_coords + + distances = torch.cdist(ca_coords.unsqueeze(0), ligand_coords.unsqueeze(0)).squeeze(0) + min_distances = distances.min(dim=1).values + return min_distances < self.pocket_distance_threshold + + def generate_protein( + self, + model: "LightningModule", + sample: dict, + ) -> dict: + """Generate protein sequence and structure conditioned on ligand. + + No protein information is provided. The model generates both sequence + and structure from noise, conditioned on the ligand. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-ligand model. + sample : dict + Sample dictionary from load_test_set(). + + Returns + ------- + dict with: + - predicted_sequence: Tensor [L] (in standard AA format) + - sequence_logits: Tensor [L, vocab_size] + - decoded_coords: Tensor [L, 3, 3] + - decoded_ligand_coords: Tensor [N, 3] or None + """ + length = self.length + + ligand_atom_tokens = sample["ligand_atom_types"].unsqueeze(0).long() + num_atoms = len(sample["ligand_atom_types"]) + bond_matrix = sample.get("bond_matrix") + if bond_matrix is not None: + bond_matrix = bond_matrix.unsqueeze(0).long() + + ligand_structure_tokens = None + ligand_structure_embeddings = None + + if self.ligand_context_mode == "structure_tokens": + ligand_coords = sample["ligand_coords"].float() + ligand_mask_t = sample["ligand_mask"].float() + ligand_indices = sample["ligand_indices"].long() + + # Center ligand at the origin so the generated protein + # (which starts from noise around the origin) is spatially + # close to the ligand. + valid_mask = ligand_mask_t.bool() + if valid_mask.any(): + centroid = ligand_coords[valid_mask].mean(dim=0, keepdim=True) + ligand_coords = ligand_coords - centroid + + ligand_coords = ligand_coords.unsqueeze(0) + ligand_mask_t = ligand_mask_t.unsqueeze(0) + ligand_indices = ligand_indices.unsqueeze(0) + + with torch.no_grad(): + encode_result = model.encode_ligand_structure( + ligand_coords, + ligand_mask_t, + ligand_indices, + return_continuous=True, + ) + ligand_structure_tokens = encode_result[0] + ligand_structure_embeddings = encode_result[2] + + inference_schedule_seq_class = _get_inference_schedule_class(self.inference_schedule_seq) + inference_schedule_struc_class = _get_inference_schedule_class(self.inference_schedule_struc) + inference_schedule_ligand_atom_class = ( + _get_inference_schedule_class(self.inference_schedule_ligand_atom) + if self.inference_schedule_ligand_atom + else None + ) + inference_schedule_ligand_struc_class = ( + _get_inference_schedule_class(self.inference_schedule_ligand_struc) + if self.inference_schedule_ligand_struc + else None + ) + + with torch.no_grad(): + result = model.generate_sample( + length=length, + num_samples=1, + inverse_folding=False, + forward_folding=False, + nsteps=self.nsteps, + inference_schedule_seq=inference_schedule_seq_class, + inference_schedule_struc=inference_schedule_struc_class, + inference_schedule_ligand_atom=inference_schedule_ligand_atom_class, + inference_schedule_ligand_struc=inference_schedule_ligand_struc_class, + temperature_seq=self.temperature_seq, + temperature_struc=self.temperature_struc, + stochasticity_seq=self.stochasticity_seq, + stochasticity_struc=self.stochasticity_struc, + temperature_ligand=self.temperature_ligand, + stochasticity_ligand=self.stochasticity_ligand, + generate_ligand=True, + num_atoms=num_atoms, + input_ligand_atom_tokens=ligand_atom_tokens, + input_ligand_structure_tokens=ligand_structure_tokens, + input_ligand_structure_embeddings=ligand_structure_embeddings, + input_bond_matrix=bond_matrix, + ligand_is_context=(self.ligand_context_mode == "structure_tokens"), + ) + + protein_mask_batch = torch.ones((1, length), device=self.device) + ligand_mask_batch = torch.ones((1, num_atoms), device=self.device) + decoded_x = model.decode_structure( + result, + protein_mask_batch, + ligand_mask=ligand_mask_batch, + ) + decoded_coords = None + decoded_ligand_coords = None + vit_output = decoded_x.get("vit_decoder") + if isinstance(vit_output, dict): + decoded_coords = vit_output.get("protein_coords") + decoded_ligand_coords = vit_output.get("ligand_coords") + else: + decoded_coords = vit_output + + sequence_logits = result["sequence_logits"] + uses_33_token_vocab = sequence_logits.shape[-1] == 33 + + if uses_33_token_vocab: + predicted_sequence = convert_lobster_aa_tokenization_to_standard_aa( + sequence_logits, device=sequence_logits.device + ).squeeze(0) + else: + predicted_sequence = sequence_logits.argmax(dim=-1).squeeze(0) + predicted_sequence[predicted_sequence > 20] = 20 + predicted_sequence = self.lobster_to_standard[predicted_sequence.long()] + + return { + "predicted_sequence": predicted_sequence, + "sequence_logits": sequence_logits.squeeze(0), + "decoded_coords": (decoded_coords.squeeze(0) if decoded_coords is not None else None), + "decoded_ligand_coords": (decoded_ligand_coords.squeeze(0) if decoded_ligand_coords is not None else None), + } + + def fold_with_esmfold(self, sequence_str: str) -> dict: + """Fold a sequence with ESMFold and return predicted coords + confidence. + + Returns + ------- + dict with: + - esmfold_coords: Tensor [L, 3, 3] (N, CA, C backbone) + - plddt: float (mean pLDDT) + - pae: float (mean predicted aligned error) + """ + tokenized_input = self.plm_fold.tokenizer.encode_plus( + sequence_str, + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + return_tensors="pt", + )["input_ids"].to(self.device) + + with torch.no_grad(): + outputs = self.plm_fold.model(tokenized_input) + + esmfold_coords = outputs["positions"][-1][0, :, :3, :] # [L, 3, 3] + plddt = outputs["plddt"].mean().item() + pae = outputs["predicted_aligned_error"].mean().item() + + return { + "esmfold_coords": esmfold_coords, + "plddt": plddt, + "pae": pae, + } + + def compute_tm_score( + self, + coords1: Tensor, + coords2: Tensor, + sequence: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute TM-score between two structures.""" + if mask is not None: + mask = mask.bool() + coords1 = coords1[mask] + coords2 = coords2[mask] + sequence = sequence[mask] + + if len(coords1) < 3: + return float("nan") + + sequence_str = "".join([restype_order_with_x_inv.get(int(s), "X") for s in sequence.cpu().tolist()]) + + ca1 = coords1[:, 1, :].detach().cpu().numpy() + ca2 = coords2[:, 1, :].detach().cpu().numpy() + + tm_out = tm_align(ca1, ca2, sequence_str, sequence_str) + return tm_out.tm_norm_chain1 + + def compute_rmsd( + self, + coords1: Tensor, + coords2: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute RMSD between two structures (Kabsch-aligned).""" + if mask is not None: + mask = mask.bool() + coords1 = coords1[mask] + coords2 = coords2[mask] + + if len(coords1) == 0: + return float("nan") + + rmsd = align_and_compute_rmsd( + coords1=coords1.detach(), + coords2=coords2.detach(), + mask=None, + return_aligned=False, + device=coords1.device, + ) + return float(rmsd) + + def compute_contact_metrics( + self, + protein_coords: Tensor, + ligand_coords: Tensor, + contact_threshold: float = 4.5, + ) -> dict: + """Compute protein-ligand contact statistics. + + Parameters + ---------- + protein_coords : Tensor [L, 3, 3] + Backbone coords (N, CA, C) per residue. + ligand_coords : Tensor [N, 3] + Ligand atom coords. + contact_threshold : float + Distance cutoff in angstrom for defining a contact. + + Returns + ------- + dict with contact metrics. + """ + # Use CA atoms for residue-level contacts + if protein_coords.dim() == 3: + ca_coords = protein_coords[:, 1, :] + else: + ca_coords = protein_coords + + # Pairwise distances: [n_residues, n_ligand_atoms] + dists = torch.cdist(ca_coords.unsqueeze(0), ligand_coords.unsqueeze(0)).squeeze(0) + min_dist_per_residue = dists.min(dim=1).values + min_dist_per_ligand_atom = dists.min(dim=0).values + + n_residues = ca_coords.shape[0] + n_ligand_atoms = ligand_coords.shape[0] + + residues_in_contact = (min_dist_per_residue < contact_threshold).sum().item() + ligand_atoms_in_contact = (min_dist_per_ligand_atom < contact_threshold).sum().item() + n_contacts = (dists < contact_threshold).sum().item() + + return { + "n_contacts": n_contacts, + "n_residues_in_contact": int(residues_in_contact), + "frac_residues_in_contact": residues_in_contact / n_residues, + "n_ligand_atoms_in_contact": int(ligand_atoms_in_contact), + "frac_ligand_atoms_in_contact": ligand_atoms_in_contact / n_ligand_atoms, + "min_protein_ligand_dist": float(dists.min().item()), + "mean_min_dist_per_residue": float(min_dist_per_residue.mean().item()), + } + + def sequence_to_string(self, seq_tensor: Tensor) -> str: + """Convert sequence tensor (standard format) to string.""" + return "".join([self.standard_aa_map.get(int(s), "X") for s in seq_tensor.cpu().tolist()]) + + def _evaluate_single_design( + self, + model: "LightningModule", + sample: dict, + ligand_id: str, + ligand_length: int, + ) -> dict | None: + """Generate one design for a ligand and compute its metrics. + + Returns a dict of metrics, or None if generation failed. + """ + gen_result = self.generate_protein(model, sample) + pred_seq = gen_result["predicted_sequence"] + decoded_coords = gen_result["decoded_coords"] + decoded_ligand_coords = gen_result["decoded_ligand_coords"] + + if decoded_coords is None: + return None + + # Minimize decoded ligand geometry if enabled + if self.minimize_ligand and decoded_ligand_coords is not None: + atom_names = sample.get("ligand_atom_names") + if atom_names is None: + idx_to_element = {v: k for k, v in self.element_to_idx.items()} + ligand_types = sample["ligand_atom_types"] + atom_names = [ + f"{idx_to_element.get(int(t), 'C')}{i + 1}" for i, t in enumerate(ligand_types.cpu().tolist()) + ] + try: + decoded_ligand_coords = minimize_ligand_structure( + decoded_ligand_coords.cpu(), + atom_names, + bond_matrix=sample.get("bond_matrix"), + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ).to(self.device) + except Exception as e: + logger.warning(f"Ligand minimization failed for {ligand_id}: {e}") + + # Fold generated sequence with ESMFold + seq_str = self.sequence_to_string(pred_seq) + esm_result = self.fold_with_esmfold(seq_str) + esmfold_coords = esm_result["esmfold_coords"] + + # Compute pocket mask on the decoded structure + pocket_mask = None + n_pocket = 0 + if decoded_ligand_coords is not None: + pocket_mask = self.compute_binding_pocket(decoded_coords, decoded_ligand_coords) + n_pocket = int(pocket_mask.sum().item()) + + # Contact metrics between protein and ligand + contact_metrics = {} + if decoded_ligand_coords is not None: + contact_metrics = self.compute_contact_metrics(decoded_coords, decoded_ligand_coords) + + result = { + "ligand_id": ligand_id, + "protein_length": self.length, + "ligand_length": ligand_length, + "n_pocket_residues": n_pocket, + "n_contacts": contact_metrics.get("n_contacts", 0), + "n_residues_in_contact": contact_metrics.get("n_residues_in_contact", 0), + "frac_residues_in_contact": contact_metrics.get("frac_residues_in_contact", 0.0), + "n_ligand_atoms_in_contact": contact_metrics.get("n_ligand_atoms_in_contact", 0), + "frac_ligand_atoms_in_contact": contact_metrics.get("frac_ligand_atoms_in_contact", 0.0), + "min_protein_ligand_dist": contact_metrics.get("min_protein_ligand_dist", float("nan")), + "scTM": self.compute_tm_score(decoded_coords, esmfold_coords, pred_seq), + "scRMSD": self.compute_rmsd(decoded_coords, esmfold_coords), + "plddt": esm_result["plddt"], + "pae": esm_result["pae"], + "sequence": seq_str, + } + + if pocket_mask is not None and n_pocket > 0: + result["pocket_scTM"] = self.compute_tm_score(decoded_coords, esmfold_coords, pred_seq, pocket_mask) + result["pocket_scRMSD"] = self.compute_rmsd(decoded_coords, esmfold_coords, pocket_mask) + else: + result["pocket_scTM"] = float("nan") + result["pocket_scRMSD"] = float("nan") + + # Attach tensors for optional structure saving (not serialized to CSV) + result["_pred_seq"] = pred_seq + result["_decoded_coords"] = decoded_coords + result["_decoded_ligand_coords"] = decoded_ligand_coords + result["_esmfold_coords"] = esmfold_coords + + return result + + def evaluate( + self, + model: "LightningModule", + samples: list[dict] | None = None, + structure_path: str | None = None, + ) -> dict: + """Run full self-consistency evaluation on the test set. + + For each ligand, ``num_designs`` proteins are generated. The design + with the highest scTM is selected as the representative for that + ligand. All per-design results are kept in the returned DataFrame + (with a ``design_idx`` column); summary statistics are computed over + best-per-ligand rows only. + + Returns + ------- + dict with: + - results_df: DataFrame with per-sample, per-design results + - summary: dict with aggregated metrics (best design per ligand) + """ + model.eval() + model.to(self.device) + + if samples is None: + samples = self.load_test_set() + + if structure_path: + os.makedirs(structure_path, exist_ok=True) + + all_results = [] + skipped_samples = [] + + for sample in tqdm(samples, desc="Evaluating ligand-conditioned generation"): + ligand_id = sample["ligand_id"] + ligand_length = len(sample["ligand_coords"]) + total_length = self.length + ligand_length + + if total_length > self.max_length: + logger.warning( + f"Skipping {ligand_id}: total length {total_length} " + f"(protein: {self.length}, ligand: {ligand_length}) " + f"exceeds max_length {self.max_length}" + ) + skipped_samples.append( + { + "ligand_id": ligand_id, + "ligand_length": ligand_length, + "total_length": total_length, + "reason": "max_length", + } + ) + continue + + # Generate num_designs proteins and evaluate each + design_results = [] + for design_idx in range(self.num_designs): + result = self._evaluate_single_design( + model, + sample, + ligand_id, + ligand_length, + ) + if result is None: + logger.warning(f"No decoded coordinates for {ligand_id} design {design_idx}, skipping design") + continue + result["design_idx"] = design_idx + design_results.append(result) + + if not design_results: + logger.warning(f"All {self.num_designs} designs failed for {ligand_id}") + continue + + # Select best design by scTM (highest) + best = max(design_results, key=lambda r: r["scTM"]) + for r in design_results: + r["is_best"] = r is best + + # Save all designs' structures + if structure_path and self.save_structures: + for r in design_results: + design_suffix = f"_d{r['design_idx']}" + self._save_outputs( + structure_path, + f"{ligand_id}{design_suffix}", + sample, + r["_pred_seq"], + r["_decoded_coords"], + r["_decoded_ligand_coords"], + r["_esmfold_coords"], + ) + + # Strip tensor fields before collecting results + for r in design_results: + for key in ( + "_pred_seq", + "_decoded_coords", + "_decoded_ligand_coords", + "_esmfold_coords", + ): + r.pop(key, None) + + all_results.extend(design_results) + + if skipped_samples: + logger.info(f"Skipped {len(skipped_samples)} samples due to total length > {self.max_length}") + + results_df = pd.DataFrame(all_results) + + if len(results_df) == 0: + logger.warning("No samples were successfully evaluated") + return {"results_df": results_df, "summary": self._empty_summary()} + + summary = self._compute_summary(results_df) + return {"results_df": results_df, "summary": summary} + + def _save_outputs( + self, + structure_path: str, + ligand_id: str, + sample: dict, + pred_seq: Tensor, + decoded_coords: Tensor, + decoded_ligand_coords: Tensor | None, + esmfold_coords: Tensor, + ): + """Save generated structures to disk.""" + atom_names = sample.get("ligand_atom_names") + if atom_names is None: + idx_to_element = {v: k for k, v in self.element_to_idx.items()} + ligand_types = sample["ligand_atom_types"] + atom_names = [ + f"{idx_to_element.get(int(t), 'C')}{i + 1}" for i, t in enumerate(ligand_types.cpu().tolist()) + ] + bond_matrix = sample.get("bond_matrix") + + # Save FASTA + seq_str = self.sequence_to_string(pred_seq) + fasta_path = os.path.join(structure_path, f"{ligand_id}_generated.fasta") + with open(fasta_path, "w") as f: + f.write(f">{ligand_id}_generated\n{seq_str}\n") + + # Save model-decoded structure + if decoded_ligand_coords is not None: + writepdb_ligand_complex( + os.path.join(structure_path, f"{ligand_id}_decoded.pdb"), + protein_atoms=decoded_coords, + protein_seq=pred_seq, + ligand_atoms=decoded_ligand_coords, + ligand_atom_names=atom_names, + ligand_bond_matrix=bond_matrix, + ) + else: + writepdb( + os.path.join(structure_path, f"{ligand_id}_decoded.pdb"), + decoded_coords, + pred_seq, + ) + + # Save ESMFold-predicted structure + writepdb( + os.path.join(structure_path, f"{ligand_id}_esmfold.pdb"), + esmfold_coords, + pred_seq, + ) + + def _empty_summary(self) -> dict: + """Return summary dict with NaN values for empty results.""" + return { + "n_ligands": 0, + "num_designs": self.num_designs, + "protein_length": self.length, + "mean_n_contacts": float("nan"), + "mean_frac_residues_in_contact": float("nan"), + "mean_frac_ligand_atoms_in_contact": float("nan"), + "mean_min_protein_ligand_dist": float("nan"), + "mean_scTM": float("nan"), + "mean_scRMSD": float("nan"), + "mean_pocket_scTM": float("nan"), + "mean_pocket_scRMSD": float("nan"), + "mean_plddt": float("nan"), + "mean_pae": float("nan"), + } + + def _compute_summary(self, results_df: pd.DataFrame) -> dict: + """Compute aggregated summary statistics from results DataFrame. + + Summary metrics are computed over the best design per ligand only + (selected by highest scTM). + """ + best_df = results_df[results_df["is_best"]].copy() + n_total_designs = len(results_df) + n_ligands = len(best_df) + + return { + "n_ligands": n_ligands, + "n_total_designs": n_total_designs, + "num_designs": self.num_designs, + "protein_length": self.length, + "mean_n_contacts": best_df["n_contacts"].mean(), + "std_n_contacts": best_df["n_contacts"].std(), + "mean_n_residues_in_contact": best_df["n_residues_in_contact"].mean(), + "mean_frac_residues_in_contact": best_df["frac_residues_in_contact"].mean(), + "mean_frac_ligand_atoms_in_contact": best_df["frac_ligand_atoms_in_contact"].mean(), + "mean_min_protein_ligand_dist": best_df["min_protein_ligand_dist"].mean(), + "std_min_protein_ligand_dist": best_df["min_protein_ligand_dist"].std(), + "mean_scTM": best_df["scTM"].mean(), + "std_scTM": best_df["scTM"].std(), + "median_scTM": best_df["scTM"].median(), + "mean_scRMSD": best_df["scRMSD"].mean(), + "std_scRMSD": best_df["scRMSD"].std(), + "median_scRMSD": best_df["scRMSD"].median(), + "mean_pocket_scTM": best_df["pocket_scTM"].mean(), + "std_pocket_scTM": best_df["pocket_scTM"].std(), + "mean_pocket_scRMSD": best_df["pocket_scRMSD"].mean(), + "std_pocket_scRMSD": best_df["pocket_scRMSD"].std(), + "mean_plddt": best_df["plddt"].mean(), + "std_plddt": best_df["plddt"].std(), + "mean_pae": best_df["pae"].mean(), + "std_pae": best_df["pae"].std(), + "mean_pocket_size": best_df["n_pocket_residues"].mean(), + } diff --git a/src/lobster/metrics/protein_ligand_forward_folding.py b/src/lobster/metrics/protein_ligand_forward_folding.py new file mode 100644 index 00000000..20e1009e --- /dev/null +++ b/src/lobster/metrics/protein_ligand_forward_folding.py @@ -0,0 +1,1191 @@ +"""Protein-Ligand Forward Folding Evaluator. + +Evaluates forward folding (sequence → structure) on protein-ligand complexes +with and without ligand context. + +Key Question: Does providing ligand context improve structure prediction quality +(TM-score, RMSD) for the protein, particularly for binding pocket residues? +""" + +import os +from glob import glob +from typing import TYPE_CHECKING + +import pandas as pd +import torch +from loguru import logger +from tmtools import tm_align +from torch import Tensor +from tqdm import tqdm + +from bionemo.moco.schedules.inference_time_schedules import ( + LinearInferenceSchedule, + LogInferenceSchedule, + PowerInferenceSchedule, +) + +from lobster.metrics import align_and_compute_rmsd +from lobster.model.latent_generator.io import writepdb, writepdb_ligand_complex +from lobster.model.latent_generator.utils import minimize_ligand_structure +from lobster.model.latent_generator.utils.residue_constants import restype_order_with_x_inv +from lobster.transforms._structure_transforms import AminoAcidTokenizerTransform + +# Mapping from string names to inference schedule classes +INFERENCE_SCHEDULE_MAP = { + "LinearInferenceSchedule": LinearInferenceSchedule, + "LogInferenceSchedule": LogInferenceSchedule, + "PowerInferenceSchedule": PowerInferenceSchedule, +} + + +def _get_inference_schedule_class(schedule_name: str): + """Convert string schedule name to class.""" + if schedule_name not in INFERENCE_SCHEDULE_MAP: + raise ValueError(f"Unknown inference schedule: {schedule_name}. Options: {list(INFERENCE_SCHEDULE_MAP.keys())}") + return INFERENCE_SCHEDULE_MAP[schedule_name] + + +if TYPE_CHECKING: + from lightning import LightningModule + + +class ProteinLigandForwardFoldingEvaluator: + """Evaluates forward folding on protein-ligand complexes with/without ligand context. + + This evaluator compares two modes: + 1. Protein-only: Provide only protein sequence, predict structure + 2. Protein+Ligand: Provide protein sequence + ligand, predict structure + + Tracks metrics: + - Overall TM-score and RMSD + - Binding pocket RMSD (residues within distance threshold of ligand) + - Non-pocket RMSD + + Can be used: + - As standalone evaluation script + - Within callback during training + + Parameters + ---------- + data_dir : str + Path to PDBBind test directory containing *_protein.pt and *_ligand.pt pairs + pocket_distance_threshold : float + Distance threshold (Å) for defining binding pocket residues + num_samples : int, optional + Limit number of samples to evaluate (None = all) + nsteps : int + Number of diffusion steps for generation + device : str + Device for computation + max_length : int + Maximum combined sequence length (protein + ligand) to process (default: 512). + Samples exceeding this length will be skipped. + max_protein_length : int + Maximum protein-only sequence length (default: 512). Samples with protein length + exceeding this will be skipped entirely. + temperature_seq : float + Temperature for sequence sampling + temperature_struc : float + Temperature for structure sampling + save_structures : bool + Whether to save predicted structures as PDB files (default: False). + save_gt_structure : bool + Whether to save ground truth structures as PDB files (default: False). + minimize_ligand : bool + Whether to apply geometry correction to decoded ligand structures (default: False). + minimize_mode : str + Minimization mode: "bonds_only", "bonds_and_angles", "local", or "full". + force_field : str + Force field for minimization: "MMFF94", "MMFF94s", "UFF", etc. + minimize_steps : int + Maximum number of minimization steps. + stochasticity_seq : int + Stochasticity parameter for sequence sampling (default: 20). + stochasticity_struc : int + Stochasticity parameter for structure sampling (default: 20). + temperature_ligand : float + Temperature for ligand structure sampling (default: 0.5). + stochasticity_ligand : int + Stochasticity parameter for ligand structure sampling (default: 20). + ligand_context_mode : str + How to provide ligand context. Options: + - "structure_tokens": Encode GT ligand structure and provide tokens as fixed context + - "atom_bond_only": Only provide atom types + bond matrix, model generates ligand structure + inference_schedule_seq : str + Inference schedule for sequence generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule" (default: "LogInferenceSchedule"). + inference_schedule_struc : str + Inference schedule for structure generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule" (default: "LinearInferenceSchedule"). + inference_schedule_ligand_atom : str + Inference schedule for ligand atom token generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule", or None to use sequence schedule + (default: None). + inference_schedule_ligand_struc : str + Inference schedule for ligand structure token generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule", or None to use structure schedule + (default: None). + num_predictions : int + Number of predictions per sample for best-of-N evaluation (default: 1). + When > 1, generates multiple predictions and selects the best one. + best_of_n_metric : str + Metric to use for best-of-N selection: "rmsd" (lower is better) or "tm_score" + (higher is better). Default: "rmsd". + save_all_predictions : bool + Whether to save all N predicted structures (not just the best). Only applies + when save_structures=True and num_predictions > 1. Default: False. + try_reflection : bool + Whether to try both original and reflected (mirror image) coordinates and + select the one with higher TM-score. This is useful if the model might + output mirror images of structures. Default: False. + """ + + def __init__( + self, + data_dir: str, + pocket_distance_threshold: float = 5.0, + num_samples: int | None = None, + nsteps: int = 100, + device: str = "cuda", + max_length: int = 512, + max_protein_length: int = 512, + temperature_seq: float = 0.5, + temperature_struc: float = 0.5, + save_structures: bool = False, + save_gt_structure: bool = False, + minimize_ligand: bool = False, + minimize_mode: str = "bonds_and_angles", + force_field: str = "MMFF94", + minimize_steps: int = 500, + # Additional generation hyperparameters + stochasticity_seq: int = 20, + stochasticity_struc: int = 20, + temperature_ligand: float = 0.5, + stochasticity_ligand: int = 20, + ligand_context_mode: str = "structure_tokens", + inference_schedule_seq: str = "LogInferenceSchedule", + inference_schedule_struc: str = "LinearInferenceSchedule", + inference_schedule_ligand_atom: str | None = None, + inference_schedule_ligand_struc: str | None = None, + # Best-of-N parameters + num_predictions: int = 1, + best_of_n_metric: str = "rmsd", + save_all_predictions: bool = False, + # Mirror image handling + try_reflection: bool = False, + ): + self.data_dir = data_dir + self.pocket_distance_threshold = pocket_distance_threshold + self.num_samples = num_samples + self.nsteps = nsteps + self.device = device + self.max_length = max_length + self.max_protein_length = max_protein_length + self.temperature_seq = temperature_seq + self.temperature_struc = temperature_struc + self.save_structures = save_structures + self.save_gt_structure = save_gt_structure + self.minimize_ligand = minimize_ligand + self.minimize_mode = minimize_mode + self.force_field = force_field + self.minimize_steps = minimize_steps + # Additional generation hyperparameters + self.stochasticity_seq = stochasticity_seq + self.stochasticity_struc = stochasticity_struc + self.temperature_ligand = temperature_ligand + self.stochasticity_ligand = stochasticity_ligand + self.ligand_context_mode = ligand_context_mode + self.inference_schedule_seq = inference_schedule_seq + self.inference_schedule_struc = inference_schedule_struc + self.inference_schedule_ligand_atom = inference_schedule_ligand_atom + self.inference_schedule_ligand_struc = inference_schedule_ligand_struc + # Best-of-N parameters + self.num_predictions = num_predictions + self.best_of_n_metric = best_of_n_metric + self.save_all_predictions = save_all_predictions + if best_of_n_metric not in ("rmsd", "tm_score"): + raise ValueError(f"best_of_n_metric must be 'rmsd' or 'tm_score', got {best_of_n_metric}") + # Mirror image handling + self.try_reflection = try_reflection + + # Initialize tokenizer transform for sequence conversion + self.tokenizer_transform = AminoAcidTokenizerTransform(max_length=max_length) + + # Amino acid mapping (standard 21 tokens) + self.aa_map = { + 0: "L", + 1: "A", + 2: "G", + 3: "V", + 4: "S", + 5: "E", + 6: "R", + 7: "T", + 8: "I", + 9: "D", + 10: "P", + 11: "K", + 12: "Q", + 13: "F", + 14: "N", + 15: "Y", + 16: "M", + 17: "H", + 18: "W", + 19: "C", + 20: "X", + } + + # Element vocabulary (ELEMENT_VOCAB_EXTENDED from residue_constants) + self.element_to_idx = { + "PAD": 0, + "MASK": 1, + "UNK": 2, + "C": 3, + "N": 4, + "O": 5, + "S": 6, + "P": 7, + "H": 8, + "F": 9, + "Cl": 10, + "Br": 11, + "I": 12, + "Fe": 13, + "Zn": 14, + "Mg": 15, + "Ca": 16, + "Mn": 17, + "Cu": 18, + "B": 19, + "Si": 20, + "Se": 21, + "Co": 22, + "Ni": 23, + "Bi": 24, + } + + def _atom_names_to_indices(self, atom_names: list) -> Tensor: + """Convert atom names (e.g., ['C1', 'N2', 'O3']) to element indices.""" + indices = [] + for name in atom_names: + # Extract element symbol (first 1-2 characters, handling cases like 'Cl', 'Br') + if len(name) >= 2 and name[:2] in self.element_to_idx: + elem = name[:2] + elif name[0] in self.element_to_idx: + elem = name[0] + else: + # Try just the first character uppercase + elem = name[0].upper() + + idx = self.element_to_idx.get(elem, 2) # 2 = UNK + indices.append(idx) + + return torch.tensor(indices, dtype=torch.long, device=self.device) + + def _reflect_coords(self, coords: Tensor) -> Tensor: + """Create a mirror image of coordinates by negating the x-axis. + + Parameters + ---------- + coords : Tensor + Coordinates tensor of shape [..., 3] where the last dimension is (x, y, z) + + Returns + ------- + Tensor + Reflected coordinates with x-axis negated + """ + reflected = coords.clone() + reflected[..., 0] = -reflected[..., 0] + return reflected + + def _select_best_orientation( + self, + pred_coords: Tensor, + gt_coords: Tensor, + sequence: Tensor, + mask: Tensor | None = None, + decoded_ligand_coords: Tensor | None = None, + ) -> tuple[Tensor, Tensor | None, bool]: + """Select best orientation (original or reflected) based on TM-score. + + Parameters + ---------- + pred_coords : Tensor + [L, 3, 3] predicted protein backbone coordinates + gt_coords : Tensor + [L, 3, 3] ground truth backbone coordinates + sequence : Tensor + [L] sequence tokens for TM-align + mask : Tensor, optional + [L] boolean mask for positions to include + decoded_ligand_coords : Tensor, optional + [N_atoms, 3] predicted ligand coordinates (will be reflected too if needed) + + Returns + ------- + tuple + (best_pred_coords, best_ligand_coords, was_reflected) + """ + # Compute TM-score for original orientation + tm_original = self.compute_tm_score(pred_coords, gt_coords, sequence, mask) + + # Compute TM-score for reflected orientation + reflected_coords = self._reflect_coords(pred_coords) + tm_reflected = self.compute_tm_score(reflected_coords, gt_coords, sequence, mask) + + # Select the orientation with higher TM-score + if tm_reflected > tm_original: + # Use reflected coordinates + reflected_ligand = None + if decoded_ligand_coords is not None: + reflected_ligand = self._reflect_coords(decoded_ligand_coords) + return reflected_coords, reflected_ligand, True + else: + # Use original coordinates + return pred_coords, decoded_ligand_coords, False + + def load_test_set(self) -> list[dict]: + """Load PDBBind test protein-ligand pairs. + + Returns list of dicts with: + - pdb_id: str + - protein_coords: Tensor [L, 3, 3] # N, CA, C backbone + - protein_sequence: Tensor [L] + - protein_mask: Tensor [L] + - protein_indices: Tensor [L] + - ligand_coords: Tensor [N_atoms, 3] + - ligand_atom_types: Tensor [N_atoms] + - ligand_mask: Tensor [N_atoms] + - ligand_indices: Tensor [N_atoms] + - bond_matrix: Tensor [N_atoms, N_atoms] (if available) + """ + # Find protein-ligand pairs + protein_files = sorted(glob(os.path.join(self.data_dir, "*_protein.pt"))) + + if not protein_files: + raise ValueError(f"No protein files found in {self.data_dir}") + + # Limit samples if specified + if self.num_samples is not None: + protein_files = protein_files[: self.num_samples] + + logger.info(f"Loading {len(protein_files)} protein-ligand pairs from {self.data_dir}") + + samples = [] + for pf in tqdm(protein_files, desc="Loading samples"): + pdb_id = os.path.basename(pf).replace("_protein.pt", "") + ligand_file = pf.replace("_protein.pt", "_ligand.pt") + + if not os.path.exists(ligand_file): + logger.warning(f"Missing ligand file for {pdb_id}, skipping") + continue + + protein_data = torch.load(pf, weights_only=False, map_location=self.device) + ligand_data = torch.load(ligand_file, weights_only=False, map_location=self.device) + + # Extract protein data + protein_coords = protein_data.get("coords_res", protein_data.get("coords")) + protein_sequence = protein_data.get("sequence") + + if protein_coords is None: + logger.warning(f"Missing protein coordinates for {pdb_id}, skipping") + continue + + protein_mask = protein_data.get("mask", torch.ones(protein_coords.shape[0], device=self.device)) + protein_indices = protein_data.get("indices", torch.arange(protein_coords.shape[0], device=self.device)) + + # Extract ligand data - handle different key names + ligand_coords = ligand_data.get("atom_coords", ligand_data.get("coords", ligand_data.get("ligand_coords"))) + + if ligand_coords is None: + logger.warning(f"Missing ligand coordinates for {pdb_id}, skipping") + continue + + # Handle atom types - may be a list of names or tensor of indices + atom_names = ligand_data.get("atom_names") + if atom_names is not None and isinstance(atom_names, list): + # Convert atom names to element indices + ligand_atom_types = self._atom_names_to_indices(atom_names) + else: + ligand_atom_types = ligand_data.get( + "element_indices", + ligand_data.get( + "ligand_element_indices", + torch.full((ligand_coords.shape[0],), 3, dtype=torch.long, device=self.device), + ), # Default to carbon (3) + ) + + ligand_mask = ligand_data.get( + "mask", ligand_data.get("ligand_mask", torch.ones(ligand_coords.shape[0], device=self.device)) + ) + ligand_indices = ligand_data.get( + "atom_indices", + ligand_data.get( + "indices", + ligand_data.get("ligand_indices", torch.arange(ligand_coords.shape[0], device=self.device)), + ), + ) + bond_matrix = ligand_data.get("bond_matrix") + + if protein_sequence is None: + logger.warning(f"Missing sequence for {pdb_id}, skipping") + continue + + samples.append( + { + "pdb_id": pdb_id, + "protein_coords": protein_coords, + "protein_sequence": protein_sequence, + "protein_mask": protein_mask, + "protein_indices": protein_indices, + "ligand_coords": ligand_coords, + "ligand_atom_types": ligand_atom_types, + "ligand_atom_names": atom_names, # Keep original atom names for PDB writing + "ligand_mask": ligand_mask, + "ligand_indices": ligand_indices, + "bond_matrix": bond_matrix, + } + ) + + logger.info(f"Loaded {len(samples)} valid samples") + return samples + + def compute_binding_pocket( + self, + protein_coords: Tensor, + ligand_coords: Tensor, + protein_mask: Tensor | None = None, + ) -> Tensor: + """Compute pocket mask based on distance to ligand. + + A residue is considered part of the binding pocket if any of its + backbone atoms (N, CA, C) are within the threshold distance of + any ligand heavy atom. + + Parameters + ---------- + protein_coords : Tensor + [L, 3, 3] or [L, 3] backbone coordinates + ligand_coords : Tensor + [N_atoms, 3] ligand atom coordinates + protein_mask : Tensor, optional + [L] valid residue mask + + Returns + ------- + pocket_mask : Tensor + [L] boolean mask, True for pocket residues + """ + # Handle different coordinate formats + if protein_coords.dim() == 3: + # [L, 3, 3] - use CA atoms (index 1) + ca_coords = protein_coords[:, 1, :] # [L, 3] + else: + # [L, 3] - already CA-like + ca_coords = protein_coords + + # Compute pairwise distances between CA atoms and ligand atoms + # ca_coords: [L, 3], ligand_coords: [N_atoms, 3] + # distances: [L, N_atoms] + distances = torch.cdist(ca_coords.unsqueeze(0), ligand_coords.unsqueeze(0)).squeeze(0) + + # Min distance from each residue to any ligand atom + min_distances = distances.min(dim=1).values # [L] + + # Pocket mask: residues within threshold + pocket_mask = min_distances < self.pocket_distance_threshold + + # Apply valid mask if provided + if protein_mask is not None: + pocket_mask = pocket_mask & protein_mask.bool() + + return pocket_mask + + def forward_fold( + self, + model: "LightningModule", + sample: dict, + include_ligand: bool, + ) -> dict: + """Run forward folding with or without ligand context. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-ligand model + sample : dict + Sample dictionary from load_test_set() + include_ligand : bool + Whether to include ligand context + + Returns + ------- + dict with: + - predicted_coords: Tensor [L, 3, 3] (N, CA, C backbone) + - structure_tokens: Tensor [L] + """ + # Prepare protein inputs + protein_mask = sample["protein_mask"].unsqueeze(0).float() + protein_indices = sample["protein_indices"].unsqueeze(0).long() + length = int(protein_mask.sum().item()) + + # Tokenize sequence for forward folding + gt_seq = sample["protein_sequence"] + tokenized_data = self.tokenizer_transform({"sequence": gt_seq.cpu()}) + tokenized_seq = tokenized_data["sequence"].to(self.device).unsqueeze(0) # [1, L] + + # Prepare ligand inputs if needed + ligand_mask = None + ligand_atom_tokens = None + ligand_structure_tokens = None + ligand_structure_embeddings = None + bond_matrix = None + num_atoms = 0 + + if include_ligand: + ligand_coords = sample["ligand_coords"].unsqueeze(0).float() + ligand_mask = sample["ligand_mask"].unsqueeze(0).float() + ligand_indices = sample["ligand_indices"].unsqueeze(0).long() + ligand_atom_tokens = sample["ligand_atom_types"].unsqueeze(0).long() + num_atoms = ligand_coords.shape[1] + + # Conditionally encode ligand structure based on ligand_context_mode + if self.ligand_context_mode == "structure_tokens": + # Encode GT ligand structure and provide tokens as fixed context + with torch.no_grad(): + encode_result = model.encode_ligand_structure( + ligand_coords, ligand_mask, ligand_indices, return_continuous=True + ) + ligand_structure_tokens, _, ligand_structure_embeddings = encode_result + else: + # atom_bond_only mode: don't provide structure tokens, model generates ligand structure + ligand_structure_tokens = None + ligand_structure_embeddings = None + + bond_matrix = sample.get("bond_matrix") + if bond_matrix is not None: + bond_matrix = bond_matrix.unsqueeze(0).long() + + # Get inference schedule classes + inference_schedule_seq_class = _get_inference_schedule_class(self.inference_schedule_seq) + inference_schedule_struc_class = _get_inference_schedule_class(self.inference_schedule_struc) + # Ligand schedules (None to fall back to protein schedules) + inference_schedule_ligand_atom_class = ( + _get_inference_schedule_class(self.inference_schedule_ligand_atom) + if self.inference_schedule_ligand_atom + else None + ) + inference_schedule_ligand_struc_class = ( + _get_inference_schedule_class(self.inference_schedule_ligand_struc) + if self.inference_schedule_ligand_struc + else None + ) + + # Generate sample (forward folding mode) + with torch.no_grad(): + result = model.generate_sample( + length=length, + num_samples=1, + forward_folding=True, + nsteps=self.nsteps, + inference_schedule_seq=inference_schedule_seq_class, + inference_schedule_struc=inference_schedule_struc_class, + inference_schedule_ligand_atom=inference_schedule_ligand_atom_class, + inference_schedule_ligand_struc=inference_schedule_ligand_struc_class, + temperature_seq=self.temperature_seq, + temperature_struc=self.temperature_struc, + stochasticity_seq=self.stochasticity_seq, + stochasticity_struc=self.stochasticity_struc, + temperature_ligand=self.temperature_ligand, + stochasticity_ligand=self.stochasticity_ligand, + input_sequence_tokens=tokenized_seq, + input_mask=protein_mask, + input_indices=protein_indices, + # Ligand context (fixed conditioning if structure_tokens mode, otherwise model generates) + generate_ligand=include_ligand, + num_atoms=num_atoms if include_ligand else 0, + input_ligand_atom_tokens=ligand_atom_tokens, + input_ligand_structure_tokens=ligand_structure_tokens, + input_ligand_structure_embeddings=ligand_structure_embeddings, + input_bond_matrix=bond_matrix, + ligand_is_context=include_ligand and self.ligand_context_mode == "structure_tokens", + ) + + # Decode structure + decoded_x = model.decode_structure(result, protein_mask, ligand_mask=ligand_mask) + + # Extract coordinates from vit_decoder + predicted_coords = None + decoded_ligand_coords = None + for decoder_name in decoded_x: + if "vit_decoder" == decoder_name: + vit_output = decoded_x[decoder_name] + # Handle both tensor output (protein-only) and dict output (protein-ligand) + if isinstance(vit_output, dict): + predicted_coords = vit_output.get("protein_coords") + decoded_ligand_coords = vit_output.get("ligand_coords") + else: + predicted_coords = vit_output + break + + if predicted_coords is None: + raise RuntimeError("No vit_decoder found in decoded structures") + + # Handle both discrete (structure_tokens) and continuous (structure_embeddings) modes + structure_tokens = result.get("generated_struc_tokens") + structure_embeddings = result.get("generated_structure_embeddings") + + # Get predicted bond matrix if available + predicted_bond_matrix = result.get("predicted_bond_matrix") + + return { + "predicted_coords": predicted_coords.squeeze(0), # [L, 3, 3] + "structure_tokens": structure_tokens.squeeze(0) if structure_tokens is not None else None, + "structure_embeddings": structure_embeddings.squeeze(0) if structure_embeddings is not None else None, + "decoded_ligand_coords": decoded_ligand_coords.squeeze(0) if decoded_ligand_coords is not None else None, + "predicted_bond_matrix": predicted_bond_matrix.squeeze(0) if predicted_bond_matrix is not None else None, + } + + def forward_fold_best_of_n( + self, + model: "LightningModule", + sample: dict, + include_ligand: bool, + n_predictions: int, + ) -> dict: + """Run forward folding N times and return the best prediction. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-ligand model + sample : dict + Sample dictionary from load_test_set() + include_ligand : bool + Whether to include ligand context + n_predictions : int + Number of predictions to generate + + Returns + ------- + dict with: + - predicted_coords: Tensor [L, 3, 3] (N, CA, C backbone) - best prediction + - structure_tokens: Tensor [L] - from best prediction + - all_predictions: list of dicts - all N predictions with their scores + - best_idx: int - index of the best prediction + - best_score: float - score of the best prediction + """ + if n_predictions == 1: + pred = self.forward_fold(model, sample, include_ligand) + return { + **pred, + "all_predictions": [pred], + "best_idx": 0, + "best_score": None, + } + + gt_coords = sample["protein_coords"] + gt_seq = sample["protein_sequence"] + protein_mask = sample["protein_mask"] + + all_predictions = [] + scores = [] + + for i in range(n_predictions): + pred = self.forward_fold(model, sample, include_ligand) + all_predictions.append(pred) + + # Compute score for selection + if self.best_of_n_metric == "rmsd": + score = self.compute_rmsd(pred["predicted_coords"], gt_coords, protein_mask) + else: # tm_score + score = self.compute_tm_score(pred["predicted_coords"], gt_coords, gt_seq, protein_mask) + scores.append(score) + + # Select best prediction + if self.best_of_n_metric == "rmsd": + # Lower RMSD is better + best_idx = min(range(len(scores)), key=lambda i: scores[i]) + else: + # Higher TM-score is better + best_idx = max(range(len(scores)), key=lambda i: scores[i]) + + best_pred = all_predictions[best_idx] + return { + **best_pred, + "all_predictions": all_predictions, + "best_idx": best_idx, + "best_score": scores[best_idx], + "all_scores": scores, + } + + def compute_tm_score( + self, + pred_coords: Tensor, + gt_coords: Tensor, + sequence: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute TM-score between predicted and ground truth structures. + + Parameters + ---------- + pred_coords : Tensor + [L, 3, 3] predicted backbone coordinates (N, CA, C) + gt_coords : Tensor + [L, 3, 3] ground truth backbone coordinates + sequence : Tensor + [L] sequence tokens for alignment + mask : Tensor, optional + [L] boolean mask for positions to include + + Returns + ------- + float + TM-score (0-1, higher is better) + """ + # Apply mask if provided + if mask is not None: + mask = mask.bool() + pred_coords = pred_coords[mask] + gt_coords = gt_coords[mask] + sequence = sequence[mask] + + if len(pred_coords) == 0: + return float("nan") + + # Get sequence string for TM-align + sequence_str = "".join([restype_order_with_x_inv.get(int(s), "X") for s in sequence.cpu().tolist()]) + + # Use CA atoms (index 1) for TM-align + pred_ca = pred_coords[:, 1, :].detach().cpu().numpy() + gt_ca = gt_coords[:, 1, :].detach().cpu().numpy() + + # Calculate TM-Score using TM-align + tm_out = tm_align(pred_ca, gt_ca, sequence_str, sequence_str) + + return tm_out.tm_norm_chain1 + + def compute_rmsd( + self, + pred_coords: Tensor, + gt_coords: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute RMSD between predicted and ground truth structures. + + Parameters + ---------- + pred_coords : Tensor + [L, 3, 3] predicted backbone coordinates (N, CA, C) + gt_coords : Tensor + [L, 3, 3] ground truth backbone coordinates + mask : Tensor, optional + [L] boolean mask for positions to include + + Returns + ------- + float + RMSD in Angstroms (lower is better) + """ + # Apply mask if provided + if mask is not None: + mask = mask.bool() + pred_coords = pred_coords[mask] + gt_coords = gt_coords[mask] + + if len(pred_coords) == 0: + return float("nan") + + # Calculate RMSD using Kabsch alignment (detach to avoid gradient issues) + rmsd = align_and_compute_rmsd( + coords1=pred_coords.detach(), + coords2=gt_coords.detach(), + mask=None, # Already masked + return_aligned=False, + device=pred_coords.device, + ) + + return float(rmsd) + + def evaluate( + self, + model: "LightningModule", + samples: list[dict] | None = None, + structure_path: str | None = None, + ) -> dict: + """Run full evaluation on PDBBind test set. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-ligand model + samples : list[dict], optional + Pre-loaded samples (will load if not provided) + structure_path : str, optional + Directory to save predicted structures as PDB files + + Returns + ------- + dict with: + - results_df: DataFrame with per-structure results + - summary: dict with aggregated metrics + """ + model.eval() + model.to(self.device) + + if samples is None: + samples = self.load_test_set() + + # Log best-of-N settings + if self.num_predictions > 1: + logger.info(f"Using best-of-{self.num_predictions} evaluation (selecting by {self.best_of_n_metric})") + + # Log reflection settings + if self.try_reflection: + logger.info( + "Mirror image handling enabled: will try both original and reflected " + "coordinates and select based on TM-score" + ) + + # Create output directory if specified + if structure_path: + os.makedirs(structure_path, exist_ok=True) + + results = [] + skipped_samples = [] + + for sample in tqdm(samples, desc="Evaluating forward folding"): + pdb_id = sample["pdb_id"] + gt_seq = sample["protein_sequence"] + gt_coords = sample["protein_coords"] + protein_mask = sample["protein_mask"] + + # Check protein and combined lengths + protein_length = len(gt_seq) + ligand_length = len(sample["ligand_coords"]) + total_length = protein_length + ligand_length + + # Skip samples exceeding max protein length + if protein_length > self.max_protein_length: + logger.warning( + f"Skipping {pdb_id}: protein length {protein_length} " + f"exceeds max_protein_length {self.max_protein_length}" + ) + skipped_samples.append( + { + "pdb_id": pdb_id, + "protein_length": protein_length, + "ligand_length": ligand_length, + "total_length": total_length, + "reason": "max_protein_length", + } + ) + continue + + if total_length > self.max_length: + logger.warning( + f"Skipping {pdb_id}: total length {total_length} " + f"(protein: {protein_length}, ligand: {ligand_length}) exceeds max_length {self.max_length}" + ) + skipped_samples.append( + { + "pdb_id": pdb_id, + "protein_length": protein_length, + "ligand_length": ligand_length, + "total_length": total_length, + "reason": "max_length", + } + ) + continue + + # Compute binding pocket + pocket_mask = self.compute_binding_pocket( + gt_coords, + sample["ligand_coords"], + protein_mask, + ) + non_pocket_mask = protein_mask.bool() & ~pocket_mask + + # Mode 1: Protein only (no ligand context) + # NOTE: try-catch removed for debugging - will crash on first error to get full traceback + pred_no_ligand = self.forward_fold_best_of_n( + model, sample, include_ligand=False, n_predictions=self.num_predictions + ) + pred_coords_no_ligand = pred_no_ligand["predicted_coords"] + + # Mode 2: Protein + Ligand context + pred_with_ligand = self.forward_fold_best_of_n( + model, sample, include_ligand=True, n_predictions=self.num_predictions + ) + pred_coords_with_ligand = pred_with_ligand["predicted_coords"] + decoded_ligand_coords_with_ligand = pred_with_ligand.get("decoded_ligand_coords") + + # Try reflection if enabled - select best orientation based on TM-score + reflected_no_ligand = False + reflected_with_ligand = False + if self.try_reflection: + pred_coords_no_ligand, _, reflected_no_ligand = self._select_best_orientation( + pred_coords_no_ligand, gt_coords, gt_seq, protein_mask + ) + pred_coords_with_ligand, decoded_ligand_coords_with_ligand, reflected_with_ligand = ( + self._select_best_orientation( + pred_coords_with_ligand, + gt_coords, + gt_seq, + protein_mask, + decoded_ligand_coords_with_ligand, + ) + ) + # Update the prediction dict with potentially reflected coordinates + pred_no_ligand["predicted_coords"] = pred_coords_no_ligand + pred_with_ligand["predicted_coords"] = pred_coords_with_ligand + if decoded_ligand_coords_with_ligand is not None: + pred_with_ligand["decoded_ligand_coords"] = decoded_ligand_coords_with_ligand + + # Save structures if requested + if structure_path and (self.save_structures or self.save_gt_structure): + ligand_coords = sample["ligand_coords"] + # Get atom names from original data or generate default names + atom_names = sample.get("ligand_atom_names") + if atom_names is None: + # Generate default atom names from element indices + idx_to_element = {v: k for k, v in self.element_to_idx.items()} + ligand_types = sample["ligand_atom_types"] + atom_names = [ + f"{idx_to_element.get(int(t), 'C')}{i + 1}" for i, t in enumerate(ligand_types.cpu().tolist()) + ] + + # Get bond matrix for CONECT records + bond_matrix = sample.get("bond_matrix") + + # Save ground truth structures + if self.save_gt_structure: + # Save ground truth protein structure + gt_pdb_path = os.path.join(structure_path, f"{pdb_id}_gt_protein.pdb") + writepdb(gt_pdb_path, gt_coords, gt_seq) + + # Save ground truth protein-ligand complex + gt_complex_path = os.path.join(structure_path, f"{pdb_id}_gt_complex.pdb") + writepdb_ligand_complex( + gt_complex_path, + protein_atoms=gt_coords, + protein_seq=gt_seq, + ligand_atoms=ligand_coords, + ligand_atom_names=atom_names, + ligand_bond_matrix=bond_matrix, + ) + + # Save predicted structures + if self.save_structures: + # Determine which predictions to save + if self.save_all_predictions and self.num_predictions > 1: + # Save all N predictions + all_preds_no_ligand = pred_no_ligand.get("all_predictions", [pred_no_ligand]) + all_preds_with_ligand = pred_with_ligand.get("all_predictions", [pred_with_ligand]) + best_idx_no_ligand = pred_no_ligand.get("best_idx", 0) + best_idx_with_ligand = pred_with_ligand.get("best_idx", 0) + + # Save all predictions without ligand context + for i, pred in enumerate(all_preds_no_ligand): + suffix = f"_pred{i}" if self.num_predictions > 1 else "" + best_marker = "_best" if i == best_idx_no_ligand else "" + pred_no_lig_path = os.path.join( + structure_path, f"{pdb_id}_pred_no_ligand{suffix}{best_marker}.pdb" + ) + writepdb(pred_no_lig_path, pred["predicted_coords"].detach(), gt_seq) + + # Save all predictions with ligand context + for i, pred in enumerate(all_preds_with_ligand): + suffix = f"_pred{i}" if self.num_predictions > 1 else "" + best_marker = "_best" if i == best_idx_with_ligand else "" + pred_with_lig_path = os.path.join( + structure_path, f"{pdb_id}_pred_with_ligand{suffix}{best_marker}.pdb" + ) + decoded_ligand = pred.get("decoded_ligand_coords") + if decoded_ligand is not None: + pred_bond = pred.get("predicted_bond_matrix") + bond_matrix_for_pred = pred_bond if pred_bond is not None else bond_matrix + ligand_coords_to_save = decoded_ligand.detach().cpu() + if self.minimize_ligand: + try: + ligand_coords_to_save = minimize_ligand_structure( + ligand_coords_to_save, + atom_names, + bond_matrix=bond_matrix_for_pred, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Ligand minimization failed for {pdb_id} pred{i}: {e}") + writepdb_ligand_complex( + pred_with_lig_path, + protein_atoms=pred["predicted_coords"].detach(), + protein_seq=gt_seq, + ligand_atoms=ligand_coords_to_save, + ligand_atom_names=atom_names, + ligand_bond_matrix=bond_matrix_for_pred, + ) + else: + writepdb(pred_with_lig_path, pred["predicted_coords"].detach(), gt_seq) + else: + # Save only the best prediction (default behavior) + pred_no_lig_path = os.path.join(structure_path, f"{pdb_id}_pred_no_ligand.pdb") + writepdb(pred_no_lig_path, pred_coords_no_ligand.detach(), gt_seq) + + # Save predicted structure with ligand context (use decoded ligand) + pred_with_lig_path = os.path.join(structure_path, f"{pdb_id}_pred_with_ligand.pdb") + decoded_ligand_coords = pred_with_ligand.get("decoded_ligand_coords") + if decoded_ligand_coords is not None: + # Use predicted bond matrix if available, otherwise fall back to GT + pred_bond_matrix = pred_with_ligand.get("predicted_bond_matrix") + bond_matrix_for_pred = pred_bond_matrix if pred_bond_matrix is not None else bond_matrix + + # Apply minimization if enabled + ligand_coords_to_save = decoded_ligand_coords.detach().cpu() + if self.minimize_ligand: + try: + ligand_coords_to_save = minimize_ligand_structure( + ligand_coords_to_save, + atom_names, + bond_matrix=bond_matrix_for_pred, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Ligand minimization failed for {pdb_id}: {e}") + writepdb_ligand_complex( + pred_with_lig_path, + protein_atoms=pred_coords_with_ligand.detach(), + protein_seq=gt_seq, + ligand_atoms=ligand_coords_to_save, + ligand_atom_names=atom_names, + ligand_bond_matrix=bond_matrix_for_pred, + ) + else: + # Fallback to protein-only if no decoded ligand available + logger.warning(f"No decoded ligand coords for {pdb_id}, saving protein only") + writepdb(pred_with_lig_path, pred_coords_with_ligand.detach(), gt_seq) + + # Compute metrics + result = { + "pdb_id": pdb_id, + "length": len(gt_seq), + "n_pocket_residues": int(pocket_mask.sum().item()), + "n_nonpocket_residues": int(non_pocket_mask.sum().item()), + # Protein-only metrics + "tm_score_no_ligand": self.compute_tm_score(pred_coords_no_ligand, gt_coords, gt_seq, protein_mask), + "rmsd_overall_no_ligand": self.compute_rmsd(pred_coords_no_ligand, gt_coords, protein_mask), + "rmsd_pocket_no_ligand": self.compute_rmsd(pred_coords_no_ligand, gt_coords, pocket_mask), + "rmsd_nonpocket_no_ligand": self.compute_rmsd(pred_coords_no_ligand, gt_coords, non_pocket_mask), + # With-ligand metrics + "tm_score_with_ligand": self.compute_tm_score(pred_coords_with_ligand, gt_coords, gt_seq, protein_mask), + "rmsd_overall_with_ligand": self.compute_rmsd(pred_coords_with_ligand, gt_coords, protein_mask), + "rmsd_pocket_with_ligand": self.compute_rmsd(pred_coords_with_ligand, gt_coords, pocket_mask), + "rmsd_nonpocket_with_ligand": self.compute_rmsd(pred_coords_with_ligand, gt_coords, non_pocket_mask), + } + + # Add best-of-N info if applicable + if self.num_predictions > 1: + result["best_idx_no_ligand"] = pred_no_ligand.get("best_idx") + result["best_idx_with_ligand"] = pred_with_ligand.get("best_idx") + if pred_no_ligand.get("all_scores"): + result["all_scores_no_ligand"] = str(pred_no_ligand["all_scores"]) + if pred_with_ligand.get("all_scores"): + result["all_scores_with_ligand"] = str(pred_with_ligand["all_scores"]) + + # Add reflection info if try_reflection is enabled + if self.try_reflection: + result["reflected_no_ligand"] = reflected_no_ligand + result["reflected_with_ligand"] = reflected_with_ligand + + results.append(result) + + # Log skipped samples + if skipped_samples: + n_protein = sum(1 for s in skipped_samples if s.get("reason") == "max_protein_length") + n_total = sum(1 for s in skipped_samples if s.get("reason") == "max_length") + skip_reasons = [] + if n_protein: + skip_reasons.append(f"{n_protein} due to protein length > {self.max_protein_length}") + if n_total: + skip_reasons.append(f"{n_total} due to total length > {self.max_length}") + logger.info(f"Skipped {len(skipped_samples)} samples: {', '.join(skip_reasons)}") + + # Create results DataFrame + results_df = pd.DataFrame(results) + + # Handle empty results + if len(results_df) == 0: + logger.warning("No samples were successfully evaluated") + summary = { + "mean_tm_score_no_ligand": float("nan"), + "mean_tm_score_with_ligand": float("nan"), + "mean_tm_score_delta": float("nan"), + "mean_rmsd_overall_no_ligand": float("nan"), + "mean_rmsd_overall_with_ligand": float("nan"), + "mean_rmsd_overall_delta": float("nan"), + "mean_rmsd_pocket_no_ligand": float("nan"), + "mean_rmsd_pocket_with_ligand": float("nan"), + "mean_rmsd_pocket_delta": float("nan"), + "mean_rmsd_nonpocket_no_ligand": float("nan"), + "mean_rmsd_nonpocket_with_ligand": float("nan"), + "mean_rmsd_nonpocket_delta": float("nan"), + "n_samples": 0, + "mean_pocket_size": float("nan"), + } + return {"results_df": results_df, "summary": summary} + + # Compute delta metrics (improvement from ligand) + # For TM-score: higher is better, so positive delta = improvement + results_df["tm_score_delta"] = results_df["tm_score_with_ligand"] - results_df["tm_score_no_ligand"] + # For RMSD: lower is better, so negative delta = improvement + results_df["rmsd_overall_delta"] = results_df["rmsd_overall_with_ligand"] - results_df["rmsd_overall_no_ligand"] + results_df["rmsd_pocket_delta"] = results_df["rmsd_pocket_with_ligand"] - results_df["rmsd_pocket_no_ligand"] + results_df["rmsd_nonpocket_delta"] = ( + results_df["rmsd_nonpocket_with_ligand"] - results_df["rmsd_nonpocket_no_ligand"] + ) + + # Compute summary statistics + summary = { + # TM-score (overall only) + "mean_tm_score_no_ligand": results_df["tm_score_no_ligand"].mean(), + "mean_tm_score_with_ligand": results_df["tm_score_with_ligand"].mean(), + "mean_tm_score_delta": results_df["tm_score_delta"].mean(), + "std_tm_score_delta": results_df["tm_score_delta"].std(), + # Overall RMSD + "mean_rmsd_overall_no_ligand": results_df["rmsd_overall_no_ligand"].mean(), + "mean_rmsd_overall_with_ligand": results_df["rmsd_overall_with_ligand"].mean(), + "mean_rmsd_overall_delta": results_df["rmsd_overall_delta"].mean(), + "std_rmsd_overall_delta": results_df["rmsd_overall_delta"].std(), + # Pocket RMSD + "mean_rmsd_pocket_no_ligand": results_df["rmsd_pocket_no_ligand"].mean(), + "mean_rmsd_pocket_with_ligand": results_df["rmsd_pocket_with_ligand"].mean(), + "mean_rmsd_pocket_delta": results_df["rmsd_pocket_delta"].mean(), + "std_rmsd_pocket_delta": results_df["rmsd_pocket_delta"].std(), + # Non-pocket RMSD + "mean_rmsd_nonpocket_no_ligand": results_df["rmsd_nonpocket_no_ligand"].mean(), + "mean_rmsd_nonpocket_with_ligand": results_df["rmsd_nonpocket_with_ligand"].mean(), + "mean_rmsd_nonpocket_delta": results_df["rmsd_nonpocket_delta"].mean(), + "std_rmsd_nonpocket_delta": results_df["rmsd_nonpocket_delta"].std(), + # Sample counts + "n_samples": len(results_df), + "mean_pocket_size": results_df["n_pocket_residues"].mean(), + } + + # Add reflection statistics if try_reflection is enabled + if self.try_reflection and "reflected_no_ligand" in results_df.columns: + summary["reflection_rate_no_ligand"] = results_df["reflected_no_ligand"].mean() + summary["reflection_rate_with_ligand"] = results_df["reflected_with_ligand"].mean() + summary["n_reflected_no_ligand"] = int(results_df["reflected_no_ligand"].sum()) + summary["n_reflected_with_ligand"] = int(results_df["reflected_with_ligand"].sum()) + + return {"results_df": results_df, "summary": summary} + + def sequence_to_string(self, seq_tensor: Tensor) -> str: + """Convert sequence tensor to string.""" + return "".join([self.aa_map.get(int(s), "X") for s in seq_tensor.cpu().tolist()]) diff --git a/src/lobster/metrics/protein_ligand_inverse_folding.py b/src/lobster/metrics/protein_ligand_inverse_folding.py new file mode 100644 index 00000000..65ade7de --- /dev/null +++ b/src/lobster/metrics/protein_ligand_inverse_folding.py @@ -0,0 +1,1370 @@ +"""Protein-Ligand Inverse Folding Evaluator. + +Evaluates inverse folding on protein-ligand complexes with and without ligand context. +Can be used as a standalone evaluator or within a callback during training. + +Key Question: Does providing ligand context improve sequence recovery for binding pocket residues? + +Optional ESMFold validation: Fold designed sequences with ESMFold to check if the predicted +structure matches the ground truth (designability metric). +""" + +import os +from glob import glob +from typing import TYPE_CHECKING + +import pandas as pd +import torch +from loguru import logger +from torch import Tensor +from tqdm import tqdm + +from bionemo.moco.schedules.inference_time_schedules import ( + LinearInferenceSchedule, + LogInferenceSchedule, + PowerInferenceSchedule, +) + +from lobster.metrics import align_and_compute_rmsd +from lobster.model.latent_generator.io import writepdb, writepdb_ligand_complex +from lobster.model.latent_generator.utils import apply_se3_augmentation_protein_ligand, minimize_ligand_structure +from lobster.model.latent_generator.utils.residue_constants import ( + convert_lobster_aa_tokenization_to_standard_aa, +) + +# Mapping from string names to inference schedule classes +INFERENCE_SCHEDULE_MAP = { + "LinearInferenceSchedule": LinearInferenceSchedule, + "LogInferenceSchedule": LogInferenceSchedule, + "PowerInferenceSchedule": PowerInferenceSchedule, +} + + +def _get_inference_schedule_class(schedule_name: str): + """Convert string schedule name to class.""" + if schedule_name not in INFERENCE_SCHEDULE_MAP: + raise ValueError(f"Unknown inference schedule: {schedule_name}. Options: {list(INFERENCE_SCHEDULE_MAP.keys())}") + return INFERENCE_SCHEDULE_MAP[schedule_name] + + +if TYPE_CHECKING: + from lightning import LightningModule + + +class ProteinLigandInverseFoldingEvaluator: + """Evaluates inverse folding on protein-ligand complexes with/without ligand context. + + This evaluator compares two modes: + 1. Protein-only: Provide only protein structure, predict sequence + 2. Protein+Ligand: Provide protein structure + ligand, predict sequence + + Tracks metrics separately for: + - Overall sequence recovery + - Binding pocket residues (within distance threshold of ligand) + - Non-pocket residues + + Can be used: + - As standalone evaluation script + - Within callback during training + + Parameters + ---------- + data_dir : str + Path to PDBBind test directory containing *_protein.pt and *_ligand.pt pairs + pocket_distance_threshold : float + Distance threshold (Å) for defining binding pocket residues + num_samples : int, optional + Limit number of samples to evaluate (None = all) + num_designs : int + Number of designs per structure + nsteps : int + Number of diffusion steps for generation + device : str + Device for computation + max_length : int + Maximum combined sequence length (protein + ligand) to process (default: 512). + Samples exceeding this length will be skipped. + max_protein_length : int + Maximum protein-only sequence length (default: 512). Samples with protein length + exceeding this will be skipped entirely. Also used as the ESMFold max length when + ESMFold is enabled. + decode_structure : bool + Whether to decode and save predicted structures as PDB files (default: False). + When True, saves decoded structures for both with/without ligand conditions. + save_gt_structure : bool + Whether to save ground truth structures as PDB files (default: False). + minimize_ligand : bool + Whether to apply geometry correction to decoded ligand structures (default: False). + minimize_mode : str + Minimization mode: "bonds_only", "bonds_and_angles", "local", or "full". + force_field : str + Force field for minimization: "MMFF94", "MMFF94s", "UFF", etc. + minimize_steps : int + Maximum number of minimization steps. + save_reconstructed_input : bool + Whether to save the reconstructed input structures (encode then decode the input + before generation) to verify token encoding fidelity (default: False). + use_se3_augmentation : bool + Whether to apply random SE3 augmentation (rotation + translation) to input + structures before encoding (default: False). This matches training behavior. + se3_translation_scale : float + Scale factor for random translation when SE3 augmentation is enabled (default: 1.0). + temperature_seq : float + Temperature for sequence sampling (default: 0.5). + temperature_struc : float + Temperature for structure sampling (default: 0.5). + stochasticity_seq : int + Stochasticity parameter for sequence sampling (default: 20). + stochasticity_struc : int + Stochasticity parameter for structure sampling (default: 20). + temperature_ligand : float + Temperature for ligand structure sampling (default: 0.5). + stochasticity_ligand : int + Stochasticity parameter for ligand structure sampling (default: 20). + inference_schedule_seq : str + Inference schedule for sequence generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule" (default: "LogInferenceSchedule"). + inference_schedule_struc : str + Inference schedule for structure generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule" (default: "LinearInferenceSchedule"). + inference_schedule_ligand_atom : str + Inference schedule for ligand atom token generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule", or None to use sequence schedule + (default: None). + inference_schedule_ligand_struc : str + Inference schedule for ligand structure token generation. Options: "LinearInferenceSchedule", + "LogInferenceSchedule", "PowerInferenceSchedule", or None to use structure schedule + (default: None). + use_esmfold : bool + Whether to validate designed sequences with ESMFold. When enabled, folds designed + sequences and computes TM-score, RMSD, pLDDT vs ground truth structure (default: False). + plm_fold : object, optional + Pre-loaded LobsterPLMFold model instance. Required if use_esmfold=True. + Load with: LobsterPLMFold(model_name="esmfold_v1", max_length=512). + """ + + def __init__( + self, + data_dir: str, + pocket_distance_threshold: float = 5.0, + num_samples: int | None = None, + num_designs: int = 1, + nsteps: int = 100, + device: str = "cuda", + max_length: int = 512, + max_protein_length: int = 512, + decode_structure: bool = False, + save_gt_structure: bool = False, + minimize_ligand: bool = False, + minimize_mode: str = "bonds_and_angles", + force_field: str = "MMFF94", + minimize_steps: int = 500, + save_reconstructed_input: bool = False, + use_se3_augmentation: bool = False, + se3_translation_scale: float = 1.0, + # Generation hyperparameters + temperature_seq: float = 0.5, + temperature_struc: float = 0.5, + stochasticity_seq: int = 20, + stochasticity_struc: int = 20, + temperature_ligand: float = 0.5, + stochasticity_ligand: int = 20, + inference_schedule_seq: str = "LogInferenceSchedule", + inference_schedule_struc: str = "LinearInferenceSchedule", + inference_schedule_ligand_atom: str | None = None, + inference_schedule_ligand_struc: str | None = None, + # ESMFold validation + use_esmfold: bool = False, + plm_fold: object | None = None, + ): + self.data_dir = data_dir + self.pocket_distance_threshold = pocket_distance_threshold + self.num_samples = num_samples + self.num_designs = num_designs + self.nsteps = nsteps + self.device = device + self.max_length = max_length + self.max_protein_length = max_protein_length + self.decode_structure = decode_structure + self.save_gt_structure = save_gt_structure + self.minimize_ligand = minimize_ligand + self.minimize_mode = minimize_mode + self.force_field = force_field + self.minimize_steps = minimize_steps + self.save_reconstructed_input = save_reconstructed_input + self.use_se3_augmentation = use_se3_augmentation + self.se3_translation_scale = se3_translation_scale + # Generation hyperparameters + self.temperature_seq = temperature_seq + self.temperature_struc = temperature_struc + self.stochasticity_seq = stochasticity_seq + self.stochasticity_struc = stochasticity_struc + self.temperature_ligand = temperature_ligand + self.stochasticity_ligand = stochasticity_ligand + self.inference_schedule_seq = inference_schedule_seq + self.inference_schedule_struc = inference_schedule_struc + self.inference_schedule_ligand_atom = inference_schedule_ligand_atom + self.inference_schedule_ligand_struc = inference_schedule_ligand_struc + # ESMFold validation + self.use_esmfold = use_esmfold + self.plm_fold = plm_fold + if use_esmfold and plm_fold is None: + raise ValueError( + "plm_fold must be provided when use_esmfold=True. " + "Load with: LobsterPLMFold(model_name='esmfold_v1', max_length=512)" + ) + + # Standard amino acid mapping (alphabetical order, matching writepdb num2aa) + # The .pt files store sequences in this STANDARD format + self.standard_aa_map = { + 0: "A", # ALA + 1: "R", # ARG + 2: "N", # ASN + 3: "D", # ASP + 4: "C", # CYS + 5: "Q", # GLN + 6: "E", # GLU + 7: "G", # GLY + 8: "H", # HIS + 9: "I", # ILE + 10: "L", # LEU + 11: "K", # LYS + 12: "M", # MET + 13: "F", # PHE + 14: "P", # PRO + 15: "S", # SER + 16: "T", # THR + 17: "W", # TRP + 18: "Y", # TYR + 19: "V", # VAL + 20: "X", # UNK + } + + # Lobster amino acid mapping (for 21-token vocab model outputs) + self.lobster_aa_map = { + 0: "L", + 1: "A", + 2: "G", + 3: "V", + 4: "S", + 5: "E", + 6: "R", + 7: "T", + 8: "I", + 9: "D", + 10: "P", + 11: "K", + 12: "Q", + 13: "F", + 14: "N", + 15: "Y", + 16: "M", + 17: "H", + 18: "W", + 19: "C", + 20: "X", + } + + # Mapping from lobster tokenization to standard (alphabetical) tokenization + # Used to convert 21-token vocab model outputs to standard format + self.lobster_to_standard = torch.tensor( + [ + 10, # 0: L -> LEU (10) + 0, # 1: A -> ALA (0) + 7, # 2: G -> GLY (7) + 19, # 3: V -> VAL (19) + 15, # 4: S -> SER (15) + 6, # 5: E -> GLU (6) + 1, # 6: R -> ARG (1) + 16, # 7: T -> THR (16) + 9, # 8: I -> ILE (9) + 3, # 9: D -> ASP (3) + 14, # 10: P -> PRO (14) + 11, # 11: K -> LYS (11) + 5, # 12: Q -> GLN (5) + 13, # 13: F -> PHE (13) + 2, # 14: N -> ASN (2) + 18, # 15: Y -> TYR (18) + 12, # 16: M -> MET (12) + 8, # 17: H -> HIS (8) + 17, # 18: W -> TRP (17) + 4, # 19: C -> CYS (4) + 20, # 20: X -> UNK (20) + ], + dtype=torch.long, + device=device, + ) + + # Element vocabulary (ELEMENT_VOCAB_EXTENDED from residue_constants) + self.element_to_idx = { + "PAD": 0, + "MASK": 1, + "UNK": 2, + "C": 3, + "N": 4, + "O": 5, + "S": 6, + "P": 7, + "H": 8, + "F": 9, + "Cl": 10, + "Br": 11, + "I": 12, + "Fe": 13, + "Zn": 14, + "Mg": 15, + "Ca": 16, + "Mn": 17, + "Cu": 18, + "B": 19, + "Si": 20, + "Se": 21, + "Co": 22, + "Ni": 23, + "Bi": 24, + } + + def _atom_names_to_indices(self, atom_names: list) -> Tensor: + """Convert atom names (e.g., ['C1', 'N2', 'O3']) to element indices.""" + indices = [] + for name in atom_names: + # Extract element symbol (first 1-2 characters, handling cases like 'Cl', 'Br') + if len(name) >= 2 and name[:2] in self.element_to_idx: + elem = name[:2] + elif name[0] in self.element_to_idx: + elem = name[0] + else: + # Try just the first character uppercase + elem = name[0].upper() + + idx = self.element_to_idx.get(elem, 2) # 2 = UNK + indices.append(idx) + + return torch.tensor(indices, dtype=torch.long, device=self.device) + + def load_test_set(self) -> list[dict]: + """Load PDBBind test protein-ligand pairs. + + Returns list of dicts with: + - pdb_id: str + - protein_coords: Tensor [L, 3, 3] # N, CA, C backbone + - protein_sequence: Tensor [L] + - protein_mask: Tensor [L] + - protein_indices: Tensor [L] + - ligand_coords: Tensor [N_atoms, 3] + - ligand_atom_types: Tensor [N_atoms] + - ligand_mask: Tensor [N_atoms] + - ligand_indices: Tensor [N_atoms] + - bond_matrix: Tensor [N_atoms, N_atoms] (if available) + """ + # Find protein-ligand pairs + protein_files = sorted(glob(os.path.join(self.data_dir, "*_protein.pt"))) + + if not protein_files: + raise ValueError(f"No protein files found in {self.data_dir}") + + # Limit samples if specified + if self.num_samples is not None: + protein_files = protein_files[: self.num_samples] + + logger.info(f"Loading {len(protein_files)} protein-ligand pairs from {self.data_dir}") + + samples = [] + for pf in tqdm(protein_files, desc="Loading samples"): + pdb_id = os.path.basename(pf).replace("_protein.pt", "") + ligand_file = pf.replace("_protein.pt", "_ligand.pt") + + if not os.path.exists(ligand_file): + logger.warning(f"Missing ligand file for {pdb_id}, skipping") + continue + + protein_data = torch.load(pf, weights_only=False, map_location=self.device) + ligand_data = torch.load(ligand_file, weights_only=False, map_location=self.device) + + # Extract protein data + protein_coords = protein_data.get("coords_res", protein_data.get("coords")) + protein_sequence = protein_data.get("sequence") + + if protein_coords is None: + logger.warning(f"Missing protein coordinates for {pdb_id}, skipping") + continue + + protein_mask = protein_data.get("mask", torch.ones(protein_coords.shape[0], device=self.device)) + protein_indices = protein_data.get("indices", torch.arange(protein_coords.shape[0], device=self.device)) + + # Extract ligand data - handle different key names + # PDBBind uses: atom_coords, atom_names, atom_indices + ligand_coords = ligand_data.get("atom_coords", ligand_data.get("coords", ligand_data.get("ligand_coords"))) + + if ligand_coords is None: + logger.warning(f"Missing ligand coordinates for {pdb_id}, skipping") + continue + + # Handle atom types - may be a list of names or tensor of indices + atom_names = ligand_data.get("atom_names") + if atom_names is not None and isinstance(atom_names, list): + # Convert atom names to element indices + ligand_atom_types = self._atom_names_to_indices(atom_names) + else: + ligand_atom_types = ligand_data.get( + "element_indices", + ligand_data.get( + "ligand_element_indices", + torch.full((ligand_coords.shape[0],), 3, dtype=torch.long, device=self.device), + ), # Default to carbon (3) + ) + + ligand_mask = ligand_data.get( + "mask", ligand_data.get("ligand_mask", torch.ones(ligand_coords.shape[0], device=self.device)) + ) + ligand_indices = ligand_data.get( + "atom_indices", + ligand_data.get( + "indices", + ligand_data.get("ligand_indices", torch.arange(ligand_coords.shape[0], device=self.device)), + ), + ) + bond_matrix = ligand_data.get("bond_matrix") + + if protein_sequence is None: + logger.warning(f"Missing sequence for {pdb_id}, skipping") + continue + + samples.append( + { + "pdb_id": pdb_id, + "protein_coords": protein_coords, + "protein_sequence": protein_sequence, + "protein_mask": protein_mask, + "protein_indices": protein_indices, + "ligand_coords": ligand_coords, + "ligand_atom_types": ligand_atom_types, + "ligand_atom_names": atom_names, # Keep original atom names for PDB writing + "ligand_mask": ligand_mask, + "ligand_indices": ligand_indices, + "bond_matrix": bond_matrix, + } + ) + + logger.info(f"Loaded {len(samples)} valid samples") + return samples + + def compute_binding_pocket( + self, + protein_coords: Tensor, + ligand_coords: Tensor, + protein_mask: Tensor | None = None, + ) -> Tensor: + """Compute pocket mask based on distance to ligand. + + A residue is considered part of the binding pocket if any of its + backbone atoms (N, CA, C) are within the threshold distance of + any ligand heavy atom. + + Parameters + ---------- + protein_coords : Tensor + [L, 3, 3] or [L, 3] backbone coordinates + ligand_coords : Tensor + [N_atoms, 3] ligand atom coordinates + protein_mask : Tensor, optional + [L] valid residue mask + + Returns + ------- + pocket_mask : Tensor + [L] boolean mask, True for pocket residues + """ + # Handle different coordinate formats + if protein_coords.dim() == 3: + # [L, 3, 3] - use CA atoms (index 1) + ca_coords = protein_coords[:, 1, :] # [L, 3] + else: + # [L, 3] - already CA-like + ca_coords = protein_coords + + # Compute pairwise distances between CA atoms and ligand atoms + # ca_coords: [L, 3], ligand_coords: [N_atoms, 3] + # distances: [L, N_atoms] + distances = torch.cdist(ca_coords.unsqueeze(0), ligand_coords.unsqueeze(0)).squeeze(0) + + # Min distance from each residue to any ligand atom + min_distances = distances.min(dim=1).values # [L] + + # Pocket mask: residues within threshold + pocket_mask = min_distances < self.pocket_distance_threshold + + # Apply valid mask if provided + if protein_mask is not None: + pocket_mask = pocket_mask & protein_mask.bool() + + return pocket_mask + + def decode_input_tokens( + self, + model: "LightningModule", + inverse_fold_result: dict, + ) -> dict: + """Decode the input structure tokens from inverse folding back to coordinates. + + This uses the EXACT same tokens that were used for inverse folding, + ensuring consistent comparison of input vs output structures. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-ligand model + inverse_fold_result : dict + Result dictionary from inverse_fold() containing: + - input_protein_structure_logits: Tensor [L, n_tokens] + - input_ligand_structure_tokens: Tensor [N] (optional) + - protein_mask: Tensor [L] + - ligand_mask: Tensor [N] (optional) + + Returns + ------- + dict with: + - reconstructed_protein_coords: Tensor [L, 3, 3] - decoded protein coordinates + - reconstructed_ligand_coords: Tensor [N, 3] - decoded ligand coordinates (if available) + """ + # Get the input tokens/logits from inverse_fold result + protein_structure_logits = inverse_fold_result.get("input_protein_structure_logits") + ligand_structure_tokens = inverse_fold_result.get("input_ligand_structure_tokens") + protein_mask = inverse_fold_result.get("protein_mask") + ligand_mask = inverse_fold_result.get("ligand_mask") + + if protein_structure_logits is None: + logger.warning("No input_protein_structure_logits in inverse_fold result") + return { + "reconstructed_protein_coords": None, + "reconstructed_ligand_coords": None, + } + + # Add batch dimension + protein_structure_logits = protein_structure_logits.unsqueeze(0) # [1, L, n_tokens] + protein_mask = protein_mask.unsqueeze(0) # [1, L] + + # Create decode input dict + decode_input = { + "structure_logits": protein_structure_logits, + "sequence_logits": torch.zeros( + 1, protein_structure_logits.shape[1], 33, device=protein_structure_logits.device + ), + } + + # Handle ligand if present + ligand_mask_batched = None + if ligand_structure_tokens is not None and ligand_mask is not None: + ligand_structure_tokens = ligand_structure_tokens.unsqueeze(0) # [1, N] + ligand_mask_batched = ligand_mask.unsqueeze(0) # [1, N] + + # Convert ligand tokens to one-hot logits + n_tokens = model.quantizer.n_tokens if model.quantizer is not None else 4375 + ligand_structure_logits = torch.zeros( + 1, ligand_structure_tokens.shape[1], n_tokens, device=ligand_structure_tokens.device + ) + ligand_structure_logits.scatter_(2, ligand_structure_tokens.unsqueeze(-1).long(), 1.0) + decode_input["ligand_structure_logits"] = ligand_structure_logits + + # Decode to coordinates + with torch.no_grad(): + decoded_x = model.decode_structure( + decode_input, + protein_mask, + ligand_mask=ligand_mask_batched, + ) + + vit_output = decoded_x.get("vit_decoder") + if isinstance(vit_output, dict): + reconstructed_protein_coords = vit_output.get("protein_coords") + reconstructed_ligand_coords = vit_output.get("ligand_coords") + else: + reconstructed_protein_coords = vit_output + reconstructed_ligand_coords = None + + return { + "reconstructed_protein_coords": reconstructed_protein_coords.squeeze(0) + if reconstructed_protein_coords is not None + else None, + "reconstructed_ligand_coords": reconstructed_ligand_coords.squeeze(0) + if reconstructed_ligand_coords is not None + else None, + } + + def inverse_fold( + self, + model: "LightningModule", + sample: dict, + include_ligand: bool, + ) -> dict: + """Run inverse folding with or without ligand context. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-ligand model + sample : dict + Sample dictionary from load_test_set() + include_ligand : bool + Whether to include ligand context + + Returns + ------- + dict with: + - predicted_sequence: Tensor [L] + - sequence_logits: Tensor [L, vocab_size] + - decoded_coords: Tensor [L, 3, 3] (decoded protein structure) + - decoded_ligand_coords: Tensor [N, 3] (decoded ligand structure, if include_ligand=True) + - input_protein_structure_tokens: Tensor [L] (input protein structure tokens used) + - input_protein_structure_logits: Tensor [L, n_tokens] (input protein structure logits) + - input_ligand_structure_tokens: Tensor [N] (input ligand structure tokens, if include_ligand) + - protein_mask: Tensor [L] (protein mask used) + - ligand_mask: Tensor [N] (ligand mask used, if include_ligand) + """ + # Prepare protein inputs - ensure proper dtype + protein_coords = sample["protein_coords"].unsqueeze(0).float() + protein_mask = sample["protein_mask"].unsqueeze(0).float() + # Indices must be long (int64) for indexing operations + protein_indices = sample["protein_indices"].unsqueeze(0).long() + length = protein_coords.shape[1] + + # Prepare ligand inputs if needed + ligand_coords = None + ligand_mask = None + ligand_indices = None + ligand_atom_tokens = None + ligand_structure_tokens = None + ligand_structure_embeddings = None + bond_matrix = None + num_atoms = 0 + + if include_ligand: + ligand_coords = sample["ligand_coords"].unsqueeze(0).float() + ligand_mask = sample["ligand_mask"].unsqueeze(0).float() + ligand_indices = sample["ligand_indices"].unsqueeze(0).long() + ligand_atom_tokens = sample["ligand_atom_types"].unsqueeze(0).long() + num_atoms = ligand_coords.shape[1] + bond_matrix = sample.get("bond_matrix") + if bond_matrix is not None: + bond_matrix = bond_matrix.unsqueeze(0).long() + + # Apply SE3 augmentation if enabled + # Uses standalone function that applies SAME SE3 transform to both protein and ligand + if self.use_se3_augmentation: + augmented = apply_se3_augmentation_protein_ligand( + protein_coords=protein_coords, + protein_mask=protein_mask, + ligand_coords=ligand_coords, + ligand_mask=ligand_mask, + random_se3=True, + translation_scale=self.se3_translation_scale, + backbone_noise=0.0, + ) + protein_coords = augmented.protein_coords + if ligand_coords is not None: + ligand_coords = augmented.ligand_coords + + # Encode protein and ligand structure TOGETHER (joint encoding) + # This allows protein-ligand interactions during encoding + input_protein_structure_tokens = None + input_protein_structure_logits = None + protein_structure_embeddings = None + ligand_structure_tokens = None + ligand_structure_embeddings = None + + with torch.no_grad(): + if include_ligand: + # Joint encoding using the model's encode_protein_ligand_structure method + encoded = model.encode_protein_ligand_structure( + protein_coords=protein_coords, + protein_mask=protein_mask, + protein_indices=protein_indices, + ligand_coords=ligand_coords, + ligand_mask=ligand_mask, + ligand_indices=ligand_indices, + ligand_atom_types=ligand_atom_tokens.squeeze(0) if ligand_atom_tokens is not None else None, + bond_matrix=bond_matrix, + ) + + input_protein_structure_tokens = encoded["protein_tokens"] + protein_structure_embeddings = encoded["protein_embeddings"] + ligand_structure_tokens = encoded["ligand_tokens"] + ligand_structure_embeddings = encoded["ligand_embeddings"] + + # For discrete mode decoding, convert tokens to one-hot logits + if not model.use_continuous_structure: + n_tokens = model.quantizer.n_tokens if model.quantizer is not None else model.num_struc_classes + input_protein_structure_logits = torch.zeros( + *input_protein_structure_tokens.shape, n_tokens, device=input_protein_structure_tokens.device + ) + input_protein_structure_logits.scatter_( + -1, input_protein_structure_tokens.unsqueeze(-1).long(), 1.0 + ) + else: + # Protein-only encoding + protein_structure_logits, _, _ = model.encode_structure(protein_coords, protein_mask, protein_indices) + input_protein_structure_logits = protein_structure_logits + input_protein_structure_tokens = protein_structure_logits.argmax(dim=-1) + + # Get inference schedule classes + inference_schedule_seq_class = _get_inference_schedule_class(self.inference_schedule_seq) + inference_schedule_struc_class = _get_inference_schedule_class(self.inference_schedule_struc) + # Ligand schedules (None to fall back to protein schedules) + inference_schedule_ligand_atom_class = ( + _get_inference_schedule_class(self.inference_schedule_ligand_atom) + if self.inference_schedule_ligand_atom + else None + ) + inference_schedule_ligand_struc_class = ( + _get_inference_schedule_class(self.inference_schedule_ligand_struc) + if self.inference_schedule_ligand_struc + else None + ) + + # Generate sample (inverse folding mode) + with torch.no_grad(): + result = model.generate_sample( + length=length, + num_samples=1, + inverse_folding=True, + nsteps=self.nsteps, + inference_schedule_seq=inference_schedule_seq_class, + inference_schedule_struc=inference_schedule_struc_class, + inference_schedule_ligand_atom=inference_schedule_ligand_atom_class, + inference_schedule_ligand_struc=inference_schedule_ligand_struc_class, + temperature_seq=self.temperature_seq, + temperature_struc=self.temperature_struc, + stochasticity_seq=self.stochasticity_seq, + stochasticity_struc=self.stochasticity_struc, + temperature_ligand=self.temperature_ligand, + stochasticity_ligand=self.stochasticity_ligand, + input_structure_coords=protein_coords, + input_mask=protein_mask, + input_indices=protein_indices, + # Ligand context (fixed conditioning, not to be generated) + generate_ligand=include_ligand, + num_atoms=num_atoms if include_ligand else 0, + input_ligand_atom_tokens=ligand_atom_tokens, + input_ligand_structure_tokens=ligand_structure_tokens, + input_ligand_structure_embeddings=ligand_structure_embeddings if include_ligand else None, + input_bond_matrix=bond_matrix, + ligand_is_context=include_ligand, + ) + + # Decode structure to coordinates (optional) + decoded_coords = None + decoded_ligand_coords = None + if self.decode_structure: + decoded_x = model.decode_structure( + result, + protein_mask, + ligand_mask=ligand_mask if include_ligand else None, + ) + vit_output = decoded_x.get("vit_decoder") + if isinstance(vit_output, dict): + decoded_coords = vit_output.get("protein_coords") + decoded_ligand_coords = vit_output.get("ligand_coords") + else: + decoded_coords = vit_output + + # Get predicted sequence + sequence_logits = result["sequence_logits"] # [1, L, vocab_size] + uses_33_token_vocab = sequence_logits.shape[-1] == 33 + + # Handle both 33-token and 21-token vocab formats + # Always convert to standard format for consistency with ground truth + if uses_33_token_vocab: + # 33-token vocab: convert to standard (alphabetical) format + predicted_sequence = convert_lobster_aa_tokenization_to_standard_aa( + sequence_logits, device=sequence_logits.device + ).squeeze(0) # [L] in standard format + else: + # 21-token vocab: output is in lobster format, convert to standard + predicted_sequence = sequence_logits.argmax(dim=-1).squeeze(0) # [L] in lobster format + predicted_sequence[predicted_sequence > 20] = 20 # Clamp to valid range + predicted_sequence = self.lobster_to_standard[predicted_sequence.long()] # Convert to standard + + # Get predicted bond matrix if available + predicted_bond_matrix = result.get("predicted_bond_matrix") + + # Get output structure tokens for comparison with input + output_protein_structure_tokens = result.get("structure_tokens") + output_ligand_structure_tokens = result.get("ligand_structure_tokens") + + return { + "predicted_sequence": predicted_sequence, # Always in standard format + "sequence_logits": sequence_logits.squeeze(0), + "decoded_coords": decoded_coords.squeeze(0) if decoded_coords is not None else None, + "decoded_ligand_coords": decoded_ligand_coords.squeeze(0) if decoded_ligand_coords is not None else None, + "predicted_bond_matrix": predicted_bond_matrix.squeeze(0) if predicted_bond_matrix is not None else None, + "output_protein_structure_tokens": output_protein_structure_tokens.squeeze(0) + if output_protein_structure_tokens is not None + else None, + "output_ligand_structure_tokens": output_ligand_structure_tokens.squeeze(0) + if output_ligand_structure_tokens is not None + else None, + # Input tokens/embeddings for reconstruction (exact same used for inverse folding) + "input_protein_structure_tokens": input_protein_structure_tokens.squeeze(0) + if input_protein_structure_tokens is not None + else None, + "input_protein_structure_logits": input_protein_structure_logits.squeeze(0) + if input_protein_structure_logits is not None + else None, + "input_protein_structure_embeddings": protein_structure_embeddings.squeeze(0) + if protein_structure_embeddings is not None + else None, + "input_ligand_structure_tokens": ligand_structure_tokens.squeeze(0) + if ligand_structure_tokens is not None + else None, + "input_ligand_structure_embeddings": ligand_structure_embeddings.squeeze(0) + if ligand_structure_embeddings is not None + else None, + "protein_mask": protein_mask.squeeze(0), + "ligand_mask": ligand_mask.squeeze(0) if ligand_mask is not None else None, + } + + def compute_aar( + self, + predicted_seq: Tensor, + ground_truth_seq: Tensor, + mask: Tensor | None = None, + ) -> float: + """Compute amino acid recovery rate. + + Parameters + ---------- + predicted_seq : Tensor + [L] predicted sequence tokens + ground_truth_seq : Tensor + [L] ground truth sequence tokens + mask : Tensor, optional + [L] boolean mask for positions to include + + Returns + ------- + float + Amino acid recovery rate (0-1) + """ + if mask is not None: + mask = mask.bool() + if mask.sum() == 0: + return float("nan") + predicted_seq = predicted_seq[mask] + ground_truth_seq = ground_truth_seq[mask] + + if len(predicted_seq) == 0: + return float("nan") + + return (predicted_seq == ground_truth_seq).float().mean().item() + + def fold_with_esmfold( + self, + sequence_str: str, + gt_coords: Tensor, + protein_mask: Tensor, + ) -> dict: + """Fold a designed sequence with ESMFold and compute metrics vs ground truth. + + Parameters + ---------- + sequence_str : str + Amino acid sequence string to fold + gt_coords : Tensor + [L, 3, 3] ground truth backbone coordinates (N, CA, C) + protein_mask : Tensor + [L] valid residue mask + + Returns + ------- + dict with: + - esmfold_tm_score: TM-score of ESMFold prediction vs GT + - esmfold_rmsd: RMSD of ESMFold prediction vs GT + - esmfold_plddt: mean pLDDT score + - esmfold_pae: mean predicted aligned error + - esmfold_coords: Tensor [L_valid, 3, 3] ESMFold predicted coordinates + """ + from lobster.metrics import get_folded_structure_metrics + + if self.plm_fold is None: + return {} + + # Tokenize sequence for ESMFold + tokenized_input = self.plm_fold.tokenizer.encode_plus( + sequence_str, + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + return_tensors="pt", + )["input_ids"].to(self.device) + + # Fold with ESMFold + with torch.no_grad(): + esmfold_outputs = self.plm_fold.model(tokenized_input) + + # Get reference structure (only valid residues) + mask_bool = protein_mask.bool() + ref_coords = gt_coords[mask_bool].unsqueeze(0) # [1, L_valid, 3, 3] + + # Calculate metrics (ref_coords already filtered, so mask=None) + folded_metrics, folded_coords = get_folded_structure_metrics( + esmfold_outputs, ref_coords, [sequence_str], mask=None, device=self.device + ) + + def _to_float(v): + """Convert tensor/numpy scalar to Python float.""" + if hasattr(v, "item"): + return v.item() + return float(v) + + return { + "esmfold_tm_score": _to_float(folded_metrics["_tm_score"]), + "esmfold_rmsd": _to_float(folded_metrics["_rmsd"]), + "esmfold_plddt": _to_float(folded_metrics["_plddt"]), + "esmfold_pae": _to_float(folded_metrics["_predicted_aligned_error"]), + "esmfold_coords": folded_coords[0], # [L_valid, 3, 3] + } + + def evaluate( + self, + model: "LightningModule", + samples: list[dict] | None = None, + structure_path: str | None = None, + ) -> dict: + """Run full evaluation on PDBBind test set. + + Parameters + ---------- + model : LightningModule + The Gen-UME protein-ligand model + samples : list[dict], optional + Pre-loaded samples (will load if not provided) + structure_path : str, optional + Directory to save designed sequences as FASTA files + + Returns + ------- + dict with: + - results_df: DataFrame with per-structure results + - summary: dict with aggregated metrics + """ + model.eval() + model.to(self.device) + + if samples is None: + samples = self.load_test_set() + + # Create output directory if specified + if structure_path: + os.makedirs(structure_path, exist_ok=True) + + results = [] + skipped_samples = [] + + for sample in tqdm(samples, desc="Evaluating inverse folding"): + pdb_id = sample["pdb_id"] + gt_seq = sample["protein_sequence"] + protein_mask = sample["protein_mask"] + + # Check combined protein + ligand length (they are concatenated in the model) + protein_length = len(gt_seq) + ligand_length = len(sample["ligand_coords"]) + total_length = protein_length + ligand_length + + # Skip samples exceeding max protein length + if protein_length > self.max_protein_length: + logger.warning( + f"Skipping {pdb_id}: protein length {protein_length} " + f"exceeds max_protein_length {self.max_protein_length}" + ) + skipped_samples.append( + { + "pdb_id": pdb_id, + "protein_length": protein_length, + "ligand_length": ligand_length, + "total_length": total_length, + "reason": "max_protein_length", + } + ) + continue + + if total_length > self.max_length: + logger.warning( + f"Skipping {pdb_id}: total length {total_length} " + f"(protein: {protein_length}, ligand: {ligand_length}) exceeds max_length {self.max_length}" + ) + skipped_samples.append( + { + "pdb_id": pdb_id, + "protein_length": protein_length, + "ligand_length": ligand_length, + "total_length": total_length, + "reason": "max_length", + } + ) + continue + + # Compute binding pocket + pocket_mask = self.compute_binding_pocket( + sample["protein_coords"], + sample["ligand_coords"], + protein_mask, + ) + non_pocket_mask = protein_mask.bool() & ~pocket_mask + + # Mode 1: Protein only (no ligand context) + pred_no_ligand = self.inverse_fold(model, sample, include_ligand=False) + pred_seq_no_ligand = pred_no_ligand["predicted_sequence"] + + # Mode 2: Protein + Ligand context + pred_with_ligand = self.inverse_fold(model, sample, include_ligand=True) + pred_seq_with_ligand = pred_with_ligand["predicted_sequence"] + + # Save sequences and structures if structure_path provided + if structure_path: + gt_seq_str = self.sequence_to_string(gt_seq) + no_ligand_seq_str = self.sequence_to_string(pred_seq_no_ligand) + with_ligand_seq_str = self.sequence_to_string(pred_seq_with_ligand) + + # Save sequences as FASTA + fasta_path = os.path.join(structure_path, f"{pdb_id}_sequences.fasta") + with open(fasta_path, "w") as f: + f.write(f">{pdb_id}_ground_truth\n{gt_seq_str}\n") + f.write(f">{pdb_id}_no_ligand\n{no_ligand_seq_str}\n") + f.write(f">{pdb_id}_with_ligand\n{with_ligand_seq_str}\n") + + # Get coordinates and atom names + protein_coords = sample["protein_coords"] + ligand_coords = sample["ligand_coords"] + # Get atom names from original data or generate default names + atom_names = sample.get("ligand_atom_names") + if atom_names is None: + # Generate default atom names from element indices + idx_to_element = {v: k for k, v in self.element_to_idx.items()} + ligand_types = sample["ligand_atom_types"] + atom_names = [ + f"{idx_to_element.get(int(t), 'C')}{i + 1}" for i, t in enumerate(ligand_types.cpu().tolist()) + ] + + # Get bond matrix for CONECT records + bond_matrix = sample.get("bond_matrix") + + # Save ground truth structures (optional) + # NOTE: gt_seq from .pt files is already in standard tokenization format + if self.save_gt_structure: + # Save ground truth protein structure as PDB + pdb_path = os.path.join(structure_path, f"{pdb_id}_protein.pdb") + writepdb(pdb_path, protein_coords, gt_seq) + + # Save protein-ligand complex as PDB (ground truth) + complex_path = os.path.join(structure_path, f"{pdb_id}_complex.pdb") + writepdb_ligand_complex( + complex_path, + protein_atoms=protein_coords, + protein_seq=gt_seq, + ligand_atoms=ligand_coords, + ligand_atom_names=atom_names, + ligand_bond_matrix=bond_matrix, + ) + + # Save reconstructed input structures (decode the SAME tokens used for inverse folding) + # This verifies token encoding/decoding fidelity using exact same tokens + if self.save_reconstructed_input: + # Use decode_input_tokens with tokens from pred_with_ligand (same tokens used for IF) + recon_result = self.decode_input_tokens(model, pred_with_ligand) + + # Save reconstructed protein structure + recon_protein_coords = recon_result.get("reconstructed_protein_coords") + if recon_protein_coords is not None: + recon_pdb_path = os.path.join(structure_path, f"{pdb_id}_reconstructed_input_protein.pdb") + writepdb(recon_pdb_path, recon_protein_coords, gt_seq) + + # Save reconstructed protein-ligand complex + recon_ligand_coords = recon_result.get("reconstructed_ligand_coords") + if recon_protein_coords is not None and recon_ligand_coords is not None: + # Apply minimization to reconstructed ligand if enabled + recon_ligand_coords_to_save = recon_ligand_coords.cpu() + if self.minimize_ligand: + try: + recon_ligand_coords_to_save = minimize_ligand_structure( + recon_ligand_coords_to_save, + atom_names, + bond_matrix=bond_matrix, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Reconstructed ligand minimization failed for {pdb_id}: {e}") + + recon_complex_path = os.path.join(structure_path, f"{pdb_id}_reconstructed_input_complex.pdb") + writepdb_ligand_complex( + recon_complex_path, + protein_atoms=recon_protein_coords, + protein_seq=gt_seq, + ligand_atoms=recon_ligand_coords_to_save, + ligand_atom_names=atom_names, + ligand_bond_matrix=bond_matrix, + ) + + # Log token info - tokens come directly from inverse_fold result + input_protein_tokens = pred_with_ligand.get("input_protein_structure_tokens") + input_ligand_tokens = pred_with_ligand.get("input_ligand_structure_tokens") + output_protein_tokens = pred_with_ligand.get("output_protein_structure_tokens") + output_ligand_tokens = pred_with_ligand.get("output_ligand_structure_tokens") + + # Compute token preservation rate + if input_protein_tokens is not None and output_protein_tokens is not None: + protein_token_match = (input_protein_tokens == output_protein_tokens).float().mean().item() + logger.info(f"{pdb_id}: Protein structure token preservation: {protein_token_match * 100:.1f}%") + if input_ligand_tokens is not None and output_ligand_tokens is not None: + ligand_token_match = (input_ligand_tokens == output_ligand_tokens).float().mean().item() + logger.info(f"{pdb_id}: Ligand structure token preservation: {ligand_token_match * 100:.1f}%") + + # Save decoded protein structure (no ligand) as PDB + # NOTE: pred_seq_no_ligand is already in standard format (converted in inverse_fold) + decoded_coords_no_ligand = pred_no_ligand.get("decoded_coords") + if decoded_coords_no_ligand is not None: + decoded_pdb_path = os.path.join(structure_path, f"{pdb_id}_decoded_no_ligand.pdb") + writepdb(decoded_pdb_path, decoded_coords_no_ligand, pred_seq_no_ligand) + + # Save decoded protein structure (with ligand) as PDB - include decoded ligand + # NOTE: pred_seq_with_ligand is already in standard format (converted in inverse_fold) + decoded_coords_with_ligand = pred_with_ligand.get("decoded_coords") + decoded_ligand_coords = pred_with_ligand.get("decoded_ligand_coords") + if decoded_coords_with_ligand is not None: + if decoded_ligand_coords is None: + raise ValueError( + f"Model did not output decoded ligand coordinates for {pdb_id}. " + "Check that the model supports ligand structure decoding." + ) + # Use predicted bond matrix if available, otherwise fall back to GT + pred_bond_matrix = pred_with_ligand.get("predicted_bond_matrix") + bond_matrix_for_pred = pred_bond_matrix if pred_bond_matrix is not None else bond_matrix + + # Apply minimization if enabled + ligand_coords_to_save = decoded_ligand_coords.cpu() + if self.minimize_ligand: + try: + ligand_coords_to_save = minimize_ligand_structure( + ligand_coords_to_save, + atom_names, + bond_matrix=bond_matrix_for_pred, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Ligand minimization failed for {pdb_id}: {e}") + decoded_pdb_path = os.path.join(structure_path, f"{pdb_id}_decoded_with_ligand.pdb") + writepdb_ligand_complex( + decoded_pdb_path, + protein_atoms=decoded_coords_with_ligand, + protein_seq=pred_seq_with_ligand, + ligand_atoms=ligand_coords_to_save, + ligand_atom_names=atom_names, + ligand_bond_matrix=bond_matrix_for_pred, + ) + + # Compute metrics + result = { + "pdb_id": pdb_id, + "length": len(gt_seq), + "n_pocket_residues": int(pocket_mask.sum().item()), + "n_nonpocket_residues": int(non_pocket_mask.sum().item()), + # Protein-only metrics + "aar_overall_no_ligand": self.compute_aar(pred_seq_no_ligand, gt_seq, protein_mask), + "aar_pocket_no_ligand": self.compute_aar(pred_seq_no_ligand, gt_seq, pocket_mask), + "aar_nonpocket_no_ligand": self.compute_aar(pred_seq_no_ligand, gt_seq, non_pocket_mask), + # With-ligand metrics + "aar_overall_with_ligand": self.compute_aar(pred_seq_with_ligand, gt_seq, protein_mask), + "aar_pocket_with_ligand": self.compute_aar(pred_seq_with_ligand, gt_seq, pocket_mask), + "aar_nonpocket_with_ligand": self.compute_aar(pred_seq_with_ligand, gt_seq, non_pocket_mask), + } + + # ESMFold validation (fold designed sequences and compare to GT structure) + if self.use_esmfold and self.plm_fold is not None: + gt_coords = sample["protein_coords"] + gt_seq_masked = gt_seq[protein_mask.bool()] + + # Build pocket mask relative to valid residues only (for pocket RMSD) + # pocket_mask is [L] over all residues; we need it relative to valid residues + pocket_mask_valid = pocket_mask[protein_mask.bool()] # [L_valid] + gt_coords_valid = gt_coords[protein_mask.bool()] # [L_valid, 3, 3] + + def _compute_pocket_rmsd(esmfold_coords, gt_coords_v, pocket_mask_v): + """Compute pocket RMSD between ESMFold prediction and GT.""" + if pocket_mask_v.sum() == 0: + return float("nan") + pred_pocket = esmfold_coords[pocket_mask_v] + gt_pocket = gt_coords_v[pocket_mask_v] + rmsd = align_and_compute_rmsd( + coords1=pred_pocket.detach(), + coords2=gt_pocket.detach(), + mask=None, + return_aligned=False, + device=self.device, + ) + return float(rmsd) + + # Fold "no ligand" designed sequence + no_ligand_seq_str = self.sequence_to_string(pred_seq_no_ligand[protein_mask.bool()]) + esmfold_no_ligand = self.fold_with_esmfold(no_ligand_seq_str, gt_coords, protein_mask) + result["esmfold_tm_no_ligand"] = esmfold_no_ligand["esmfold_tm_score"] + result["esmfold_rmsd_no_ligand"] = esmfold_no_ligand["esmfold_rmsd"] + result["esmfold_plddt_no_ligand"] = esmfold_no_ligand["esmfold_plddt"] + result["esmfold_pae_no_ligand"] = esmfold_no_ligand["esmfold_pae"] + result["esmfold_rmsd_pocket_no_ligand"] = _compute_pocket_rmsd( + esmfold_no_ligand["esmfold_coords"], gt_coords_valid, pocket_mask_valid + ) + + if structure_path and esmfold_no_ligand.get("esmfold_coords") is not None: + esmfold_pdb_path = os.path.join(structure_path, f"{pdb_id}_esmfold_no_ligand.pdb") + writepdb(esmfold_pdb_path, esmfold_no_ligand["esmfold_coords"], gt_seq_masked) + + # Fold "with ligand" designed sequence + with_ligand_seq_str = self.sequence_to_string(pred_seq_with_ligand[protein_mask.bool()]) + esmfold_with_ligand = self.fold_with_esmfold(with_ligand_seq_str, gt_coords, protein_mask) + result["esmfold_tm_with_ligand"] = esmfold_with_ligand["esmfold_tm_score"] + result["esmfold_rmsd_with_ligand"] = esmfold_with_ligand["esmfold_rmsd"] + result["esmfold_plddt_with_ligand"] = esmfold_with_ligand["esmfold_plddt"] + result["esmfold_pae_with_ligand"] = esmfold_with_ligand["esmfold_pae"] + result["esmfold_rmsd_pocket_with_ligand"] = _compute_pocket_rmsd( + esmfold_with_ligand["esmfold_coords"], gt_coords_valid, pocket_mask_valid + ) + + if structure_path and esmfold_with_ligand.get("esmfold_coords") is not None: + esmfold_pdb_path = os.path.join(structure_path, f"{pdb_id}_esmfold_with_ligand.pdb") + writepdb(esmfold_pdb_path, esmfold_with_ligand["esmfold_coords"], gt_seq_masked) + + # Fold GT sequence for baseline comparison + gt_seq_str = self.sequence_to_string(gt_seq_masked) + esmfold_gt = self.fold_with_esmfold(gt_seq_str, gt_coords, protein_mask) + result["esmfold_tm_gt"] = esmfold_gt["esmfold_tm_score"] + result["esmfold_rmsd_gt"] = esmfold_gt["esmfold_rmsd"] + result["esmfold_plddt_gt"] = esmfold_gt["esmfold_plddt"] + result["esmfold_pae_gt"] = esmfold_gt["esmfold_pae"] + result["esmfold_rmsd_pocket_gt"] = _compute_pocket_rmsd( + esmfold_gt["esmfold_coords"], gt_coords_valid, pocket_mask_valid + ) + + if structure_path and esmfold_gt.get("esmfold_coords") is not None: + esmfold_pdb_path = os.path.join(structure_path, f"{pdb_id}_esmfold_gt.pdb") + writepdb(esmfold_pdb_path, esmfold_gt["esmfold_coords"], gt_seq_masked) + + results.append(result) + + # Log skipped samples + if skipped_samples: + n_protein = sum(1 for s in skipped_samples if s.get("reason") == "max_protein_length") + n_total = sum(1 for s in skipped_samples if s.get("reason") == "max_length") + skip_reasons = [] + if n_protein: + skip_reasons.append(f"{n_protein} due to protein length > {self.max_protein_length}") + if n_total: + skip_reasons.append(f"{n_total} due to total length > {self.max_length}") + logger.info(f"Skipped {len(skipped_samples)} samples: {', '.join(skip_reasons)}") + logger.debug(f"Skipped samples: {skipped_samples}") + + # Create results DataFrame + results_df = pd.DataFrame(results) + + # Handle empty results + if len(results_df) == 0: + logger.warning("No samples were successfully evaluated") + summary = { + "mean_aar_overall_no_ligand": float("nan"), + "mean_aar_overall_with_ligand": float("nan"), + "mean_aar_pocket_no_ligand": float("nan"), + "mean_aar_pocket_with_ligand": float("nan"), + "mean_aar_nonpocket_no_ligand": float("nan"), + "mean_aar_nonpocket_with_ligand": float("nan"), + "mean_aar_overall_delta": float("nan"), + "mean_aar_pocket_delta": float("nan"), + "mean_aar_nonpocket_delta": float("nan"), + "std_aar_pocket_delta": float("nan"), + "std_aar_nonpocket_delta": float("nan"), + "n_samples": 0, + "mean_pocket_size": float("nan"), + } + return {"results_df": results_df, "summary": summary} + + # Compute delta metrics (improvement from ligand) + results_df["aar_overall_delta"] = results_df["aar_overall_with_ligand"] - results_df["aar_overall_no_ligand"] + results_df["aar_pocket_delta"] = results_df["aar_pocket_with_ligand"] - results_df["aar_pocket_no_ligand"] + results_df["aar_nonpocket_delta"] = ( + results_df["aar_nonpocket_with_ligand"] - results_df["aar_nonpocket_no_ligand"] + ) + + # Compute summary statistics + summary = { + # Overall averages (excluding NaN) + "mean_aar_overall_no_ligand": results_df["aar_overall_no_ligand"].mean(), + "mean_aar_overall_with_ligand": results_df["aar_overall_with_ligand"].mean(), + "mean_aar_pocket_no_ligand": results_df["aar_pocket_no_ligand"].mean(), + "mean_aar_pocket_with_ligand": results_df["aar_pocket_with_ligand"].mean(), + "mean_aar_nonpocket_no_ligand": results_df["aar_nonpocket_no_ligand"].mean(), + "mean_aar_nonpocket_with_ligand": results_df["aar_nonpocket_with_ligand"].mean(), + # Delta (improvement from ligand) + "mean_aar_overall_delta": results_df["aar_overall_delta"].mean(), + "mean_aar_pocket_delta": results_df["aar_pocket_delta"].mean(), + "mean_aar_nonpocket_delta": results_df["aar_nonpocket_delta"].mean(), + # Standard deviations + "std_aar_pocket_delta": results_df["aar_pocket_delta"].std(), + "std_aar_nonpocket_delta": results_df["aar_nonpocket_delta"].std(), + # Sample counts + "n_samples": len(results_df), + "mean_pocket_size": results_df["n_pocket_residues"].mean(), + } + + # Add ESMFold summary metrics if available + if self.use_esmfold and "esmfold_tm_no_ligand" in results_df.columns: + # ESMFold deltas + results_df["esmfold_tm_delta"] = results_df["esmfold_tm_with_ligand"] - results_df["esmfold_tm_no_ligand"] + results_df["esmfold_rmsd_delta"] = ( + results_df["esmfold_rmsd_no_ligand"] - results_df["esmfold_rmsd_with_ligand"] + ) + results_df["esmfold_rmsd_pocket_delta"] = ( + results_df["esmfold_rmsd_pocket_no_ligand"] - results_df["esmfold_rmsd_pocket_with_ligand"] + ) + results_df["esmfold_plddt_delta"] = ( + results_df["esmfold_plddt_with_ligand"] - results_df["esmfold_plddt_no_ligand"] + ) + + summary.update( + { + # ESMFold: no ligand + "mean_esmfold_tm_no_ligand": results_df["esmfold_tm_no_ligand"].mean(), + "mean_esmfold_rmsd_no_ligand": results_df["esmfold_rmsd_no_ligand"].mean(), + "mean_esmfold_rmsd_pocket_no_ligand": results_df["esmfold_rmsd_pocket_no_ligand"].mean(), + "mean_esmfold_plddt_no_ligand": results_df["esmfold_plddt_no_ligand"].mean(), + "mean_esmfold_pae_no_ligand": results_df["esmfold_pae_no_ligand"].mean(), + # ESMFold: with ligand + "mean_esmfold_tm_with_ligand": results_df["esmfold_tm_with_ligand"].mean(), + "mean_esmfold_rmsd_with_ligand": results_df["esmfold_rmsd_with_ligand"].mean(), + "mean_esmfold_rmsd_pocket_with_ligand": results_df["esmfold_rmsd_pocket_with_ligand"].mean(), + "mean_esmfold_plddt_with_ligand": results_df["esmfold_plddt_with_ligand"].mean(), + "mean_esmfold_pae_with_ligand": results_df["esmfold_pae_with_ligand"].mean(), + # ESMFold: GT sequence baseline + "mean_esmfold_tm_gt": results_df["esmfold_tm_gt"].mean(), + "mean_esmfold_rmsd_gt": results_df["esmfold_rmsd_gt"].mean(), + "mean_esmfold_rmsd_pocket_gt": results_df["esmfold_rmsd_pocket_gt"].mean(), + "mean_esmfold_plddt_gt": results_df["esmfold_plddt_gt"].mean(), + "mean_esmfold_pae_gt": results_df["esmfold_pae_gt"].mean(), + # ESMFold: deltas (improvement from ligand context) + "mean_esmfold_tm_delta": results_df["esmfold_tm_delta"].mean(), + "mean_esmfold_rmsd_delta": results_df["esmfold_rmsd_delta"].mean(), + "mean_esmfold_rmsd_pocket_delta": results_df["esmfold_rmsd_pocket_delta"].mean(), + "mean_esmfold_plddt_delta": results_df["esmfold_plddt_delta"].mean(), + "std_esmfold_tm_delta": results_df["esmfold_tm_delta"].std(), + } + ) + + return {"results_df": results_df, "summary": summary} + + def sequence_to_string(self, seq_tensor: Tensor) -> str: + """Convert sequence tensor (in standard format) to string. + + All sequences (ground truth and predictions) are in standard tokenization format. + """ + return "".join([self.standard_aa_map.get(int(s), "X") for s in seq_tensor.cpu().tolist()]) diff --git a/src/lobster/model/gen_ume/README.md b/src/lobster/model/gen_ume/README.md new file mode 100644 index 00000000..ea271604 --- /dev/null +++ b/src/lobster/model/gen_ume/README.md @@ -0,0 +1,735 @@ +# Gen-UME: Generative Unified Molecular Encoder + +Gen-UME is a generative model for protein structure and sequence design based on discrete flow matching. It supports three generation modes: **unconditional generation**, **inverse folding**, and **forward folding**. + +## Quick Start + +```bash +# Unconditional: Generate novel proteins from scratch +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_unconditional + +# Inverse Folding: Design sequences for structures +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_inverse_folding \ + generation.input_structures="path/to/structures/*.pdb" + +# Forward Folding: Predict structures from sequences +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_forward_folding \ + generation.input_structures="path/to/structures/*.pdb" +``` + +## Table of Contents + +- [Quick Start](#quick-start) +- [Overview](#overview) +- [Model Checkpoints](#model-checkpoints) +- [Installation](#installation) +- [Generation Modes](#generation-modes) + - [Unconditional Generation](#1-unconditional-generation) + - [Inverse Folding](#2-inverse-folding) + - [Forward Folding](#3-forward-folding) +- [Benchmark Results](#benchmark-results) +- [Key Parameters](#key-parameters) +- [Advanced Features](#advanced-features) +- [Tips and Best Practices](#tips-and-best-practices) +- [Protein-Ligand](#protein-ligand) + - [Protein-Ligand Inverse Folding](#protein-ligand-inverse-folding) + - [Protein-Ligand Forward Folding](#protein-ligand-forward-folding) + +## Overview + +Gen-UME generates protein structures and sequences using discrete flow matching, a unified generative modeling approach that operates on both modalities simultaneously. The model can: + +- **Generate novel proteins** from scratch (unconditional) +- **Design sequences** for given structures (inverse folding) +- **Predict structures** from sequences (forward folding) + +### Technical Approach + +Gen-UME employs **discrete flow matching**, which models the generation process as a continuous-time flow on discrete state spaces (sequences) and continuous state spaces (structures). The model uses **tokenized structure representations** to encode protein backbone geometry, enabling efficient joint generation of sequence and structure. + +## Model Checkpoints + +For a complete list of all available checkpoints with detailed descriptions, see **[CHECKPOINTS.md](./CHECKPOINTS.md)**. + +### Quick Reference + +| Model | Size | S3 Path | Description | +|-------|------|---------|-------------| +| **Gen-UME 90M** | 1.1 GiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_90M_PDB.ckpt` | Smallest model, good for testing | +| **Gen-UME 450M** | 5.3 GiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_450M_2025-11-07_*.ckpt` | Medium model, balanced performance | +| **Gen-UME 750M** | 8.3 GiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_750M_2025-11-17_*.ckpt` | **Primary production model** | +| **Gen-UME 750M ESM Atlas** | 8.3 GiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_750M_ESM_Atlas_2026-01-04_*.ckpt` | Extended training data | + +### Download Checkpoints + +```bash +# Download Gen-UME 750M (recommended) +aws s3 cp s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/gen_ume_750M_2025-11-17_last.ckpt ./ + +# Download all Gen-UME checkpoints +aws s3 sync s3://prescient-pcluster-data/gen_ume/checkpoints/gen_ume/ ./checkpoints/ +``` + +### Latent Generator Checkpoints + +The Latent Generator provides the structure tokenization backbone: + +| Model | Codebook | Size | S3 Path | +|-------|----------|------|---------| +| **LG PL FSQ 4375** | 4375 tokens (FSQ) | 295.8 MiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_fsq_4375_2026-01-05.ckpt` | +| **LG PL FSQ 4375/15360** | 4375/15360 tokens (asymmetric FSQ) | 360.2 MiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_fsq_4375_15360_2026-01-07.ckpt` | +| **LG PL 4096** | 4096 tokens (SLQ) | 292.9 MiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_4096_2026-01-05.ckpt` | +| **LG Ligand** | 512 tokens (SLQ) | 250.5 MiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Ligand_2025-11-09.ckpt` | +| **LG Full Attention 2** | 256 tokens (SLQ) | 245.3 MiB | `s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_full_attention_2_2025-11-06.ckpt` | + +**Additional checkpoints on HuggingFace:** LG Ligand 20A, LG 20A seq Aux, LG 20A seq 3di c6d Aux, LG 20A seq 3di c6d Aux PDB Pinder, and more. See [CHECKPOINTS.md](./CHECKPOINTS.md) for the complete list. + +## Installation + +Ensure you have the lobster package installed: + +```bash +cd /path/to/lobster +uv pip install -e . +``` + +## Generation Modes + +### 1. Unconditional Generation + +Generate novel protein structures and sequences from scratch. + +#### Basic Usage + +Use the provided configuration file: + +```bash +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_unconditional +``` + +#### Configuration File + +The example configuration is located at `src/lobster/hydra_config/experiment/generate_unconditional.yaml`: + +```yaml +# Generation parameters +generation: + mode: unconditional + length: [100, 200, 300, 400, 500] # Sequence lengths to generate + num_samples: 10 # Samples per length + nsteps: 1000 # Diffusion steps + batch_size: 1 + + # Temperature and stochasticity control + temperature_seq: 0.4579796403264936 + temperature_struc: 0.35751879409731435 + stochasticity_seq: 30 + stochasticity_struc: 70 + + # ESMFold validation + use_esmfold: true + max_length: 512 +``` + +#### Override Parameters + +You can override any parameter from the command line: + +```bash +# Change output directory +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_unconditional \ + output_dir="./my_generation" + +# Generate different lengths +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_unconditional \ + generation.length="[50,100,150]" \ + generation.num_samples=20 +``` + +#### Self-Reflection (Enabled by Default) + +The provided config already has self-reflection enabled to improve structure-sequence consistency. The self-reflection pipeline refines unconditionally generated structures through forward and inverse folding steps: + +```yaml +generation: + enable_self_reflection: true # Already enabled in default config + + self_reflection: + forward_folding: + nsteps: 100 + temperature_seq: 0.2967457760634187 + temperature_struc: 0.1102551183666233 + stochasticity_seq: 10 + stochasticity_struc: 30 + + inverse_folding: + nsteps: 200 + temperature_seq: 0.16423763902324678 + temperature_struc: 1.0 + stochasticity_seq: 20 + stochasticity_struc: 10 + + quality_control: + enable_tm_threshold: true + min_tm_score_forward: 0.8334123066155882 + min_percent_identity: 50 + max_percent_identity: 100 + max_retries: 30 +``` + +To disable self-reflection: + +```bash +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_unconditional \ + generation.enable_self_reflection=false +``` + +### 2. Inverse Folding + +Generate sequences for given protein structures (sequence design). + +#### Basic Usage + +Use the provided configuration file and specify your input structures: + +```bash +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_inverse_folding \ + generation.input_structures="path/to/structures/*.pdb" +``` + +#### Configuration File + +The example configuration is located at `src/lobster/hydra_config/experiment/generate_inverse_folding.yaml`: + +```yaml +# Generation settings +generation: + mode: inverse_folding + nsteps: 200 + batch_size: 1 + n_trials: 3 # Generate multiple designs and select best + + # Temperature parameters (optimized for inverse folding) + temperature_seq: 0.16423763902324678 + temperature_struc: 1.0 + stochasticity_seq: 20 + stochasticity_struc: 10 + + n_designs_per_structure: 10 # Number of sequences per structure + + # Input structures - update via command line or edit config + input_structures: "test_data/inv_folding/9jl9.pdb" + + # ESMFold validation (recommended) + use_esmfold: true + max_length: 512 +``` + +#### Input Structure Formats + +Multiple input formats are supported: + +```bash +# Single file +generation.input_structures="/path/to/structure.pdb" + +# Directory (finds all PDB/CIF files) +generation.input_structures="/path/to/pdb/directory/" + +# Glob pattern +generation.input_structures="/path/to/structures/*.pdb" + +# Multiple files (use quotes) +generation.input_structures="[/path/to/file1.pdb,/path/to/file2.pdb]" +``` + +#### Multi-Chain Support + +For multi-chain structures, specify which chains to predict: + +```bash +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_inverse_folding \ + generation.input_structures="path/to/complex.pdb" \ + generation.esmfold_chain_groups="[[A,B],[C]]" +``` + +If not specified, all chains will be predicted together. + +### 3. Forward Folding + +Generate structures from sequences (structure prediction). + +#### Basic Usage + +Use the provided configuration file and specify your input structures: + +```bash +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_forward_folding \ + generation.input_structures="path/to/structures/*.pdb" +``` + +**Note:** Despite the name `input_structures`, forward folding extracts sequences from these structures to generate new structures. + +#### Configuration File + +The example configuration is located at `src/lobster/hydra_config/experiment/generate_forward_folding.yaml`: + +```yaml +# Generation settings +generation: + mode: forward_folding + nsteps: 100 + batch_size: 1 + n_trials: 1 + + # Temperature parameters (optimized for forward folding) + temperature_seq: 0.2967457760634187 + temperature_struc: 0.1102551183666233 + stochasticity_seq: 10 + stochasticity_struc: 30 + + # Input structures - sequences will be extracted from these + input_structures: "test_data/inv_folding/9jl9.pdb" + + max_length: 512 +``` + +#### Override Examples + +```bash +# Generate multiple trials for better results +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_forward_folding \ + generation.n_trials=5 + +# Change number of diffusion steps +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_forward_folding \ + generation.nsteps=200 +``` + +## Benchmark Results + +### Unconditional Generation + +Results from large-scale unconditional generation with self-reflection (100 samples per length): + +| Model | Length | Total Structures | RMSD<2.0 | % Pass | Clusters | Diversity % | Avg TM | Avg RMSD | Avg pLDDT | +|-------|--------|-----------------|----------|--------|----------|-------------|--------|----------|-----------| +| genUME 90M | 100 | 100 | 85 | 85.0% | 25 | 25.0% | 0.8203 | 1.963 | 0.7111 | +| genUME 90M | 200 | 100 | 63 | 63.0% | 19 | 19.0% | 0.8043 | 2.467 | 0.6639 | +| genUME 90M | 300 | 100 | 62 | 62.0% | 23 | 23.0% | 0.8447 | 2.015 | 0.6851 | +| genUME 90M | 400 | 100 | 56 | 56.0% | 13 | 13.0% | 0.8505 | 2.191 | 0.7177 | +| genUME 90M | 500 | 91 | 31 | 34.1% | 10 | 11.0% | 0.8344 | 2.730 | 0.7283 | + +**Metrics Explanation:** +- **RMSD<2.0**: Number of structures with RMSD < 2.0 Å between gen-UME and ESMFold predictions +- **% Pass**: Percentage of structures passing RMSD threshold +- **Clusters**: Number of unique structural clusters (Foldseek, TM-score threshold 0.5) +- **Diversity %**: Percentage of unique structures (clusters/total) +- **Avg TM**: Average TM-score between gen-UME structure and ESMFold prediction +- **Avg RMSD**: Average RMSD between gen-UME structure and ESMFold prediction +- **Avg pLDDT**: Average pLDDT (confidence score) from ESMFold prediction + +**Key Observations:** +- Shorter sequences (100-200 AA) show better consistency with ESMFold +- Self-reflection improves structure quality across all lengths +- Diversity remains high (10-25%) indicating generation of distinct structures +- High TM-scores (>0.8) indicate good structural quality + +### Inverse Folding + +Performance on sequence design for given structures: + +| Task | Model | AAR | TM-Score | +|------|-------|-----|----------| +| Inverse Folding | genUME 90M | 50.67% | 0.83 | + +**Metrics Explanation:** +- **AAR (Amino Acid Recovery)**: Percentage of positions where the designed sequence matches the native sequence +- **TM-Score**: Structural similarity between input structure and structure predicted from designed sequence + +**Key Observations:** +- AAR of 50.67% demonstrates strong sequence recovery capability +- TM-score of 0.83 indicates excellent structural preservation +- Model successfully designs sequences that fold back to target structures + +**Dataset:** Benchmarked on the dataset from [Generative Flows on Discrete State-Spaces](https://arxiv.org/abs/2402.04997) (Campbell et al., ICML 2024) + +### Forward Folding + +Performance on structure prediction from sequences: + +| Task | Model | TM-Score | +|------|-------|----------| +| Forward Folding | genUME 90M | 0.70 | + +**Metrics Explanation:** +- **TM-Score**: Structural similarity between generated structure and reference structure + +**Key Observations:** +- TM-score of 0.70 indicates good structure prediction capability +- Model generates plausible structures from sequence inputs + +**Dataset:** Benchmarked on the dataset from [Generative Flows on Discrete State-Spaces](https://arxiv.org/abs/2402.04997) (Campbell et al., ICML 2024) + +## Key Parameters + +### Temperature and Stochasticity + +Control the randomness and exploration of the generation process: + +| Parameter | Range | Effect | Recommended Values | +|-----------|-------|--------|-------------------| +| `temperature_seq` | 0.1-1.0 | Sequence randomness | Unconditional: 0.45, Inverse: 0.16, Forward: 0.30 | +| `temperature_struc` | 0.1-1.0 | Structure randomness | Unconditional: 0.35, Inverse: 1.0, Forward: 0.11 | +| `stochasticity_seq` | 0-100 | Sequence noise steps | Unconditional: 30, Inverse: 20, Forward: 10 | +| `stochasticity_struc` | 0-100 | Structure noise steps | Unconditional: 70, Inverse: 10, Forward: 30 | + +**Tips:** +- **Lower temperature** = more deterministic, conservative outputs +- **Higher temperature** = more diverse, exploratory outputs +- **Higher stochasticity** = more diffusion steps with noise injection + +### Generation Steps + +| Mode | Recommended nsteps | Notes | +|------|-------------------|-------| +| Unconditional | 1000 | Higher steps for de novo generation | +| Unconditional + Self-Reflection | Forward: 100, Inverse: 200 | Refinement needs fewer steps | +| Inverse Folding | 200 | Structure constrains generation | +| Forward Folding | 100 | Sequence constrains generation | + +## Advanced Features + +### Distributed Generation + +For large-scale generation, use the distributed generation system with WandB: + +```bash +# See distributed generation README +cd src/lobster/cmdline/distributed_generation +python create_job_config.py --total_samples 100 --samples_per_job 5 +``` + +See [Distributed Generation README](../../cmdline/distributed_generation/README.md) for details. + +### Foldseek Diversity Analysis + +Automatically cluster generated structures by structural similarity: + +```yaml +generation: + calculate_foldseek_diversity: true + foldseek_bin_path: "/path/to/foldseek/bin" + foldseek_tmscore_threshold: 0.5 # TM-score cutoff for clustering + rmsd_threshold_for_diversity: 2.0 # Only cluster high-quality structures +``` + +### Asynchronous Sampling + +Enable asynchronous sequence and structure sampling for faster generation: + +```yaml +generation: + asynchronous_sampling: true # Default: false +``` + +**Note:** This can significantly speed up generation but may affect reproducibility. + +## Tips and Best Practices + +### 1. Start Small +Begin with small test runs to validate your setup: +```bash +uv run python -m lobster.cmdline.generate \ + --config-path "../hydra_config/experiment" \ + --config-name generate_unconditional \ + generation.length="[100]" \ + generation.num_samples=2 \ + generation.nsteps=100 +``` + +### 2. Use Self-Reflection for Quality +Self-reflection is enabled by default in `generate_unconditional.yaml` to improve ESMFold metrics. To disable it: +```bash +generation.enable_self_reflection=false +``` + +### 3. ESMFold Validation +ESMFold is enabled by default in the provided configs and provides crucial quality metrics. Adjust `max_length` based on your sequences: +```bash +generation.max_length=1024 # For longer sequences +``` + +### 4. Batch Size Selection +- **GPU Memory Limited**: Use `batch_size: 1` +- **Long sequences (>400)**: Use `batch_size: 1` +- **Short sequences (<200)**: Can use `batch_size: 2-4` + +### 5. Output Organization +Always use descriptive output directories for tracking experiments: +```bash +output_dir="./examples/generation_20251104_my_experiment" +``` + +### 6. Reproducibility +Set a seed for reproducible results: +```bash +seed=12345 +``` + +### 7. Monitor Progress +CSV metrics and plots are enabled by default in the provided configs. To disable: +```bash +generation.save_csv_metrics=false +generation.create_plots=false +``` + +### 8. Multi-Chain Design +For inverse folding of multi-chain complexes: +```bash +generation.esmfold_chain_groups="[[A,B],[C]]" # Design chains A+B together, C separately +``` + +### 9. Quality Control +Quality control is enabled by default in unconditional generation with self-reflection. Adjust thresholds if needed: +```bash +generation.self_reflection.quality_control.min_tm_score_forward=0.9 +generation.self_reflection.quality_control.max_retries=50 +``` + +### 10. Structure File Formats +Supported input formats: +- **PDB files** (`.pdb`) +- **mmCIF files** (`.cif`) +- **PyTorch tensors** (`.pt`) + +## Output Files + +After generation, you'll find: + +``` +output_dir/ +├── generated_structure_length_XXX_YYY.pdb # Generated structures +├── generated_structure_length_XXX_YYY_esmfold_000.pdb # ESMFold predictions +├── unconditional_metrics_TIMESTAMP.csv # Metrics for all samples +├── unconditional_sequences_TIMESTAMP.csv # Generated sequences +├── unconditional_combined_boxplots_TIMESTAMP.png # Visualizations +└── foldseek_results/ # Diversity analysis (if enabled) + └── length_XXX/ + ├── res_rep_seq.fasta # Cluster representatives + └── res_cluster.tsv # Cluster assignments +``` + +## Citation + +If you use Gen-UME in your research, please cite: + +``` +[Citation to be added] +``` + +### Benchmark Dataset + +The inverse folding and forward folding benchmarks use the dataset from: + +```bibtex +@inproceedings{campbell2024generative, + title={Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design}, + author={Campbell, Andrew and Yim, Jason and Barzilay, Regina and Rainforth, Tom and Jaakkola, Tommi}, + booktitle={International Conference on Machine Learning (ICML)}, + year={2024}, + url={https://arxiv.org/abs/2402.04997} +} +``` + +## Protein-Ligand + +Gen-UME Protein-Ligand extends the model to handle protein-ligand complexes. The key insight is that **providing ligand context can improve both sequence design and structure prediction**, particularly for binding pocket residues. + +### Protein-Ligand Inverse Folding + +Design protein sequences conditioned on both protein structure **and** ligand context. + +**Command Line:** + +```bash +# Evaluate inverse folding with/without ligand context +uv run python -m lobster.cmdline.evaluate_protein_ligand_inverse_folding \ + --checkpoint /cv/scratch/u/lisanzas/gen_ume_protein_ligand/runs//2026-01-27T16-05-28/epoch=114-step=37186-val_loss=3.1698.ckpt \ + --data_dir /cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/ \ + --structure_path ./protein_ligand_eval_inverse_folding/ \ + --output protein_ligand_inverse_folding_results.csv \ + --pocket_threshold 5.0 \ + --num_samples 100 \ + --nsteps 100 \ + --device cuda \ + --decode_structure \ + --save_gt_structure \ + --minimize_ligand +``` + +**Options:** +- `--checkpoint`: Path to protein-ligand model checkpoint (required) +- `--data_dir`: Directory with `*_protein.pt` and `*_ligand.pt` pairs (default: `/data2/lisanzas/pdb_bind_12_15_25/test/`) +- `--structure_path`: Output directory for designed sequences (FASTA files), decoded structures (PDB files), and results +- `--output`: Output CSV file for per-structure results (default: `protein_ligand_inverse_folding_results.csv`) +- `--pocket_threshold`: Distance (Å) to define binding pocket residues (default: 5.0) +- `--num_samples`: Number of samples to evaluate (-1 for all, default: 100) +- `--nsteps`: Diffusion steps for generation (default: 100) +- `--device`: Device for computation, `cuda` or `cpu` (default: `cuda`) +- `--decode_structure`: Flag to decode and save predicted protein structures as PDB files (includes decoded ligand when ligand context is used) +- `--save_gt_structure`: Flag to save ground truth protein and protein-ligand complex structures as PDB files +- `--minimize_ligand`: Flag to apply Open Babel geometry correction to decoded ligand structures +- `--minimize_mode`: Minimization mode: `bonds_only`, `bonds_and_angles` (default), `local`, or `full` +- `--force_field`: Force field for minimization: `MMFF94` (default), `UFF`, `MMFF94s`, etc. +- `--minimize_steps`: Maximum number of minimization steps (default: 500) + +**Output Files** (in `--structure_path`): +``` +protein_ligand_eval/ +├── protein_ligand_inverse_folding_results.csv # Per-structure metrics +├── {pdb_id}_sequences.fasta # Ground truth + designed sequences +├── {pdb_id}_protein.pdb # Ground truth protein structure (--save_gt_structure) +├── {pdb_id}_complex.pdb # Ground truth protein-ligand complex (--save_gt_structure) +├── {pdb_id}_decoded_no_ligand.pdb # Decoded structure without ligand context (--decode_structure) +├── {pdb_id}_decoded_with_ligand.pdb # Decoded protein-ligand complex with ligand context (--decode_structure) +└── ... +``` + +**Tracked Metrics:** +- `aar_overall_*`: Overall amino acid recovery +- `aar_pocket_*`: Pocket-only recovery (residues within threshold of ligand) +- `aar_nonpocket_*`: Non-pocket recovery +- `*_delta`: Improvement from providing ligand context + +### Protein-Ligand Forward Folding + +Predict protein structure from sequence **with ligand context**. + +**Command Line:** + +```bash +# Evaluate forward folding with/without ligand context +uv run python -m lobster.cmdline.evaluate_protein_ligand_forward_folding \ + --checkpoint /cv/scratch/u/lisanzas/gen_ume_protein_ligand/runs//2026-01-27T16-05-28/epoch=114-step=37186-val_loss=3.1698.ckpt \ + --data_dir /cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/ \ + --structure_path ./protein_ligand_eval/ \ + --output protein_ligand_forward_folding_results.csv \ + --pocket_threshold 5.0 \ + --num_samples 100 \ + --nsteps 100 \ + --device cuda \ + --save_structures \ + --save_gt_structure \ + --minimize_ligand +``` + +**Options:** +- `--checkpoint`: Path to protein-ligand model checkpoint (required) +- `--data_dir`: Directory with `*_protein.pt` and `*_ligand.pt` pairs (default: `/data2/lisanzas/pdb_bind_12_15_25/test/`) +- `--structure_path`: Output directory for predicted structures (PDB files) and results +- `--output`: Output CSV file for per-structure results (default: `protein_ligand_forward_folding_results.csv`) +- `--pocket_threshold`: Distance (Å) to define binding pocket residues (default: 5.0) +- `--num_samples`: Number of samples to evaluate (-1 for all, default: 100) +- `--nsteps`: Diffusion steps for generation (default: 100) +- `--device`: Device for computation, `cuda` or `cpu` (default: `cuda`) +- `--temperature_seq`: Temperature for sequence sampling (default: 0.5) +- `--temperature_struc`: Temperature for structure sampling (default: 0.5) +- `--save_structures`: Flag to save predicted protein structures as PDB files +- `--save_gt_structure`: Flag to save ground truth protein and protein-ligand complex structures as PDB files +- `--minimize_ligand`: Flag to apply Open Babel geometry correction to decoded ligand structures +- `--minimize_mode`: Minimization mode: `bonds_only`, `bonds_and_angles` (default), `local`, or `full` +- `--force_field`: Force field for minimization: `MMFF94` (default), `UFF`, `MMFF94s`, etc. +- `--minimize_steps`: Maximum number of minimization steps (default: 500) + +**Output Files** (in `--structure_path`): +``` +protein_ligand_eval/ +├── protein_ligand_forward_folding_results.csv # Per-structure metrics +├── {pdb_id}_gt_protein.pdb # Ground truth protein structure (--save_gt_structure) +├── {pdb_id}_gt_complex.pdb # Ground truth protein-ligand complex (--save_gt_structure) +├── {pdb_id}_pred_no_ligand.pdb # Predicted structure without ligand context (--save_structures) +├── {pdb_id}_pred_with_ligand.pdb # Predicted protein + decoded ligand structure (--save_structures) +└── ... +``` + +**Callback Configuration** (for training-time evaluation): + +```yaml +protein_ligand_forward_folding: + _target_: lobster.callbacks.ProteinLigandForwardFoldingCallback + data_dir: /cv/data/ai4dd/data2/lisanzas/pdb_bind_12_15_25/test/ + structure_path: ${paths.output_dir}/protein_ligand_eval/ + save_every_n: 1000 + num_samples: 100 + pocket_distance_threshold: 5.0 # Å + nsteps: 100 + # Ligand minimization options (for decoded ligand structures) + minimize_ligand: false + minimize_mode: "bonds_and_angles" # "bonds_only", "bonds_and_angles", "local", or "full" + force_field: "MMFF94" + minimize_steps: 500 +``` + +**Tracked Metrics:** +- `tm_score_*`: Overall TM-score (structure similarity) +- `rmsd_overall_*`: Overall backbone RMSD +- `rmsd_pocket_*`: Pocket-only RMSD (residues within threshold of ligand) +- `rmsd_nonpocket_*`: Non-pocket RMSD +- `*_delta`: Improvement from providing ligand context + +### Protein-Ligand Checkpoints + +**Latest checkpoint** (`/cv/scratch/u/lisanzas/gen_ume_protein_ligand/runs/2026-01-22T03-49-10/`): + +| Checkpoint | Validation Loss | +|------------|-----------------| +| `epoch=49-step=16126-val_loss=3.5112.ckpt` | **3.5112** (best) | +| `epoch=57-step=18718-val_loss=3.5529.ckpt` | 3.5529 | +| `epoch=47-step=15303-val_loss=3.5601.ckpt` | 3.5601 | + +### Data Format + +The protein-ligand evaluators expect paired `.pt` files: + +``` +data_dir/ +├── pdb_id_protein.pt # Contains: coords_res, sequence, mask, indices +└── pdb_id_ligand.pt # Contains: atom_coords, atom_names/element_indices, mask, bond_matrix +``` + +## Support + +For issues and questions: +- **GitHub Issues**: [prescient-design/lobster](https://github.com/prescient-design/lobster) +- **Documentation**: See `src/lobster/cmdline/generate.py` for implementation details +- **Examples**: Check `src/lobster/hydra_config/experiment/` for example configurations + +--- + +**Last Updated**: January 2026 + diff --git a/src/lobster/model/gen_ume/__init__.py b/src/lobster/model/gen_ume/__init__.py index e523ff9b..2b121b2b 100644 --- a/src/lobster/model/gen_ume/__init__.py +++ b/src/lobster/model/gen_ume/__init__.py @@ -5,9 +5,21 @@ ) from ._gen_ume_sequence_structure_encoder_lightning_module import UMESequenceStructureEncoderLightningModule +# Protein-Ligand support +from ._gen_ume_protein_ligand_encoder import ProteinLigandEncoderModule +from ._gen_ume_protein_ligand_lightning import ProteinLigandEncoderLightningModule +from ._bond_embedding import BondMatrixEmbedding +from ._bond_prediction import BondMatrixPredictionHead, BondMatrixLoss + __all__ = [ "UMESequenceStructureEncoderModule", "UMESequenceStructureEncoderLightningModule", "AuxiliaryTask", "AuxiliaryRegressionTaskHead", + # Protein-Ligand support + "ProteinLigandEncoderModule", + "ProteinLigandEncoderLightningModule", + "BondMatrixEmbedding", + "BondMatrixPredictionHead", + "BondMatrixLoss", ] diff --git a/src/lobster/model/gen_ume/_bond_embedding.py b/src/lobster/model/gen_ume/_bond_embedding.py new file mode 100644 index 00000000..3f4ae04f --- /dev/null +++ b/src/lobster/model/gen_ume/_bond_embedding.py @@ -0,0 +1,114 @@ +"""Bond matrix embedding module for Gen-UME protein-ligand modeling. + +This module embeds bond matrix information into atom features using an +encoder-agnostic design. Bond information is added to input features +rather than modifying attention, allowing any encoder to be used. +""" + +import torch +import torch.nn as nn + +from lobster.model.latent_generator.utils.residue_constants import NUM_BOND_TYPES + + +class BondMatrixEmbedding(nn.Module): + """Embed bond matrix information into atom features. + + For each atom, this module: + 1. Looks at bonded neighbors in the bond matrix + 2. Embeds the bond types (single, double, triple, aromatic) + 3. Aggregates into a single vector and adds to atom embedding + + The transformer's attention handles longer-range topology naturally + (layer 1 sees neighbors, layer 2 sees neighbors-of-neighbors, etc.) + + Parameters + ---------- + hidden_size : int + Dimension of atom embeddings. + num_bond_types : int, optional + Number of bond types (default: 6). + 0=none, 1=single, 2=double, 3=triple, 4=aromatic, 5=other. + + Examples + -------- + >>> embed = BondMatrixEmbedding(hidden_size=64) + >>> atom_embeddings = torch.randn(2, 10, 64) + >>> bond_matrix = torch.randint(0, 5, (2, 10, 10)) + >>> atom_mask = torch.ones(2, 10) + >>> enriched = embed(atom_embeddings, bond_matrix, atom_mask) + >>> enriched.shape + torch.Size([2, 10, 64]) + + Notes + ----- + Design decision: We use SUM (not MEAN) for aggregation because atom degree + is chemically informative - a terminal -CH3 with 1 bond behaves differently + from a ring carbon with 3 bonds. LayerNorm in the transformer handles + magnitude differences. + """ + + def __init__( + self, + hidden_size: int, + num_bond_types: int = NUM_BOND_TYPES, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_bond_types = num_bond_types + + # Bond type embeddings + self.bond_type_embedding = nn.Embedding(num_bond_types, hidden_size) + + # Project aggregated bond info + self.bond_proj = nn.Linear(hidden_size, hidden_size) + + # LayerNorm for stable training + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward( + self, + atom_embeddings: torch.Tensor, + bond_matrix: torch.Tensor, + atom_mask: torch.Tensor, + ) -> torch.Tensor: + """Enrich atom embeddings with direct bond information. + + Parameters + ---------- + atom_embeddings : torch.Tensor + Base atom embeddings with shape [B, N_atoms, H]. + bond_matrix : torch.Tensor + Bond type matrix with shape [B, N_atoms, N_atoms]. + Values: 0=none, 1=single, 2=double, 3=triple, 4=aromatic, 5=other. + atom_mask : torch.Tensor + Valid atom mask with shape [B, N_atoms]. + + Returns + ------- + torch.Tensor + Enriched atom embeddings with shape [B, N_atoms, H]. + """ + # Embed all bonds: [B, N, N] -> [B, N, N, H] + bond_embeds = self.bond_type_embedding(bond_matrix) + + # Mask out padding atoms: create 2D mask for atom pairs + mask_2d = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2) # [B, N, N] + bond_embeds = bond_embeds * mask_2d.unsqueeze(-1) # [B, N, N, H] + + # Sum over neighbors where bond exists (bond_type > 0) + # [B, N, N, H] -> [B, N, H] + # NOTE: We use SUM not MEAN because atom degree is chemically informative + bond_exists = (bond_matrix > 0).float().unsqueeze(-1) # [B, N, N, 1] + neighbor_bonds = (bond_embeds * bond_exists).sum(dim=2) # [B, N, H] + + # Project bond context + bond_context = self.bond_proj(neighbor_bonds) # [B, N, H] + + # Add to atom embeddings (residual connection) + enriched = atom_embeddings + bond_context + + # Normalize + enriched = self.layer_norm(enriched) + + return enriched diff --git a/src/lobster/model/gen_ume/_bond_prediction.py b/src/lobster/model/gen_ume/_bond_prediction.py new file mode 100644 index 00000000..2abc8e6d --- /dev/null +++ b/src/lobster/model/gen_ume/_bond_prediction.py @@ -0,0 +1,186 @@ +"""Bond matrix prediction head for Gen-UME protein-ligand modeling. + +This module predicts bond types between atom pairs from encoder output +features. Used for SMILES reconstruction from generated atom types. +""" + +import torch +import torch.nn as nn + +from lobster.model.latent_generator.utils.residue_constants import NUM_BOND_TYPES + + +class BondMatrixPredictionHead(nn.Module): + """Predict bond matrix from atom features. + + Given atom features from the encoder, predicts bond types between + all atom pairs. Output can be used with cross-entropy loss against + ground truth bond matrices. + + This implementation uses a memory-efficient outer product approach + rather than explicit pairwise tensor construction to reduce GPU memory. + + Parameters + ---------- + hidden_size : int + Dimension of input atom features. + num_bond_types : int, optional + Number of bond type classes (default: 6). + 0=none, 1=single, 2=double, 3=triple, 4=aromatic, 5=other. + symmetric : bool, optional + If True, enforce symmetric predictions (default: True). + Bonds are inherently symmetric (A-B = B-A). + + Examples + -------- + >>> head = BondMatrixPredictionHead(hidden_size=64) + >>> atom_features = torch.randn(2, 10, 64) + >>> logits = head(atom_features) + >>> logits.shape + torch.Size([2, 10, 10, 6]) + + Notes + ----- + Uses outer product of projected features for memory efficiency. + """ + + def __init__( + self, + hidden_size: int, + num_bond_types: int = NUM_BOND_TYPES, + symmetric: bool = True, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_bond_types = num_bond_types + self.symmetric = symmetric + + # Project atoms to smaller dimension for efficiency + proj_dim = min(hidden_size, 64) + self.proj_dim = proj_dim + + # Separate projections for source and destination atoms + self.proj_src = nn.Linear(hidden_size, proj_dim * num_bond_types) + self.proj_dst = nn.Linear(hidden_size, proj_dim * num_bond_types) + + def forward( + self, + atom_features: torch.Tensor, + atom_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Predict bond types between all atom pairs. + + Parameters + ---------- + atom_features : torch.Tensor + Atom features from encoder with shape [B, N_atoms, H]. + atom_mask : torch.Tensor, optional + Valid atom mask with shape [B, N_atoms]. + + Returns + ------- + torch.Tensor + Bond type logits with shape [B, N_atoms, N_atoms, num_bond_types]. + """ + batch_size, num_atoms, _ = atom_features.shape + + # Project to [B, N, proj_dim * num_bond_types] + src = self.proj_src(atom_features) + dst = self.proj_dst(atom_features) + + # Reshape to [B, N, num_bond_types, proj_dim] + src = src.view(batch_size, num_atoms, self.num_bond_types, self.proj_dim) + dst = dst.view(batch_size, num_atoms, self.num_bond_types, self.proj_dim) + + # Compute outer product via einsum: [B, N_i, K, D] x [B, N_j, K, D] -> [B, N_i, N_j, K] + # This is equivalent to sum over D of src[i,k,:] * dst[j,k,:] + logits = torch.einsum("biku,bjku->bijk", src, dst) + + # Enforce symmetry if requested + if self.symmetric: + # Average logits[i,j] and logits[j,i] + logits = (logits + logits.transpose(1, 2)) / 2 + + # Mask diagonal (no self-bonds) by setting to large negative for "no bond" + diag_mask = torch.eye(num_atoms, device=logits.device, dtype=torch.bool) + diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, -1, self.num_bond_types) + + # Set diagonal to favor "no bond" (index 0) + no_bond_logits = torch.zeros(self.num_bond_types, device=logits.device, dtype=logits.dtype) + no_bond_logits[0] = 10.0 # High logit for "no bond" + no_bond_logits[1:] = -10.0 # Low logit for actual bonds + + logits = torch.where(diag_mask, no_bond_logits, logits) + + return logits + + +class BondMatrixLoss(nn.Module): + """Compute loss for bond matrix prediction. + + Parameters + ---------- + ignore_diagonal : bool, optional + If True, ignore diagonal elements in loss (default: True). + class_weights : torch.Tensor, optional + Weights for each bond type class. Useful for handling class imbalance. + """ + + def __init__( + self, + ignore_diagonal: bool = True, + class_weights: torch.Tensor | None = None, + ): + super().__init__() + self.ignore_diagonal = ignore_diagonal + self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, reduction="none") + + def forward( + self, + logits: torch.Tensor, + target: torch.Tensor, + atom_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compute bond prediction loss. + + Parameters + ---------- + logits : torch.Tensor + Predicted logits with shape [B, N, N, num_bond_types]. + target : torch.Tensor + Target bond types with shape [B, N, N]. + atom_mask : torch.Tensor, optional + Valid atom mask with shape [B, N]. + + Returns + ------- + torch.Tensor + Scalar loss value. + """ + batch_size, num_atoms, _, num_classes = logits.shape + + # Flatten for cross-entropy (use reshape instead of view for non-contiguous tensors) + logits_flat = logits.reshape(-1, num_classes) + target_flat = target.reshape(-1) + + # Compute per-element loss + loss = self.ce_loss(logits_flat, target_flat) + loss = loss.reshape(batch_size, num_atoms, num_atoms) + + # Create mask + if atom_mask is not None: + # Only compute loss where both atoms are valid + pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2) + else: + pair_mask = torch.ones(batch_size, num_atoms, num_atoms, device=loss.device) + + # Ignore diagonal + if self.ignore_diagonal: + diag_mask = ~torch.eye(num_atoms, device=loss.device, dtype=torch.bool) + diag_mask = diag_mask.unsqueeze(0).expand(batch_size, -1, -1) + pair_mask = pair_mask * diag_mask.float() + + # Apply mask and compute mean + loss = (loss * pair_mask).sum() / (pair_mask.sum() + 1e-8) + + return loss diff --git a/src/lobster/model/gen_ume/_gen_ume_protein_ligand_encoder.py b/src/lobster/model/gen_ume/_gen_ume_protein_ligand_encoder.py new file mode 100644 index 00000000..b0f88d13 --- /dev/null +++ b/src/lobster/model/gen_ume/_gen_ume_protein_ligand_encoder.py @@ -0,0 +1,527 @@ +"""Gen-UME Protein-Ligand Encoder Module. + +This module extends the Gen-UME encoder to support both protein-only and +protein-ligand modeling. It maintains backward compatibility with existing +protein-only tasks while adding new capabilities for protein-ligand complexes. + +Key features: +- Unified architecture handling both protein and ligand modalities +- Bond matrix embedding for molecular topology +- Bond matrix prediction for SMILES reconstruction +- Ligand-optional forward pass (behaves like original Gen-UME when no ligand) +""" + +import logging +from dataclasses import dataclass +from typing import Literal + +import torch +import torch.nn as nn +from torch import Tensor + +from lobster.model.latent_generator.utils.residue_constants import ( + ELEMENT_VOCAB_EXTENDED, + NUM_BOND_TYPES, +) + +from ._bond_embedding import BondMatrixEmbedding +from ._bond_prediction import BondMatrixPredictionHead +from ..neobert import NeoBERTModule + +logger = logging.getLogger(__name__) + + +@dataclass +class AuxiliaryTask: + """Configuration for auxiliary tasks.""" + + name: str + output_dim: int + task_type: Literal["regression"] = "regression" + pooling: Literal["cls", "mean"] = "mean" + hidden_size: int | None = None + dropout: float = 0.1 + num_layers: int = 2 + loss_weight: float = 1.0 + + +class ProteinLigandEncoderModule(nn.Module): + """Unified encoder for protein and protein-ligand modeling. + + This encoder handles: + - Protein sequence tokens (per-residue, 33 vocabulary) + - Protein structure tokens (per-residue, 4375 vocabulary) + - Ligand atom type tokens (per-atom, 25 vocabulary) - OPTIONAL + - Ligand structure tokens (per-atom, 4375 vocabulary) - OPTIONAL + - Bond matrix embedding (NxN, embedded into features) - OPTIONAL + + When ligand inputs are None, the encoder behaves identically to the + original Gen-UME for backward compatibility. + + Parameters + ---------- + sequence_token_vocab_size : int + Size of protein sequence vocabulary (default: 33). + structure_token_vocab_size : int + Size of structure token vocabulary (default: 4375 for FSQ). + ligand_atom_vocab_size : int + Size of ligand atom type vocabulary (default: 25 from ELEMENT_VOCAB_EXTENDED). + ligand_structure_vocab_size : int + Size of ligand structure token vocabulary (default: 4375). + sequence_token_pad_token_id : int + Padding token ID for sequences (default: 1). + structure_token_pad_token_id : int + Padding token ID for structures (default: 4374). + ligand_atom_pad_token_id : int + Padding token ID for ligand atoms (default: 0 for PAD). + conditioning_input_dim : int + Dimension of conditioning input (default: 1). + **neobert_kwargs + Additional arguments for NeoBERTModule. + + Examples + -------- + >>> encoder = ProteinLigandEncoderModule(hidden_size=256, num_hidden_layers=6) + >>> # Protein-only forward (backward compatible) + >>> output = encoder(seq_ids, struct_ids, mask, cond) + >>> # Protein-ligand forward + >>> output = encoder(seq_ids, struct_ids, mask, cond, + ... ligand_atom_input_ids=atom_ids, + ... ligand_structure_input_ids=lig_struct_ids, + ... ligand_mask=lig_mask, bond_matrix=bonds) + """ + + def __init__( + self, + auxiliary_tasks: list[AuxiliaryTask] | None = None, + model_ckpt: str | None = None, + cache_dir: str | None = None, + sequence_token_vocab_size: int = 33, + structure_token_vocab_size: int = 4375, + ligand_atom_vocab_size: int = len(ELEMENT_VOCAB_EXTENDED), + ligand_structure_vocab_size: int = 4375, + sequence_token_pad_token_id: int = 1, + structure_token_pad_token_id: int = 4374, + ligand_atom_pad_token_id: int = 0, # PAD in ELEMENT_VOCAB_EXTENDED + ligand_structure_pad_token_id: int = 4374, + conditioning_input_dim: int = 1, + num_bond_types: int = NUM_BOND_TYPES, + **neobert_kwargs, + ) -> None: + super().__init__() + + # Store config + self.sequence_token_vocab_size = sequence_token_vocab_size + self.structure_token_vocab_size = structure_token_vocab_size + self.ligand_atom_vocab_size = ligand_atom_vocab_size + self.ligand_structure_vocab_size = ligand_structure_vocab_size + self.num_bond_types = num_bond_types + + # Initialize NeoBERT backbone + self.neobert = NeoBERTModule(**neobert_kwargs) + hidden_size = self.neobert.config.hidden_size + self.hidden_size = hidden_size # Store for continuous projection initialization + + # === PROTEIN EMBEDDINGS (same as original) === + self.sequence_embedding = nn.Embedding( + sequence_token_vocab_size, + hidden_size, + padding_idx=sequence_token_pad_token_id, + ) + self.structure_embedding = nn.Embedding( + structure_token_vocab_size, + hidden_size, + padding_idx=structure_token_pad_token_id, + ) + self.conditioning_embedding = nn.Linear(conditioning_input_dim, hidden_size, bias=False) + self.combine_embedding = nn.Linear(hidden_size * 3, hidden_size) + + # === LIGAND EMBEDDINGS (new) === + self.ligand_atom_embedding = nn.Embedding( + ligand_atom_vocab_size, + hidden_size, + padding_idx=ligand_atom_pad_token_id, + ) + self.ligand_structure_embedding = nn.Embedding( + ligand_structure_vocab_size, + hidden_size, + padding_idx=ligand_structure_pad_token_id, + ) + self.ligand_combine_embedding = nn.Linear(hidden_size * 2, hidden_size) + + # === CONTINUOUS STRUCTURE EMBEDDING PROJECTION (for diffusion-based generation) === + # When using continuous structure embeddings (from DiffusionLoss), project to hidden_size + # This is set dynamically if continuous_structure_dim is provided + self.continuous_structure_proj: nn.Linear | None = None + self.continuous_ligand_structure_proj: nn.Linear | None = None + + # === BOND MATRIX MODULES (new) === + self.bond_embedding = BondMatrixEmbedding( + hidden_size=hidden_size, + num_bond_types=num_bond_types, + ) + self.bond_prediction_head = BondMatrixPredictionHead( + hidden_size=hidden_size, + num_bond_types=num_bond_types, + symmetric=True, + ) + + # === MODALITY EMBEDDING (new) === + # 0 = protein, 1 = ligand + self.modality_embedding = nn.Embedding(2, hidden_size) + + # === OUTPUT HEADS === + # Protein outputs (same as original) + self.sequence_output = nn.Linear(hidden_size, sequence_token_vocab_size) + self.structure_output = nn.Linear(hidden_size, structure_token_vocab_size) + + # Ligand outputs (new) + self.ligand_atom_output = nn.Linear(hidden_size, ligand_atom_vocab_size) + self.ligand_structure_output = nn.Linear(hidden_size, ligand_structure_vocab_size) + + # Handle auxiliary tasks + self.auxiliary_tasks = None + if auxiliary_tasks is not None: + from ..ume2 import AuxiliaryRegressionTaskHead + + self.auxiliary_tasks = nn.ModuleDict( + { + task.name: AuxiliaryRegressionTaskHead( + input_dim=hidden_size, + output_dim=task.output_dim, + task_name=task.name, + hidden_size=task.hidden_size, + dropout=task.dropout, + num_layers=task.num_layers, + pooling=task.pooling, + ) + for task in auxiliary_tasks + } + ) + + def init_continuous_structure_proj(self, continuous_dim: int) -> None: + """Initialize projection layers for continuous structure embeddings. + + This is called when using DiffusionLoss for structure tokens. + Projects continuous embeddings to hidden_size for input to transformer. + + Parameters + ---------- + continuous_dim : int + Dimension of continuous structure embeddings (e.g., 256). + """ + self.continuous_structure_proj = nn.Linear(continuous_dim, self.hidden_size) + self.continuous_ligand_structure_proj = nn.Linear(continuous_dim, self.hidden_size) + + def _embed_protein( + self, + sequence_input_ids: Tensor, + structure_input_ids: Tensor | None, + conditioning_tensor: Tensor | None, + structure_embeddings: Tensor | None = None, + ) -> Tensor: + """Embed protein sequence and structure tokens. + + Parameters + ---------- + sequence_input_ids : Tensor + Sequence token IDs with shape [B, N_res]. + structure_input_ids : Tensor, optional + Structure token IDs with shape [B, N_res]. Can be None if + structure_embeddings is provided. + conditioning_tensor : Tensor, optional + Conditioning input with shape [B, N_res, C]. + structure_embeddings : Tensor, optional + Pre-computed continuous structure embeddings with shape [B, N_res, D]. + For MAR-style generation: zeros at uncommitted positions, sampled + embeddings at committed positions. The mask embedding from + structure_input_ids is used for uncommitted positions. + + Returns + ------- + Tensor + Combined protein embedding with shape [B, N_res, H]. + """ + batch_size, seq_len = sequence_input_ids.shape + device = sequence_input_ids.device + + seq_emb = self.sequence_embedding(sequence_input_ids) + + # Handle structure embeddings + if structure_embeddings is not None and self.continuous_structure_proj is not None: + # MAR-style: blend continuous embeddings with mask embeddings + # structure_embeddings has zeros for uncommitted positions + # We need mask embeddings from structure_input_ids for those positions + + # Get discrete embedding (mask tokens for uncommitted positions) + discrete_struct_emb = self.structure_embedding(structure_input_ids) + + # Project continuous embeddings to hidden_size + continuous_struct_emb = self.continuous_structure_proj(structure_embeddings) + + # Create blend mask: 1 where continuous embeddings are non-zero (committed) + # This is approximate - assumes committed positions have non-zero embeddings + committed_mask = (structure_embeddings.abs().sum(dim=-1, keepdim=True) > 1e-6).float() + + # Blend: discrete for uncommitted, continuous for committed + struct_emb = discrete_struct_emb * (1 - committed_mask) + continuous_struct_emb * committed_mask + elif structure_input_ids is not None: + struct_emb = self.structure_embedding(structure_input_ids) + else: + # Fallback: if only continuous embeddings provided without projection + struct_emb = structure_embeddings + + if conditioning_tensor is None: + conditioning_tensor = torch.zeros(batch_size, seq_len, 1, device=device) + cond_emb = self.conditioning_embedding(conditioning_tensor) + + # Combine: [seq; struct; cond] -> H + combined = torch.cat([seq_emb, struct_emb, cond_emb], dim=-1) + protein_emb = self.combine_embedding(combined) + + # Add modality embedding (protein = 0) + modality_ids = torch.zeros(batch_size, seq_len, dtype=torch.long, device=device) + protein_emb = protein_emb + self.modality_embedding(modality_ids) + + return protein_emb + + def _embed_ligand( + self, + ligand_atom_input_ids: Tensor, + ligand_structure_input_ids: Tensor | None, + bond_matrix: Tensor, + ligand_mask: Tensor, + ligand_structure_embeddings: Tensor | None = None, + ) -> Tensor: + """Embed ligand atom types and structure tokens with bond information. + + Parameters + ---------- + ligand_atom_input_ids : Tensor + Atom type token IDs with shape [B, N_atoms]. + ligand_structure_input_ids : Tensor, optional + Ligand structure token IDs with shape [B, N_atoms]. Can be None if + ligand_structure_embeddings is provided. + bond_matrix : Tensor + Bond type matrix with shape [B, N_atoms, N_atoms]. + ligand_mask : Tensor + Valid atom mask with shape [B, N_atoms]. + ligand_structure_embeddings : Tensor, optional + Pre-computed continuous structure embeddings with shape [B, N_atoms, D]. + For MAR-style generation: zeros at uncommitted positions, sampled + embeddings at committed positions. + + Returns + ------- + Tensor + Combined ligand embedding with shape [B, N_atoms, H]. + """ + batch_size, num_atoms = ligand_atom_input_ids.shape + device = ligand_atom_input_ids.device + + # Embed atom types + atom_emb = self.ligand_atom_embedding(ligand_atom_input_ids) + + # Handle structure embeddings + if ligand_structure_embeddings is not None and self.continuous_ligand_structure_proj is not None: + # MAR-style: blend continuous embeddings with mask embeddings + discrete_struct_emb = self.ligand_structure_embedding(ligand_structure_input_ids) + continuous_struct_emb = self.continuous_ligand_structure_proj(ligand_structure_embeddings) + + # Create blend mask: 1 where continuous embeddings are non-zero (committed) + committed_mask = (ligand_structure_embeddings.abs().sum(dim=-1, keepdim=True) > 1e-6).float() + + # Blend: discrete for uncommitted, continuous for committed + struct_emb = discrete_struct_emb * (1 - committed_mask) + continuous_struct_emb * committed_mask + elif ligand_structure_input_ids is not None: + struct_emb = self.ligand_structure_embedding(ligand_structure_input_ids) + else: + struct_emb = ligand_structure_embeddings + + # Combine: [atom; struct] -> H + combined = torch.cat([atom_emb, struct_emb], dim=-1) + ligand_emb = self.ligand_combine_embedding(combined) + + # Add bond information + ligand_emb = self.bond_embedding(ligand_emb, bond_matrix, ligand_mask) + + # Add modality embedding (ligand = 1) + modality_ids = torch.ones(batch_size, num_atoms, dtype=torch.long, device=device) + ligand_emb = ligand_emb + self.modality_embedding(modality_ids) + + return ligand_emb + + def forward( + self, + sequence_input_ids: Tensor, + structure_input_ids: Tensor | None, + attention_mask: Tensor, + conditioning_tensor: Tensor | None = None, + position_ids: Tensor | None = None, + # Ligand inputs (optional) + ligand_atom_input_ids: Tensor | None = None, + ligand_structure_input_ids: Tensor | None = None, + ligand_mask: Tensor | None = None, + bond_matrix: Tensor | None = None, + # Validity masks for mixed batches + protein_valid_mask: Tensor | None = None, + ligand_valid_mask: Tensor | None = None, + # Other + timesteps: Tensor | None = None, + return_auxiliary_tasks: bool = False, + # Continuous structure embeddings (for MAR-style generation) + structure_embeddings: Tensor | None = None, + ligand_structure_embeddings: Tensor | None = None, + **kwargs, + ) -> dict[str, Tensor]: + """Forward pass supporting both protein-only and protein-ligand inputs. + + Parameters + ---------- + sequence_input_ids : Tensor + Protein sequence token IDs with shape [B, N_res]. + structure_input_ids : Tensor, optional + Protein structure token IDs with shape [B, N_res]. Can be None if + structure_embeddings is provided. + attention_mask : Tensor + Attention mask for protein with shape [B, N_res]. + conditioning_tensor : Tensor, optional + Conditioning input with shape [B, N_res, C]. + position_ids : Tensor, optional + Position IDs (not currently used). + ligand_atom_input_ids : Tensor, optional + Ligand atom type IDs with shape [B, N_atoms]. + ligand_structure_input_ids : Tensor, optional + Ligand structure token IDs with shape [B, N_atoms]. + ligand_mask : Tensor, optional + Valid atom mask with shape [B, N_atoms]. + bond_matrix : Tensor, optional + Bond type matrix with shape [B, N_atoms, N_atoms]. + protein_valid_mask : Tensor, optional + Which samples have valid protein with shape [B]. + ligand_valid_mask : Tensor, optional + Which samples have valid ligand with shape [B]. + timesteps : Tensor, optional + Timesteps for flow matching. + return_auxiliary_tasks : bool + Whether to return auxiliary task outputs. + structure_embeddings : Tensor, optional + Pre-computed continuous protein structure embeddings [B, N_res, D]. + If provided, bypasses the structure_embedding layer. Used for + MAR-style iterative generation. + ligand_structure_embeddings : Tensor, optional + Pre-computed continuous ligand structure embeddings [B, N_atoms, D]. + + Returns + ------- + dict[str, Tensor] + Output dictionary containing: + - sequence_logits: [B, N_res, vocab_size] + - structure_logits: [B, N_res, struct_vocab_size] + - last_hidden_state: [B, N_total, H] + - ligand_atom_logits: [B, N_atoms, atom_vocab_size] (if ligand present) + - ligand_structure_logits: [B, N_atoms, struct_vocab_size] (if ligand present) + - bond_logits: [B, N_atoms, N_atoms, num_bond_types] (if ligand present) + """ + batch_size = sequence_input_ids.shape[0] + device = sequence_input_ids.device + seq_len = sequence_input_ids.shape[1] + + # Check if we have ligand inputs + has_ligand = ( + ligand_atom_input_ids is not None + and ligand_atom_input_ids.numel() > 0 + and ligand_atom_input_ids.shape[1] > 0 + ) + + # === EMBED PROTEIN === + if seq_len > 0: + protein_emb = self._embed_protein( + sequence_input_ids, + structure_input_ids, + conditioning_tensor, + structure_embeddings=structure_embeddings, + ) + else: + protein_emb = torch.empty(batch_size, 0, self.neobert.config.hidden_size, device=device) + + # === EMBED LIGAND (if present) === + if has_ligand: + num_atoms = ligand_atom_input_ids.shape[1] + if ligand_mask is None: + ligand_mask = torch.ones(batch_size, num_atoms, device=device) + if bond_matrix is None: + bond_matrix = torch.zeros(batch_size, num_atoms, num_atoms, dtype=torch.long, device=device) + + ligand_emb = self._embed_ligand( + ligand_atom_input_ids, + ligand_structure_input_ids, + bond_matrix, + ligand_mask, + ligand_structure_embeddings=ligand_structure_embeddings, + ) + else: + ligand_emb = torch.empty(batch_size, 0, self.neobert.config.hidden_size, device=device) + num_atoms = 0 + + # === CONCATENATE AND CREATE COMBINED MASK === + # [protein_emb; ligand_emb] -> [B, N_res + N_atoms, H] + combined_emb = torch.cat([protein_emb, ligand_emb], dim=1) + + # Combined attention mask + if has_ligand: + combined_mask = torch.cat([attention_mask, ligand_mask], dim=1) + else: + combined_mask = attention_mask + + # === RUN THROUGH TRANSFORMER === + # Position IDs are not used (NeoBERT handles internally) + neobert_output = self.neobert( + input_ids=None, + inputs_embeds=combined_emb, + position_ids=None, + attention_mask=combined_mask, + **kwargs, + ) + + hidden_state = neobert_output["last_hidden_state"] + + # === SPLIT HIDDEN STATE BACK === + protein_hidden = hidden_state[:, :seq_len, :] + ligand_hidden = hidden_state[:, seq_len:, :] if has_ligand else None + + # === COMPUTE PROTEIN OUTPUTS === + output = {} + if seq_len > 0: + output["sequence_logits"] = self.sequence_output(protein_hidden) + output["structure_logits"] = self.structure_output(protein_hidden) + else: + output["sequence_logits"] = torch.empty(batch_size, 0, self.sequence_token_vocab_size, device=device) + output["structure_logits"] = torch.empty(batch_size, 0, self.structure_token_vocab_size, device=device) + + output["last_hidden_state"] = hidden_state + # Expose protein and ligand hidden states for DiffusionLoss conditioning + output["protein_hidden_states"] = protein_hidden + + # === COMPUTE LIGAND OUTPUTS (if present) === + if has_ligand and ligand_hidden is not None: + output["ligand_atom_logits"] = self.ligand_atom_output(ligand_hidden) + output["ligand_structure_logits"] = self.ligand_structure_output(ligand_hidden) + output["bond_logits"] = self.bond_prediction_head(ligand_hidden, ligand_mask) + # Expose ligand hidden states for DiffusionLoss conditioning + output["ligand_hidden_states"] = ligand_hidden + else: + # Empty tensors for consistency + output["ligand_atom_logits"] = torch.empty(batch_size, 0, self.ligand_atom_vocab_size, device=device) + output["ligand_structure_logits"] = torch.empty( + batch_size, 0, self.ligand_structure_vocab_size, device=device + ) + output["bond_logits"] = torch.empty(batch_size, 0, 0, self.num_bond_types, device=device) + output["ligand_hidden_states"] = None + + # === AUXILIARY TASKS === + if self.auxiliary_tasks is not None and return_auxiliary_tasks: + for task_name, task_head in self.auxiliary_tasks.items(): + output[task_name] = task_head(hidden_state) + + return output diff --git a/src/lobster/model/gen_ume/_gen_ume_protein_ligand_lightning.py b/src/lobster/model/gen_ume/_gen_ume_protein_ligand_lightning.py new file mode 100644 index 00000000..7506d90c --- /dev/null +++ b/src/lobster/model/gen_ume/_gen_ume_protein_ligand_lightning.py @@ -0,0 +1,1856 @@ +"""Gen-UME Protein-Ligand Lightning Module. + +This module provides the PyTorch Lightning training wrapper for the +Gen-UME protein-ligand encoder. It handles: +- Multi-modal loss computation (protein_seq, protein_struct, ligand_atom, ligand_struct, bond) +- Mixed batch handling with validity masks +- Independent time sampling for different modalities +- Structure encoding/decoding via LatentGenerator + +Key features: +- Backward compatible with protein-only training +- Supports mixed protein-only + protein-ligand batches +- Flow matching for all modalities +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from lobster.model.losses import DiffusionLoss + +import torch +import torch.nn as nn +import transformers +from lightning import LightningModule +from torch import Tensor +from tqdm import tqdm + +from lobster.model.latent_generator.cmdline import LatentEncoderDecoder +from lobster.model.latent_generator.cmdline import methods as latent_generator_methods +from lobster.model.latent_generator.utils import apply_se3_augmentation_protein_ligand +from lobster.model.latent_generator.utils.residue_constants import NUM_BOND_TYPES + +from ._bond_prediction import BondMatrixLoss +from ._gen_ume_protein_ligand_encoder import AuxiliaryTask, ProteinLigandEncoderModule +from lobster.model.neobert._config import NEOBERT_CONFIGS + +# Bionemo interpolant code +from bionemo.moco.distributions.prior import DiscreteMaskedPrior, DiscreteUniformPrior +from bionemo.moco.distributions.time import UniformTimeDistribution +from bionemo.moco.interpolants import DiscreteFlowMatcher +from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule, LogInferenceSchedule + +logger = logging.getLogger(__name__) + + +class ProteinLigandEncoderLightningModule(LightningModule): + """PyTorch Lightning module for Gen-UME protein-ligand training. + + This module extends the original Gen-UME to support both protein-only + and protein-ligand tasks. When no ligand data is present in a batch, + it behaves identically to the original Gen-UME. + + Parameters + ---------- + mask_token_id : int + Mask token ID for protein sequences. + pad_token_id : int + Padding token ID for protein sequences. + vocab_size : int + Vocabulary size for protein sequences. + ligand_atom_vocab_size : int + Vocabulary size for ligand atom types (default: 25). + ligand_mask_token_id : int + Mask token ID for ligand atoms (default: 1 for MASK). + ligand_pad_token_id : int + Padding token ID for ligand atoms (default: 0 for PAD). + auxiliary_tasks : list[AuxiliaryTask], optional + List of auxiliary task configurations. + seed : int + Random seed. + lr : float + Learning rate. + encoder_kwargs : dict, optional + Additional kwargs for the encoder. + latent_generator_model_name : str + Name of the LatentGenerator model for structure encoding/decoding. + use_masked_prior : bool + Whether to use masked prior for flow matching. + inverse_folding : bool + Whether to train for inverse folding (structure -> sequence). + bond_loss_weight : float + Weight for bond prediction loss. + """ + + def __init__( + self, + mask_token_id: int, + pad_token_id: int, + vocab_size: int, + ligand_atom_vocab_size: int = 25, + ligand_mask_token_id: int = 1, + ligand_pad_token_id: int = 0, + num_bond_types: int = NUM_BOND_TYPES, + auxiliary_tasks: list[AuxiliaryTask] | None = None, + seed: int = 0, + lr: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.98, + eps: float = 1e-12, + num_warmup_steps: int = 20_000, + num_training_steps: int = 100_000, + weight_decay: float = 0.0, + scheduler: str = "constant", + scheduler_kwargs: dict | None = None, + encoder_kwargs: dict | None = None, + ckpt_path: str | None = None, + # LatentGenerator params + decode_tokens_during_training: bool = True, + latent_generator_model_name: str = "LG Protein Ligand fsq 4375", + # Generation params + prior_distribution_seq: Callable[..., DiscreteUniformPrior] = DiscreteUniformPrior, + prior_distribution_struc: Callable[..., DiscreteUniformPrior] = DiscreteUniformPrior, + prior_distribution_ligand_atom: Callable[..., DiscreteUniformPrior] = DiscreteUniformPrior, + prior_distribution_ligand_struc: Callable[..., DiscreteUniformPrior] = DiscreteUniformPrior, + time_distribution_seq: Callable[..., UniformTimeDistribution] = UniformTimeDistribution, + time_distribution_struc: Callable[..., UniformTimeDistribution] = UniformTimeDistribution, + time_distribution_ligand: Callable[..., UniformTimeDistribution] = UniformTimeDistribution, + interpolant: Callable[..., DiscreteFlowMatcher] = DiscreteFlowMatcher, + inference_schedule: Callable[..., LinearInferenceSchedule] = LinearInferenceSchedule, + use_masked_prior: bool = True, + inverse_folding: bool = False, + # Loss weights + bond_loss_weight: float = 1.0, + ligand_atom_loss_weight: float = 1.0, + ligand_struct_loss_weight: float = 1.0, + # Diffusion loss params for continuous structure tokens + use_diffusion_loss_structure: bool = False, + diffusion_target_dim: int = 256, # Structure embedding dim from LatentGenerator + diffusion_z_dim: int | None = None, # Transformer hidden dim (auto-detect if None) + diffusion_depth: int = 3, # MLP depth (from MAR: diffloss_d) + diffusion_width: int = 1024, # MLP width (from MAR: diffloss_w) + diffusion_num_sampling_steps: str = "100", + diffusion_noise_schedule: Literal["linear", "cosine"] = "cosine", + diffusion_loss_weight: float = 1.0, + # SE(3) augmentation + use_se3_augmentation: bool = True, + se3_translation_scale: float = 1.0, + ): + self.save_hyperparameters() + super().__init__() + + # Store config + self.mask_token_id = mask_token_id + self.pad_token_id = pad_token_id + self.vocab_size = vocab_size + self.ligand_atom_vocab_size = ligand_atom_vocab_size + self.ligand_mask_token_id = ligand_mask_token_id + self.ligand_pad_token_id = ligand_pad_token_id + self.num_bond_types = num_bond_types + + self.lr = lr + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + self.weight_decay = weight_decay + self.scheduler = scheduler + self.scheduler_kwargs = scheduler_kwargs or {} + self.seed = seed + + # Loss weights + self.bond_loss_weight = bond_loss_weight + self.ligand_atom_loss_weight = ligand_atom_loss_weight + self.ligand_struct_loss_weight = ligand_struct_loss_weight + + # SE(3) augmentation config + self.use_se3_augmentation = use_se3_augmentation + self.se3_translation_scale = se3_translation_scale + + # Auxiliary tasks + self.auxiliary_tasks = auxiliary_tasks + self.auxiliary_task_loss_fns = {"regression": nn.MSELoss()} + + # LatentGenerator for structure encoding/decoding + self.decode_tokens_during_training = decode_tokens_during_training + self.structure_latent_encoder_decoder = LatentEncoderDecoder() + + # Load from registered model name + logger.info(f"Loading LatentGenerator model: {latent_generator_model_name}") + self.structure_latent_encoder_decoder.load_model( + latent_generator_methods[latent_generator_model_name].model_config.checkpoint, + latent_generator_methods[latent_generator_model_name].model_config.config_path, + latent_generator_methods[latent_generator_model_name].model_config.config_name, + overrides=latent_generator_methods[latent_generator_model_name].model_config.overrides, + ) + self.quantizer = self.structure_latent_encoder_decoder.model.quantizer + self.structure_encoder = self.structure_latent_encoder_decoder.model.encoder + self.decoder_factory = self.structure_latent_encoder_decoder.model.decoder_factory + self.loss_factory = self.structure_latent_encoder_decoder.model.loss_factory + + # === CONTINUOUS/DISCRETE MODE SETUP === + # Check if using continuous mode (quantizer is None) + self.use_continuous_structure = self.quantizer is None + if self.use_continuous_structure: + # For continuous mode, we need a way to handle masking + # Use a simple vocab: 0 = visible (placeholder), 1 = mask, 2 = pad + # The actual embedding values come from continuous embeddings, not discrete tokens + self.structure_embed_dim = self.structure_encoder.embed_dim + self.num_struc_classes = 3 # placeholder, mask, pad + self.mask_index_struc_tokens = 1 + self.padding_index_struc_tokens = 2 + logger.info(f"Using CONTINUOUS structure embeddings (dim={self.structure_embed_dim})") + else: + self.structure_embed_dim = None + self.num_struc_classes = self.quantizer.n_tokens + 2 + self.mask_index_struc_tokens = self.quantizer.n_tokens + self.padding_index_struc_tokens = self.quantizer.n_tokens + 1 + logger.info(f"Using DISCRETE structure tokens (vocab={self.quantizer.n_tokens})") + + # === FLOW MATCHING SETUP === + self.inverse_folding = inverse_folding + + # Set up priors and interpolants + device = "cpu" # Will be moved to correct device during training + + # Protein sequence interpolant + if use_masked_prior: + prior_seq = DiscreteMaskedPrior(num_classes=self.vocab_size, mask_dim=self.mask_token_id, inclusive=True) + prior_struc = DiscreteMaskedPrior( + num_classes=self.num_struc_classes, mask_dim=self.mask_index_struc_tokens, inclusive=True + ) + prior_ligand_atom = DiscreteMaskedPrior( + num_classes=self.ligand_atom_vocab_size, mask_dim=self.ligand_mask_token_id, inclusive=True + ) + prior_ligand_struc = DiscreteMaskedPrior( + num_classes=self.num_struc_classes, mask_dim=self.mask_index_struc_tokens, inclusive=True + ) + else: + prior_seq = prior_distribution_seq(num_classes=self.vocab_size) + prior_struc = prior_distribution_struc(num_classes=self.num_struc_classes) + prior_ligand_atom = prior_distribution_ligand_atom(num_classes=self.ligand_atom_vocab_size) + prior_ligand_struc = prior_distribution_ligand_struc(num_classes=self.num_struc_classes) + + time_dist_seq = time_distribution_seq() + time_dist_struc = time_distribution_struc() + time_dist_ligand = time_distribution_ligand() + + # Create interpolants + self.interpolant_seq = interpolant(time_distribution=time_dist_seq, prior_distribution=prior_seq, device=device) + self.interpolant_struc = interpolant( + time_distribution=time_dist_struc, prior_distribution=prior_struc, device=device + ) + self.interpolant_ligand_atom = interpolant( + time_distribution=time_dist_ligand, prior_distribution=prior_ligand_atom, device=device + ) + self.interpolant_ligand_struc = interpolant( + time_distribution=time_dist_ligand, prior_distribution=prior_ligand_struc, device=device + ) + + self.inference_schedule = inference_schedule(nsteps=1000) + + logger.info(f"Initialized protein-ligand flow matching with masked_prior={use_masked_prior}") + + # === ENCODER === + self.encoder = ProteinLigandEncoderModule( + auxiliary_tasks=auxiliary_tasks, + sequence_token_vocab_size=self.vocab_size, + structure_token_vocab_size=self.num_struc_classes, + ligand_atom_vocab_size=self.ligand_atom_vocab_size, + ligand_structure_vocab_size=self.num_struc_classes, + sequence_token_pad_token_id=self.pad_token_id, + structure_token_pad_token_id=self.padding_index_struc_tokens, + ligand_atom_pad_token_id=self.ligand_pad_token_id, + ligand_structure_pad_token_id=self.padding_index_struc_tokens, + num_bond_types=self.num_bond_types, + model_ckpt=ckpt_path, + **encoder_kwargs or {}, + ) + + # === BOND LOSS === + self.bond_loss_fn = BondMatrixLoss(ignore_diagonal=True) + + # === DIFFUSION LOSS FOR STRUCTURE (Option A: Hybrid) === + self.use_diffusion_loss_structure = use_diffusion_loss_structure + self.diffusion_loss_weight = diffusion_loss_weight + self.diffusion_loss_protein_struc: DiffusionLoss | None = None + self.diffusion_loss_ligand_struc: DiffusionLoss | None = None + + if use_diffusion_loss_structure: + from lobster.model.losses import DiffusionLoss + + # Auto-detect transformer hidden dim if not provided + if diffusion_z_dim is not None: + z_dim = diffusion_z_dim + elif encoder_kwargs.get("model_size") in NEOBERT_CONFIGS: + z_dim = NEOBERT_CONFIGS[encoder_kwargs["model_size"]]["hidden_size"] + else: + z_dim = 768 # Default fallback + + # Initialize DiffusionLoss for protein structure + self.diffusion_loss_protein_struc = DiffusionLoss( + target_channels=diffusion_target_dim, + z_channels=z_dim, + depth=diffusion_depth, + width=diffusion_width, + num_sampling_steps=diffusion_num_sampling_steps, + noise_schedule=diffusion_noise_schedule, + ) + + # Initialize DiffusionLoss for ligand structure + self.diffusion_loss_ligand_struc = DiffusionLoss( + target_channels=diffusion_target_dim, + z_channels=z_dim, + depth=diffusion_depth, + width=diffusion_width, + num_sampling_steps=diffusion_num_sampling_steps, + noise_schedule=diffusion_noise_schedule, + ) + + # Initialize continuous structure projection in encoder + # This allows feeding back sampled embeddings during MAR-style generation + self.encoder.init_continuous_structure_proj(diffusion_target_dim) + + logger.info( + f"Using DiffusionLoss for structure tokens " + f"(target_dim={diffusion_target_dim}, z_dim={z_dim}, " + f"depth={diffusion_depth}, width={diffusion_width})" + ) + + def encode_structure( + self, x_gt: Tensor, mask: Tensor, residue_index: Tensor, return_continuous: bool = False + ) -> tuple[Tensor, Tensor, Tensor] | tuple[Tensor, Tensor, Tensor, Tensor]: + """Encode protein structure to tokens using LatentGenerator. + + Handles both FSQLigandTokenizer (returns dicts), standard quantizers, + and continuous mode (quantizer=None). + + Parameters + ---------- + x_gt : Tensor + Ground truth coordinates [B, L, 4*3]. + mask : Tensor + Residue mask [B, L]. + residue_index : Tensor + Residue indices [B, L]. + return_continuous : bool + If True, also return continuous embeddings for DiffusionLoss. + + Returns + ------- + x_quant : Tensor + Soft token distribution with shape [B, L, n_tokens] for FSQLigandTokenizer, + quantized logits for standard quantizers, or dummy tokens for continuous mode. + x_quant_emb : Tensor + Embeddings from encoder. + out_mask : Tensor + Mask for valid positions. + x_continuous : Tensor (only if return_continuous=True) + Raw continuous embeddings [B, L, embed_dim]. + """ + x_emb = self.structure_encoder(x_gt, mask, residue_index=residue_index) + + if self.use_continuous_structure: + # Continuous mode: no quantization + # Return dummy discrete tokens (all zeros, will be masked anyway) + # The continuous embeddings are what we actually use + B, L, D = x_emb.shape + x_quant = torch.zeros(B, L, self.num_struc_classes, device=x_emb.device) + x_quant[:, :, 0] = 1.0 # All "visible" placeholder token + x_quant_emb = x_emb + out_mask = mask + + if return_continuous: + return x_quant, x_quant_emb, out_mask, x_emb + return x_quant, x_quant_emb, out_mask + + # Handle FSQLigandTokenizer which returns dicts + result = self.quantizer.quantize(x_emb, mask=mask) + if isinstance(result[0], dict): + # FSQLigandTokenizer returns (tokens_dict, logits_dict, masks_dict) + # tokens_dict contains soft token distributions [B, L, n_tokens] + # logits_dict contains raw FSQ logits [B, L, 5] (FSQ levels) + tokens_dict, logits_dict, masks_dict = result + # Use tokens (vocabulary size) not logits (FSQ levels) for proper token extraction + x_quant = tokens_dict.get("protein_tokens", tokens_dict.get("ligand_tokens")) + x_quant_emb = x_emb # Use embeddings as-is + out_mask = masks_dict.get("protein_mask", masks_dict.get("ligand_mask")) + else: + # Standard quantizer returns (x_quant, x_quant_emb, mask) + x_quant, x_quant_emb, out_mask = result + + if return_continuous: + return x_quant, x_quant_emb, out_mask, x_emb + return x_quant, x_quant_emb, out_mask + + def encode_ligand_structure( + self, ligand_coords: Tensor, ligand_mask: Tensor, atom_indices: Tensor, return_continuous: bool = False + ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]: + """Encode ligand structure to tokens using LatentGenerator. + + Note: The ViT encoder accepts ligand_coords as a separate argument with shape [B, N_atoms, 3]. + + Parameters + ---------- + ligand_coords : Tensor + Ligand coordinates [B, N_atoms, 3]. + ligand_mask : Tensor + Ligand atom mask [B, N_atoms]. + atom_indices : Tensor + Atom indices [B, N_atoms]. + return_continuous : bool + If True, also return continuous embeddings for DiffusionLoss. + + Returns + ------- + x_quant_argmax : Tensor + Structure token indices with shape [B, N_atoms]. + out_mask : Tensor + Mask for valid positions. + x_continuous : Tensor (only if return_continuous=True) + Raw continuous embeddings [B, N_atoms, embed_dim]. + """ + # Ligand coords: [B, N_atoms, 3] - passed directly to encoder + # The encoder handles ligands separately via ligand_coords argument + x_emb = self.structure_encoder( + coords=None, # No protein coords + seq_mask=torch.zeros(ligand_coords.shape[0], 1, device=ligand_coords.device), # Dummy protein mask + ligand_coords=ligand_coords, + ligand_mask=ligand_mask, + ligand_residue_index=atom_indices, + ) + + if self.use_continuous_structure: + # Continuous mode: no quantization + # Return dummy discrete tokens (all masked) + B, N = ligand_coords.shape[:2] + x_quant_argmax = torch.full((B, N), self.mask_index_struc_tokens, device=ligand_coords.device) + x_quant_argmax[~ligand_mask.bool()] = self.padding_index_struc_tokens + out_mask = ligand_mask + + if return_continuous: + return x_quant_argmax, out_mask, x_emb + return x_quant_argmax, out_mask + + # Handle FSQLigandTokenizer which returns dicts + result = self.quantizer.quantize(x_emb, mask=None, ligand_mask=ligand_mask) + if isinstance(result[0], dict): + # FSQLigandTokenizer returns (tokens_dict, logits_dict, masks_dict) + # tokens_dict contains soft token distributions [B, L, n_tokens] + # logits_dict contains raw FSQ logits [B, L, 5] (FSQ levels) + tokens_dict, logits_dict, masks_dict = result + # Use tokens (vocabulary size) not logits (FSQ levels) for proper token extraction + x_quant = tokens_dict.get("ligand_tokens") + out_mask = masks_dict.get("ligand_mask") + else: + # Standard quantizer + x_quant, _, out_mask = result + + x_quant_argmax = torch.argmax(x_quant, dim=-1) + x_quant_argmax[~out_mask.bool()] = self.padding_index_struc_tokens + + if return_continuous: + return x_quant_argmax, out_mask, x_emb + return x_quant_argmax, out_mask + + def encode_protein_ligand_structure( + self, + protein_coords: Tensor, + protein_mask: Tensor, + protein_indices: Tensor, + ligand_coords: Tensor, + ligand_mask: Tensor, + ligand_indices: Tensor, + ligand_atom_types: Tensor | None = None, + bond_matrix: Tensor | None = None, + ) -> dict[str, Tensor]: + """Encode protein and ligand structure JOINTLY for proper interaction. + + This method encodes both protein and ligand together through the structure + encoder, allowing cross-attention between them. The embeddings are then + split and quantized separately. + + Parameters + ---------- + protein_coords : Tensor + Protein coordinates [B, L, n_atoms, 3]. + protein_mask : Tensor + Protein mask [B, L]. + protein_indices : Tensor + Protein residue indices [B, L]. + ligand_coords : Tensor + Ligand coordinates [B, N_atoms, 3]. + ligand_mask : Tensor + Ligand atom mask [B, N_atoms]. + ligand_indices : Tensor + Ligand atom indices [B, N_atoms]. + ligand_atom_types : Tensor, optional + Ligand atom type indices [B, N_atoms]. + bond_matrix : Tensor, optional + Bond matrix [B, N_atoms, N_atoms]. + + Returns + ------- + dict[str, Tensor] + Dictionary containing: + - "protein_tokens": Protein structure tokens [B, L] + - "protein_mask": Protein mask [B, L] + - "protein_embeddings": Continuous protein embeddings [B, L, D] + - "ligand_tokens": Ligand structure tokens [B, N_atoms] + - "ligand_mask": Ligand mask [B, N_atoms] + - "ligand_embeddings": Continuous ligand embeddings [B, N_atoms, D] + """ + L = protein_coords.shape[1] + + # Joint encoding: encode protein and ligand together + joint_emb = self.structure_encoder( + coords=protein_coords, + seq_mask=protein_mask, + residue_index=protein_indices, + ligand_coords=ligand_coords, + ligand_mask=ligand_mask, + ligand_residue_index=ligand_indices, + ligand_atom_types=ligand_atom_types, + ligand_bond_matrix=bond_matrix, + ) + # joint_emb shape: [B, L + N_atoms, D] + + # Split embeddings back into protein and ligand parts + protein_emb = joint_emb[:, :L, :] # [B, L, D] + ligand_emb = joint_emb[:, L:, :] # [B, N_atoms, D] + + # Quantize protein embeddings + if self.use_continuous_structure: + # Continuous mode: no quantization, dummy tokens + B, L_prot, D = protein_emb.shape + protein_logits = torch.zeros(B, L_prot, self.num_struc_classes, device=protein_emb.device) + protein_logits[:, :, 0] = 1.0 + protein_out_mask = protein_mask + else: + # Discrete mode: quantize + result = self.quantizer.quantize(protein_emb, mask=protein_mask) + if isinstance(result[0], dict): + tokens_dict, _, masks_dict = result + protein_logits = tokens_dict.get("protein_tokens", tokens_dict.get("ligand_tokens")) + protein_out_mask = masks_dict.get("protein_mask", masks_dict.get("ligand_mask")) + else: + protein_logits, _, protein_out_mask = result + + protein_tokens = torch.argmax(protein_logits, dim=-1) + protein_tokens[~protein_out_mask.bool()] = self.padding_index_struc_tokens + + # Quantize ligand embeddings + if self.use_continuous_structure: + # Continuous mode: dummy discrete tokens + B, N = ligand_coords.shape[:2] + ligand_tokens = torch.full((B, N), self.mask_index_struc_tokens, device=ligand_coords.device) + ligand_tokens[~ligand_mask.bool()] = self.padding_index_struc_tokens + ligand_out_mask = ligand_mask + else: + # Discrete mode: quantize + result = self.quantizer.quantize(ligand_emb, mask=None, ligand_mask=ligand_mask) + if isinstance(result[0], dict): + tokens_dict, _, masks_dict = result + ligand_logits = tokens_dict.get("ligand_tokens") + ligand_out_mask = masks_dict.get("ligand_mask") + else: + ligand_logits, _, ligand_out_mask = result + ligand_tokens = torch.argmax(ligand_logits, dim=-1) + ligand_tokens[~ligand_out_mask.bool()] = self.padding_index_struc_tokens + + return { + "protein_tokens": protein_tokens, + "protein_mask": protein_out_mask, + "protein_embeddings": protein_emb, + "ligand_tokens": ligand_tokens, + "ligand_mask": ligand_out_mask, + "ligand_embeddings": ligand_emb, + } + + def decode_structure( + self, + generated_output: dict[str, Tensor], + mask: Tensor, + ligand_mask: Tensor | None = None, + ) -> dict[str, Tensor]: + """Decode protein and ligand structure tokens back to 3D coordinates. + + This method handles both protein-only and protein-ligand decoding. + When ligand data is present, it passes both to the vit_decoder which + returns a dict with "protein_coords" and "ligand_coords". + + For continuous mode (use_diffusion_loss_structure=True), pass continuous + embeddings directly via "structure_embeddings" key. + + Parameters + ---------- + generated_output : dict + Output from forward pass or generate_sample, containing structure_logits + and optionally ligand_structure_logits. For continuous mode, use + "structure_embeddings" and "ligand_structure_embeddings" keys. + mask : Tensor + [B, L] mask for valid protein positions + ligand_mask : Tensor, optional + [B, N_atoms] mask for valid ligand positions + + Returns + ------- + dict + Dictionary with decoder name as key and decoded coords as value. + When ligand is present, value is a dict with "protein_coords" and "ligand_coords". + e.g., {"vit_decoder": {"protein_coords": Tensor[B, L, 3, 3], "ligand_coords": Tensor[B, N, 3]}} + """ + decoder_name = "vit_decoder" + decoded_x = {} + + # Check for continuous mode (direct embeddings instead of logits) + if "structure_embeddings" in generated_output: + # Continuous mode: use embeddings directly + struc_tokens_ = generated_output["structure_embeddings"] + has_ligand = "ligand_structure_embeddings" in generated_output and ligand_mask is not None + + if has_ligand: + lig_struc_tokens_ = generated_output["ligand_structure_embeddings"] + x_quant = { + "protein_tokens": struc_tokens_, + "ligand_tokens": lig_struc_tokens_, + } + mask_dict = { + "protein_mask": mask, + "ligand_mask": ligand_mask, + } + decoded_x[decoder_name] = self.decoder_factory.decoders[decoder_name](x_quant, mask_dict) + else: + decoded_x[decoder_name] = self.decoder_factory.decoders[decoder_name](struc_tokens_, mask) + + return decoded_x + + # Discrete mode: use logits + # Get protein structure logits + if "structure_logits" in generated_output: + struc_logits = generated_output["structure_logits"] + else: + raise ValueError("No structure_logits or structure_embeddings found in output") + + # Handle continuous mode where quantizer may be None + if self.use_continuous_structure or self.quantizer is None: + # For continuous mode with logits (shouldn't normally happen, but handle gracefully) + struc_tokens_ = struc_logits + else: + # Slice to only valid token indices (exclude mask/pad tokens) + struc_tokens = struc_logits[..., : self.quantizer.n_tokens] + # Apply softmax with temperature for soft decoding + temp = 0.1 + struc_tokens_ = torch.softmax(struc_tokens / temp, dim=-1) + + # Check if ligand data is present + has_ligand = "ligand_structure_logits" in generated_output and ligand_mask is not None and ligand_mask.sum() > 0 + + if has_ligand: + # Get ligand structure logits + lig_struc_logits = generated_output["ligand_structure_logits"] + + # Handle continuous mode where quantizer may be None + if self.use_continuous_structure or self.quantizer is None: + lig_struc_tokens_ = lig_struc_logits + else: + lig_struc_tokens = lig_struc_logits[..., : self.quantizer.n_tokens] + temp = 0.1 + lig_struc_tokens_ = torch.softmax(lig_struc_tokens / temp, dim=-1) + + # Prepare dict input for vit_decoder (matching TokenizerMulti.decode pattern) + x_quant = { + "protein_tokens": struc_tokens_, + "ligand_tokens": lig_struc_tokens_, + } + mask_dict = { + "protein_mask": mask, + "ligand_mask": ligand_mask, + } + + # Decode through vit_decoder - returns {"protein_coords": ..., "ligand_coords": ...} + decoded_x[decoder_name] = self.decoder_factory.decoders[decoder_name](x_quant, mask_dict) + else: + # Protein-only decoding + decoded_x[decoder_name] = self.decoder_factory.decoders[decoder_name](struc_tokens_, mask) + + return decoded_x + + def extract_ligand_predictions(self, generated_output: dict[str, Tensor], ligand_mask: Tensor) -> dict[str, Tensor]: + """Extract ligand atom type and bond predictions from model output. + + Note: Coordinate decoding is done by decode_structure() which handles + both protein and ligand together using vit_decoder. This method extracts + the discrete predictions (atom types, bonds) from the flow matching output. + + Parameters + ---------- + generated_output : dict + Output containing ligand_atom_logits and bond_logits + ligand_mask : Tensor + [B, N_atoms] mask for valid ligand atoms + + Returns + ------- + dict + Dictionary with ligand predictions: + - "coords": None (populated later from decode_structure output) + - "atom_types": Tensor[B, N_atoms] (argmax of atom logits) + - "bond_matrix": Tensor[B, N_atoms, N_atoms] (argmax of bond logits) + """ + ligand_predictions = {} + + # Coordinates are decoded by decode_structure() - placeholder for caller to fill + ligand_predictions["coords"] = None + + # Get atom types from atom logits + if "ligand_atom_logits" in generated_output: + ligand_predictions["atom_types"] = generated_output["ligand_atom_logits"].argmax(dim=-1) + + # Get bond matrix from bond logits + if "bond_logits" in generated_output: + ligand_predictions["bond_matrix"] = generated_output["bond_logits"].argmax(dim=-1) + + return ligand_predictions + + def get_timesteps(self, batch_size: int, has_ligand: bool = False) -> dict[str, Tensor]: + """Sample timesteps for all modalities.""" + timesteps = { + "sequence_tokens": self.interpolant_seq.sample_time(batch_size), + "structure_tokens": self.interpolant_struc.sample_time(batch_size), + } + if has_ligand: + timesteps["ligand_atom_tokens"] = self.interpolant_ligand_atom.sample_time(batch_size) + timesteps["ligand_structure_tokens"] = self.interpolant_ligand_struc.sample_time(batch_size) + return timesteps + + def interpolate_tokens(self, input_tokens: dict[str, Tensor], timesteps: dict[str, Tensor]) -> dict[str, Tensor]: + """Interpolate tokens for flow matching.""" + x_t = {} + + # Protein sequence + x_1_seq = input_tokens["sequence_tokens"] + x_0_seq = self.interpolant_seq.sample_prior(x_1_seq.shape) + t_seq = timesteps["sequence_tokens"] + x_t["sequence_tokens"] = self.interpolant_seq.interpolate(x_1_seq, t_seq, x_0_seq) + + # Protein structure + x_1_struc = input_tokens["structure_tokens"] + x_0_struc = self.interpolant_struc.sample_prior(x_1_struc.shape) + t_struc = timesteps["structure_tokens"] + if self.inverse_folding: + t_struc = torch.ones_like(t_struc) + x_t["structure_tokens"] = self.interpolant_struc.interpolate(x_1_struc, t_struc, x_0_struc) + + # Ligand atom types (if present) + if "ligand_atom_tokens" in input_tokens: + x_1_lig_atom = input_tokens["ligand_atom_tokens"] + x_0_lig_atom = self.interpolant_ligand_atom.sample_prior(x_1_lig_atom.shape) + t_lig = timesteps["ligand_atom_tokens"] + x_t["ligand_atom_tokens"] = self.interpolant_ligand_atom.interpolate(x_1_lig_atom, t_lig, x_0_lig_atom) + + # Ligand structure (if present) + if "ligand_structure_tokens" in input_tokens: + x_1_lig_struc = input_tokens["ligand_structure_tokens"] + x_0_lig_struc = self.interpolant_ligand_struc.sample_prior(x_1_lig_struc.shape) + t_lig_struc = timesteps["ligand_structure_tokens"] + x_t["ligand_structure_tokens"] = self.interpolant_ligand_struc.interpolate( + x_1_lig_struc, t_lig_struc, x_0_lig_struc + ) + + return x_t + + def forward( + self, + x_t: dict[str, Tensor], + mask: Tensor, + residue_index: Tensor, + conditioning_tensor: Tensor, + timesteps: dict[str, Tensor] | None = None, + # Ligand inputs + ligand_mask: Tensor | None = None, + bond_matrix: Tensor | None = None, + protein_valid_mask: Tensor | None = None, + ligand_valid_mask: Tensor | None = None, + # Continuous structure embeddings (for MAR-style generation) + structure_embeddings: Tensor | None = None, + ligand_structure_embeddings: Tensor | None = None, + ) -> dict[str, Tensor]: + """Forward pass through encoder. + + Parameters + ---------- + x_t : dict[str, Tensor] + Input tokens (sequence, structure, optionally ligand). + mask : Tensor + Attention mask. + residue_index : Tensor + Residue indices. + conditioning_tensor : Tensor + Conditioning input. + timesteps : dict[str, Tensor], optional + Timesteps for flow matching. + ligand_mask : Tensor, optional + Ligand mask. + bond_matrix : Tensor, optional + Bond matrix. + protein_valid_mask : Tensor, optional + Valid protein mask. + ligand_valid_mask : Tensor, optional + Valid ligand mask. + structure_embeddings : Tensor, optional + Pre-computed continuous protein structure embeddings [B, L, D]. + Used for MAR-style iterative generation where previously sampled + embeddings are fed back into the model. + ligand_structure_embeddings : Tensor, optional + Pre-computed continuous ligand structure embeddings [B, N_atoms, D]. + """ + # Expand timesteps if provided + if timesteps is not None: + timesteps = timesteps.copy() + B, L = x_t["sequence_tokens"].shape + timesteps["sequence_tokens"] = timesteps["sequence_tokens"][:, None].expand(-1, L)[:, :, None] + timesteps["structure_tokens"] = timesteps["structure_tokens"][:, None].expand(-1, L)[:, :, None] + + # Extract ligand tokens if present + ligand_atom_input_ids = x_t.get("ligand_atom_tokens") + ligand_structure_input_ids = x_t.get("ligand_structure_tokens") + + output = self.encoder( + sequence_input_ids=x_t["sequence_tokens"], + structure_input_ids=x_t["structure_tokens"], + attention_mask=mask, + conditioning_tensor=conditioning_tensor, + ligand_atom_input_ids=ligand_atom_input_ids, + ligand_structure_input_ids=ligand_structure_input_ids, + ligand_mask=ligand_mask, + bond_matrix=bond_matrix, + protein_valid_mask=protein_valid_mask, + ligand_valid_mask=ligand_valid_mask, + timesteps=timesteps, + structure_embeddings=structure_embeddings, + ligand_structure_embeddings=ligand_structure_embeddings, + ) + + return output + + def compute_loss( + self, + split: str, + x_gt: dict[str, Tensor], + output: dict[str, Tensor], + timesteps: dict[str, Tensor], + mask: Tensor | None = None, + ligand_mask: Tensor | None = None, + bond_matrix_gt: Tensor | None = None, + protein_valid_mask: Tensor | None = None, + ligand_valid_mask: Tensor | None = None, + decoder_gt: dict[str, Tensor] | None = None, + # DiffusionLoss params (for continuous structure) + structure_embeddings_gt: Tensor | None = None, + ligand_structure_embeddings_gt: Tensor | None = None, + x_t_structure: Tensor | None = None, + x_t_ligand_structure: Tensor | None = None, + ) -> tuple[Tensor, dict[str, Tensor]]: + """Compute multi-modal loss. + + Parameters + ---------- + split : str + Split name (train/val). + x_gt : dict[str, Tensor] + Ground truth tokens. + output : dict[str, Tensor] + Model output logits. + timesteps : dict[str, Tensor] + Timesteps for flow matching. + mask : Tensor, optional + Protein residue mask. + ligand_mask : Tensor, optional + Ligand atom mask. + bond_matrix_gt : Tensor, optional + Ground truth bond matrix. + protein_valid_mask : Tensor, optional + Mask for valid protein samples in mixed batch. + ligand_valid_mask : Tensor, optional + Mask for valid ligand samples in mixed batch. + decoder_gt : dict[str, Tensor], optional + Ground truth batch data for decoder loss (should contain coords, ligand_coords, etc.). + structure_embeddings_gt : Tensor, optional + Continuous structure embeddings for DiffusionLoss [B, L, D]. + ligand_structure_embeddings_gt : Tensor, optional + Continuous ligand structure embeddings for DiffusionLoss [B, M, D]. + x_t_structure : Tensor, optional + Interpolated structure tokens for mask derivation [B, L]. + x_t_ligand_structure : Tensor, optional + Interpolated ligand structure tokens for mask derivation [B, M]. + + Returns + ------- + tuple[Tensor, dict[str, Tensor]] + Total loss and loss dictionary. + """ + total_loss = torch.tensor(0.0, device=output["sequence_logits"].device) + loss_dict = {} + + # === PROTEIN LOSSES === + # Sequence loss (always discrete CE) + loss_seq = self.interpolant_seq.loss( + output["sequence_logits"], x_gt["sequence_tokens"], timesteps["sequence_tokens"] + ) + if protein_valid_mask is not None: + loss_seq = (loss_seq.mean(dim=-1) * protein_valid_mask.float()).sum() / ( + protein_valid_mask.float().sum() + 1e-8 + ) + else: + loss_seq = loss_seq.mean() + + # Structure loss - Option A: Hybrid (DiffusionLoss if enabled) + # Initialize predicted embeddings for later use in decoding + pred_structure_embeddings = None + pred_ligand_structure_embeddings = None + + if self.use_diffusion_loss_structure and structure_embeddings_gt is not None: + # Use DiffusionLoss for structure tokens + # Get hidden states for conditioning (need to expose from encoder) + structure_hidden = output.get("protein_hidden_states", output.get("last_hidden_state")) + if structure_hidden is not None: + # Extract protein portion if full hidden state + if "protein_hidden_states" not in output and mask is not None: + B, L = mask.shape + structure_hidden = structure_hidden[:, :L, :] + + # Create diffusion mask from flow matching: mask=1 where position was corrupted + if x_t_structure is not None: + diffusion_mask = (x_t_structure == self.mask_index_struc_tokens).float() + else: + # Fallback: all positions if interpolated tokens not available + diffusion_mask = torch.ones_like(mask).float() + + # Apply validity masks + if protein_valid_mask is not None: + diffusion_mask = diffusion_mask * protein_valid_mask[:, None].float() + if mask is not None: + diffusion_mask = diffusion_mask * mask.float() + + # Compute diffusion loss and get predicted embeddings for decoding + loss_struc, pred_structure_embeddings = self.diffusion_loss_protein_struc( + target=structure_embeddings_gt, + z=structure_hidden, + mask=diffusion_mask, + return_pred=True, + ) + loss_struc = loss_struc * self.diffusion_loss_weight + else: + # Fallback if hidden states not available + logger.warning("No hidden states for DiffusionLoss, falling back to discrete loss") + loss_struc = self.interpolant_struc.loss( + output["structure_logits"], x_gt["structure_tokens"], timesteps["structure_tokens"] + ).mean() + else: + # Original discrete flow matching loss + loss_struc = self.interpolant_struc.loss( + output["structure_logits"], x_gt["structure_tokens"], timesteps["structure_tokens"] + ) + if protein_valid_mask is not None: + loss_struc = (loss_struc.mean(dim=-1) * protein_valid_mask.float()).sum() / ( + protein_valid_mask.float().sum() + 1e-8 + ) + else: + loss_struc = loss_struc.mean() + + total_loss = total_loss + loss_seq + loss_struc + loss_dict[f"{split}_loss_seq"] = loss_seq + loss_dict[f"{split}_loss_struc"] = loss_struc + + # === LIGAND LOSSES (if present) === + has_ligand = "ligand_atom_tokens" in x_gt and output["ligand_atom_logits"].shape[1] > 0 + + if has_ligand: + # Ligand atom type loss + loss_lig_atom = self.interpolant_ligand_atom.loss( + output["ligand_atom_logits"], + x_gt["ligand_atom_tokens"], + timesteps["ligand_atom_tokens"], + ) + if ligand_valid_mask is not None: + loss_lig_atom = (loss_lig_atom.mean(dim=-1) * ligand_valid_mask.float()).sum() / ( + ligand_valid_mask.float().sum() + 1e-8 + ) + else: + loss_lig_atom = loss_lig_atom.mean() + + # Ligand structure loss - Option A: Hybrid (DiffusionLoss if enabled) + if self.use_diffusion_loss_structure and ligand_structure_embeddings_gt is not None: + # Use DiffusionLoss for ligand structure + ligand_hidden = output.get("ligand_hidden_states") + if ligand_hidden is None: + # Extract from full hidden state + full_hidden = output.get("last_hidden_state") + if full_hidden is not None and mask is not None: + B, L = mask.shape + ligand_hidden = full_hidden[:, L:, :] + + if ligand_hidden is not None: + # Create diffusion mask from flow matching + if x_t_ligand_structure is not None: + lig_diffusion_mask = (x_t_ligand_structure == self.mask_index_struc_tokens).float() + else: + lig_diffusion_mask = torch.ones_like(ligand_mask).float() + + # Apply validity masks + if ligand_valid_mask is not None: + lig_diffusion_mask = lig_diffusion_mask * ligand_valid_mask[:, None].float() + if ligand_mask is not None: + lig_diffusion_mask = lig_diffusion_mask * ligand_mask.float() + + # Compute diffusion loss for ligand structure and get predicted embeddings + loss_lig_struc, pred_ligand_structure_embeddings = self.diffusion_loss_ligand_struc( + target=ligand_structure_embeddings_gt, + z=ligand_hidden, + mask=lig_diffusion_mask, + return_pred=True, + ) + loss_lig_struc = loss_lig_struc * self.diffusion_loss_weight + else: + # Fallback if hidden states not available + loss_lig_struc = self.interpolant_ligand_struc.loss( + output["ligand_structure_logits"], + x_gt["ligand_structure_tokens"], + timesteps["ligand_structure_tokens"], + ).mean() + else: + # Original discrete flow matching loss + loss_lig_struc = self.interpolant_ligand_struc.loss( + output["ligand_structure_logits"], + x_gt["ligand_structure_tokens"], + timesteps["ligand_structure_tokens"], + ) + if ligand_valid_mask is not None: + loss_lig_struc = (loss_lig_struc.mean(dim=-1) * ligand_valid_mask.float()).sum() / ( + ligand_valid_mask.float().sum() + 1e-8 + ) + else: + loss_lig_struc = loss_lig_struc.mean() + + # Bond prediction loss + if bond_matrix_gt is not None: + loss_bond = self.bond_loss_fn(output["bond_logits"], bond_matrix_gt, ligand_mask) + else: + loss_bond = torch.tensor(0.0, device=total_loss.device) + + total_loss = total_loss + ( + self.ligand_atom_loss_weight * loss_lig_atom + + self.ligand_struct_loss_weight * loss_lig_struc + + self.bond_loss_weight * loss_bond + ) + + loss_dict[f"{split}_loss_lig_atom"] = loss_lig_atom + loss_dict[f"{split}_loss_lig_struc"] = loss_lig_struc + loss_dict[f"{split}_loss_bond"] = loss_bond + + # === STRUCTURE DECODER LOSSES (if enabled) === + if self.decode_tokens_during_training and decoder_gt is not None: + if mask is not None: + with torch.no_grad(): + # In continuous mode, use predicted embeddings from diffusion loss + # This allows tracking reconstruction learning during training + decode_output = output + if self.use_continuous_structure: + decode_output = dict(output) # Copy to avoid modifying original + # Use predicted embeddings if available, otherwise fall back to ground truth + if pred_structure_embeddings is not None: + decode_output["structure_embeddings"] = pred_structure_embeddings + elif structure_embeddings_gt is not None: + decode_output["structure_embeddings"] = structure_embeddings_gt + if pred_ligand_structure_embeddings is not None: + decode_output["ligand_structure_embeddings"] = pred_ligand_structure_embeddings + elif ligand_structure_embeddings_gt is not None: + decode_output["ligand_structure_embeddings"] = ligand_structure_embeddings_gt + + # Decode both protein and ligand together using vit_decoder + decoded_x = self.decode_structure( + decode_output, mask, ligand_mask=ligand_mask if has_ligand else None + ) + + # Get the vit_decoder output + vit_output = decoded_x.get("vit_decoder") + + # Check if output is a dict (protein + ligand) or tensor (protein only) + if isinstance(vit_output, dict): + # Unified protein + ligand decode output + protein_coords = vit_output.get("protein_coords") + ligand_coords = vit_output.get("ligand_coords") + + # Apply protein decoder loss (only if we have protein ground truth) + if protein_coords is not None and decoder_gt.get("coords_res") is not None: + protein_decoded_x = {"vit_decoder": protein_coords} + total_loss, loss_dict = self.apply_structure_decoder_loss( + split, decoder_gt, protein_decoded_x, mask, total_loss, loss_dict, prefix="protein_" + ) + + # Apply ligand decoder loss + if has_ligand and ligand_coords is not None and ligand_mask is not None: + ligand_decoder_gt = { + "ligand_coords": decoder_gt.get("ligand_coords"), + "coords_res": decoder_gt.get("coords_res"), # For joint alignment in loss + } + ligand_decoded_x = { + "vit_decoder": {"ligand_coords": ligand_coords, "protein_coords": protein_coords} + } + # Create mask dict for ligand loss computation + mask_dict = {"protein_mask": mask, "ligand_mask": ligand_mask} + total_loss, loss_dict = self.apply_structure_decoder_loss( + split, + ligand_decoder_gt, + ligand_decoded_x, + mask_dict, + total_loss, + loss_dict, + prefix="ligand_", + ) + else: + # Protein-only output (backward compatible, only if we have protein ground truth) + if decoder_gt.get("coords_res") is not None: + total_loss, loss_dict = self.apply_structure_decoder_loss( + split, decoder_gt, decoded_x, mask, total_loss, loss_dict, prefix="protein_" + ) + + return total_loss, loss_dict + + def apply_structure_decoder_loss( + self, + split: str, + decoder_gt: dict[str, Tensor], + decoded_x: dict[str, Tensor], + mask: Tensor, + total_loss: Tensor, + loss_dict: dict[str, Tensor], + just_loss: bool = False, + keep_batch_dim: bool = False, + prefix: str = "", + ) -> tuple[Tensor, dict[str, Tensor]]: + """Apply the structure decoder loss to the model for protein and/or ligand. + + Parameters + ---------- + split : str + Split name (train/val). + decoder_gt : dict[str, Tensor] + Ground truth decoder targets (e.g., coordinates). + decoded_x : dict[str, Tensor] + Decoded outputs from decoder_factory. + mask : Tensor + Mask for valid tokens. + total_loss : Tensor + Current accumulated loss. + loss_dict : dict[str, Tensor] + Dictionary to store individual losses. + just_loss : bool + If True, return only the loss without accumulating. + keep_batch_dim : bool + If True, keep batch dimension in loss computation. + prefix : str + Prefix for loss names (e.g., "protein_" or "ligand_"). + + Returns + ------- + tuple[Tensor, dict[str, Tensor]] + Updated total loss and loss dictionary. + """ + decoder_name = "vit_decoder" + loss2apply = self.decoder_factory.get_loss(decoder_name) + + # Filter losses based on data type to avoid misleading metrics + # Protein losses: l2_loss, pairwise_l2_loss + # Ligand losses: ligand_l2_loss, ligand_pairwise_l2_loss + if prefix == "protein_": + loss2apply = [l for l in loss2apply if not l.startswith("ligand_")] + elif prefix == "ligand_": + loss2apply = [l for l in loss2apply if l.startswith("ligand_")] + + for loss2apply_ in loss2apply: + loss = self.loss_factory( + loss2apply_, decoder_gt, decoded_x[decoder_name], mask, keep_batch_dim=keep_batch_dim + ) + if just_loss: + return loss, loss_dict + # Apply loss weighting from weight_dict in loss_factory; setting to 0 for now + # as we need a different way to set weights for different losses + # total_loss += self.loss_factory.weight_dict[loss2apply_] * loss + total_loss += 0 * loss + loss_dict[f"{split}_{prefix}{loss2apply_}"] = loss + + return total_loss, loss_dict + + def step(self, batch: dict[str, Tensor], batch_idx: int, split: Literal["train", "val"] = "train") -> dict: + """Single training/validation step. + + Handles three cases: + 1. Protein-only batches (from PDB) + 2. Protein-ligand batches (from PDBBind, SAIR) + 3. Ligand-only batches (from GEOM) - still valuable for learning ligand conformations + """ + # Get device from whatever tensor is available + if "sequence" in batch: + device = batch["sequence"].device + elif "ligand_coords" in batch: + device = batch["ligand_coords"].device + else: + device = next(self.parameters()).device + + # Set device for interpolants + self.interpolant_seq.device = device + self.interpolant_struc.device = device + self.interpolant_ligand_atom.device = device + self.interpolant_ligand_struc.device = device + + # Get validity masks if present (mixed batches) + protein_valid_mask = batch.get("protein_valid_mask") + ligand_valid_mask = batch.get("ligand_valid_mask") + + # Check if batch has protein and ligand data + has_protein = "sequence" in batch and batch["sequence"].numel() > 0 + has_ligand = "ligand_coords" in batch and batch["ligand_coords"].numel() > 0 + + # === APPLY SE(3) AUGMENTATION (training only) === + if self.use_se3_augmentation and split == "train": + with torch.no_grad(): + protein_coords = batch.get("coords_res") if has_protein else None + protein_mask = batch.get("mask") if has_protein else None + ligand_coords = batch.get("ligand_coords") if has_ligand else None + ligand_mask = batch.get("ligand_mask") if has_ligand else None + + if protein_coords is not None or ligand_coords is not None: + # Apply same SE(3) transform to protein and ligand jointly + augmented = apply_se3_augmentation_protein_ligand( + protein_coords=protein_coords.clone() if protein_coords is not None else None, + protein_mask=protein_mask, + ligand_coords=ligand_coords.clone() if ligand_coords is not None else None, + ligand_mask=ligand_mask, + random_se3=True, + translation_scale=self.se3_translation_scale, + backbone_noise=0.0, + ) + if has_protein: + batch["coords_res"] = augmented.protein_coords + if has_ligand: + batch["ligand_coords"] = augmented.ligand_coords + + # === ENCODE GROUND TRUTH === + x_gt = {} + mask = None + seq_gt = None + ligand_mask = None + bond_matrix_gt = None + structure_embeddings_gt = None + ligand_structure_embeddings_gt = None + + with torch.no_grad(): + # Get data from batch + protein_coords = batch.get("coords_res") if has_protein else None + protein_mask = batch.get("mask") if has_protein else None + protein_indices = batch.get("indices") if has_protein else None + ligand_coords = batch.get("ligand_coords") if has_ligand else None + ligand_mask = batch.get("ligand_mask") if has_ligand else None + ligand_indices = batch.get("ligand_indices") if has_ligand else None + ligand_atom_types = batch.get("ligand_element_indices") if has_ligand else None + bond_matrix_gt = batch.get("bond_matrix") if has_ligand else None + + if has_protein: + seq_gt = batch["sequence"] + + # Joint encoding when both protein and ligand are present + if has_protein and has_ligand: + encoded = self.encode_protein_ligand_structure( + protein_coords=protein_coords, + protein_mask=protein_mask, + protein_indices=protein_indices, + ligand_coords=ligand_coords, + ligand_mask=ligand_mask, + ligand_indices=ligand_indices, + ligand_atom_types=ligand_atom_types, + bond_matrix=bond_matrix_gt, + ) + + # Extract results + struc_gt = encoded["protein_tokens"] + mask = encoded["protein_mask"] + structure_embeddings_gt = encoded["protein_embeddings"] + + ligand_struc_gt = encoded["ligand_tokens"] + ligand_mask = encoded["ligand_mask"] + ligand_structure_embeddings_gt = encoded["ligand_embeddings"] + + # Get ligand atom types + ligand_atom_gt = ( + ligand_atom_types if ligand_atom_types is not None else torch.full_like(ligand_indices, 3) + ) + + x_gt["sequence_tokens"] = seq_gt + x_gt["structure_tokens"] = struc_gt + x_gt["ligand_atom_tokens"] = ligand_atom_gt + x_gt["ligand_structure_tokens"] = ligand_struc_gt + + elif has_protein: + # Protein-only: use existing encode_structure method + if self.use_diffusion_loss_structure: + result = self.encode_structure( + protein_coords, protein_mask, protein_indices, return_continuous=True + ) + x_quant, _, mask, structure_embeddings_gt = result + else: + x_quant, _, mask = self.encode_structure(protein_coords, protein_mask, protein_indices) + + struc_gt = torch.argmax(x_quant, dim=-1) + struc_gt[~mask.bool()] = self.padding_index_struc_tokens + + x_gt["sequence_tokens"] = seq_gt + x_gt["structure_tokens"] = struc_gt + + elif has_ligand: + # Ligand-only: use existing encode_ligand_structure method + if ligand_atom_types is None: + ligand_atom_gt = torch.full_like(ligand_indices, 3) + else: + ligand_atom_gt = ligand_atom_types + + if self.use_diffusion_loss_structure: + ligand_struc_gt, _, ligand_structure_embeddings_gt = self.encode_ligand_structure( + ligand_coords, ligand_mask, ligand_indices, return_continuous=True + ) + else: + ligand_struc_gt, _ = self.encode_ligand_structure(ligand_coords, ligand_mask, ligand_indices) + + x_gt["ligand_atom_tokens"] = ligand_atom_gt + x_gt["ligand_structure_tokens"] = ligand_struc_gt + + # Get batch size from available data + if has_protein: + B = seq_gt.shape[0] + L = seq_gt.shape[1] + else: + # Ligand-only batch + B = batch["ligand_coords"].shape[0] + L = 1 # Dummy length for protein - we'll create empty protein tensors + + # For ligand-only batches, create dummy protein tensors + if not has_protein: + # Create minimal dummy protein data (will be masked out in loss) + seq_gt = torch.full((B, L), self.pad_token_id, device=device, dtype=torch.long) + struc_gt = torch.full((B, L), self.padding_index_struc_tokens, device=device, dtype=torch.long) + mask = torch.zeros((B, L), device=device) + x_gt["sequence_tokens"] = seq_gt + x_gt["structure_tokens"] = struc_gt + # Set protein_valid_mask to all zeros for ligand-only batches + protein_valid_mask = torch.zeros(B, device=device, dtype=torch.bool) + + # === SAMPLE TIMESTEPS AND INTERPOLATE === + timesteps = self.get_timesteps(B, has_ligand=has_ligand) + x_t = self.interpolate_tokens(x_gt, timesteps) + + # === CONDITIONING === + conditioning_tensor = torch.zeros((B, L, 1), device=device) + + # Get protein indices (or create dummy for ligand-only batches) + if has_protein: + residue_index = batch["indices"] + else: + residue_index = torch.zeros((B, L), device=device, dtype=torch.long) + + # === FORWARD === + output = self.forward( + x_t, + mask, + residue_index, + conditioning_tensor, + timesteps=timesteps, + ligand_mask=ligand_mask, + bond_matrix=bond_matrix_gt, + protein_valid_mask=protein_valid_mask, + ligand_valid_mask=ligand_valid_mask, + ) + + # === COMPUTE LOSS === + # Only pass decoder_gt every 1000 steps to avoid overhead + decoder_gt = batch if (self.decode_tokens_during_training and batch_idx % 1000 == 0) else None + total_loss, loss_dict = self.compute_loss( + split, + x_gt, + output, + timesteps, + mask=mask, + ligand_mask=ligand_mask, + bond_matrix_gt=bond_matrix_gt, + protein_valid_mask=protein_valid_mask, + ligand_valid_mask=ligand_valid_mask, + # DiffusionLoss params + structure_embeddings_gt=structure_embeddings_gt, + ligand_structure_embeddings_gt=ligand_structure_embeddings_gt, + x_t_structure=x_t.get("structure_tokens"), + x_t_ligand_structure=x_t.get("ligand_structure_tokens"), + decoder_gt=decoder_gt, + ) + + # Log metrics + self.log_dict({f"{split}_loss": total_loss, **loss_dict}, batch_size=B) + + # === PREPARE OUTPUT (compatible with callbacks) === + # StructureDecodeCallback, InverseFoldingCallback, ProteinLigandDecodeCallback expect these keys + outputs = { + "loss": total_loss, + "x_gt": x_gt, + "output": output, + # Timesteps (use dummy zeros for ligand-only batches) + "train_timesteps_seq": timesteps.get("sequence_tokens", torch.zeros(B, device=device)), + "train_timesteps_struc": timesteps.get("structure_tokens", torch.zeros(B, device=device)), + # Conditioning + "conditioning": 0, # No conditioning used + # Unmasked logits + "unmasked_x": { + "sequence_logits": output["sequence_logits"], + "structure_logits": output["structure_logits"], + }, + # Protein/Ligand info for callbacks + "has_protein": has_protein, + "has_ligand": has_ligand, + "ligand_mask": ligand_mask, + "mask": mask, + } + + # Add ligand logits if present + if has_ligand: + outputs["ligand_atom_logits"] = output["ligand_atom_logits"] + outputs["ligand_structure_logits"] = output["ligand_structure_logits"] + outputs["bond_logits"] = output["bond_logits"] + outputs["train_timesteps_ligand"] = timesteps.get("ligand_atom_tokens") + + # Decode structure (only during training, not every batch) + if self.decode_tokens_during_training and batch_idx % 1000 == 0: + with torch.no_grad(): + # In continuous mode, inject ground truth embeddings for decoding + # (predicted embeddings are computed inside compute_loss but not exposed here) + decode_output = output + if self.use_continuous_structure: + decode_output = dict(output) # Copy to avoid modifying original + if structure_embeddings_gt is not None: + decode_output["structure_embeddings"] = structure_embeddings_gt + if ligand_structure_embeddings_gt is not None: + decode_output["ligand_structure_embeddings"] = ligand_structure_embeddings_gt + + # Unified decode for both protein and ligand + decoded_x = self.decode_structure(decode_output, mask, ligand_mask=ligand_mask if has_ligand else None) + outputs["decoded_x"] = decoded_x + + # Extract ligand info (atom types, bond matrix) for callbacks + if has_ligand: + decoded_ligand = self.extract_ligand_predictions(output, ligand_mask) + # Add ligand coords from unified decode output + vit_output = decoded_x.get("vit_decoder") + if isinstance(vit_output, dict) and "ligand_coords" in vit_output: + decoded_ligand["coords"] = vit_output["ligand_coords"] + outputs["decoded_ligand_x"] = decoded_ligand + + return outputs + + def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> dict: + """Training step. + + Returns full output dict for callbacks (StructureDecodeCallback, InverseFoldingCallback). + """ + result = self.step(batch, batch_idx, "train") + # Must return dict with "loss" key for Lightning, plus callback-compatible keys + return result + + def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: + """Validation step.""" + result = self.step(batch, batch_idx, "val") + return result["loss"] + + def configure_optimizers(self): + """Configure optimizer and scheduler.""" + optimizer = torch.optim.AdamW( + self.encoder.parameters(), + lr=self.lr, + betas=(self.beta1, self.beta2), + eps=self.eps, + weight_decay=self.weight_decay, + ) + # Use .get() instead of .pop() to avoid mutating scheduler_kwargs + # This ensures configure_optimizers works correctly on checkpoint resumption + scheduler_kwargs_copy = { + k: v for k, v in self.scheduler_kwargs.items() if k not in ("num_training_steps", "num_warmup_steps") + } + scheduler = transformers.get_scheduler( + self.scheduler, + optimizer, + num_training_steps=self.scheduler_kwargs.get("num_training_steps", None), + num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps", None), + scheduler_specific_kwargs=scheduler_kwargs_copy, + ) + + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + def generate_sample( + self, + length: int, + num_samples: int, + inference_schedule_seq: Callable = LogInferenceSchedule, + inference_schedule_struc: Callable = LinearInferenceSchedule, + inference_schedule_ligand_atom: Callable = None, + inference_schedule_ligand_struc: Callable = None, + nsteps: int = 200, + stochasticity_seq: int = 20, + stochasticity_struc: int = 20, + temperature_seq: float = 0.5, + temperature_struc: float = 1.0, + inverse_folding: bool = False, + forward_folding: bool = False, + input_structure_coords: Tensor = None, + input_sequence_tokens: Tensor = None, + input_mask: Tensor = None, + input_indices: Tensor = None, + # Ligand generation params + generate_ligand: bool = False, + num_atoms: int = 30, + input_ligand_atom_tokens: Tensor = None, + input_ligand_structure_tokens: Tensor = None, + input_ligand_structure_embeddings: Tensor = None, + input_bond_matrix: Tensor = None, + stochasticity_ligand: int = 20, + temperature_ligand: float = 0.5, + ligand_is_context: bool = False, + # Compatibility with base Gen-UME (not used in protein-ligand model) + asynchronous_sampling: bool = False, + # Diffusion loss sampling params (for continuous structure generation) + diffusion_sampling_steps: int = 100, + diffusion_temperature: float = 1.0, + ): + """Generate samples with optional ligand. + + This extends the original Gen-UME generation to support ligand + generation alongside protein. + + When use_diffusion_loss_structure=True, structure generation uses + DiffusionLoss.sample() to produce continuous embeddings, which are + then decoded by vit_decoder to 3D coordinates. + + Parameters + ---------- + inference_schedule_ligand_atom : Callable, optional + Schedule class for ligand atom token generation. If None, falls back + to inference_schedule_seq (protein sequence schedule). + inference_schedule_ligand_struc : Callable, optional + Schedule class for ligand structure token generation. If None, falls + back to inference_schedule_struc (protein structure schedule). + ligand_is_context : bool + If True, the provided ligand is used as fixed conditioning context + and will NOT be updated during generation. Use this for inverse/forward + folding where the ligand provides structural context. + input_ligand_structure_embeddings : Tensor, optional + Continuous structure embeddings for ligand [B, N_atoms, D]. Required + in continuous mode when ligand_is_context=True. Obtain these via + encode_ligand_structure(..., return_continuous=True). + """ + device = next(self.parameters()).device + + # Update interpolant devices (required for proper device handling during inference) + self.interpolant_seq.device = device + self.interpolant_struc.device = device + if hasattr(self, "interpolant_ligand_atom"): + self.interpolant_ligand_atom.device = device + if hasattr(self, "interpolant_ligand_struc"): + self.interpolant_ligand_struc.device = device + + # Initialize protein tokens + xt_seq = self.interpolant_seq.sample_prior((num_samples, length)).to(device) + xt_struc = self.interpolant_struc.sample_prior((num_samples, length)).to(device) + + # Set up schedules + schedule_seq = inference_schedule_seq(nsteps=nsteps) if inference_schedule_seq else self.inference_schedule + schedule_struc = ( + inference_schedule_struc(nsteps=nsteps) if inference_schedule_struc else self.inference_schedule + ) + + ts_seq = schedule_seq.generate_schedule(device=device) + ts_struc = schedule_struc.generate_schedule(device=device) + dts_seq = schedule_seq.discretize(device=device) + dts_struc = schedule_struc.discretize(device=device) + + # Set up independent ligand schedules (fall back to protein schedules if not specified) + schedule_lig_atom = ( + inference_schedule_ligand_atom(nsteps=nsteps) if inference_schedule_ligand_atom else schedule_seq + ) + schedule_lig_struc = ( + inference_schedule_ligand_struc(nsteps=nsteps) if inference_schedule_ligand_struc else schedule_struc + ) + ts_lig_atom = schedule_lig_atom.generate_schedule(device=device) + ts_lig_struc = schedule_lig_struc.generate_schedule(device=device) + dts_lig_atom = schedule_lig_atom.discretize(device=device) + dts_lig_struc = schedule_lig_struc.discretize(device=device) + + # Initialize defaults (same as original Gen-UME) + mask = torch.ones((num_samples, length), device=device) + residue_index = torch.arange(length, device=device) + conditioning_tensor = torch.zeros((num_samples, length, 1), device=device) + + # Handle inverse/forward folding + # Use t=0.9950 (close to 1) to indicate "clean/conditioned" tokens + # In discrete flow matching: t=0 is noise/mask, t=1 is clean data + if inverse_folding and input_structure_coords is not None: + x_quant, _, mask = self.encode_structure(input_structure_coords, input_mask, input_indices) + xt_struc = x_quant.argmax(dim=-1).to(device) + ts_struc = torch.full_like(ts_struc, 0.9950) # Structure is clean/given (t≈1) + elif forward_folding and input_sequence_tokens is not None: + xt_seq = input_sequence_tokens.to(device) + ts_seq = torch.full_like(ts_seq, 0.9950) # Sequence is clean/given (t≈1) + + # Initialize ligand tokens if generating + xt_lig_atom = None + xt_lig_struc = None + ligand_mask = None + bond_matrix = None + + if generate_ligand: + ligand_mask = torch.ones((num_samples, num_atoms), device=device) + if input_ligand_atom_tokens is not None: + xt_lig_atom = input_ligand_atom_tokens.to(device) + else: + xt_lig_atom = self.interpolant_ligand_atom.sample_prior((num_samples, num_atoms)).to(device) + + if input_ligand_structure_tokens is not None: + xt_lig_struc = input_ligand_structure_tokens.to(device) + else: + xt_lig_struc = self.interpolant_ligand_struc.sample_prior((num_samples, num_atoms)).to(device) + + if input_bond_matrix is not None: + bond_matrix = input_bond_matrix.to(device) + else: + bond_matrix = None + + # Check if using continuous structure generation + use_continuous_gen = self.use_diffusion_loss_structure and self.diffusion_loss_protein_struc is not None + + # For continuous mode: track which positions have been "committed" (generated) + # Start with all positions needing generation + if use_continuous_gen: + # Initialize continuous embeddings storage + protein_structure_embeddings = torch.zeros( + num_samples, length, self.diffusion_loss_protein_struc.target_channels, device=device + ) + protein_committed_mask = torch.zeros( + num_samples, length, device=device + ) # 0 = need to generate, 1 = committed + + if generate_ligand: + # Check if ligand is fixed context with provided continuous embeddings + if ligand_is_context and input_ligand_structure_embeddings is not None: + # Use provided embeddings and mark all positions as committed + ligand_structure_embeddings = input_ligand_structure_embeddings.to(device) + ligand_committed_mask = torch.ones(num_samples, num_atoms, device=device) + else: + # Normal generation: start from zeros, commit as positions are unmasked + ligand_structure_embeddings = torch.zeros( + num_samples, num_atoms, self.diffusion_loss_ligand_struc.target_channels, device=device + ) + ligand_committed_mask = torch.zeros(num_samples, num_atoms, device=device) + + # Generation loop (includes independent ligand schedules) + for step_idx, ( + dt_seq, + dt_struc, + t_seq, + t_struc, + dt_lig_atom, + dt_lig_struc, + t_lig_atom, + t_lig_struc, + ) in enumerate( + tqdm( + zip(dts_seq, dts_struc, ts_seq, ts_struc, dts_lig_atom, dts_lig_struc, ts_lig_atom, ts_lig_struc), + desc="Generating samples", + total=len(dts_seq), + ) + ): + # Ensure all schedule tensors are on the correct device + dt_seq = dt_seq.to(device) + dt_struc = dt_struc.to(device) + dt_lig_atom = dt_lig_atom.to(device) + dt_lig_struc = dt_lig_struc.to(device) + t_seq = schedule_seq.pad_time(num_samples, t_seq, device) + t_struc = schedule_struc.pad_time(num_samples, t_struc, device) + t_lig_atom = schedule_lig_atom.pad_time(num_samples, t_lig_atom, device) + t_lig_struc = schedule_lig_struc.pad_time(num_samples, t_lig_struc, device) + timesteps = {"sequence_tokens": t_seq, "structure_tokens": t_struc} + + x_t = {"sequence_tokens": xt_seq, "structure_tokens": xt_struc} + if generate_ligand: + timesteps["ligand_atom_tokens"] = t_lig_atom # Independent ligand atom schedule + timesteps["ligand_structure_tokens"] = t_lig_struc # Independent ligand structure schedule + x_t["ligand_atom_tokens"] = xt_lig_atom + x_t["ligand_structure_tokens"] = xt_lig_struc + + # For MAR-style generation: pass previously sampled embeddings + # The encoder will use these for committed positions, mask embedding for others + structure_embeddings_for_forward = None + ligand_structure_embeddings_for_forward = None + + if use_continuous_gen: + # Create blended embeddings: sampled for committed, zeros for uncommitted + # The mask token embedding will be added in encoder for uncommitted positions + structure_embeddings_for_forward = protein_structure_embeddings * protein_committed_mask.unsqueeze(-1) + if generate_ligand: + ligand_structure_embeddings_for_forward = ( + ligand_structure_embeddings * ligand_committed_mask.unsqueeze(-1) + ) + + output = self.forward( + x_t, + mask, + residue_index, + conditioning_tensor, + timesteps=timesteps, + ligand_mask=ligand_mask, + bond_matrix=bond_matrix, + structure_embeddings=structure_embeddings_for_forward, + ligand_structure_embeddings=ligand_structure_embeddings_for_forward, + ) + + # Update protein sequence tokens (always discrete) + xt_seq = self.interpolant_seq.step( + output["sequence_logits"], + t_seq, + xt_seq, + dt_seq, + stochasticity=stochasticity_seq, + temperature=temperature_seq, + ) + + # Update protein structure + # Always run discrete interpolant step to get the masking schedule + # This ensures continuous mode uses the SAME masking pattern as discrete mode + prev_xt_struc = xt_struc.clone() + xt_struc = self.interpolant_struc.step( + output["structure_logits"], + t_struc, + xt_struc, + dt_struc, + stochasticity=stochasticity_struc, + temperature=temperature_struc, + ) + + if use_continuous_gen: + # Continuous mode: derive committed mask from discrete masking schedule + # Positions that are no longer mask tokens are "committed" + protein_hidden = output.get("protein_hidden_states") + if protein_hidden is not None: + # Identify positions newly unmasked this step + # prev_xt_struc == mask_token AND xt_struc != mask_token + was_masked = prev_xt_struc == self.mask_index_struc_tokens + is_unmasked = xt_struc != self.mask_index_struc_tokens + newly_committed = was_masked & is_unmasked # [B, L] + + if newly_committed.any(): + # Sample continuous embeddings for newly committed positions + with torch.no_grad(): + sampled_embeddings = self.diffusion_loss_protein_struc.sample( + z=protein_hidden, + temperature=diffusion_temperature, + num_steps=diffusion_sampling_steps, + ) + + # Update embeddings and mask for newly committed positions + protein_structure_embeddings = torch.where( + newly_committed.unsqueeze(-1), + sampled_embeddings, + protein_structure_embeddings, + ) + protein_committed_mask = torch.where( + newly_committed, + torch.ones_like(protein_committed_mask), + protein_committed_mask, + ) + + # Update ligand tokens if generating (skip if ligand is fixed context) + if generate_ligand and not ligand_is_context: + xt_lig_atom = self.interpolant_ligand_atom.step( + output["ligand_atom_logits"], + t_lig_atom, # Use independent ligand atom schedule + xt_lig_atom, + dt_lig_atom, # Use independent ligand atom schedule + stochasticity=stochasticity_ligand, + temperature=temperature_ligand, + ) + + # Always run discrete interpolant step for ligand structure to get masking schedule + prev_xt_lig_struc = xt_lig_struc.clone() + xt_lig_struc = self.interpolant_ligand_struc.step( + output["ligand_structure_logits"], + t_lig_struc, # Use independent ligand structure schedule + xt_lig_struc, + dt_lig_struc, # Use independent ligand structure schedule + stochasticity=stochasticity_ligand, + temperature=temperature_ligand, + ) + + if use_continuous_gen: + # Continuous mode: derive committed mask from discrete masking schedule + ligand_hidden = output.get("ligand_hidden_states") + if ligand_hidden is not None: + # Identify positions newly unmasked this step + was_masked = prev_xt_lig_struc == self.mask_index_struc_tokens + is_unmasked = xt_lig_struc != self.mask_index_struc_tokens + newly_committed = was_masked & is_unmasked # [B, N_atoms] + + if newly_committed.any(): + # Sample continuous embeddings for newly committed positions + with torch.no_grad(): + sampled_lig_embeddings = self.diffusion_loss_ligand_struc.sample( + z=ligand_hidden, + temperature=diffusion_temperature, + num_steps=diffusion_sampling_steps, + ) + + # Update embeddings and mask for newly committed positions + ligand_structure_embeddings = torch.where( + newly_committed.unsqueeze(-1), + sampled_lig_embeddings, + ligand_structure_embeddings, + ) + ligand_committed_mask = torch.where( + newly_committed, + torch.ones_like(ligand_committed_mask), + ligand_committed_mask, + ) + + # Final output + result = output + result["generated_seq_tokens"] = xt_seq + + if use_continuous_gen: + # For continuous mode: return embeddings (decoding done by decode_structure()) + result["generated_structure_embeddings"] = protein_structure_embeddings + result["generated_struc_tokens"] = None # No discrete tokens + # Also add as "structure_embeddings" for decode_structure() compatibility + result["structure_embeddings"] = protein_structure_embeddings + + if generate_ligand: + result["generated_ligand_structure_embeddings"] = ligand_structure_embeddings + result["generated_ligand_struc_tokens"] = None + # Also add as "ligand_structure_embeddings" for decode_structure() compatibility + result["ligand_structure_embeddings"] = ligand_structure_embeddings + else: + result["generated_struc_tokens"] = xt_struc + + if generate_ligand: + result["generated_ligand_atom_tokens"] = xt_lig_atom + if not use_continuous_gen: + result["generated_ligand_struc_tokens"] = xt_lig_struc + result["predicted_bond_matrix"] = output["bond_logits"].argmax(dim=-1) + + return result diff --git a/src/lobster/model/gen_ume/_gen_ume_sequence_structure_encoder_lightning_module.py b/src/lobster/model/gen_ume/_gen_ume_sequence_structure_encoder_lightning_module.py index ed62d6d5..2dae53db 100644 --- a/src/lobster/model/gen_ume/_gen_ume_sequence_structure_encoder_lightning_module.py +++ b/src/lobster/model/gen_ume/_gen_ume_sequence_structure_encoder_lightning_module.py @@ -52,7 +52,7 @@ def __init__( ckpt_path: str | None = None, # LatentGenerator params decode_tokens_during_training: bool = True, - latent_generator_model_name: str = "LG 20A seq 3di c6d Aux", + latent_generator_model_name: str = "LG full attention", # generation params prior_distribution_seq: Callable[..., DiscreteUniformPrior] = DiscreteUniformPrior, prior_distribution_struc: Callable[..., DiscreteUniformPrior] = DiscreteUniformPrior, @@ -436,6 +436,8 @@ def generate_sample( inpainting_mask_sequence: Tensor = None, inpainting_mask_structure: Tensor = None, asynchronous_sampling: bool = False, + sequence_anchor_tokens: Tensor = None, + sequence_anchor_mask: Tensor = None, ): """Generate with model, with option to return full unmasking trajectory and likelihood.""" device = next(self.parameters()).device @@ -556,6 +558,10 @@ def generate_sample( xt_seq = xt_seq_new xt_struc = xt_struc_new + # Apply sequence anchors: keep anchored positions fixed (mask=0), update free positions (mask=1) + if sequence_anchor_tokens is not None and sequence_anchor_mask is not None: + xt_seq = torch.where(sequence_anchor_mask.bool(), xt_seq, sequence_anchor_tokens) + xt = {"sequence_tokens": xt_seq, "structure_tokens": xt_struc} return unmasked_x diff --git a/src/lobster/model/gen_ume/binder_utils.py b/src/lobster/model/gen_ume/binder_utils.py new file mode 100644 index 00000000..fc025e7d --- /dev/null +++ b/src/lobster/model/gen_ume/binder_utils.py @@ -0,0 +1,296 @@ +""" +Utility functions for binder design generation. + +This module provides helper functions for the binder_design generation mode, +including chain information extraction, binder initialization, and mask creation. +""" + +import torch + + +def get_target_chain_info(structure_data: dict, target_chain_letter: str) -> tuple[int, int, int]: + """ + Get chain information for the target chain. + + Args: + structure_data: Loaded PDB structure dictionary with 'real_chains' and 'chains_ids' + target_chain_letter: Chain letter (e.g., "A", "B") + + Returns: + chain_idx: Chain index (0, 200, 400, etc.) + start_residue_idx: Starting residue index for this chain + end_residue_idx: Ending residue index for this chain (exclusive) + + Example: + For a PDB with chains A (residues 0-99) and B (residues 100-161): + - real_chains: [65, 65, ..., 66, 66, ...] # ord('A')=65, ord('B')=66 + - chains_ids: [0, 0, ..., 200, 200, ...] + + get_target_chain_info(data, "A") -> (0, 0, 100) + get_target_chain_info(data, "B") -> (200, 100, 162) + """ + # Convert chain letter to ASCII code + target_chain_ord = ord(target_chain_letter) + + # Get real_chains tensor (contains ASCII codes for chain letters) + real_chains = structure_data["real_chains"] + # Note: StructureBackboneTransform renames 'chains_ids' to 'chains' + chains_ids = structure_data.get("chains", structure_data.get("chains_ids")) + + # Find where this chain appears + chain_mask = real_chains == target_chain_ord + + if not chain_mask.any(): + available = set(chr(c) for c in real_chains.unique().tolist()) + raise ValueError(f"Chain '{target_chain_letter}' not found in structure. Available chains: {available}") + + # Get the chain index (0, 200, 400, etc.) + chain_idx = chains_ids[chain_mask][0].item() + + # Find start and end indices in the sequence + chain_positions = torch.where(chain_mask)[0] + start_residue_idx = chain_positions[0].item() + end_residue_idx = chain_positions[-1].item() + 1 + + return chain_idx, start_residue_idx, end_residue_idx + + +def initialize_binder_at_origin( + binder_length: int, device: torch.device, target_coords: torch.Tensor = None, epitope_indices: list[int] = None +) -> dict: + """ + Create initial binder structure with coordinates positioned relative to target epitope. + + Args: + binder_length: Length of binder to create + device: Torch device + target_coords: Optional target coordinates tensor (L_target, 3, 3) + epitope_indices: Optional list of residue indices (in coords_res numbering) defining the epitope. + If provided, binder atoms are randomly distributed in a ball of radius 12Å, + centered 5Å away from epitope (in direction away from target COM). + All binder atoms are constrained to be at least 5Å from target atoms. + + Returns: + binder_data: Dictionary with keys: + - 'coords_res': Coordinates tensor (L, 3, 3) initialized based on epitope or COM + - 'sequence': Sequence tokens (L,) initialized to random valid amino acids + - 'mask': Validity mask (L,) all ones + + Example: + For binder_length=100 with epitope: + coords_res shape: (100, 3, 3) + Atoms randomly distributed in 12Å ball, centered 5Å from epitope, ≥5Å from target + + sequence shape: (100,) + Random tokens from 0-19 (valid amino acids, excluding X=20) + + mask shape: (100,) + All ones (all positions valid) + """ + # Constants for initialization + EPITOPE_DISTANCE = 5.0 # Distance from epitope center to ball center + BALL_RADIUS = 12.0 # Radius of random distribution ball + MIN_TARGET_DISTANCE = 5.0 # Minimum distance from target atoms + MAX_ATTEMPTS = 100 # Maximum rejection sampling attempts per atom + + # Calculate initial position + if target_coords is not None: + # Calculate center of mass of target structure using CA atoms (index 1) + ca_coords = target_coords[:, 1, :] # (L_target, 3) + center_of_mass = ca_coords.mean(dim=0) # (3,) + + # Flatten all target atoms for distance checking + all_target_atoms = target_coords.reshape(-1, 3) # (L_target * 3, 3) + + if epitope_indices is not None and len(epitope_indices) > 0: + # Calculate epitope center from specified residues + epitope_ca_coords = ca_coords[epitope_indices] # (n_epitope, 3) + epitope_center = epitope_ca_coords.mean(dim=0) # (3,) + + # Calculate direction vector from COM to epitope (pointing away from COM) + direction = epitope_center - center_of_mass # (3,) + direction_norm = torch.norm(direction) + + if direction_norm > 1e-6: # Avoid division by zero + direction_unit = direction / direction_norm # Normalize + else: + # If epitope is at COM, use arbitrary direction (z-axis) + direction_unit = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32) + + # Calculate ball center: 5Å away from epitope, in direction away from COM + ball_center = epitope_center + direction_unit * EPITOPE_DISTANCE + + # Initialize coordinates tensor + coords_res = torch.zeros((binder_length, 3, 3), dtype=torch.float32) + + # Generate random positions for each residue + for i in range(binder_length): + for attempt in range(MAX_ATTEMPTS): + # Generate random point in unit sphere using rejection sampling + random_vec = torch.randn(3) + random_vec = random_vec / torch.norm(random_vec) # Normalize to unit sphere surface + + # Random radius (uniform in volume, so use cube root) + random_radius = BALL_RADIUS * (torch.rand(1).item() ** (1 / 3)) + + # Random point in ball + random_point = ball_center + random_vec * random_radius + + # Check minimum distance to all target atoms + distances_to_target = torch.norm(all_target_atoms - random_point.unsqueeze(0), dim=1) + min_distance = distances_to_target.min().item() + + if min_distance >= MIN_TARGET_DISTANCE: + # Point is valid, use it for all 3 backbone atoms (N, CA, C) + coords_res[i, :, :] = random_point.unsqueeze(0).expand(3, 3) + break + else: + # Max attempts reached, use ball center as fallback + coords_res[i, :, :] = ball_center.unsqueeze(0).expand(3, 3) + + else: + # No epitope specified, use center of mass with random distribution + coords_res = torch.zeros((binder_length, 3, 3), dtype=torch.float32) + for i in range(binder_length): + for attempt in range(MAX_ATTEMPTS): + # Generate random point around COM + random_vec = torch.randn(3) + random_vec = random_vec / torch.norm(random_vec) + random_radius = BALL_RADIUS * (torch.rand(1).item() ** (1 / 3)) + random_point = center_of_mass + random_vec * random_radius + + # Check minimum distance to target + distances_to_target = torch.norm(all_target_atoms - random_point.unsqueeze(0), dim=1) + min_distance = distances_to_target.min().item() + + if min_distance >= MIN_TARGET_DISTANCE: + coords_res[i, :, :] = random_point.unsqueeze(0).expand(3, 3) + break + else: + # Fallback: place at COM + offset in random direction + fallback_dir = torch.randn(3) + fallback_dir = fallback_dir / torch.norm(fallback_dir) + coords_res[i, :, :] = ( + (center_of_mass + fallback_dir * MIN_TARGET_DISTANCE).unsqueeze(0).expand(3, 3) + ) + else: + # Fall back to origin if no target provided + coords_res = torch.zeros((binder_length, 3, 3), dtype=torch.float32) + + coords_res = coords_res.to(device) + # Shape: (L, 3, 3) where: + # - First dim: residue index + # - Second dim: atom type (0=N, 1=CA, 2=C) + # - Third dim: xyz coordinates + + # Initialize sequence with random valid amino acids (0-19, excluding X=20) + sequence = torch.randint(0, 20, (binder_length,), dtype=torch.int32, device=device) + + # Set first residue to Methionine (M=10 in standard AA ordering) + # This is the canonical start codon and the first residue is kept fixed for chain break + METHIONINE_IDX = 10 # M in alphabetical AA ordering: A=0, C=1, ..., M=10, ... + sequence[0] = METHIONINE_IDX + + # Create validity mask (all ones - all positions are valid) + mask = torch.ones(binder_length, dtype=torch.float32, device=device) + + return { + "coords_res": coords_res, + "sequence": sequence, + "mask": mask, + } + + +def get_next_chain_index(structure_data: dict) -> int: + """ + Get the next available chain index (200, 400, 600, etc.). + + Args: + structure_data: Loaded PDB structure dictionary with 'chains' or 'chains_ids' + + Returns: + next_chain_idx: Next available chain index + + Example: + If chains contains [0, 0, ..., 200, 200, ...]: + - Max chain index is 200 + - Next available is 400 + + If chains contains [0, 0, ...]: + - Max chain index is 0 + - Next available is 200 + """ + # Note: StructureBackboneTransform renames 'chains_ids' to 'chains' + chains_ids = structure_data.get("chains", structure_data.get("chains_ids")) + + # Find max chain index + max_chain_idx = chains_ids.max().item() + + # Next chain index is max + 200 + next_chain_idx = max_chain_idx + 200 + + # Verify it doesn't collide (rare but check) + while next_chain_idx in chains_ids: + next_chain_idx += 200 + + return next_chain_idx + + +def create_binder_inpainting_masks( + chains_ids: torch.Tensor, target_chain_idx: int, binder_chain_idx: int, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Create inpainting masks for binder design. + Target residues get mask=0 (fixed), binder residues get mask=1 (generate). + + IMPORTANT: The first residue of the binder is kept fixed (mask=0) to preserve + the chain break token. This tells the model where the new chain starts, + otherwise it would treat the binder as a continuation of the target chain. + + Args: + chains_ids: Chain ID tensor for all residues (B, L) + target_chain_idx: Index of target chain to keep fixed + binder_chain_idx: Index of binder chain to generate + device: Torch device + + Returns: + mask_sequence: Inpainting mask for sequence (B, L) + mask_structure: Inpainting mask for structure (B, L) + + Example: + For a complex with: + - Chain A (target): chain_idx=0, residues 0-99 + - Chain B (binder): chain_idx=200, residues 100-199 + + chains_ids: [0,0,...,0, 200,200,...,200] (shape: 1, 200) + target_chain_idx: 0 + binder_chain_idx: 200 + + Returns masks of shape (1, 200): + - Positions 0-99: mask=0 (keep target fixed) + - Position 100: mask=0 (keep first binder token fixed for chain break) + - Positions 101-199: mask=1 (generate rest of binder) + """ + # Create masks initialized to zeros + B, L = chains_ids.shape + mask_sequence = torch.zeros((B, L), dtype=torch.float32, device=device) + mask_structure = torch.zeros((B, L), dtype=torch.float32, device=device) + + # Set binder positions to 1 (generate) + binder_mask = chains_ids == binder_chain_idx + mask_sequence[binder_mask] = 1.0 + mask_structure[binder_mask] = 1.0 + + # Keep first binder residue fixed (mask=0) to preserve chain break token + # Find the first position of the binder chain in each batch + for b in range(B): + binder_positions = torch.where(chains_ids[b] == binder_chain_idx)[0] + if len(binder_positions) > 0: + first_binder_idx = binder_positions[0].item() + mask_sequence[b, first_binder_idx] = 0.0 + mask_structure[b, first_binder_idx] = 0.0 + + # Target positions remain 0 (fixed) + # Any other chains in the structure also remain 0 (fixed) + + return mask_sequence, mask_structure diff --git a/src/lobster/model/latent_generator/README.md b/src/lobster/model/latent_generator/README.md index 58ba84b2..d421c7e5 100644 --- a/src/lobster/model/latent_generator/README.md +++ b/src/lobster/model/latent_generator/README.md @@ -4,9 +4,9 @@ A powerful protein and protein-ligand structure representation learning model fo ## Table of Contents - [Performance](#performance) - - [Reconstruction Quality on CASP15 Proteins](#reconstruction-quality-on-casp15-proteins) - - [Reconstruction Quality with Canonical Pose (Mol Frame)](#reconstruction-quality-with-canonical-pose-mol-frame) - - [Fold Prediction Accuracy](#fold-prediction-accuracy) + - [Structure Reconstruction Quality on CASP15 Proteins](#structure-reconstruction-quality-on-casp15-proteins) + - [Ligand Reconstruction Quality](#ligand-reconstruction-quality) + - [Protein-Ligand Complex Reconstruction Quality](#protein-ligand-complex-reconstruction-quality) - [Setup](#setup) - [Environment Setup](#environment-setup) - [Getting Embeddings and Tokens](#getting-embeddings-and-tokens) @@ -14,80 +14,67 @@ A powerful protein and protein-ligand structure representation learning model fo - [Ligand Example](#ligand-example) - [Protein-Ligand Complex Example](#protein-ligand-complex-example) - [Command-line Example](#command-line-example) -- [Training](#training) - - [Protein-only Training](#protein-only-training) - - [Protein+Ligand (Complex) Training](#proteinligand-complex-training) + - [Ligand Structure Minimization](#ligand-structure-minimization) +- [Evaluation](#evaluation) + - [Evaluating Reconstruction Quality on CASP15](#evaluating-reconstruction-quality-on-casp15) - [Model Configurations](#model-configurations) - - [Ligand Models](#ligand-models) - [Protein-Ligand Models](#protein-ligand-models) - [Protein-Only Models](#protein-only-models) ## Performance -### Reconstruction Quality on CASP15 Proteins +### Structure Reconstruction Quality on CASP15 Proteins -We evaluated the reconstruction quality of our models on CASP15 proteins ≤ 512 residues. The table below shows the average RMSD between original and reconstructed structures: +We evaluated the reconstruction quality of our models on CASP15 proteins (≤ 512 residues). The continuous baseline establishes an upper bound for the ViT architecture. -**Evaluation Set**: CASP15 proteins ≤ 512 residues +**Evaluation Set**: CASP15 proteins ≤ 512 residues -| Model | Average RMSD (Å) | Std RMSD (Å) | Min RMSD (Å) | Max RMSD (Å) | -|-------|------------------|--------------|--------------|--------------| -| LG full attention | 1.707 | 0.643 | 0.839 | 3.434 | -| LG 10A | 3.698 | 1.756 | 1.952 | 7.664 | -| LG 20A c6d Aux | 4.395 | 2.671 | 1.678 | 11.306 | -| LG 20A seq 3di c6d Aux | 4.428 | 1.723 | 2.757 | 8.556 | -| LG 20A 3di c6d Aux | 4.484 | 2.458 | 2.390 | 11.696 | -| LG 20A | 4.470 | 3.540 | 1.630 | 12.864 | -| LG 20A seq 3di c6d 512 Aux | 5.761 | 4.349 | 1.188 | 17.442 | -| LG 20A seq Aux | 5.449 | 2.862 | 3.063 | 13.342 | -| LG 20A seq 3di Aux | 6.112 | 3.723 | 2.973 | 17.839 | -| LG 20A 3di Aux | 7.844 | 4.289 | 3.119 | 16.500 | +| Model | Quantizer | Size | RMSD (Å) | Std | Min | Max | +|-------|-----------|------|----------|-----|-----|-----| +| LG Protein (cont.) | None | - | 0.462 | 0.322 | 0.200 | 1.271 | +| LG Protein SLQ | SLQ | 256 | 1.647 | 0.535 | 0.979 | 3.189 | +| LG Prot-Lig SLQ | SLQ | 256 | 1.873 | 1.054 | 0.798 | 5.143 | +| LG Prot-Lig SLQ | SLQ | 4096 | 3.097 | 2.009 | 1.242 | 8.474 | +| LG Protein FSQ | FSQ | 240 | 1.848 | 1.194 | 0.483 | 5.419 | +| LG Prot-Lig FSQ | FSQ | 4375 | 1.260 | 0.632 | 0.651 | 3.117 | +| LG Prot-Lig FSQ | FSQ | 4375/15360 | 1.418 | 0.810 | 0.748 | 3.396 | -### Reconstruction Quality with Canonical Pose (Mol Frame) +### Ligand Reconstruction Quality -We also evaluated the models using canonical pose mode, which makes the model invariant to rotations and translations: +We evaluated ligand reconstruction quality on 30,936 ligand structures from the GEOM dataset. The unified protein-ligand model achieves comparable performance to the specialist ligand-only model, demonstrating the architecture's capacity to handle multimodal distributions within a shared parameter set. -**Evaluation Set**: CASP15 proteins ≤ 512 residues +**Evaluation Set**: 30,936 ligands from GEOM dataset -| Model | Average RMSD (Å) | Std RMSD (Å) | Min RMSD (Å) | Max RMSD (Å) | -|-------|------------------|--------------|--------------|--------------| -| LG full attention | 1.645 | 0.573 | 0.664 | 2.901 | -| LG 10A | 4.005 | 2.173 | 1.981 | 9.883 | -| LG 20A c6d Aux | 4.603 | 3.028 | 1.240 | 12.297 | -| LG 20A seq 3di c6d Aux | 4.614 | 2.103 | 2.811 | 9.061 | -| LG 20A 3di c6d Aux | 4.140 | 2.108 | 2.195 | 9.275 | -| LG 20A | 4.268 | 3.306 | 1.461 | 12.989 | -| LG 20A seq 3di c6d 512 Aux | 5.445 | 3.963 | 1.568 | 15.305 | -| LG 20A seq Aux | 5.759 | 3.248 | 2.246 | 16.543 | -| LG 20A seq 3di Aux | 6.107 | 2.974 | 3.097 | 13.456 | -| LG 20A 3di Aux | 8.288 | 4.434 | 3.043 | 16.252 | +| Model | Size | Avg RMSD (Å) | Std | Min | Max | +|-------|------|--------------|-----|-----|-----| +| LG Ligand SLQ | 512 | 0.752 | 0.305 | 0.065 | 4.943 | +| LG Prot-Lig SLQ | 512 | 0.920 | 0.236 | 0.152 | 3.704 | +| LG Prot-Lig SLQ | 4096 | 1.239 | 0.335 | 0.196 | 4.101 | +| LG Prot-Lig FSQ | 4375 | 0.395 | 0.059 | 0.179 | 1.784 | +| LG Prot-Lig FSQ | 15360 | 0.295 | 0.052 | 0.120 | 1.792 | +### Protein-Ligand Complex Reconstruction Quality -### Fold Prediction Accuracy +Comparison of FSQ and SLQ variants on PDBbind complexes. Token counts represent the specific codebook size for protein and ligand components respectively. -We evaluated the fold prediction accuracy using embeddings from different LatentGenerator models as features for a small MLP trained for protein fold classification: - - -| Model | Val Acc Mean | Val Acc Std | Val Acc Min | Val Acc Max | -|-------|--------------|-------------|-------------|-------------| -| LG 20A seq 3di c6d Aux PDB | 0.385 | 0.001 | 0.383 | 0.386 | -| LG 20A seq 3di c6d Aux PDB Pinder | 0.381 | 0.004 | 0.376 | 0.387 | -| LG 20A seq 3di c6d Aux PDB Pinder Iterative Refine Module | 0.335 | 0.005 | 0.330 | 0.342 | -| LG 20A seq 3di c6d Aux | 0.313 | 0.004 | 0.310 | 0.319 | -| LG 20A seq Aux | 0.298 | 0.010 | 0.287 | 0.311 | -| LG 20A seq 3di Aux | 0.293 | 0.009 | 0.281 | 0.302 | -| LG 20A 3di c6d Aux | 0.237 | 0.009 | 0.224 | 0.245 | -| LG 20A c6d Aux | 0.226 | 0.003 | 0.223 | 0.231 | -| LG full attention | 0.225 | 0.007 | 0.215 | 0.232 | -| LG 20A 3di Aux | 0.196 | 0.003 | 0.192 | 0.200 | -| LG 10A | 0.123 | 0.001 | 0.122 | 0.124 | -| LG 20A | 0.074 | 0.007 | 0.067 | 0.083 | - -**Key Findings:** -- Models trained on PDB datasets achieve the highest fold prediction accuracy -- Sequence-aware models (with "seq" in the name) consistently outperform structure-only models -- All models use standard hyperparameters: learning rate 0.0003, dropout 0.4, label smoothing 0.2, weight decay 0.0001 +**Evaluation Set**: PDBbind complexes +| Model | Metric | Prot Tokens | Lig Tokens | Alignment | Avg RMSD (Å) | Std | Min | Max | +|-------|--------|-------------|------------|-----------|--------------|-----|-----|-----| +| LG Prot-Lig SLQ | Ligand | 256 | 512 | Individual | 1.411 | 0.593 | 0.365 | 4.519 | +| LG Prot-Lig SLQ | Ligand | 4096 | 4096 | Individual | 1.620 | 0.711 | 0.533 | 6.756 | +| LG Prot-Lig FSQ | Ligand | 4375 | 4375 | Individual | 0.705 | 0.139 | 0.345 | 1.935 | +| LG Prot-Lig FSQ | Ligand | 4375 | 15360 | Individual | 0.657 | 0.146 | 0.315 | 2.407 | +| | | | | | | | | | +| LG Prot-Lig SLQ | Complex | 256 | 512 | Joint | 1.567 | 0.343 | 0.939 | 5.579 | +| LG Prot-Lig SLQ | Complex | 4096 | 4096 | Joint | 4.680 | 2.962 | 1.415 | 19.173 | +| LG Prot-Lig FSQ | Complex | 4375 | 4375 | Joint | 1.011 | 0.127 | 0.723 | 2.387 | +| LG Prot-Lig FSQ | Complex | 4375 | 15360 | Joint | 1.009 | 0.138 | 0.739 | 3.578 | +| | | | | | | | | | +| LG Prot-Lig SLQ | Ligand | 256 | 512 | Joint (c) | 2.306 | 0.758 | 0.711 | 5.927 | +| LG Prot-Lig SLQ | Ligand | 4096 | 4096 | Joint (c) | 3.589 | - | - | - | +| LG Prot-Lig FSQ | Ligand | 4375 | 4375 | Joint (c) | 1.011 | 0.271 | 0.507 | 3.729 | +| LG Prot-Lig FSQ | Ligand | 4375 | 15360 | Joint (c) | 0.998 | - | - | - | ## Setup @@ -112,7 +99,7 @@ from lobster.model.latent_generator.io import writepdb, writepdb_ligand_complex, import torch -model_name = 'LG 20A seq 3di c6d Aux' +model_name = 'LG full attention' # Load model using the ModelInfo dataclass structure load_model( @@ -143,7 +130,7 @@ from lobster.model.latent_generator.cmdline import load_model, encode, decode, m from lobster.model.latent_generator.io import writepdb_ligand_complex, load_pdb, load_ligand import torch -model_name = 'LG Ligand 20A' +model_name = 'LG Protein Ligand fsq 4375' # Load model with ligand support using the ModelInfo dataclass structure load_model( @@ -153,7 +140,7 @@ load_model( overrides=methods[model_name].model_config.overrides ) -# Load protein-ligand complex +# Load ligand only (no protein) pdb_data = {"protein_coords": None, "protein_mask": None, "protein_seq": None} ligand_data = load_ligand("src/lobster/model/latent_generator/example/example_pdbs/4erk_ligand.sdf") pdb_data["ligand_coords"] = ligand_data["atom_coords"] @@ -161,10 +148,11 @@ pdb_data["ligand_mask"] = ligand_data["mask"] pdb_data["ligand_residue_index"] = ligand_data["atom_indices"] pdb_data["ligand_atom_names"] = ligand_data["atom_names"] pdb_data["ligand_indices"] = ligand_data["atom_indices"] -# Get tokens for the complex + +# Get tokens for the ligand tokens, embeddings = encode(pdb_data, return_embeddings=True) print(tokens["ligand_tokens"].shape) # (batch, length_ligand, n_tokens) -print(embeddings.shape) # (batch, length_protein+length_ligand, embedding_dim) +print(embeddings.shape) # (batch, length_ligand, embedding_dim) # Decode tokens back to structure decoded_outputs = decode(tokens, x_emb=embeddings) @@ -179,14 +167,17 @@ writepdb_ligand_complex( ``` +### Protein-Ligand Complex Example -### Protein-Ligand Complex Example (warning ligand recon not good yet) ```python from lobster.model.latent_generator.cmdline import load_model, encode, decode, methods from lobster.model.latent_generator.io import writepdb_ligand_complex, load_pdb, load_ligand import torch -model_name = 'LG Ligand 20A seq 3di Aux' +# Choose one of the protein-ligand models: +# - 'LG Protein Ligand fsq 4375' (4375 tokens for both protein and ligand) +# - 'LG Protein Ligand fsq 4375 15360' (4375 protein tokens, 15360 ligand tokens) +model_name = 'LG Protein Ligand fsq 4375' # Load model with ligand support using the ModelInfo dataclass structure load_model( @@ -204,6 +195,7 @@ pdb_data["ligand_mask"] = ligand_data["mask"] pdb_data["ligand_residue_index"] = ligand_data["atom_indices"] pdb_data["ligand_atom_names"] = ligand_data["atom_names"] pdb_data["ligand_indices"] = ligand_data["atom_indices"] + # Get tokens for the complex tokens, embeddings = encode(pdb_data, return_embeddings=True) print(tokens["protein_tokens"].shape) # (batch, length_protein, n_tokens) @@ -231,155 +223,169 @@ writepdb_ligand_complex( ### Command-line Example ```bash # Get tokens and decode to structure for protein only -python src/lobster/model/latent_generator/cmdline/inference.py \ - --model_name 'LG 20A seq 3di c6d Aux' \ +uv run python src/lobster/model/latent_generator/cmdline/inference.py \ + --model_name 'LG full attention' \ --pdb_path src/lobster/model/latent_generator/example/example_pdbs/7kdr_protein.pdb \ --decode -# Get tokens and decode to structure for ligand -python src/lobster/model/latent_generator/cmdline/inference.py \ - --model_name 'LG Ligand 20A' \ - --ligand_path src/lobster/model/latent_generator/example/example_pdbs/4erk_ligand.sdf \ +# Get tokens and decode to structure for ligand only +uv run python src/lobster/model/latent_generator/cmdline/inference.py \ + --model_name 'LG Protein Ligand fsq 4375' \ + --ligand_path src/lobster/model/latent_generator/example/example_pdbs/4erk_ligand.sdf \ --decode - -# Get tokens and decode to structure for protein-ligand -python src/lobster/model/latent_generator/cmdline/inference.py \ - --model_name 'LG Ligand 20A seq 3di Aux' \ + +# Get tokens and decode to structure for protein-ligand complex using LG Protein Ligand fsq 4375 +uv run python src/lobster/model/latent_generator/cmdline/inference.py \ + --model_name 'LG Protein Ligand fsq 4375' \ --pdb_path src/lobster/model/latent_generator/example/example_pdbs/4erk_protein.pdb \ - --ligand_path latent_generator/example/example_pdbs/4erk_ligand.sdf \ + --ligand_path src/lobster/model/latent_generator/example/example_pdbs/4erk_ligand.sdf \ + --decode + +# Get tokens and decode using LG Protein Ligand fsq 4375 15360 (higher ligand resolution) +uv run python src/lobster/model/latent_generator/cmdline/inference.py \ + --model_name 'LG Protein Ligand fsq 4375 15360' \ + --pdb_path src/lobster/model/latent_generator/example/example_pdbs/4erk_protein.pdb \ + --ligand_path src/lobster/model/latent_generator/example/example_pdbs/4erk_ligand.sdf \ --decode # Get embeddings (requires Python API) ``` +### Ligand Structure Minimization + +For protein-ligand complexes, you can apply post-decoding geometry correction to improve ligand bond lengths and angles using Open Babel force fields. This is especially useful for improving the quality of decoded ligand structures. + +```bash +# Decode with ligand minimization (bonds and angles correction - recommended) +uv run python src/lobster/model/latent_generator/cmdline/inference.py \ + --model_name 'LG Protein Ligand fsq 4375' \ + --pdb_path src/lobster/model/latent_generator/example/example_pdbs/4erk_protein.pdb \ + --ligand_path src/lobster/model/latent_generator/example/example_pdbs/4erk_ligand.sdf \ + --output_pdb decoded_complex.pdb \ + --decode \ + --minimize + +# Specify output paths explicitly +uv run python src/lobster/model/latent_generator/cmdline/inference.py \ + --model_name 'LG Protein Ligand fsq 4375' \ + --pdb_path src/lobster/model/latent_generator/example/example_pdbs/4erk_protein.pdb \ + --ligand_path src/lobster/model/latent_generator/example/example_pdbs/4erk_ligand.sdf \ + --output_file_encode encoded_latents.pt \ + --output_file_decode decoded_outputs.pt \ + --output_pdb decoded_complex.pdb \ + --decode \ + --minimize +``` + +#### Minimization Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--minimize` | False | Enable ligand structure minimization after decoding | +| `--minimize_mode` | `bonds_and_angles` | Minimization strategy (see below) | +| `--force_field` | `MMFF94` | Force field: `MMFF94`, `MMFF94s`, `UFF`, `GAFF`, `Ghemical` | +| `--minimize_steps` | `500` | Maximum optimization steps | +| `--minimize_method` | `cg` | Optimization method: `cg` (conjugate gradients) or `sd` (steepest descent) | + +#### Minimization Modes + +| Mode | Description | +|------|-------------| +| `bonds_and_angles` | **Recommended.** Constrained force field minimization that idealizes both bond lengths and angles while preserving overall structure. | +| `bonds_only` | Only corrects bond lengths to ideal values, preserving torsion angles. | + +#### Example with Custom Minimization Settings + +```bash +# Use UFF force field with bonds_only mode +uv run python src/lobster/model/latent_generator/cmdline/inference.py \ + --model_name 'LG Protein Ligand fsq 4375' \ + --pdb_path protein.pdb \ + --ligand_path ligand.sdf \ + --output_pdb output.pdb \ + --decode \ + --minimize \ + --minimize_mode bonds_only \ + --force_field UFF +``` + +#### CONECT Records + +When the ligand SDF file contains bond information, the output PDB will include CONECT records for proper bond visualization in molecular viewers like PyMOL, Chimera, or VMD. + The tokens are discrete representations that can be used for tasks like discrete generation (with LLMs or PLMs) and compact storage of structure information, while embeddings are continuous representations useful for tasks like similarity search, feature extraction, and representation centric tasks. -## Model Configurations +## Evaluation -LatentGenerator provides several pre-configured models optimized for different use cases. These configurations include all necessary settings and overrides, making them easy to use without manual configuration. +### Evaluating Reconstruction Quality on CASP15 -### Ligand Models +The `evaluate_reconstruction.py` script evaluates the reconstruction quality of LatentGenerator models by computing the aligned RMSD between original and reconstructed structures. -#### LG Ligand 20A -- **Description**: Ligand only model with 20Å spatial attention -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Ligand only decoder - - 512 ligand tokens -- **Use Case**: Ligand analysis and generation - -#### LG Ligand 20A 512 1024 -- **Description**: Ligand only model with 20Å spatial attention -- **Features**: - - 512-dim embeddings - - 20Å spatial attention - - Ligand only decoder - - 1024 ligand tokens -- **Use Case**: High-dimensional ligand analysis and generation - -#### LG Ligand 20A 512 1024 element -- **Description**: Ligand only model with 20Å spatial attention and element awareness -- **Features**: - - 512-dim embeddings - - 20Å spatial attention - - Ligand only decoder with element awareness - - 1024 ligand tokens -- **Use Case**: Element-aware ligand analysis and generation - -#### LG Ligand 20A continuous -- **Description**: Ligand only model with 20Å spatial attention and continuous encoding -- **Features**: - - 512-dim embeddings - - 20Å spatial attention - - Ligand only decoder - - Continuous ligand encoding (no quantization) -- **Use Case**: Continuous ligand representation learning +#### Basic Usage -### Protein-Ligand Models +Evaluate a single model on a directory of structures: -#### LG Ligand 20A seq 3di Aux -- **Description**: Protein-ligand model with sequence and 3Di awareness -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Sequence and 3Di decoder - - Ligand encoding support - - 512 ligand tokens - - 512 protein tokens -- **Use Case**: Protein-ligand complex analysis and generation with sequence awareness +```bash +uv run python src/lobster/metrics/evaluate_reconstruction.py \ + --models "LG full attention" \ + --data_dir /path/to/casp15/structures/ \ + --output_file reconstruction_results.json +``` -### Protein-Only Models +#### Using Canonical Pose (Mol Frame) -#### LG 20A seq Aux -- **Description**: Sequence-aware protein model -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Sequence decoder - - 256 protein tokens -- **Use Case**: Protein structure analysis with sequence awareness +Evaluate with canonical pose mode for rotation/translation invariance: -#### LG 20A seq 3di c6d Aux -- **Description**: Sequence, 3Di and C6D-aware protein model -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Sequence + 3Di + C6D decoder - - 256 protein tokens -- **Use Case**: Advanced protein structure analysis with sequence, 3Di and C6D features +```bash +uv run python src/lobster/metrics/evaluate_reconstruction.py \ + --models "LG full attention" \ + --data_dir /path/to/casp15/structures/ \ + --output_file reconstruction_canonical.json \ + --use_canonical_pose +``` -#### LG 20A seq 3di c6d Aux Pinder -- **Description**: Sequence, 3Di and C6D-aware protein model (Pinder dataset) -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Sequence + 3Di + C6D decoder - - 256 protein tokens -- **Use Case**: Advanced protein structure analysis trained on Pinder dataset +#### Input File Formats -#### LG 20A seq 3di c6d Aux PDB -- **Description**: Sequence, 3Di and C6D-aware protein model (PDB dataset) -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Sequence + 3Di + C6D decoder - - 256 protein tokens -- **Use Case**: Advanced protein structure analysis trained on PDB dataset +The evaluation script supports multiple structure file formats: +- **PDB files** (`.pdb`): Standard protein structure files +- **SDF files** (`.sdf`): Ligand structure files +- **PyTorch files** (`.pt`): Pre-processed structure data -#### LG 20A seq 3di c6d Aux PDB Pinder -- **Description**: Sequence, 3Di and C6D-aware protein model (PDB + Pinder datasets) -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Sequence + 3Di + C6D decoder - - 256 protein tokens -- **Use Case**: Advanced protein structure analysis trained on combined PDB and Pinder datasets +#### Performance Metrics -#### LG 20A seq 3di c6d Aux PDB Pinder Finetune -- **Description**: Sequence, 3Di and C6D-aware protein model (finetuned on PDB + Pinder) -- **Features**: - - 256-dim embeddings - - 20Å spatial attention - - Sequence + 3Di + C6D decoder - - 256 protein tokens -- **Use Case**: Finetuned protein structure analysis with sequence, 3Di and C6D features +The evaluation reports: +- **Average RMSD**: Mean reconstruction error across all structures +- **Std RMSD**: Standard deviation of RMSD values +- **Min/Max RMSD**: Best and worst reconstruction quality +- **Success Rate**: Number of successful vs. failed reconstructions + +## Model Configurations + +LatentGenerator provides pre-configured models optimized for different use cases. These configurations include all necessary settings and overrides, making them easy to use without manual configuration. + +### Protein-Ligand Models -#### LG 20A -- **Description**: Basic protein model with 20Å cutoff +#### LG Protein Ligand fsq 4375 +- **Description**: Protein-ligand model with FSQ quantization (4375 tokens) - **Features**: - - Standard configuration - - 20Å spatial attention - - 256 protein tokens -- **Use Case**: Basic protein structure analysis + - 5-dim embeddings + - FSQ quantization + - Ligand encoding support + - 4375 ligand tokens + - 4375 protein tokens +- **Use Case**: Protein-ligand complex analysis and generation with balanced token resolution -#### LG 10A -- **Description**: Basic protein model with 10Å cutoff +#### LG Protein Ligand fsq 4375 15360 +- **Description**: Protein-ligand model with FSQ quantization (4375 protein tokens, 15360 ligand tokens) - **Features**: - - Standard configuration - - 10Å spatial attention - - 256 protein tokens -- **Use Case**: Local protein structure analysis + - 5-dim embeddings + - FSQ quantization + - Ligand encoding support + - 15360 ligand tokens (higher resolution for ligands) + - 4375 protein tokens +- **Use Case**: Protein-ligand complex analysis and generation with higher ligand resolution + +### Protein-Only Models #### LG full attention - **Description**: Full attention model without spatial masking @@ -399,7 +405,7 @@ To use any of these models, simply specify the model name when loading. The `met from lobster.model.latent_generator.latent_generator.cmdline import load_model, methods # Load a pre-configured model using the ModelInfo dataclass structure -model_name = 'LG seq 20A 3di c6d Aux' +model_name = 'LG full attention' load_model( methods[model_name].model_config.checkpoint, methods[model_name].model_config.config_path, @@ -427,10 +433,10 @@ load_model( Or via command line: ```bash # Using pre-configured model -python latent_generator/cmdline/inference.py --model_name 'LG 20A 3di c6d Aux' --pdb_path your_protein.pdb +uv run python latent_generator/cmdline/inference.py --model_name 'LG full attention' --pdb_path your_protein.pdb # Using custom checkpoint -python latent_generator/cmdline/inference.py \ +uv run python latent_generator/cmdline/inference.py \ --ckpt_path path/to/your/checkpoint.ckpt \ --cfg_path path/to/config/ \ --cfg_name config_name \ diff --git a/src/lobster/model/latent_generator/callbacks/_backbone_reconstruction.py b/src/lobster/model/latent_generator/callbacks/_backbone_reconstruction.py index 053a5e48..7cfd31bc 100644 --- a/src/lobster/model/latent_generator/callbacks/_backbone_reconstruction.py +++ b/src/lobster/model/latent_generator/callbacks/_backbone_reconstruction.py @@ -6,6 +6,7 @@ from lobster.model.latent_generator.io import writepdb, writepdb_ligand_complex from lobster.model.latent_generator.utils import residue_constants +from lobster.model.latent_generator.utils import minimize_ligand_structure logger = logging.getLogger(__name__) @@ -13,6 +14,33 @@ idx_to_aa = dict(enumerate(residue_constants.restype_order_with_x)) +def get_element_name(idx: int, use_extended_vocab: bool = False) -> str: + """Get element name from index, supporting both standard and extended vocabularies. + + Parameters + ---------- + idx : int + Element index. + use_extended_vocab : bool + If True, use ELEMENT_VOCAB_EXTENDED (25 tokens). + If False, use ELEMENT_VOCAB (14 tokens). + + Returns + ------- + str + Element name (e.g., 'C', 'N', 'O') or 'X' for unknown. + """ + if use_extended_vocab: + vocab = residue_constants.ELEMENT_VOCAB_EXTENDED + else: + vocab = residue_constants.ELEMENT_VOCAB + + if 0 <= idx < len(vocab): + return vocab[idx] + else: + return "X" # Unknown element + + def get_seq_from_batch(batch): seq = [] for i in range(batch.shape[0]): @@ -21,113 +49,378 @@ def get_seq_from_batch(batch): class BackboneReconstruction(lightning.Callback): - def __init__(self, structure_path: str = None, target_paths: str = None, save_every_n: int = 1000): + def __init__( + self, + structure_path: str = None, + target_paths: str = None, + save_every_n: int = 1000, + max_total_files: int = 1000, + use_extended_element_vocab: bool = False, + minimize_ligand: bool = False, + minimize_mode: str = "bonds_and_angles", + force_field: str = "MMFF94", + minimize_steps: int = 500, + ): + """Initialize BackboneReconstruction callback. + + Args: + structure_path: Path to save reconstructed structures + target_paths: Target paths (unused) + save_every_n: Save structures every N batches + max_total_files: Maximum total number of PDB files to keep. Older files + will be deleted when this limit is exceeded. If None, keeps all files. + Default: None + use_extended_element_vocab: If True, use ELEMENT_VOCAB_EXTENDED (25 tokens) + for mapping element indices to atom names. If False, use ELEMENT_VOCAB + (14 tokens). Default: False + minimize_ligand: If True, apply geometry correction to ligand structures. + Default: False + minimize_mode: Minimization mode - "bonds_only" or "bonds_and_angles" (recommended). + Default: "bonds_and_angles" + force_field: Force field for minimization - "MMFF94", "MMFF94s", "UFF", etc. + Default: "MMFF94" + minimize_steps: Maximum number of minimization steps. Default: 500 + """ self.target_paths = target_paths self.STRUCTURE_PATH = structure_path self.save_every_n = save_every_n + self.max_total_files = max_total_files + self.use_extended_element_vocab = use_extended_element_vocab + self.minimize_ligand = minimize_ligand + self.minimize_mode = minimize_mode + self.force_field = force_field + self.minimize_steps = minimize_steps self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs(f"{self.STRUCTURE_PATH}/recon", exist_ok=True) + if self.max_total_files is not None: + logger.info(f"Will keep maximum {self.max_total_files} total PDB files (oldest will be deleted)") + logger.info(f"Using {'extended' if use_extended_element_vocab else 'standard'} element vocabulary") + if self.minimize_ligand: + logger.info(f"Ligand minimization enabled: mode={minimize_mode}, force_field={force_field}") + + def _cleanup_old_files(self): + """Remove oldest PDB files if total count exceeds max_total_files.""" + if self.max_total_files is None: + return + + recon_dir = f"{self.STRUCTURE_PATH}/recon" + if not os.path.exists(recon_dir): + return + + # Get all PDB files with their creation times + pdb_files = [] + for filename in os.listdir(recon_dir): + if filename.endswith(".pdb"): + filepath = os.path.join(recon_dir, filename) + try: + mtime = os.path.getmtime(filepath) + pdb_files.append((filepath, mtime)) + except OSError: + continue + + # If we exceed the limit, delete oldest files + if len(pdb_files) > self.max_total_files: + # Sort by modification time (oldest first) + pdb_files.sort(key=lambda x: x[1]) + + # Calculate how many to delete + num_to_delete = len(pdb_files) - self.max_total_files + + # Delete oldest files + for filepath, _ in pdb_files[:num_to_delete]: + try: + os.remove(filepath) + logger.debug(f"Deleted old PDB file: {filepath}") + except OSError as e: + logger.warning(f"Failed to delete {filepath}: {e}") + + logger.info(f"Cleaned up {num_to_delete} old PDB files. Total files: {self.max_total_files}") + def on_train_batch_end(self, trainer, tokenizer, outputs, batch, batch_idx): + # Only save on rank 0 to avoid file I/O contention in distributed training + if trainer.is_global_zero: + self._save_reconstruction(trainer, outputs, batch, batch_idx, prefix="") + self._cleanup_old_files() + + def on_validation_batch_end(self, trainer, tokenizer, outputs, batch, batch_idx, dataloader_idx=0): + # Only save on rank 0 to avoid file I/O contention in distributed training + if trainer.is_global_zero: + self._save_reconstruction(trainer, outputs, batch, batch_idx, prefix="val_") + self._cleanup_old_files() + + def _save_reconstruction(self, trainer, outputs, batch, batch_idx, prefix=""): current_step = trainer.global_step - seq = None + + if batch_idx % self.save_every_n != 0: + return + + # Extract reconstructions + x_recon = outputs["x_recon"] x_recon_xyz = None + x_recon_ligand = None + x_recon_element = None - if batch_idx % self.save_every_n == 0: - # save ouputs too - x_recon = outputs["x_recon"] - - x_recon_xyz = None - seq = None - - for decoder_name in x_recon: - if "vit_decoder" == decoder_name or "vit_decoder_simple" == decoder_name: - x_recon_xyz = x_recon[decoder_name] - if isinstance(x_recon_xyz, dict) and "ligand_coords" in x_recon_xyz: - x_recon_ligand = x_recon_xyz["ligand_coords"] - x_recon_xyz = x_recon_xyz["protein_coords"] - elif isinstance(x_recon_xyz, dict) and "protein_coords_refinement" in x_recon_xyz: - x_recon_xyz = x_recon_xyz["protein_coords_refinement"] - x_recon_ligand = None - else: - x_recon_ligand = None - if "element_decoder" == decoder_name: - x_recon_element = x_recon[decoder_name] - x_recon_element = x_recon_element.argmax(dim=-1) - ligand_atom_names = [residue_constants.ELEMENT_VOCAB[int(i)] for i in x_recon_element[0]] + for decoder_name in x_recon: + if "vit_decoder" == decoder_name or "vit_decoder_simple" == decoder_name: + x_recon_xyz = x_recon[decoder_name] + if isinstance(x_recon_xyz, dict) and "ligand_coords" in x_recon_xyz: + x_recon_ligand = x_recon_xyz["ligand_coords"] + x_recon_xyz = x_recon_xyz["protein_coords"] + elif isinstance(x_recon_xyz, dict) and "protein_coords_refinement" in x_recon_xyz: + x_recon_xyz = x_recon_xyz["protein_coords_refinement"] + x_recon_ligand = None else: - ligand_atom_names = None + x_recon_ligand = None + if "element_decoder" == decoder_name: + x_recon_element = x_recon[decoder_name].argmax(dim=-1) - # save the pdb file + # Determine batch size + if x_recon_xyz is not None: + batch_size = x_recon_xyz.shape[0] + elif x_recon_ligand is not None: + batch_size = x_recon_ligand.shape[0] + else: + return + + # Save all batch entries + for i in range(batch_size): + # Save reconstructed structures if x_recon_xyz is not None: - if seq is None: - seq = torch.zeros(x_recon_xyz.shape[1], dtype=torch.long)[None] - filename = f"{self.STRUCTURE_PATH}recon/struc_{batch_idx}_{current_step}_gen.pdb" - if x_recon_ligand is not None: - ligand_atoms = x_recon_ligand[0] - ligand_chain = "L" - ligand_resname = "LIG" - writepdb_ligand_complex( - filename, - ligand_atoms=ligand_atoms, - ligand_atom_names=ligand_atom_names, - ligand_chain=ligand_chain, - ligand_resname=ligand_resname, - protein_atoms=x_recon_xyz[0], - protein_seq=seq[0], - ) + # Apply mask to reconstructed protein (assume mask is in batch) + protein_mask_i = batch.get("mask", None) + if protein_mask_i is not None: + protein_mask_i = protein_mask_i[i].bool() + protein_coords_i = x_recon_xyz[i][protein_mask_i] + seq_i = torch.zeros(protein_coords_i.shape[0], dtype=torch.long) else: - writepdb(filename, x_recon_xyz[0], seq[0]) - logger.info(f"Saved {filename}") + protein_coords_i = x_recon_xyz[i] + seq_i = torch.zeros(x_recon_xyz.shape[1], dtype=torch.long) + + filename = f"{self.STRUCTURE_PATH}recon/{prefix}struc_{batch_idx}_{current_step}_gen_item{i}.pdb" + + if x_recon_ligand is not None: + # Apply mask to reconstructed ligand + ligand_mask_i = batch.get("ligand_mask", None) + ligand_atom_names_i = None + bond_matrix_i = None + + if ligand_mask_i is not None: + ligand_mask_i = ligand_mask_i[i].bool() + ligand_coords_i = x_recon_ligand[i][ligand_mask_i] + + # Get ligand atom names with masking + if x_recon_element is not None: + ligand_elements_masked = x_recon_element[i][ligand_mask_i] + ligand_atom_names_i = [ + get_element_name(int(j), self.use_extended_element_vocab) + for j in ligand_elements_masked + ] + + # Get bond matrix with masking if available + if "ligand_bond_matrix" in batch: + full_bond_matrix = batch["ligand_bond_matrix"][i] + # Apply mask to bond matrix (select rows and columns for valid atoms) + bond_matrix_i = full_bond_matrix[ligand_mask_i][:, ligand_mask_i] + else: + ligand_coords_i = x_recon_ligand[i] + if x_recon_element is not None: + ligand_atom_names_i = [ + get_element_name(int(j), self.use_extended_element_vocab) for j in x_recon_element[i] + ] + if "ligand_bond_matrix" in batch: + bond_matrix_i = batch["ligand_bond_matrix"][i] + + # Apply ligand minimization if enabled + if self.minimize_ligand and ligand_atom_names_i is not None: + try: + ligand_coords_i = minimize_ligand_structure( + ligand_coords_i, + ligand_atom_names_i, + bond_matrix=bond_matrix_i, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Ligand minimization failed: {e}") - # save batch - filename = f"{self.STRUCTURE_PATH}recon/struc_{batch_idx}_{current_step}_gt.pdb" - seq = torch.zeros(batch["coords_res"].shape[1], dtype=torch.long)[None] - if "ligand_coords" in batch: - ligand_atoms = batch["ligand_coords"][0] - ligand_atom_names = None - ligand_chain = "L" - ligand_resname = "LIG" writepdb_ligand_complex( filename, - ligand_atoms=ligand_atoms, - ligand_atom_names=ligand_atom_names, - ligand_chain=ligand_chain, - ligand_resname=ligand_resname, - protein_atoms=batch["coords_res"][0], - protein_seq=seq[0], + ligand_atoms=ligand_coords_i, + ligand_atom_names=ligand_atom_names_i, + ligand_chain="L", + ligand_resname="LIG", + protein_atoms=protein_coords_i, + protein_seq=seq_i, + ligand_bond_matrix=bond_matrix_i, ) else: - writepdb(filename, batch["coords_res"][0], seq) + writepdb(filename, protein_coords_i, seq_i) logger.info(f"Saved {filename}") + + # Save ground truth + if "coords_res" in batch: + filename_gt = f"{self.STRUCTURE_PATH}recon/{prefix}struc_{batch_idx}_{current_step}_gt_item{i}.pdb" + + # Apply mask to ground truth protein + if protein_mask_i is not None: + gt_protein_coords_i = batch["coords_res"][i][protein_mask_i] + seq_gt_i = torch.zeros(gt_protein_coords_i.shape[0], dtype=torch.long) + else: + gt_protein_coords_i = batch["coords_res"][i] + seq_gt_i = torch.zeros(batch["coords_res"].shape[1], dtype=torch.long) + + if "ligand_coords" in batch: + # Apply mask to ground truth ligand + gt_ligand_mask_i = batch.get("ligand_mask", None) + gt_ligand_atom_names = None + gt_bond_matrix_i = None + + if gt_ligand_mask_i is not None: + gt_ligand_mask_i = gt_ligand_mask_i[i].bool() + gt_ligand_coords_i = batch["ligand_coords"][i][gt_ligand_mask_i] + + # Get ground truth ligand atom names with masking + if "ligand_element_indices" in batch: + gt_ligand_elements_masked = batch["ligand_element_indices"][i][gt_ligand_mask_i] + gt_ligand_atom_names = [ + get_element_name(int(j), self.use_extended_element_vocab) + for j in gt_ligand_elements_masked + ] + + # Get bond matrix with masking if available + if "ligand_bond_matrix" in batch: + full_bond_matrix = batch["ligand_bond_matrix"][i] + gt_bond_matrix_i = full_bond_matrix[gt_ligand_mask_i][:, gt_ligand_mask_i] + else: + gt_ligand_coords_i = batch["ligand_coords"][i] + if "ligand_element_indices" in batch: + gt_ligand_atom_names = [ + get_element_name(int(j), self.use_extended_element_vocab) + for j in batch["ligand_element_indices"][i] + ] + if "ligand_bond_matrix" in batch: + gt_bond_matrix_i = batch["ligand_bond_matrix"][i] + + writepdb_ligand_complex( + filename_gt, + ligand_atoms=gt_ligand_coords_i, + ligand_atom_names=gt_ligand_atom_names, + ligand_chain="L", + ligand_resname="LIG", + protein_atoms=gt_protein_coords_i, + protein_seq=seq_gt_i, + ligand_bond_matrix=gt_bond_matrix_i, + ) + else: + writepdb(filename_gt, gt_protein_coords_i, seq_gt_i) + logger.info(f"Saved {filename_gt}") + + # Ligand-only case elif x_recon_ligand is not None: - filename = f"{self.STRUCTURE_PATH}recon/struc_{batch_idx}_{current_step}_gen_ligand.pdb" - ligand_atoms = x_recon_ligand[0] - ligand_chain = "L" - ligand_resname = "LIG" + # Apply mask to reconstructed ligand + ligand_mask_i = batch.get("ligand_mask", None) + ligand_atom_names_i = None + bond_matrix_i = None + + if ligand_mask_i is not None: + ligand_mask_i = ligand_mask_i[i].bool() + ligand_coords_recon_i = x_recon_ligand[i][ligand_mask_i] + + # Get ligand atom names with masking + if x_recon_element is not None: + ligand_elements_masked = x_recon_element[i][ligand_mask_i] + ligand_atom_names_i = [ + get_element_name(int(j), self.use_extended_element_vocab) for j in ligand_elements_masked + ] + + # Get bond matrix with masking if available + if "ligand_bond_matrix" in batch: + full_bond_matrix = batch["ligand_bond_matrix"][i] + bond_matrix_i = full_bond_matrix[ligand_mask_i][:, ligand_mask_i] + else: + ligand_coords_recon_i = x_recon_ligand[i] + if x_recon_element is not None: + ligand_atom_names_i = [ + get_element_name(int(j), self.use_extended_element_vocab) for j in x_recon_element[i] + ] + if "ligand_bond_matrix" in batch: + bond_matrix_i = batch["ligand_bond_matrix"][i] + + # Apply ligand minimization if enabled + if self.minimize_ligand and ligand_atom_names_i is not None: + try: + ligand_coords_recon_i = minimize_ligand_structure( + ligand_coords_recon_i, + ligand_atom_names_i, + bond_matrix=bond_matrix_i, + steps=self.minimize_steps, + force_field=self.force_field, + mode=self.minimize_mode, + ) + except Exception as e: + logger.warning(f"Ligand minimization failed: {e}") + + # Save reconstructed ligand + filename = f"{self.STRUCTURE_PATH}recon/{prefix}struc_{batch_idx}_{current_step}_gen_ligand_item{i}.pdb" writepdb_ligand_complex( filename, - ligand_atoms=ligand_atoms, - ligand_atom_names=ligand_atom_names, - ligand_chain=ligand_chain, - ligand_resname=ligand_resname, + ligand_atoms=ligand_coords_recon_i, + ligand_atom_names=ligand_atom_names_i, + ligand_chain="L", + ligand_resname="LIG", protein_atoms=None, protein_seq=None, + ligand_bond_matrix=bond_matrix_i, ) logger.info(f"Saved {filename}") - filename = f"{self.STRUCTURE_PATH}recon/struc_{batch_idx}_{current_step}_gt_ligand.pdb" - ligand_atoms = batch["ligand_coords"][0] - ligand_atom_names = [ - residue_constants.ELEMENT_VOCAB[int(i)] for i in batch["ligand_element_indices"][0] - ] - ligand_chain = "L" - ligand_resname = "LIG" - writepdb_ligand_complex( - filename, - ligand_atoms=ligand_atoms, - ligand_atom_names=ligand_atom_names, - ligand_chain=ligand_chain, - ligand_resname=ligand_resname, - protein_atoms=None, - protein_seq=None, - ) + + # Save ground truth ligand + if "ligand_coords" in batch: + filename_gt = ( + f"{self.STRUCTURE_PATH}recon/{prefix}struc_{batch_idx}_{current_step}_gt_ligand_item{i}.pdb" + ) + + # Apply mask to ground truth ligand + gt_ligand_mask_i = batch.get("ligand_mask", None) + gt_ligand_atom_names = None + gt_bond_matrix_i = None + + if gt_ligand_mask_i is not None: + gt_ligand_mask_i = gt_ligand_mask_i[i].bool() + gt_ligand_coords_i = batch["ligand_coords"][i][gt_ligand_mask_i] + + # Get ground truth ligand atom names with masking + if "ligand_element_indices" in batch: + gt_ligand_elements_masked = batch["ligand_element_indices"][i][gt_ligand_mask_i] + gt_ligand_atom_names = [ + get_element_name(int(j), self.use_extended_element_vocab) + for j in gt_ligand_elements_masked + ] + + # Get bond matrix with masking if available + if "ligand_bond_matrix" in batch: + full_bond_matrix = batch["ligand_bond_matrix"][i] + gt_bond_matrix_i = full_bond_matrix[gt_ligand_mask_i][:, gt_ligand_mask_i] + else: + gt_ligand_coords_i = batch["ligand_coords"][i] + if "ligand_element_indices" in batch: + gt_ligand_atom_names = [ + get_element_name(int(j), self.use_extended_element_vocab) + for j in batch["ligand_element_indices"][i] + ] + if "ligand_bond_matrix" in batch: + gt_bond_matrix_i = batch["ligand_bond_matrix"][i] + + writepdb_ligand_complex( + filename_gt, + ligand_atoms=gt_ligand_coords_i, + ligand_atom_names=gt_ligand_atom_names, + ligand_chain="L", + ligand_resname="LIG", + protein_atoms=None, + protein_seq=None, + ligand_bond_matrix=gt_bond_matrix_i, + ) + logger.info(f"Saved {filename_gt}") diff --git a/src/lobster/model/latent_generator/callbacks/_dssp_linear_probe.py b/src/lobster/model/latent_generator/callbacks/_dssp_linear_probe.py index 6b1cc9cf..1692435b 100644 --- a/src/lobster/model/latent_generator/callbacks/_dssp_linear_probe.py +++ b/src/lobster/model/latent_generator/callbacks/_dssp_linear_probe.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F -from lobster.model.latent_generator.datasets import StructureBackboneTransform +from lobster.transforms._structure_transforms import StructureBackboneTransform logger = logging.getLogger(__name__) diff --git a/src/lobster/model/latent_generator/cmdline/__init__.py b/src/lobster/model/latent_generator/cmdline/__init__.py index 70f5354d..825eabf7 100644 --- a/src/lobster/model/latent_generator/cmdline/__init__.py +++ b/src/lobster/model/latent_generator/cmdline/__init__.py @@ -1 +1,4 @@ from .inference import decode, encode, load_model, methods, LatentEncoderDecoder + +# Re-export minimization utilities for convenience +from lobster.model.latent_generator.utils import get_ligand_energy, minimize_ligand_structure diff --git a/src/lobster/model/latent_generator/cmdline/inference.py b/src/lobster/model/latent_generator/cmdline/inference.py index 36191271..c9185c85 100644 --- a/src/lobster/model/latent_generator/cmdline/inference.py +++ b/src/lobster/model/latent_generator/cmdline/inference.py @@ -17,6 +17,7 @@ from lobster.model.latent_generator.io import load_ligand, load_pdb, writepdb, writepdb_ligand_complex from lobster.model.latent_generator.tokenizer import TokenizerMulti +from lobster.model.latent_generator.utils import get_ligand_energy, minimize_ligand_structure py_logger = logging.getLogger(__name__) @@ -46,6 +47,27 @@ class ModelInfo: methods = { # Ligand Models # These models are optimized for ligand structure analysis + "LG Ligand": ModelInfo( + description="Ligand only model with", + features=["256-dim embeddings", "Ligand only decoder", "512 ligand tokens"], + model_config=ModelConfig( + # S3 backup: s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Ligand_2025-11-09.ckpt + checkpoint="/cv/data/ai4dd/data2/ume/latent_generator_/runs//2025-11-09T14-23-55/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=256", + "tokenizer.quantizer.embed_dim=256", + "tokenizer.structure_encoder.encode_ligand=true", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=512", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_dim=512", + "tokenizer.quantizer.ligand_n_tokens=512", + "tokenizer/quantizer=slq_quantizer_ligand", + "tokenizer/decoder_factory=struc_decoder_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + ], + ), + ), "LG Ligand 20A": ModelInfo( description="Ligand only model with 20Å spatial attention", features=["256-dim embeddings", "20Å spatial attention", "Ligand only decoder", "512 ligand tokens"], @@ -159,6 +181,236 @@ class ModelInfo: ), # Protein-Ligand Models # These models can handle both protein and ligand structures + "LG Protein Ligand": ModelInfo( + description="Protein-ligand model with structure-only encoding", + features=[ + "256-dim embeddings", + "Ligand encoding support", + "512 ligand tokens", + "512 protein tokens", + ], + model_config=ModelConfig( + checkpoint="/cv/data/ai4dd/data2/ume/latent_generator_/runs//2025-12-07T22-38-42/epoch=830-step=88917-val_loss=16.5010.ckpt", # "/cv/data/ai4dd/data2/ume/latent_generator_/runs//2025-11-26T15-51-49/last.ckpt", #"/cv/data/ai4dd/data2/ume/latent_generator_/runs//2025-11-25T14-42-33/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=4", + "tokenizer.quantizer.embed_dim=4", + "tokenizer.quantizer.ligand_embed_dim=4", + "tokenizer.structure_encoder.encode_ligand=true", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=512", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_dim=512", + "tokenizer.quantizer.ligand_n_tokens=512", + "tokenizer/quantizer=slq_quantizer_ligand", + "tokenizer/decoder_factory=struc_decoder_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + ], + ), + ), + "LG Protein Ligand 4096": ModelInfo( + description="Protein-ligand model with structure-only encoding", + features=[ + "256-dim embeddings", + "Ligand encoding support", + "4096 ligand tokens", + "4096 protein tokens", + ], + model_config=ModelConfig( + # S3 backup: s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_4096_2026-01-05.ckpt + checkpoint="/cv/data/ai4dd/data2/ume/latent_generator_/runs//2026-01-05T16-48-02/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=4", + "tokenizer.quantizer.embed_dim=4", + "tokenizer.quantizer.ligand_embed_dim=4", + "tokenizer.structure_encoder.encode_ligand=true", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=4096", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=4096", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_dim=512", + "tokenizer.quantizer.ligand_n_tokens=4096", + "tokenizer.quantizer.n_tokens=4096", + "tokenizer/quantizer=slq_quantizer_ligand", + "tokenizer/decoder_factory=struc_decoder_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + ], + ), + ), + "LG Protein Ligand fsq 4375": ModelInfo( + description="Protein-ligand model with FSQ quantization (4375 tokens)", + features=[ + "5-dim embeddings", + "FSQ quantization", + "Ligand encoding support", + "4375 ligand tokens", + "4375 protein tokens", + ], + model_config=ModelConfig( + # S3 backup: s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_fsq_4375_2026-01-05.ckpt + # data2: /cv/data/ai4dd/data2/ume/latent_generator_/runs//2026-01-05T16-13-19/last.ckpt + checkpoint="s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_fsq_4375_2026-01-05.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=5", + "tokenizer.structure_encoder.encode_ligand=true", + "tokenizer/quantizer=fsq_quantizer_ligand", + "tokenizer.quantizer.protein_levels=[7,5,5,5,5]", + "tokenizer.quantizer.ligand_levels=[7,5,5,5,5]", + "tokenizer/decoder_factory=struc_decoder_ligand", + "tokenizer/loss_factory=structure_losses_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=4375", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=4375", + ], + ), + ), + "LG Protein Ligand fsq 4375 15360": ModelInfo( + description="Protein-ligand model with FSQ quantization (4375 protein tokens, 15360 ligand tokens)", + features=[ + "5-dim embeddings", + "FSQ quantization", + "Ligand encoding support", + "15360 ligand tokens", + "4375 protein tokens", + ], + model_config=ModelConfig( + # S3 backup: s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_Protein_Ligand_fsq_4375_15360_2026-01-07.ckpt + checkpoint="/cv/data/ai4dd/data2/ume/latent_generator_/runs//2026-01-07T02-17-14/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=5", + "tokenizer.structure_encoder.encode_ligand=true", + "tokenizer/quantizer=fsq_quantizer_ligand", + "tokenizer.quantizer.protein_levels=[7,5,5,5,5]", + "tokenizer.quantizer.ligand_levels=[8,8,8,6,5]", + "tokenizer/decoder_factory=struc_decoder_ligand", + "tokenizer/loss_factory=structure_losses_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=4375", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=15360", + ], + ), + ), + "LG Protein Ligand fsq 240 4375": ModelInfo( + description="Protein-ligand model with FSQ quantization (240 protein tokens, 4375 ligand tokens)", + features=[ + "5-dim embeddings", + "FSQ quantization", + "Ligand encoding support", + "4375 ligand tokens", + "240 protein tokens", + ], + model_config=ModelConfig( + # Trained with different FSQ levels for smaller protein codebook (240 vs 4375) + checkpoint="/cv/scratch/u/lisanzas/latent_generator_fsq_sair_240_4375/runs//2026-02-02T21-26-19/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=5", + "tokenizer.structure_encoder.encode_ligand=true", + "tokenizer/quantizer=fsq_quantizer_ligand", + "tokenizer.quantizer.protein_levels=[5,4,3,2,2]", + "tokenizer.quantizer.ligand_levels=[7,5,5,5,5]", + "tokenizer/decoder_factory=struc_decoder_ligand", + "tokenizer/loss_factory=structure_losses_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=240", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=4375", + ], + ), + ), + "LG Protein Ligand fsq 4375 15360 bond": ModelInfo( + description="Protein-ligand model with FSQ quantization, bond matrix embedding, and extended element vocabulary", + features=[ + "5-dim embeddings", + "FSQ quantization", + "Ligand encoding support", + "Bond matrix embedding", + "Extended element vocabulary (25 tokens)", + "15360 ligand tokens", + "4375 protein tokens", + ], + model_config=ModelConfig( + # Trained with: slurm/scripts/train_latent_generator_protein_ligand_fsq_bond_element.sh + checkpoint="/cv/scratch/u/lisanzas/latent_generator_bond_element/runs/2026-01-24T20-54-23/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=5", + "tokenizer.structure_encoder.encode_ligand=true", + "tokenizer.structure_encoder.ligand_atom_embedding=true", + "tokenizer.structure_encoder.use_ligand_bond_embedding=true", + "tokenizer.structure_encoder.use_extended_element_vocab=true", + "tokenizer/quantizer=fsq_quantizer_ligand", + "tokenizer.quantizer.protein_levels=[7,5,5,5,5]", + "tokenizer.quantizer.ligand_levels=[8,8,8,6,5]", + "tokenizer/decoder_factory=struc_decoder_ligand", + "tokenizer/loss_factory=structure_losses_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=4375", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=15360", + ], + ), + ), + "LG Protein Ligand cont": ModelInfo( + description="Protein-ligand model with CONTINUOUS embeddings (no quantization)", + features=[ + "256-dim continuous embeddings", + "No quantization (quantizer=null)", + "Ligand encoding support", + "Bond matrix embedding", + "Extended element vocabulary (25 tokens)", + "For use with DiffusionLoss in Gen-UME", + ], + model_config=ModelConfig( + # Trained with: slurm/scripts/train_latent_generator_protein_ligand_continuous_bond_element.sh + checkpoint="/cv/scratch/u/lisanzas/latent_generator_continuous_bond_element/runs/2026-01-24T21-03-23/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=256", + "tokenizer.structure_encoder.embed_dim_hidden=512", + "tokenizer.structure_encoder.encode_ligand=true", + "tokenizer.structure_encoder.ligand_atom_embedding=true", + "tokenizer.structure_encoder.use_ligand_bond_embedding=true", + "tokenizer.structure_encoder.use_extended_element_vocab=true", + "tokenizer.quantizer=null", + "tokenizer/decoder_factory=struc_decoder_ligand", + "tokenizer/loss_factory=structure_losses_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=256", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=256", + ], + ), + ), + "LG Protein Ligand fsq 1000": ModelInfo( + description="Protein-ligand model with FSQ quantization (1000 tokens)", + features=[ + "4-dim embeddings", + "FSQ quantization", + "Ligand encoding support", + "1000 ligand tokens", + "1000 protein tokens", + ], + model_config=ModelConfig( + checkpoint="/cv/data/ai4dd/data2/ume/latent_generator_/runs//2025-12-13T14-57-53/epoch=210-step=22577-val_loss=17.2066.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=4", + "tokenizer.structure_encoder.encode_ligand=true", + "tokenizer/quantizer=fsq_quantizer_ligand", + "tokenizer.quantizer.protein_levels=[8,5,5,5]", + "tokenizer.quantizer.ligand_levels=[8,5,5,5]", + "tokenizer/decoder_factory=struc_decoder_ligand", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.encode_ligand=true", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=1000", + "+tokenizer.decoder_factory.decoder_mapping.vit_decoder.ligand_struc_token_codebook_size=1000", + ], + ), + ), "LG Ligand 20A seq 3di Aux": ModelInfo( description="Protein-ligand model with sequence and 3Di awareness", features=[ @@ -359,6 +611,31 @@ class ModelInfo: overrides=[], ), ), + "LG full attention 2": ModelInfo( + description="Full attention model without spatial masking", + features=["Standard configuration", "Full attention (no spatial masking)", "256 protein tokens"], + model_config=ModelConfig( + # S3 backup: s3://prescient-pcluster-data/gen_ume/checkpoints/latent_generator/LG_full_attention_2_2025-11-06.ckpt + checkpoint="/cv/data/ai4dd/data2/ume/latent_generator_/runs//2025-11-06T00-40-11/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[], + ), + ), + "LG full attention 512 PDB Pinder FSQ": ModelInfo( + description="Full attention model with 512 protein tokens and FSQ quantization", + features=["240 protein tokens", "FSQ quantization"], + model_config=ModelConfig( + checkpoint="/cv/data/ai4dd/data2/lisanzas/latent_generator/studies/outputs/train/dev/runs/2025-11-09_22-19-12/checkpoints/last.ckpt", + config_path="../../latent_generator/hydra_config/", + config_name="train_multi", + overrides=[ + "tokenizer.structure_encoder.embed_dim=3", + "tokenizer/quantizer=fsq_quantizer", + "tokenizer.decoder_factory.decoder_mapping.vit_decoder.struc_token_codebook_size=240", + ], + ), + ), } @@ -436,8 +713,14 @@ def load_model( try: s3 = boto3.client("s3") bucket_name, key = checkpoint_path[5:].split("/", 1) # Extract bucket and key - local_checkpoint_path = "/tmp/" + os.path.basename(key) # Temporary local path - s3.download_file(bucket_name, key, local_checkpoint_path) + cache_dir = os.path.expanduser("~/.cache/lobster") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(key) + local_checkpoint_path = os.path.join(cache_dir, filename) + if not os.path.exists(local_checkpoint_path): + s3.download_file(bucket_name, key, local_checkpoint_path) + else: + py_logger.info(f"Checkpoint already exists at {local_checkpoint_path}") checkpoint_path = local_checkpoint_path # Update checkpoint_path to the local file except NoCredentialsError as e: raise RuntimeError("AWS credentials not found. Ensure they are configured properly.") from e @@ -452,10 +735,14 @@ def load_model( # Extract filename from URL filename = os.path.basename(urllib.parse.urlparse(checkpoint_path).path) - local_checkpoint_path = "/tmp/" + filename - - py_logger.info(f"Downloading checkpoint from Hugging Face: {checkpoint_path}") - urllib.request.urlretrieve(checkpoint_path, local_checkpoint_path) + cache_dir = os.path.expanduser("~/.cache/lobster") + os.makedirs(cache_dir, exist_ok=True) + local_checkpoint_path = os.path.join(cache_dir, filename) + if not os.path.exists(local_checkpoint_path): + py_logger.info(f"Downloading checkpoint from Hugging Face: {checkpoint_path}") + urllib.request.urlretrieve(checkpoint_path, local_checkpoint_path) + else: + py_logger.info(f"Checkpoint already exists at {local_checkpoint_path}") checkpoint_path = local_checkpoint_path # Update checkpoint_path to the local file py_logger.info(f"Checkpoint downloaded to: {checkpoint_path}") except Exception as e: @@ -614,6 +901,10 @@ def decode(self, latents: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]: encode = encoder_decoder.encode decode = encoder_decoder.decode +# Minimization functions are imported from utils and can be accessed via: +# from lobster.model.latent_generator.cmdline import minimize_ligand_structure, get_ligand_energy +# or directly from: from lobster.model.latent_generator.utils import minimize_ligand_structure, get_ligand_energy + def main(): """ @@ -663,6 +954,32 @@ def main(): Protein-Ligand Models: --------------------- +LG Protein Ligand + Description: Protein-ligand model with structure-only encoding + Features: + - 256-dim embeddings + - Ligand encoding support + - 512 ligand tokens + - 512 protein tokens + +LG Protein Ligand fsq 4375 + Description: Protein-ligand model with FSQ quantization (4375 tokens) + Features: + - 5-dim embeddings + - FSQ quantization + - Ligand encoding support + - 4375 ligand tokens + - 4375 protein tokens + +LG Protein Ligand fsq 1000 + Description: Protein-ligand model with FSQ quantization (1000 tokens) + Features: + - 4-dim embeddings + - FSQ quantization + - Ligand encoding support + - 1000 ligand tokens + - 1000 protein tokens + LG Ligand 20A seq 3di Aux Description: Protein-ligand model with sequence and 3Di awareness Features: @@ -764,8 +1081,48 @@ def main(): parser.add_argument( "--output_file_decode", type=str, default="decoded_outputs.pt", help="Path to save decoded outputs" ) + parser.add_argument( + "--output_pdb", + type=str, + default=None, + help="Path to save decoded structure as PDB (optional, auto-generated if not provided)", + ) parser.add_argument("--overrides", type=str, nargs="+", help="Configuration overrides in the format key=value") + # Minimization options + parser.add_argument( + "--minimize", + action="store_true", + help="Minimize ligand structure after decoding using Open Babel force field", + ) + parser.add_argument( + "--minimize_steps", + type=int, + default=500, + help="Maximum number of minimization steps (default: 500)", + ) + parser.add_argument( + "--force_field", + type=str, + default="MMFF94", + choices=["MMFF94", "MMFF94s", "UFF", "GAFF", "Ghemical"], + help="Force field for minimization (default: MMFF94)", + ) + parser.add_argument( + "--minimize_method", + type=str, + default="cg", + choices=["cg", "sd"], + help="Optimization method: cg (conjugate gradients) or sd (steepest descent)", + ) + parser.add_argument( + "--minimize_mode", + type=str, + default="bonds_and_angles", + choices=["bonds_only", "bonds_and_angles"], + help="Minimization mode: 'bonds_only' (ideal bond lengths), 'bonds_and_angles' (ideal bonds + angles, recommended)", + ) + args = parser.parse_args() # Set up logging @@ -775,13 +1132,22 @@ def main(): py_logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}") # Load the model with overrides if provided - if ( - args.model_name != "LG Ligand 20A seq 3di Aux" - and args.model_name != "LG Ligand 20A" - and args.model_name != "LG Ligand 20A continuous" - ) and args.ligand_path is not None: + ligand_supported_models = [ + "LG Protein Ligand", + "LG Protein Ligand 4096", + "LG Protein Ligand fsq 4375", + "LG Protein Ligand fsq 4375 15360", + "LG Protein Ligand fsq 240 4375", + "LG Protein Ligand fsq 4375 15360 bond", + "LG Protein Ligand fsq 1000", + "LG Protein Ligand cont", + "LG Ligand 20A seq 3di Aux", + "LG Ligand 20A", + "LG Ligand 20A continuous", + ] + if args.model_name not in ligand_supported_models and args.ligand_path is not None: raise ValueError( - "Ligand path is only supported for LG Ligand 20A seq 3di Aux model, LG Ligand 20A model or LG Ligand 20A continuous model" + f"Ligand path is only supported for the following models: {', '.join(ligand_supported_models)}" ) if args.model_name in methods: @@ -815,6 +1181,9 @@ def main(): pdb_data["ligand_residue_index"] = ligand_data["atom_indices"] pdb_data["ligand_atom_names"] = ligand_data["atom_names"] pdb_data["ligand_indices"] = ligand_data["atom_indices"] + # Include bond matrix for CONECT records and minimization + if "bond_matrix" in ligand_data: + pdb_data["ligand_bond_matrix"] = ligand_data["bond_matrix"] if args.model_name in [ "LG ESMC 300M 256 cont", @@ -840,7 +1209,70 @@ def main(): if isinstance(decoded_outputs, dict): x_recon_ligand = decoded_outputs["ligand_coords"] x_recon_xyz = decoded_outputs["protein_coords"] - filename = f"{args.output_file_decode.split('.')[0]}_ligand_decoded.pdb" + + # Get bond matrix and atom names for CONECT records (always retrieve these) + bond_matrix = decoded_outputs.get("ligand_bond_matrix", None) + if bond_matrix is None: + bond_matrix = pdb_data.get("ligand_bond_matrix", None) + if bond_matrix is None: + bond_matrix = pdb_data.get("bond_matrix", None) + + ligand_atom_names = pdb_data.get("ligand_atom_names", None) + if ligand_atom_names is None: + ligand_atom_names = pdb_data.get("atom_names", None) + + # Apply ligand minimization if requested + if args.minimize and x_recon_ligand is not None: + py_logger.info( + f"Minimizing ligand structure (mode={args.minimize_mode}, force_field={args.force_field}, " + f"steps={args.minimize_steps}, method={args.minimize_method})" + ) + # Get atom types from input data if available + ligand_atom_types = ligand_atom_names + if ligand_atom_types is None: + # Default to carbon for unknown atoms + num_atoms = x_recon_ligand.shape[1] + ligand_atom_types = ["C"] * num_atoms + py_logger.warning(f"No atom types provided, defaulting to Carbon for {num_atoms} atoms") + + try: + # Calculate energy before minimization + energy_before = get_ligand_energy( + x_recon_ligand[0], ligand_atom_types, bond_matrix, args.force_field + ) + py_logger.info(f"Energy before minimization: {energy_before:.2f} kcal/mol") + + # Minimize + x_recon_ligand_minimized = minimize_ligand_structure( + x_recon_ligand[0], + ligand_atom_types, + bond_matrix=bond_matrix, + steps=args.minimize_steps, + force_field=args.force_field, + method=args.minimize_method, + mode=args.minimize_mode, + ) + + # Calculate energy after minimization + energy_after = get_ligand_energy( + x_recon_ligand_minimized, ligand_atom_types, bond_matrix, args.force_field + ) + py_logger.info(f"Energy after minimization: {energy_after:.2f} kcal/mol") + py_logger.info(f"Energy reduction: {energy_before - energy_after:.2f} kcal/mol") + + # Update ligand coordinates + x_recon_ligand = x_recon_ligand_minimized.unsqueeze(0) + decoded_outputs["ligand_coords"] = x_recon_ligand + py_logger.info("Ligand minimization completed successfully") + except Exception as e: + py_logger.warning(f"Ligand minimization failed: {e}. Using original coordinates.") + + # Determine PDB output filename + if args.output_pdb: + filename = args.output_pdb + else: + filename = f"{args.output_file_decode.split('.')[0]}_ligand_decoded.pdb" + if x_recon_xyz is not None: if sequence_outputs is not None: seq = sequence_outputs.argmax(dim=-1) @@ -849,10 +1281,10 @@ def main(): seq = torch.zeros(x_recon_xyz.shape[1], dtype=torch.long)[None] else: seq = None - filename = f"{args.output_file_decode.split('.')[0]}_ligand_only_decoded.pdb" + if not args.output_pdb: + filename = f"{args.output_file_decode.split('.')[0]}_ligand_only_decoded.pdb" if x_recon_ligand is not None: ligand_atoms = x_recon_ligand[0] - ligand_atom_names = None ligand_chain = "L" ligand_resname = "LIG" if x_recon_xyz is not None: @@ -864,6 +1296,7 @@ def main(): ligand_resname=ligand_resname, protein_atoms=x_recon_xyz[0], protein_seq=seq[0], + ligand_bond_matrix=bond_matrix, ) else: writepdb_ligand_complex( @@ -872,16 +1305,21 @@ def main(): ligand_atom_names=ligand_atom_names, ligand_chain=ligand_chain, ligand_resname=ligand_resname, + ligand_bond_matrix=bond_matrix, ) + py_logger.info(f"PDB structure saved to {filename}") else: if sequence_outputs is not None: seq = sequence_outputs.argmax(dim=-1) seq[seq == 22] = 21 else: seq = torch.zeros(decoded_outputs.shape[1], dtype=torch.long)[None] - filename = f"{args.output_file_decode.split('.')[0]}_decoded.pdb" - writepdb(filename, decoded_outputs[0], seq[0]) + if args.output_pdb: + filename = args.output_pdb + else: + filename = f"{args.output_file_decode.split('.')[0]}_decoded.pdb" writepdb(filename, decoded_outputs[0], seq[0]) + py_logger.info(f"PDB structure saved to {filename}") # Save decoded outputs torch.save(decoded_outputs, args.output_file_decode) diff --git a/src/lobster/model/latent_generator/datamodules/_utils.py b/src/lobster/model/latent_generator/datamodules/_utils.py index 57a35b1f..fd9cf631 100644 --- a/src/lobster/model/latent_generator/datamodules/_utils.py +++ b/src/lobster/model/latent_generator/datamodules/_utils.py @@ -1,34 +1,215 @@ +import logging + import torch import torch.nn.functional as F from lobster.model.latent_generator.utils import residue_constants +logger = logging.getLogger(__name__) + + +# Padding values - MUST match what's used in collation functions +PROTEIN_PADDING_VALUES = { + "coords_res": 0.0, + "mask": 0.0, + "indices": -1, + "sequence": None, # Set at runtime to residue_constants.PEPTIDE_ALPHABET.index("-") + "chains": -1, + "template_coords": 0.0, + "template_mask": 0.0, + "3di_states": 0.0, + "3di_descriptors": 0.0, + "c6d": 0.0, + "c6d_mask": False, + "c6d_binned": 0.0, + "plm_embeddings": 0.0, + "graph_label": 0.0, + "zernlike_descriptors": 0.0, + "geometric_features": 0.0, +} + +LIGAND_PADDING_VALUES = { + "ligand_coords": 0.0, + "ligand_mask": 0.0, + "ligand_indices": -1, + "ligand_element_indices": 0, + "radius_of_gyration": 0.0, + "solvent_accessible_surface_area": 0.0, +} + + +def get_padding_value(key: str, dtype: torch.dtype | None = None): + """Get padding value for a field, matching standard collation behavior.""" + if key in PROTEIN_PADDING_VALUES: + val = PROTEIN_PADDING_VALUES[key] + if val is None and key == "sequence": + return residue_constants.PEPTIDE_ALPHABET.index("-") + return val + if key in LIGAND_PADDING_VALUES: + return LIGAND_PADDING_VALUES[key] + # Default fallbacks + if "mask" in key: + return False if dtype == torch.bool else 0.0 + elif "indices" in key or "chains" in key: + return -1 + else: + return 0.0 + def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: - """Collate fn for batching protein backbone data.""" - if "protein" and "ligand" in batch[0]: - ligand_batch = [bb_dict["ligand"] for bb_dict in batch] - batch = [bb_dict["protein"] for bb_dict in batch] - # make sure batch is not list of None - if batch[0] is not None: - protein_present = True - batch = collate_fn_backbone(batch) + """Collate function with unified batch dimensions and validity masks. + + BACKWARDS COMPATIBILITY: + - Pure protein-only batches: Use original collation (no validity masks) + - Pure ligand-only batches: Use original collation (no validity masks) + - Pure paired batches: Use original collation (no validity masks) + - Mixed batches: Use unified batch with validity masks (NEW) + + Handles: + - StructureDataset items: {"coords_res": ..., "mask": ..., ...} + - LigandDataset items: {"protein": None or {...}, "ligand": {...}} + """ + batch_size = len(batch) + + # Categorize batch items to determine if homogeneous or heterogeneous + has_structure_items = False # Pure protein items (StructureDataset) + has_ligand_only_items = False # Ligand-only items (protein=None) + has_paired_items = False # Protein-ligand pairs (both present) + + for item in batch: + if "protein" in item and "ligand" in item: + # LigandDataset format + if item["protein"] is not None and item["ligand"] is not None: + has_paired_items = True + elif item["protein"] is not None: + # Shouldn't happen, but count as structure + has_structure_items = True + else: + # Ligand only + has_ligand_only_items = True + else: + # StructureDataset format + has_structure_items = True + + # Check if batch is homogeneous (backwards compatible case) + num_types = sum([has_structure_items, has_ligand_only_items, has_paired_items]) + + if num_types == 1: + # HOMOGENEOUS BATCH - use original collation for backwards compatibility + if has_structure_items: + # Pure protein-only batch - use original collation + logger.debug("Homogeneous protein-only batch - using original collation") + return _collate_proteins(batch) + elif has_ligand_only_items: + # Pure ligand-only batch - extract ligands and use original collation + logger.debug("Homogeneous ligand-only batch - using original collation") + ligand_items = [item["ligand"] for item in batch] + return collate_fn_ligand(ligand_items) + elif has_paired_items: + # Pure paired batch - extract and collate + logger.debug("Homogeneous paired batch - using original collation") + protein_items = [item["protein"] for item in batch] + ligand_items = [item["ligand"] for item in batch] + + protein_collated = _collate_proteins(protein_items) + ligand_collated = collate_fn_ligand(ligand_items) + + # Combine without validity masks (backwards compatible) + return {**protein_collated, **ligand_collated} + + # HETEROGENEOUS BATCH - use new unified batch approach with validity masks + logger.debug( + f"Heterogeneous batch detected (size={batch_size}): " + f"structure={has_structure_items}, ligand_only={has_ligand_only_items}, paired={has_paired_items}" + ) + + # Normalize all items to {"protein": ..., "ligand": ...} format + normalized_batch = [] + for item in batch: + if "protein" in item and "ligand" in item: + # LigandDataset format - already normalized + normalized_batch.append(item) else: - protein_present = False - ligand_batch = collate_fn_ligand(ligand_batch) - if protein_present: - # combine batch and ligand_batch - batch = {**batch, **ligand_batch} + # StructureDataset format - wrap it + normalized_batch.append({"protein": item, "ligand": None}) + + # Extract components and build validity masks + protein_items = [] + ligand_items = [] + protein_valid_mask = [] + ligand_valid_mask = [] + + for item in normalized_batch: + protein_items.append(item["protein"]) + ligand_items.append(item["ligand"]) + protein_valid_mask.append(item["protein"] is not None) + ligand_valid_mask.append(item["ligand"] is not None) + + # Convert to tensors + protein_valid_mask = torch.tensor(protein_valid_mask, dtype=torch.bool) + ligand_valid_mask = torch.tensor(ligand_valid_mask, dtype=torch.bool) + + logger.debug( + f"Validity masks: protein={protein_valid_mask.sum().item()}/{batch_size}, " + f"ligand={ligand_valid_mask.sum().item()}/{batch_size}" + ) + + # Collate and expand + result = {} + + # Collate valid proteins + if protein_valid_mask.any(): + valid_protein_items = [p for p in protein_items if p is not None] + protein_collated = _collate_proteins(valid_protein_items) + + # Expand to full batch size if needed + if protein_valid_mask.all(): + # All items have protein - no expansion needed + result.update(protein_collated) else: - batch = ligand_batch - return batch + # Some items don't have protein - need to expand with padding + expanded = _expand_protein_to_full_batch(protein_collated, protein_valid_mask, batch_size) + result.update(expanded) + else: + # No proteins in batch - create minimal placeholders + result.update(_create_empty_protein_batch(batch_size)) + + # Collate valid ligands + if ligand_valid_mask.any(): + valid_ligand_items = [l for l in ligand_items if l is not None] + ligand_collated = collate_fn_ligand(valid_ligand_items) + # Expand to full batch size if needed + if ligand_valid_mask.all(): + # All items have ligand - no expansion needed + result.update(ligand_collated) + else: + # Some items don't have ligand - need to expand with padding + expanded = _expand_ligand_to_full_batch(ligand_collated, ligand_valid_mask, batch_size) + result.update(expanded) + else: + # No ligands in batch - create minimal placeholders + result.update(_create_empty_ligand_batch(batch_size)) + + # Add validity masks + result["protein_valid_mask"] = protein_valid_mask + result["ligand_valid_mask"] = ligand_valid_mask + + return result + + +def _collate_proteins(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + """Collate protein-only items (original collate_fn_backbone logic). + + This is the ORIGINAL implementation extracted for backwards compatibility. + """ max_length = max(bb_dict["coords_res"].shape[0] for bb_dict in batch) padded_coords_res = [] padded_mask = [] padded_indices = [] padded_sequence = [] padded_chains = [] + if "3di_states" in batch[0]: padded_3di_states = [] padded_3di_descriptors = [] @@ -50,32 +231,16 @@ def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch mask = bb_dict["mask"] indices = bb_dict["indices"] chains = bb_dict["chains"] + padded_coords_res.append( torch.cat( - [ - coords_res, - torch.zeros(max_length - coords_res.shape[0], *coords_res.shape[1:]), - ], - dim=0, - ) - ) - padded_mask.append( - torch.cat( - [ - mask, - torch.zeros(max_length - mask.shape[0], *mask.shape[1:]), - ], + [coords_res, torch.zeros(max_length - coords_res.shape[0], *coords_res.shape[1:])], dim=0, ) ) + padded_mask.append(torch.cat([mask, torch.zeros(max_length - mask.shape[0], *mask.shape[1:])], dim=0)) padded_indices.append( - torch.cat( - [ - indices, - torch.full((max_length - indices.shape[0],), -1, dtype=indices.dtype), - ], - dim=0, - ) + torch.cat([indices, torch.full((max_length - indices.shape[0],), -1, dtype=indices.dtype)], dim=0) ) padded_sequence.append( torch.cat( @@ -91,21 +256,17 @@ def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch ) ) padded_chains.append( - torch.cat( - [ - chains, - torch.full((max_length - chains.shape[0],), -1, dtype=chains.dtype), - ], - dim=0, - ) + torch.cat([chains, torch.full((max_length - chains.shape[0],), -1, dtype=chains.dtype)], dim=0) ) + if "template_coords" in batch[0]: padded_template_coords.append( torch.cat( [ bb_dict["template_coords"], torch.zeros( - max_length - bb_dict["template_coords"].shape[0], *bb_dict["template_coords"].shape[1:] + max_length - bb_dict["template_coords"].shape[0], + *bb_dict["template_coords"].shape[1:], ), ], dim=0, @@ -176,6 +337,7 @@ def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch dim=0, ) ) + out = { "coords_res": torch.stack(padded_coords_res, dim=0), "mask": torch.stack(padded_mask, dim=0), @@ -183,6 +345,7 @@ def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch "sequence": torch.stack(padded_sequence, dim=0), "chains": torch.stack(padded_chains, dim=0), } + if "3di_states" in batch[0]: out["3di_states"] = torch.stack(padded_3di_states, dim=0) out["3di_descriptors"] = torch.stack(padded_3di_descriptors, dim=0) @@ -191,6 +354,7 @@ def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch out["c6d"] = torch.stack(padded_c6d, dim=0) out["c6d_mask"] = torch.stack(padded_c6d_mask, dim=0) out["c6d_binned"] = torch.stack(padded_c6d_binned, dim=0) + if "graph_label" in batch[0]: out["graph_label"] = torch.stack([bb_dict["graph_label"] for bb_dict in batch], dim=0) if "zernlike_descriptors" in batch[0]: @@ -204,65 +368,167 @@ def collate_fn_backbone(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch out["template_mask"] = torch.stack(padded_template_mask, dim=0) if "name" in batch[0]: - out["name"] = [bb_dict["name"] for bb_dict in batch] + out["name"] = [bb_dict.get("name", None) for bb_dict in batch] return out +def _expand_protein_to_full_batch( + collated: dict[str, torch.Tensor], valid_mask: torch.Tensor, batch_size: int +) -> dict[str, torch.Tensor]: + """Expand collated protein batch to full batch size with padding.""" + result = {} + n_valid = collated["coords_res"].shape[0] + + assert valid_mask.sum() == n_valid, f"valid_mask count {valid_mask.sum()} doesn't match {n_valid}" + + for key, value in collated.items(): + if key == "name": + # Handle name specially (list, not tensor) + expanded_names = [None] * batch_size + valid_idx = 0 + for i in range(batch_size): + if valid_mask[i]: + expanded_names[i] = value[valid_idx] + valid_idx += 1 + result[key] = expanded_names + continue + + if not isinstance(value, torch.Tensor): + result[key] = value + continue + + # Get padding value using centralized function + pad_value = get_padding_value(key, value.dtype) + + # Create full batch tensor + full_shape = (batch_size,) + value.shape[1:] + if value.dtype == torch.bool: + expanded = torch.zeros(full_shape, dtype=torch.bool, device=value.device) + else: + expanded = torch.full(full_shape, pad_value, dtype=value.dtype, device=value.device) + + # Fill in valid positions + expanded[valid_mask] = value + result[key] = expanded + + return result + + +def _expand_ligand_to_full_batch( + collated: dict[str, torch.Tensor], valid_mask: torch.Tensor, batch_size: int +) -> dict[str, torch.Tensor]: + """Expand collated ligand batch to full batch size with padding.""" + result = {} + n_valid = collated["ligand_coords"].shape[0] + + assert valid_mask.sum() == n_valid, f"valid_mask count {valid_mask.sum()} doesn't match {n_valid}" + + for key, value in collated.items(): + if not isinstance(value, torch.Tensor): + result[key] = value + continue + + # Get padding value using centralized function + pad_value = get_padding_value(key, value.dtype) + + # Create full batch tensor + full_shape = (batch_size,) + value.shape[1:] + if value.dtype == torch.bool: + expanded = torch.zeros(full_shape, dtype=torch.bool, device=value.device) + else: + expanded = torch.full(full_shape, pad_value, dtype=value.dtype, device=value.device) + + # Fill in valid positions + expanded[valid_mask] = value + result[key] = expanded + + return result + + +def _create_empty_protein_batch(batch_size: int) -> dict[str, torch.Tensor]: + """Create minimal empty protein batch as placeholder.""" + return { + "coords_res": torch.zeros(batch_size, 1, 3, 3), + "mask": torch.zeros(batch_size, 1, dtype=torch.bool), + "indices": torch.full((batch_size, 1), -1, dtype=torch.long), + "sequence": torch.full((batch_size, 1), -1, dtype=torch.long), + "chains": torch.full((batch_size, 1), -1, dtype=torch.long), + } + + +def _create_empty_ligand_batch(batch_size: int) -> dict[str, torch.Tensor]: + """Create minimal empty ligand batch as placeholder.""" + return { + "ligand_coords": torch.zeros(batch_size, 1, 3), + "ligand_mask": torch.zeros(batch_size, 1, dtype=torch.bool), + "ligand_indices": torch.full((batch_size, 1), -1, dtype=torch.long), + } + + def collate_fn_ligand(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: - """Collate fn for batching ligand data.""" + """Collate fn for batching ligand data. + + Handles: + - atom_coords: [N_atoms, 3] -> [batch, max_atoms, 3] + - mask: [N_atoms] -> [batch, max_atoms] + - atom_indices: [N_atoms] -> [batch, max_atoms] + - element_indices: [N_atoms] -> [batch, max_atoms] (optional) + - bond_matrix: [N_atoms, N_atoms] -> [batch, max_atoms, max_atoms] (optional) + - smiles: str (optional, passed through as list) + """ padded_ligand_coords = [] padded_ligand_mask = [] padded_ligand_indices = [] padded_element_indices = [] + padded_bond_matrices = [] + smiles_list = [] max_length = max(atom_dict["atom_coords"].shape[0] for atom_dict in batch) + has_element_indices = "element_indices" in batch[0] + has_bond_matrix = "bond_matrix" in batch[0] + has_smiles = "smiles" in batch[0] + for atom_dict in batch: ligand_coords = atom_dict["atom_coords"] ligand_mask = atom_dict["mask"] ligand_indices = atom_dict["atom_indices"] + n_atoms = ligand_coords.shape[0] + pad_length = max_length - n_atoms padded_ligand_coords.append( - torch.cat( - [ - ligand_coords, - torch.zeros(max_length - ligand_coords.shape[0], *ligand_coords.shape[1:]), - ], - dim=0, - ) - ) - padded_ligand_mask.append( - torch.cat( - [ - ligand_mask, - torch.zeros(max_length - ligand_mask.shape[0], *ligand_mask.shape[1:]), - ], - dim=0, - ) + torch.cat([ligand_coords, torch.zeros(pad_length, *ligand_coords.shape[1:])], dim=0) ) + padded_ligand_mask.append(torch.cat([ligand_mask, torch.zeros(pad_length, *ligand_mask.shape[1:])], dim=0)) padded_ligand_indices.append( torch.cat( - [ - ligand_indices, - torch.full((max_length - ligand_indices.shape[0],), -1, dtype=ligand_indices.dtype), - ], + [ligand_indices, torch.full((pad_length,), -1, dtype=ligand_indices.dtype)], dim=0, ) ) # Handle element indices if present - if "element_indices" in atom_dict: + if has_element_indices: element_indices = atom_dict["element_indices"] padded_element_indices.append( torch.cat( - [ - element_indices, - torch.zeros(max_length - element_indices.shape[0], dtype=element_indices.dtype), - ], + [element_indices, torch.zeros(pad_length, dtype=element_indices.dtype)], dim=0, ) ) + # Handle bond matrix if present + if has_bond_matrix: + bond_matrix = atom_dict["bond_matrix"] + # Pad bond_matrix from [n_atoms, n_atoms] to [max_length, max_length] + padded_bond = torch.zeros(max_length, max_length, dtype=bond_matrix.dtype) + padded_bond[:n_atoms, :n_atoms] = bond_matrix + padded_bond_matrices.append(padded_bond) + + # Handle SMILES if present + if has_smiles: + smiles_list.append(atom_dict["smiles"]) + out = { "ligand_coords": torch.stack(padded_ligand_coords, dim=0), "ligand_mask": torch.stack(padded_ligand_mask, dim=0), @@ -272,6 +538,12 @@ def collate_fn_ligand(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.T if padded_element_indices: out["ligand_element_indices"] = torch.stack(padded_element_indices, dim=0) + if padded_bond_matrices: + out["bond_matrix"] = torch.stack(padded_bond_matrices, dim=0) + + if smiles_list: + out["smiles"] = smiles_list + # Handle additional properties like radius_of_gyration if "radius_of_gyration" in batch[0]: out["radius_of_gyration"] = torch.tensor( diff --git a/src/lobster/model/latent_generator/datasets/__init__.py b/src/lobster/model/latent_generator/datasets/__init__.py index 0d39aee4..8b137891 100644 --- a/src/lobster/model/latent_generator/datasets/__init__.py +++ b/src/lobster/model/latent_generator/datasets/__init__.py @@ -1,9 +1 @@ -from ._structure_dataset_iterable import ShardedStructureDataset -from ._transforms import ( - BinderTargetTransform, - Structure3diTransform, - StructureBackboneTransform, - StructureC6DTransform, - StructureLigandTransform, - StructureResidueTransform, -) + diff --git a/src/lobster/model/latent_generator/hydra_config/callbacks/backbone_reconstruction.yaml b/src/lobster/model/latent_generator/hydra_config/callbacks/backbone_reconstruction.yaml index b3bcc3eb..dcfd9a9d 100644 --- a/src/lobster/model/latent_generator/hydra_config/callbacks/backbone_reconstruction.yaml +++ b/src/lobster/model/latent_generator/hydra_config/callbacks/backbone_reconstruction.yaml @@ -1,4 +1,10 @@ backbone_reconstruction: _target_: lobster.model.latent_generator.callbacks.BackboneReconstruction - structure_path: "${paths.run_path}/structures/" + structure_path: "${paths.output_dir}/structures/" save_every_n: 10000 + use_extended_element_vocab: false + # Ligand minimization options + minimize_ligand: false + minimize_mode: "bonds_and_angles" # "bonds_only" or "bonds_and_angles" + force_field: "MMFF94" + minimize_steps: 500 diff --git a/src/lobster/model/latent_generator/hydra_config/tokenizer/quantizer/fsq_quantizer.yaml b/src/lobster/model/latent_generator/hydra_config/tokenizer/quantizer/fsq_quantizer.yaml new file mode 100644 index 00000000..71d7580a --- /dev/null +++ b/src/lobster/model/latent_generator/hydra_config/tokenizer/quantizer/fsq_quantizer.yaml @@ -0,0 +1,2 @@ +_target_: lobster.model.latent_generator.quantizer.FiniteScalarQuantizer +levels: [8,6,5] \ No newline at end of file diff --git a/src/lobster/model/latent_generator/hydra_config/tokenizer/quantizer/fsq_quantizer_ligand.yaml b/src/lobster/model/latent_generator/hydra_config/tokenizer/quantizer/fsq_quantizer_ligand.yaml new file mode 100644 index 00000000..da2a22ff --- /dev/null +++ b/src/lobster/model/latent_generator/hydra_config/tokenizer/quantizer/fsq_quantizer_ligand.yaml @@ -0,0 +1,5 @@ +_target_: lobster.model.latent_generator.quantizer.FSQLigandTokenizer +protein_levels: [8, 6, 5] # 240 tokens for protein +ligand_levels: [8, 6, 5] # 240 tokens for ligand +return_oh_like: true + diff --git a/src/lobster/model/latent_generator/hydra_config/tokenizer/structure_encoder/vit_encoder.yaml b/src/lobster/model/latent_generator/hydra_config/tokenizer/structure_encoder/vit_encoder.yaml index 988f6285..35e0d3f7 100644 --- a/src/lobster/model/latent_generator/hydra_config/tokenizer/structure_encoder/vit_encoder.yaml +++ b/src/lobster/model/latent_generator/hydra_config/tokenizer/structure_encoder/vit_encoder.yaml @@ -12,3 +12,5 @@ backbone_noise: 0.30 concat_sine_pw: true encode_ligand: false ligand_atom_embedding: false +use_ligand_bond_embedding: false +use_extended_element_vocab: false diff --git a/src/lobster/model/latent_generator/io/__init__.py b/src/lobster/model/latent_generator/io/__init__.py index ef62c26d..40ce87d7 100644 --- a/src/lobster/model/latent_generator/io/__init__.py +++ b/src/lobster/model/latent_generator/io/__init__.py @@ -1,4 +1,10 @@ -from ._load_pdb import load_ligand, load_pdb +from ._load_pdb import ( + extract_bond_matrix, + extract_element_indices, + load_ligand, + load_pdb, + load_pdb_atom14, +) from ._write_pdb import writepdb, writepdb_ligand_complex from ._token_from_text import ( parse_tokens_from_text, diff --git a/src/lobster/model/latent_generator/io/_load_pdb.py b/src/lobster/model/latent_generator/io/_load_pdb.py index 1c2d3d5c..977492e8 100644 --- a/src/lobster/model/latent_generator/io/_load_pdb.py +++ b/src/lobster/model/latent_generator/io/_load_pdb.py @@ -9,6 +9,10 @@ from rdkit import Chem from lobster.model.latent_generator.utils import residue_constants +from lobster.model.latent_generator.utils.residue_constants import ( + ELEMENT_TO_IDX, + ELEMENT_VOCAB_EXTENDED_TO_IDX, +) try: import cpdb @@ -17,6 +21,94 @@ logger = logging.getLogger(__name__) + +# RDKit bond type to integer mapping (matches BOND_TYPES in residue_constants.py) +# 0=none, 1=single, 2=double, 3=triple, 4=aromatic, 5=other +RDKIT_BOND_TYPE_MAP = { + Chem.BondType.SINGLE: 1, + Chem.BondType.DOUBLE: 2, + Chem.BondType.TRIPLE: 3, + Chem.BondType.AROMATIC: 4, +} + + +def extract_bond_matrix(mol: Chem.Mol) -> torch.Tensor: + """Extract bond matrix from RDKit molecule. + + Creates a symmetric matrix where entry [i,j] indicates the bond type + between atoms i and j. + + Parameters + ---------- + mol : Chem.Mol + RDKit molecule object with atoms and bonds. + + Returns + ------- + torch.Tensor + Bond type matrix of shape [N_atoms, N_atoms] with values: + 0 = no bond + 1 = single bond + 2 = double bond + 3 = triple bond + 4 = aromatic bond + 5 = other bond type + """ + n_atoms = mol.GetNumAtoms() + bond_matrix = torch.zeros(n_atoms, n_atoms, dtype=torch.long) + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + bond_type = RDKIT_BOND_TYPE_MAP.get(bond.GetBondType(), 5) # 5 = OTHER + bond_matrix[i, j] = bond_type + bond_matrix[j, i] = bond_type # Symmetric + + return bond_matrix + + +def extract_element_indices(mol: Chem.Mol, use_extended_vocab: bool = False) -> torch.Tensor: + """Extract element indices from RDKit molecule. + + Maps each atom's element symbol to its index in the chosen vocabulary. + + Parameters + ---------- + mol : Chem.Mol + RDKit molecule object. + use_extended_vocab : bool + If True, use ELEMENT_VOCAB_EXTENDED (25 tokens) to match Gen-UME. + If False (default), use ELEMENT_VOCAB (14 tokens) for latent generator. + + Returns + ------- + torch.Tensor + Element indices of shape [N_atoms] with integer values. + + If use_extended_vocab=False (default, ELEMENT_VOCAB, 14 tokens): + 0=PAD, 1=B, 2=Bi, 3=Br, 4=C, 5=Cl, 6=F, 7=H, 8=I, 9=N, 10=O, 11=P, 12=S, 13=Si + + If use_extended_vocab=True (ELEMENT_VOCAB_EXTENDED, 25 tokens): + 0=PAD, 1=MASK, 2=UNK, 3=C, 4=N, 5=O, 6=S, 7=P, 8=F, 9=Cl, 10=Br, 11=I, + 12=B, 13=Si, 14=Se, 15=As, 16=Zn, 17=Fe, 18=Cu, 19=Mg, 20=Ca, 21=Na, + 22=K, 23=Bi, 24=H + """ + if use_extended_vocab: + vocab = ELEMENT_VOCAB_EXTENDED_TO_IDX + default_idx = 2 # UNK token + else: + vocab = ELEMENT_TO_IDX + default_idx = 0 # PAD token (no UNK in ELEMENT_VOCAB) + + element_indices = [] + for atom in mol.GetAtoms(): + symbol = atom.GetSymbol() + idx = vocab.get(symbol, default_idx) + element_indices.append(idx) + + return torch.tensor(element_indices, dtype=torch.long) + + aa_3to1 = { "ALA": "A", "ARG": "R", @@ -184,6 +276,142 @@ def load_pdb(filepath: str, add_batch_dim: bool = True) -> dict[str, Any] | None return structure_data +def load_pdb_atom14(pdb_file, add_batch_dim: bool = True) -> dict[str, Any]: + """Convert a PDB file to a PyTorch tensor. + + Args: + filepath (str): Path to the PDB file. Can be a local path or an S3 URI. + + Returns: + dict: A dictionary containing the following keys: + - 'pdb_path': The path to the PDB file. + - 'sequence': A tensor of shape (1, N) containing the amino acid sequence as integer indices. + - 'sequence_str': A string representing the amino acid sequence in one-letter codes. + - 'atom14_coords': A tensor of shape (1, N, 14, 3) containing the coordinates of the atom14 atoms. + - 'chains_ids': A tensor of shape (1, N) containing the chain IDs. + - 'indices': A tensor of shape (1, N) containing the residue numbers. + - 'atom14_mask': A tensor of shape (1, N) containing the mask for the coordinates. + - 'real_chains': A tensor of shape (1, N) containing the real chain IDs. + """ + if pdb_file.startswith("s3://"): + # Parse S3 URI + s3 = boto3.client("s3") + bucket, key = pdb_file[5:].split("/", 1) + + # Download the file locally + local_file = "/tmp/" + os.path.basename(pdb_file) + s3.download_file(bucket, key, local_file) + pdb_file = local_file + + # Read PDB or CIF file to dataframe + if pdb_file.endswith(".cif"): + pmmcif = PandasMmcif() + df = pmmcif.read_mmcif(pdb_file).df["ATOM"] + # rename label_atom_id to atom_name + df = df.rename(columns={"label_atom_id": "atom_name"}) + # rename Cartn_x, Cartn_y, Cartn_z to x_coord, y_coord, z_coord + df = df.rename(columns={"Cartn_x": "x_coord", "Cartn_y": "y_coord", "Cartn_z": "z_coord"}) + # rename auth_comp_id to residue_name + df = df.rename(columns={"label_seq_id": "residue_number"}) + df = df.rename(columns={"auth_comp_id": "residue_name"}) + # ensure that residue_number is an integer + df["residue_number"] = df["residue_number"].astype(int) + df_coords = df + group_chain = df_coords.groupby("auth_asym_id") + else: + df = cpdb.parse(pdb_file, df=True) + df = df[df["record_name"] == "ATOM"] + df_coords = df + group_chain = df_coords.groupby("chain_id") + atom14_coords = [] + atom14_mask = [] + sequence = [] + chains = [] + residue_numbers = [] + + for chain_id, chain in group_chain: + group_residue = chain.groupby("residue_number") + for residue_number, residue in group_residue: + residue_name = residue["residue_name"].iloc[0] + # Skip non-standard residues + if residue_name not in residue_constants.restype_name_to_atom_thin_names: + logger.warning( + f"Skipping non-standard residue {residue_name} at position {residue_number} in chain {chain_id}" + ) + continue + atom14_atom_names = residue_constants.restype_name_to_atom_thin_names[residue_name] + atom14_coords_list = [] + atom14_mask_list = [] + atom14_atom_names_list = [] + for atom_name in atom14_atom_names: + if atom_name != "": + if atom_name in residue["atom_name"].values: + coords_x = residue[residue["atom_name"] == atom_name]["x_coord"].values[0] + coords_y = residue[residue["atom_name"] == atom_name]["y_coord"].values[0] + coords_z = residue[residue["atom_name"] == atom_name]["z_coord"].values[0] + atom14_coords_list.append(np.array([coords_x, coords_y, coords_z])) + atom14_mask_list.append(1) + atom14_atom_names_list.append(atom_name) + else: + atom14_coords_list.append(np.array([0.0, 0.0, 0.0])) + atom14_mask_list.append(0) + atom14_atom_names_list.append("") + else: + atom14_coords_list.append(np.array([0.0, 0.0, 0.0])) + atom14_mask_list.append(0) + atom14_atom_names_list.append("") + atom14_coords.append(np.array(atom14_coords_list)) + atom14_mask.append(np.array(atom14_mask_list)) + sequence.append(residue["residue_name"].values[0]) + chains.append(chain_id) + residue_numbers.append(residue_number) + atom14_coords = np.array(atom14_coords) + atom14_coords = torch.tensor(atom14_coords, dtype=torch.float32) + atom14_mask = np.array(atom14_mask) + atom14_mask = torch.tensor(atom14_mask, dtype=torch.float32) + residue_numbers = np.array(residue_numbers) + + # Convert 3-letter codes to 1-letter codes + sequence_1letter = [aa_3to1.get(aa, "X") for aa in sequence] + + # Create the string sequence + sequence_str = "".join(sequence_1letter) + + # Convert to tensor indices + sequence = [residue_constants.restype_order_with_x[aa] for aa in sequence_1letter] + sequence = torch.tensor(sequence, dtype=torch.int32) + + # get ord of chains but make sure chain is a character + chains = [ord(chain[0]) for chain in chains] + real_chains = torch.tensor(chains, dtype=torch.int32) + + # renumber residue_numbers such that when the chain changes, the residue_numbers are continuous+200 + residue_numbers = torch.tensor(residue_numbers, dtype=torch.int32) + chain_changes = np.diff(chains, prepend=chains[0]) != 0 + chains = np.cumsum(chain_changes) * 200 + chains = torch.tensor(chains) + residue_numbers = residue_numbers + chains + + structure_data = { + "pdb_path": pdb_file, + "sequence": sequence, + "sequence_str": sequence_str, + "atom14_coords": atom14_coords, + "chains_ids": chains, + "indices": residue_numbers, + "atom14_mask": atom14_mask, + "real_chains": real_chains, + } + if add_batch_dim: + structure_data["sequence"] = structure_data["sequence"][None] + structure_data["atom14_coords"] = structure_data["atom14_coords"][None] + structure_data["atom14_mask"] = structure_data["atom14_mask"][None] + structure_data["chains_ids"] = structure_data["chains_ids"][None] + structure_data["indices"] = structure_data["indices"][None] + structure_data["real_chains"] = structure_data["real_chains"][None] + return structure_data + + def reorder_molecule(mol, new_order): """ Create a new molecule with atoms reordered according to new_order. @@ -225,22 +453,43 @@ def reorder_molecule(mol, new_order): return new_mol.GetMol() -def load_ligand(filepath: str, add_batch_dim: bool = True, canonical_order: bool = True) -> dict[str, Any]: +def load_ligand( + filepath: str, + add_batch_dim: bool = True, + canonical_order: bool = True, + use_extended_element_vocab: bool = False, +) -> dict[str, Any]: """Convert a ligand file to a PyTorch tensor. - Args: - filepath (str): Path to the ligand file. Can be a local path or an S3 URI. - Supports .pdb, .mol2, and .sdf formats. - add_batch_dim (bool): Whether to add a batch dimension to the output. - canonical_order (bool): Whether to reorder the atoms to the canonical order. - - Returns: - dict: A dictionary containing the following keys: - - 'pdb_path': The path to the ligand file. Could be .pdb or .mol2 or .sdf - - 'atom_names': A list of strings representing the atom names. - - 'atom_coords': A tensor of shape (1, N, 3) containing the coordinates of the ligand atoms. - - 'atom_indices': A tensor of shape (1, N) containing the atom indices. - - 'mask': A tensor of shape (1, N) containing the mask for the coordinates. + Parameters + ---------- + filepath : str + Path to the ligand file. Can be a local path or an S3 URI. + Supports .pdb, .mol2, and .sdf formats. + add_batch_dim : bool + Whether to add a batch dimension to the output. + canonical_order : bool + Whether to reorder the atoms to the canonical order (mol2/sdf only). + use_extended_element_vocab : bool + If True, use ELEMENT_VOCAB_EXTENDED (25 tokens) for element_indices + to match Gen-UME protein-ligand encoder. + If False (default), use ELEMENT_VOCAB (14 tokens) for latent generator. + + Returns + ------- + dict + A dictionary containing the following keys: + - 'pdb_path': The path to the ligand file. + - 'atom_names': A list of strings representing the atom symbols. + - 'atom_coords': Tensor of shape [N, 3] or [1, N, 3] with coordinates. + - 'atom_indices': Tensor of shape [N] or [1, N] with atom indices. + - 'mask': Tensor of shape [N] or [1, N] with validity mask. + - 'element_indices': Tensor of shape [N] or [1, N] with element type indices + (only for mol2/sdf files). Uses ELEMENT_VOCAB (14 tokens) by default, + or ELEMENT_VOCAB_EXTENDED (25 tokens) if use_extended_element_vocab=True. + - 'bond_matrix': Tensor of shape [N, N] with bond types + (only for mol2/sdf files). Values: 0=none, 1=single, 2=double, + 3=triple, 4=aromatic, 5=other. """ if filepath.startswith("s3://"): # Parse S3 URI @@ -285,17 +534,25 @@ def load_ligand(filepath: str, add_batch_dim: bool = True, canonical_order: bool atom_numbers = torch.tensor(atom_numbers, dtype=torch.int32) mask = torch.ones(coords.shape[0], dtype=torch.float32) + # Extract bond matrix and element indices from RDKit molecule + bond_matrix = extract_bond_matrix(mol) + element_indices = extract_element_indices(mol, use_extended_vocab=use_extended_element_vocab) + structure_data = { "pdb_path": filepath, "atom_names": atom_names, "atom_coords": coords, "atom_indices": atom_numbers, "mask": mask, + "element_indices": element_indices, + "bond_matrix": bond_matrix, } if add_batch_dim: structure_data["atom_coords"] = structure_data["atom_coords"][None] structure_data["atom_indices"] = structure_data["atom_indices"][None] structure_data["mask"] = structure_data["mask"][None] + structure_data["element_indices"] = structure_data["element_indices"][None] + # Note: bond_matrix is NOT batched as it's [N, N] and collation handles it return structure_data diff --git a/src/lobster/model/latent_generator/io/_write_pdb.py b/src/lobster/model/latent_generator/io/_write_pdb.py index f1dbfc71..795b310f 100644 --- a/src/lobster/model/latent_generator/io/_write_pdb.py +++ b/src/lobster/model/latent_generator/io/_write_pdb.py @@ -1,9 +1,152 @@ import logging +import math import torch py_logger = logging.getLogger(__name__) +# Ideal geometry constants for backbone atoms +CA_CB_BOND = 1.521 # Å - standard CA-CB bond length +C_O_BOND = 1.231 # Å - carbonyl C=O bond length +N_CA_CB_ANGLE = math.radians(110.5) # tetrahedral angle +CA_C_O_ANGLE = math.radians(120.5) # sp2 carbonyl angle + +# Glycine index in num2aa (GLY has no CB) +GLY_INDEX = 7 + + +def _normalize(v): + """Normalize a vector, handling zero-length vectors.""" + norm = torch.linalg.norm(v) + if norm < 1e-8: + return v + return v / norm + + +def calculate_idealized_cb(n_pos, ca_pos, c_pos): + """Calculate CB position using tetrahedral geometry. + + Places CB in the standard L-amino acid position using the + tetrahedral geometry around the CA atom. + + Args: + n_pos: N atom coordinates (torch.Tensor, shape [3]) + ca_pos: CA atom coordinates (torch.Tensor, shape [3]) + c_pos: C atom coordinates (torch.Tensor, shape [3]) + + Returns: + CB position as torch.Tensor of shape [3] + """ + # Vectors from CA to N and C + n_vec = n_pos - ca_pos + c_vec = c_pos - ca_pos + + # Normalize + n_unit = _normalize(n_vec) + c_unit = _normalize(c_vec) + + # Calculate the N-CA-C plane normal + plane_normal = torch.linalg.cross(n_unit, c_unit) + plane_normal_norm = torch.linalg.norm(plane_normal) + + if plane_normal_norm > 1e-6: + plane_normal = plane_normal / plane_normal_norm + else: + # Fallback for collinear atoms + plane_normal = torch.tensor([0.0, 0.0, 1.0], dtype=n_pos.dtype, device=n_pos.device) + + # CB direction: solve for position that makes correct angles with N and C + # For tetrahedral geometry, CB should make ~110.5° with both N and C + cos_target = math.cos(N_CA_CB_ANGLE) # cos(110.5°) ≈ -0.35 + cos_ncc = torch.dot(n_unit, c_unit).item() # cos of N-CA-C angle + + # From the constraint equations: + # CB_dir · n_unit = cos(110.5°) + # CB_dir · c_unit = cos(110.5°) + # Solving: a = b = cos_target / (1 + cos_ncc) + denom = 1 + cos_ncc + if abs(denom) < 1e-6: + denom = 1e-6 + a = cos_target / denom + + # c² = 1 - 2*a²*(1 + cos_ncc) + c_sq = 1 - 2 * a * a * (1 + cos_ncc) + if c_sq < 0: + c_sq = 0.01 # Handle numerical issues + + # For L-amino acids, CB is on the positive side of the plane + c_coeff = math.sqrt(c_sq) + + cb_dir = a * n_unit + a * c_unit + c_coeff * plane_normal + cb_dir = _normalize(cb_dir) + + cb_pos = ca_pos + CA_CB_BOND * cb_dir + return cb_pos + + +def calculate_idealized_o(ca_pos, c_pos, next_n_pos=None): + """Calculate carbonyl O position. + + Places O in the peptide plane, trans to the next residue's N + (if available) or using simple geometry. + + Args: + ca_pos: CA atom coordinates (torch.Tensor, shape [3]) + c_pos: C atom coordinates (torch.Tensor, shape [3]) + next_n_pos: Next residue's N atom coordinates (optional) + + Returns: + O position as torch.Tensor of shape [3] + """ + c_to_ca = ca_pos - c_pos + c_to_ca = _normalize(c_to_ca) + + if next_n_pos is not None: + # O is roughly trans to N across the C-CA axis + c_to_n = next_n_pos - c_pos + c_to_n = _normalize(c_to_n) + + # Calculate plane normal + plane_normal = torch.linalg.cross(c_to_ca, c_to_n) + plane_normal_norm = torch.linalg.norm(plane_normal) + + if plane_normal_norm > 1e-6: + plane_normal = plane_normal / plane_normal_norm + + # O direction is in the plane, roughly opposite to N + # Use the CA-C-N angle to place O correctly + # O should be at ~120° from both CA and N (sp2 geometry) + o_dir = torch.linalg.cross(plane_normal, c_to_n) + o_dir = _normalize(o_dir) + + # Blend to get correct angle + # O is at ~120° from N, so mix -c_to_n and perpendicular + o_dir = -c_to_n * 0.5 + o_dir * 0.866 # cos(120°), sin(120°) + o_dir = _normalize(o_dir) + else: + # Fallback: place O perpendicular to CA direction + o_dir = _get_perpendicular(c_to_ca) + else: + # No next N available (terminal residue) + # Place O roughly perpendicular to CA-C bond + o_dir = _get_perpendicular(c_to_ca) + + o_pos = c_pos + C_O_BOND * o_dir + return o_pos + + +def _get_perpendicular(v): + """Get a unit vector perpendicular to v.""" + # Choose reference axis that's not parallel to v + if abs(v[2]) < 0.9: + ref = torch.tensor([0.0, 0.0, 1.0], dtype=v.dtype, device=v.device) + else: + ref = torch.tensor([1.0, 0.0, 0.0], dtype=v.dtype, device=v.device) + + perp = torch.linalg.cross(v, ref) + return _normalize(perp) + + num2aa = [ "ALA", "ARG", @@ -673,7 +816,21 @@ # writepdb -def writepdb(filename, atoms, seq, idx_pdb=None, bfacts=None): +def writepdb(filename, atoms, seq, idx_pdb=None, bfacts=None, add_cb_o=True): + """Write protein structure to a PDB file. + + Args: + filename: Output PDB filename + atoms: Tensor of atom coordinates. Shape can be: + - [num_residues, 3] for CA-only + - [num_residues, 3, 3] for backbone (N, CA, C) + - [num_residues, 14, 3] or [num_residues, 27, 3] for full atoms + seq: Tensor of residue type indices (into num2aa) + idx_pdb: Optional tensor of residue numbers (default: 1-indexed sequential) + bfacts: Optional tensor of B-factors (default: zeros) + add_cb_o: If True and atoms has shape [N, 3, 3] (backbone only), add + idealized O and CB atoms. CB is not added for glycine. Default: True + """ f = open(filename, "w") ctr = 1 scpu = seq.cpu().squeeze() @@ -684,6 +841,8 @@ def writepdb(filename, atoms, seq, idx_pdb=None, bfacts=None): idx_pdb = 1 + torch.arange(atomscpu.shape[0]) Bfacts = torch.clamp(bfacts.cpu(), 0, 1) + num_residues = len(scpu) + for i, s in enumerate(scpu): if len(atomscpu.shape) == 2: f.write( @@ -691,11 +850,52 @@ def writepdb(filename, atoms, seq, idx_pdb=None, bfacts=None): ) ctr += 1 elif atomscpu.shape[1] == 3: - for j, atm_j in enumerate([" N ", " CA ", " C "]): + if add_cb_o: + # Write N, CA, C, O, CB (CB only for non-glycine) + n_pos = atomscpu[i, 0] + ca_pos = atomscpu[i, 1] + c_pos = atomscpu[i, 2] + + # Write N + f.write( + f"{'ATOM':<6}{ctr:>5} {' N ':>4} {num2aa[s]:>3} {'A'}{idx_pdb[i]:>4} {n_pos[0]:8.3f}{n_pos[1]:8.3f}{n_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + ctr += 1 + + # Write CA + f.write( + f"{'ATOM':<6}{ctr:>5} {' CA ':>4} {num2aa[s]:>3} {'A'}{idx_pdb[i]:>4} {ca_pos[0]:8.3f}{ca_pos[1]:8.3f}{ca_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + ctr += 1 + + # Write C f.write( - f"{'ATOM':<6}{ctr:>5} {atm_j:>4} {num2aa[s]:>3} {'A'}{idx_pdb[i]:>4} {atomscpu[i, j, 0]:8.3f}{atomscpu[i, j, 1]:8.3f}{atomscpu[i, j, 2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + f"{'ATOM':<6}{ctr:>5} {' C ':>4} {num2aa[s]:>3} {'A'}{idx_pdb[i]:>4} {c_pos[0]:8.3f}{c_pos[1]:8.3f}{c_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" ) ctr += 1 + + # Write O (carbonyl oxygen) + next_n_pos = atomscpu[i + 1, 0] if i < num_residues - 1 else None + o_pos = calculate_idealized_o(ca_pos, c_pos, next_n_pos) + f.write( + f"{'ATOM':<6}{ctr:>5} {' O ':>4} {num2aa[s]:>3} {'A'}{idx_pdb[i]:>4} {o_pos[0]:8.3f}{o_pos[1]:8.3f}{o_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + ctr += 1 + + # Write CB (skip for glycine) + if s != GLY_INDEX: + cb_pos = calculate_idealized_cb(n_pos, ca_pos, c_pos) + f.write( + f"{'ATOM':<6}{ctr:>5} {' CB ':>4} {num2aa[s]:>3} {'A'}{idx_pdb[i]:>4} {cb_pos[0]:8.3f}{cb_pos[1]:8.3f}{cb_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + ctr += 1 + else: + # Original behavior: just N, CA, C + for j, atm_j in enumerate([" N ", " CA ", " C "]): + f.write( + f"{'ATOM':<6}{ctr:>5} {atm_j:>4} {num2aa[s]:>3} {'A'}{idx_pdb[i]:>4} {atomscpu[i, j, 0]:8.3f}{atomscpu[i, j, 1]:8.3f}{atomscpu[i, j, 2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + ctr += 1 else: natoms = atomscpu.shape[1] if natoms != 14 and natoms != 27: @@ -755,6 +955,8 @@ def writepdb_ligand_complex( ligand_bfacts=None, ligand_chain="L", ligand_resname="LIG", + ligand_bond_matrix=None, + add_cb_o=True, ): """Write protein and ligand atoms to a PDB file. @@ -771,6 +973,10 @@ def writepdb_ligand_complex( ligand_bfacts: Optional tensor of ligand B-factors (default: zeros) ligand_chain: Chain ID for ligand (default: "L") ligand_resname: Residue name for ligand atoms (default: "LIG") + ligand_bond_matrix: Optional bond matrix [num_atoms, num_atoms] where non-zero values + indicate bonds. Used to write CONECT records for proper bond visualization. + add_cb_o: If True and protein_atoms has shape [N, 3, 3] (backbone only), add + idealized O and CB atoms. CB is not added for glycine. Default: True """ # Check if protein_atoms and ligand_atoms are provided @@ -791,6 +997,7 @@ def writepdb_ligand_complex( protein_idx = 1 + torch.arange(atomscpu.shape[0]) Bfacts = torch.clamp(protein_bfacts.cpu(), 0, 1) + num_residues = len(scpu) for i, s in enumerate(scpu): if len(atomscpu.shape) == 2: @@ -801,13 +1008,53 @@ def writepdb_ligand_complex( atom_counter += 1 elif atomscpu.shape[1] == 3: - # Backbone atoms (N, CA, C) - for j, atm_j in enumerate([" N ", " CA ", " C "]): + if add_cb_o: + # Write N, CA, C, O, CB (CB only for non-glycine) + n_pos = atomscpu[i, 0] + ca_pos = atomscpu[i, 1] + c_pos = atomscpu[i, 2] + + # Write N + f.write( + f"{'ATOM':<6}{atom_counter:>5} {' N ':>4} {num2aa[s]:>3} {protein_chain}{protein_idx[i]:>4} {n_pos[0]:8.3f}{n_pos[1]:8.3f}{n_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + atom_counter += 1 + + # Write CA + f.write( + f"{'ATOM':<6}{atom_counter:>5} {' CA ':>4} {num2aa[s]:>3} {protein_chain}{protein_idx[i]:>4} {ca_pos[0]:8.3f}{ca_pos[1]:8.3f}{ca_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + atom_counter += 1 + + # Write C + f.write( + f"{'ATOM':<6}{atom_counter:>5} {' C ':>4} {num2aa[s]:>3} {protein_chain}{protein_idx[i]:>4} {c_pos[0]:8.3f}{c_pos[1]:8.3f}{c_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + atom_counter += 1 + + # Write O (carbonyl oxygen) + next_n_pos = atomscpu[i + 1, 0] if i < num_residues - 1 else None + o_pos = calculate_idealized_o(ca_pos, c_pos, next_n_pos) f.write( - f"{'ATOM':<6}{atom_counter:>5} {atm_j:>4} {num2aa[s]:>3} {protein_chain}{protein_idx[i]:>4} {atomscpu[i, j, 0]:8.3f}{atomscpu[i, j, 1]:8.3f}{atomscpu[i, j, 2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + f"{'ATOM':<6}{atom_counter:>5} {' O ':>4} {num2aa[s]:>3} {protein_chain}{protein_idx[i]:>4} {o_pos[0]:8.3f}{o_pos[1]:8.3f}{o_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" ) atom_counter += 1 + # Write CB (skip for glycine) + if s != GLY_INDEX: + cb_pos = calculate_idealized_cb(n_pos, ca_pos, c_pos) + f.write( + f"{'ATOM':<6}{atom_counter:>5} {' CB ':>4} {num2aa[s]:>3} {protein_chain}{protein_idx[i]:>4} {cb_pos[0]:8.3f}{cb_pos[1]:8.3f}{cb_pos[2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + atom_counter += 1 + else: + # Original behavior: just N, CA, C + for j, atm_j in enumerate([" N ", " CA ", " C "]): + f.write( + f"{'ATOM':<6}{atom_counter:>5} {atm_j:>4} {num2aa[s]:>3} {protein_chain}{protein_idx[i]:>4} {atomscpu[i, j, 0]:8.3f}{atomscpu[i, j, 1]:8.3f}{atomscpu[i, j, 2]:8.3f}{1.0:6.2f}{Bfacts[i]:6.2f}\n" + ) + atom_counter += 1 + else: # Full atom representation natoms = atomscpu.shape[1] @@ -909,5 +1156,40 @@ def writepdb_ligand_complex( ) atom_counter += 1 + # Write CONECT records for ligand bonds if bond matrix provided + if ligand_bond_matrix is not None: + bond_mat = ( + ligand_bond_matrix.cpu().numpy() + if isinstance(ligand_bond_matrix, torch.Tensor) + else ligand_bond_matrix + ) + # ligand_start_atom is the first atom serial number for ligand atoms + ligand_start_atom = atom_counter - latoms.shape[0] + + for i in range(latoms.shape[0]): + # Find all atoms bonded to atom i + bonded_atoms = [] + for j in range(latoms.shape[0]): + if i != j and bond_mat[i, j] > 0: + bonded_atoms.append(ligand_start_atom + j) + + if bonded_atoms: + # Write CONECT record: atom serial number followed by bonded atoms + atom_serial = ligand_start_atom + i + # PDB CONECT format: up to 4 bonded atoms per line + conect_line = f"CONECT{atom_serial:5d}" + for bonded in bonded_atoms[:4]: + conect_line += f"{bonded:5d}" + f.write(conect_line + "\n") + + # If more than 4 bonds, write continuation lines + if len(bonded_atoms) > 4: + for batch_start in range(4, len(bonded_atoms), 4): + batch = bonded_atoms[batch_start : batch_start + 4] + conect_line = f"CONECT{atom_serial:5d}" + for bonded in batch: + conect_line += f"{bonded:5d}" + f.write(conect_line + "\n") + # Write TER record to indicate end of chains f.write("TER\nEND\n") diff --git a/src/lobster/model/latent_generator/models/vit/_vit_utils.py b/src/lobster/model/latent_generator/models/vit/_vit_utils.py index c16fe1b7..3608d903 100644 --- a/src/lobster/model/latent_generator/models/vit/_vit_utils.py +++ b/src/lobster/model/latent_generator/models/vit/_vit_utils.py @@ -15,7 +15,11 @@ from einops import rearrange from einops.layers.torch import Rearrange -from lobster.model.latent_generator.utils.residue_constants import ELEMENT_VOCAB +from lobster.model.latent_generator.utils.residue_constants import ( + ELEMENT_VOCAB, + ELEMENT_VOCAB_EXTENDED, + NUM_BOND_TYPES, +) # os.environ["HYDRA_FULL_ERROR"] = "1" @@ -500,6 +504,8 @@ def __init__( add_cls_token: bool = False, sequence_embedding: bool = False, ligand_atom_embedding: bool = False, + use_ligand_bond_embedding: bool = False, + use_extended_element_vocab: bool = False, ): super().__init__() @@ -532,9 +538,24 @@ def __init__( # Ligand atom type embedding using element vocabulary if ligand_atom_embedding: - logger.info("Adding ligand atom type embeddings") - # Use element vocabulary size - self.ligand_atom_type_embedding = nn.Embedding(len(ELEMENT_VOCAB), embed_dim_hidden) + if use_extended_element_vocab: + vocab_size = len(ELEMENT_VOCAB_EXTENDED) + logger.info(f"Adding ligand atom type embeddings (extended vocab: {vocab_size} tokens)") + else: + vocab_size = len(ELEMENT_VOCAB) + logger.info(f"Adding ligand atom type embeddings (standard vocab: {vocab_size} tokens)") + self.ligand_atom_type_embedding = nn.Embedding(vocab_size, embed_dim_hidden) + + # Ligand bond matrix embedding (optional, for topology-aware encoding) + self.use_ligand_bond_embedding = use_ligand_bond_embedding + if use_ligand_bond_embedding: + logger.info("Adding ligand bond matrix embeddings") + from lobster.model.gen_ume._bond_embedding import BondMatrixEmbedding + + self.ligand_bond_embedding = BondMatrixEmbedding( + hidden_size=embed_dim_hidden, + num_bond_types=NUM_BOND_TYPES, + ) transformer_seq_len = seq_len // (2**1) assert transformer_seq_len % patch_size == 0 @@ -700,8 +721,10 @@ def forward( ligand_residue_index=None, sequence=None, ligand_atom_types=None, + ligand_bond_matrix=None, **kwargs, ): + # === STAGE 1: Process protein coordinates === if coords is not None: B, L, n_atoms = coords.shape[:3] coords_gt = coords.clone() @@ -712,6 +735,7 @@ def forward( else: coords_gt = None + # === STAGE 2: Process ligand coordinates === if self.encode_ligand and ligand_coords is not None: ligand_embedding = self.ligand_to_embedding(ligand_coords) @@ -720,6 +744,10 @@ def forward( ligand_type_embedding = self.ligand_atom_type_embedding(ligand_atom_types) ligand_embedding = ligand_embedding + ligand_type_embedding + # Add bond matrix embeddings if available (topology-aware encoding) + if self.use_ligand_bond_embedding and ligand_bond_matrix is not None: + ligand_embedding = self.ligand_bond_embedding(ligand_embedding, ligand_bond_matrix, ligand_mask) + if coords is not None: x = torch.cat([x, ligand_embedding], -2) seq_mask = torch.cat([seq_mask, ligand_mask], -1) @@ -729,6 +757,7 @@ def forward( seq_mask = ligand_mask residue_index = ligand_residue_index + # === STAGE 3: Pairwise distance features (if enabled) === if self.concat_sine_pw: with torch.no_grad(): pw_coords = self.pairwise_distances(coords_gt, ligand_coords=ligand_coords) @@ -762,6 +791,7 @@ def forward( spatial_attention_mask_ = torch.cat([padding_col, spatial_attention_mask_], dim=1) spatial_attention_mask_ = torch.cat([padding_row, spatial_attention_mask_], dim=2) + # === STAGE 4: Transformer === x = self.transformer( x, time=time_cond, @@ -772,6 +802,7 @@ def forward( spatial_attention_mask=spatial_attention_mask_, ) + # === STAGE 5: Final projection === x_out = self.to_hidden(x) if return_embeddings: @@ -894,11 +925,15 @@ def forward( ligand_mask=None, **kwargs, ): + # Check if ligand has actual data (not just empty tensor) + has_ligand_data = self.encode_ligand and ligand_quant is not None and ligand_quant.shape[1] > 0 + if x_quant is not None: x_emb = self.embed_struc_tokens(x_quant) else: x_emb = None - if self.encode_ligand and ligand_quant is not None: + + if has_ligand_data: ligand_emb = self.embed_ligand_tokens(ligand_quant) if x_emb is not None: B, L, D = x_emb.shape @@ -924,7 +959,7 @@ def forward( x = self.ffn(x_out) - if self.encode_ligand and ligand_quant is not None: + if has_ligand_data: if x_quant is not None: ligand_x = x[:, L:, :] x = x[:, :L, :] @@ -937,7 +972,7 @@ def forward( x = self.from_patch(x) x = rearrange(x, "b c n a -> b n a c") - if self.encode_ligand and ligand_quant is not None: + if has_ligand_data: out = {"protein_coords": x, "ligand_coords": ligand_x} else: out = x diff --git a/src/lobster/model/latent_generator/quantizer/__init__.py b/src/lobster/model/latent_generator/quantizer/__init__.py index 64f91a7d..348db925 100644 --- a/src/lobster/model/latent_generator/quantizer/__init__.py +++ b/src/lobster/model/latent_generator/quantizer/__init__.py @@ -1,2 +1,4 @@ from ._ligand_tokenizer import LigandTokenizer from ._slq import SimpleLinearQuantizer +from ._fsq import FiniteScalarQuantizer +from ._fsq_ligand_tokenizer import FSQLigandTokenizer diff --git a/src/lobster/model/latent_generator/quantizer/_fsq.py b/src/lobster/model/latent_generator/quantizer/_fsq.py new file mode 100644 index 00000000..af9a6c12 --- /dev/null +++ b/src/lobster/model/latent_generator/quantizer/_fsq.py @@ -0,0 +1,113 @@ +""" +Torch implementation of Finite Scalar Quantization +https://arxiv.org/abs/2309.15505, Appendix 1 +""" + +import torch + + +def round_ste(z): + """Round with straight through gradients.""" + zhat = torch.round(z) + return z + (zhat - z).detach() + + +class FiniteScalarQuantizer(torch.nn.Module): + def __init__(self, levels: list[int], return_oh_like: bool = True): + super().__init__() + + levels = torch.tensor(levels) + basis = torch.cat([torch.tensor([1]), torch.cumprod(levels[:-1], dim=0)]).to(dtype=torch.int32) + self.levels = levels + self.basis = basis + self.return_oh_like = return_oh_like + # number of dimensions expect from inputs + self.num_dimensions = len(levels) + + # size of the codebook + self.codebook_size = torch.prod(levels) + self.implicit_codebook = self.indexes_to_codes(torch.arange(self.codebook_size)) + self.n_tokens = self.codebook_size + print("Codebook size:", self.codebook_size) + + @property + def codebook(self): + return self.implicit_codebook + + def bound(self, z, eps=1e-3): + """Bound z, an array of shape (..., d).""" + levels = self.levels.to(z.device) + half_l = (levels - 1) * (1 - eps) / 2 + offset = torch.where(levels % 2 == 1, 0.0, 0.5) + shift = torch.tan(offset / half_l) + return torch.tanh(z + shift) * half_l - offset + + def _quantize(self, z, mask=None, **kwargs): + """Quanitzes z, returns quantized zhat as codewords, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self.levels // 2 # Renormalize to [-1, 1]. + half_width = half_width.to(z.device) + z_tokens = quantized / half_width + return z_tokens, z, mask + + def _scale_and_shift(self, zhat_normalized): + levels = self.levels.to(zhat_normalized.device) + half_width = levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + levels = self.levels.to(zhat.device) + half_width = levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indexes(self, zhat): + # assert zhat.shape[-1] == len(self.levels) + basis = self.basis.to(zhat.device) + zhat = self._scale_and_shift(zhat) + return (zhat * basis).sum(axis=-1) + + def indexes_to_codes(self, indices): + indices = indices.unsqueeze(-1) + + # def _maybe_cast_shape(input_arr, target_arr) + # # both should have 2 dimensions + # # but user-specified indices might be batched + # if input_arr.shape != target_arr.shape: + # return input_arr.expand_as(target_arr) + # else: + # return input_arr + + # basis = _maybe_cast_shape(self.basis, indices) + # levels = _maybe_cast_shape(self.levels, indices) + basis = self.basis.to(indices.device) + levels = self.levels.to(indices.device) + codes_non_centered = torch.remainder(torch.floor_divide(indices, basis), levels) + return self._scale_and_shift_inverse(codes_non_centered) + + def quantize(self, z, mask=None, **kwargs): + z_tokens, z, mask = self._quantize(z, mask=mask, **kwargs) + if self.return_oh_like: + # Get continuous indexes (B, L) + continuous_indexes = self.codes_to_indexes(z_tokens) + + # Get codebook entries (codebook_size, num_dimensions) + codebook = self.implicit_codebook.to(z_tokens.device) + + # Compute similarity between z_tokens and each codebook entry + # z_tokens: (B, L, num_dimensions) + # codebook: (codebook_size, num_dimensions) + # Result: (B, L, codebook_size) + oh_like = torch.matmul(z_tokens, codebook.T) + + # Optionally: use straight-through with discrete one-hot for sharper distribution + # Get discrete indexes for one-hot + codebook_size_int = int(self.codebook_size.item()) + discrete_indexes = torch.round(continuous_indexes).long().clamp(0, codebook_size_int - 1) + discrete_oh = torch.nn.functional.one_hot(discrete_indexes, num_classes=codebook_size_int).float() + + # Use straight-through: gradients flow through oh_like, forward uses discrete_oh + oh_like = oh_like + (discrete_oh - oh_like).detach() + + return oh_like, z, mask + else: + return z_tokens, z, mask diff --git a/src/lobster/model/latent_generator/quantizer/_fsq_ligand_tokenizer.py b/src/lobster/model/latent_generator/quantizer/_fsq_ligand_tokenizer.py new file mode 100644 index 00000000..915734fd --- /dev/null +++ b/src/lobster/model/latent_generator/quantizer/_fsq_ligand_tokenizer.py @@ -0,0 +1,73 @@ +import torch + +from lobster.model.latent_generator.quantizer._fsq import FiniteScalarQuantizer + + +class FSQLigandTokenizer(torch.nn.Module): + """Ligand tokenizer using Finite Scalar Quantization (FSQ) for both protein and ligand.""" + + def __init__( + self, + protein_levels: list[int] | None = None, + ligand_levels: list[int] | None = None, + return_oh_like: bool = True, + n_tokens: int | None = None, # For hydra config compatibility (ignored, computed from levels) + ): + """ + Initialize FSQ-based ligand tokenizer. + + Args: + protein_levels: FSQ levels for protein tokenization (e.g., [8, 6, 5] for 240 tokens) + ligand_levels: FSQ levels for ligand tokenization (e.g., [8, 6, 5] for 240 tokens) + return_oh_like: Whether to return one-hot-like representation + n_tokens: Ignored - for hydra config compatibility only (n_tokens is computed from levels) + """ + del n_tokens # Not used - computed from levels + super().__init__() + if protein_levels is None: + protein_levels = [8, 6, 5] # 240 tokens + if ligand_levels is None: + ligand_levels = [8, 6, 5] # 240 tokens + + self.protein_tokenizer = FiniteScalarQuantizer( + levels=protein_levels, + return_oh_like=return_oh_like, + ) + self.ligand_tokenizer = FiniteScalarQuantizer( + levels=ligand_levels, + return_oh_like=return_oh_like, + ) + + # Store codebook sizes for external use + self.n_tokens = self.protein_tokenizer.n_tokens + self.ligand_n_tokens = self.ligand_tokenizer.n_tokens + + def quantize(self, z, mask=None, ligand_mask=None): + """ + Quantize protein and ligand embeddings using FSQ. + + Args: + z: Input embeddings of shape (B, L_protein + L_ligand, embed_dim) + mask: Protein mask of shape (B, L_protein) + ligand_mask: Ligand mask of shape (B, L_ligand) + + Returns: + out_tokens: Dict with 'protein_tokens' and 'ligand_tokens' + out_logits: Dict with 'protein_logits' and 'ligand_logits' + out_masks: Dict with 'protein_mask' and 'ligand_mask' + """ + if mask is not None: + B, L = mask.shape + z_protein = z[:, :L, :] + z_ligand = z[:, L:, :] + protein_tokens, protein_logits, protein_mask = self.protein_tokenizer.quantize(z_protein, mask) + ligand_tokens, ligand_logits, ligand_mask = self.ligand_tokenizer.quantize(z_ligand, ligand_mask) + out_tokens = {"protein_tokens": protein_tokens, "ligand_tokens": ligand_tokens} + out_logits = {"protein_logits": protein_logits, "ligand_logits": ligand_logits} + out_masks = {"protein_mask": protein_mask, "ligand_mask": ligand_mask} + else: + ligand_tokens, ligand_logits, ligand_mask = self.ligand_tokenizer.quantize(z, ligand_mask) + out_tokens = {"ligand_tokens": ligand_tokens} + out_logits = {"ligand_logits": ligand_logits} + out_masks = {"ligand_mask": ligand_mask} + return out_tokens, out_logits, out_masks diff --git a/src/lobster/model/latent_generator/quantizer/_slq.py b/src/lobster/model/latent_generator/quantizer/_slq.py index 0be4ca49..590e32cd 100644 --- a/src/lobster/model/latent_generator/quantizer/_slq.py +++ b/src/lobster/model/latent_generator/quantizer/_slq.py @@ -35,6 +35,10 @@ def quantize(self, z, mask=None, **kwargs): if self.gumbel: z_tokens = gumbel_softmax(z_emb, temperature=self.tau, hard=False, include_noise=self.use_gumbel_noise) + # Create default mask if None to ensure consistent behavior + if mask is None: + mask = torch.ones(z_emb.shape[0], z_emb.shape[1], device=z_emb.device) + return z_tokens, z_emb, mask elif self.softmax: @@ -52,4 +56,8 @@ def quantize(self, z, mask=None, **kwargs): z_tokens = torch.nn.functional.one_hot(z_tokens, num_classes=self.n_tokens).float() + # Create default mask if None to ensure consistent behavior + if mask is None: + mask = torch.ones(z_emb.shape[0], z_emb.shape[1], device=z_emb.device) + return z_emb + (z_tokens - z_emb).detach(), z_emb, mask diff --git a/src/lobster/model/latent_generator/reconstruction_results_table.md b/src/lobster/model/latent_generator/reconstruction_results_table.md index 0098473c..1fb3a047 100644 --- a/src/lobster/model/latent_generator/reconstruction_results_table.md +++ b/src/lobster/model/latent_generator/reconstruction_results_table.md @@ -2,23 +2,12 @@ **Evaluation Set**: CASP15 proteins ≤ 512 residues (26 successful reconstructions out of 30 total structures) -| Model | Average RMSD (Å) | Std RMSD (Å) | Min RMSD (Å) | Max RMSD (Å) | -|-------|------------------|--------------|--------------|--------------| -| LG full attention | 1.707 | 0.643 | 0.839 | 3.434 | -| LG 10A | 3.698 | 1.756 | 1.952 | 7.664 | -| LG 20A c6d Aux | 4.395 | 2.671 | 1.678 | 11.306 | -| LG 20A seq 3di c6d Aux | 4.428 | 1.723 | 2.757 | 8.556 | -| LG 20A 3di c6d Aux | 4.484 | 2.458 | 2.390 | 11.696 | -| LG 20A | 4.470 | 3.540 | 1.630 | 12.864 | -| LG 20A seq 3di c6d 512 Aux | 5.761 | 4.349 | 1.188 | 17.442 | -| LG 20A seq Aux | 5.449 | 2.862 | 3.063 | 13.342 | -| LG 20A seq 3di Aux | 6.112 | 3.723 | 2.973 | 17.839 | -| LG 20A 3di Aux | 7.844 | 4.289 | 3.119 | 16.500 | +| Model | Tokens | Average RMSD (Å) | Std RMSD (Å) | Min RMSD (Å) | Max RMSD (Å) | +|-------|--------|------------------|--------------|--------------|--------------| +| LG full attention | 256 | 1.707 | 0.643 | 0.839 | 3.434 | ## Summary - **Best performing model**: LG full attention (1.707 ± 0.643 Å) -- **Second best**: LG 10A (3.698 ± 1.756 Å) -- **Third best**: LG 20A seq 3di c6d Aux (4.428 ± 1.723 Å) All models successfully reconstructed 26 out of 30 structures (86.7% success rate). \ No newline at end of file diff --git a/src/lobster/model/latent_generator/structure_decoder/_vit_decoder.py b/src/lobster/model/latent_generator/structure_decoder/_vit_decoder.py index ca2c054d..d796a7e1 100644 --- a/src/lobster/model/latent_generator/structure_decoder/_vit_decoder.py +++ b/src/lobster/model/latent_generator/structure_decoder/_vit_decoder.py @@ -105,7 +105,10 @@ def forward( **kwargs, ) - if ligand_present: + # Check if decoder actually returned ligand data (handles protein-only data with ligand-capable model) + ligand_output_present = isinstance(emb, dict) and "ligand_coords" in emb + + if ligand_output_present: ligand_emb = emb["ligand_coords"] emb = emb["protein_coords"] assert not torch.isnan(ligand_emb).any() @@ -121,7 +124,7 @@ def forward( assert not torch.isnan(emb).any() emb *= expand(seq_mask, emb) - if ligand_present: + if ligand_output_present: out = {"protein_coords": emb, "ligand_coords": ligand_emb} elif self.refinement_module: out = {"protein_coords": emb, "protein_coords_refinement": emb_refinement} diff --git a/src/lobster/model/latent_generator/structure_encoder/_vit_encoder.py b/src/lobster/model/latent_generator/structure_encoder/_vit_encoder.py index f62195d6..52f2580b 100644 --- a/src/lobster/model/latent_generator/structure_encoder/_vit_encoder.py +++ b/src/lobster/model/latent_generator/structure_encoder/_vit_encoder.py @@ -49,6 +49,8 @@ def __init__( use_sequence_embedding: bool = False, mask_structure: float = 0.0, ligand_atom_embedding: bool = False, + use_ligand_bond_embedding: bool = False, + use_extended_element_vocab: bool = False, *args, **kwargs, ): @@ -86,8 +88,13 @@ def __init__( logger.info(f"use sequence embedding: {self.use_sequence_embedding}") self.ligand_atom_embedding = ligand_atom_embedding logger.info(f"ligand atom embedding: {self.ligand_atom_embedding}") + self.use_ligand_bond_embedding = use_ligand_bond_embedding + logger.info(f"use ligand bond embedding: {self.use_ligand_bond_embedding}") + self.use_extended_element_vocab = use_extended_element_vocab + logger.info(f"use extended element vocab: {self.use_extended_element_vocab}") self.n_atoms = n_atoms + self.embed_dim = embed_dim n_xyz = 3 # Neural networks @@ -115,6 +122,8 @@ def __init__( add_cls_token=add_cls_token, sequence_embedding=use_sequence_embedding, ligand_atom_embedding=ligand_atom_embedding, + use_ligand_bond_embedding=use_ligand_bond_embedding, + use_extended_element_vocab=use_extended_element_vocab, ) def featurize( @@ -128,67 +137,120 @@ def featurize( apply_stochastic_fa: bool = False, backbone_noise: float = None, ): - if "sequence" in batch: + # NEW: Extract validity masks (only present for heterogeneous batches) + protein_valid = batch.get("protein_valid_mask", None) # Shape: (batch_size,) + ligand_valid = batch.get("ligand_valid_mask", None) # Shape: (batch_size,) + + # Determine what data we have + has_proteins = protein_valid is None or protein_valid.any() + has_ligands = ligand_valid is None or ligand_valid.any() + + # Process protein data + if has_proteins and "sequence" in batch: + coords = batch["coords_res"].clone() seq_mask = batch["mask"].clone() residue_index = batch["indices"].clone() - coords = batch["coords_res"].clone() + + # NOTE: If protein_valid exists, some batch positions may be all-zero padding + # The seq_mask will be False for those positions, so downstream processing + # will naturally ignore them via masking + if self.use_sequence_embedding: sequence = batch["sequence"].clone() else: sequence = None else: + coords = None seq_mask = None residue_index = None - coords = None sequence = None - if "ligand_coords" in batch: - # need to figure out how to rotate and translate the ligand coords the same way as the protein coords + # Process ligand data + if has_ligands and "ligand_coords" in batch: ligand_coords = batch["ligand_coords"].clone() ligand_mask = batch["ligand_mask"].clone() ligand_residue_index = batch["ligand_indices"].clone() ligand_atomic_numbers = batch["ligand_atomic_numbers"].clone() if "ligand_atomic_numbers" in batch else None - # combine protein and ligand coords but note index to splice out the ligand coords after rotation and translation + + # Combine protein and ligand if both present if coords is None: + # Ligand-only case coords = ligand_coords seq_mask = ligand_mask residue_index = ligand_residue_index else: + # Both present - concatenate (NOW SAFE with unified batch from Phase 1!) B, L, n_atoms, _ = coords.shape - coords = coords.reshape(B, -1, 3) - coords = torch.cat([coords, ligand_coords], dim=1) - seq_mask = torch.cat([seq_mask, ligand_mask], dim=1) + B_ligand = ligand_coords.shape[0] + + # Batch sizes MUST match with unified batch approach from Phase 1 + assert B == B_ligand, f"Batch size mismatch: protein {B} vs ligand {B_ligand}" + + # Flatten protein coords and concatenate + coords = coords.reshape(B, -1, 3) # [B, L*n_atoms, 3] + coords = torch.cat([coords, ligand_coords], dim=1) # [B, L*n_atoms + L_ligand, 3] + + # Expand seq_mask to match flattened protein coords: [B, L] -> [B, L*n_atoms] + seq_mask = torch.cat( + [seq_mask.unsqueeze(-1).expand(-1, -1, n_atoms).reshape(B, -1), ligand_mask], dim=1 + ) # [B, L*n_atoms + L_ligand] + # seq_mask = torch.cat([seq_mask, ligand_mask], dim=1) # [B, L + L_ligand] + + # NOTE: For batch positions where ligand_valid=False: + # - ligand_coords[i] is all zeros (from collate padding) + # - ligand_mask[i] is all False (from collate padding) + # For batch positions where protein_valid=False: + # - coords[i] is all zeros (from collate padding) + # - seq_mask[i] is all False (from collate padding) + # The masks handle this naturally! + else: + ligand_coords = None + ligand_mask = None + ligand_residue_index = None + ligand_atomic_numbers = None frame_type = self.frame_type if frame_type is None else frame_type get_all_frames = self.get_all_frames if get_all_frames is None else get_all_frames apply_stochastic_fa = self.apply_stochastic_fa if apply_stochastic_fa is None else apply_stochastic_fa - if random_se3: - if only_rot: - logger.info("only rotating") - translation_scale = 0.0 - else: - translation_scale = self.translation_scale - if only_trans: - logger.info("only translating") - rotation_mode = "none" - coords = apply_random_se3_batched( - coords, translation_scale=translation_scale, rotation_mode=rotation_mode - ) + # Apply SE(3) transformations - only if we have valid data + # Pass atom_mask to ensure we only transform non-masked (valid) regions + if random_se3 and coords is not None: + # Check if we have any valid coordinates to transform + if seq_mask is not None and seq_mask.any(): + if only_rot: + logger.info("only rotating") + translation_scale = 0.0 else: - coords = apply_random_se3_batched(coords, translation_scale=translation_scale) + translation_scale = self.translation_scale + if only_trans: + logger.info("only translating") + rotation_mode = "none" + coords = apply_random_se3_batched( + coords, atom_mask=seq_mask, translation_scale=translation_scale, rotation_mode=rotation_mode + ) + else: + coords = apply_random_se3_batched( + coords, atom_mask=seq_mask, translation_scale=translation_scale + ) + else: + logger.debug("Skipping SE(3) transform - no valid coordinates") else: - logger.info("no se3 applied") - - if frame_type is not None: - # apply global frame - coords = apply_global_frame_to_coords( - coords, - frame_type=frame_type, - mask=seq_mask, - apply_stochastic_fa=apply_stochastic_fa, - get_all_frames=get_all_frames, - ) + if not random_se3: + logger.info("no se3 applied") + + if frame_type is not None and coords is not None: + # Apply global frame only if we have valid coordinates + if seq_mask is not None and seq_mask.any(): + coords = apply_global_frame_to_coords( + coords, + frame_type=frame_type, + mask=seq_mask, # Mask handles padded positions + apply_stochastic_fa=apply_stochastic_fa, + get_all_frames=get_all_frames, + ) + else: + logger.debug("Skipping frame application - no valid coordinates") if self.backbone_noise > 0 and backbone_noise is None: coords = coords + self.backbone_noise * torch.randn_like(coords) @@ -204,13 +266,21 @@ def featurize( else: coords = coords * mask_structure.unsqueeze(-1).unsqueeze(-1) - if "ligand_coords" in batch: - if "sequence" in batch: - # splice out the ligand coords + if has_ligands and "ligand_coords" in batch: + if has_proteins and "sequence" in batch and coords is not None: + # Both present - split them back out + # NOTE: With unified batch, both modalities exist for all batch positions + # The masks (seq_mask, ligand_mask) handle which are valid ligand_coords = coords[:, L * n_atoms :, :] coords = coords[:, : L * n_atoms, :] coords = coords.reshape(B, L, n_atoms, 3) - seq_mask = seq_mask[:, :L] + # seq_mask = seq_mask[:, :L] # Keep only protein mask + # keep only protein mask and make it just a residue mask so from [B, L*n_atoms+L_ligand] to [B, L] + seq_mask = seq_mask[:, : L * n_atoms].reshape(B, L, n_atoms) + seq_mask = seq_mask.sum(dim=-1) > 0 + + # ligand_mask = seq_mask[:, L*n_atoms:] + return ( coords, seq_mask, @@ -222,6 +292,7 @@ def featurize( ligand_atomic_numbers, ) else: + # Ligand-only return None, None, None, None, ligand_coords, ligand_mask, ligand_residue_index, ligand_atomic_numbers return coords, seq_mask, residue_index, sequence @@ -236,6 +307,7 @@ def forward( ligand_mask: Tensor | None = None, ligand_residue_index: Tensor | None = None, ligand_atom_types: Tensor | None = None, + ligand_bond_matrix: Tensor | None = None, return_embeddings: bool = False, **kwargs, ): @@ -245,6 +317,13 @@ def forward( coords = coords[:, :, : self.n_atoms, :] else: B, _, _ = ligand_coords.shape + + # Extract bond_matrix from batch if not passed explicitly + if ligand_bond_matrix is None and "batch" in kwargs: + batch = kwargs["batch"] + if batch is not None and "bond_matrix" in batch: + ligand_bond_matrix = batch["bond_matrix"] + emb = self.net( coords, seq_mask=seq_mask, @@ -253,6 +332,7 @@ def forward( ligand_mask=ligand_mask, ligand_residue_index=ligand_residue_index, ligand_atom_types=ligand_atom_types, + ligand_bond_matrix=ligand_bond_matrix, attn_drop_out_rate=self.attn_drop_out_rate, return_embeddings=return_embeddings, sequence=sequence, diff --git a/src/lobster/model/latent_generator/tokenizer/_losses.py b/src/lobster/model/latent_generator/tokenizer/_losses.py index 8449bf83..92d06099 100644 --- a/src/lobster/model/latent_generator/tokenizer/_losses.py +++ b/src/lobster/model/latent_generator/tokenizer/_losses.py @@ -39,30 +39,73 @@ def __init__(self, clamp=50, ligand_weight=5.0, permute_chains=False): self.ligand_weight = ligand_weight def forward_ligand(self, ground_truth, predictions, mask, eps=1e-5, **kwargs): - # note that we do not consider relative reconstruction for the ligand and the protein + # MODIFIED: Now we align protein+ligand together to preserve relative positioning predicted_protein = predictions["protein_coords"] - if predicted_protein is not None: + # Check both that protein predictions exist AND that we have protein mask + # For ligand-only batches (e.g., GEOM), mask only has "ligand_mask" + if predicted_protein is not None and "protein_mask" in mask: B, L, n_atoms, _ = predicted_protein.shape predicted_protein = predicted_protein[:, :, :3, :] ground_truth_protein = ground_truth["coords_res"] ground_truth_protein = ground_truth_protein[:, :, :3, :] mask_protein = mask["protein_mask"] + else: + # Force ligand-only path if no protein mask available + predicted_protein = None predicted_ligand = predictions["ligand_coords"] ground_truth_ligand = ground_truth["ligand_coords"] mask_ligand = mask["ligand_mask"] - # align ground truth to predictions + # align ground truth to predictions - JOINT ALIGNMENT for protein+ligand complex with torch.no_grad(): with torch.autocast(enabled=False, device_type=predicted_ligand.device.type): if predicted_protein is not None: + # Concatenate protein and ligand for joint alignment mask_protein_expanded = mask_protein.unsqueeze(-1).repeat(1, 1, 3) - ground_truth_protein = kabsch_torch_batched( - ground_truth_protein.reshape(B, -1, 3), - predicted_protein.reshape(B, -1, 3), - mask_protein_expanded.reshape(B, -1), + mask_protein_flat = mask_protein_expanded.reshape(B, -1) + + # Flatten protein coordinates + gt_protein_flat = ground_truth_protein.reshape(B, -1, 3) + pred_protein_flat = predicted_protein.reshape(B, -1, 3) + + # Concatenate protein + ligand for JOINT alignment + gt_complex = torch.cat([gt_protein_flat, ground_truth_ligand], dim=1) + pred_complex = torch.cat([pred_protein_flat, predicted_ligand], dim=1) + mask_complex = torch.cat([mask_protein_flat, mask_ligand], dim=1) + + # Safety for kabsch: Replace invalid samples with noise to prevent SVD NaN + valid_sample = mask_complex.sum(dim=1) > 0 + mask_safe = torch.where(valid_sample[:, None], mask_complex, torch.ones_like(mask_complex)) + + noise_gt = torch.randn_like(gt_complex) + noise_pred = torch.randn_like(pred_complex) + gt_safe = torch.where(valid_sample[:, None, None], gt_complex, noise_gt) + pred_safe = torch.where(valid_sample[:, None, None], pred_complex, noise_pred) + + # Align the ENTIRE complex together + aligned_complex = kabsch_torch_batched(gt_safe, pred_safe, mask_safe) + + # Split back into protein and ligand + n_protein_atoms = gt_protein_flat.shape[1] + ground_truth_protein = aligned_complex[:, :n_protein_atoms, :].reshape(B, L, 3, 3) + ground_truth_ligand = aligned_complex[:, n_protein_atoms:, :] + + else: + # Ligand-only case: align ligand independently + valid_sample_ligand = mask_ligand.sum(dim=1) > 0 + mask_ligand_safe = torch.where( + valid_sample_ligand[:, None], mask_ligand, torch.ones_like(mask_ligand) + ) + noise_gt_ligand = torch.randn_like(ground_truth_ligand) + noise_pred_ligand = torch.randn_like(predicted_ligand) + gt_ligand_safe = torch.where( + valid_sample_ligand[:, None, None], ground_truth_ligand, noise_gt_ligand ) - ground_truth_protein = ground_truth_protein.reshape(B, L, 3, 3) - ground_truth_ligand = kabsch_torch_batched(ground_truth_ligand, predicted_ligand, mask_ligand) + pred_ligand_safe = torch.where( + valid_sample_ligand[:, None, None], predicted_ligand, noise_pred_ligand + ) + + ground_truth_ligand = kabsch_torch_batched(gt_ligand_safe, pred_ligand_safe, mask_ligand_safe) # calculate loss if predicted_protein is not None: @@ -99,13 +142,31 @@ def forward(self, ground_truth_, predictions, mask, eps=1e-5, keep_batch_dim: bo predictions = predictions["protein_coords"] predictions = predictions[:, :, :3, :] + # Handle dict mask (from ligand datasets) + if isinstance(mask, dict): + if "protein_mask" not in mask: + # Ligand-only batch - no protein data to compute protein loss + return torch.tensor(0.0, device=predictions.device, requires_grad=True) + mask = mask["protein_mask"] + # align predictions to ground truth with torch.no_grad(): with torch.autocast(enabled=False, device_type=predictions.device.type): mask_expanded = mask.unsqueeze(-1).repeat(1, 1, 3) - ground_truth = kabsch_torch_batched( - ground_truth.reshape(B, -1, 3), predictions.reshape(B, -1, 3), mask_expanded.reshape(B, -1) - ) + + # Safety for kabsch + mask_flat = mask_expanded.reshape(B, -1) + valid_sample = mask_flat.sum(dim=1) > 0 + mask_safe = torch.where(valid_sample[:, None], mask_flat, torch.ones_like(mask_flat)) + + gt_flat = ground_truth.reshape(B, -1, 3) + pred_flat = predictions.reshape(B, -1, 3) + noise_gt = torch.randn_like(gt_flat) + noise_pred = torch.randn_like(pred_flat) + gt_safe = torch.where(valid_sample[:, None, None], gt_flat, noise_gt) + pred_safe = torch.where(valid_sample[:, None, None], pred_flat, noise_pred) + + ground_truth = kabsch_torch_batched(gt_safe, pred_safe, mask_safe) ground_truth = ground_truth.reshape(B, L, 3, 3) # use MSE loss @@ -144,10 +205,23 @@ def forward(self, ground_truth_, predictions, mask, eps=1e-5, keep_batch_dim: bo # step 2b: realign the permuted chains to the predictions with torch.autocast(enabled=False, device_type=predictions.device.type): mask_expanded = mask.unsqueeze(-1).repeat(1, 1, 3) + + # Safety for kabsch + mask_flat = mask_expanded.reshape(B, -1) + valid_sample = mask_flat.sum(dim=1) > 0 + mask_safe = torch.where(valid_sample[:, None], mask_flat, torch.ones_like(mask_flat)) + + gt_flat = ground_truth_permuted.reshape(B, -1, 3) + pred_flat = predictions.reshape(B, -1, 3) + noise_gt = torch.randn_like(gt_flat) + noise_pred = torch.randn_like(pred_flat) + gt_safe = torch.where(valid_sample[:, None, None], gt_flat, noise_gt) + pred_safe = torch.where(valid_sample[:, None, None], pred_flat, noise_pred) + ground_truth_permuted = kabsch_torch_batched( - ground_truth_permuted.reshape(B, -1, 3), - predictions.reshape(B, -1, 3), - mask_expanded.reshape(B, -1), + gt_safe, + pred_safe, + mask_safe, ) ground_truth_permuted = ground_truth_permuted.reshape(B, L, 3, 3) @@ -178,6 +252,15 @@ def __init__(self, ligand_weight=1.0): self.ligand_weight = ligand_weight def forward(self, ground_truth, predictions, mask, eps=1e-5, **kwargs): + # Skip if no ligand data (protein-only batch) + if not isinstance(predictions, dict) or "ligand_coords" not in predictions: + return torch.tensor( + 0.0, + device=predictions.device + if not isinstance(predictions, dict) + else predictions["protein_coords"].device, + ) + predicted_ligand = self.ligand_weight * predictions["ligand_coords"] ground_truth_ligand = self.ligand_weight * ground_truth["ligand_coords"] mask_ligand = mask["ligand_mask"] @@ -185,7 +268,14 @@ def forward(self, ground_truth, predictions, mask, eps=1e-5, **kwargs): # align ground truth to predictions with torch.no_grad(): with torch.autocast(enabled=False, device_type=predicted_ligand.device.type): - ground_truth_ligand = kabsch_torch_batched(ground_truth_ligand, predicted_ligand, mask_ligand) + valid_sample = mask_ligand.sum(dim=1) > 0 + mask_safe = torch.where(valid_sample[:, None], mask_ligand, torch.ones_like(mask_ligand)) + noise_gt = torch.randn_like(ground_truth_ligand) + noise_pred = torch.randn_like(predicted_ligand) + gt_safe = torch.where(valid_sample[:, None, None], ground_truth_ligand, noise_gt) + pred_safe = torch.where(valid_sample[:, None, None], predicted_ligand, noise_pred) + + ground_truth_ligand = kabsch_torch_batched(gt_safe, pred_safe, mask_safe) loss_ligand = nn.MSELoss(reduction="none")(predicted_ligand, ground_truth_ligand) loss_ligand = loss_ligand * mask_ligand[:, :, None] @@ -200,13 +290,30 @@ def __init__(self, ligand_weight=1.0): self.ligand_weight = ligand_weight def forward(self, ground_truth, predictions, mask, eps=1e-5, **kwargs): + # Skip if no ligand data (protein-only batch) + if not isinstance(predictions, dict) or "ligand_coords" not in predictions: + return torch.tensor( + 0.0, + device=predictions.device + if not isinstance(predictions, dict) + else predictions["protein_coords"].device, + ) + predicted_ligand = self.ligand_weight * predictions["ligand_coords"] ground_truth_ligand = self.ligand_weight * ground_truth["ligand_coords"] mask_ligand = mask["ligand_mask"] - Dpred = torch.cdist(predicted_ligand, predicted_ligand, p=2) + # Stabilize masked coordinates with noise + mask_ligand_expanded = mask_ligand.unsqueeze(-1) + noise_pred = torch.randn_like(predicted_ligand) * 1e-3 + predicted_ligand_safe = predicted_ligand * mask_ligand_expanded + noise_pred * (1 - mask_ligand_expanded) + + noise_gt = torch.randn_like(ground_truth_ligand) * 1e-3 + ground_truth_ligand_safe = ground_truth_ligand * mask_ligand_expanded + noise_gt * (1 - mask_ligand_expanded) + + Dpred = torch.cdist(predicted_ligand_safe, predicted_ligand_safe, p=2) Dpred = torch.clamp(Dpred, max=20) - D = torch.cdist(ground_truth_ligand, ground_truth_ligand, p=2) + D = torch.cdist(ground_truth_ligand_safe, ground_truth_ligand_safe, p=2) D = torch.clamp(D, max=20) E = (Dpred - D) ** 2 E = torch.clamp(E, max=25) @@ -323,7 +430,9 @@ def forward(self, ground_truth, predictions, mask, eps=1e-5, keep_batch_dim: boo if ligand_present: # get the protein and ligand predictions predicted_protein = predictions["protein_coords"] - if predicted_protein is not None: + # Check both that protein predictions exist AND that we have protein mask + # For ligand-only batches (e.g., GEOM), mask only has "ligand_mask" + if predicted_protein is not None and "protein_mask" in mask: B, L, n_atoms, _ = predicted_protein.shape predicted_protein = predicted_protein[:, :, :3, :] ground_truth_protein = ground_truth["coords_res"] @@ -333,6 +442,10 @@ def forward(self, ground_truth, predictions, mask, eps=1e-5, keep_batch_dim: boo Z_hat_protein = predicted_protein.reshape(B, -1, 3) Z_protein = ground_truth_protein.reshape(B, -1, 3) mask_protein = mask_protein.unsqueeze(-1).repeat(1, 1, n_atoms).view(B, -1) + else: + # Force ligand-only path if no protein mask available + predicted_protein = None + ground_truth_ligand = ground_truth["ligand_coords"] predicted_ligand = predictions["ligand_coords"] mask_ligand = mask["ligand_mask"] @@ -340,11 +453,12 @@ def forward(self, ground_truth, predictions, mask, eps=1e-5, keep_batch_dim: boo if predicted_protein is not None: Z_hat = torch.cat([Z_hat_protein, predicted_ligand], dim=1) Z = torch.cat([Z_protein, ground_truth_ligand], dim=1) - mask = torch.cat([mask_protein, mask_ligand], dim=1) + mask_combined = torch.cat([mask_protein, mask_ligand], dim=1) + mask_flat = mask_combined else: Z_hat = predicted_ligand Z = ground_truth_ligand - mask = mask_ligand + mask_flat = mask_ligand.clone() n_atoms = 3 else: @@ -352,16 +466,41 @@ def forward(self, ground_truth, predictions, mask, eps=1e-5, keep_batch_dim: boo B, L, n_atoms, _ = predictions.shape ground_truth = ground_truth[:, :, :n_atoms, :] + # Handle dict mask (from ligand datasets with protein-only batch) + if isinstance(mask, dict): + if "protein_mask" not in mask: + # No protein data in this batch - return 0 loss + return torch.tensor(0.0, device=predictions.device, requires_grad=True) + mask = mask["protein_mask"] + # Step 1: Flatten predictions and ground_truth Z_hat = predictions.reshape(predictions.size(0), -1, 3) # (B, L*n_atoms,3) Z = ground_truth.reshape(ground_truth.size(0), -1, 3) # (B, L*n_atoms,3) + mask_flat = mask.unsqueeze(-1).repeat(1, 1, n_atoms).view(B, -1) + + # Step 1.5: Stabilize masked coordinates with noise to prevent NaN gradients in cdist + # cdist(0, 0) gradient is NaN. If padding is 0, we get NaNs even if masked later. + # We replace masked 0s with small random noise. + + # CRITICAL: Detach mask_flat once at the beginning to prevent version tracking issues + # Masks don't need gradients, so this is safe + mask_flat_detached = mask_flat.detach() + + # Create expanded mask from detached version + mask_flat_expanded = mask_flat_detached.unsqueeze(-1) + + noise_hat = torch.randn_like(Z_hat) * 1e-3 + Z_hat_safe = Z_hat * mask_flat_expanded + noise_hat * (1 - mask_flat_expanded) + + noise_gt = torch.randn_like(Z) * 1e-3 + Z_safe = Z * mask_flat_expanded + noise_gt * (1 - mask_flat_expanded) # Step 2: Compute Dpred - Dpred = torch.cdist(Z_hat, Z_hat, p=2) # (B, L*n_atoms, L*n_atoms) + Dpred = torch.cdist(Z_hat_safe, Z_hat_safe, p=2) # (B, L*n_atoms, L*n_atoms) Dpred = torch.clamp(Dpred, max=20) # Step 3: Compute D - D = torch.cdist(Z, Z, p=2) # (B, L*n_atoms, L*n_atoms) + D = torch.cdist(Z_safe, Z_safe, p=2) # (B, L*n_atoms, L*n_atoms) # Step 4: Compute E E = (Dpred - D) ** 2 @@ -373,14 +512,13 @@ def forward(self, ground_truth, predictions, mask, eps=1e-5, keep_batch_dim: boo if n_atoms > 3: raise NotImplementedError("nonbackbone PairWiseL2Loss is not implemented correctly yet") else: - if not ligand_present: - mask = mask.unsqueeze(-1).repeat(1, 1, n_atoms).view(B, -1) - mask = mask[:, None, :] * mask[:, :, None] # (B, L*n_atoms, L*n_atoms) - E = E * mask + # Use detached mask_flat to avoid version conflicts + mask_pairwise = mask_flat_detached[:, None, :] * mask_flat_detached[:, :, None] + E = E * mask_pairwise if keep_batch_dim: - l = E.sum(dim=(1, 2)) / mask.sum(dim=1) + l = E.sum(dim=(1, 2)) / mask_pairwise.sum(dim=1) else: - l = E.sum() / (mask.sum() + eps) + l = E.sum() / (mask_pairwise.sum() + eps) return l diff --git a/src/lobster/model/latent_generator/tokenizer/_tokenizer_multi.py b/src/lobster/model/latent_generator/tokenizer/_tokenizer_multi.py index 5555f9e7..54ea847d 100644 --- a/src/lobster/model/latent_generator/tokenizer/_tokenizer_multi.py +++ b/src/lobster/model/latent_generator/tokenizer/_tokenizer_multi.py @@ -7,6 +7,7 @@ from collections.abc import Callable from typing import Literal +import hydra import lightning.pytorch as pl import omegaconf import torch @@ -58,8 +59,38 @@ def __init__( min_mask_timestep: float = 0.5, mask_sequence: bool = False, debug: bool = False, + num_warmup_steps: int = 50000, + num_training_steps: int = 500000, + ckpt_path: str = None, ): super().__init__() + + # Debug: Check what types we received + logger.info(f"Received structure_encoder type: {type(structure_encoder)}") + logger.info(f"Received quantizer type: {type(quantizer)}") + logger.info(f"Received decoder_factory type: {type(decoder_factory)}") + logger.info(f"Received loss_factory type: {type(loss_factory)}") + + # Instantiate modules if they're DictConfig objects (Hydra didn't instantiate them) + if isinstance(structure_encoder, omegaconf.DictConfig): + logger.info("Instantiating structure_encoder from config") + structure_encoder = hydra.utils.instantiate(structure_encoder) + if isinstance(quantizer, omegaconf.DictConfig): + logger.info("Instantiating quantizer from config") + quantizer = hydra.utils.instantiate(quantizer) + if isinstance(decoder_factory, omegaconf.DictConfig): + logger.info("Instantiating decoder_factory from config") + decoder_factory = hydra.utils.instantiate(decoder_factory) + if isinstance(loss_factory, omegaconf.DictConfig): + logger.info("Instantiating loss_factory from config") + loss_factory = hydra.utils.instantiate(loss_factory) + if isinstance(optim, omegaconf.DictConfig): + logger.info("Instantiating optim from config") + optim = hydra.utils.instantiate(optim) + if isinstance(lr_scheduler, omegaconf.DictConfig): + logger.info("Instantiating lr_scheduler from config") + lr_scheduler = hydra.utils.instantiate(lr_scheduler) + self.encoder = structure_encoder self.quantizer = quantizer self.decoder_factory = decoder_factory @@ -80,12 +111,140 @@ def __init__( logger.info(f"Using min_mask_timestep: {self.min_mask_timestep}") logger.info(f"Using schedule: {self.schedule}") + self.num_warmup_steps = num_warmup_steps + self.num_training_steps = num_training_steps + self.ckpt_path = ckpt_path self.structure_path = structure_path if not os.path.exists(f"{self.structure_path}train/") and self.structure_path is not None: os.makedirs(f"{self.structure_path}train/") self.automatic_optimization = automatic_optimization + # Debug: Check number of parameters + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + logger.info(f"Total parameters: {total_params:,}") + logger.info(f"Trainable parameters: {trainable_params:,}") + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: str, + map_location=None, + hparams_file: str | None = None, + strict: bool = True, + load_encoder: bool = True, + load_encoder_strict: bool = True, + load_quantizer: bool = True, + load_quantizer_strict: bool = True, + load_decoder: bool = True, + load_decoder_strict: bool = True, + **kwargs, + ): + """Load model from checkpoint with support for component-specific loading. + + This override allows selective loading of encoder, quantizer, and decoder + components with different strictness levels for each component. + + Args: + checkpoint_path: Path to the checkpoint file + map_location: Device to map tensors to + hparams_file: Path to hyperparameters file (unused, for compatibility) + strict: If True and all component flags are True, use standard loading. + If False, enables component-specific loading. + load_encoder: Whether to load encoder weights + load_encoder_strict: Whether to use strict loading for encoder + load_quantizer: Whether to load quantizer weights + load_quantizer_strict: Whether to use strict loading for quantizer + load_decoder: Whether to load decoder weights + load_decoder_strict: Whether to use strict loading for decoder + **kwargs: Additional arguments passed to model __init__ + + Returns: + Loaded model instance + """ + # Determine if we need component-specific loading + use_component_loading = not strict or any( + [ + not load_encoder, + not load_quantizer, + not load_decoder, + not load_encoder_strict, + not load_quantizer_strict, + not load_decoder_strict, + ] + ) + + if use_component_loading: + logger.info("Using component-specific checkpoint loading") + checkpoint = torch.load(checkpoint_path, map_location=map_location) + checkpoint_state_dict = checkpoint["state_dict"] + + # Instantiate model with new hyperparameters + model = cls(**kwargs) + + # Load encoder weights if requested + if load_encoder: + logger.info("Loading encoder weights from checkpoint") + encoder_state_dict = { + k.replace("encoder.", "", 1): v + for k, v in checkpoint_state_dict.items() + if k.startswith("encoder.") + } + if encoder_state_dict: + model.encoder.load_state_dict(encoder_state_dict, strict=load_encoder_strict) + logger.info(f"Successfully loaded {len(encoder_state_dict)} encoder parameters") + else: + logger.warning("No encoder weights found in checkpoint") + else: + logger.info("Keeping randomly initialized encoder") + + # Load quantizer weights if requested + if load_quantizer and model.quantizer is not None: + logger.info("Loading quantizer weights from checkpoint") + quantizer_state_dict = { + k.replace("quantizer.", "", 1): v + for k, v in checkpoint_state_dict.items() + if k.startswith("quantizer.") + } + if quantizer_state_dict: + model.quantizer.load_state_dict(quantizer_state_dict, strict=load_quantizer_strict) + logger.info(f"Successfully loaded {len(quantizer_state_dict)} quantizer parameters") + else: + logger.warning("No quantizer weights found in checkpoint") + elif model.quantizer is None: + logger.info("No quantizer in model, skipping quantizer loading") + else: + logger.info("Keeping randomly initialized quantizer") + + # Load decoder weights if requested + if load_decoder: + logger.info("Loading decoder weights from checkpoint") + decoder_state_dict = { + k.replace("decoder_factory.", "", 1): v + for k, v in checkpoint_state_dict.items() + if k.startswith("decoder_factory.") + } + if decoder_state_dict: + model.decoder_factory.load_state_dict(decoder_state_dict, strict=load_decoder_strict) + logger.info(f"Successfully loaded {len(decoder_state_dict)} decoder parameters") + else: + logger.warning("No decoder weights found in checkpoint") + else: + logger.info("Keeping randomly initialized decoder") + + return model + else: + # Standard Lightning checkpoint loading + logger.info("Using standard checkpoint loading") + return super().load_from_checkpoint( + checkpoint_path=checkpoint_path, + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, + ) + def on_after_backward(self): if self.debug: for name, param in self.named_parameters(): @@ -293,6 +452,7 @@ def single_step(self, batch, batch_idx, split="train"): ligand_mask=x_feat[5], ligand_residue_index=x_feat[6], ligand_atom_types=x_feat[7], + batch=batch, ) elif len(x_feat) == 7: x_emb = self.encoder( @@ -303,9 +463,10 @@ def single_step(self, batch, batch_idx, split="train"): ligand_coords=x_feat[4], ligand_mask=x_feat[5], ligand_residue_index=x_feat[6], + batch=batch, ) else: - x_emb = self.encoder(*x_feat) # Keep original unpacking for backward compatibility + x_emb = self.encoder(*x_feat, batch=batch) # Keep original unpacking for backward compatibility if self.quantizer is not None: # check if cls token is used @@ -384,13 +545,29 @@ def validation_step(self, batch, batch_idx): # mask, t): """Validation step of the model.""" return self.single_step(batch, batch_idx, split="val") + # def configure_optimizers(self): + # """Configure the optimizer and learning rate scheduler.""" + # optimizer = self.optim_factory(params=self.parameters()) + + # out = {"optimizer": optimizer} + + # out["lr_scheduler"] = {"scheduler": self.lr_scheduler(optimizer=optimizer), "interval": "step"} + + # return out + def configure_optimizers(self): """Configure the optimizer and learning rate scheduler.""" optimizer = self.optim_factory(params=self.parameters()) out = {"optimizer": optimizer} - out["lr_scheduler"] = {"scheduler": self.lr_scheduler(optimizer=optimizer), "interval": "step"} + # Pass num_warmup_steps and num_training_steps to the lr_scheduler factory + out["lr_scheduler"] = { + "scheduler": self.lr_scheduler( + optimizer=optimizer, num_warmup_steps=self.num_warmup_steps, num_training_steps=self.num_training_steps + ), + "interval": "step", + } return out diff --git a/src/lobster/model/latent_generator/utils/__init__.py b/src/lobster/model/latent_generator/utils/__init__.py index 0bf9ddaa..911d1c79 100644 --- a/src/lobster/model/latent_generator/utils/__init__.py +++ b/src/lobster/model/latent_generator/utils/__init__.py @@ -3,10 +3,17 @@ apply_global_frame_to_coords, apply_random_se3_2, apply_random_se3_batched, + apply_random_se3_protein_ligand, c6d_to_bins, xyz_to_c6d, ) from ._lrf import compute_geometric_features +from ._minimize_ligand import get_ligand_energy, minimize_ligand_structure +from ._se3_augmentation import ( + SE3AugmentedComplex, + apply_se3_augmentation_batched, + apply_se3_augmentation_protein_ligand, +) from ._utils import ( batch_align_on_calpha, extract_cropped_coordinates, diff --git a/src/lobster/model/latent_generator/utils/_get_ligand_coords.py b/src/lobster/model/latent_generator/utils/_get_ligand_coords.py new file mode 100644 index 00000000..e0f16930 --- /dev/null +++ b/src/lobster/model/latent_generator/utils/_get_ligand_coords.py @@ -0,0 +1,378 @@ +import torch +import tqdm +import os +import pandas as pd +import s3fs +from loguru import logger +from latent_generator.io import load_ligand +from concurrent.futures import ThreadPoolExecutor, as_completed +from collections import defaultdict +import multiprocessing +import hashlib +import random +from rdkit import Chem +from rdkit.Chem import rdFreeSASA, Descriptors3D +from upath import UPath + + +def load_mol_from_s3_sdf(s3_path: str): + """Load RDKit molecule directly from S3 SDF without downloading""" + path = UPath(s3_path) + + # Read file content directly into memory + with path.open("r") as f: + sdf_content = f.read() + + # Create molecule from SDF content + mol = Chem.MolFromMolBlock(sdf_content) + return mol + + +def calc_rg_sasa(conformer_data): + """Calculate radius of gyration and SASA from conformer data""" + s3_path = conformer_data["sdf_filepath"] + + mol = load_mol_from_s3_sdf(s3_path) + + # Calculate RG + rg = Descriptors3D.RadiusOfGyration(mol) + + # Calculate SASA + ptable = Chem.GetPeriodicTable() + radii = [ptable.GetRvdw(atom.GetAtomicNum()) for atom in mol.GetAtoms()] + sasa = rdFreeSASA.CalcSASA(mol, radii) + + return rg, sasa + + +def load_geom_parquet_from_s3(s3_path: str) -> pd.DataFrame: + """ + Load a parquet file from S3 containing GEOM dataset data. + + Args: + s3_path: S3 path to the parquet file + + Returns: + pandas DataFrame containing the data + """ + try: + logger.info(f"Loading parquet file from: {s3_path}") + + # Use s3fs to read parquet file directly + fs = s3fs.S3FileSystem() + df = pd.read_parquet(s3_path, filesystem=fs) + + logger.info(f"Successfully loaded data with shape: {df.shape}") + logger.info(f"Columns: {list(df.columns)}") + + return df + + except Exception as e: + logger.error(f"Error loading parquet file: {e}") + raise + + +def check_parquet_already_processed(parquet_file, save_path): + """ + Check if all ligand PT files for a parquet file already exist. + + Args: + parquet_file: S3 path to the parquet file + save_path: Directory where PT files are saved + + Returns: + bool: True if all PT files exist, False otherwise + """ + try: + # Load parquet file to get all unique SMILES + df = load_geom_parquet_from_s3(parquet_file) + rows = df.to_dict(orient="records") + + # Group by SMILES to get unique SMILES strings + smiles_groups = defaultdict(list) + for row in rows: + smiles_groups[row["smiles"]].append(row) + + unique_smiles = list(smiles_groups.keys()) + + # Check if all corresponding PT files exist + missing_files = 0 + for smiles in unique_smiles: + smiles_hash = hashlib.md5(smiles.encode()).hexdigest()[:8] + filename = f"ligand_{smiles_hash}.pt" + output_path = os.path.join(save_path, filename) + + if not os.path.exists(output_path): + missing_files += 1 + + if missing_files == 0: + logger.info( + f"All {len(unique_smiles)} PT files for {os.path.basename(parquet_file)} already exist, skipping" + ) + return True + else: + logger.info( + f"Missing {missing_files}/{len(unique_smiles)} PT files for {os.path.basename(parquet_file)}, processing" + ) + return False + + except Exception as e: + logger.error(f"Error checking if parquet file is already processed: {e}") + return False + + +def process_single_ligand(row_data, fs): + """ + Process a single ligand from the parquet data. + + Args: + row_data: Dictionary containing 'smiles' and 'sdf_path' + fs: S3 filesystem object + + Returns: + tuple: (success: bool, ligand_data: dict or None, error: str or None) + """ + try: + smiles = row_data["smiles"] + sdf_path = row_data["sdf_path"] + + # Load ligand data from SDF file + ligand_data = load_ligand(sdf_path, add_batch_dim=False) + + # Add SMILES information to the data + ligand_data["smiles"] = smiles + ligand_data["sdf_filepath"] = sdf_path + + # Calculate RG and SASA properties + rg, sasa = calc_rg_sasa(ligand_data) + ligand_data["radius_of_gyration"] = float(rg) + ligand_data["solvent_accessible_surface_area"] = float(sasa) + + return True, ligand_data, None + + except Exception as e: + error_msg = f"Error processing ligand {smiles[:50]}...: {str(e)}" + return False, None, error_msg + + +def save_conformers_group(smiles, conformers_data, save_path): + """ + Save all conformers for a given SMILES string to a single PT file. + + Args: + smiles: SMILES string + conformers_data: List of ligand data dictionaries for all conformers + save_path: Directory to save the processed PT file + + Returns: + tuple: (success: bool, filename: str, error: str or None) + """ + try: + # Create a unique filename based on SMILES hash + smiles_hash = hashlib.md5(smiles.encode()).hexdigest()[:8] + filename = f"ligand_{smiles_hash}.pt" + output_path = os.path.join(save_path, filename) + + # Skip if file already exists + if os.path.exists(output_path): + logger.info(f"File {filename} already exists, skipping") + return True, filename, None + + # Combine all conformers into a single data structure + combined_data = {"smiles": smiles, "num_conformers": len(conformers_data), "conformers": conformers_data} + # Save to PT file + torch.save(combined_data, output_path) + + return True, filename, None + + except Exception as e: + error_msg = f"Error saving conformers for {smiles[:50]}...: {str(e)}" + return False, None, error_msg + + +def process_smiles_group(smiles, conformer_rows, fs, save_path): + """ + Process all conformers for a single SMILES group and save to PT file. + + Args: + smiles: SMILES string + conformer_rows: List of row data for all conformers of this SMILES + fs: S3 filesystem object + save_path: Directory to save the processed PT file + + Returns: + tuple: (success: bool, num_conformers: int, error_count: int, error_msg: str or None) + """ + try: + # Process all conformers for this SMILES sequentially + conformers_data = [] + errors = 0 + + for row in conformer_rows: + success, ligand_data, error = process_single_ligand(row, fs) + + if success: + conformers_data.append(ligand_data) + else: + errors += 1 + logger.error(error) + + # Save all conformers for this SMILES to a single PT file + if conformers_data: + success, filename, error = save_conformers_group(smiles, conformers_data, save_path) + + if success: + return True, len(conformers_data), errors, None + else: + return False, 0, errors + 1, error + else: + error_msg = f"No conformers successfully processed for SMILES {smiles[:50]}..." + return False, 0, errors + 1, error_msg + + except Exception as e: + error_msg = f"Error processing SMILES group {smiles[:50]}...: {str(e)}" + return False, 0, 1, error_msg + + +def process_geom_ligands( + s3_root: str = "s3://prescient-lobster/ume/datasets/geom/processed/test", + save_path: str = "/data/bucket/lisanza/structures/GEOM/processed/", + max_workers: int = 8, + testing: bool = False, + shuffle_files: bool = True, +): + """ + Process all ligand data from GEOM dataset parquet files and save as PT files. + Groups conformers by SMILES string and saves all conformers in a single PT file. + Skips parquet files where all corresponding PT files already exist. + + Args: + s3_root: S3 path to the directory containing parquet files + save_path: Local directory to save processed PT files + max_workers: Maximum number of parallel workers + testing: If True, only process a small subset for testing + shuffle_files: If True, shuffle the order of parquet files for better load distribution + """ + + # Create output directory + os.makedirs(save_path, exist_ok=True) + logger.info(f"Output directory: {save_path}") + + # Initialize S3 filesystem + fs = s3fs.S3FileSystem() + + # Find all parquet files in S3 + try: + all_files = fs.find(s3_root) + parquet_files = [f"s3://{file_path}" for file_path in all_files if file_path.endswith(".parquet")] + logger.info(f"Found {len(parquet_files)} parquet files in S3") + + if testing: + parquet_files = parquet_files[:2] # Only process first 2 files for testing + logger.info(f"Testing mode: processing only {len(parquet_files)} files") + + # Shuffle the parquet files for better load distribution + if shuffle_files: + random.shuffle(parquet_files) + logger.info("Shuffled parquet files order for better load distribution") + + except Exception as e: + logger.error(f"Error listing S3 files: {e}") + raise + + # Process each parquet file + total_processed = 0 + total_errors = 0 + total_skipped = 0 + + for parquet_file in tqdm.tqdm(parquet_files, desc="Processing parquet files"): + try: + # Check if this parquet file has already been fully processed + if check_parquet_already_processed(parquet_file, save_path): + total_skipped += 1 + continue + + # Load parquet file + df = load_geom_parquet_from_s3(parquet_file) + + # Convert to list of dictionaries for processing + rows = df.to_dict(orient="records") + + if testing: + rows = rows[:50] # Only process first 50 rows for testing + + logger.info(f"Processing {len(rows)} ligands from {parquet_file}") + + # Group rows by SMILES string + smiles_groups = defaultdict(list) + for row in rows: + smiles_groups[row["smiles"]].append(row) + + logger.info(f"Found {len(smiles_groups)} unique SMILES strings") + + # Process all SMILES groups in parallel + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all SMILES group processing tasks + future_to_smiles = { + executor.submit(process_smiles_group, smiles, conformer_rows, fs, save_path): smiles + for smiles, conformer_rows in smiles_groups.items() + } + + # Process completed tasks with progress bar + for future in tqdm.tqdm( + as_completed(future_to_smiles), + total=len(smiles_groups), + desc=f"Processing SMILES groups from {os.path.basename(parquet_file)}", + ): + _ = future_to_smiles[future] # Keep mapping for debugging if needed + success, num_conformers, error_count, error_msg = future.result() + + if success: + total_processed += 1 + # logger.info(f"Saved {num_conformers} conformers for SMILES {smiles[:50]}...") + else: + total_errors += 1 + logger.error(error_msg) + + # Add any errors from conformer processing to total + total_errors += error_count + + logger.info(f"Completed {parquet_file}") + + except Exception as e: + logger.error(f"Error processing parquet file {parquet_file}: {e}") + total_errors += 1 + + logger.info("Processing complete!") + logger.info(f"Total SMILES processed: {total_processed}") + logger.info(f"Total errors: {total_errors}") + logger.info(f"Total parquet files skipped: {total_skipped}") + logger.info(f"Output directory: {save_path}") + + +def main(): + """Main function to run the ligand processing.""" + + # Configuration + s3_root = "s3://prescient-lobster/ume/datasets/geom/processed/test" + # save_path = "/data/bucket/lisanza/structures/GEOM/processed/train/" + save_path = "/data/bucket/shmilovk/structures/GEOM/processed/test/" + max_workers = max(1, multiprocessing.cpu_count()) + testing = False # Set to False for full processing + shuffle_files = True # Shuffle parquet files for better load distribution + + logger.info("Starting GEOM ligand processing...") + logger.info(f"S3 root: {s3_root}") + logger.info(f"Output directory: {save_path}") + logger.info(f"Max workers: {max_workers}") + logger.info(f"Testing mode: {testing}") + logger.info(f"Shuffle files: {shuffle_files}") + + # Process the ligands + process_geom_ligands( + s3_root=s3_root, save_path=save_path, max_workers=max_workers, testing=testing, shuffle_files=shuffle_files + ) + + +if __name__ == "__main__": + main() diff --git a/src/lobster/model/latent_generator/utils/_get_protein_ligand_coords.py b/src/lobster/model/latent_generator/utils/_get_protein_ligand_coords.py new file mode 100644 index 00000000..f8fc6a7b --- /dev/null +++ b/src/lobster/model/latent_generator/utils/_get_protein_ligand_coords.py @@ -0,0 +1,78 @@ +import numpy as np +import torch +import tqdm +import glob +import os +from loguru import logger +from latent_generator.io import load_pdb, load_ligand + + +def process_pdb(file_name_protein, file_name_ligand, save_path): + # if file exists, return + if os.path.exists(save_path + file_name_protein.split("/")[-1].split(".")[0] + ".pt") and os.path.exists( + save_path + file_name_ligand.split("/")[-1].split(".")[0] + ".pt" + ): + return None + try: + structure_data_protein = load_pdb(file_name_protein, add_batch_dim=False) + structure_data_ligand = load_ligand(file_name_ligand, add_batch_dim=False) + # Save the processed data + save_path_protein = save_path + file_name_protein.split("/")[-1].split(".")[0] + ".pt" + save_path_ligand = save_path + file_name_ligand.split("/")[-1].split(".")[0] + ".pt" + torch.save(structure_data_protein, save_path_protein) + torch.save(structure_data_ligand, save_path_ligand) + # Clear memory + del structure_data_protein + del structure_data_ligand + torch.cuda.empty_cache() if torch.cuda.is_available() else None + return True + except Exception as e: + logger.error(f"Error processing {file_name_protein} and {file_name_ligand}: {str(e)}") + return None + + +def process_pdb_parallel(file_name_protein, file_name_ligand, save_path): + try: + return process_pdb(file_name_protein, file_name_ligand, save_path) + except Exception as e: + logger.error(f"Error in parallel processing of {file_name_protein} and {file_name_ligand}: {str(e)}") + return None + + +if __name__ == "__main__": + import concurrent.futures + import multiprocessing + + pdb_dir = "/data/bucket/lisanza/structures/pdb_bind/1981-2000/" + # pdb_dir = "/data/bucket/lisanza/structures/pdb_bind/2001-2010/" + # pdb_dir = "/data/bucket/lisanza/structures/pdb_bind/2011-2020/" + # pdb_dir = "/data/bucket/lisanza/structures/pdb_bind/2021-2023/" + # save_path = "/data/bucket/lisanza/structures/pdb_bind/processed/" + save_path = "/data/bucket/lisanza/structures/pdb_bind/processed_2/" + + os.makedirs(save_path, exist_ok=True) + # proteins have *_protein.pdb and ligands have *_ligand.sdf + pdb_paths = glob.glob(pdb_dir + "*/" + "*protein.pdb") + ligand_paths = glob.glob(pdb_dir + "*/" + "*ligand.sdf") + # sort pdb_paths and ligand_paths + pdb_paths.sort() + ligand_paths.sort() + # zip pdb_paths and ligand_paths + pdb_paths = list(zip(pdb_paths, ligand_paths)) + # shuffle pdb_paths + np.random.shuffle(pdb_paths) + logger.info(f"Number of pdb_paths: {len(pdb_paths)}") + + # Calculate optimal number of workers based on available memory + # Use 1/4 of available CPU cores to avoid memory issues + num_workers = max(1, multiprocessing.cpu_count() // 4) + logger.info(f"Using {num_workers} workers") + + logger.info(f"Processing {len(pdb_paths)} pdb files in parallel") + with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit(process_pdb_parallel, file_name_protein, file_name_ligand, save_path) + for file_name_protein, file_name_ligand in pdb_paths + ] + for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): + result = future.result() diff --git a/src/lobster/model/latent_generator/utils/_kinematics.py b/src/lobster/model/latent_generator/utils/_kinematics.py index 72d5d63d..bbfe3cec 100644 --- a/src/lobster/model/latent_generator/utils/_kinematics.py +++ b/src/lobster/model/latent_generator/utils/_kinematics.py @@ -390,6 +390,9 @@ def apply_random_se3(coords_in, atom_mask=None, translation_scale=1.0, rotation_ if rotation_mode == "svd": random_rot, _ = torch.linalg.qr(torch.randn(3, 3)) + # QR decomposition can produce det=-1 (reflection). Ensure proper rotation (det=+1) + if torch.linalg.det(random_rot) < 0: + random_rot[:, 0] = -random_rot[:, 0] # Flip first column to fix determinant elif rotation_mode == "quaternion": random_rot = uniform_rand_rotation(1).squeeze(0) elif rotation_mode == "none": @@ -416,9 +419,11 @@ def apply_random_se3_2(coords_in, atom_mask=None, translation_scale=1.0, rotatio coords_in -= coords_mean if rotation_mode == "svd": random_rot, _ = torch.linalg.qr(torch.randn(3, 3)) + # QR decomposition can produce det=-1 (reflection). Ensure proper rotation (det=+1) + if torch.linalg.det(random_rot) < 0: + random_rot[:, 0] = -random_rot[:, 0] # Flip first column to fix determinant elif rotation_mode == "quaternion": random_rot = uniform_rand_rotation(1).squeeze(0) - random_rot, _ = torch.linalg.qr(torch.randn(3, 3)) coords_in = coords_in @ random_rot.to(coords_in) random_trans = torch.randn_like(coords_mean) * translation_scale coords_in += random_trans.to(coords_in) @@ -440,6 +445,99 @@ def apply_random_se3_batched(coords_in, atom_mask=None, translation_scale=1.0, r return coords_in +def apply_random_se3_protein_ligand( + protein_coords: torch.Tensor, + ligand_coords: torch.Tensor | None = None, + protein_mask: torch.Tensor | None = None, + ligand_mask: torch.Tensor | None = None, + translation_scale: float = 1.0, + rotation_mode: str = "quaternion", +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply random SE(3) transformation to protein-ligand complex. + + Applies the SAME rotation and translation to both protein and ligand + coordinates, ensuring they remain in the same reference frame. + + Args: + protein_coords: Protein coordinates [B, L, n_atoms, 3] or [B, L, 3] + ligand_coords: Ligand coordinates [B, N_atoms, 3] or None + protein_mask: Protein mask [B, L] or None + ligand_mask: Ligand mask [B, N_atoms] or None + translation_scale: Scale factor for random translation + rotation_mode: Method to generate rotation ("svd", "quaternion", "none") + + Returns: + Tuple of (transformed_protein_coords, transformed_ligand_coords) + """ + B = protein_coords.shape[0] + device = protein_coords.device + dtype = protein_coords.dtype + + # Determine if protein coords are flat [B, L, 3] or structured [B, L, n_atoms, 3] + is_flat = len(protein_coords.shape) == 3 + + for b in range(B): + # Compute center from protein CA atoms (or mean if flat) + if is_flat: + if protein_mask is not None: + valid_coords = protein_coords[b][protein_mask[b].bool()] + if valid_coords.numel() > 0: + center = valid_coords.mean(dim=0, keepdim=True) # [1, 3] + else: + center = protein_coords[b].mean(dim=0, keepdim=True) + else: + center = protein_coords[b].mean(dim=0, keepdim=True) + else: + # Use CA atoms (index 1) for centering + ca_coords = protein_coords[b, :, 1, :] # [L, 3] + if protein_mask is not None: + valid_ca = ca_coords[protein_mask[b].bool()] + if valid_ca.numel() > 0: + center = valid_ca.mean(dim=0, keepdim=True) # [1, 3] + else: + center = ca_coords.mean(dim=0, keepdim=True) + else: + center = ca_coords.mean(dim=0, keepdim=True) + + # Generate random rotation + if rotation_mode == "svd": + R, _ = torch.linalg.qr(torch.randn(3, 3, device=device, dtype=dtype)) + # QR decomposition can produce det=-1 (reflection). Ensure proper rotation (det=+1) + if torch.linalg.det(R) < 0: + R[:, 0] = -R[:, 0] # Flip first column to fix determinant + elif rotation_mode == "quaternion": + R = uniform_rand_rotation(1).squeeze(0).to(device=device, dtype=dtype) + elif rotation_mode == "none": + R = torch.eye(3, device=device, dtype=dtype) + else: + R = torch.eye(3, device=device, dtype=dtype) + + # Generate random translation + trans = torch.randn(1, 3, device=device, dtype=dtype) * translation_scale + + # Apply to protein: center, rotate, translate + if is_flat: + protein_coords[b] = (protein_coords[b] - center) @ R.T + trans + else: + # [L, n_atoms, 3] - need to broadcast center properly + protein_coords[b] = (protein_coords[b] - center.unsqueeze(0)) @ R.T + trans + + # Apply mask to protein + if protein_mask is not None: + if is_flat: + protein_coords[b] = protein_coords[b] * protein_mask[b, :, None].float() + else: + protein_coords[b] = protein_coords[b] * protein_mask[b, :, None, None].float() + + # Apply SAME transform to ligand (if present) + if ligand_coords is not None: + ligand_coords[b] = (ligand_coords[b] - center) @ R.T + trans + if ligand_mask is not None: + ligand_coords[b] = ligand_coords[b] * ligand_mask[b, :, None].float() + + return protein_coords, ligand_coords + + def _graham_schmidt(x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12): e1 = xy_plane denom = torch.sqrt((x_axis**2).sum(dim=-1, keepdim=True) + eps) diff --git a/src/lobster/model/latent_generator/utils/_minimize_ligand.py b/src/lobster/model/latent_generator/utils/_minimize_ligand.py new file mode 100644 index 00000000..46c20f24 --- /dev/null +++ b/src/lobster/model/latent_generator/utils/_minimize_ligand.py @@ -0,0 +1,499 @@ +"""Open Babel ligand minimization utilities. + +This module provides functions for energy minimization of ligand structures +using Open Babel force fields. It can be used as a post-processing step +after structure generation to improve ligand geometry. + +Functions +--------- +minimize_ligand_structure : Minimize ligand coordinates using force field optimization +get_ligand_energy : Calculate the potential energy of a ligand structure +""" + +import logging + +import torch + +py_logger = logging.getLogger(__name__) + + +def minimize_ligand_structure( + coords: torch.Tensor, + atom_types: list[str], + bond_matrix: torch.Tensor | None = None, + steps: int = 500, + force_field: str = "MMFF94", + method: str = "cg", + mode: str = "full", +) -> torch.Tensor: + """Minimize ligand structure using Open Babel force field optimization. + + This function performs energy minimization on ligand coordinates to improve + geometry (bond lengths, angles, torsions). It can use provided bond connectivity + or infer bonds from atomic distances. + + Parameters + ---------- + coords : torch.Tensor + Ligand coordinates with shape (num_atoms, 3) or (batch, num_atoms, 3). + Coordinates should be in Angstroms. + atom_types : list[str] + Element symbols for each atom (e.g., ["C", "N", "O", "C", ...]). + Length must match num_atoms. + bond_matrix : torch.Tensor, optional + Bond connectivity matrix with shape (num_atoms, num_atoms). + Values: 0=no bond, 1=single, 2=double, 3=triple, 4=aromatic. + If None, bonds will be inferred from coordinates using Open Babel. + steps : int, default=500 + Maximum number of minimization steps. Ignored if mode="bonds_only". + force_field : str, default="MMFF94" + Force field to use. Options: "MMFF94", "MMFF94s", "UFF", "GAFF", "Ghemical". + - MMFF94: Merck Molecular Force Field (recommended for drug-like molecules) + - MMFF94s: MMFF94 with modified torsion parameters for planar groups + - UFF: Universal Force Field (good fallback, works for all elements) + - GAFF: General AMBER Force Field (good for organic molecules) + - Ghemical: Ghemical force field + method : str, default="cg" + Optimization method. Options: "cg" (conjugate gradients), "sd" (steepest descent). + Conjugate gradients is generally faster and recommended. + mode : str, default="full" + Minimization mode: + - "full": Full energy minimization (default, may change conformation) + - "local": Short minimization (50 steps) to fix bond lengths/angles only + - "bonds_only": Correct bond lengths to ideal values without minimization + - "bonds_and_angles": Correct both bond lengths and angles to ideal values + + Returns + ------- + torch.Tensor + Minimized coordinates with same shape as input. + + Raises + ------ + ImportError + If openbabel is not installed. + ValueError + If force field setup fails or coordinates are invalid. + + Examples + -------- + >>> coords = torch.tensor([[0.0, 0.0, 0.0], [1.5, 0.0, 0.0], [2.3, 1.2, 0.0]]) + >>> atom_types = ["C", "C", "O"] + >>> minimized = minimize_ligand_structure(coords, atom_types, mode="local") + + Notes + ----- + - If bond_matrix is not provided, Open Babel will infer bonds based on + atomic distances and element types. This works well for most organic molecules. + - The minimization preserves the overall molecular topology and only adjusts + atomic positions to lower the potential energy. + - For best results with drug-like molecules, use MMFF94 force field. + - UFF is recommended as fallback since it supports all elements. + - Use mode="local" or mode="bonds_only" to preserve overall conformation + while fixing local geometry issues. + """ + try: + from openbabel import openbabel as ob + except ImportError as e: + raise ImportError( + "Open Babel is required for ligand minimization. Install with: pip install openbabel-wheel" + ) from e + + # Handle batch dimension + had_batch_dim = coords.dim() == 3 + if had_batch_dim: + if coords.shape[0] != 1: + raise ValueError(f"Batch minimization not supported. Got batch size {coords.shape[0]}, expected 1.") + coords = coords.squeeze(0) + + # Validate inputs + num_atoms = coords.shape[0] + if len(atom_types) != num_atoms: + raise ValueError(f"Number of atom types ({len(atom_types)}) must match number of atoms ({num_atoms})") + + # Convert to numpy for Open Babel + coords_np = coords.detach().cpu().numpy() + + # Create Open Babel molecule + mol = ob.OBMol() + + # Add atoms + for i, (coord, atom_type) in enumerate(zip(coords_np, atom_types)): + atom = mol.NewAtom() + # Handle element lookup - strip any numbers from atom names (e.g., "C1" -> "C") + element = "".join(c for c in atom_type if c.isalpha()) + atomic_num = ob.GetAtomicNum(element) + if atomic_num == 0: + py_logger.warning(f"Unknown element '{element}', defaulting to Carbon") + atomic_num = 6 # Default to Carbon + atom.SetAtomicNum(atomic_num) + atom.SetVector(float(coord[0]), float(coord[1]), float(coord[2])) + + # Add bonds from bond_matrix if provided, otherwise let Open Babel infer + if bond_matrix is not None: + bond_matrix_np = bond_matrix.detach().cpu().numpy() + # Map our bond types to Open Babel bond orders + bond_order_map = { + 1: 1, # Single + 2: 2, # Double + 3: 3, # Triple + 4: 5, # Aromatic (Open Babel uses 5 for aromatic) + 5: 1, # Other -> Single + } + for i in range(num_atoms): + for j in range(i + 1, num_atoms): + bond_val = int(bond_matrix_np[i, j]) + if bond_val > 0: + order = bond_order_map.get(bond_val, 1) + mol.AddBond(i + 1, j + 1, order) # OB uses 1-based indexing + else: + # Infer bonds from coordinates + mol.ConnectTheDots() + mol.PerceiveBondOrders() + + # Helper function to correct bond lengths + def _correct_bond_lengths(molecule): + """Correct bond lengths to ideal values.""" + for bond in ob.OBMolBondIter(molecule): + atom1 = bond.GetBeginAtom() + atom2 = bond.GetEndAtom() + # Get ideal bond length from Open Babel's tables + ideal_length = ob.GetCovalentRad(atom1.GetAtomicNum()) + ob.GetCovalentRad(atom2.GetAtomicNum()) + if bond.GetBondOrder() == 2: + ideal_length *= 0.87 # Double bonds are ~13% shorter + elif bond.GetBondOrder() == 3: + ideal_length *= 0.78 # Triple bonds are ~22% shorter + elif bond.IsAromatic(): + ideal_length *= 0.91 # Aromatic bonds are ~9% shorter + + # Get current bond vector + v1 = atom1.GetVector() + v2 = atom2.GetVector() + current_length = v1.distSq(v2) ** 0.5 + + if current_length > 0.01: # Avoid division by zero + # Scale factor to achieve ideal length + scale = ideal_length / current_length + # Move atoms toward/away from each other equally + midpoint_x = (v1.GetX() + v2.GetX()) / 2 + midpoint_y = (v1.GetY() + v2.GetY()) / 2 + midpoint_z = (v1.GetZ() + v2.GetZ()) / 2 + + # New positions scaled from midpoint + new_x1 = midpoint_x + (v1.GetX() - midpoint_x) * scale + new_y1 = midpoint_y + (v1.GetY() - midpoint_y) * scale + new_z1 = midpoint_z + (v1.GetZ() - midpoint_z) * scale + new_x2 = midpoint_x + (v2.GetX() - midpoint_x) * scale + new_y2 = midpoint_y + (v2.GetY() - midpoint_y) * scale + new_z2 = midpoint_z + (v2.GetZ() - midpoint_z) * scale + + atom1.SetVector(new_x1, new_y1, new_z1) + atom2.SetVector(new_x2, new_y2, new_z2) + + # Helper function to get ideal bond angle based on hybridization + def _get_ideal_angle(central_atom): + """Get ideal bond angle for a central atom based on its hybridization.""" + hyb = central_atom.GetHyb() + if hyb == 1: # sp - linear + return 180.0 + elif hyb == 2: # sp2 - trigonal planar + return 120.0 + elif hyb == 3: # sp3 - tetrahedral + return 109.47 + else: + # Default to sp3 if unknown + return 109.47 + + # Helper function to correct bond angles + def _correct_bond_angles(molecule, num_iterations=3): + """Correct bond angles to ideal values based on hybridization.""" + import math + + for _ in range(num_iterations): + for angle in ob.OBMolAngleIter(molecule): + # angle is a tuple (vertex_idx, atom1_idx, atom2_idx) - 0-based + vertex_idx, idx1, idx2 = angle + + # Get atoms (OBMol uses 1-based indexing) + central_atom = molecule.GetAtom(vertex_idx + 1) + atom1 = molecule.GetAtom(idx1 + 1) + atom2 = molecule.GetAtom(idx2 + 1) + + if central_atom is None or atom1 is None or atom2 is None: + continue + + # Get ideal angle for this central atom + ideal_angle = _get_ideal_angle(central_atom) + ideal_rad = math.radians(ideal_angle) + + # Get current positions + vc = central_atom.GetVector() + v1 = atom1.GetVector() + v2 = atom2.GetVector() + + # Calculate vectors from central atom + vec1_x = v1.GetX() - vc.GetX() + vec1_y = v1.GetY() - vc.GetY() + vec1_z = v1.GetZ() - vc.GetZ() + + vec2_x = v2.GetX() - vc.GetX() + vec2_y = v2.GetY() - vc.GetY() + vec2_z = v2.GetZ() - vc.GetZ() + + # Calculate current angle + len1 = math.sqrt(vec1_x**2 + vec1_y**2 + vec1_z**2) + len2 = math.sqrt(vec2_x**2 + vec2_y**2 + vec2_z**2) + + if len1 < 0.01 or len2 < 0.01: + continue + + dot = vec1_x * vec2_x + vec1_y * vec2_y + vec1_z * vec2_z + cos_angle = max(-1.0, min(1.0, dot / (len1 * len2))) + current_rad = math.acos(cos_angle) + + # Calculate angle difference + angle_diff = ideal_rad - current_rad + + # Skip if angle is already close to ideal (within 5 degrees) + if abs(angle_diff) < math.radians(5.0): + continue + + # Calculate rotation axis (perpendicular to the plane of the angle) + cross_x = vec1_y * vec2_z - vec1_z * vec2_y + cross_y = vec1_z * vec2_x - vec1_x * vec2_z + cross_z = vec1_x * vec2_y - vec1_y * vec2_x + cross_len = math.sqrt(cross_x**2 + cross_y**2 + cross_z**2) + + if cross_len < 0.001: + continue # Vectors are parallel, can't define rotation axis + + # Normalize rotation axis + axis_x = cross_x / cross_len + axis_y = cross_y / cross_len + axis_z = cross_z / cross_len + + # Rotate atom2 around axis by half the angle difference + # (and atom1 in opposite direction by half) + half_diff = angle_diff / 2.0 + + # Rodrigues rotation formula for atom2 + cos_rot = math.cos(half_diff) + sin_rot = math.sin(half_diff) + + # Rotate vec2 + dot_axis_vec2 = axis_x * vec2_x + axis_y * vec2_y + axis_z * vec2_z + cross2_x = axis_y * vec2_z - axis_z * vec2_y + cross2_y = axis_z * vec2_x - axis_x * vec2_z + cross2_z = axis_x * vec2_y - axis_y * vec2_x + + new_vec2_x = vec2_x * cos_rot + cross2_x * sin_rot + axis_x * dot_axis_vec2 * (1 - cos_rot) + new_vec2_y = vec2_y * cos_rot + cross2_y * sin_rot + axis_y * dot_axis_vec2 * (1 - cos_rot) + new_vec2_z = vec2_z * cos_rot + cross2_z * sin_rot + axis_z * dot_axis_vec2 * (1 - cos_rot) + + # Rotate vec1 in opposite direction + cos_rot_neg = math.cos(-half_diff) + sin_rot_neg = math.sin(-half_diff) + + dot_axis_vec1 = axis_x * vec1_x + axis_y * vec1_y + axis_z * vec1_z + cross1_x = axis_y * vec1_z - axis_z * vec1_y + cross1_y = axis_z * vec1_x - axis_x * vec1_z + cross1_z = axis_x * vec1_y - axis_y * vec1_x + + new_vec1_x = vec1_x * cos_rot_neg + cross1_x * sin_rot_neg + axis_x * dot_axis_vec1 * (1 - cos_rot_neg) + new_vec1_y = vec1_y * cos_rot_neg + cross1_y * sin_rot_neg + axis_y * dot_axis_vec1 * (1 - cos_rot_neg) + new_vec1_z = vec1_z * cos_rot_neg + cross1_z * sin_rot_neg + axis_z * dot_axis_vec1 * (1 - cos_rot_neg) + + # Update positions + atom1.SetVector(vc.GetX() + new_vec1_x, vc.GetY() + new_vec1_y, vc.GetZ() + new_vec1_z) + atom2.SetVector(vc.GetX() + new_vec2_x, vc.GetY() + new_vec2_y, vc.GetZ() + new_vec2_z) + + # Handle bonds_only mode - correct bond lengths without energy minimization + if mode == "bonds_only": + builder = ob.OBBuilder() + builder.CorrectStereoAtoms(mol) + _correct_bond_lengths(mol) + elif mode == "bonds_and_angles": + # Use constrained force field minimization with ideal bond lengths and angles + + # Set up constraints for ideal geometry + constraints = ob.OBFFConstraints() + constraints.SetFactor(10000.0) # High weight to enforce constraints + + # Add distance constraints for all bonds at ideal lengths + for bond in ob.OBMolBondIter(mol): + atom1 = bond.GetBeginAtom() + atom2 = bond.GetEndAtom() + + # Calculate ideal bond length based on atom types and bond order + ideal_length = ob.GetCovalentRad(atom1.GetAtomicNum()) + ob.GetCovalentRad(atom2.GetAtomicNum()) + if bond.GetBondOrder() == 2: + ideal_length *= 0.87 + elif bond.GetBondOrder() == 3: + ideal_length *= 0.78 + elif bond.IsAromatic(): + ideal_length *= 0.91 + + constraints.AddDistanceConstraint(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), ideal_length) + + # Add angle constraints for all angles at ideal values based on hybridization + for angle in ob.OBMolAngleIter(mol): + vertex_idx, idx1, idx2 = angle + central_atom = mol.GetAtom(vertex_idx + 1) + + # Determine ideal angle based on hybridization + hyb = central_atom.GetHyb() + if hyb == 1: # sp - linear + ideal_angle = 180.0 + elif hyb == 2: # sp2 - trigonal planar + ideal_angle = 120.0 + else: # sp3 - tetrahedral (default) + ideal_angle = 109.47 + + # OBFFConstraints uses 1-based indexing + constraints.AddAngleConstraint(idx1 + 1, vertex_idx + 1, idx2 + 1, ideal_angle) + + # Run constrained minimization + ff = ob.OBForceField.FindForceField(force_field) + if ff is None: + ff = ob.OBForceField.FindForceField("UFF") + + if ff is not None and ff.Setup(mol, constraints): + ff.ConjugateGradients(min(steps, 500)) + ff.GetCoordinates(mol) + else: + py_logger.warning("Constrained minimization failed, falling back to bonds_only") + _correct_bond_lengths(mol) + else: + # Force field minimization modes + # Determine actual steps based on mode + actual_steps = steps + if mode == "local": + actual_steps = min(50, steps) # Cap at 50 for local mode + elif mode != "full": + raise ValueError(f"Unknown mode: {mode}. Use 'full', 'local', 'bonds_only', or 'bonds_and_angles'.") + + # Set up force field + ff = ob.OBForceField.FindForceField(force_field) + if ff is None: + # Try fallback to UFF + py_logger.warning(f"Force field '{force_field}' not available, falling back to UFF") + ff = ob.OBForceField.FindForceField("UFF") + if ff is None: + raise ValueError("No force field available for minimization") + + # Initialize force field with molecule + if not ff.Setup(mol): + py_logger.warning(f"Force field setup failed with {force_field}, trying UFF as fallback") + ff = ob.OBForceField.FindForceField("UFF") + if ff is None or not ff.Setup(mol): + py_logger.warning("Force field setup failed, returning original coordinates") + if had_batch_dim: + return coords.unsqueeze(0) + return coords + + # Run minimization + if method == "cg": + ff.ConjugateGradients(actual_steps) + elif method == "sd": + ff.SteepestDescent(actual_steps) + else: + raise ValueError(f"Unknown optimization method: {method}. Use 'cg' or 'sd'.") + + # Update coordinates in molecule + ff.GetCoordinates(mol) + + # Extract minimized coordinates + minimized_coords = torch.zeros_like(coords) + for i in range(num_atoms): + atom = mol.GetAtom(i + 1) # OB uses 1-based indexing + minimized_coords[i, 0] = atom.GetX() + minimized_coords[i, 1] = atom.GetY() + minimized_coords[i, 2] = atom.GetZ() + + # Restore batch dimension if needed + if had_batch_dim: + minimized_coords = minimized_coords.unsqueeze(0) + + return minimized_coords + + +def get_ligand_energy( + coords: torch.Tensor, + atom_types: list[str], + bond_matrix: torch.Tensor | None = None, + force_field: str = "MMFF94", +) -> float: + """Calculate the potential energy of a ligand structure. + + This function computes the force field energy of a ligand, which can be + used to compare structures before and after minimization. + + Parameters + ---------- + coords : torch.Tensor + Ligand coordinates with shape (num_atoms, 3) or (batch, num_atoms, 3). + atom_types : list[str] + Element symbols for each atom. + bond_matrix : torch.Tensor, optional + Bond connectivity matrix. If None, bonds will be inferred. + force_field : str, default="MMFF94" + Force field to use for energy calculation. + + Returns + ------- + float + Potential energy in kcal/mol. Returns float('inf') if calculation fails. + + Examples + -------- + >>> coords = torch.tensor([[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]) + >>> atom_types = ["C", "C"] + >>> energy = get_ligand_energy(coords, atom_types) + """ + try: + from openbabel import openbabel as ob + except ImportError: + py_logger.warning("Open Babel not available for energy calculation") + return float("inf") + + # Handle batch dimension + if coords.dim() == 3: + coords = coords.squeeze(0) + + num_atoms = coords.shape[0] + coords_np = coords.detach().cpu().numpy() + + # Create molecule + mol = ob.OBMol() + for i, (coord, atom_type) in enumerate(zip(coords_np, atom_types)): + atom = mol.NewAtom() + element = "".join(c for c in atom_type if c.isalpha()) + atomic_num = ob.GetAtomicNum(element) + if atomic_num == 0: + atomic_num = 6 + atom.SetAtomicNum(atomic_num) + atom.SetVector(float(coord[0]), float(coord[1]), float(coord[2])) + + # Add bonds + if bond_matrix is not None: + bond_matrix_np = bond_matrix.detach().cpu().numpy() + bond_order_map = {1: 1, 2: 2, 3: 3, 4: 5, 5: 1} + for i in range(num_atoms): + for j in range(i + 1, num_atoms): + bond_val = int(bond_matrix_np[i, j]) + if bond_val > 0: + order = bond_order_map.get(bond_val, 1) + mol.AddBond(i + 1, j + 1, order) + else: + mol.ConnectTheDots() + mol.PerceiveBondOrders() + + # Calculate energy + ff = ob.OBForceField.FindForceField(force_field) + if ff is None: + ff = ob.OBForceField.FindForceField("UFF") + + if ff is None or not ff.Setup(mol): + return float("inf") + + return ff.Energy() diff --git a/src/lobster/model/latent_generator/utils/_se3_augmentation.py b/src/lobster/model/latent_generator/utils/_se3_augmentation.py new file mode 100644 index 00000000..ae29e97d --- /dev/null +++ b/src/lobster/model/latent_generator/utils/_se3_augmentation.py @@ -0,0 +1,369 @@ +"""Standalone SE3 augmentation utilities for protein-ligand complexes. + +This module provides functions to apply SE(3) transformations (rotation + translation) +to protein-ligand complexes while maintaining their relative positions. +""" + +import logging +from typing import Literal, NamedTuple + +import torch +from torch import Tensor + +from ._kinematics import ( + apply_global_frame_to_coords, + apply_random_se3_batched, +) + +logger = logging.getLogger(__name__) + + +class SE3AugmentedComplex(NamedTuple): + """Output of SE3 augmentation for protein-ligand complexes.""" + + protein_coords: Tensor | None # [B, L, n_atoms, 3] + protein_mask: Tensor | None # [B, L] + ligand_coords: Tensor | None # [B, N_ligand, 3] + ligand_mask: Tensor | None # [B, N_ligand] + + +def apply_se3_augmentation_protein_ligand( + protein_coords: Tensor | None = None, + protein_mask: Tensor | None = None, + ligand_coords: Tensor | None = None, + ligand_mask: Tensor | None = None, + random_se3: bool = True, + only_rot: bool = False, + only_trans: bool = False, + translation_scale: float = 1.0, + rotation_mode: str = "svd", + frame_type: Literal["norm_frame", "pca_frame", "mol_frame"] | None = None, + apply_stochastic_fa: bool = False, + get_all_frames: bool = False, + backbone_noise: float = 0.0, +) -> SE3AugmentedComplex: + """Apply SE(3) augmentation to protein-ligand complex. + + This function applies the SAME SE(3) transformation (rotation + translation) + to both protein and ligand coordinates, ensuring they remain in the same + reference frame. The transformation is applied as: + 1. Concatenate protein and ligand coordinates + 2. Apply random SE(3) transformation + 3. Optionally apply global frame (PCA/norm/mol frame) + 4. Add optional backbone noise + 5. Split back into protein and ligand coordinates + + Parameters + ---------- + protein_coords : Tensor | None + Protein backbone coordinates of shape [B, L, n_atoms, 3] where n_atoms + is typically 4 (N, CA, C, O) or 3 (N, CA, C). Can be None for ligand-only. + protein_mask : Tensor | None + Boolean mask of shape [B, L] indicating valid residues. Can be None. + ligand_coords : Tensor | None + Ligand atom coordinates of shape [B, N_ligand, 3]. Can be None for protein-only. + ligand_mask : Tensor | None + Boolean mask of shape [B, N_ligand] indicating valid atoms. Can be None. + random_se3 : bool + Whether to apply random SE(3) transformation. Default True. + only_rot : bool + If True, only apply rotation (no translation). Default False. + only_trans : bool + If True, only apply translation (no rotation). Default False. + translation_scale : float + Scale factor for random translation. Default 1.0. + rotation_mode : str + Method to generate random rotation. One of "svd", "quaternion", or "none". + Default "svd". + frame_type : str | None + Type of global frame to apply. One of "norm_frame", "pca_frame", "mol_frame", + or None to skip frame application. Default None. + apply_stochastic_fa : bool + Whether to apply stochastic frame alignment. Only used with frame_type. + Default False. + get_all_frames : bool + Whether to get all possible frame orientations. Only used with frame_type. + Default False. + backbone_noise : float + Standard deviation of Gaussian noise to add to coordinates. Default 0.0. + + Returns + ------- + SE3AugmentedComplex + Named tuple containing: + - protein_coords: Transformed protein coordinates [B, L, n_atoms, 3] or None + - protein_mask: Protein mask [B, L] or None + - ligand_coords: Transformed ligand coordinates [B, N_ligand, 3] or None + - ligand_mask: Ligand mask [B, N_ligand] or None + + Examples + -------- + >>> # Protein-only augmentation + >>> protein = torch.randn(2, 100, 4, 3) + >>> mask = torch.ones(2, 100, dtype=torch.bool) + >>> result = apply_se3_augmentation_protein_ligand( + ... protein_coords=protein, + ... protein_mask=mask, + ... ) + + >>> # Protein-ligand complex augmentation + >>> protein = torch.randn(2, 100, 4, 3) + >>> ligand = torch.randn(2, 30, 3) + >>> result = apply_se3_augmentation_protein_ligand( + ... protein_coords=protein, + ... protein_mask=torch.ones(2, 100, dtype=torch.bool), + ... ligand_coords=ligand, + ... ligand_mask=torch.ones(2, 30, dtype=torch.bool), + ... ) + """ + # Handle edge cases + has_protein = protein_coords is not None + has_ligand = ligand_coords is not None + + if not has_protein and not has_ligand: + raise ValueError("At least one of protein_coords or ligand_coords must be provided") + + # Clone inputs to avoid modifying originals + if has_protein: + protein_coords = protein_coords.clone() + if protein_mask is not None: + protein_mask = protein_mask.clone() + if has_ligand: + ligand_coords = ligand_coords.clone() + if ligand_mask is not None: + ligand_mask = ligand_mask.clone() + + # Determine batch size and device + if has_protein: + B, L, n_atoms, _ = protein_coords.shape + device = protein_coords.device + else: + B = ligand_coords.shape[0] + L = 0 + n_atoms = 0 + device = ligand_coords.device + + # Prepare combined coordinates for joint transformation + if has_protein and has_ligand: + # Both present - concatenate for joint transformation + B_ligand = ligand_coords.shape[0] + assert B == B_ligand, f"Batch size mismatch: protein {B} vs ligand {B_ligand}" + + # Flatten protein coords: [B, L, n_atoms, 3] -> [B, L*n_atoms, 3] + coords_flat = protein_coords.reshape(B, -1, 3) + # Concatenate with ligand: [B, L*n_atoms + N_ligand, 3] + coords = torch.cat([coords_flat, ligand_coords], dim=1) + + # Expand protein mask to atom level: [B, L] -> [B, L*n_atoms] + if protein_mask is not None: + seq_mask = protein_mask.unsqueeze(-1).expand(-1, -1, n_atoms).reshape(B, -1) + else: + seq_mask = torch.ones(B, L * n_atoms, device=device, dtype=torch.bool) + + # Concatenate masks + if ligand_mask is not None: + seq_mask = torch.cat([seq_mask, ligand_mask], dim=1) + else: + ligand_n = ligand_coords.shape[1] + seq_mask = torch.cat([seq_mask, torch.ones(B, ligand_n, device=device, dtype=torch.bool)], dim=1) + + elif has_protein: + # Protein-only - flatten + coords = protein_coords.reshape(B, -1, 3) + if protein_mask is not None: + seq_mask = protein_mask.unsqueeze(-1).expand(-1, -1, n_atoms).reshape(B, -1) + else: + seq_mask = torch.ones(B, L * n_atoms, device=device, dtype=torch.bool) + + else: + # Ligand-only + coords = ligand_coords + if ligand_mask is not None: + seq_mask = ligand_mask + else: + seq_mask = torch.ones(B, ligand_coords.shape[1], device=device, dtype=torch.bool) + + # Apply SE(3) transformation + if random_se3 and seq_mask.any(): + if only_rot: + logger.debug("Only applying rotation") + actual_translation_scale = 0.0 + actual_rotation_mode = rotation_mode + elif only_trans: + logger.debug("Only applying translation") + actual_translation_scale = translation_scale + actual_rotation_mode = "none" + else: + actual_translation_scale = translation_scale + actual_rotation_mode = rotation_mode + + coords = apply_random_se3_batched( + coords, + atom_mask=seq_mask, + translation_scale=actual_translation_scale, + rotation_mode=actual_rotation_mode, + ) + elif not random_se3: + logger.debug("No SE(3) applied") + + # Apply global frame transformation + if frame_type is not None and seq_mask.any(): + coords = apply_global_frame_to_coords( + coords, + frame_type=frame_type, + mask=seq_mask, + apply_stochastic_fa=apply_stochastic_fa, + get_all_frames=get_all_frames, + ) + + # Add backbone noise + if backbone_noise > 0: + coords = coords + backbone_noise * torch.randn_like(coords) + + # Split back into protein and ligand + if has_protein and has_ligand: + # Split coordinates + n_protein_atoms = L * n_atoms + ligand_coords_out = coords[:, n_protein_atoms:, :] + protein_coords_flat = coords[:, :n_protein_atoms, :] + protein_coords_out = protein_coords_flat.reshape(B, L, n_atoms, 3) + + # Reconstruct protein mask from atom-level mask + protein_mask_out = seq_mask[:, :n_protein_atoms].reshape(B, L, n_atoms).any(dim=-1) + ligand_mask_out = seq_mask[:, n_protein_atoms:] + + return SE3AugmentedComplex( + protein_coords=protein_coords_out, + protein_mask=protein_mask_out, + ligand_coords=ligand_coords_out, + ligand_mask=ligand_mask_out, + ) + + elif has_protein: + # Protein-only - reshape back + protein_coords_out = coords.reshape(B, L, n_atoms, 3) + protein_mask_out = seq_mask.reshape(B, L, n_atoms).any(dim=-1) if protein_mask is not None else None + + return SE3AugmentedComplex( + protein_coords=protein_coords_out, + protein_mask=protein_mask_out, + ligand_coords=None, + ligand_mask=None, + ) + + else: + # Ligand-only + return SE3AugmentedComplex( + protein_coords=None, + protein_mask=None, + ligand_coords=coords, + ligand_mask=seq_mask if ligand_mask is not None else None, + ) + + +def apply_se3_augmentation_batched( + coords: Tensor, + mask: Tensor | None = None, + random_se3: bool = True, + only_rot: bool = False, + only_trans: bool = False, + translation_scale: float = 1.0, + rotation_mode: str = "svd", + frame_type: Literal["norm_frame", "pca_frame", "mol_frame"] | None = None, + apply_stochastic_fa: bool = False, + get_all_frames: bool = False, + backbone_noise: float = 0.0, +) -> Tensor: + """Apply SE(3) augmentation to a batch of coordinates. + + Simplified interface for single coordinate tensor (either protein or ligand). + + Parameters + ---------- + coords : Tensor + Coordinates of shape [B, N, 3] (flat) or [B, L, n_atoms, 3] (structured). + mask : Tensor | None + Boolean mask. For flat coords: [B, N], for structured: [B, L]. + random_se3 : bool + Whether to apply random SE(3) transformation. Default True. + only_rot : bool + If True, only apply rotation (no translation). Default False. + only_trans : bool + If True, only apply translation (no rotation). Default False. + translation_scale : float + Scale factor for random translation. Default 1.0. + rotation_mode : str + Method to generate random rotation. One of "svd", "quaternion", "none". + frame_type : str | None + Type of global frame to apply. One of "norm_frame", "pca_frame", "mol_frame". + apply_stochastic_fa : bool + Whether to apply stochastic frame alignment. + get_all_frames : bool + Whether to get all possible frame orientations. + backbone_noise : float + Standard deviation of Gaussian noise to add. + + Returns + ------- + Tensor + Transformed coordinates with same shape as input. + """ + coords = coords.clone() + is_structured = len(coords.shape) == 4 + + if is_structured: + B, L, n_atoms, _ = coords.shape + device = coords.device + + # Flatten for processing + coords_flat = coords.reshape(B, -1, 3) + + # Expand mask + if mask is not None: + seq_mask = mask.unsqueeze(-1).expand(-1, -1, n_atoms).reshape(B, -1) + else: + seq_mask = torch.ones(B, L * n_atoms, device=device, dtype=torch.bool) + else: + coords_flat = coords + if mask is not None: + seq_mask = mask + else: + seq_mask = torch.ones(coords.shape[0], coords.shape[1], device=coords.device, dtype=torch.bool) + + # Apply SE(3) transformation + if random_se3 and seq_mask.any(): + if only_rot: + actual_translation_scale = 0.0 + actual_rotation_mode = rotation_mode + elif only_trans: + actual_translation_scale = translation_scale + actual_rotation_mode = "none" + else: + actual_translation_scale = translation_scale + actual_rotation_mode = rotation_mode + + coords_flat = apply_random_se3_batched( + coords_flat, + atom_mask=seq_mask, + translation_scale=actual_translation_scale, + rotation_mode=actual_rotation_mode, + ) + + # Apply global frame + if frame_type is not None and seq_mask.any(): + coords_flat = apply_global_frame_to_coords( + coords_flat, + frame_type=frame_type, + mask=seq_mask, + apply_stochastic_fa=apply_stochastic_fa, + get_all_frames=get_all_frames, + ) + + # Add noise + if backbone_noise > 0: + coords_flat = coords_flat + backbone_noise * torch.randn_like(coords_flat) + + # Reshape if needed + if is_structured: + return coords_flat.reshape(B, L, n_atoms, 3) + return coords_flat diff --git a/src/lobster/model/latent_generator/utils/residue_constants.py b/src/lobster/model/latent_generator/utils/residue_constants.py index b06e8135..09a88ce9 100644 --- a/src/lobster/model/latent_generator/utils/residue_constants.py +++ b/src/lobster/model/latent_generator/utils/residue_constants.py @@ -753,6 +753,54 @@ def convert_lobster_aa_tokenization_to_standard_aa(sequence_logits, device=None) ELEMENT_VOCAB = ["PAD", "B", "Bi", "Br", "C", "Cl", "F", "H", "I", "N", "O", "P", "S", "Si"] ELEMENT_TO_IDX = {elem: idx for idx, elem in enumerate(ELEMENT_VOCAB)} +# Extended element vocabulary for drug-like molecules (25 tokens) +# Covers ~99.9% of drug-like molecules with special tokens for masking and unknown elements +ELEMENT_VOCAB_EXTENDED = [ + # Special tokens + "PAD", # 0: Padding token + "MASK", # 1: Mask token for masked language modeling + "UNK", # 2: Unknown element + # Common organic elements (highest frequency in drugs) + "C", # 3: Carbon + "N", # 4: Nitrogen + "O", # 5: Oxygen + "S", # 6: Sulfur + "P", # 7: Phosphorus + "F", # 8: Fluorine + "Cl", # 9: Chlorine + "Br", # 10: Bromine + "I", # 11: Iodine + # Common metal/metalloid in drugs + "B", # 12: Boron (boronic acids in drugs) + "Si", # 13: Silicon (silanes, silanols) + "Se", # 14: Selenium (selenocysteine analogs) + "As", # 15: Arsenic (rare, but exists in some drugs) + # Metals found in coordination complexes and some drugs + "Zn", # 16: Zinc + "Fe", # 17: Iron + "Cu", # 18: Copper + "Mg", # 19: Magnesium + "Ca", # 20: Calcium + "Na", # 21: Sodium + "K", # 22: Potassium + # Additional elements occasionally found in bioactive molecules + "Bi", # 23: Bismuth (bismuth subsalicylate) + "H", # 24: Hydrogen (explicit H, rare in our representation) +] +ELEMENT_VOCAB_EXTENDED_TO_IDX = {elem: idx for idx, elem in enumerate(ELEMENT_VOCAB_EXTENDED)} + +# Bond types for bond matrix representation +# 0: no bond, 1: single, 2: double, 3: triple, 4: aromatic, 5: unknown/other +BOND_TYPES = { + "NONE": 0, + "SINGLE": 1, + "DOUBLE": 2, + "TRIPLE": 3, + "AROMATIC": 4, + "OTHER": 5, +} +NUM_BOND_TYPES = len(BOND_TYPES) + # from funcbind diff --git a/src/lobster/model/losses/__init__.py b/src/lobster/model/losses/__init__.py index 56d2ee84..3700f544 100644 --- a/src/lobster/model/losses/__init__.py +++ b/src/lobster/model/losses/__init__.py @@ -9,6 +9,11 @@ NaturalGaussianLoss, MixtureGaussianNLLLoss, ) +from ._diffusion_loss import ( + DiffusionLoss, + SimpleMLPAdaLN, + create_diffusion_loss, +) # Import registry data from ._registry import ( @@ -28,6 +33,10 @@ "ExponentialParameterizedLoss", "NaturalGaussianLoss", "MixtureGaussianNLLLoss", + # Diffusion loss for continuous tokens (MAR-style) + "DiffusionLoss", + "SimpleMLPAdaLN", + "create_diffusion_loss", # Registry constants and function "AVAILABLE_LOSS_FUNCTIONS", "DEFAULT_LOSS_FUNCTIONS", diff --git a/src/lobster/model/losses/_diffusion_loss.py b/src/lobster/model/losses/_diffusion_loss.py new file mode 100644 index 00000000..8fc4b426 --- /dev/null +++ b/src/lobster/model/losses/_diffusion_loss.py @@ -0,0 +1,559 @@ +""" +Diffusion Loss for continuous structure token modeling. + +Adapted from MAR (Masked Autoregressive Models) paper: +"Autoregressive Image Generation without Vector Quantization" +https://arxiv.org/abs/2406.11838 +https://github.com/LTH14/mar + +This module provides a self-contained implementation of Diffusion Loss +that can replace categorical cross-entropy for continuous token spaces. +""" + +import math +from typing import Literal + +import torch +import torch.nn as nn +from torch import Tensor +from torch.utils.checkpoint import checkpoint + + +def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + """Apply AdaLN modulation: x * (1 + scale) + shift.""" + return x * (1 + scale) + shift + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + + From MAR/DiT: uses sinusoidal embeddings followed by MLP. + """ + + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000) -> Tensor: + """ + Create sinusoidal timestep embeddings. + + :param t: a 1-D Tensor of N indices, one per batch element. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t: Tensor) -> Tensor: + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class ResBlock(nn.Module): + """ + A residual block with AdaLN modulation. + + From MAR: uses shift, scale, and gate modulation. + :param channels: the number of input channels. + """ + + def __init__(self, channels: int): + super().__init__() + self.channels = channels + + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)) + + def forward(self, x: Tensor, y: Tensor) -> Tensor: + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class FinalLayer(nn.Module): + """ + The final layer with AdaLN modulation. + + From MAR/DiT: applies final normalization and linear projection. + """ + + def __init__(self, model_channels: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)) + + def forward(self, x: Tensor, c: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class SimpleMLPAdaLN(nn.Module): + """ + The MLP denoiser for Diffusion Loss. + + From MAR: Simple MLP with AdaLN conditioning. + + :param in_channels: channels in the input Tensor (target dim). + :param model_channels: base channel count for the model (width). + :param out_channels: channels in the output Tensor. + :param z_channels: channels in the condition from transformer. + :param num_res_blocks: number of residual blocks (depth). + :param grad_checkpointing: whether to use gradient checkpointing. + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + grad_checkpointing: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.grad_checkpointing = grad_checkpointing + + self.time_embed = TimestepEmbedder(model_channels) + self.cond_embed = nn.Linear(z_channels, model_channels) + + self.input_proj = nn.Linear(in_channels, model_channels) + + res_blocks = [] + for i in range(num_res_blocks): + res_blocks.append(ResBlock(model_channels)) + + self.res_blocks = nn.ModuleList(res_blocks) + self.final_layer = FinalLayer(model_channels, out_channels) + + self.initialize_weights() + + def initialize_weights(self): + """Initialize weights following MAR/DiT conventions.""" + + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP + nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x: Tensor, t: Tensor, c: Tensor) -> Tensor: + """ + Apply the model to an input batch. + + :param x: an [N x C] or [N x L x C] Tensor of noisy inputs. + :param t: a 1-D batch of timesteps [N]. + :param c: conditioning from transformer [N x C] or [N x L x C]. + :return: an [N x C] or [N x L x C] Tensor of outputs. + """ + # Handle both 2D (per-sample) and 3D (per-token) inputs + has_seq_dim = x.dim() == 3 + if has_seq_dim: + B, L, C = x.shape + # Flatten to [B*L, C] for processing + x = x.reshape(B * L, C) + c = c.reshape(B * L, -1) + # Expand timesteps to match + t = t.unsqueeze(1).expand(-1, L).reshape(B * L) + + x = self.input_proj(x) + t_emb = self.time_embed(t) + c_emb = self.cond_embed(c) + + y = t_emb + c_emb + + if self.grad_checkpointing and self.training: + for block in self.res_blocks: + x = checkpoint(block, x, y, use_reentrant=False) + else: + for block in self.res_blocks: + x = block(x, y) + + out = self.final_layer(x, y) + + # Reshape back to 3D if needed + if has_seq_dim: + out = out.reshape(B, L, -1) + + return out + + def forward_with_cfg(self, x: Tensor, t: Tensor, c: Tensor, cfg_scale: float) -> Tensor: + """Forward with classifier-free guidance.""" + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, c) + eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +class DiffusionLoss(nn.Module): + """ + Diffusion Loss for continuous structure tokens. + + Models per-token probability p(z|c) using diffusion, eliminating + the need for vector quantization. This is a self-contained implementation + adapted from MAR (https://github.com/LTH14/mar). + + Parameters + ---------- + target_channels : int + Dimension of target continuous embeddings (e.g., 256). + z_channels : int + Dimension of conditioning from transformer. + depth : int + Number of residual blocks in the MLP denoiser. + width : int + Hidden dimension of the MLP denoiser. + num_sampling_steps : str + Number of steps for sampling (e.g., "100" or "250"). + diffusion_steps : int + Total diffusion timesteps for training. + noise_schedule : str + Type of noise schedule: "linear" or "cosine". + learn_sigma : bool + Whether to learn the variance (doubles output channels). + grad_checkpointing : bool + Whether to use gradient checkpointing for memory efficiency. + """ + + def __init__( + self, + target_channels: int, + z_channels: int, + depth: int = 3, + width: int = 1024, + num_sampling_steps: str = "100", + diffusion_steps: int = 1000, + noise_schedule: Literal["linear", "cosine"] = "cosine", + learn_sigma: bool = True, + grad_checkpointing: bool = False, + ): + super().__init__() + + self.target_channels = target_channels + self.diffusion_steps = diffusion_steps + self.learn_sigma = learn_sigma + + # Output channels: double if learning sigma (for mean + variance) + out_channels = target_channels * 2 if learn_sigma else target_channels + + self.net = SimpleMLPAdaLN( + in_channels=target_channels, + model_channels=width, + out_channels=out_channels, + z_channels=z_channels, + num_res_blocks=depth, + grad_checkpointing=grad_checkpointing, + ) + + # Precompute noise schedule + betas = self._get_beta_schedule(noise_schedule, diffusion_steps) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]]) + + # Register buffers + self.register_buffer("betas", betas) + self.register_buffer("alphas_cumprod", alphas_cumprod) + self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) + self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) + self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) + self.register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) + self.register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) + + # For sampling + self.register_buffer("posterior_variance", betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + self.register_buffer( + "posterior_log_variance_clipped", torch.log(torch.clamp(self.posterior_variance, min=1e-20)) + ) + self.register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + self.register_buffer( + "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) + ) + + # Parse sampling steps + self.num_sampling_steps = int(num_sampling_steps) if num_sampling_steps else diffusion_steps + + def _get_beta_schedule(self, schedule: str, num_timesteps: int) -> Tensor: + """Create beta schedule for noise.""" + if schedule == "linear": + beta_start = 0.0001 + beta_end = 0.02 + return torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32) + elif schedule == "cosine": + # Cosine schedule from "Improved Denoising Diffusion Probabilistic Models" + s = 0.008 + steps = num_timesteps + 1 + x = torch.linspace(0, num_timesteps, steps, dtype=torch.float32) + alphas_cumprod = torch.cos(((x / num_timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clamp(betas, 0.0001, 0.9999) + else: + raise ValueError(f"Unknown schedule: {schedule}") + + def q_sample(self, x_start: Tensor, t: Tensor, noise: Tensor | None = None) -> Tensor: + """ + Forward diffusion: add noise to x_start at timestep t. + + q(x_t | x_0) = sqrt(α̅_t) * x_0 + sqrt(1 - α̅_t) * ε + """ + if noise is None: + noise = torch.randn_like(x_start) + + sqrt_alpha = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) + sqrt_one_minus_alpha = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + + return sqrt_alpha * x_start + sqrt_one_minus_alpha * noise + + def _extract(self, arr: Tensor, timesteps: Tensor, broadcast_shape: tuple) -> Tensor: + """Extract values from arr at timesteps and broadcast to shape.""" + res = arr[timesteps] + while len(res.shape) < len(broadcast_shape): + res = res.unsqueeze(-1) + return res.expand(broadcast_shape) + + def forward( + self, + target: Tensor, + z: Tensor, + mask: Tensor | None = None, + return_pred: bool = False, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Compute diffusion loss for structure tokens. + + Parameters + ---------- + target : Tensor [B, L, D] or [B, D] + Ground truth continuous structure embeddings from encoder. + z : Tensor [B, L, D] or [B, D] + Conditioning from transformer (predicted token features). + mask : Tensor [B, L] or [B], optional + Valid token mask. + return_pred : bool + If True, also return the predicted (denoised) embeddings. + + Returns + ------- + loss : Tensor + Scalar diffusion loss. + pred_x0 : Tensor [B, L, D] or [B, D], optional + Predicted denoised embeddings (only if return_pred=True). + """ + # Sample random timesteps + t = torch.randint(0, self.diffusion_steps, (target.shape[0],), device=target.device) + + # Sample noise + noise = torch.randn_like(target) + + # Forward diffusion: x_t = sqrt(α̅_t) * x_0 + sqrt(1 - α̅_t) * ε + x_t = self.q_sample(target, t, noise=noise) + + # Predict noise (and optionally variance) + model_output = self.net(x_t, t, z) + + if self.learn_sigma: + # Split output into noise prediction and variance prediction + pred_noise, pred_var = model_output.chunk(2, dim=-1) + else: + pred_noise = model_output + + # MSE loss on noise prediction + loss = (pred_noise - noise) ** 2 + + # Average over feature dimension + loss = loss.mean(dim=-1) # [B, L] or [B] + + # Apply mask if provided + if mask is not None: + if loss.dim() == 2: # [B, L] + loss = (loss * mask).sum() / (mask.sum() + 1e-8) + else: # [B] + loss = (loss * mask).sum() / (mask.sum() + 1e-8) + else: + loss = loss.mean() + + if return_pred: + # Predict x_0 from noise prediction (same formula as sample()): + # x_0 = sqrt(1/α̅_t) * x_t - sqrt(1/α̅_t - 1) * pred_noise + pred_x0 = ( + self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise + ) + return loss, pred_x0 + + return loss + + @torch.no_grad() + def sample( + self, + z: Tensor, + temperature: float = 1.0, + num_steps: int | None = None, + cfg_scale: float = 1.0, + ) -> Tensor: + """ + Sample continuous embeddings via reverse diffusion (DDPM). + + Parameters + ---------- + z : Tensor [B, L, D] or [B, D] + Conditioning from transformer. + temperature : float + Sampling temperature (scales initial noise). + num_steps : int, optional + Number of sampling steps (defaults to num_sampling_steps). + cfg_scale : float + Classifier-free guidance scale (1.0 = no guidance). + + Returns + ------- + x : Tensor [B, L, D] or [B, D] + Sampled continuous embeddings. + """ + num_steps = num_steps or self.num_sampling_steps + + # Start from pure noise + shape = (*z.shape[:-1], self.target_channels) + x = torch.randn(shape, device=z.device) * temperature + + # Compute timestep indices for sampling + step_indices = torch.linspace(self.diffusion_steps - 1, 0, num_steps, device=z.device).long() + + for i, t in enumerate(step_indices): + t_batch = torch.full((z.shape[0],), t, device=z.device, dtype=torch.long) + + # Predict noise + if cfg_scale != 1.0: + model_output = self.net.forward_with_cfg(x, t_batch, z, cfg_scale) + else: + model_output = self.net(x, t_batch, z) + + if self.learn_sigma: + pred_noise, pred_var = model_output.chunk(2, dim=-1) + else: + pred_noise = model_output + + # DDPM update step (posterior uses precomputed coefs) + + # Predict x_0 from noise prediction + pred_x0 = ( + self._extract(self.sqrt_recip_alphas_cumprod, t_batch, x.shape) * x + - self._extract(self.sqrt_recipm1_alphas_cumprod, t_batch, x.shape) * pred_noise + ) + + # Compute posterior mean + posterior_mean = ( + self._extract(self.posterior_mean_coef1, t_batch, x.shape) * pred_x0 + + self._extract(self.posterior_mean_coef2, t_batch, x.shape) * x + ) + + # Add noise (except at t=0) + if t > 0: + noise = torch.randn_like(x) + posterior_var = self._extract(self.posterior_variance, t_batch, x.shape) + x = posterior_mean + torch.sqrt(posterior_var) * noise + else: + x = posterior_mean + + return x + + +# Convenience function to match MAR API +def create_diffusion_loss( + target_channels: int, + z_channels: int, + depth: int = 3, + width: int = 1024, + num_sampling_steps: str = "100", + noise_schedule: str = "cosine", + grad_checkpointing: bool = False, +) -> DiffusionLoss: + """ + Factory function to create DiffusionLoss with MAR-like defaults. + + Parameters + ---------- + target_channels : int + Dimension of target embeddings (your structure token dim, e.g., 256). + z_channels : int + Dimension of conditioning (transformer hidden dim). + depth : int + Number of MLP residual blocks (default: 3). + width : int + MLP hidden dimension (default: 1024). + num_sampling_steps : str + Steps for inference sampling (default: "100"). + noise_schedule : str + "linear" or "cosine" (default: "cosine"). + grad_checkpointing : bool + Enable gradient checkpointing for memory efficiency. + + Returns + ------- + DiffusionLoss + Configured diffusion loss module. + """ + return DiffusionLoss( + target_channels=target_channels, + z_channels=z_channels, + depth=depth, + width=width, + num_sampling_steps=num_sampling_steps, + diffusion_steps=1000, + noise_schedule=noise_schedule, + learn_sigma=True, + grad_checkpointing=grad_checkpointing, + ) diff --git a/src/lobster/transforms/__init__.py b/src/lobster/transforms/__init__.py index e81c60f3..d43aed1e 100644 --- a/src/lobster/transforms/__init__.py +++ b/src/lobster/transforms/__init__.py @@ -24,6 +24,12 @@ from ._tokenizer_transform import TokenizerTransform from ._transform import Transform +# Note: _ligand_chemistry and _ligand_inference have lazy imports to avoid +# circular dependencies with lobster.model.latent_generator.utils.residue_constants. +# Import them directly from the submodules when needed: +# from lobster.transforms._ligand_chemistry import smiles_to_graph +# from lobster.transforms._ligand_inference import smiles_to_ligand_input + __all__ = [ "AutoTokenizerTransform", "BinarizeTransform", @@ -43,3 +49,31 @@ "ComposedModalityAwareTransform", "SmilesToRDKitDescriptorsTransform", ] + + +def __getattr__(name): + """Lazy import for ligand chemistry functions to avoid circular imports.""" + ligand_chemistry_exports = { + "smiles_to_graph", + "graph_to_smiles", + "atom_types_to_indices", + "indices_to_atom_types", + "mol_to_bond_matrix", + "sdf_to_bond_matrix", + } + ligand_inference_exports = { + "smiles_to_ligand_input", + "sdf_to_ligand_input", + "reconstruct_smiles", + "reconstruct_smiles_from_tokens", + } + + if name in ligand_chemistry_exports: + from . import _ligand_chemistry + + return getattr(_ligand_chemistry, name) + elif name in ligand_inference_exports: + from . import _ligand_inference + + return getattr(_ligand_inference, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/lobster/transforms/_ligand_chemistry.py b/src/lobster/transforms/_ligand_chemistry.py new file mode 100644 index 00000000..01250d40 --- /dev/null +++ b/src/lobster/transforms/_ligand_chemistry.py @@ -0,0 +1,315 @@ +"""Ligand chemistry utilities for SMILES <-> graph conversion. + +This module provides functions for converting between SMILES strings and +graph representations (atom types + bond matrices) for use in Gen-UME +protein-ligand modeling. + +Functions: + smiles_to_graph: Convert SMILES string to (atom_types, bond_matrix) + graph_to_smiles: Convert (atom_types, bond_matrix) to SMILES string + atom_types_to_indices: Convert element strings to vocabulary indices + indices_to_atom_types: Convert vocabulary indices to element strings +""" + +import torch +from rdkit import Chem + +from lobster.model.latent_generator.utils.residue_constants import ( + BOND_TYPES, + ELEMENT_VOCAB_EXTENDED, + ELEMENT_VOCAB_EXTENDED_TO_IDX, +) + + +def atom_types_to_indices(atom_types: list[str]) -> torch.Tensor: + """Convert element strings to vocabulary indices. + + Parameters + ---------- + atom_types : list[str] + List of element symbols (e.g., ["C", "N", "O"]). + + Returns + ------- + torch.Tensor + Tensor of vocabulary indices with shape (N,). + + Examples + -------- + >>> atom_types_to_indices(["C", "N", "O"]) + tensor([3, 4, 5]) + + Notes + ----- + Unknown elements are mapped to the UNK token (index 2). + """ + indices = [] + unk_idx = ELEMENT_VOCAB_EXTENDED_TO_IDX["UNK"] + + for atom in atom_types: + idx = ELEMENT_VOCAB_EXTENDED_TO_IDX.get(atom, unk_idx) + indices.append(idx) + + return torch.tensor(indices, dtype=torch.long) + + +def indices_to_atom_types(indices: torch.Tensor) -> list[str]: + """Convert vocabulary indices to element strings. + + Parameters + ---------- + indices : torch.Tensor + Tensor of vocabulary indices. + + Returns + ------- + list[str] + List of element symbols. + + Examples + -------- + >>> indices_to_atom_types(torch.tensor([3, 4, 5])) + ['C', 'N', 'O'] + """ + return [ELEMENT_VOCAB_EXTENDED[idx.item()] for idx in indices] + + +def smiles_to_graph(smiles: str) -> tuple[list[str], torch.Tensor]: + """Convert SMILES string to graph representation. + + Parameters + ---------- + smiles : str + SMILES string representing a molecule. + + Returns + ------- + tuple[list[str], torch.Tensor] + atom_types: List of element symbols for each heavy atom. + bond_matrix: Symmetric NxN tensor where N is number of atoms. + Values: 0=no bond, 1=single, 2=double, 3=triple, 4=aromatic, 5=other. + + Raises + ------ + ValueError + If SMILES string is invalid. + + Examples + -------- + >>> atom_types, bond_matrix = smiles_to_graph("CCO") + >>> atom_types + ['C', 'C', 'O'] + >>> bond_matrix.shape + torch.Size([3, 3]) + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError(f"Invalid SMILES: {smiles}") + + # Get heavy atoms (non-hydrogen) + num_atoms = mol.GetNumAtoms() + atom_types = [] + + for atom in mol.GetAtoms(): + atom_types.append(atom.GetSymbol()) + + # Build bond matrix + bond_matrix = torch.zeros(num_atoms, num_atoms, dtype=torch.long) + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + + # Map RDKit bond type to our encoding + bond_type = bond.GetBondType() + if bond_type == Chem.BondType.SINGLE: + val = BOND_TYPES["SINGLE"] + elif bond_type == Chem.BondType.DOUBLE: + val = BOND_TYPES["DOUBLE"] + elif bond_type == Chem.BondType.TRIPLE: + val = BOND_TYPES["TRIPLE"] + elif bond_type == Chem.BondType.AROMATIC: + val = BOND_TYPES["AROMATIC"] + else: + val = BOND_TYPES["OTHER"] + + # Symmetric + bond_matrix[i, j] = val + bond_matrix[j, i] = val + + return atom_types, bond_matrix + + +def graph_to_smiles( + atom_types: list[str], + bond_matrix: torch.Tensor, + coords: torch.Tensor | None = None, +) -> str: + """Convert graph representation to SMILES string. + + Parameters + ---------- + atom_types : list[str] + List of element symbols for each atom. + bond_matrix : torch.Tensor + Symmetric NxN tensor of bond types. + Values: 0=no bond, 1=single, 2=double, 3=triple, 4=aromatic, 5=other. + coords : torch.Tensor, optional + 3D coordinates with shape (N, 3). If provided, used for stereochemistry. + + Returns + ------- + str + Canonical SMILES string. + + Examples + -------- + >>> atom_types = ["C", "C", "O"] + >>> bond_matrix = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]]) + >>> graph_to_smiles(atom_types, bond_matrix) + 'CCO' + """ + # Create editable molecule + mol = Chem.RWMol() + + # Add atoms + for elem in atom_types: + atom = Chem.Atom(elem) + mol.AddAtom(atom) + + # Add bonds + num_atoms = len(atom_types) + for i in range(num_atoms): + for j in range(i + 1, num_atoms): + bond_val = bond_matrix[i, j].item() + if bond_val == 0: + continue + + # Map our encoding to RDKit bond type + if bond_val == BOND_TYPES["SINGLE"]: + bond_type = Chem.BondType.SINGLE + elif bond_val == BOND_TYPES["DOUBLE"]: + bond_type = Chem.BondType.DOUBLE + elif bond_val == BOND_TYPES["TRIPLE"]: + bond_type = Chem.BondType.TRIPLE + elif bond_val == BOND_TYPES["AROMATIC"]: + bond_type = Chem.BondType.AROMATIC + else: + bond_type = Chem.BondType.SINGLE # Default to single + + mol.AddBond(i, j, bond_type) + + # Convert to regular molecule + mol = mol.GetMol() + + # Set aromaticity if we have aromatic bonds + if (bond_matrix == BOND_TYPES["AROMATIC"]).any(): + # Mark atoms with aromatic bonds as aromatic + for i in range(num_atoms): + if (bond_matrix[i] == BOND_TYPES["AROMATIC"]).any(): + mol.GetAtomWithIdx(i).SetIsAromatic(True) + + # Add stereochemistry from 3D coordinates if provided + if coords is not None: + _assign_stereochemistry_from_coords(mol, coords) + + # Sanitize molecule + try: + Chem.SanitizeMol(mol) + except Exception: + # If sanitization fails, try without aromaticity perception + Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ Chem.SanitizeFlags.SANITIZE_KEKULIZE) + + return Chem.MolToSmiles(mol) + + +def _assign_stereochemistry_from_coords(mol: Chem.Mol, coords: torch.Tensor) -> None: + """Assign stereochemistry to molecule from 3D coordinates. + + Parameters + ---------- + mol : Chem.Mol + RDKit molecule (modified in place). + coords : torch.Tensor + 3D coordinates with shape (N, 3). + """ + # Create conformer + conf = Chem.Conformer(mol.GetNumAtoms()) + for i, coord in enumerate(coords): + conf.SetAtomPosition(i, coord.tolist()) + + mol.AddConformer(conf, assignId=True) + + # Assign stereochemistry from 3D structure + Chem.AssignStereochemistryFrom3D(mol) + + +def mol_to_bond_matrix(mol: Chem.Mol) -> torch.Tensor: + """Convert RDKit molecule to bond matrix. + + Parameters + ---------- + mol : Chem.Mol + RDKit molecule object. + + Returns + ------- + torch.Tensor + Symmetric NxN tensor of bond types. + + Notes + ----- + This is useful for extracting bond matrices from SDF files. + """ + num_atoms = mol.GetNumAtoms() + bond_matrix = torch.zeros(num_atoms, num_atoms, dtype=torch.long) + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + + bond_type = bond.GetBondType() + if bond_type == Chem.BondType.SINGLE: + val = BOND_TYPES["SINGLE"] + elif bond_type == Chem.BondType.DOUBLE: + val = BOND_TYPES["DOUBLE"] + elif bond_type == Chem.BondType.TRIPLE: + val = BOND_TYPES["TRIPLE"] + elif bond_type == Chem.BondType.AROMATIC: + val = BOND_TYPES["AROMATIC"] + else: + val = BOND_TYPES["OTHER"] + + bond_matrix[i, j] = val + bond_matrix[j, i] = val + + return bond_matrix + + +def sdf_to_bond_matrix(sdf_content: str) -> tuple[list[str], torch.Tensor]: + """Extract atom types and bond matrix from SDF content. + + Parameters + ---------- + sdf_content : str + Content of an SDF file. + + Returns + ------- + tuple[list[str], torch.Tensor] + atom_types: List of element symbols. + bond_matrix: Symmetric NxN tensor of bond types. + + Raises + ------ + ValueError + If SDF content is invalid. + """ + mol = Chem.MolFromMolBlock(sdf_content, removeHs=True) + if mol is None: + raise ValueError("Invalid SDF content") + + atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()] + bond_matrix = mol_to_bond_matrix(mol) + + return atom_types, bond_matrix diff --git a/src/lobster/transforms/_ligand_inference.py b/src/lobster/transforms/_ligand_inference.py new file mode 100644 index 00000000..d0d873bc --- /dev/null +++ b/src/lobster/transforms/_ligand_inference.py @@ -0,0 +1,429 @@ +"""Inference transforms for ligand processing in Gen-UME. + +This module provides transforms for preparing ligand inputs during inference +and reconstructing SMILES from model outputs. + +Functions: + smiles_to_ligand_input: Convert SMILES to model input tensors + sdf_to_ligand_input: Convert SDF to model input tensors + reconstruct_smiles: Convert model outputs to SMILES strings +""" + +import torch +from rdkit import Chem +from torch import Tensor + +from lobster.model.latent_generator.utils.residue_constants import ( + ELEMENT_VOCAB_EXTENDED_TO_IDX, +) + +from ._ligand_chemistry import ( + atom_types_to_indices, + graph_to_smiles, + indices_to_atom_types, + mol_to_bond_matrix, + smiles_to_graph, +) + + +def smiles_to_ligand_input( + smiles: str | list[str], + max_atoms: int | None = None, + device: torch.device | str = "cpu", +) -> dict[str, Tensor]: + """Convert SMILES string(s) to model input tensors. + + Parameters + ---------- + smiles : str or list[str] + SMILES string or list of SMILES strings. + max_atoms : int, optional + Maximum number of atoms (for padding). If None, uses the max + in the batch without padding. + device : torch.device or str + Device to place tensors on. + + Returns + ------- + dict[str, Tensor] + Dictionary containing: + - ligand_atom_input_ids: [B, N_atoms] atom type indices + - ligand_mask: [B, N_atoms] valid atom mask + - bond_matrix: [B, N_atoms, N_atoms] bond types + - atom_types: list[list[str]] original atom type strings + - smiles: list[str] original SMILES strings + + Examples + -------- + >>> inputs = smiles_to_ligand_input("CCO") + >>> inputs["ligand_atom_input_ids"].shape + torch.Size([1, 3]) + >>> inputs = smiles_to_ligand_input(["CCO", "CC(=O)O"]) + >>> inputs["ligand_atom_input_ids"].shape + torch.Size([2, 4]) + """ + # Handle single SMILES + if isinstance(smiles, str): + smiles = [smiles] + + # batch_size = len(smiles) + + # Parse all SMILES + all_atom_types = [] + all_bond_matrices = [] + + for smi in smiles: + atom_types, bond_matrix = smiles_to_graph(smi) + all_atom_types.append(atom_types) + all_bond_matrices.append(bond_matrix) + + # Determine max atoms + if max_atoms is None: + max_atoms = max(len(atoms) for atoms in all_atom_types) + + # Pad and stack + ligand_atom_ids = [] + ligand_masks = [] + padded_bond_matrices = [] + + pad_idx = ELEMENT_VOCAB_EXTENDED_TO_IDX["PAD"] + + for atom_types, bond_matrix in zip(all_atom_types, all_bond_matrices): + n_atoms = len(atom_types) + + # Convert atom types to indices + atom_indices = atom_types_to_indices(atom_types) + + # Pad atom indices + padded_atoms = torch.full((max_atoms,), pad_idx, dtype=torch.long) + padded_atoms[:n_atoms] = atom_indices + ligand_atom_ids.append(padded_atoms) + + # Create mask + mask = torch.zeros(max_atoms) + mask[:n_atoms] = 1.0 + ligand_masks.append(mask) + + # Pad bond matrix + padded_bonds = torch.zeros(max_atoms, max_atoms, dtype=torch.long) + padded_bonds[:n_atoms, :n_atoms] = bond_matrix + padded_bond_matrices.append(padded_bonds) + + return { + "ligand_atom_input_ids": torch.stack(ligand_atom_ids).to(device), + "ligand_mask": torch.stack(ligand_masks).to(device), + "bond_matrix": torch.stack(padded_bond_matrices).to(device), + "atom_types": all_atom_types, + "smiles": smiles, + } + + +def sdf_to_ligand_input( + sdf_path: str | list[str], + max_atoms: int | None = None, + device: torch.device | str = "cpu", + include_coords: bool = True, +) -> dict[str, Tensor]: + """Convert SDF file(s) to model input tensors. + + Parameters + ---------- + sdf_path : str or list[str] + Path to SDF file or list of paths. + max_atoms : int, optional + Maximum number of atoms (for padding). + device : torch.device or str + Device to place tensors on. + include_coords : bool + Whether to include 3D coordinates. + + Returns + ------- + dict[str, Tensor] + Dictionary containing: + - ligand_atom_input_ids: [B, N_atoms] atom type indices + - ligand_mask: [B, N_atoms] valid atom mask + - bond_matrix: [B, N_atoms, N_atoms] bond types + - ligand_coords: [B, N_atoms, 3] 3D coordinates (if include_coords) + - atom_types: list[list[str]] original atom type strings + - smiles: list[str] SMILES derived from SDF + + Examples + -------- + >>> inputs = sdf_to_ligand_input("ligand.sdf") + >>> inputs["ligand_atom_input_ids"].shape + torch.Size([1, N]) + """ + # Handle single path + if isinstance(sdf_path, str): + sdf_path = [sdf_path] + + # batch_size = len(sdf_path) + + # Parse all SDF files + all_atom_types = [] + all_bond_matrices = [] + all_coords = [] + all_smiles = [] + + for path in sdf_path: + # Read SDF file + supplier = Chem.SDMolSupplier(path, removeHs=True) + mol = next(supplier) + + if mol is None: + raise ValueError(f"Failed to parse SDF: {path}") + + # Get atom types + atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()] + all_atom_types.append(atom_types) + + # Get bond matrix + bond_matrix = mol_to_bond_matrix(mol) + all_bond_matrices.append(bond_matrix) + + # Get coordinates + if include_coords and mol.GetNumConformers() > 0: + conf = mol.GetConformer() + coords = torch.tensor( + [list(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())], + dtype=torch.float32, + ) + all_coords.append(coords) + elif include_coords: + # No conformer, create zeros + all_coords.append(torch.zeros(len(atom_types), 3)) + + # Get SMILES + smiles = Chem.MolToSmiles(mol) + all_smiles.append(smiles) + + # Determine max atoms + if max_atoms is None: + max_atoms = max(len(atoms) for atoms in all_atom_types) + + # Pad and stack + ligand_atom_ids = [] + ligand_masks = [] + padded_bond_matrices = [] + padded_coords = [] + + pad_idx = ELEMENT_VOCAB_EXTENDED_TO_IDX["PAD"] + + for i, (atom_types, bond_matrix) in enumerate(zip(all_atom_types, all_bond_matrices)): + n_atoms = len(atom_types) + + # Convert atom types to indices + atom_indices = atom_types_to_indices(atom_types) + + # Pad atom indices + padded_atoms = torch.full((max_atoms,), pad_idx, dtype=torch.long) + padded_atoms[:n_atoms] = atom_indices + ligand_atom_ids.append(padded_atoms) + + # Create mask + mask = torch.zeros(max_atoms) + mask[:n_atoms] = 1.0 + ligand_masks.append(mask) + + # Pad bond matrix + padded_bonds = torch.zeros(max_atoms, max_atoms, dtype=torch.long) + padded_bonds[:n_atoms, :n_atoms] = bond_matrix + padded_bond_matrices.append(padded_bonds) + + # Pad coordinates + if include_coords: + coords = all_coords[i] + padded_coord = torch.zeros(max_atoms, 3) + padded_coord[:n_atoms] = coords + padded_coords.append(padded_coord) + + result = { + "ligand_atom_input_ids": torch.stack(ligand_atom_ids).to(device), + "ligand_mask": torch.stack(ligand_masks).to(device), + "bond_matrix": torch.stack(padded_bond_matrices).to(device), + "atom_types": all_atom_types, + "smiles": all_smiles, + } + + if include_coords: + result["ligand_coords"] = torch.stack(padded_coords).to(device) + + return result + + +def reconstruct_smiles( + ligand_atom_logits: Tensor, + bond_logits: Tensor, + ligand_mask: Tensor | None = None, + ligand_coords: Tensor | None = None, + temperature: float = 1.0, +) -> list[str]: + """Reconstruct SMILES strings from model outputs. + + Parameters + ---------- + ligand_atom_logits : Tensor + Atom type logits with shape [B, N_atoms, atom_vocab_size]. + bond_logits : Tensor + Bond type logits with shape [B, N_atoms, N_atoms, num_bond_types]. + ligand_mask : Tensor, optional + Valid atom mask with shape [B, N_atoms]. + ligand_coords : Tensor, optional + 3D coordinates with shape [B, N_atoms, 3] for stereochemistry. + temperature : float + Temperature for sampling (1.0 = argmax). + + Returns + ------- + list[str] + Reconstructed SMILES strings for each sample in batch. + + Examples + -------- + >>> smiles = reconstruct_smiles(atom_logits, bond_logits, mask) + >>> print(smiles[0]) + 'CCO' + """ + batch_size = ligand_atom_logits.shape[0] + reconstructed_smiles = [] + + for b in range(batch_size): + # Get predicted atom types + if temperature == 1.0: + atom_indices = ligand_atom_logits[b].argmax(dim=-1) + else: + probs = torch.softmax(ligand_atom_logits[b] / temperature, dim=-1) + atom_indices = torch.multinomial(probs, num_samples=1).squeeze(-1) + + # Get predicted bond matrix + bond_matrix = bond_logits[b].argmax(dim=-1) + + # Apply mask if provided + if ligand_mask is not None: + mask = ligand_mask[b].bool() + n_valid = mask.sum().item() + atom_indices = atom_indices[:n_valid] + bond_matrix = bond_matrix[:n_valid, :n_valid] + else: + n_valid = atom_indices.shape[0] + + # Skip if no valid atoms + if n_valid == 0: + reconstructed_smiles.append("") + continue + + # Convert indices to atom types + atom_types = indices_to_atom_types(atom_indices) + + # Filter out special tokens + valid_atoms = [] + valid_indices = [] + for i, atom in enumerate(atom_types): + if atom not in ["PAD", "MASK", "UNK"]: + valid_atoms.append(atom) + valid_indices.append(i) + + if len(valid_atoms) == 0: + reconstructed_smiles.append("") + continue + + # Extract valid portion of bond matrix + valid_indices_t = torch.tensor(valid_indices) + valid_bond_matrix = bond_matrix[valid_indices_t][:, valid_indices_t] + + # Get coordinates if provided + coords = None + if ligand_coords is not None: + coords = ligand_coords[b, valid_indices_t] + + # Convert to SMILES + try: + smiles = graph_to_smiles(valid_atoms, valid_bond_matrix, coords) + reconstructed_smiles.append(smiles) + except Exception: + # If reconstruction fails, return empty string + reconstructed_smiles.append("") + + return reconstructed_smiles + + +def reconstruct_smiles_from_tokens( + ligand_atom_tokens: Tensor, + bond_matrix: Tensor, + ligand_mask: Tensor | None = None, + ligand_coords: Tensor | None = None, +) -> list[str]: + """Reconstruct SMILES from discrete tokens (not logits). + + This is useful when you have the final generated tokens rather + than probability distributions. + + Parameters + ---------- + ligand_atom_tokens : Tensor + Atom type token indices with shape [B, N_atoms]. + bond_matrix : Tensor + Bond type matrix with shape [B, N_atoms, N_atoms]. + ligand_mask : Tensor, optional + Valid atom mask with shape [B, N_atoms]. + ligand_coords : Tensor, optional + 3D coordinates with shape [B, N_atoms, 3] for stereochemistry. + + Returns + ------- + list[str] + Reconstructed SMILES strings for each sample in batch. + """ + batch_size = ligand_atom_tokens.shape[0] + reconstructed_smiles = [] + + for b in range(batch_size): + atom_indices = ligand_atom_tokens[b] + bonds = bond_matrix[b] + + # Apply mask if provided + if ligand_mask is not None: + mask = ligand_mask[b].bool() + n_valid = mask.sum().item() + atom_indices = atom_indices[:n_valid] + bonds = bonds[:n_valid, :n_valid] + else: + n_valid = atom_indices.shape[0] + + if n_valid == 0: + reconstructed_smiles.append("") + continue + + # Convert indices to atom types + atom_types = indices_to_atom_types(atom_indices) + + # Filter out special tokens + valid_atoms = [] + valid_indices = [] + for i, atom in enumerate(atom_types): + if atom not in ["PAD", "MASK", "UNK"]: + valid_atoms.append(atom) + valid_indices.append(i) + + if len(valid_atoms) == 0: + reconstructed_smiles.append("") + continue + + # Extract valid portion of bond matrix + valid_indices_t = torch.tensor(valid_indices, device=bonds.device) + valid_bond_matrix = bonds[valid_indices_t][:, valid_indices_t] + + # Get coordinates if provided + coords = None + if ligand_coords is not None: + coords = ligand_coords[b, valid_indices_t] + + # Convert to SMILES + try: + smiles = graph_to_smiles(valid_atoms, valid_bond_matrix, coords) + reconstructed_smiles.append(smiles) + except Exception: + reconstructed_smiles.append("") + + return reconstructed_smiles diff --git a/src/lobster/transforms/_structure_transforms.py b/src/lobster/transforms/_structure_transforms.py index 4e82a2df..a5852f79 100644 --- a/src/lobster/transforms/_structure_transforms.py +++ b/src/lobster/transforms/_structure_transforms.py @@ -589,16 +589,23 @@ def __init__(self, max_length=512, rand_permute_ligand=False, **kwargs): self.periodic_table = Chem.GetPeriodicTable() def __call__(self, x: dict) -> dict: - # Convert atom names to element indices using our vocabulary + # Convert atom names to element indices using the extended vocabulary (25 elements) if "atom_names" in x: - element_indices = torch.tensor( - [ - residue_constants.ELEMENT_TO_IDX[atom_name] # Will raise KeyError if element not in vocab - for atom_name in x["atom_names"] - ], - dtype=torch.long, - ) - x["element_indices"] = element_indices + element_indices = [] + unknown_elements = set() + # UNK token index for unknown elements (better than PAD) + unk_idx = residue_constants.ELEMENT_VOCAB_EXTENDED_TO_IDX.get("UNK", 2) + for atom_name in x["atom_names"]: + # Use ELEMENT_VOCAB_EXTENDED_TO_IDX which includes Se, Fe, Cu, Zn, etc. + idx = residue_constants.ELEMENT_VOCAB_EXTENDED_TO_IDX.get(atom_name, unk_idx) + if idx == unk_idx and atom_name not in ("UNK", "PAD"): + unknown_elements.add(atom_name) + element_indices.append(idx) + + if unknown_elements: + logger.warning(f"Unknown elements mapped to UNK: {unknown_elements}") + + x["element_indices"] = torch.tensor(element_indices, dtype=torch.long) if self.rand_permute_ligand: random_order = torch.randperm(x["atom_coords"].shape[0]) @@ -608,6 +615,9 @@ def __call__(self, x: dict) -> dict: x["atom_names"] = [x["atom_names"][i] for i in random_order_list] if "element_indices" in x: x["element_indices"] = x["element_indices"][random_order] + # Permute bond_matrix: reindex both rows and columns + if "bond_matrix" in x: + x["bond_matrix"] = x["bond_matrix"][random_order][:, random_order] return x diff --git a/uv.lock b/uv.lock index 2add4cab..c34a32a9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'darwin'", @@ -101,13 +101,13 @@ name = "aiohttp" version = "3.12.15" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "aiohappyeyeballs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "aiosignal", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "attrs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "frozenlist", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "multidict", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "propcache", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "yarl", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "aiohappyeyeballs", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "aiosignal", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "attrs", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "frozenlist", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "multidict", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "propcache", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "yarl", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9b/e7/d92a237d8802ca88483906c388f7c201bbe96cd80a165ffd0ac2f6a8d59f/aiohttp-3.12.15.tar.gz", hash = "sha256:4fc61385e9c98d72fcdf47e6dd81833f47b2f77c114c29cd64a361be57a763a2", size = 7823716, upload-time = "2025-07-29T05:52:32.215Z" } wheels = [ @@ -161,8 +161,8 @@ name = "aiosignal" version = "1.4.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "frozenlist", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "frozenlist", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } wheels = [ @@ -320,12 +320,12 @@ name = "biopandas" version = "0.5.0.dev0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "looseversion", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "mmtf-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "looseversion", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "mmtf-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, - { name = "pandas", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "setuptools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pandas", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "setuptools", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9a/c2/8db303e82e7c7de1980c64a20607d0d947aff5cdc957658bd0023ad7058c/biopandas-0.5.0.dev0.tar.gz", hash = "sha256:e5ca32f0e1a5971d664bac931436cf3c3205746bdf6693fe661a6b55a234811d", size = 990273, upload-time = "2023-04-03T17:02:34.223Z" } wheels = [ @@ -362,13 +362,13 @@ name = "biotite" version = "1.4.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "biotraj", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "msgpack", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "biotraj", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "msgpack", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, - { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ff/64/3526f99fe09add35decea977bd3049672fc0be689d7e0557b0564a55600e/biotite-1.4.0.tar.gz", hash = "sha256:0428427fff47e046a36ecdda1cbb38fc61e652e8df4339bf0a0b7a248a051a8b", size = 37035933, upload-time = "2025-07-07T12:08:53.956Z" } wheels = [ @@ -389,7 +389,7 @@ source = { registry = "https://pypi.python.org/simple" } dependencies = [ { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, - { name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/07/21/2287edfd0d2569639eea706e25c39e63b46a384cf1712db8ea05768317b0/biotraj-1.2.2.tar.gz", hash = "sha256:4bcba92101ed50f369cc1487fb5dfcfe1d8402ad47adaa9232b080553271663a", size = 3909030, upload-time = "2024-11-02T11:30:54.974Z" } wheels = [ @@ -408,9 +408,9 @@ name = "boto3" version = "1.40.18" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "botocore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "jmespath", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "s3transfer", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "botocore", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "jmespath", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "s3transfer", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/36/35/a30dc21ca6582358e0ce963f38e85d42ea619f12e7be4101a834c21d749d/boto3-1.40.18.tar.gz", hash = "sha256:64301d39adecc154e3e595eaf0d4f28998ef0a5551f1d033aeac51a9e1a688e5", size = 111994, upload-time = "2025-08-26T19:21:38.61Z" } wheels = [ @@ -422,9 +422,9 @@ name = "botocore" version = "1.40.18" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "jmespath", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "jmespath", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6a/91/2e745382793fa7d30810a7d5ca3e05f6817b6db07601ca5aaab12720caf9/botocore-1.40.18.tar.gz", hash = "sha256:afd69bdadd8c55cc89d69de0799829e555193a352d87867f746e19020271cc0f", size = 14375007, upload-time = "2025-08-26T19:21:24.996Z" } wheels = [ @@ -1614,7 +1614,7 @@ name = "jinja2" version = "3.1.6" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "markupsafe", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "markupsafe", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } wheels = [ @@ -1804,6 +1804,7 @@ dependencies = [ { name = "onnx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "onnxruntime", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "onnxscript", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "openbabel-wheel", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "pandas", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "peft", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "polars", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, @@ -1965,6 +1966,7 @@ requires-dist = [ { name = "onnx" }, { name = "onnxruntime" }, { name = "onnxscript" }, + { name = "openbabel-wheel", specifier = ">=3.1.1" }, { name = "optree", marker = "extra == 'struct-cpu'" }, { name = "optree", marker = "extra == 'struct-gpu'" }, { name = "pandas" }, @@ -2093,9 +2095,9 @@ name = "lightning-utilities" version = "0.15.2" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "setuptools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "setuptools", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b8/39/6fc58ca81492db047149b4b8fd385aa1bfb8c28cd7cacb0c7eb0c44d842f/lightning_utilities-0.15.2.tar.gz", hash = "sha256:cdf12f530214a63dacefd713f180d1ecf5d165338101617b4742e8f22c032e24", size = 31090, upload-time = "2025-08-06T13:57:39.242Z" } wheels = [ @@ -2167,7 +2169,7 @@ name = "markdown-it-py" version = "4.0.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "mdurl", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "mdurl", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } wheels = [ @@ -2207,16 +2209,16 @@ name = "matplotlib" version = "3.10.6" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "contourpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "cycler", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "kiwisolver", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "contourpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "cycler", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "kiwisolver", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, - { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pyparsing", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "pyparsing", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a0/59/c3e6453a9676ffba145309a73c462bb407f4400de7de3f2b41af70720a3c/matplotlib-3.10.6.tar.gz", hash = "sha256:ec01b645840dd1996df21ee37f208cd8ba57644779fa20464010638013d3203c", size = 34804264, upload-time = "2025-08-30T00:14:25.137Z" } wheels = [ @@ -2339,7 +2341,7 @@ name = "mmtf-python" version = "1.1.3" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "msgpack", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "msgpack", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d8/0f/f3c132dc9aac9a3f32a0eba7a80f07d14e7624e96f9245eeac5fe48f42cd/mmtf-python-1.1.3.tar.gz", hash = "sha256:12a02fe1b7131f0a2b8ce45b46f1e0cdd28b9818fe4499554c26884987ea0c32", size = 46032, upload-time = "2022-07-06T03:06:25.993Z" } wheels = [ @@ -2643,7 +2645,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -2656,7 +2658,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -2688,9 +2690,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -2703,7 +2705,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -2800,8 +2802,8 @@ name = "omegaconf" version = "2.3.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "antlr4-python3-runtime", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "antlr4-python3-runtime", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } wheels = [ @@ -2954,6 +2956,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/dd/b3fd642260cb17532f66cc1e8250f3507d1e580483e209dc1e9d13bd980d/openapi_spec_validator-0.7.2-py3-none-any.whl", hash = "sha256:4bbdc0894ec85f1d1bea1d6d9c8b2c3c8d7ccaa13577ef40da9c006c9fd0eb60", size = 39713, upload-time = "2025-06-07T14:48:54.077Z" }, ] +[[package]] +name = "openbabel-wheel" +version = "3.1.1.22" +source = { registry = "https://pypi.python.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/55/c916675c49beb64256bd108f4a797111cedca592c9be0c931014b8654806/openbabel-wheel-3.1.1.22.tar.gz", hash = "sha256:d12e07f8e2b2a8a9007b486087fdeb731cedb5b14129bf93ad3d2c5a72e54d08", size = 13246, upload-time = "2025-05-20T14:27:41.725Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/de/be239b210500c0332007e18dd0a8ca4939d13f05433d453e3a0983defa70/openbabel_wheel-3.1.1.22-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:2499af11947d3454439182a75516bc4008fe0ff22a24f16b3b964317004d5e86", size = 11514597, upload-time = "2025-05-20T14:26:20.174Z" }, + { url = "https://files.pythonhosted.org/packages/0d/cb/167d9bace82f82fa9231e562e723d489f5c07e789290e668b552874c69f8/openbabel_wheel-3.1.1.22-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:4beda0224caa906d1c2ed35652786da244757aaea361cdc0aef13e935c0149a7", size = 12381697, upload-time = "2025-05-20T14:26:22.251Z" }, + { url = "https://files.pythonhosted.org/packages/40/61/3f6f6e78c45d67453628596dff1640e8d56d58e6317ddace5c3be3d550f4/openbabel_wheel-3.1.1.22-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b33c21c987b3a22f7a4c812b8283edbe5ad6c6deae29e9568f1b16a8f5ad20b", size = 15633459, upload-time = "2025-05-20T14:26:24.554Z" }, + { url = "https://files.pythonhosted.org/packages/28/59/f7e7424da6a6ea81f987fac2093e440b8689f4e27929f92167f2a81fc01d/openbabel_wheel-3.1.1.22-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2ed691f3f5ac2cfebfaa0d9718acd58e41fcbc0e114735f518c92167f7c5e2f", size = 16076298, upload-time = "2025-05-20T14:26:26.823Z" }, + { url = "https://files.pythonhosted.org/packages/d5/27/c219228b4eb582ab11ac0ae2a8acd895e880a4f427e9bdf17d0e41783d91/openbabel_wheel-3.1.1.22-cp311-cp311-win_amd64.whl", hash = "sha256:1303272ecdcfcdfbe7648f2761aa6639c5a4ec3073b472f766d3d2567a44b523", size = 5095182, upload-time = "2025-05-20T14:26:29.03Z" }, + { url = "https://files.pythonhosted.org/packages/0e/91/df492446daefef9ba3c97e1555b1a7abb4ab9a21f00744689fec8e10d53e/openbabel_wheel-3.1.1.22-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:48846ac5c8dff5589fd7834fbc1ee0d3a5d67ed544ef81b716db6932a27c34c2", size = 11517993, upload-time = "2025-05-20T14:26:30.741Z" }, + { url = "https://files.pythonhosted.org/packages/64/35/b5ccaa112d253fb9089db9c5390d066cab08b5401cd2ec59709937f0a746/openbabel_wheel-3.1.1.22-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:9189315a466da529e6422fa588163be230f0fcfb8d79e240fd9d8330fd2af8a6", size = 12390413, upload-time = "2025-05-20T14:26:33.304Z" }, + { url = "https://files.pythonhosted.org/packages/bc/ce/e9a9553a10ec6f0a3f1a3112cd00619bb77651779b59507ce1be4dcef2a1/openbabel_wheel-3.1.1.22-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17588f6823a1a2bc361fbdfec167e8d30d4764176b0b63a3e3127efae71aed8f", size = 15632153, upload-time = "2025-05-20T14:26:36.864Z" }, + { url = "https://files.pythonhosted.org/packages/bc/79/a713470f17697b5ebee488929a88b2daaa3b20eb5be502d6c5cab9dbc4e2/openbabel_wheel-3.1.1.22-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcc3d01459489334e9acf68526db48be91ce88692c59192a0d0fb9f422f2bc9a", size = 16075937, upload-time = "2025-05-20T14:26:39.362Z" }, + { url = "https://files.pythonhosted.org/packages/21/f3/12a0e194e735cb0f3fc6122cfd39d63f05ab2426ee4083b7de1bc4ea1242/openbabel_wheel-3.1.1.22-cp312-cp312-win_amd64.whl", hash = "sha256:e369179c9edfd35b79faa252be0b28b7b8932aee80acfced1922595b67594dbd", size = 5095492, upload-time = "2025-05-20T14:26:46.743Z" }, +] + [[package]] name = "opt-einsum" version = "3.4.0" @@ -2984,7 +3004,7 @@ name = "optree" version = "0.17.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/56/c7/0853e0c59b135dff770615d2713b547b6b3b5cde7c10995b4a5825244612/optree-0.17.0.tar.gz", hash = "sha256:5335a5ec44479920620d72324c66563bd705ab2a698605dd4b6ee67dbcad7ecd", size = 163111, upload-time = "2025-07-25T11:26:11.586Z" } wheels = [ @@ -3093,9 +3113,9 @@ source = { registry = "https://pypi.python.org/simple" } dependencies = [ { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, - { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pytz", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "tzdata", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "pytz", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "tzdata", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/79/8e/0e90233ac205ad182bd6b422532695d2b9414944a280488105d598c70023/pandas-2.3.2.tar.gz", hash = "sha256:ab7b58f8f82706890924ccdfb5f48002b83d2b5a3845976a9fb705d36c34dcdb", size = 4488684, upload-time = "2025-08-21T10:28:29.257Z" } wheels = [ @@ -3590,7 +3610,7 @@ name = "python-dateutil" version = "2.9.0.post0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "six", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "six", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } wheels = [ @@ -3867,10 +3887,10 @@ name = "requests" version = "2.32.5" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "charset-normalizer", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "charset-normalizer", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } wheels = [ @@ -3894,8 +3914,8 @@ name = "rich" version = "14.1.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "markdown-it-py", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pygments", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "markdown-it-py", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "pygments", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fe/75/af448d8e52bf1d8fa6a9d089ca6c07ff4453d86c65c145d0a300bb073b9b/rich-14.1.0.tar.gz", hash = "sha256:e497a48b844b0320d45007cdebfeaeed8db2a4f4bcf49f15e455cfc4af11eaa8", size = 224441, upload-time = "2025-07-25T07:32:58.125Z" } wheels = [ @@ -3920,8 +3940,8 @@ name = "rotary-embedding-torch" version = "0.8.9" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "einops", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "torch", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "einops", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "torch", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/71/38/74783585b1f0282fddd3faf1abd6dd20977255c27e31737eced2d7ec05f1/rotary_embedding_torch-0.8.9.tar.gz", hash = "sha256:b213f153cad1d108064d930544fb3af678d56515893d3f869a7a146f87997e3f", size = 7497, upload-time = "2025-07-27T01:26:14.675Z" } wheels = [ @@ -4028,7 +4048,7 @@ name = "s3transfer" version = "0.13.1" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "botocore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "botocore", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6d/05/d52bf1e65044b4e5e27d4e63e8d1579dbdec54fce685908ae09bc3720030/s3transfer-0.13.1.tar.gz", hash = "sha256:c3fdba22ba1bd367922f27ec8032d6a1cf5f10c934fb5d68cf60fd5a23d936cf", size = 150589, upload-time = "2025-07-18T19:22:42.31Z" } wheels = [ @@ -4062,11 +4082,11 @@ name = "scikit-learn" version = "1.7.2" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "joblib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "joblib", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-cpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.python.org/simple" }, marker = "(sys_platform == 'darwin' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra != 'extra-6-lbster-struct-cpu' and extra != 'extra-6-lbster-struct-gpu') or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, - { name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "threadpoolctl", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "scipy", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "threadpoolctl", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136, upload-time = "2025-09-09T08:21:29.075Z" } wheels = [ @@ -4286,7 +4306,7 @@ name = "sympy" version = "1.14.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "mpmath", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "mpmath", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } wheels = [ @@ -4465,28 +4485,28 @@ name = "torch" version = "2.8.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "fsspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin') or (python_full_version >= '3.12' and sys_platform == 'linux')" }, - { name = "sympy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "fsspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'darwin') or (python_full_version >= '3.12' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'darwin' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform == 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "sympy", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu') or (sys_platform != 'linux' and extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391, upload-time = "2025-08-06T14:53:20.937Z" }, @@ -4793,7 +4813,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "setuptools", marker = "sys_platform == 'linux'" }, + { name = "setuptools", marker = "sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" }, @@ -5038,9 +5058,9 @@ name = "yarl" version = "1.20.1" source = { registry = "https://pypi.python.org/simple" } dependencies = [ - { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "multidict", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "propcache", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "multidict", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, + { name = "propcache", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or (extra == 'extra-6-lbster-struct-cpu' and extra == 'extra-6-lbster-struct-gpu')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3c/fb/efaa23fa4e45537b827620f04cf8f3cd658b76642205162e072703a5b963/yarl-1.20.1.tar.gz", hash = "sha256:d017a4997ee50c91fd5466cef416231bb82177b93b029906cefc542ce14c35ac", size = 186428, upload-time = "2025-06-10T00:46:09.923Z" } wheels = [ diff --git a/wandb_sweeps/wandb_generate.py b/wandb_sweeps/wandb_generate.py index a4a66c54..515e1263 100644 --- a/wandb_sweeps/wandb_generate.py +++ b/wandb_sweeps/wandb_generate.py @@ -56,7 +56,11 @@ def objective(config): logger.info(f"Completed generation for length {length}, output: {gen_config.output_dir}") # Collect metrics from all length runs - metrics = collect_metrics_from_all_lengths(output_dirs, lengths) + num_samples = config.get("num_samples", 10) # Get requested designs from config + metrics = collect_metrics_from_all_lengths(output_dirs, lengths, num_samples) + + # Store num_samples in metrics for score calculation + metrics["num_samples"] = num_samples # Extract score weights from config score_weights = { @@ -108,7 +112,7 @@ def create_config_from_wandb(config, length: int | None = None) -> DictConfig: "seed": 12345, "model": { "_target_": "lobster.model.gen_ume.UMESequenceStructureEncoderLightningModule", - "ckpt_path": "/data2/ume/gen_ume/runs//2025-10-08T23-54-39/last.ckpt", + "ckpt_path": config.get("ckpt_path", "/data2/ume/gen_ume/runs//2025-10-08T23-54-39/last.ckpt"), }, "generation": { "mode": mode, @@ -119,6 +123,8 @@ def create_config_from_wandb(config, length: int | None = None) -> DictConfig: "temperature_struc": config.get("temperature_struc", 0.5), "stochasticity_seq": config.get("stochasticity_seq", 20), "stochasticity_struc": config.get("stochasticity_struc", 20), + "inference_schedule_seq": config.get("inference_schedule_seq", "LogInferenceSchedule"), + "inference_schedule_struc": config.get("inference_schedule_struc", "LinearInferenceSchedule"), "use_esmfold": True, "max_length": 512, "save_csv_metrics": True, @@ -311,13 +317,16 @@ def collect_metrics_from_output(output_dir: str) -> dict[str, float]: return metrics -def collect_metrics_from_all_lengths(output_dirs: list[str], lengths: list[int]) -> dict[str, float]: +def collect_metrics_from_all_lengths( + output_dirs: list[str], lengths: list[int], num_samples: int = 10 +) -> dict[str, float]: """ Collect metrics from multiple length-specific output directories. Args: output_dirs: List of output directory paths, one per length lengths: List of generation lengths corresponding to output_dirs + num_samples: Number of designs requested per length (for diversity normalization) Returns: Dictionary with both per-length metrics and aggregated metrics @@ -339,19 +348,22 @@ def collect_metrics_from_all_lengths(output_dirs: list[str], lengths: list[int]) # Calculate aggregated metrics across all lengths logger.info("Aggregating metrics across all lengths") - aggregate_metrics = aggregate_across_lengths(per_length_metrics) + aggregate_metrics = aggregate_across_lengths(per_length_metrics, num_samples=num_samples) all_metrics.update(aggregate_metrics) logger.info(f"Total metrics collected: {len(all_metrics)} (per-length + aggregated)") return all_metrics -def aggregate_across_lengths(per_length_metrics: dict[int, dict[str, float]]) -> dict[str, float]: +def aggregate_across_lengths( + per_length_metrics: dict[int, dict[str, float]], num_samples: int = 10 +) -> dict[str, float]: """ Aggregate metrics across all lengths using simple averaging. Args: per_length_metrics: Dict mapping length → metrics dict + num_samples: Number of designs requested per length (for diversity normalization) Returns: Dictionary of aggregated metrics with 'agg_' prefix @@ -379,32 +391,53 @@ def aggregate_across_lengths(per_length_metrics: dict[int, dict[str, float]]) -> if values: aggregated[f"agg_{metric_key}"] = sum(values) / len(values) - # Special handling for diversity: sum total clusters across ALL lengths - cluster_keys = [k for k in all_metric_keys if k.startswith("diversity_num_clusters_")] - if cluster_keys: + # Special handling for diversity: calculate average diversity percentage per length + # IMPORTANT: Divide by num_samples (requested), not successful structures + # This penalizes failures and gives a more informative signal + # Process all lengths, even if some/all have 0 diversity results + if per_length_metrics: total_clusters = 0 + total_structures = 0 + diversity_percentages = [] + for length, metrics in per_length_metrics.items(): - for k in cluster_keys: - if k in metrics: - total_clusters += metrics[k] + # Find cluster and structure keys for this length + cluster_key = f"diversity_num_clusters_length_{length}" + structure_key = f"diversity_total_structures_length_{length}" - aggregated["agg_total_clusters_all_lengths"] = total_clusters - logger.info(f"Total clusters across all lengths: {total_clusters}") + # Get clusters (default 0 if no diversity results for this length) + num_clusters = metrics.get(cluster_key, 0) + total_clusters += num_clusters - # Also calculate total structures across all lengths - total_structures = 0 - for length, metrics in per_length_metrics.items(): - for k in all_metric_keys: - if k.startswith("diversity_total_structures_"): - total_structures += metrics.get(k, 0) + # Also track successful structures for logging + num_structures = metrics.get(structure_key, 0) + total_structures += num_structures + + # Calculate diversity percentage: clusters / REQUESTED designs (not successful) + # IMPORTANT: Always include all lengths, even if 0 clusters (penalizes failure) + diversity_pct = num_clusters / num_samples + diversity_percentages.append(diversity_pct) + + logger.debug( + f"Length {length}: {num_clusters} clusters / {num_samples} requested " + f"= {diversity_pct:.3f} ({diversity_pct * 100:.1f}%) [{num_structures} succeeded]" + ) - if total_structures > 0: + # Store totals for logging/debugging + aggregated["agg_total_clusters_all_lengths"] = total_clusters aggregated["agg_total_structures_all_lengths"] = total_structures - if total_clusters > 0: - aggregated["agg_diversity_percentage_all_lengths"] = (total_clusters / total_structures) * 100 + + # NEW: Calculate average diversity percentage across lengths + if diversity_percentages: + avg_diversity_pct = sum(diversity_percentages) / len(diversity_percentages) + aggregated["agg_avg_diversity_percentage"] = avg_diversity_pct logger.info( - f"Overall diversity: {total_clusters}/{total_structures} = {aggregated['agg_diversity_percentage_all_lengths']:.1f}%" + f"Diversity metrics: {total_clusters} clusters / {num_samples * len(diversity_percentages)} requested " + f"across {len(diversity_percentages)} lengths = {avg_diversity_pct:.3f} ({avg_diversity_pct * 100:.1f}% average diversity per length) " + f"[{total_structures} structures succeeded]" ) + else: + logger.info(f"Total clusters: {total_clusters}, Total structures: {total_structures}") return aggregated @@ -450,28 +483,42 @@ def calculate_composite_score(metrics: dict[str, float], score_weights: dict[str mode = "forward_folding" if mode == "unconditional": - # MAIN METRIC: Number of foldseek clusters (diversity) across ALL lengths + # MAIN METRIC: Average diversity percentage across lengths + # (Normalized by REQUESTED designs, not successful) # Try to use aggregated metrics first (multi-length mode) - if "agg_total_clusters_all_lengths" in metrics: - # Multi-length mode: use aggregated total clusters - total_clusters = metrics["agg_total_clusters_all_lengths"] - diversity_score = total_clusters * score_weights["diversity"] + if "agg_avg_diversity_percentage" in metrics: + # Multi-length mode: use average diversity percentage + avg_diversity_pct = metrics["agg_avg_diversity_percentage"] + diversity_score = avg_diversity_pct * score_weights["diversity"] score += diversity_score logger.info( - f"Diversity contribution to score (MAIN METRIC, ALL LENGTHS): {total_clusters} clusters * " - f"{score_weights['diversity']:.2f} weight = {diversity_score:.2f} points" + f"Diversity contribution to score (MAIN METRIC): {avg_diversity_pct:.3f} ({avg_diversity_pct * 100:.1f}%) avg diversity * " + f"{score_weights['diversity']:.2f} weight = {diversity_score:.4f} points" ) else: - # Single-length mode (backward compatible): sum clusters from individual length keys + # Single-length mode (backward compatible): calculate diversity percentage + # Need num_samples from metrics for proper calculation cluster_keys = [k for k in metrics.keys() if k.startswith("diversity_num_clusters_")] if cluster_keys: - total_clusters = sum(metrics[k] for k in cluster_keys) - diversity_score = total_clusters * score_weights["diversity"] - score += diversity_score - logger.info( - f"Diversity contribution to score (MAIN METRIC): {total_clusters} clusters * " - f"{score_weights['diversity']:.2f} weight = {diversity_score:.2f} points" - ) + # Try to get num_samples from metrics, fallback to 10 + num_samples = metrics.get("num_samples", 10) + diversity_percentages = [] + + for cluster_key in cluster_keys: + num_clusters = metrics[cluster_key] + # Divide by requested designs, not successful structures + diversity_pct = num_clusters / num_samples + diversity_percentages.append(diversity_pct) + + if diversity_percentages: + avg_diversity_pct = sum(diversity_percentages) / len(diversity_percentages) + diversity_score = avg_diversity_pct * score_weights["diversity"] + score += diversity_score + logger.info( + f"Diversity contribution to score (MAIN METRIC): {avg_diversity_pct:.3f} ({avg_diversity_pct * 100:.1f}%) " + f"avg diversity across {len(diversity_percentages)} lengths * " + f"{score_weights['diversity']:.2f} weight = {diversity_score:.4f} points" + ) # Secondary metrics (configurable weights) # Use aggregated metrics if available (multi-length), otherwise use direct metrics (single-length) @@ -515,13 +562,13 @@ def calculate_composite_score(metrics: dict[str, float], score_weights: dict[str elif mode == "inverse_folding": # Higher is better: percent_identity, plddt, tm_score if "avg_percent_identity" in metrics: - score += metrics["avg_percent_identity"] * 0.4 # Most important for inverse folding + score += metrics["avg_percent_identity"] * 0.1 # Most important for inverse folding if "avg_plddt" in metrics: score += metrics["avg_plddt"] * 0.2 if "avg_tm_score" in metrics: - score += metrics["avg_tm_score"] * 0.4 + score += metrics["avg_tm_score"] * 1.0 # Lower is better: predicted_aligned_error, rmsd if "avg_predicted_aligned_error" in metrics: diff --git a/wandb_sweeps/wandb_slurm.sh b/wandb_sweeps/wandb_slurm.sh old mode 100644 new mode 100755 index 187a588e..76f3f54b --- a/wandb_sweeps/wandb_slurm.sh +++ b/wandb_sweeps/wandb_slurm.sh @@ -1,34 +1,55 @@ #!/bin/bash -#SBATCH --partition b200 -#SBATCH --array=1-16 +#SBATCH --partition=preempt +#SBATCH --account=llm +#SBATCH --array=1-50 #SBATCH --nodes 1 #SBATCH --ntasks-per-node 1 -#SBATCH --gpus-per-node 1 +#SBATCH --gres=gpu:b200:1 #SBATCH --cpus-per-task 4 -#SBATCH -o /data2/ume/gen_ume/slurm/logs/inference/%J_%x.out +#SBATCH -o /cv/scratch/u/lisanzas/sweeps/logs/%J_%x.out #SBATCH -q preempt #SBATCH --mem=256G -#SBATCH --job-name=gen_ume_hyp_param -#SBATCH -t 7-00:00:00 +#SBATCH --job-name=pl_sweep +#SBATCH -t 2-00:00:00 + +# ============================================================================= +# Protein-Ligand W&B Sweep SLURM Submission Script +# ============================================================================= +# Usage: +# 1. Create sweep: wandb sweep wandb_sweep_config_protein_ligand_inverse_folding.yaml +# 2. Update SWEEP_ID below with the returned sweep ID +# 3. Submit: sbatch wandb_slurm.sh +# ============================================================================= + +# Set your sweep ID here (from wandb sweep command output) +SWEEP_ID="${SWEEP_ID:-f42gu2mv}" +WANDB_PROJECT="${WANDB_PROJECT:-lobster-wandb_sweeps}" nvidia-smi -#source .venv/bin/activate -source /homefs/home/lisanzas/scratch/Develop/lobster/lobster_env/bin/activate +# Change to script directory +cd /cv/home/lisanzas/lobster/wandb_sweeps + echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURM_ARRAY_TASK_ID = ${SLURM_ARRAY_TASK_ID}" +echo "SWEEP_ID = ${SWEEP_ID}" export LD_LIBRARY_PATH=/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/opt/amazon/ofi-nccl/lib64 export WANDB_INSECURE_DISABLE_SSL=true export HYDRA_FULL_ERROR=1 export PYTHONUNBUFFERED=1 -export NCCL_DEBUG=INFO -export LOBSTER_RUNS_DIR="/data2/ume/gen_ume/runs/" #"s3://prescient-lobster/ume/runs" # CHANGE TO YOUR S3 BUCKET -export LOBSTER_DATA_DIR="/data2/ume/.cache2/" # CHANGE TO YOUR DATA DIRECTORY -export LOBSTER_USER=$(whoami) # CHANGE TO YOUR WANDB USERNAME IF NOT YOUR UNIXID +export LOBSTER_RUNS_DIR="/cv/scratch/u/lisanzas/gen_ume_protein_ligand/runs/" +export LOBSTER_DATA_DIR="/cv/scratch/u/lisanzas/.cache/" +export LOBSTER_USER=$(whoami) export WANDB_BASE_URL=https://genentech.wandb.io export TOKENIZERS_PARALLELISM=true -srun -u --cpus-per-task $SLURM_CPUS_PER_TASK --cpu-bind=cores,verbose wandb agent prescient-design/lobster-wandb_sweeps/b7wlmyg8 \ No newline at end of file +# Create log directory if it doesn't exist +mkdir -p /cv/scratch/u/lisanzas/sweeps/logs + +# Run wandb agent with uv +srun -u --cpus-per-task $SLURM_CPUS_PER_TASK --cpu-bind=cores,verbose \ + uv run wandb agent "prescient-design/${WANDB_PROJECT}/${SWEEP_ID}"