2828
2929
3030def blobs (
31- length : int = 512 , n_points : int = 200 , n_shapes : int = 5 , extra_coord_system : Optional [str ] = None
31+ length : int = 512 ,
32+ n_points : int = 200 ,
33+ n_shapes : int = 5 ,
34+ extra_coord_system : Optional [str ] = None ,
35+ n_channels : int = 3 ,
3236) -> SpatialData :
3337 """
3438 Blobs dataset.
@@ -43,7 +47,9 @@ def blobs(
4347 Number of max shapes to generate.
4448 At most, as if overlapping they will be discarded
4549 extra_coord_system
46- Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
50+ Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
51+ n_channels
52+ Number of channels of the image
4753
4854
4955 Returns
@@ -52,7 +58,11 @@ def blobs(
5258 SpatialData object with blobs dataset.
5359 """
5460 return BlobsDataset (
55- length = length , n_points = n_points , n_shapes = n_shapes , extra_coord_system = extra_coord_system
61+ length = length ,
62+ n_points = n_points ,
63+ n_shapes = n_shapes ,
64+ extra_coord_system = extra_coord_system ,
65+ n_channels = n_channels ,
5666 ).blobs ()
5767
5868
@@ -84,7 +94,12 @@ class BlobsDataset:
8494 """Blobs dataset."""
8595
8696 def __init__ (
87- self , length : int = 512 , n_points : int = 200 , n_shapes : int = 5 , extra_coord_system : Optional [str ] = None
97+ self ,
98+ length : int = 512 ,
99+ n_points : int = 200 ,
100+ n_shapes : int = 5 ,
101+ extra_coord_system : Optional [str ] = None ,
102+ n_channels : int = 3 ,
88103 ) -> None :
89104 """
90105 Blobs dataset.
@@ -100,20 +115,23 @@ def __init__(
100115 At most, as if overlapping they will be discarded
101116 extra_coord_system
102117 Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
118+ n_channels
119+ Number of channels of the image
103120 """
104121 self .length = length
105122 self .n_points = n_points
106123 self .n_shapes = n_shapes
107124 self .transformations = {"global" : Identity ()}
125+ self .n_channels = n_channels
108126 if extra_coord_system :
109127 self .transformations [extra_coord_system ] = Identity ()
110128
111129 def blobs (
112130 self ,
113131 ) -> SpatialData :
114132 """Blobs dataset."""
115- image = self ._image_blobs (self .transformations , self .length )
116- multiscale_image = self ._image_blobs (self .transformations , self .length , multiscale = True )
133+ image = self ._image_blobs (self .transformations , self .length , self . n_channels )
134+ multiscale_image = self ._image_blobs (self .transformations , self .length , self . n_channels , multiscale = True )
117135 labels = self ._labels_blobs (self .transformations , self .length )
118136 multiscale_labels = self ._labels_blobs (self .transformations , self .length , multiscale = True )
119137 points = self ._points_blobs (self .transformations , self .length , self .n_points )
@@ -138,10 +156,11 @@ def _image_blobs(
138156 self ,
139157 transformations : Optional [dict [str , Any ]] = None ,
140158 length : int = 512 ,
159+ n_channels : int = 3 ,
141160 multiscale : bool = False ,
142161 ) -> Union [SpatialImage , MultiscaleSpatialImage ]:
143162 masks = []
144- for i in range (3 ):
163+ for i in range (n_channels ):
145164 mask = self ._generate_blobs (length = length , seed = i )
146165 mask = (mask - mask .min ()) / mask .ptp ()
147166 masks .append (mask )
0 commit comments