-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathfunctional.py
349 lines (287 loc) · 16.8 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, Mapping, Sequence
import torch
from monai.apps.utils import get_logger
from monai.config import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta, MetaObj
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms.lazy.utils import (
affine_from_pending,
combine_transforms,
is_compatible_apply_kwargs,
kwargs_from_pending,
resample,
)
from monai.transforms.traits import LazyTrait
from monai.transforms.transform import MapTransform
from monai.utils import LazyAttr, MetaKeys, convert_to_tensor, look_up_option
__all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending", "apply_to_geometry"]
__override_keywords = {"mode", "padding_mode", "dtype", "align_corners", "resample_mode", "device"}
def _log_pending_info(
transform: Any,
data: Any,
activity: str,
*,
lazy: bool | None = None,
key: str | None = None,
logger_name: bool | str = False,
):
if logger_name is False:
return
logger_name = logger_name if isinstance(logger_name, str) else "apply_pending_transforms"
logger = get_logger(logger_name)
tcname = type(transform).__name__
if isinstance(transform, LazyTrait):
tlazy = f", transform.lazy: {transform.lazy}"
if lazy is not None and lazy != transform.lazy:
tlazy += " (overridden)"
else:
tlazy = ", transform is not lazy"
msg = f"{activity} - lazy: {lazy}, {{key_msg}}pending: {{pcount}}, upcoming '{tcname}'{tlazy}"
if isinstance(transform, MapTransform):
transform_keys = transform.keys if key is None else (key,)
for k in transform_keys:
if k in data:
pcount = len(data[k].pending_operations) if isinstance(data[k], MetaTensor) else 0
logger.info(msg.format(pcount=pcount, key_msg=f"key: '{k}', "))
else:
pcount = len(data.pending_operations) if isinstance(data, MetaTensor) else 0
logger.info(msg.format(pcount=pcount, key_msg="" if key is None else f"key: '{key}', "))
def _log_applied_info(data: Any, key=None, logger_name: bool | str = False):
if logger_name is False:
return
logger_name = logger_name if isinstance(logger_name, str) else "apply_pending_transforms"
logger = get_logger(logger_name)
key_str = "" if key is None else f"key: '{key}', "
logger.info(f"Pending transforms applied: {key_str}applied_operations: {len(data.applied_operations)}")
def apply_pending_transforms(
data: NdarrayOrTensor | Sequence[Any | NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
keys: tuple | None,
overrides: dict | None = None,
logger_name: bool | str = False,
):
"""
apply_pending_transforms is called with either a tensor or a dictionary, some entries of which contain
tensors.
When operating on a dictionary of tensors, the 'keys' parameter determines what tensors should be checked.
If 'keys' is not set, all keys of 'data' are considered.
This method optionally takes a set of overrides that can be used to change specific parameters on the
transform pipeline. See ``Compose`` for more details. This method takes a logger_name that can be used
to override the default logger, to provide telemetry during the execution of pending transforms.
This method is intended primarily for use by ``execute_compose`` and other methods that handle the
underlying execution of transform pipelines. You should not need to use it in the general case, unless
you are developing functionality to perform such operations.
Args:
data: a ``torch.Tensor`` or ``MetaTensor``, or dictionary of tensors.
keys: an optional tuple of keys that filters the keys on 'data' if it is a dict
overrides: An optional dictionary that specifies parameters that can be used to override transform
arguments when they are called. When 'data' is a dict, this dictionary should contain a dictionary
of overrides for each key that needs them
logger_name: An optional name for a logger to be used when applying pending transforms. If None,
logging is suppressed.
Returns:
an object of the same type as data if pending transforms were applied, or 'data' if they were not
"""
if isinstance(data, list):
return [apply_pending_transforms(d, keys, overrides, logger_name) for d in data]
if isinstance(data, tuple):
return tuple(apply_pending_transforms(d, keys, overrides, logger_name) for d in data)
if isinstance(data, dict):
# get the keys from 'data' for metatensors with pending operations. If 'keys' is set, select
# only data keys that are in 'keys'
active_keys = [k for k in data.keys() if keys is None or k in keys]
keys_to_update = [k for k in active_keys if isinstance(data[k], MetaTensor) and data[k].has_pending_operations]
if len(keys_to_update) > 0:
rdata = dict(data)
for k in keys_to_update:
overrides_ = None if overrides is None else overrides.get(k, None)
rdata[k], _ = apply_pending(data[k], overrides=overrides_)
_log_applied_info(rdata[k], key=k, logger_name=logger_name)
return rdata
else:
if isinstance(data, MetaTensor) and data.has_pending_operations:
rdata, _ = apply_pending(data, overrides=overrides)
_log_applied_info(rdata, logger_name=logger_name)
return rdata
return data
def apply_pending_transforms_in_order(
transform, data, lazy: bool | None = None, overrides: dict | None = None, logger_name: bool | str = False
):
"""
This method causes "in order" processing of pending transforms to occur.
"in order" processing of pending transforms ensures that all pending transforms have been applied to the
tensor before a non-lazy transform (or lazy transform that is executing non-lazily) is carried out.
It ensures that no operations will be added to a metatensor's apply_operations while there are outstanding
pending_operations. Note that there is only one mechanism for executing lazy resampling at present but this
is expected to change in future releases.
Evaluation of pending transforms is performed under the following circumstances:
* If the transform is a lazy transform and:
* The transform checks data as part of its execution, or
* the transform is not executing lazily
* If the transform is an ApplyPending[d] transform
* If the transform is not a lazy transform
This method is designed to be used only in the context of implementing lazy resampling functionality. In general
you should not need to interact with or use this method directly, and its API may change without warning between
releases. See the :ref:`Lazy Resampling topic<lazy_resampling> for more information about lazy resampling.
Args:
transform: a transform that should be evaluated to determine whether pending transforms should be applied
data: a tensor / MetaTensor, or dictionary containing tensors / MetaTensors whose pending transforms may
need to be applied
lazy: The lazy mode that is being applied (this can be False, True or None)
overrides: An optional dictionary containing overrides to be applied to the pending transforms when they
are lazily executed. If data is a dict, it should contain a dictionary of overrides for each key that
needs them
logger_name: An optional name for a logger to be used when applying pending transforms. If None,
logging is suppressed.
Returns:
an object of the same type as data if pending transforms were applied, or 'data' if they were not
"""
from monai.transforms.lazy.dictionary import ApplyPendingd
must_apply_pending = True
keys = transform.keys if isinstance(transform, ApplyPendingd) else None
if isinstance(transform, LazyTrait) and not transform.requires_current_data:
must_apply_pending = not (transform.lazy if lazy is None else lazy)
if must_apply_pending is True:
_log_pending_info(transform, data, "Apply pending transforms", lazy=lazy, logger_name=logger_name)
return apply_pending_transforms(data, keys, overrides, logger_name)
_log_pending_info(transform, data, "Accumulate pending transforms", lazy=lazy, logger_name=logger_name)
return data
def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, overrides: dict | None = None):
"""
This method applies pending transforms to `data` tensors.
Currently, only 2d and 3d inputs are supported.
This method is designed to be called by ``apply_pending_transforms`` and other methods / classes
that are part of the implementation of lazy resampling. In general, you should not need to call
this method unless you are directly developing custom lazy execution strategies.
It works by calculating the overall effect of the accumulated pending transforms. When it runs
out of pending transforms or when it finds incompatibilities between the accumulated pending
transform and the next pending transform, it then applies the accumulated transform in a call to
``resample``.
Pending transforms are incompatible with each other if one or more of the arguments in the pending
transforms differ. These are parameters such as 'mode', 'padding_mode', 'dtype' and so forth. If
a pending transform doesn't have a given parameter, it is considered compatible with the
accumulated transform. If a subsequent transform has a parameter that is incompatible with
the accumulated transform (e.g. 'mode' of 'bilinear' vs. 'mode' of 'nearest'), an intermediate
resample will be performed and the accumulated transform reset to its starting state.
After resampling, the pending transforms are pushed to the ``applied_transforms`` field of the
resulting MetaTensor. Note, if a torch.tensor is passed to this method along with a list of
pending transforms, the resampled tensor will be wrapped in a MetaTensor before being returned.
Args:
data: A torch Tensor or a monai MetaTensor.
pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor.
overrides: a dictionary of overrides for the transform arguments. The keys must be one of:
- mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order ``0-5`` (integers).
Interpolation mode to calculate output values. Defaults to None.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When it's `an integer`, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
and the value represents the order of the spline interpolation.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
- padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to None.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When `mode` is an integer, using numpy/cupy backends, this argument accepts
{'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
- dtype: data type for resampling computation. Defaults to ``float64``.
If ``None``, use the data type of input data, this option may not be compatible the resampling backend.
- align_corners: Geometrically, we consider the pixels of the input as squares rather than points, when using
the PyTorch resampling backend. Defaults to ``False``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- device: device for resampling computation. Defaults to ``None``.
- resample_mode: the mode of resampling, currently support ``"auto"``. Setting to other values will use the
:py:class:`monai.transforms.SpatialResample` for resampling (instead of potentially crop/pad).
"""
overrides = (overrides or {}).copy()
for k in overrides:
look_up_option(k, __override_keywords) # check existence of the key
if isinstance(data, MetaTensor) and pending is None:
pending = data.pending_operations.copy()
data.clear_pending_operations()
pending = [] if pending is None else pending
if not pending:
return data, []
cumulative_xform = affine_from_pending(pending[0])
if cumulative_xform.shape[0] == 3:
cumulative_xform = to_affine_nd(3, cumulative_xform)
cur_kwargs = kwargs_from_pending(pending[0])
override_kwargs: dict[str, Any] = {}
if "mode" in overrides:
override_kwargs[LazyAttr.INTERP_MODE] = overrides["mode"]
if "padding_mode" in overrides:
override_kwargs[LazyAttr.PADDING_MODE] = overrides["padding_mode"]
if "align_corners" in overrides:
override_kwargs[LazyAttr.ALIGN_CORNERS] = overrides["align_corners"]
if "resample_mode" in overrides:
override_kwargs[LazyAttr.RESAMPLE_MODE] = overrides["resample_mode"]
override_dtype = overrides.get("dtype", torch.float64)
override_kwargs[LazyAttr.DTYPE] = data.dtype if override_dtype is None else override_dtype
device = overrides.get("device")
for p in pending[1:]:
new_kwargs = kwargs_from_pending(p)
if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs):
# carry out an intermediate resample here due to incompatibility between arguments
_cur_kwargs = cur_kwargs.copy()
_cur_kwargs.update(override_kwargs)
data = resample(data.to(device), cumulative_xform, _cur_kwargs)
next_matrix = affine_from_pending(p)
if next_matrix.shape[0] == 3:
next_matrix = to_affine_nd(3, next_matrix)
cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
cur_kwargs.update(new_kwargs)
cur_kwargs.update(override_kwargs)
data = resample(data.to(device), cumulative_xform, cur_kwargs)
if isinstance(data, MetaTensor):
for p in pending:
data.push_applied_operation(p)
return data, pending
def apply_to_geometry(
data: torch.Tensor,
meta_info: dict | MetaObj | None = None,
transform: torch.Tensor | None = None,
):
"""
Apply an affine geometric transform or deformation field to geometry.
At present this is limited to the transformation of points.
The points must be provided as a tensor and must be compatible with a homogeneous
transform. This means that:
- 2D points are of the form (x, y, 1)
- 3D points are of the form (x, y, z, 1)
The affine transform or deformation field is applied to the the points and a tensor of
the same shape as the input tensor is returned.
Args:
data: the tensor of points to be transformed.
meta_info: the metadata containing the affine transformation
"""
if meta_info is None and transform is None:
raise ValueError("either meta_info or transform must be provided")
if meta_info is not None and transform is not None:
raise ValueError("only one of meta_info or transform can be provided")
if not isinstance(data, (torch.Tensor, MetaTensor)):
raise TypeError(f"data {type(data)} must be a torch.Tensor or MetaTensor")
data = convert_to_tensor(data, track_meta=get_track_meta())
if meta_info is not None:
transform_ = meta_info.meta[MetaKeys.AFFINE]
else:
transform_ = transform
if transform_.dtype != data.dtype:
transform_ = transform_.to(data.dtype)
if data.shape[-1] == 3 and transform_.shape[0] == 4:
transform_[2, 0:2] = transform_[3, 0:2]
transform_[2, 2] = transform_[3, 3]
transform_[0:2, 2] = transform_[0:2, 3]
transform_ = transform_[:-1, :-1]
if data.shape[-1] != transform_.shape[0]:
raise ValueError(f"final element of data.shape {data.shape} must match transform shape {transform_.shape}")
result = torch.matmul(data, transform_.T)
return result