Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def _spatial_norm_callback(
spatial = TaggedArray._extract_data(adata, attr=attr, key=key)

logger.info(f"Normalizing spatial coordinates of `{term}`.")
spatial = (spatial - spatial.mean()) / spatial.std()
spatial = (spatial - spatial.mean(axis=0)) / spatial.std()
return TaggedArray(spatial, tag=Tag.POINT_CLOUD)

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def test_prepare_sequential(
ap = ap.prepare(batch_key="batch", joint_attr=joint_attr, normalize_spatial=normalize_spatial)
assert len(ap) == 2
if normalize_spatial:
np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), ap[("0", "1")].x.data_src.std(), atol=1e-15)
np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), 1.0, atol=1e-15)
np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), ap[("0", "1")].y.data_src.std(), atol=1e-15)
np.testing.assert_allclose(ap[("1", "2")].x.data_src.mean(), 0, atol=1e-15)
np.testing.assert_allclose(ap[("0", "1")].x.data_src.mean(), 0, atol=1e-15)

Expand Down