Skip to content

ISYSLAB-HUST/ProtFlash

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ProtFlash

A lightweight protein language model for protein representation learning.

PyPI - Version PyPI - Python Version GitHub - LICENSE PyPI - Downloads Wheel build

Table of contents

Installation

Prerequisite: Install PyTorch first.

Choose one of the following installation methods:

# Latest version from GitHub
pip install git+https://github.com/isyslab-hust/ProtFlash

# Stable release from PyPI
pip install ProtFlash

Model details

Model Parameters Hidden size Pretraining dataset Proteins Download
ProtFlash-base 174M 768 UniRef50 51M ProtFlash-base
ProtFlash-small 79M 512 UniRef50 51M ProtFlash-small

Usage

Protein sequence embedding

import torch
from ProtFlash.pretrain import load_prot_flash_base
from ProtFlash.utils import batchConverter

data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
]

ids, batch_token, lengths = batchConverter(data)
model = load_prot_flash_base()

with torch.no_grad():
    token_embedding = model(batch_token, lengths)

# Generate per-sequence representations via averaging
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_embedding[i, 0 : len(seq) + 1].mean(0))

Load weight files

import torch
from ProtFlash.model import FLASHTransformer

model_data = torch.load(your_parameters_file)
hyper_parameter = model_data["hyper_parameters"]

model = FLASHTransformer(
    hyper_parameter["dim"],
    hyper_parameter["num_tokens"],
    hyper_parameter["num_layers"],
    group_size=hyper_parameter["num_tokens"],
    query_key_dim=hyper_parameter["qk_dim"],
    max_rel_dist=hyper_parameter["max_rel_dist"],
    expansion_factor=hyper_parameter["expansion_factor"],
)

model.load_state_dict(model_data["state_dict"])

License

This project is licensed under the MIT License. See LICENSE for details.

Citation

If you use this code or one of the pretrained models in your research, please cite:

@article{wang2023deciphering,
  title={Deciphering the protein landscape with ProtFlash, a lightweight language model},
  author={Wang, Lei and Zhang, Hui and Xu, Wei and Xue, Zhidong and Wang, Yan},
  journal={Cell Reports Physical Science},
  volume={4},
  number={10},
  year={2023},
  publisher={Elsevier}
}