Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NumpyTransform,
OneHot,
Rename,
Reshape,
SerializableCustomTransform,
Squeeze,
Sqrt,
Expand Down Expand Up @@ -746,6 +747,25 @@ def rename(self, from_key: str, to_key: str):
self.transforms.append(Rename(from_key, to_key))
return self

def reshape(self, keys: str | Sequence[str], *, to: int | Sequence[int]):
"""Append a :py:class:`~transforms.Reshape` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
Variables that should be reshaped
to : int or tuple of int
Target shape of the variables
"""
from .transforms import Reshape

if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: Reshape(shape=to) for key in keys})
self.transforms.append(transform)
return self

def scale(self, keys: str | Sequence[str], by: float | np.ndarray):
from .transforms import Scale

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .numpy_transform import NumpyTransform
from .one_hot import OneHot
from .rename import Rename
from .reshape import Reshape
from .scale import Scale
from .serializable_custom_transform import SerializableCustomTransform
from .shift import Shift
Expand Down
28 changes: 28 additions & 0 deletions bayesflow/adapters/transforms/reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np

from collections.abc import Sequence
from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


@serializable("bayesflow.adapters")
class Reshape(ElementwiseTransform):

def __init__(self, shape: int | Sequence[int]):
super().__init__()

if isinstance(shape, Sequence):
shape = tuple(shape)
self.shape = shape

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.reshape(data, self.shape)


def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.reshape(data, self.shape)


def get_config(self) -> dict:
return {"shape": self.shape}
Loading
Loading