diff --git a/notebooks/OpenFoldLocal.ipynb b/notebooks/OpenFoldLocal.ipynb new file mode 100644 index 000000000..36f72ccd1 --- /dev/null +++ b/notebooks/OpenFoldLocal.ipynb @@ -0,0 +1,360 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OpenFold Local Notebook\n", + "\n", + "Provides the flexibility to run inference on a target sequence using a local Docker installation of [OpenFold](https://github.com/aqlaboratory/openfold), along with the convenience of visualizing results using the same plots from the OpenFold Colab Notebook.\n", + "\n", + "This notebook utilizes the provided utility functions to execute OpenFold via Docker. It includes logic to handle results, allowing you to experiment with different parameters, reuse computed MSAs, filter the best model, and plot metrics. It also supports asynchronous and long-running executions.\n", + "\n", + "If you have access to a machine and want to perform quick inference and visualize results, this notebook offers several useful features:\n", + "\n", + "- Use precomputed alignments, enabling you to run inference with different model parameters for result comparison.\n", + "- Identify the best model and generate metric plots.\n", + "- Manage long-running executions.\n", + "- Work with large datasets by splitting your input and performing asynchronous runs using threads on multiple GPUs.\n", + "\n", + "While you can achieve this entirely through Docker commands in the terminal, you would need to code or adjust the Colab functions to work with local data. This notebook gives you a head start.\n", + "\n", + "**Citing this work**\n", + "\n", + "Any publication that discloses findings arising from using this notebook should [cite](https://github.com/deepmind/alphafold/#citing-this-work) DeepMind's [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2).\n", + "\n", + "**Licenses**\n", + "\n", + "This Notebook supports inference with the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license), made available under the Creative Commons Attribution 4.0 International ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n", + "\n", + "**More information**\n", + "\n", + "You can find more information about how AlphaFold/OpenFold works in DeepMind's two Nature papers:\n", + "\n", + "* [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2)\n", + "* [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1)\n", + "\n", + "FAQ on how to interpret AlphaFold/OpenFold predictions are [here](https://alphafold.ebi.ac.uk/faq)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup the notebook\n", + "\n", + "Fist, build Openfold using Docker. Follow this [guide](https://openfold.readthedocs.io/en/latest/original_readme.html#building-and-using-the-docker-container).\n", + "\n", + "Then, go to the notebook folder\n", + "\n", + "`cd notebooks`\n", + "\n", + "Create an environment to run Jupyter with the requirements\n", + "\n", + "`mamba create -n openfold_notebook python==3.10`\n", + "\n", + "Activate the environment\n", + "\n", + "`mamba activate openfold_notebook`\n", + "\n", + "Install the requirements\n", + "\n", + "`pip install -r src/requirements.txt`\n", + "\n", + "Start your Jupyter server in the current folder\n", + "\n", + "`jupyter lab . --ip=\"0.0.0.0\"`\n", + "\n", + "Access the notebook URL or connect remotely using VSCode.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running Inference \n", + "\n", + "**Inputs:** files or strings with sequences\n", + "\n", + "**Output:** \n", + "\n", + "```bash\n", + "data/ \n", + "├── run__/ # each is run stored with a random ID, this id can be use to re-run inference \n", + "│ ├── fasta_dir/ \n", + "│ │ ├── tmp/ # generated .fasta file per sequence\n", + "│ │ └── sequences.fasta # validated input sequences are merged into a .fasta file\n", + "│ └── output/\n", + "│ ├── alignments/ # one folder per sequence of resulted MSA\n", + "│ ├── msa_plots/ # one file per aligment .png\n", + "│ ├── predictions/ # inference results .pkl and .pdb files\n", + "│ ├── selected_predictions/ # selected best inferece and metrics plots\n", + "│ └── timings.json # inference time\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Initialize the client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import docker\n", + "from src.inference import InferenceClientOpenFold\n", + "\n", + "# You can also use a remote docker server \n", + "docker_client = docker.from_env()\n", + "\n", + "# i.e connect to the remote Docker daemon\n", + "# remote_docker_client = docker.DockerClient(base_url='tcp://:2375')\n", + "\n", + "# Initialize the OpenFold Docker client setting the database path \n", + "\n", + "databases_dir = \"/path/to/databases\"\n", + "\n", + "openfold_client = InferenceClientOpenFold(databases_dir, docker_client)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inference using a sequence string" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For multiple sequences, separate sequences with a colon `:`\n", + "input_string = \"DAGAQGAAIGSPGVLSGNVVQVPVHVPVNVCGNTVSVIGLLNPAFGNTCVNA:AGETGRTGVLVTSSATNDGDSGWGRFAG\"\n", + "\n", + "model_name = \"multimer\" # or \"monomer\"\n", + "weight_set = 'AlphaFold' # or 'OpenFold'\n", + "\n", + "# Run inference\n", + "run_id = openfold_client.run_inference(weight_set, model_name, inference_input=input_string)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inference using a fasta file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_file = \"/path/to/test.fasta\"\n", + "\n", + "run_id = openfold_client.run_inference(weight_set, model_name, inference_input=input_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inference using pre-computed aligments for a run_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"monomer\"\n", + "weight_set = 'OpenFold'\n", + "\n", + "openfold_client.run_inference(weight_set, model_name, use_precomputed_alignments=True, run_id=run_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Metrics and Visualizations " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Get the MSA Plots for one sequence in a run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from src.plot_msas import get_msa_plot\n", + "\n", + "# Provide the fasta sequence id and the run_id\n", + "get_msa_plot(run_id, fasta_id=\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# To get all sequence aligments\n", + "get_msa_plot(run_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Get the best prediction by pLDDT and metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from src.metrics import get_metrics_and_visualizations, plot_plddt_legend\n", + "\n", + "model_name = \"multimer\"\n", + "weight_set = 'AlphaFold'\n", + "\n", + "plot_plddt_legend()\n", + "get_metrics_and_visualizations(run_id, weight_set, model_name, \"\", relax_prediction=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Concurrent/Async inference\n", + "\n", + "If you have multiple cards and want to run concurrent inference for experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from concurrent.futures import ProcessPoolExecutor, as_completed\n", + "import signal\n", + "\n", + "def experiment_1():\n", + " print(\"Experiment 1 is running\")\n", + " input_file = \"/path/to/experiment_1.fasta\"\n", + " gpu = \"cuda:0\"\n", + " model_name = \"multimer\"\n", + " weight_set = 'AlphaFold'\n", + " run_id = openfold_client.run_inference(weight_set, model_name, inference_input=input_file, gpu=gpu) \n", + " return \"Experiment 1 completed\"\n", + "\n", + "def experiment_2():\n", + " print(\"Experiment 2 is running\")\n", + " input_file = \"/path/to/experiment_2.fasta\"\n", + " gpu = \"cuda:1\"\n", + " model_name = \"monomer\"\n", + " weight_set = 'OpenFold'\n", + " run_id = openfold_client.run_inference(weight_set, model_name, inference_input=input_file, gpu=gpu)\n", + " return \"Experiment 2 completed\"\n", + "\n", + "experiments = [experiment_1, experiment_2]\n", + "\n", + "# Function to handle keyboard interrupt\n", + "def signal_handler(sig, frame):\n", + " print(\"Interrupt received, stopping...\")\n", + " raise KeyboardInterrupt\n", + "\n", + "# Register the signal handler\n", + "signal.signal(signal.SIGINT, signal_handler)\n", + "\n", + "try:\n", + " # Execute tasks in parallel\n", + " with ProcessPoolExecutor() as executor:\n", + " futures = [executor.submit(task) for task in experiments]\n", + " results = []\n", + " for future in as_completed(futures):\n", + " results.append(future.result())\n", + " print(\"Results:\", results)\n", + "except KeyboardInterrupt:\n", + " print(\"Execution interrupted by user.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# License and Disclaimer\n", + "\n", + "This notebook and other information provided is for theoretical modelling only, caution should be exercised in its use. It is provided ‘as-is’ without any warranty of any kind, whether expressed or implied. Information is not intended to be a substitute for professional medical advice, diagnosis, or treatment, and does not constitute medical or other professional advice.\n", + "\n", + "## AlphaFold/OpenFold Code License\n", + "\n", + "Copyright 2021 AlQuraishi Laboratory\n", + "\n", + "Copyright 2021 DeepMind Technologies Limited.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0.\n", + "\n", + "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", + "\n", + "## Model Parameters License\n", + "\n", + "DeepMind's AlphaFold parameters are made available under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode\n", + "\n", + "\n", + "## Third-party software\n", + "\n", + "Use of the third-party software, libraries or code referred to in this notebook may be governed by separate terms and conditions or license provisions. Your use of the third-party software, libraries or code is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use.\n", + "\n", + "\n", + "## Mirrored Databases\n", + "\n", + "The following databases have been mirrored by DeepMind, and are available with reference to the following:\n", + "* UniRef90: v2021\\_03 (unmodified), by The UniProt Consortium, available under a [Creative Commons Attribution-NoDerivatives 4.0 International License](http://creativecommons.org/licenses/by-nd/4.0/).\n", + "* MGnify: v2019\\_05 (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).\n", + "* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/src/__init__.py b/notebooks/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/notebooks/src/docker_runner.py b/notebooks/src/docker_runner.py new file mode 100644 index 000000000..f048ba415 --- /dev/null +++ b/notebooks/src/docker_runner.py @@ -0,0 +1,118 @@ +import os + +class DockerRunner: + def __init__(self, client, run_path, database_dir, default_fasta_dir_name): + self.client = client + self.run_path = run_path + self.database_dir = database_dir + self.default_fasta_dir_name = default_fasta_dir_name + + def _setup_volumes(self): + return { + self.run_path: {'bind': '/run_path', 'mode': 'rw'}, + self.database_dir: {'bind': '/database', 'mode': 'rw'} + } + + def _stream_logs(self, container): + for log in container.logs(stream=True): + print(log.decode('utf-8'), end='') + + def run_inference_for_model(self, weight_set, config_preset, gpu, use_precomputed_alignments=True): + command = self._build_inference_command(weight_set, config_preset, gpu, use_precomputed_alignments) + self._run_container(command) + + def run_msa_alignment(self, cpus_per_task=32, no_tasks=1): + precomputed_alignments_dir = f"{self.run_path}/output/alignments" + os.makedirs(precomputed_alignments_dir, exist_ok=True) + command = self._build_msa_alignment_command(cpus_per_task, no_tasks) + self._run_container(command) + + def _build_inference_command(self, weight_set, config_preset, gpu, use_precomputed_alignments): + fasta_dir = f"/run_path/{self.default_fasta_dir_name}/tmp" + output_dir = "/run_path/output" + precomputed_alignments_dir = "/run_path/output/alignments" + + command = [ + "python3", "/opt/openfold/run_pretrained_openfold.py", + fasta_dir, + "/database/pdb_mmcif/mmcif_files/", + "--uniref90_database_path", "/database/uniref90/uniref90.fasta", + "--mgnify_database_path", "/database/mgnify/mgy_clusters_2022_05.fa", + "--pdb70_database_path", "/database/pdb70/pdb70", + "--uniclust30_database_path", "/database/uniclust30/uniclust30_2018_08/uniclust30_2018_08", + "--bfd_database_path", "/database/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt", + "--jackhmmer_binary_path", "/opt/conda/bin/jackhmmer", + "--hhblits_binary_path", "/opt/conda/bin/hhblits", + "--hhsearch_binary_path", "/opt/conda/bin/hhsearch", + "--kalign_binary_path", "/opt/conda/bin/kalign", + "--pdb_seqres_database_path", "/database/pdb_seqres/pdb_seqres.txt", + "--uniref30_database_path", "/database/uniref30/UniRef30_2021_03", + "--uniprot_database_path", "/database/uniprot/uniprot.fasta", + "--hmmsearch_binary_path", "/opt/conda/bin/hmmsearch", + "--hmmbuild_binary_path", "/opt/conda/bin/hmmbuild", + "--model_device", gpu, + "--save_outputs", + "--output_dir", output_dir + ] + + if weight_set == "AlphaFold": + command.extend(["--config_preset", config_preset]) + + if weight_set == "OpenFold": + command.extend(["--openfold_checkpoint_path", f"/database/openfold_params/{config_preset}"]) + + if use_precomputed_alignments: + command.extend(["--use_precomputed_alignments", precomputed_alignments_dir]) + + return command + + def _build_msa_alignment_command(self, cpus_per_task, no_tasks): + fasta_dir = f"/run_path/{self.default_fasta_dir_name}/tmp" + precomputed_alignments_dir_docker = "/run_path/output/alignments" + + command = [ + "python3", "/opt/openfold/scripts/precompute_alignments.py", + fasta_dir, + precomputed_alignments_dir_docker, + "--uniref90_database_path", "/database/uniref90/uniref90.fasta", + "--mgnify_database_path", "/database/mgnify/mgy_clusters_2022_05.fa", + "--pdb70_database_path", "/database/pdb70/pdb70", + "--uniclust30_database_path", "/database/uniclust30/uniclust30_2018_08/uniclust30_2018_08", + "--bfd_database_path", "/database/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt", + "--cpus_per_task", str(cpus_per_task), + "--no_tasks", str(no_tasks), + "--jackhmmer_binary_path", "/opt/conda/bin/jackhmmer", + "--hhblits_binary_path", "/opt/conda/bin/hhblits", + "--hhsearch_binary_path", "/opt/conda/bin/hhsearch", + "--kalign_binary_path", "/opt/conda/bin/kalign" + ] + + return command + + def _run_container(self, command): + volumes = self._setup_volumes() + + try: + print("Running Docker container...") + container = self.client.containers.run( + image="openfold:latest", + command=command, + volumes=volumes, + runtime="nvidia", + remove=True, + detach=True, + stdout=True, + stderr=True + ) + + self._stream_logs(container) + + except docker.errors.ContainerError as e: + print(f"ContainerError: {e}") + raise e + except docker.errors.ImageNotFound as e: + print(f"ImageNotFound: {e}") + raise e + except docker.errors.APIError as e: + print(f"APIError: {e}") + raise e \ No newline at end of file diff --git a/notebooks/src/inference.py b/notebooks/src/inference.py new file mode 100644 index 000000000..b053e711c --- /dev/null +++ b/notebooks/src/inference.py @@ -0,0 +1,86 @@ +import os +import shutil +from .docker_runner import DockerRunner +from datetime import datetime +from .utils import validate_sequence, generate_random_run_id, write_sequences_to_fasta, generate_individual_sequence_files, get_run_folder_by_id, get_config_preset_list_for_model + +class InferenceClientOpenFold: + default_fasta_dir_name = 'fasta_dir' + + def __init__(self, database_dir, docker_client): + self.database_dir = database_dir + self.docker_client = docker_client + self.docker_runner = None + + def run_inference(self, weight_set, model_name, inference_input=None, use_precomputed_alignments=False, run_id=None, gpu="cuda:0"): + os.makedirs('data', exist_ok=True) + + if use_precomputed_alignments: + if not run_id: + raise ValueError("run_id is required when using pre-computed alignments.") + return self._run_with_precomputed_alignments(run_id, weight_set, model_name, gpu) + + if not inference_input: + raise ValueError("inference_input is required to compute alignments.") + + if not run_id: + run_id = generate_random_run_id() + + return self._run_new_inference(run_id, weight_set, model_name, inference_input, gpu) + + def _run_new_inference(self, run_id, weight_set, model_name, inference_input, gpu): + # Get the current date and time in dd_MM_yy_hh_mm_ss format + current_datetime = datetime.now().strftime('%d_%m_%y_%H_%M_%S') + run_path = os.path.join(os.getcwd(), 'data', f'run_{current_datetime}_{run_id}') + # run_path = os.path.join(os.getcwd(), 'data', f'run_{run_id}') + fasta_dir_root_path = os.path.join(run_path, self.default_fasta_dir_name) + self._initialize_docker_runner(run_path) + + self._prepare_fasta_directory(fasta_dir_root_path, inference_input, weight_set) + generate_individual_sequence_files(run_path) + + self.run_msa_alignment() + self.run_model_with_preset(run_path, weight_set, model_name, gpu) + + return run_id + + def _run_with_precomputed_alignments(self, run_id, weight_set, model_name, gpu): + run_path = get_run_folder_by_id(run_id) + print(f"Using pre-computed alignments from: {run_path}") + if not os.path.isdir(run_path): + raise ValueError(f"Provided Run ID '{run_id}' does not exist.") + + self._initialize_docker_runner(run_path) + self.run_model_with_preset(run_path, weight_set, model_name, gpu) + + return run_id + + def _initialize_docker_runner(self, run_path): + self.docker_runner = DockerRunner(self.docker_client, run_path, self.database_dir, self.default_fasta_dir_name) + + def _prepare_fasta_directory(self, fasta_dir_root_path, inference_input, weight_set): + os.makedirs(fasta_dir_root_path, exist_ok=True) + print(f"Fasta root directory: {fasta_dir_root_path}") + + if os.path.isfile(inference_input): + self._copy_input_file(inference_input, fasta_dir_root_path) + else: + self._write_sequence_to_fasta(inference_input, fasta_dir_root_path, weight_set) + + def _copy_input_file(self, input_file, fasta_dir_root_path): + destination = os.path.join(fasta_dir_root_path, 'sequences.fasta') + shutil.copy2(input_file, destination) + print(f"Using input file: {input_file}") + + def _write_sequence_to_fasta(self, sequence_string, fasta_dir_root_path, weight_set): + validated_sequence = validate_sequence(sequence_string, weight_set) + fasta_file = write_sequences_to_fasta(validated_sequence.split(':'), fasta_dir_root_path) + print(f"Sequences written to FASTA file: {fasta_file}") + + def run_model_with_preset(self, run_path, weight_set, model_name, gpu): + config_preset_list = get_config_preset_list_for_model(weight_set, model_name) + for config_preset in config_preset_list: + self.docker_runner.run_inference_for_model(weight_set, config_preset, gpu) + + def run_msa_alignment(self, cpus_per_task=32, no_tasks=1): + self.docker_runner.run_msa_alignment(cpus_per_task=cpus_per_task, no_tasks=no_tasks) \ No newline at end of file diff --git a/notebooks/src/metrics.py b/notebooks/src/metrics.py new file mode 100644 index 000000000..8d755bb93 --- /dev/null +++ b/notebooks/src/metrics.py @@ -0,0 +1,205 @@ +import os +import json +import numpy as np +import pickle +import matplotlib +import matplotlib.pyplot as plt +import py3Dmol +from openfold.np import protein +from openfold.np.relax.utils import overwrite_b_factors +from .utils import get_config_preset_list_for_model, get_run_folder_by_id + +# Color bands for visualizing plddt +PLDDT_BANDS = [ + (0, 50, '#FF7D45'), + (50, 70, '#FFDB13'), + (70, 90, '#65CBF3'), + (90, 100, '#0053D6') +] + +plddts = {} +pae_outputs = {} +weighted_ptms = {} + + +def load_prediction_results(pkl_dir, fasta_id, model_name): + with open(f'{pkl_dir}/{fasta_id}_{model_name}_output_dict.pkl', 'rb') as f: + prediction_result = pickle.load(f) + return prediction_result + +def get_metrics_and_visualizations(run_id, weight_set, model_name, fasta_id, relax_prediction=False): + best_model_name, to_visualize_pdb_path = get_best_prediction_by_plddt(run_id, weight_set, model_name, fasta_id, relax_prediction) + plot_3d_structure(to_visualize_pdb_path) + get_plddt_pae(f"{get_run_folder_by_id(run_id)}/output/selected_predictions", best_model_name, fasta_id) + + +def get_best_prediction_by_plddt(run_id, weight_set, model_name, fasta_id, relax_prediction=False): + global plddts, pae_outputs, weighted_ptms, best_unrelaxed_prot + + run_folder = get_run_folder_by_id(run_id) + + pkl_dir = os.path.join(f"{run_folder}/output", 'predictions') + output_dir = os.path.join(f"{run_folder}/output", 'selected_predictions') + os.makedirs(output_dir, exist_ok=True) + + config_preset_list = get_config_preset_list_for_model(weight_set, model_name) + + for config_preset_item in config_preset_list: + prediction_result = load_prediction_results(pkl_dir, fasta_id, config_preset_item) + mean_plddt = prediction_result['plddt'].mean() + + plddts[config_preset_item] = prediction_result['plddt'] + + if model_name == 'multimer': + pae_outputs[config_preset_item] = (prediction_result['predicted_aligned_error'], prediction_result['max_predicted_aligned_error']) + weighted_ptms[config_preset_item] = prediction_result['weighted_ptm_score'] + + final_atom_mask = prediction_result['final_atom_mask'] + + # Find the best model according to the mean pLDDT. + if model_name == 'monomer' or model_name == 'monomer_ptm': + best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean()) + elif model_name =='multimer': + best_model_name = max(weighted_ptms.keys(), key=lambda x: weighted_ptms[x]) + + print(f"Best model is: {best_model_name}, relaxed: {relax_prediction}") + + best_model_plddts = plddts[best_model_name].mean() + + print(f"Mean PLDDT: {best_model_plddts} ") + + # Save the mean pLDDT + pred_output_path = os.path.join(output_dir, f'{fasta_id}_mean_plddt.txt') + with open(pred_output_path, 'w') as f: + f.write(str(best_model_plddts)) + + unrelaxed_file_name = f'{pkl_dir}/{fasta_id}_{best_model_name}_unrelaxed.pdb' + + if relax_prediction: + pdb_file_name = f'{pkl_dir}/{fasta_id}_{best_model_name}_relaxed.pdb' + else: + pdb_file_name = unrelaxed_file_name + + with open(pdb_file_name, 'r') as file: + best_pdb = file.read() + + with open(unrelaxed_file_name, 'r') as file: + best_unrelaxed_pdb_str = file.read() + + best_unrelaxed_prot = protein.from_pdb_string(best_unrelaxed_pdb_str) + + pred_output_path = os.path.join(output_dir, f'{fasta_id}_selected_prediction.pdb') + with open(pred_output_path, 'w') as f: + f.write(best_pdb) + + banded_b_factors = [] + for plddt in plddts[best_model_name]: + for idx, (min_val, max_val, _) in enumerate(PLDDT_BANDS): + if plddt >= min_val and plddt <= max_val: + banded_b_factors.append(idx) + break + banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask + to_visualize_pdb = overwrite_b_factors(best_pdb, banded_b_factors) + + visualize_output_path = os.path.join(output_dir, f'{fasta_id}_selected_prediction_visualize.pdb') + with open(visualize_output_path, 'w') as f: + f.write(to_visualize_pdb) + + pae_output_path = os.path.join(output_dir, f'{fasta_id}_predicted_aligned_error.json') + if pae_outputs: + rounded_errors = np.round(pae_outputs[best_model_name][0].astype(np.float64), decimals=1) + indices = np.indices((len(rounded_errors), len(rounded_errors))) + 1 + indices_1 = indices[0].flatten().tolist() + indices_2 = indices[1].flatten().tolist() + pae_data = json.dumps([{ + 'residue1': indices_1, + 'residue2': indices_2, + 'distance': rounded_errors.flatten().tolist(), + 'max_predicted_aligned_error': pae_outputs[best_model_name][1].item() + }], + indent=None, + separators=(',', ':')) + with open(pae_output_path, 'w') as f: + f.write(pae_data) + + return best_model_name, visualize_output_path + + +def get_plddt_pae(output_dir, best_model_name, fasta_id): + if pae_outputs: + num_plots = 2 + else: + num_plots = 1 + + plt.figure(figsize=[8 * num_plots, 6]) + plt.subplot(1, num_plots, 1) + plt.plot(plddts[best_model_name]) + plt.title('Predicted LDDT') + plt.xlabel('Residue') + plt.ylabel('pLDDT') + + if num_plots == 2: + plt.subplot(1, 2, 2) + pae, max_pae = list(pae_outputs.values())[0] + plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r') + plt.colorbar(fraction=0.046, pad=0.04) + + total_num_res = best_unrelaxed_prot.residue_index.shape[-1] + chain_ids = best_unrelaxed_prot.chain_index + for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]): + if chain_boundary.size: + plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red') + plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red') + plt.title('Predicted Aligned Error') + plt.xlabel('Scored residue') + plt.ylabel('Aligned residue') + + # Save the pLDDT and predicted aligned error plots as PNG + plt.savefig(os.path.join(output_dir, f'{fasta_id}_plddt_pae.png')) + print(f"Saved pLDDT and predicted aligned error plots as PNG to {output_dir}") + return plt + +def plot_plddt_legend(): + thresh = [ + 'Very low (pLDDT < 50)', + 'Low (70 > pLDDT > 50)', + 'Confident (90 > pLDDT > 70)', + 'Very high (pLDDT > 90)'] + + colors = [x[2] for x in PLDDT_BANDS] + + plt.figure(figsize=(2, 2)) + for c in colors: + plt.bar(0, 0, color=c) + plt.legend(thresh, frameon=False, loc='center', fontsize=20) + plt.xticks([]) + plt.yticks([]) + ax = plt.gca() + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['left'].set_visible(False) + ax.spines['bottom'].set_visible(False) + plt.title('Model Confidence', fontsize=20, pad=20) + plt.show() + return plt + +def plot_3d_structure(pdb_file_path): + """Plots the 3D structure for use in a Jupyter notebook.""" + show_sidechains = True + + with open(pdb_file_path, 'r') as f: + pdb_content = f.read() + + color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)} + view = py3Dmol.view(width=800, height=600) + view.addModelsAsFrames(pdb_content) + style = {'cartoon': { + 'colorscheme': { + 'prop': 'b', + 'map': color_map} + }} + if show_sidechains: + style['stick'] = {} + view.setStyle({'model': -1}, style) + view.zoomTo() + view.show() \ No newline at end of file diff --git a/notebooks/src/plot_msas.py b/notebooks/src/plot_msas.py new file mode 100644 index 000000000..310a46071 --- /dev/null +++ b/notebooks/src/plot_msas.py @@ -0,0 +1,115 @@ +import os +import argparse +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from openfold.data import parsers +from .utils import get_run_folder_by_id, read_fasta_file + + +def get_msa_plot(run_id, fasta_id=None): + + run_folder = get_run_folder_by_id(run_id) + input_msa_folder = os.path.join(f"{run_folder}/output", 'alignments') + output_folder = os.path.join(f"{run_folder}/output", 'msa_plots') + os.makedirs(output_folder, exist_ok=True) + + + if fasta_id: + fasta_file_path = os.path.join(f"{run_folder}/fasta_dir/tmp", f"{fasta_id}.fasta") + sequence = read_fasta_file(fasta_file_path) + output_file_plot = os.path.join(output_folder, f"{fasta_id}_msa_plot.png") + create_msa_plot(fasta_id, sequence, output_file_plot, input_msa_folder) + + else: + # Iterate over each subfolder in output_msa_openfold + for subfolder in os.listdir(input_msa_folder): + + fasta_id = subfolder + fasta_file_path = os.path.join(f"{run_folder}/fasta_dir/tmp", f"{fasta_id}.fasta") + + if os.path.exists(fasta_file_path): + sequence = read_fasta_file(fasta_file_path) + output_file_plot = os.path.join(output_folder, f"{fasta_id}_msa_plot.png") + create_msa_plot(fasta_id, sequence, output_file_plot, input_msa_folder) + +def create_msa_plot(fasta_id, original_sequence, output_file, input_msa_folder): + # Path to the search results files + search_results_files = { + original_sequence: { + 'uniref90': os.path.join(input_msa_folder, f'{fasta_id}/uniref90_hits.sto'), + 'mgnify': os.path.join(input_msa_folder, f'{fasta_id}/mgnify_hits.sto'), + 'smallbfd': os.path.join(input_msa_folder, f'{fasta_id}/bfd_uniclust_hits.a3m'), + # Add other databases if needed + }, + } + + MAX_HITS_BY_DB = { + 'uniref90': 10000, + 'mgnify': 501, + 'smallbfd': 5000, + } + + msas_by_seq_by_db = {seq: {} for seq in search_results_files.keys()} + full_msa_by_seq = {seq: [] for seq in search_results_files.keys()} + + # Function to parse the MSA files + def parse_msa_file(file_path, file_type): + if file_type == 'sto': + with open(file_path, 'r') as f: + sto_content = f.read() + msa_obj = parsers.parse_stockholm(sto_content) + elif file_type == 'a3m': + with open(file_path, 'r') as f: + a3m_content = f.read() + msa_obj = parsers.parse_a3m(a3m_content) + return msa_obj + + # Load the search results from files + for seq_name, db_files in search_results_files.items(): + print(f'Loading results for sequence: {fasta_id}:{seq_name}') + for db_name, result_file in db_files.items(): + print(f' Loading database: {db_name}') + file_type = result_file.split('.')[-1] + msa_obj = parse_msa_file(result_file, file_type) + + msas, del_matrix, targets = msa_obj.sequences, msa_obj.deletion_matrix, msa_obj.descriptions + db_msas = parsers.Msa(msas, del_matrix, targets) + if db_msas: + if db_name in MAX_HITS_BY_DB: + db_msas.truncate(MAX_HITS_BY_DB[db_name]) + msas_by_seq_by_db[seq_name][db_name] = db_msas + full_msa_by_seq[seq_name].extend(msas) + msa_size = len(set(msas)) + print(f'{msa_size} Sequences Found in {db_name}') + + # Deduplicate full MSA and calculate total MSA size + for seq_name in full_msa_by_seq.keys(): + full_msa_by_seq[seq_name] = list(dict.fromkeys(full_msa_by_seq[seq_name])) + total_msa_size = len(full_msa_by_seq[seq_name]) + print(f'\n{total_msa_size} Sequences Found in Total for {seq_name}\n') + + # Visualize the results + fig = plt.figure(figsize=(12, 3)) + max_num_alignments = 0 + + for seq_idx, seq_name in enumerate(search_results_files.keys()): + full_msas = full_msa_by_seq[seq_name] + deduped_full_msa = list(dict.fromkeys(full_msas)) + total_msa_size = len(deduped_full_msa) + + aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')} + msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa]) + num_alignments, num_res = msa_arr.shape + plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), label=f'Chain {seq_idx}') + max_num_alignments = max(num_alignments, max_num_alignments) + + plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA') + plt.ylabel('Non-Gap Count') + plt.yticks(range(0, max_num_alignments + 1, max(1, int(max_num_alignments / 3)))) + plt.legend() + + # Save the plot to a file + plt.savefig(output_file) + print(f'MSA plot saved to: {output_file}') + return fig \ No newline at end of file diff --git a/notebooks/src/requirements.txt b/notebooks/src/requirements.txt new file mode 100644 index 000000000..3d5a234e4 --- /dev/null +++ b/notebooks/src/requirements.txt @@ -0,0 +1,13 @@ +docker==7.1.0 +py3Dmol==2.3.0 +torch==2.4.0 +ml_collections==0.1.1 +modelcif==1.0 +tqdm==4.66.5 +dm-tree==0.1.8 +biopython==1.83 +OpenMM +numpy==1.23.5 +git+https://github.com/aqlaboratory/openfold.git@3bec3e9b2d1e8bdb83887899102eff7d42dc2ba9 +matplotlib +notebook \ No newline at end of file diff --git a/notebooks/src/utils.py b/notebooks/src/utils.py new file mode 100644 index 000000000..fedf0b87c --- /dev/null +++ b/notebooks/src/utils.py @@ -0,0 +1,102 @@ +import random +import string +import os + +def get_run_folder_by_id(run_id): + base_path = os.path.join(os.getcwd(), 'data') + run_folder = next((entry.path for entry in os.scandir(base_path) if entry.is_dir() and entry.name.endswith(f'_{run_id}')), None) + if run_folder is None: + raise ValueError(f"Run ID '{run_id}' does not exist.") + return run_folder + +def read_fasta_file(fasta_file_path): + with open(fasta_file_path, 'r') as file: + lines = file.readlines() + sequence = ''.join(line.strip() for line in lines if not line.startswith('>')) + return sequence + +def generate_random_run_id(length=6): + run_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=length)) + print(f"Run ID: {run_id}") + return run_id + +def generate_random_sequence_name(length=8): + return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length)) + +def write_sequences_to_fasta(sequences, fasta_dir_root_path): + # os.makedirs(f'inference/{fasta_dir}', exist_ok=True) + fasta_file = os.path.join(fasta_dir_root_path, "sequences.fasta") + + with open(fasta_file, 'w') as f: + for seq in sequences: + sequence_name = f"sequence_{generate_random_sequence_name()}" + f.write(f">{sequence_name}\n") + f.write(f"{seq}\n") + + return fasta_file + +def validate_sequence(input_sequence, weight_set): + # Remove all whitespaces, tabs, and end lines; convert to upper-case + input_sequence = input_sequence.translate(str.maketrans('', '', ' \n\t')).upper() + aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard amino acids + allowed_chars = aatypes.union({':'}) + + if not set(input_sequence).issubset(allowed_chars): + raise Exception(f'Input sequence contains non-amino acid letters: {set(input_sequence) - allowed_chars}. OpenFold only supports 20 standard amino acids as inputs.') + + if ':' in input_sequence and weight_set != 'AlphaFold': + raise ValueError('Input sequence is a multimer, must select Alphafold weight set') + + return input_sequence + +def generate_individual_sequence_files(output_path): + with open(f"{output_path}/fasta_dir/sequences.fasta", "r") as infile: + sequence_id = None + sequence_lines = [] + + output_path = f"{output_path}/fasta_dir/tmp" + os.makedirs(output_path, exist_ok=True) + + for line in infile: + line = line.strip() + if line.startswith(">"): + if sequence_id is not None: + # Save the previous sequence to a file + output_file = os.path.join(output_path, f"{sequence_id}.fasta") + with open(output_file, "w") as outfile: + outfile.write(f">{sequence_id}\n") + outfile.write("\n".join(sequence_lines) + "\n") + print(f"Saved {sequence_id} to {output_file}") + + # Start a new sequence + sequence_id = line[1:].split('.')[0] # Remove '>' and split by '.' to remove the suffix + sequence_lines = [] + else: + sequence_lines.append(line) + + # Save the last sequence + if sequence_id is not None: + output_file = os.path.join(output_path, f"{sequence_id}.fasta") + with open(output_file, "w") as outfile: + outfile.write(f">{sequence_id}\n") + outfile.write("\n".join(sequence_lines) + "\n") + print(f"Saved {sequence_id} to {output_file}") + +def get_config_preset_list_for_model(weight_set, model_name): + model_configurations = { + ("OpenFold", "monomer"): [ + 'finetuning_3.pt', + 'finetuning_4.pt', + 'finetuning_5.pt', + 'finetuning_ptm_2.pt', + 'finetuning_no_templ_ptm_1.pt' + ], + ("AlphaFold", "multimer"): [f'model_{i}_multimer_v3' for i in range(1, 6)], + ("AlphaFold", "monomer"): [f'model_{i}' for i in range(1, 6)] + } + + config_preset_list = model_configurations.get((weight_set, model_name)) + if not config_preset_list: + raise ValueError(f"Invalid combination of weight_set '{weight_set}' and model_name '{model_name}'") + + return config_preset_list