diff --git a/.github/workflows/docs-gh-pages.yml b/.github/workflows/docs-gh-pages.yml index eff0dbf..fa1940d 100644 --- a/.github/workflows/docs-gh-pages.yml +++ b/.github/workflows/docs-gh-pages.yml @@ -22,7 +22,6 @@ concurrency: jobs: build-docs: - if: github.repository == 'boschresearch/torchphysics' runs-on: [ubuntu-latest] container: python:3.10-bookworm steps: @@ -42,7 +41,6 @@ jobs: # Deployment job deploy-docs: - if: github.repository == 'boschresearch/torchphysics' environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} diff --git a/.gitignore b/.gitignore index 7587e5c..8688b9b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ __pycache__/* **/bosch/** **/experiments/** **/fluid_logs/** +**/logs/** # Project files .ropeproject .project diff --git a/src/torchphysics/models/FNO.py b/src/torchphysics/models/FNO.py index 5788921..53b0682 100644 --- a/src/torchphysics/models/FNO.py +++ b/src/torchphysics/models/FNO.py @@ -3,6 +3,7 @@ from .model import Model from ..problem.spaces import Points +from .embedding_layers import PositionalEmbedding class _Permute(nn.Module): def __init__(self, permute_dims): @@ -152,6 +153,9 @@ class FNO(Model): The network that transforms the hidden channel dimension to the output channel dimension. (The mapping Q in [1], Figure 2) Default is a linear mapping. + positional_embedding : torchphysics.models.PositionalEmbedding or bool + An additional embedding layer, which adds positional information to the input. + Default is True, adding an embedding given the shape of fourie modes. xavier_gains : int or list, tuple For the weight initialization a Xavier/Glorot algorithm will be used. The gain can be specified over this value. @@ -181,6 +185,7 @@ def __init__(self, input_space, output_space, fourier_layers : int, hidden_channels : int = 16, fourier_modes = 16, activations=torch.nn.GELU(), skip_connections = False, linear_connections = True, bias = True, channel_up_sample_network = None, channel_down_sample_network = None, + positional_embedding=True, xavier_gains=5.0/3, space_resolution = None): super().__init__(input_space, output_space) @@ -205,17 +210,32 @@ def __init__(self, input_space, output_space, fourier_layers : int, in_channels = self.input_space.dim out_channels = self.output_space.dim + if positional_embedding is True: + dim = 1 if isinstance(fourier_modes[0], int) else len(fourier_modes[0]) + positional_embedding = PositionalEmbedding(dim) + elif positional_embedding is False: + positional_embedding = None + if not channel_up_sample_network: + if positional_embedding is not None: + in_channels += positional_embedding.dim self.channel_up_sampling = nn.Linear(in_channels, - hidden_channels, - bias=True) + hidden_channels) else: + if positional_embedding is not None: + print("Note: Positional embedding is used, make sure that the network for " \ + f"lifiting (channel up sampling) expects inputs of size {in_channels+positional_embedding.dim}.") self.channel_up_sampling = channel_up_sample_network + # combine embedding with up_sampling layer: + if positional_embedding is not None: + self.channel_up_sampling = nn.Sequential( + positional_embedding, self.channel_up_sampling + ) + if not channel_down_sample_network: self.channel_down_sampling = nn.Linear(hidden_channels, - out_channels, - bias=True) + out_channels) else: self.channel_down_sampling = channel_down_sample_network diff --git a/src/torchphysics/models/__init__.py b/src/torchphysics/models/__init__.py index 891c7da..b3f95b6 100644 --- a/src/torchphysics/models/__init__.py +++ b/src/torchphysics/models/__init__.py @@ -26,6 +26,6 @@ # FNO: from .FNO import FNO, _FourierLayer - +from .embedding_layers import PositionalEmbedding # PCA: from .PCANN import PCANN, PCANN_FC \ No newline at end of file diff --git a/src/torchphysics/models/deeponet/deeponet.py b/src/torchphysics/models/deeponet/deeponet.py index c26285c..9f32b25 100644 --- a/src/torchphysics/models/deeponet/deeponet.py +++ b/src/torchphysics/models/deeponet/deeponet.py @@ -4,7 +4,7 @@ from ..model import Model, Sequential from .branchnets import BranchNet from .trunknets import TrunkNet - +from ...utils.user_fun import UserFunction class DeepONet(Model): """Implementation of the architecture used in the DeepONet paper [#]_. @@ -28,7 +28,10 @@ class DeepONet(Model): For higher dimensional outputs, will be multiplied my the dimension of the output space, so each dimension will have the same number of intermediate neurons. - + constrain_fn : callable + A constrain function that can be used to constrain the network + outputs. + Notes ----- The number of output neurons in the branch and trunk net have to be the same! @@ -38,12 +41,16 @@ class DeepONet(Model): based on the universal approximation theorem of operators", 2021 """ - def __init__(self, trunk_net, branch_net, output_space, output_neurons): + def __init__(self, trunk_net, branch_net, output_space, output_neurons, + constrain_fn = None): self._check_trunk_and_branch_correct(trunk_net, branch_net) super().__init__(input_space=trunk_net.input_space, output_space=output_space) self.trunk = trunk_net self.branch = branch_net self._finalize_trunk_and_branch(output_space, output_neurons) + self.constrain_fn = constrain_fn + if self.constrain_fn: + self.constrain_fn = UserFunction(self.constrain_fn) def _check_trunk_and_branch_correct(self, trunk_net, branch_net): """Checks if the trunk and branch net are compatible @@ -92,7 +99,15 @@ def forward(self, trunk_inputs=None, branch_inputs=None, device="cpu"): ) out = torch.sum(trunk_out * branch_out, dim=-1) - return Points(out, self.output_space) + points_out = Points(out, self.output_space) + if not self.constrain_fn is None: + if trunk_inputs is None: + trunk_inputs = self.trunk.default_trunk_input.repeat(self.branch.current_out.shape[0]) + + return Points(self.constrain_fn(trunk_inputs.join(points_out).coordinates), + self.output_space) + else: + return points_out def _forward_branch(self, function_set, iteration_num=-1, device="cpu"): """Branch evaluation for training.""" diff --git a/src/torchphysics/models/embedding_layers.py b/src/torchphysics/models/embedding_layers.py new file mode 100644 index 0000000..c4da03b --- /dev/null +++ b/src/torchphysics/models/embedding_layers.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + + +class PositionalEmbedding(nn.Module): + """ + Adds positional information to data provided on a uniform grid. + The input data is expected to have shape (batch, axis_1, ..., axis_n, channels), + and the output will have shape (batch, axis_1, ..., axis_n, channels + n). + + Parameters + ---------- + dim : int + The dimension of the underlying space for which positional + information should be appended. + coordinate_boundaries : list, optional + Boundaries of the underlying grid for each dimension. Expected to be a list + of lists, where each inner list contains the bounds for a dimension. + The default is [[0, 1]] * dim. + + Notes + ----- + The positional information is generated on the fly, depending on the input + shape. This makes the embedding resolution-independent. Since this class is mainly + used in connection with Fourier Neural Operators (FNOs), a uniform grid is created + in each direction. + """ + + def __init__(self, dim, coordinate_boundaries=None): + super().__init__() + self.dim = dim + if coordinate_boundaries is not None: + assert self.dim == len(coordinate_boundaries), \ + f"Dimension is {self.dim} and does not fit provided coordinates of shape {len(coordinate_boundaries)}!" + self.bounds = coordinate_boundaries + else: + self.bounds = [[0, 1]] * self.dim + self.register_buffer("_positions", torch.empty(0)) + + + def forward(self, points): + input_shape = points.as_tensor.shape + # If we have a new shape of the data, we need to create a new positional embedding + if not input_shape[1:-1] == self._positions.shape[1:-1]: + self._build_positional_embedding(input_shape, points.device, points.as_tensor.dtype) + + repeated_embedding = self._positions.repeat(input_shape[0], *[1] * (self.dim+1)) + return torch.cat((points.as_tensor, repeated_embedding), dim=-1) + + + def _build_positional_embedding(self, data_shape, device="cpu", dtype=torch.float32): + coordinate_grid = [] + for i in range(self.dim): + coordinate_grid.append( + torch.linspace(self.bounds[i][0], self.bounds[i][1], + data_shape[i+1], dtype=dtype, device=device) + ) + coordinate_meshgrid = torch.meshgrid(*coordinate_grid, indexing='ij') + self._positions = torch.cat([x.unsqueeze(-1) for x in coordinate_meshgrid], dim=-1) + self._positions = self._positions.unsqueeze(0) \ No newline at end of file diff --git a/src/torchphysics/problem/conditions/condition.py b/src/torchphysics/problem/conditions/condition.py index af4814b..c284baa 100644 --- a/src/torchphysics/problem/conditions/condition.py +++ b/src/torchphysics/problem/conditions/condition.py @@ -24,7 +24,7 @@ def forward(self, x): x : torch.tensor The values for which the squared error should be computed. """ - return torch.sum(torch.square(x), dim=1) + return torch.sum(torch.square(x), dim=list(range(1, len(x.shape)))) class Condition(torch.nn.Module): diff --git a/src/torchphysics/problem/conditions/deeponet_condition.py b/src/torchphysics/problem/conditions/deeponet_condition.py index 298d9c9..c96958a 100644 --- a/src/torchphysics/problem/conditions/deeponet_condition.py +++ b/src/torchphysics/problem/conditions/deeponet_condition.py @@ -60,9 +60,7 @@ def forward(self, device="cpu", iteration=None): trunk_points = self.trunk_points_sampler.sample_points(device=device) # TODO: make this more memory efficient (e.g. in DeepONet we know when data is just copied???) - trunk_points = trunk_points.unsqueeze(0).repeat( - self.branch_function_sampler.n_functions, 1, 1 - ) + trunk_points = trunk_points.unsqueeze(0).repeat(self.branch_function_sampler.n_functions) trunk_coordinates, trunk_points = trunk_points.track_coord_gradients() # 2) sample branch inputs diff --git a/src/torchphysics/problem/conditions/operator_condition.py b/src/torchphysics/problem/conditions/operator_condition.py index 5345a83..23dd358 100644 --- a/src/torchphysics/problem/conditions/operator_condition.py +++ b/src/torchphysics/problem/conditions/operator_condition.py @@ -5,17 +5,51 @@ from ...models.deeponet.deeponet import DeepONet class OperatorCondition(Condition): + """ + General condition used for the (data-driven) training of different + operator approaches. + Parameters + ---------- + module : torchphysics.Model + The torch module which should be fitted to data. + input_function_sampler : torch.utils.FunctionSampler + The sampler providing the input data to the module. + output_function_sampler : torch.utils.FunctionSampler + The expected output to a given input. + residual_fn : callable, optional + An optional function that computes the residual, by default + the network output minus the expected output is taken. + relative : bool, optional + Whether to compute the relative error (i.e. error / target) or absolute error. + Default is True, hence, the relative error is used. + error_fn : callable, optional + the function used to compute the final loss. E.g., the squarred error or + any other norm. + reduce_fn : callable, optional + Function that will be applied to reduce the loss to a scalar. Defaults to + torch.mean + name : str, optional + The name of this condition which will be monitored in logging. + weight : float, optional + The weight multiplied with the loss of this condition during + training. + epsilon : float, optional + For the relative loss, we add a small epsilon to the target to + circumvent dividing by 0, the default is 1.e-8. + """ def __init__( self, module, input_function_sampler, output_function_sampler, residual_fn=None, + relative=True, reduce_fn=torch.mean, error_fn=SquaredError(), name="operator_condition", weight=1.0, + epsilon=1e-8 ): super().__init__(name=name, weight=weight, track_gradients=False) assert input_function_sampler.function_set.is_discretized, \ @@ -32,6 +66,9 @@ def __init__( else: self.residual_fn = None + self.relative = relative + self.epsilon = epsilon + self.error_fn = error_fn self.reduce_fn = reduce_fn @@ -55,8 +92,13 @@ def forward(self, device="cpu", iteration=None): else: first_error = model_out.as_tensor - output_functions.as_tensor - return self.reduce_fn(self.error_fn(first_error)) + out = self.error_fn(first_error) + + if self.relative: + y_norm = self.error_fn(output_functions.as_tensor) + self.epsilon + out = out / y_norm + return self.reduce_fn(out) class PIOperatorCondition(Condition): diff --git a/src/torchphysics/problem/domains/domain2D/triangle.py b/src/torchphysics/problem/domains/domain2D/triangle.py index 6375fc7..0db2b97 100644 --- a/src/torchphysics/problem/domains/domain2D/triangle.py +++ b/src/torchphysics/problem/domains/domain2D/triangle.py @@ -103,14 +103,14 @@ def sample_random_uniform( origin, _, _, dir_1, _, dir_3 = self._construct_triangle(params, device) num_of_params = self.len_of_params(params) bary_coords = torch.rand((num_of_params, n, 2), device=device) - bary_coords = self._handle_sum_greater_1(d, bary_coords) + bary_coords = self._handle_sum_greater_1(d, bary_coords, device=device) points_in_dir_1 = bary_coords[:, :, :1] * dir_1[:, None] points_in_dir_2 = -bary_coords[:, :, 1:] * dir_3[:, None] points = points_in_dir_1 + points_in_dir_2 points += origin[:, None, :] return Points(points.reshape(-1, self.space.dim), self.space) - def _handle_sum_greater_1(self, d, bary_coords): + def _handle_sum_greater_1(self, d, bary_coords, device="cpu"): sum_bigger_one = bary_coords.sum(axis=2) >= 1 if d: # for a given density just remove the points index = torch.where(torch.logical_not(sum_bigger_one)) @@ -120,7 +120,7 @@ def _handle_sum_greater_1(self, d, bary_coords): # This stays uniform. index = torch.where(sum_bigger_one) bary_coords[index] = torch.subtract( - torch.tensor([[1.0, 1.0]]), bary_coords[index] + torch.tensor([[1.0, 1.0]], device=device), bary_coords[index] ) return bary_coords diff --git a/src/torchphysics/problem/domains/domainoperations/union.py b/src/torchphysics/problem/domains/domainoperations/union.py index ce3a46a..f5c0c44 100644 --- a/src/torchphysics/problem/domains/domainoperations/union.py +++ b/src/torchphysics/problem/domains/domainoperations/union.py @@ -155,6 +155,7 @@ class UnionBoundaryDomain(BoundaryDomain): def __init__(self, domain: UnionDomain): assert not isinstance(domain.domain_a, BoundaryDomain) assert not isinstance(domain.domain_b, BoundaryDomain) + self.overlap_tol = 0.5 super().__init__(domain) def _contains(self, points, params=Points.empty()): @@ -162,10 +163,33 @@ def _contains(self, points, params=Points.empty()): in_b = self.domain.domain_b._contains(points, params) on_a_bound = self.domain.domain_a.boundary._contains(points, params) on_b_bound = self.domain.domain_b.boundary._contains(points, params) - on_both = torch.logical_and(on_b_bound, on_a_bound) + on_both = torch.logical_and(on_b_bound, on_a_bound) on_a_part = torch.logical_and(on_a_bound, torch.logical_not(in_b)) on_b_part = torch.logical_and(on_b_bound, torch.logical_not(in_a)) - return torch.logical_or(on_a_part, torch.logical_or(on_b_part, on_both)) + + # if on the both lay on both boundaries it could still happen that + # the boundary is in the inside of the union, this we can only check + # via a normal test + overlap_points = torch.ones_like(on_both, dtype=torch.bool) + if torch.any(on_both): + index_tensor = on_both.clone().flatten() + if not params.isempty: + sliced_params = params[index_tensor] + else: + sliced_params = params + + normals_a = self.domain.domain_a.boundary.normal( + points[index_tensor], sliced_params, device=points.device + ) + normals_b = self.domain.domain_b.boundary.normal( + points[index_tensor], sliced_params, device=points.device + ) + + inner_product_ok = torch.sum(normals_a*normals_b, dim=-1, keepdim=True) >= self.overlap_tol + overlap_points[index_tensor] = inner_product_ok + + default_check = torch.logical_or(on_a_part, torch.logical_or(on_b_part, on_both)) + return torch.logical_and(default_check, overlap_points) def _get_volume(self, params=Points.empty(), device="cpu"): if not self.domain.disjoint: diff --git a/tests/tests_models/test_fno.py b/tests/tests_models/test_fno.py index e707edd..83fb623 100644 --- a/tests/tests_models/test_fno.py +++ b/tests/tests_models/test_fno.py @@ -70,7 +70,8 @@ def test_create_fno_optional(): bias=[False, False], channel_up_sample_network=in_network, channel_down_sample_network=out_network, - activations=torch.nn.Tanh()) + activations=torch.nn.Tanh(), + positional_embedding=False) assert fno.channel_up_sampling == in_network assert fno.channel_down_sampling == out_network