Skip to content

Commit 297c79e

Browse files
author
Tobias Olenyi
committed
Merge branch 'add_meta' into 'main'
Add meta script with all parameters of all available scripts See merge request olenyi/vespa-cli!7
2 parents 7cb985a + 74de550 commit 297c79e

File tree

12 files changed

+734
-187
lines changed

12 files changed

+734
-187
lines changed

README.md

Lines changed: 80 additions & 81 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[tool.poetry]
22
name = "vespa"
3-
version = "0.3.0-beta"
3+
version = "0.9.0-beta"
44
description = ""
5-
authors = ["Tobias O <tobias.olenyi@tum.de>"]
5+
authors = ["Tobias O <tobias.olenyi@tum.de>", "Duc Anh Le <ducanh.le@tum.de>"]
66

77
[tool.poetry.dependencies]
88
python = ">=3.9,<3.11"
@@ -25,8 +25,10 @@ build-backend = "poetry.core.masonry.api"
2525

2626
[tool.poetry.scripts]
2727
vespa_logodds = 'vespa.scripts.logodds:main'
28-
vespa = 'vespa.scripts.vespa:main'
28+
vespa = 'vespa.scripts.meta:main'
2929
vespa_conspred = 'vespa.scripts.conspred:main'
30+
vespa_run = 'vespa.scripts.vespa_run:main'
31+
vespa_emb = 'vespa.scripts.embedding:main'
3032

3133
[tool.poetry2conda]
3234
name = "vespa-env"

vespa/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
__author__ = "Tobias O, Michael H., Celine M."
2-
__copyright__ = "Copyright 2021, Rostlab"
1+
__author__ = "Tobias O, Michael H., Celine M., Duc Anh L."
2+
__copyright__ = "Copyright 2022, Rostlab"
33
__license__ = "AGPL-3.0-or-later"
4-
__version__ = "1.0.0"
5-
__maintainer__ = "Tobias O."
4+
__version__ = "0.9.0b"
5+
__maintainer__ = ["Duc Anh L.", "Tobias O."]
66
__email__ = ""
7-
__status__ = "Production"
7+
__status__ = "Production"

vespa/predict/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
from pathlib import Path
2222
import torch
2323

24+
VESPA_LOCATION = Path(__file__).resolve().parent.parent.parent
2425

2526
VESPA = "VESPA"
2627
VESPAL = "VESPAl"
2728

2829
MODEL_PATH_DICT = {
29-
VESPA: Path("models/VESPA-10LR_Cons_Blsm_Prob.pkl"),
30-
VESPAL: Path("models/VESPAl-10LR_Cons_Blsm.pkl"),
31-
"CONSCNN": Path("models/ProtT5cons_checkpoint.pt"),
30+
VESPA: Path(VESPA_LOCATION.joinpath("models/VESPA-10LR_Cons_Blsm_Prob.pkl")),
31+
VESPAL: Path(VESPA_LOCATION.joinpath("models/VESPAl-10LR_Cons_Blsm.pkl")),
32+
"CONSCNN": Path(VESPA_LOCATION.joinpath("models/ProtT5cons_checkpoint.pt")),
3233
}
3334

3435
OUTPUT_MAP_NAME = "map.json"
@@ -37,7 +38,7 @@
3738
# https://huggingface.co/transformers/v3.1.0/_modules/transformers/tokenization_t5.html
3839
SPIECE_UNDERLINE = "▁"
3940

40-
CACHE_DIR = Path("./cache")
41+
CACHE_DIR = "./cache"
4142

4243
TRANSFORMER_LINK = "Rostlab/prot_t5_xl_uniref50"
4344

@@ -55,3 +56,6 @@
5556

5657
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
5758
EMBEDDING_HALF_PREC = True
59+
60+
EMBED, LOGODDS = 0, 1
61+
EMB_MAX_SEQ_LEN, EMB_MAX_RESIDUES, EMB_MAX_BATCH, EMB_STORE_FREQ = 600, 8000, 5, 200

