Skip to content

Commit 7d398e4

Browse files
committed
modified plotmetatester
1 parent 60fe9ba commit 7d398e4

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

tests/conftest.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,15 @@ def test_sdata_single_image():
7777
np.zeros((1, 10, 10)), dims=("c", "y", "x"), transformations={"data1": sd.transformations.Identity()}
7878
)
7979
}
80-
sdata = sd.SpatialData(images=images)
81-
return sdata
80+
return sd.SpatialData(images=images)
8281

8382

8483
@pytest.fixture
8584
def test_sdata_single_image_with_label():
8685
"""Creates a simple sdata object."""
8786
images = {"data1": sd.models.Image2DModel.parse(np.zeros((1, 10, 10)), dims=("c", "y", "x"))}
8887
labels = {"label1": sd.models.Labels2DModel.parse(np.zeros((10, 10)), dims=("y", "x"))}
89-
sdata = sd.SpatialData(images=images, labels=labels)
90-
return sdata
88+
return sd.SpatialData(images=images, labels=labels)
9189

9290

9391
@pytest.fixture
@@ -104,8 +102,7 @@ def test_sdata_multiple_images():
104102
np.zeros((1, 10, 10)), dims=("c", "y", "x"), transformations={"data1": sd.transformations.Identity()}
105103
),
106104
}
107-
sdata = sd.SpatialData(images=images)
108-
return sdata
105+
return sd.SpatialData(images=images)
109106

110107

111108
@pytest.fixture
@@ -141,8 +138,7 @@ def test_sdata_multiple_images_dims():
141138
"data2": sd.models.Image2DModel.parse(np.zeros((3, 10, 10)), dims=("c", "y", "x")),
142139
"data3": sd.models.Image2DModel.parse(np.zeros((3, 10, 10)), dims=("c", "y", "x")),
143140
}
144-
sdata = sd.SpatialData(images=images)
145-
return sdata
141+
return sd.SpatialData(images=images)
146142

147143

148144
@pytest.fixture
@@ -153,8 +149,7 @@ def test_sdata_multiple_images_diverging_dims():
153149
"data2": sd.models.Image2DModel.parse(np.zeros((6, 10, 10)), dims=("c", "y", "x")),
154150
"data3": sd.models.Image2DModel.parse(np.zeros((3, 10, 10)), dims=("c", "y", "x")),
155151
}
156-
sdata = sd.SpatialData(images=images)
157-
return sdata
152+
return sd.SpatialData(images=images)
158153

159154

160155
@pytest.fixture
@@ -226,27 +221,28 @@ def empty_table() -> SpatialData:
226221
)
227222
def sdata(request) -> SpatialData:
228223
if request.param == "full":
229-
s = SpatialData(
224+
return SpatialData(
230225
images=_get_images(),
231226
labels=_get_labels(),
232227
shapes=_get_shapes(),
233228
points=_get_points(),
234229
table=_get_table("sample1"),
235230
)
236-
elif request.param == "empty":
237-
s = SpatialData()
238-
else:
239-
s = request.getfixturevalue(request.param)
240-
return s
231+
if request.param == "empty":
232+
return SpatialData()
233+
return request.getfixturevalue(request.param)
241234

242235

243236
def _get_images() -> dict[str, DataArray | DataTree]:
244-
out = {}
245237
dims_2d = ("c", "y", "x")
246238
dims_3d = ("z", "y", "x", "c")
247-
out["image2d"] = Image2DModel.parse(
248-
get_standard_RNG().normal(size=(3, 64, 64)), dims=dims_2d, c_coords=["r", "g", "b"]
249-
)
239+
out = {
240+
"image2d": Image2DModel.parse(
241+
get_standard_RNG().normal(size=(3, 64, 64)),
242+
dims=dims_2d,
243+
c_coords=["r", "g", "b"],
244+
)
245+
}
250246
out["image2d_multiscale"] = Image2DModel.parse(
251247
get_standard_RNG().normal(size=(3, 64, 64)), scale_factors=[2, 2], dims=dims_2d, c_coords=["r", "g", "b"]
252248
)
@@ -274,11 +270,10 @@ def _get_images() -> dict[str, DataArray | DataTree]:
274270

275271

276272
def _get_labels() -> dict[str, DataArray | DataTree]:
277-
out = {}
278273
dims_2d = ("y", "x")
279274
dims_3d = ("z", "y", "x")
280275

281-
out["labels2d"] = Labels2DModel.parse(get_standard_RNG().integers(0, 100, size=(64, 64)), dims=dims_2d)
276+
out = {"labels2d": Labels2DModel.parse(get_standard_RNG().integers(0, 100, size=(64, 64)), dims=dims_2d)}
282277
out["labels2d_multiscale"] = Labels2DModel.parse(
283278
get_standard_RNG().integers(0, 100, size=(64, 64)), scale_factors=[2, 4], dims=dims_2d
284279
)
@@ -347,8 +342,9 @@ def _get_polygons() -> dict[str, GeoDataFrame]:
347342

348343

349344
def _get_shapes() -> dict[str, AnnData]:
350-
out = {}
351345
arr = get_standard_RNG().normal(size=(100, 2))
346+
347+
out = {}
352348
out["shapes_0"] = ShapesModel.parse(arr, shape_type="Square", shape_size=3)
353349
out["shapes_1"] = ShapesModel.parse(arr, shape_type="Circle", shape_size=np.repeat(1, len(arr)))
354350

@@ -411,28 +407,34 @@ def compare(cls, basename: str, tolerance: float | None = None):
411407
ACTUAL.mkdir(parents=True, exist_ok=True)
412408
out_path = ACTUAL / f"{basename}.png"
413409

414-
width, height = 400, 300 # fixed dimensions so runners don't change
410+
width, height = 400, 300 # base dimensions; actual PNG may grow/shrink
415411
fig = plt.gcf()
416412
fig.set_size_inches(width / DPI, height / DPI)
417413
fig.set_dpi(DPI)
418414

419-
# Ensure all elements (including colorbars) are visible
420-
# Use constrained_layout first (better for colorbars), fallback to tight_layout with padding
421-
# Check if constrained_layout is already enabled
415+
# Try to get a reasonable layout first (helps with axes/labels)
422416
if not fig.get_constrained_layout():
423417
try:
424418
fig.set_constrained_layout(True)
425419
except (ValueError, RuntimeError):
426-
# If constrained_layout fails, use tight_layout with extra padding for colorbars
427420
try:
428421
fig.tight_layout(pad=2.0, rect=[0.02, 0.02, 0.98, 0.98])
429422
except (ValueError, RuntimeError):
430-
# Last resort: use subplots_adjust to add padding
431423
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
432-
plt.figure(fig.number) # Ensure this figure is current
433424

434-
plt.savefig(out_path, dpi=DPI)
435-
plt.close()
425+
plt.figure(fig.number) # ensure this figure is current
426+
427+
# Force a draw so that tight bbox "sees" all artists (including colorbars)
428+
fig.canvas.draw()
429+
430+
# Let matplotlib adjust the output size so that all artists are included
431+
fig.savefig(
432+
out_path,
433+
dpi=DPI,
434+
bbox_inches="tight",
435+
pad_inches=0.02, # small margin around everything
436+
)
437+
plt.close(fig)
436438

437439
if tolerance is None:
438440
# see https://github.com/scverse/squidpy/pull/302

0 commit comments

Comments
 (0)