PyTorch's implementation of code2seq model.
You can easily install model through the PIP:
pip install code2seq
To prepare your own dataset with a storage format supported by this implementation, use on the following:
- Original dataset preprocessing from vanilla repository
astminer
: the tool for mining path-based representation and more with multiple language support.PSIMiner
: the tool for extracting PSI trees from IntelliJ Platform and creating datasets from them.
Dataset (with link) | Checkpoint | # epochs | F1-score | Precision | Recall | ChrF |
---|---|---|---|---|---|---|
Java-small | link | 11 | 41.49 | 54.26 | 33.59 | 30.21 |
Java-med | link | 10 | 48.17 | 58.87 | 40.76 | 42.32 |
The model is fully configurable by standalone YAML file. Navigate to config directory to see examples of configs.
Model training may be done via PyTorch Lightning trainer. See it documentation for more information.
from argparse import ArgumentParser
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from code2seq.data.path_context_data_module import PathContextDataModule
from code2seq.model import Code2Seq
def train(config: DictConfig):
# Define data module
data_module = PathContextDataModule(config.data_folder, config.data)
# Define model
model = Code2Seq(
config.model,
config.optimizer,
data_module.vocabulary,
config.train.teacher_forcing
)
# Define hyper parameters
trainer = Trainer(max_epochs=config.train.n_epochs)
# Train model
trainer.fit(model, datamodule=data_module)
if __name__ == "__main__":
__arg_parser = ArgumentParser()
__arg_parser.add_argument("config", help="Path to YAML configuration file", type=str)
__args = __arg_parser.parse_args()
__config = OmegaConf.load(__args.config)
train(__config)