Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions models/saprot_vh_vl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Saprot_VH_VL Baseline

Ridge regression on embeddings from the **SaProt** protein language model on VH (variable heavy) and VL (variable light) sequences with two-chain encoding.

## Description

SaProt (Structure-aware Protein Language Model) generates predictions for protein properties using sequence and structure information.

This baseline uses locally computed SaProt embeddings(SaProt_35M_AF2) to generate fixed-length embeddings for VH and VL chains using their sequences and structures, then concatenates these embeddings(VH + VL) and trains simple Ridge regression models on top to predict antibody developability properties.

The rationale behind this joint representation is same as that of the ESM2 case(no token contamination and learning from features independently).

Note: At the time of writing, there were two choices to fetch structures from - MOE and ABB3. This implementation concerns itself only with the MOE structures

## Method

### 1. Separate Chain Embedding

For each antibody, we embed the heavy and light chains independently:

**VH Embedding:**
```
Complexed .pdb files from MOE → Extract VH pdb → FoldSeek 3di Descriptors → Interleaved with VH_seq → SaProt Tokenizer → Last Hidden State → Mean Pool → vh_embedding
```

**VL Embedding:**
```
Complexed .pdb files from MOE → Extract VL pdb → FoldSeek 3di Descriptors → Interleaved with VL_seq → SaProt Tokenizer → Last Hidden State → Mean Pool → vl_embedding
```
### 2. Feature Concatenation

After generating embeddings for both chains, we concatenate them:

```
combined_embedding = np.concatenate([vh_embed, vl_embed])
```

For SaProt_35M_AF2, the embedding dimension is 480, so:
- VH embedding: 480D
- VL embedding: 480D
- Combined: 960D


## Requirements

- The Complexed(VH+VL) PDB structures for training are in `../../data/structures/MOE_structures/GDPa1/` and in the format of `{antibody_name}.csv`
- The Complexed(VH+VL) PDB structures for the heldout data is in `../../data/structures/MOE_structures/heldout_test/` and in the format of `{antibody_name}.csv`
- The Heavy chains are labelled by 'B' and the light chains are labelled by 'A'
- foldseek installed
- BioPython installed

Note: While SaProt embeddings can be calculated from the sequence and structure information, in the absence of structure information, it defaults to calculating embeddings with sequence information only.

Also Note: Current implementation works around abdev-core via hard-coding the size of heldout data. This is not good practice, and is only a temporary fix

### Train

From the repository root:

```bash
cd model/saprot_vh_vl
pixi install

# Train on GDPa1 dataset
pixi run python -m saprot_vh_vl train \
--data ../../data/GDPa1_v1.2_20250814.csv \
--run-dir ./runs/my_run
```

### Predict

```bash
# Predict on training data
pixi run python -m saprot_vh_vl predict \
--data ../../data/GDPa1_v1.2_20250814.csv \
--run-dir ./runs/my_run
```

### Full Workflow via Orchestrator

From repository root:

```bash
pixi run all
```

This automatically discovers and runs all models, including SaProt_VH_VL, with 5-fold cross-validation.

## Citation

Saprot: Su J, et al. (2023). "SaProt: Protein Language Modeling with Structure-aware Vocabulary." bioRxiv.

# Code References

SaProt - https://github.com/westlake-repl/SaProt
Foldseek - https://github.com/steineggerlab/foldseek

247 changes: 247 additions & 0 deletions models/saprot_vh_vl/outputs/heldout/predictions.csv

Large diffs are not rendered by default.

247 changes: 247 additions & 0 deletions models/saprot_vh_vl/outputs/train/predictions.csv

Large diffs are not rendered by default.

4,032 changes: 4,032 additions & 0 deletions models/saprot_vh_vl/pixi.lock

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions models/saprot_vh_vl/pixi.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
[workspace]
name = "saprot_vh_vl"
version = "0.1.0"
description = "Saprot_VH_VL baseline - protein language model predictions on VH and VL sequences and Structures"
channels = ["conda-forge", "bioconda", "pytorch"]
platforms = ["linux-64", "osx-64", "osx-arm64"]

[dependencies]
python = "3.11.*"
numpy = ">=1.24"
pandas = ">=2.0"
scikit-learn = ">=1.3"
typer = ">=0.9"
foldseek = "*"
biopython = ">=1.81"


[pypi-dependencies]
abdev-core = { path = "../../libs/abdev_core", editable = true }
saprot_vh_vl = { path = ".", editable = true }
transformers = ">=4.30"
torch = ">=2.0"

[environments]
default = []
dev = ["dev"]

[feature.dev.dependencies]
pytest = ">=7.0"
ruff = ">=0.1"

[feature.dev.tasks]
lint = "ruff check src && ruff format --check src"
test = "pytest tests -v"

20 changes: 20 additions & 0 deletions models/saprot_vh_vl/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "saprot_vh_vl"
version = "0.1.0"
description = "Saprot_VH_VL baseline - protein language model predictions on VH and VL sequences and structures"
requires-python = ">=3.11"
dependencies = [
"abdev-core",
"pandas>=2.0",
]

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-dir]
"" = "src"

Binary file added models/saprot_vh_vl/runs/my_run/embeddings.npy
Binary file not shown.
Binary file added models/saprot_vh_vl/runs/my_run/models.pkl
Binary file not shown.
4 changes: 4 additions & 0 deletions models/saprot_vh_vl/src/saprot_vh_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Saprot_VH_VL baseline - protein language model predictions."""

__version__ = "0.1.0"

7 changes: 7 additions & 0 deletions models/saprot_vh_vl/src/saprot_vh_vl/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Entry point for model CLI."""

from .run import app

if __name__ == "__main__":
app()

Loading