vespa/predict/embedding.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import numpy
2+
import torch
3+
import h5py
4+
from tqdm import tqdm
5+
from pathlib import Path
6+
7+
from vespa.predict.config import (
8+
DEVICE, CACHE_DIR, VERBOSE,
9+
EMBED, EMB_MAX_SEQ_LEN, EMB_MAX_RESIDUES, EMB_MAX_BATCH, EMB_STORE_FREQ
10+
)
11+
from vespa.predict.utils import parse_fasta_input
12+
from vespa.predict.utils_t5 import ProtT5
13+
14+
15+
class T5_Embed:
16+
def __init__(self, cache_dir):
17+
self.prott5 = ProtT5(cache_dir)
18+
self.saving_pattern = 'w'
19+
20+
def embed_from_fasta(self, fasta_path, output_path):
21+
self.saving_pattern = 'w'
22+
if VERBOSE:
23+
print('Load model: ProtT5')
24+
self.model, self.tokenizer = self.prott5.get_model(EMBED)
25+
if VERBOSE:
26+
print('Compute embeddings!')
27+
self.get_embeddings(fasta_path, output_path)
28+
29+
def embedding_init(self, fasta_path):
30+
seq_dict = parse_fasta_input(fasta_path)
31+
seq_dict = sorted(seq_dict.items(), key=lambda kv: len(seq_dict[kv[0]]), reverse=True)
32+
return seq_dict
33+
34+
def process_batch(self, batch, emb_dict):
35+
pdb_ids, seqs, seq_lens = zip(*batch)
36+
37+
token_encoding = self.tokenizer(seqs, add_special_tokens=True, padding='longest', return_tensors="pt")
38+
input_ids = token_encoding['input_ids'].to(DEVICE)
39+
attention_mask = token_encoding['attention_mask'].to(DEVICE)
40+
41+
try:
42+
# batch-size x seq_len x embedding_dim
43+
with torch.no_grad():
44+
embedding_repr = self.model(input_ids, attention_mask=attention_mask)
45+
except RuntimeError:
46+
print("RuntimeError for {} (L={})".format(pdb_ids, seq_lens))
47+
return emb_dict
48+
49+
new_emb_dict = dict()
50+
for batch_idx, identifier in enumerate(pdb_ids):
51+
s_len = seq_lens[batch_idx]
52+
emb = embedding_repr.last_hidden_state[batch_idx, :s_len]
53+
new_emb_dict[identifier] = emb.detach().cpu().numpy().squeeze()
54+
55+
if new_emb_dict:
56+
emb_dict.update(new_emb_dict)
57+
return emb_dict
58+
59+
def save_embeddings(self, output_path, emb_dict):
60+
Path(str(output_path.absolute())).parent.mkdir(parents=True, exist_ok=True)
61+
with h5py.File(str(output_path.absolute()), self.saving_pattern) as hf:
62+
for sequence_id, embedding in emb_dict.items():
63+
hf.create_dataset(sequence_id, data=embedding)
64+
self.saving_pattern = 'a'
65+
66+
def get_embeddings(self, fasta_path, output_path):
67+
seq_dict = self.embedding_init(fasta_path)
68+
69+
emb_dict = dict()
70+
batch, n_res_batch = [], 0
71+
72+
for seq_idx, (pdb_id, seq) in tqdm(enumerate(seq_dict, 1), total=len(seq_dict)):
73+
seq_len = len(seq)
74+
seq = ' '.join(list(seq))
75+
76+
if seq_len >= EMB_MAX_SEQ_LEN:
77+
emb_dict = self.process_batch([(pdb_id, seq, seq_len)], emb_dict)
78+
else:
79+
if len(batch) >= EMB_MAX_BATCH or n_res_batch >= EMB_MAX_RESIDUES:
80+
emb_dict = self.process_batch(batch, emb_dict)
81+
batch = []
82+
n_res_batch = 0
83+
84+
batch.append((pdb_id, seq, seq_len))
85+
n_res_batch += seq_len
86+
87+
if len(emb_dict) > EMB_STORE_FREQ:
88+
self.save_embeddings(output_path, emb_dict)
89+
emb_dict = dict()
90+
91+
if batch:
92+
emb_dict = self.process_batch(batch, emb_dict)
93+
94+
if emb_dict:
95+
self.save_embeddings(output_path, emb_dict)

