1818from absl .testing import parameterized
1919from alphafold .model import all_atom
2020from alphafold .model import r3
21+ import jax
2122import numpy as np
2223
2324L1_CLAMP_DISTANCE = 10
@@ -80,7 +81,7 @@ def test_frame_aligned_point_error_perfect_on_global_transform(
8081 global_rigid_transform = get_global_rigid_transform (
8182 rot_angle , translation , 1 )
8283
83- target_positions = r3 .vecs_from_tensor (target_positions )
84+ target_positions = r3 .vecs_from_tensor (jax . numpy . array ( target_positions ) )
8485 pred_positions = r3 .rigids_mul_vecs (
8586 global_rigid_transform , target_positions )
8687 positions_mask = np .ones (target_positions .x .shape [0 ])
@@ -93,7 +94,7 @@ def test_frame_aligned_point_error_perfect_on_global_transform(
9394 pred_frames , target_frames , frames_mask , pred_positions ,
9495 target_positions , positions_mask , L1_CLAMP_DISTANCE ,
9596 L1_CLAMP_DISTANCE , epsilon = 0 )
96- self .assertAlmostEqual (fape , 0. )
97+ self .assertAlmostEqual (fape , 0. , places = 6 )
9798
9899 @parameterized .named_parameters (
99100 ('identity' ,
@@ -120,8 +121,8 @@ def test_frame_aligned_point_error_matches_expected(
120121 pred_frames = target_frames
121122 frames_mask = np .ones (2 )
122123
123- target_positions = r3 .vecs_from_tensor (np .array (target_positions ))
124- pred_positions = r3 .vecs_from_tensor (np .array (pred_positions ))
124+ target_positions = r3 .vecs_from_tensor (jax . numpy .array (target_positions ))
125+ pred_positions = r3 .vecs_from_tensor (jax . numpy .array (pred_positions ))
125126 positions_mask = np .ones (target_positions .x .shape [0 ])
126127
127128 alddt = all_atom .frame_aligned_point_error (
0 commit comments