-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
The following works fine for me:
python eat.py --lookup test.fasta --queries test.fasta --output test/
But when I add --use_tucker 1 I get:
Start loading ProtT5...
Finished loading Rostlab/prot_t5_xl_half_uniref50-enc in 28.2[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 1.4[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 0.7[s]
No existing model found. Start downloading pre-trained ProtTucker(ProtT5)...
Loading Tucker checkpoint from: temp/tucker_weights.pt
Traceback (most recent call last):
File "/home/jgreener/soft/EAT/eat.py", line 515, in <module>
main()
File "/home/jgreener/soft/EAT/eat.py", line 496, in main
eater = EAT(lookup_p, query_p, output_d,
File "/home/jgreener/soft/EAT/eat.py", line 220, in __init__
self.lookup_embs = self.tucker_embeddings(self.lookup_embs)
File "/home/jgreener/soft/EAT/eat.py", line 245, in tucker_embeddings
dataset = model.single_pass(dataset)
File "/home/jgreener/soft/EAT/eat.py", line 36, in single_pass
return self.tucker(x)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Float but found Half
I am on Python 3.9.16, PyTorch 1.10.0, h5py 3.6.0, numpy 1.22.0, scikit-learn 0.24.2 and transformers 4.17.0. test.fasta is uploaded as test.txt.
Metadata
Metadata
Assignees
Labels
No labels