Skip to content

Commit aba9765

Browse files
rpachauricopybara-github
authored andcommitted
Use correct type information in tests.
PiperOrigin-RevId: 797102858 Change-Id: Ie49cf8102dee57a507d60464c746cf92fa70192f
1 parent 29f0821 commit aba9765

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

alphafold/model/all_atom_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl.testing import parameterized
1919
from alphafold.model import all_atom
2020
from alphafold.model import r3
21+
import jax
2122
import numpy as np
2223

2324
L1_CLAMP_DISTANCE = 10
@@ -80,7 +81,7 @@ def test_frame_aligned_point_error_perfect_on_global_transform(
8081
global_rigid_transform = get_global_rigid_transform(
8182
rot_angle, translation, 1)
8283

83-
target_positions = r3.vecs_from_tensor(target_positions)
84+
target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions))
8485
pred_positions = r3.rigids_mul_vecs(
8586
global_rigid_transform, target_positions)
8687
positions_mask = np.ones(target_positions.x.shape[0])
@@ -93,7 +94,7 @@ def test_frame_aligned_point_error_perfect_on_global_transform(
9394
pred_frames, target_frames, frames_mask, pred_positions,
9495
target_positions, positions_mask, L1_CLAMP_DISTANCE,
9596
L1_CLAMP_DISTANCE, epsilon=0)
96-
self.assertAlmostEqual(fape, 0.)
97+
self.assertAlmostEqual(fape, 0., places=6)
9798

9899
@parameterized.named_parameters(
99100
('identity',
@@ -120,8 +121,8 @@ def test_frame_aligned_point_error_matches_expected(
120121
pred_frames = target_frames
121122
frames_mask = np.ones(2)
122123

123-
target_positions = r3.vecs_from_tensor(np.array(target_positions))
124-
pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
124+
target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions))
125+
pred_positions = r3.vecs_from_tensor(jax.numpy.array(pred_positions))
125126
positions_mask = np.ones(target_positions.x.shape[0])
126127

127128
alddt = all_atom.frame_aligned_point_error(

alphafold/notebooks/notebook_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ def test_merge_chunked_msa(self):
159159
def test_show_msa_info(self, mocked_stdout):
160160
single_chain_msas = [
161161
parsers.Msa(sequences=['A', 'B', 'C', 'C'],
162-
deletion_matrix=[None] * 4,
162+
deletion_matrix=[[0]] * 4,
163163
descriptions=[''] * 4),
164164
parsers.Msa(sequences=['A', 'A', 'A', 'D'],
165-
deletion_matrix=[None] * 4,
165+
deletion_matrix=[[0]] * 4,
166166
descriptions=[''] * 4)
167167
]
168168
notebook_utils.show_msa_info(

0 commit comments

Comments
 (0)