Skip to content

Commit fa096da

Browse files
authored
specify n_channels (#307)
* specify n_channels * Add n_channels to multiscale * updated submodule * add notebooks gitignore
1 parent b2c43b9 commit fa096da

2 files changed

Lines changed: 29 additions & 7 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,6 @@ spatialdata-sandbox
4545

4646
# version file
4747
_version.py
48+
49+
# prevent submodule notebooks getting out of sync
50+
docs/tutorials/notebooks

src/spatialdata/datasets.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828

2929

3030
def blobs(
31-
length: int = 512, n_points: int = 200, n_shapes: int = 5, extra_coord_system: Optional[str] = None
31+
length: int = 512,
32+
n_points: int = 200,
33+
n_shapes: int = 5,
34+
extra_coord_system: Optional[str] = None,
35+
n_channels: int = 3,
3236
) -> SpatialData:
3337
"""
3438
Blobs dataset.
@@ -43,7 +47,9 @@ def blobs(
4347
Number of max shapes to generate.
4448
At most, as if overlapping they will be discarded
4549
extra_coord_system
46-
Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
50+
Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
51+
n_channels
52+
Number of channels of the image
4753
4854
4955
Returns
@@ -52,7 +58,11 @@ def blobs(
5258
SpatialData object with blobs dataset.
5359
"""
5460
return BlobsDataset(
55-
length=length, n_points=n_points, n_shapes=n_shapes, extra_coord_system=extra_coord_system
61+
length=length,
62+
n_points=n_points,
63+
n_shapes=n_shapes,
64+
extra_coord_system=extra_coord_system,
65+
n_channels=n_channels,
5666
).blobs()
5767

5868

@@ -84,7 +94,12 @@ class BlobsDataset:
8494
"""Blobs dataset."""
8595

8696
def __init__(
87-
self, length: int = 512, n_points: int = 200, n_shapes: int = 5, extra_coord_system: Optional[str] = None
97+
self,
98+
length: int = 512,
99+
n_points: int = 200,
100+
n_shapes: int = 5,
101+
extra_coord_system: Optional[str] = None,
102+
n_channels: int = 3,
88103
) -> None:
89104
"""
90105
Blobs dataset.
@@ -100,20 +115,23 @@ def __init__(
100115
At most, as if overlapping they will be discarded
101116
extra_coord_system
102117
Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
118+
n_channels
119+
Number of channels of the image
103120
"""
104121
self.length = length
105122
self.n_points = n_points
106123
self.n_shapes = n_shapes
107124
self.transformations = {"global": Identity()}
125+
self.n_channels = n_channels
108126
if extra_coord_system:
109127
self.transformations[extra_coord_system] = Identity()
110128

111129
def blobs(
112130
self,
113131
) -> SpatialData:
114132
"""Blobs dataset."""
115-
image = self._image_blobs(self.transformations, self.length)
116-
multiscale_image = self._image_blobs(self.transformations, self.length, multiscale=True)
133+
image = self._image_blobs(self.transformations, self.length, self.n_channels)
134+
multiscale_image = self._image_blobs(self.transformations, self.length, self.n_channels, multiscale=True)
117135
labels = self._labels_blobs(self.transformations, self.length)
118136
multiscale_labels = self._labels_blobs(self.transformations, self.length, multiscale=True)
119137
points = self._points_blobs(self.transformations, self.length, self.n_points)
@@ -138,10 +156,11 @@ def _image_blobs(
138156
self,
139157
transformations: Optional[dict[str, Any]] = None,
140158
length: int = 512,
159+
n_channels: int = 3,
141160
multiscale: bool = False,
142161
) -> Union[SpatialImage, MultiscaleSpatialImage]:
143162
masks = []
144-
for i in range(3):
163+
for i in range(n_channels):
145164
mask = self._generate_blobs(length=length, seed=i)
146165
mask = (mask - mask.min()) / mask.ptp()
147166
masks.append(mask)

0 commit comments

Comments
 (0)