@@ -77,17 +77,15 @@ def test_sdata_single_image():
7777 np .zeros ((1 , 10 , 10 )), dims = ("c" , "y" , "x" ), transformations = {"data1" : sd .transformations .Identity ()}
7878 )
7979 }
80- sdata = sd .SpatialData (images = images )
81- return sdata
80+ return sd .SpatialData (images = images )
8281
8382
8483@pytest .fixture
8584def test_sdata_single_image_with_label ():
8685 """Creates a simple sdata object."""
8786 images = {"data1" : sd .models .Image2DModel .parse (np .zeros ((1 , 10 , 10 )), dims = ("c" , "y" , "x" ))}
8887 labels = {"label1" : sd .models .Labels2DModel .parse (np .zeros ((10 , 10 )), dims = ("y" , "x" ))}
89- sdata = sd .SpatialData (images = images , labels = labels )
90- return sdata
88+ return sd .SpatialData (images = images , labels = labels )
9189
9290
9391@pytest .fixture
@@ -104,8 +102,7 @@ def test_sdata_multiple_images():
104102 np .zeros ((1 , 10 , 10 )), dims = ("c" , "y" , "x" ), transformations = {"data1" : sd .transformations .Identity ()}
105103 ),
106104 }
107- sdata = sd .SpatialData (images = images )
108- return sdata
105+ return sd .SpatialData (images = images )
109106
110107
111108@pytest .fixture
@@ -141,8 +138,7 @@ def test_sdata_multiple_images_dims():
141138 "data2" : sd .models .Image2DModel .parse (np .zeros ((3 , 10 , 10 )), dims = ("c" , "y" , "x" )),
142139 "data3" : sd .models .Image2DModel .parse (np .zeros ((3 , 10 , 10 )), dims = ("c" , "y" , "x" )),
143140 }
144- sdata = sd .SpatialData (images = images )
145- return sdata
141+ return sd .SpatialData (images = images )
146142
147143
148144@pytest .fixture
@@ -153,8 +149,7 @@ def test_sdata_multiple_images_diverging_dims():
153149 "data2" : sd .models .Image2DModel .parse (np .zeros ((6 , 10 , 10 )), dims = ("c" , "y" , "x" )),
154150 "data3" : sd .models .Image2DModel .parse (np .zeros ((3 , 10 , 10 )), dims = ("c" , "y" , "x" )),
155151 }
156- sdata = sd .SpatialData (images = images )
157- return sdata
152+ return sd .SpatialData (images = images )
158153
159154
160155@pytest .fixture
@@ -226,27 +221,28 @@ def empty_table() -> SpatialData:
226221)
227222def sdata (request ) -> SpatialData :
228223 if request .param == "full" :
229- s = SpatialData (
224+ return SpatialData (
230225 images = _get_images (),
231226 labels = _get_labels (),
232227 shapes = _get_shapes (),
233228 points = _get_points (),
234229 table = _get_table ("sample1" ),
235230 )
236- elif request .param == "empty" :
237- s = SpatialData ()
238- else :
239- s = request .getfixturevalue (request .param )
240- return s
231+ if request .param == "empty" :
232+ return SpatialData ()
233+ return request .getfixturevalue (request .param )
241234
242235
243236def _get_images () -> dict [str , DataArray | DataTree ]:
244- out = {}
245237 dims_2d = ("c" , "y" , "x" )
246238 dims_3d = ("z" , "y" , "x" , "c" )
247- out ["image2d" ] = Image2DModel .parse (
248- get_standard_RNG ().normal (size = (3 , 64 , 64 )), dims = dims_2d , c_coords = ["r" , "g" , "b" ]
249- )
239+ out = {
240+ "image2d" : Image2DModel .parse (
241+ get_standard_RNG ().normal (size = (3 , 64 , 64 )),
242+ dims = dims_2d ,
243+ c_coords = ["r" , "g" , "b" ],
244+ )
245+ }
250246 out ["image2d_multiscale" ] = Image2DModel .parse (
251247 get_standard_RNG ().normal (size = (3 , 64 , 64 )), scale_factors = [2 , 2 ], dims = dims_2d , c_coords = ["r" , "g" , "b" ]
252248 )
@@ -274,11 +270,10 @@ def _get_images() -> dict[str, DataArray | DataTree]:
274270
275271
276272def _get_labels () -> dict [str , DataArray | DataTree ]:
277- out = {}
278273 dims_2d = ("y" , "x" )
279274 dims_3d = ("z" , "y" , "x" )
280275
281- out [ "labels2d" ] = Labels2DModel .parse (get_standard_RNG ().integers (0 , 100 , size = (64 , 64 )), dims = dims_2d )
276+ out = { "labels2d" : Labels2DModel .parse (get_standard_RNG ().integers (0 , 100 , size = (64 , 64 )), dims = dims_2d )}
282277 out ["labels2d_multiscale" ] = Labels2DModel .parse (
283278 get_standard_RNG ().integers (0 , 100 , size = (64 , 64 )), scale_factors = [2 , 4 ], dims = dims_2d
284279 )
@@ -347,8 +342,9 @@ def _get_polygons() -> dict[str, GeoDataFrame]:
347342
348343
349344def _get_shapes () -> dict [str , AnnData ]:
350- out = {}
351345 arr = get_standard_RNG ().normal (size = (100 , 2 ))
346+
347+ out = {}
352348 out ["shapes_0" ] = ShapesModel .parse (arr , shape_type = "Square" , shape_size = 3 )
353349 out ["shapes_1" ] = ShapesModel .parse (arr , shape_type = "Circle" , shape_size = np .repeat (1 , len (arr )))
354350
@@ -411,28 +407,34 @@ def compare(cls, basename: str, tolerance: float | None = None):
411407 ACTUAL .mkdir (parents = True , exist_ok = True )
412408 out_path = ACTUAL / f"{ basename } .png"
413409
414- width , height = 400 , 300 # fixed dimensions so runners don't change
410+ width , height = 400 , 300 # base dimensions; actual PNG may grow/shrink
415411 fig = plt .gcf ()
416412 fig .set_size_inches (width / DPI , height / DPI )
417413 fig .set_dpi (DPI )
418414
419- # Ensure all elements (including colorbars) are visible
420- # Use constrained_layout first (better for colorbars), fallback to tight_layout with padding
421- # Check if constrained_layout is already enabled
415+ # Try to get a reasonable layout first (helps with axes/labels)
422416 if not fig .get_constrained_layout ():
423417 try :
424418 fig .set_constrained_layout (True )
425419 except (ValueError , RuntimeError ):
426- # If constrained_layout fails, use tight_layout with extra padding for colorbars
427420 try :
428421 fig .tight_layout (pad = 2.0 , rect = [0.02 , 0.02 , 0.98 , 0.98 ])
429422 except (ValueError , RuntimeError ):
430- # Last resort: use subplots_adjust to add padding
431423 fig .subplots_adjust (left = 0.1 , right = 0.9 , top = 0.9 , bottom = 0.1 )
432- plt .figure (fig .number ) # Ensure this figure is current
433424
434- plt .savefig (out_path , dpi = DPI )
435- plt .close ()
425+ plt .figure (fig .number ) # ensure this figure is current
426+
427+ # Force a draw so that tight bbox "sees" all artists (including colorbars)
428+ fig .canvas .draw ()
429+
430+ # Let matplotlib adjust the output size so that all artists are included
431+ fig .savefig (
432+ out_path ,
433+ dpi = DPI ,
434+ bbox_inches = "tight" ,
435+ pad_inches = 0.02 , # small margin around everything
436+ )
437+ plt .close (fig )
436438
437439 if tolerance is None :
438440 # see https://github.com/scverse/squidpy/pull/302
0 commit comments