vespa/predict/logodds.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import torch
2929
import h5py
3030
from tqdm import tqdm
31-
from transformers import T5ForConditionalGeneration, T5Tokenizer
3231

3332
from vespa.predict.config import (
3433
CACHE_DIR,
@@ -40,8 +39,10 @@
4039
SPIECE_UNDERLINE,
4140
VERBOSE,
4241
DEVICE,
42+
LOGODDS
4343
)
4444
from vespa.predict import utils
45+
from vespa.predict.utils_t5 import ProtT5
4546

4647

4748
if VERBOSE:
@@ -58,8 +59,8 @@ class _ProbaVector:
5859

5960
class T5_condProbas:
6061
def __init__(self, cache_dir):
61-
self.cache_dir = cache_dir
62-
self.tokenizer = self.get_tokenizer()
62+
self.prott5 = ProtT5(cache_dir)
63+
self.tokenizer = self.prott5.get_tokenizer()
6364
self.AAs = MUTANT_ORDER + "X"
6465
self.AA2class = {AA: idx for idx, AA in enumerate(self.AAs)}
6566
self.class2AA = {idx: AA for idx, AA in enumerate(self.AAs)}
@@ -69,24 +70,6 @@ def __init__(self, cache_dir):
6970
]
7071
self.softmax = torch.nn.Softmax(dim=0)
7172

72-
def get_model(self):
73-
model = T5ForConditionalGeneration.from_pretrained(
74-
TRANSFORMER_LINK, cache_dir=self.cache_dir
75-
)
76-
model = model.eval()
77-
model = model.to(DEVICE)
78-
vocab = T5Tokenizer.from_pretrained(
79-
TRANSFORMER_LINK, do_lower_case=False, cache_dir=self.cache_dir
80-
)
81-
82-
return model, vocab
83-
84-
def get_tokenizer(self):
85-
vocab = T5Tokenizer.from_pretrained(
86-
TRANSFORMER_LINK, do_lower_case=False, cache_dir=self.cache_dir
87-
)
88-
return vocab
89-
9073
def reconstruct_sequence(self, probs):
9174
return [self.class2AA[yhat] for yhat in probs.argmax(axis=1)]
9275

