From 7ab95f66e10d9c0fa33b425f27e8f56ec440cbd7 Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Sat, 20 Apr 2024 13:38:20 +0100 Subject: [PATCH] feat: bfloat16 support for monomer --- alphafold/model/config.py | 2 + alphafold/model/modules.py | 420 +++++++++++++++++++------------------ 2 files changed, 220 insertions(+), 202 deletions(-) diff --git a/alphafold/model/config.py b/alphafold/model/config.py index 447c3e34b..0e66b981d 100644 --- a/alphafold/model/config.py +++ b/alphafold/model/config.py @@ -378,6 +378,8 @@ def model_config(name: str) -> ml_collections.ConfigDict: } }, 'global_config': { + 'bfloat16': False, + 'bfloat16_output': False, 'deterministic': False, 'multimer_mode': False, 'subbatch_size': 4, diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 554c078c0..692177686 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -30,6 +30,7 @@ import haiku as hk import jax import jax.numpy as jnp +from jax.tree_util import tree_map _SOFTMAX_MASK = -1e9 @@ -1792,214 +1793,223 @@ def __call__(self, batch, is_training, safe_key=None): c = self.config gc = self.global_config + dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32 + if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) - # Embed clustered MSA. - # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 - # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" - preprocess_1d = common_modules.Linear( - c.msa_channel, name='preprocess_1d')( - batch['target_feat']) - - preprocess_msa = common_modules.Linear( - c.msa_channel, name='preprocess_msa')( - batch['msa_feat']) - - msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa - - left_single = common_modules.Linear( - c.pair_channel, name='left_single')( - batch['target_feat']) - right_single = common_modules.Linear( - c.pair_channel, name='right_single')( - batch['target_feat']) - pair_activations = left_single[:, None] + right_single[None] - mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] - - # Inject previous outputs for recycling. - # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 - # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" - if c.recycle_pos: - prev_pseudo_beta = pseudo_beta_fn( - batch['aatype'], batch['prev_pos'], None) - dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) - pair_activations += common_modules.Linear( - c.pair_channel, name='prev_pos_linear')( - dgram) - - if c.recycle_features: - prev_msa_first_row = common_modules.LayerNorm( - axis=[-1], - create_scale=True, - create_offset=True, - name='prev_msa_first_row_norm')( - batch['prev_msa_first_row']) - msa_activations = msa_activations.at[0].add(prev_msa_first_row) - - pair_activations += common_modules.LayerNorm( - axis=[-1], - create_scale=True, - create_offset=True, - name='prev_pair_norm')( - batch['prev_pair']) - - # Relative position encoding. - # Jumper et al. (2021) Suppl. Alg. 4 "relpos" - # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" - if c.max_relative_feature: - # Add one-hot-encoded clipped residue distances to the pair activations. - pos = batch['residue_index'] - offset = pos[:, None] - pos[None, :] - rel_pos = jax.nn.one_hot( - jnp.clip( - offset + c.max_relative_feature, - a_min=0, - a_max=2 * c.max_relative_feature), - 2 * c.max_relative_feature + 1) - pair_activations += common_modules.Linear( - c.pair_channel, name='pair_activiations')( - rel_pos) - - # Embed templates into the pair activations. - # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 - if c.template.enabled: - template_batch = {k: batch[k] for k in batch if k.startswith('template_')} - template_pair_representation = TemplateEmbedding(c.template, gc)( - pair_activations, - template_batch, - mask_2d, - is_training=is_training) - - pair_activations += template_pair_representation - - # Embed extra MSA features. - # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 - extra_msa_feat = create_extra_msa_feature(batch) - extra_msa_activations = common_modules.Linear( - c.extra_msa_channel, - name='extra_msa_activations')( - extra_msa_feat) - - # Extra MSA Stack. - # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" - extra_msa_stack_input = { - 'msa': extra_msa_activations, - 'pair': pair_activations, - } - - extra_msa_stack_iteration = EvoformerIteration( - c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') - - def extra_msa_stack_fn(x): - act, safe_key = x - safe_key, safe_subkey = safe_key.split() - extra_evoformer_output = extra_msa_stack_iteration( - activations=act, - masks={ - 'msa': batch['extra_msa_mask'], - 'pair': mask_2d - }, - is_training=is_training, - safe_key=safe_subkey) - return (extra_evoformer_output, safe_key) - - if gc.use_remat: - extra_msa_stack_fn = hk.remat(extra_msa_stack_fn) - - extra_msa_stack = layer_stack.layer_stack( - c.extra_msa_stack_num_block)( - extra_msa_stack_fn) - extra_msa_output, safe_key = extra_msa_stack( - (extra_msa_stack_input, safe_key)) - - pair_activations = extra_msa_output['pair'] - - evoformer_input = { - 'msa': msa_activations, - 'pair': pair_activations, - } - - evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d} - - # Append num_templ rows to msa_activations with template embeddings. - # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 - if c.template.enabled and c.template.embed_torsion_angles: - num_templ, num_res = batch['template_aatype'].shape - - # Embed the templates aatypes. - aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) - - # Embed the templates aatype, torsion angles and masks. - # Shape (templates, residues, msa_channels) - ret = all_atom.atom37_to_torsion_angles( - aatype=batch['template_aatype'], - all_atom_pos=batch['template_all_atom_positions'], - all_atom_mask=batch['template_all_atom_masks'], - # Ensure consistent behaviour during testing: - placeholder_for_undefined=not gc.zero_init) - - template_features = jnp.concatenate([ - aatype_one_hot, - jnp.reshape( - ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]), - jnp.reshape( - ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]), - ret['torsion_angles_mask']], axis=-1) - - template_activations = common_modules.Linear( - c.msa_channel, - initializer='relu', - name='template_single_embedding')( - template_features) - template_activations = jax.nn.relu(template_activations) - template_activations = common_modules.Linear( - c.msa_channel, - initializer='relu', - name='template_projection')( - template_activations) - - # Concatenate the templates to the msa. - evoformer_input['msa'] = jnp.concatenate( - [evoformer_input['msa'], template_activations], axis=0) - # Concatenate templates masks to the msa masks. - # Use mask from the psi angle, as it only depends on the backbone atoms - # from a single residue. - torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2] - torsion_angle_mask = torsion_angle_mask.astype( - evoformer_masks['msa'].dtype) - evoformer_masks['msa'] = jnp.concatenate( - [evoformer_masks['msa'], torsion_angle_mask], axis=0) - - # Main trunk of the network - # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 - evoformer_iteration = EvoformerIteration( - c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') - - def evoformer_fn(x): - act, safe_key = x - safe_key, safe_subkey = safe_key.split() - evoformer_output = evoformer_iteration( - activations=act, - masks=evoformer_masks, - is_training=is_training, - safe_key=safe_subkey) - return (evoformer_output, safe_key) - - if gc.use_remat: - evoformer_fn = hk.remat(evoformer_fn) + with utils.bfloat16_context(): + # Embed clustered MSA. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + target_feat = batch['target_feat'].astype(dtype) + + preprocess_1d = common_modules.Linear( + c.msa_channel, name='preprocess_1d')( + target_feat) + + preprocess_msa = common_modules.Linear( + c.msa_channel, name='preprocess_msa')( + batch['msa_feat'].astype(dtype)) + + msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + + left_single = common_modules.Linear( + c.pair_channel, name='left_single')( + target_feat) + right_single = common_modules.Linear( + c.pair_channel, name='right_single')( + target_feat) + pair_activations = left_single[:, None] + right_single[None] + mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] + mask_2d = mask_2d.astype(dtype) + + # Inject previous outputs for recycling. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + if c.recycle_pos: + prev_pseudo_beta = pseudo_beta_fn( + batch['aatype'], batch['prev_pos'], None) + dgram = dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos).astype(dtype) + pair_activations += common_modules.Linear( + c.pair_channel, name='prev_pos_linear')( + dgram) + + if c.recycle_features: + prev_msa_first_row = common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_msa_first_row_norm')( + batch['prev_msa_first_row']).astype(dtype) + msa_activations = msa_activations.at[0].add(prev_msa_first_row) + + pair_activations += common_modules.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_pair_norm')( + batch['prev_pair']).astype(dtype) + + # Relative position encoding. + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if c.max_relative_feature: + # Add one-hot-encoded clipped residue distances to the pair activations. + pos = batch['residue_index'] + offset = pos[:, None] - pos[None, :] + rel_pos = jax.nn.one_hot( + jnp.clip( + offset + c.max_relative_feature, + a_min=0, + a_max=2 * c.max_relative_feature), + 2 * c.max_relative_feature + 1).astype(dtype) + pair_activations += common_modules.Linear( + c.pair_channel, name='pair_activiations')( + rel_pos) + + # Embed templates into the pair activations. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + if c.template.enabled: + template_batch = {k: batch[k] for k in batch if k.startswith('template_')} + template_pair_representation = TemplateEmbedding(c.template, gc)( + pair_activations, + template_batch, + mask_2d, + is_training=is_training) + + pair_activations += template_pair_representation + + # Embed extra MSA features. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + extra_msa_feat = create_extra_msa_feature(batch) + extra_msa_activations = common_modules.Linear( + c.extra_msa_channel, + name='extra_msa_activations')( + extra_msa_feat).astype(dtype) + + # Extra MSA Stack. + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + extra_msa_stack_input = { + 'msa': extra_msa_activations, + 'pair': pair_activations, + } - evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( - evoformer_fn) - evoformer_output, safe_key = evoformer_stack( - (evoformer_input, safe_key)) + extra_msa_stack_iteration = EvoformerIteration( + c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') + + def extra_msa_stack_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + extra_evoformer_output = extra_msa_stack_iteration( + activations=act, + masks={ + 'msa': batch['extra_msa_mask'].astype(dtype), + 'pair': mask_2d + }, + is_training=is_training, + safe_key=safe_subkey) + return (extra_evoformer_output, safe_key) + + if gc.use_remat: + extra_msa_stack_fn = hk.remat(extra_msa_stack_fn) + + extra_msa_stack = layer_stack.layer_stack( + c.extra_msa_stack_num_block)( + extra_msa_stack_fn) + extra_msa_output, safe_key = extra_msa_stack( + (extra_msa_stack_input, safe_key)) + + pair_activations = extra_msa_output['pair'] + + evoformer_input = { + 'msa': msa_activations, + 'pair': pair_activations, + } - msa_activations = evoformer_output['msa'] - pair_activations = evoformer_output['pair'] + evoformer_masks = { + 'msa': batch['msa_mask'].astype(dtype), + 'pair': mask_2d + } - single_activations = common_modules.Linear( - c.seq_channel, name='single_activations')( - msa_activations[0]) + # Append num_templ rows to msa_activations with template embeddings. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 + if c.template.enabled and c.template.embed_torsion_angles: + num_templ, num_res = batch['template_aatype'].shape + + # Embed the templates aatypes. + aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) + + # Embed the templates aatype, torsion angles and masks. + # Shape (templates, residues, msa_channels) + ret = all_atom.atom37_to_torsion_angles( + aatype=batch['template_aatype'], + all_atom_pos=batch['template_all_atom_positions'], + all_atom_mask=batch['template_all_atom_masks'], + # Ensure consistent behaviour during testing: + placeholder_for_undefined=not gc.zero_init) + + template_features = jnp.concatenate([ + aatype_one_hot, + jnp.reshape( + ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]), + jnp.reshape( + ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]), + ret['torsion_angles_mask']], axis=-1).astype(dtype) + + template_activations = common_modules.Linear( + c.msa_channel, + initializer='relu', + name='template_single_embedding')( + template_features) + template_activations = jax.nn.relu(template_activations) + template_activations = common_modules.Linear( + c.msa_channel, + initializer='relu', + name='template_projection')( + template_activations) + + # Concatenate the templates to the msa. + evoformer_input['msa'] = jnp.concatenate( + [evoformer_input['msa'], template_activations], axis=0) + # Concatenate templates masks to the msa masks. + # Use mask from the psi angle, as it only depends on the backbone atoms + # from a single residue. + torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2] + torsion_angle_mask = torsion_angle_mask.astype( + evoformer_masks['msa'].dtype) + evoformer_masks['msa'] = jnp.concatenate( + [evoformer_masks['msa'], torsion_angle_mask], axis=0) + + # Main trunk of the network + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + evoformer_iteration = EvoformerIteration( + c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') + + def evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + evoformer_output = evoformer_iteration( + activations=act, + masks=evoformer_masks, + is_training=is_training, + safe_key=safe_subkey) + return (evoformer_output, safe_key) + + if gc.use_remat: + evoformer_fn = hk.remat(evoformer_fn) + + evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( + evoformer_fn) + evoformer_output, safe_key = evoformer_stack( + (evoformer_input, safe_key)) + + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + + single_activations = common_modules.Linear( + c.seq_channel, name='single_activations')( + msa_activations[0]) num_sequences = batch['msa_feat'].shape[0] output = { @@ -2010,6 +2020,12 @@ def evoformer_fn(x): 'msa_first_row': msa_activations[0], } + # Convert back to float32 if we're not saving memory. + if not gc.bfloat16_output: + output = tree_map( + lambda v: v.astype(jnp.float32) if (v.dtype==jnp.bfloat16) else v, + output) + return output