diff --git a/src/torchphysics/models/FNO.py b/src/torchphysics/models/FNO.py index dd2592d..5788921 100644 --- a/src/torchphysics/models/FNO.py +++ b/src/torchphysics/models/FNO.py @@ -30,12 +30,13 @@ def __init__(self, channels, mode_num, self.skip_connection : bool = skip_connection # Values for Fourier transformation self.mode_num = torch.tensor(mode_num) + self.mode_num[-1] = self.mode_num[-1] // 2 + 1 self.data_dim = len(mode_num) self.fourier_dims = list(range(1, self.data_dim+1)) # Learnable parameters self.fourier_kernel = nn.Parameter( - torch.empty((*self.mode_num, self.channels), dtype=torch.cfloat)) + torch.empty((*self.mode_num, self.channels, self.channels), dtype=torch.cfloat)) nn.init.xavier_normal_(self.fourier_kernel, gain=xavier_gain) self.linear_connection : bool = linear_connection @@ -56,23 +57,44 @@ def __init__(self, channels, mode_num, _Permute([0, *range(2, self.data_dim+2), 1]) ) + self.mode_slice = self.compute_mode_slice(self.mode_num) + + def compute_mode_slice(self, mode_nums): + mode_slice = [] + if len(mode_nums) > 1: + for n in mode_nums[:-1]: + mode_ls = list(range(-(n//2), 0)) + list(range(0, n // 2 + n % 2)) + mode_slice.append(mode_ls) + grids = torch.meshgrid(*[torch.tensor(idxs) for idxs in mode_slice], indexing='ij') + return (slice(None), *grids, slice(0, mode_nums[-1]), slice(None)) + else: + return (slice(None), slice(0, mode_nums[-1]), slice(None)) def forward(self, points): - fft = torch.fft.rfftn(points, dim=self.fourier_dims) + fft = torch.fft.rfftn(points, dim=self.fourier_dims, norm='forward') # Next add zeros or remove fourier modes to fit input for wanted freq. original_fft_shape = torch.tensor(fft.shape[1:-1]) # padding needs to extra values, since the torch.nn.functional.pad starts # from the last dimension (the channels in our case), there we dont need to # change anything so only zeros in the padding. - padding = torch.zeros(2*self.data_dim + 2, device=points.device, dtype=torch.int32) - padding[3::2] = torch.flip((self.mode_num - original_fft_shape), dims=(0,)) + if torch.any(original_fft_shape < self.mode_num): + min_mode_nums = torch.minimum(self.mode_num, original_fft_shape) + zeros = torch.zeros(points.shape[0], *self.mode_num, points.shape[-1], device=fft.device, dtype=fft.dtype) + slc = self.compute_mode_slice(min_mode_nums) + zeros[slc] = fft[slc] + fft = zeros + fft_in_shape = tuple(fft.shape) + + fft = fft[self.mode_slice] + + # fft is of shape (batch_dim, *mode_nums, channels) + fft = (self.fourier_kernel @ fft[..., None]).squeeze(-1) - fft = nn.functional.pad(fft, padding.tolist()) - - fft *= self.fourier_kernel + out_zeros = torch.zeros(*fft_in_shape, device=fft.device, dtype=fft.dtype) + out_zeros[self.mode_slice] = fft - ifft = torch.fft.irfftn(fft, s=points.shape[1:-1], dim=self.fourier_dims) + ifft = torch.fft.irfftn(out_zeros, s=points.shape[1:-1], dim=self.fourier_dims, norm='forward') if self.linear_connection: ifft += self.linear_transform(points) @@ -104,7 +126,7 @@ class FNO(Model): The number of hidden channels. fourier_modes : int or list, tuple The number of Fourier modes that will be used for the spectral convolution - in each layer. Modes over the given value will be truncated, and in case + in each layer. Modes above the given value will be truncated, and in case of not enough modes they are padded with 0. In case of a 1D space domain you can pass in one integer or a list of integers, such that in each layer a different amount of modes is used. @@ -156,7 +178,7 @@ class FNO(Model): Differential Equations", 2020 """ def __init__(self, input_space, output_space, fourier_layers : int, - hidden_channels : int = 16, fourier_modes = 16, activations=torch.nn.Tanh(), + hidden_channels : int = 16, fourier_modes = 16, activations=torch.nn.GELU(), skip_connections = False, linear_connections = True, bias = True, channel_up_sample_network = None, channel_down_sample_network = None, xavier_gains=5.0/3, space_resolution = None): diff --git a/src/torchphysics/problem/conditions/condition.py b/src/torchphysics/problem/conditions/condition.py index 9e98a47..af4814b 100644 --- a/src/torchphysics/problem/conditions/condition.py +++ b/src/torchphysics/problem/conditions/condition.py @@ -94,6 +94,8 @@ class DataCondition(Condition): The 'norm' which should be computed for evaluation. If 'inf', maximum norm will be used. Else, the result will be taken to the n-th potency (without computing the root!) + relative : bool + Whether to compute the relative error (i.e. error / target) or absolute error. root : float the n-th root to be computed to obtain the final loss. E.g., if norm=2, root=2, the loss is the 2-norm. @@ -116,20 +118,22 @@ def __init__( self, module, dataloader, - norm, - root=1.0, + norm=2, + relative=True, use_full_dataset=False, name="datacondition", constrain_fn=None, weight=1.0, + epsilon=1e-8, ): super().__init__(name=name, weight=weight, track_gradients=False) self.module = module self.dataloader = dataloader self.norm = norm - self.root = root + self.relative = relative self.use_full_dataset = use_full_dataset self.constrain_fn = constrain_fn + self.epsilon = epsilon if self.constrain_fn: self.constrain_fn = UserFunction(self.constrain_fn) @@ -141,7 +145,22 @@ def _compute_dist(self, batch, device): model_out = self.constrain_fn({**model_out.coordinates, **x.coordinates}) else: model_out = model_out.as_tensor - return torch.abs(model_out - y.as_tensor) + if self.relative: + if self.norm == "inf": + out_norm = torch.max(torch.abs(model_out - y.as_tensor), + dim=list(range(1, len(model_out.shape)))) + y_norm = torch.max(torch.abs(y.as_tensor), dim=list(range(1, len(model_out.shape)))) + self.epsilon + out = out_norm / y_norm + else: + out_norm = torch.norm(model_out - y.as_tensor, p=self.norm, dim=list(range(1, len(model_out.shape)))) + y_norm = torch.norm(y.as_tensor, p=self.norm, dim=list(range(1, len(model_out.shape)))) + self.epsilon + out = out_norm / y_norm + else: + if self.norm == "inf": + out = torch.abs(model_out - y.as_tensor) + else: + out = torch.norm(model_out - y.as_tensor, p=self.norm, dim=list(range(1, len(model_out.shape)))) + return out def forward(self, device="cpu", iteration=None): if self.use_full_dataset: @@ -151,7 +170,7 @@ def forward(self, device="cpu", iteration=None): if self.norm == "inf": loss = torch.maximum(loss, torch.max(a)) else: - loss = loss + torch.mean(a**self.norm) / len(self.dataloader) + loss = loss + torch.mean(a) / len(self.dataloader) else: try: batch = next(self.iterator) @@ -162,9 +181,7 @@ def forward(self, device="cpu", iteration=None): if self.norm == "inf": loss = torch.max(a) else: - loss = torch.mean(a**self.norm) - if self.root != 1.0: - loss = loss ** (1 / self.root) + loss = torch.mean(a) return loss diff --git a/src/torchphysics/problem/samplers/function_sampler.py b/src/torchphysics/problem/samplers/function_sampler.py index aeea683..8c1117e 100644 --- a/src/torchphysics/problem/samplers/function_sampler.py +++ b/src/torchphysics/problem/samplers/function_sampler.py @@ -76,11 +76,14 @@ class FunctionSamplerOrdered(FunctionSampler): def __init__(self, n_functions, function_set : FunctionSet, function_creation_interval : int = 0): super().__init__(n_functions, function_set, function_creation_interval) self.current_indices = torch.arange(self.n_functions, dtype=torch.int64) + self.new_indieces = torch.zeros_like(self.current_indices, dtype=torch.int64) + def sample_functions(self, device="cpu"): self._check_recreate_functions(device=device) + self.current_indices = self.new_indieces.clone() current_out = self.function_set.get_function(self.current_indices) - self.current_indices = (self.current_indices + self.n_functions) % self.function_set.function_set_size + self.new_indieces = (self.current_indices + self.n_functions) % self.function_set.function_set_size return current_out diff --git a/src/torchphysics/utils/differentialoperators/differentialoperators.py b/src/torchphysics/utils/differentialoperators/differentialoperators.py index 588e43a..ba2ba4a 100644 --- a/src/torchphysics/utils/differentialoperators/differentialoperators.py +++ b/src/torchphysics/utils/differentialoperators/differentialoperators.py @@ -251,10 +251,10 @@ def jac(model_out, *derivative_variable): Du_i = [] for vari in derivative_variable: Du_i.append( - torch.autograd.grad(model_out[:, i].sum(), vari, create_graph=True)[0] + torch.autograd.grad(model_out[..., i].sum(), vari, create_graph=True)[0] ) - Du_rows.append(torch.cat(Du_i, dim=1)) - Du = torch.stack(Du_rows, dim=1) + Du_rows.append(torch.cat(Du_i, dim=-1)) + Du = torch.stack(Du_rows, dim=-2) return Du @@ -284,10 +284,10 @@ def rot(model_out, *derivative_variable): "" """ jacobian = jac(model_out, *derivative_variable) - rotation = torch.zeros((len(derivative_variable[0]), 3)) - rotation[:, 0] = jacobian[:, 2, 1] - jacobian[:, 1, 2] - rotation[:, 1] = jacobian[:, 0, 2] - jacobian[:, 2, 0] - rotation[:, 2] = jacobian[:, 1, 0] - jacobian[:, 0, 1] + rotation = torch.zeros((*(jacobian.shape[:-2]), 3)) + rotation[..., 0] = jacobian[..., 2, 1] - jacobian[..., 1, 2] + rotation[..., 1] = jacobian[..., 0, 2] - jacobian[..., 2, 0] + rotation[..., 2] = jacobian[..., 1, 0] - jacobian[..., 0, 1] return rotation @@ -359,7 +359,7 @@ def sym_grad(model_out, *derivative_variable): symmetric gradient. """ jac_matrix = jac(model_out, *derivative_variable) - return 0.5 * (jac_matrix + torch.transpose(jac_matrix, 1, 2)) + return 0.5 * (jac_matrix + torch.transpose(jac_matrix, -2, -1)) def matrix_div(model_out, *derivative_variable): diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 70526be..b46f860 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -106,12 +106,24 @@ def test_datacondition_forward(): def test_datacondition_forward_2(): module = UserFunction(helper_fn) loader = PointsDataLoader((Points(torch.tensor([[0.0], [2.0]]), R1('x')), - Points(torch.tensor([[0.0], [1.0]]), R1('u'))), + Points(torch.tensor([[0.0], [2.0]]), R1('u'))), + batch_size=1) + cond = DataCondition(module=module, dataloader=loader, + norm=2, relative=False, use_full_dataset=True) + out = cond() + assert out == 1.0 + + +def test_datacondition_forward_relative(): + module = UserFunction(helper_fn) + loader = PointsDataLoader((Points(torch.tensor([[0.0], [2.0]]), R1('x')), + Points(torch.tensor([[0.0], [2.0]]), R1('u'))), batch_size=1) cond = DataCondition(module=module, dataloader=loader, norm=2, use_full_dataset=True) out = cond() - assert out == 4.5 + print(out) + assert out == 0.5 def test_create_pinncondition(): diff --git a/tests/tests_models/test_fno.py b/tests/tests_models/test_fno.py index 2fbed8c..e707edd 100644 --- a/tests/tests_models/test_fno.py +++ b/tests/tests_models/test_fno.py @@ -8,13 +8,13 @@ def test_create_fourier_layer(): fourier_layer = _FourierLayer(4, 4) assert fourier_layer.data_dim == 1 - assert fourier_layer.fourier_kernel.shape == (4, 4) + assert fourier_layer.fourier_kernel.shape == (3, 4, 4) def test_create_fourier_layer_higher_dim(): fourier_layer = _FourierLayer(8, (4, 6)) assert fourier_layer.data_dim == 2 - assert fourier_layer.fourier_kernel.shape == (4, 6, 8) + assert fourier_layer.fourier_kernel.shape == (4, 4, 8, 8) def test_create_fourier_layer_with_linear_transform(): diff --git a/tests/tests_sampler/test_function_sampler.py b/tests/tests_sampler/test_function_sampler.py index 2d29426..ebe5f13 100644 --- a/tests/tests_sampler/test_function_sampler.py +++ b/tests/tests_sampler/test_function_sampler.py @@ -71,8 +71,8 @@ def test_ordered_function_sampler_sample(): fn_set = make_default_fn_set() fn_sampler = FunctionSamplerOrdered(20, fn_set, 100) fns = fn_sampler.sample_functions() - assert torch.all(fn_sampler.current_indices >= 20) - assert torch.all(fn_sampler.current_indices < 40) + assert torch.all(fn_sampler.new_indieces >= 20) + assert torch.all(fn_sampler.new_indieces < 40) assert callable(fns) @@ -81,8 +81,8 @@ def test_ordered_function_sampler_sample_two_times(): fn_sampler = FunctionSamplerOrdered(20, fn_set, 100) fns = fn_sampler.sample_functions() fns = fn_sampler.sample_functions() - assert torch.all(fn_sampler.current_indices >= 40) - assert torch.all(fn_sampler.current_indices < 60) + assert torch.all(fn_sampler.new_indieces >= 40) + assert torch.all(fn_sampler.new_indieces < 60) assert callable(fns) @@ -91,8 +91,8 @@ def test_ordered_function_sampler_sample_multiple_times(): fn_sampler = FunctionSamplerOrdered(20, fn_set, 100) for _ in range(5): _ = fn_sampler.sample_functions() - assert torch.all(fn_sampler.current_indices >= 0) - assert torch.all(fn_sampler.current_indices < 20) + assert torch.all(fn_sampler.new_indieces >= 0) + assert torch.all(fn_sampler.new_indieces < 20) def test_create_coupled_function_sampler(): @@ -110,5 +110,5 @@ def test_coupled_function_sampler_sample(): fn_sampler2 = FunctionSamplerCoupled(fn_set, fn_sampler) _ = fn_sampler.sample_functions() _ = fn_sampler2.sample_functions() - assert torch.all(fn_sampler2.current_indices >= 20) - assert torch.all(fn_sampler2.current_indices < 40) \ No newline at end of file + assert torch.all(fn_sampler2.current_indices >= 0) + assert torch.all(fn_sampler2.current_indices < 20) \ No newline at end of file diff --git a/tests/tests_utils/test_differentialoperators.py b/tests/tests_utils/test_differentialoperators.py index 677a7d2..2e28411 100644 --- a/tests/tests_utils/test_differentialoperators.py +++ b/tests/tests_utils/test_differentialoperators.py @@ -511,6 +511,23 @@ def f(x, y): assert torch.allclose(d[1], torch.tensor([[12.0, 9.0], [torch.exp(a[1]), 1.0]])) +def test_jac_for_different_input_shape(): + def f(x, y): + out = torch.zeros((*x.shape[:-1], 2)) + out[..., :1] = x**2 * y + out[..., 1:] = y + torch.exp(x) + return out + a = torch.tensor([[[0.0], [3.0]], [[0.0], [3.0]]], requires_grad=True) + b = torch.tensor([[[1.0], [2.0]], [[1.0], [2.0]]], requires_grad=True) + output = f(a, b) + d = jac(output, a, b) + exp_value = torch.exp(torch.tensor(3.0)) + assert d.shape == (2, 2, 2, 2) + assert torch.allclose(d[:, 0], torch.tensor([[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [1.0, 1.0]]])) + assert torch.allclose(d[:, 1], torch.tensor([[[12.0, 9.0], [exp_value, 1.0]], + [[12.0, 9.0], [exp_value, 1.0]]]), atol=0.001) + + # Test rot def rot_function(x): out = torch.zeros((len(x), 3))