@@ -116,7 +99,7 @@ def get_proba_dict(self, seq_dict, mutation_gen: utils.MutationGenerator):
11699
Compute for all residues in a protein the conditional probabilities for reconstructing single, masked tokens.
117100
"""
118101

119-
self.model, self.tokenizer = self.get_model()
102+
self.model, self.tokenizer = self.prott5.get_model(LOGODDS)
120103

121104
result_dict = dict()
122105

vespa/predict/utils_t5.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Tokenizer
2+
from transformers import logging
3+
logging.set_verbosity_error()
4+
from vespa.predict.config import (
5+
TRANSFORMER_LINK,
6+
DEVICE,
7+
EMBED, EMBEDDING_HALF_PREC, LOGODDS
8+
)
9+
10+
11+
class ProtT5:
12+
def __init__(self, cache_dir):
13+
self.cache_dir = cache_dir
14+
15+
def get_model(self, model_usage: EMBED | LOGODDS):
16+
if model_usage == EMBED:
17+
model = T5EncoderModel.from_pretrained(
18+
TRANSFORMER_LINK, cache_dir=self.cache_dir
19+
)
20+
if EMBEDDING_HALF_PREC:
21+
model = model.half()
22+
elif model_usage == LOGODDS:
23+
model = T5ForConditionalGeneration.from_pretrained(
24+
TRANSFORMER_LINK, cache_dir=self.cache_dir
25+
)
26+
else:
27+
raise NotImplementedError(
28+
"The intended use of ProtT5 is not implemented."
29+
)
30+
model = model.eval()
31+
model = model.to(DEVICE)
32+
return model, self.get_tokenizer()
33+
34+
def get_tokenizer(self):
35+
tokenizer = T5Tokenizer.from_pretrained(
36+
TRANSFORMER_LINK, do_lower_case=False, cache_dir=self.cache_dir
37+
)
38+
return tokenizer

vespa/scripts/conspred.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
# Lib Imports
3333
import torch.utils.data
34-
from vespa.predict.config import MODEL_PATH_DICT
34+
from vespa.predict.config import MODEL_PATH_DICT, VERBOSE
3535

3636
# Module Imports
3737
from vespa.predict.conspred import ProtT5Cons, get_dataloader
@@ -50,7 +50,7 @@ def create_arg_parser():
5050

5151
# Required positional argument
5252
parser.add_argument(
53-
"Input",
53+
"input",
5454
type=Path,
5555
help="A path to a h5 embedding file, containing per-residue ProtT5 embeddings.",
5656
)
@@ -71,15 +71,15 @@ def create_arg_parser():
7171
"--checkpoint",
7272
required=False,
7373
type=Path,
74-
default= None,
74+
default=None,
7575
help="A path for the pre-trained checkpoint for the conservation CNN",
7676
)
7777

7878
# Optional argument
7979
parser.add_argument(
8080
"--output_probs",
8181
type=bool,
82-
default=True,
82+
default=True,
8383
action=argparse.BooleanOptionalAction,
8484
help="Output probabilities for all classes, not only class with highest probability. The probabilities are stored in an h5 file with a dataset per-protein of shape Lx20 (L being the protein length). This output is written to <output_prefix>_probs.h5)",
8585
)
@@ -88,37 +88,60 @@ def create_arg_parser():
8888
parser.add_argument(
8989
"--output_classes",
9090
type=bool,
91-
default=True,
91+
default=False,
9292
action=argparse.BooleanOptionalAction,
9393
help="Output the conservation class prediction per residue in a fasta-like format with comma-separated per-residue classes. The output is written to <output_prefix>_class.fast)",
9494
)
9595

9696
return parser
9797

9898

99-
def main():
100-
parser = create_arg_parser()
101-
args = parser.parse_args()
102-
103-
checkpoint_path = args.checkpoint if args.checkpoint else Path(MODEL_PATH_DICT["CONSCNN"])
104-
out_prefix = args.output_prefix
105-
out_class = Path(out_prefix + "_class.fasta")
106-
out_probs = Path(out_prefix + "_probs.h5")
107-
108-
write_probs = args.output_probs
109-
write_classes = args.output_classes
110-
111-
out_class.parent.mkdir(parents=True, exist_ok=True)
112-
99+
def run_conspred(seq_path, checkpoint_path, write_probs, write_classes, out_prefix):
113100
try:
114-
embeddings = h5py.File(str(args.Input.resolve()), 'r')
101+
if VERBOSE:
102+
print(f" Start Conservation Prediction ".center(80, "#"))
103+
embeddings = h5py.File(str(seq_path.resolve()), "r")
115104
data_loader = get_dataloader(embeddings, batch_size=128)
105+
if VERBOSE:
106+
print(f" Load model! ")
116107
conspred = ProtT5Cons(checkpoint_path)
117-
predictions = conspred.conservation_prediction(data_loader, prob_return=write_probs, class_return=write_classes)
108+
if VERBOSE:
109+
print(f" Predict Conservation! ")
110+
predictions = conspred.conservation_prediction(
111+
data_loader, prob_return=write_probs, class_return=write_classes
112+
)
113+
114+
out_class = Path(str(out_prefix) + "_class.fasta")
115+
out_probs = Path(str(out_prefix) + "_probs.h5")
116+
out_class.parent.mkdir(parents=True, exist_ok=True)
117+
118118
if write_classes:
119119
conspred.write_cons_class_pred(predictions, out_class)
120120
if write_probs:
121121
conspred.write_probabilities(predictions, out_probs)
122+
123+
if VERBOSE:
124+
print(f">> Finished Conservation Prediction!")
122125
finally:
123126
embeddings.close()
124127

128+
129+
def main():
130+
parser = create_arg_parser()
131+
args = parser.parse_args()
132+
133+
arguments = {
134+
"seq_path": args.input,
135+
"checkpoint_path": args.checkpoint
136+
if args.checkpoint
137+
else Path(MODEL_PATH_DICT["CONSCNN"]),
138+
"out_prefix": args.output_prefix,
139+
"write_probs": args.output_probs,
140+
"write_classes": args.output_classes,
141+
}
142+
143+
run_conspred(**arguments)
144+
145+
146+
if __name__ == "__main__":
147+
main()

0 commit comments

Comments
 (0)