Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/torchphysics/models/FNO.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def __init__(self, channels, mode_num,
self.skip_connection : bool = skip_connection
# Values for Fourier transformation
self.mode_num = torch.tensor(mode_num)
self.mode_num[-1] = self.mode_num[-1] // 2 + 1
self.data_dim = len(mode_num)
self.fourier_dims = list(range(1, self.data_dim+1))

# Learnable parameters
self.fourier_kernel = nn.Parameter(
torch.empty((*self.mode_num, self.channels), dtype=torch.cfloat))
torch.empty((*self.mode_num, self.channels, self.channels), dtype=torch.cfloat))
nn.init.xavier_normal_(self.fourier_kernel, gain=xavier_gain)

self.linear_connection : bool = linear_connection
Expand All @@ -56,23 +57,44 @@ def __init__(self, channels, mode_num,
_Permute([0, *range(2, self.data_dim+2), 1])
)

self.mode_slice = self.compute_mode_slice(self.mode_num)

def compute_mode_slice(self, mode_nums):
mode_slice = []
if len(mode_nums) > 1:
for n in mode_nums[:-1]:
mode_ls = list(range(-(n//2), 0)) + list(range(0, n // 2 + n % 2))
mode_slice.append(mode_ls)
grids = torch.meshgrid(*[torch.tensor(idxs) for idxs in mode_slice], indexing='ij')
return (slice(None), *grids, slice(0, mode_nums[-1]), slice(None))
else:
return (slice(None), slice(0, mode_nums[-1]), slice(None))

def forward(self, points):
fft = torch.fft.rfftn(points, dim=self.fourier_dims)
fft = torch.fft.rfftn(points, dim=self.fourier_dims, norm='forward')

# Next add zeros or remove fourier modes to fit input for wanted freq.
original_fft_shape = torch.tensor(fft.shape[1:-1])
# padding needs to extra values, since the torch.nn.functional.pad starts
# from the last dimension (the channels in our case), there we dont need to
# change anything so only zeros in the padding.
padding = torch.zeros(2*self.data_dim + 2, device=points.device, dtype=torch.int32)
padding[3::2] = torch.flip((self.mode_num - original_fft_shape), dims=(0,))
if torch.any(original_fft_shape < self.mode_num):
min_mode_nums = torch.minimum(self.mode_num, original_fft_shape)
zeros = torch.zeros(points.shape[0], *self.mode_num, points.shape[-1], device=fft.device, dtype=fft.dtype)
slc = self.compute_mode_slice(min_mode_nums)
zeros[slc] = fft[slc]
fft = zeros
fft_in_shape = tuple(fft.shape)

fft = fft[self.mode_slice]

# fft is of shape (batch_dim, *mode_nums, channels)
fft = (self.fourier_kernel @ fft[..., None]).squeeze(-1)

fft = nn.functional.pad(fft, padding.tolist())

fft *= self.fourier_kernel
out_zeros = torch.zeros(*fft_in_shape, device=fft.device, dtype=fft.dtype)
out_zeros[self.mode_slice] = fft

ifft = torch.fft.irfftn(fft, s=points.shape[1:-1], dim=self.fourier_dims)
ifft = torch.fft.irfftn(out_zeros, s=points.shape[1:-1], dim=self.fourier_dims, norm='forward')

if self.linear_connection:
ifft += self.linear_transform(points)
Expand Down Expand Up @@ -104,7 +126,7 @@ class FNO(Model):
The number of hidden channels.
fourier_modes : int or list, tuple
The number of Fourier modes that will be used for the spectral convolution
in each layer. Modes over the given value will be truncated, and in case
in each layer. Modes above the given value will be truncated, and in case
of not enough modes they are padded with 0.
In case of a 1D space domain you can pass in one integer or a list of
integers, such that in each layer a different amount of modes is used.
Expand Down Expand Up @@ -156,7 +178,7 @@ class FNO(Model):
Differential Equations", 2020
"""
def __init__(self, input_space, output_space, fourier_layers : int,
hidden_channels : int = 16, fourier_modes = 16, activations=torch.nn.Tanh(),
hidden_channels : int = 16, fourier_modes = 16, activations=torch.nn.GELU(),
skip_connections = False, linear_connections = True, bias = True,
channel_up_sample_network = None, channel_down_sample_network = None,
xavier_gains=5.0/3, space_resolution = None):
Expand Down
33 changes: 25 additions & 8 deletions src/torchphysics/problem/conditions/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class DataCondition(Condition):
The 'norm' which should be computed for evaluation. If 'inf', maximum norm will
be used. Else, the result will be taken to the n-th potency (without computing the
root!)
relative : bool
Whether to compute the relative error (i.e. error / target) or absolute error.
root : float
the n-th root to be computed to obtain the final loss. E.g., if norm=2, root=2, the
loss is the 2-norm.
Expand All @@ -116,20 +118,22 @@ def __init__(
self,
module,
dataloader,
norm,
root=1.0,
norm=2,
relative=True,
use_full_dataset=False,
name="datacondition",
constrain_fn=None,
weight=1.0,
epsilon=1e-8,
):
super().__init__(name=name, weight=weight, track_gradients=False)
self.module = module
self.dataloader = dataloader
self.norm = norm
self.root = root
self.relative = relative
self.use_full_dataset = use_full_dataset
self.constrain_fn = constrain_fn
self.epsilon = epsilon
if self.constrain_fn:
self.constrain_fn = UserFunction(self.constrain_fn)

Expand All @@ -141,7 +145,22 @@ def _compute_dist(self, batch, device):
model_out = self.constrain_fn({**model_out.coordinates, **x.coordinates})
else:
model_out = model_out.as_tensor
return torch.abs(model_out - y.as_tensor)
if self.relative:
if self.norm == "inf":
out_norm = torch.max(torch.abs(model_out - y.as_tensor),
dim=list(range(1, len(model_out.shape))))
y_norm = torch.max(torch.abs(y.as_tensor), dim=list(range(1, len(model_out.shape)))) + self.epsilon
out = out_norm / y_norm
else:
out_norm = torch.norm(model_out - y.as_tensor, p=self.norm, dim=list(range(1, len(model_out.shape))))
y_norm = torch.norm(y.as_tensor, p=self.norm, dim=list(range(1, len(model_out.shape)))) + self.epsilon
out = out_norm / y_norm
else:
if self.norm == "inf":
out = torch.abs(model_out - y.as_tensor)
else:
out = torch.norm(model_out - y.as_tensor, p=self.norm, dim=list(range(1, len(model_out.shape))))
return out

def forward(self, device="cpu", iteration=None):
if self.use_full_dataset:
Expand All @@ -151,7 +170,7 @@ def forward(self, device="cpu", iteration=None):
if self.norm == "inf":
loss = torch.maximum(loss, torch.max(a))
else:
loss = loss + torch.mean(a**self.norm) / len(self.dataloader)
loss = loss + torch.mean(a) / len(self.dataloader)
else:
try:
batch = next(self.iterator)
Expand All @@ -162,9 +181,7 @@ def forward(self, device="cpu", iteration=None):
if self.norm == "inf":
loss = torch.max(a)
else:
loss = torch.mean(a**self.norm)
if self.root != 1.0:
loss = loss ** (1 / self.root)
loss = torch.mean(a)
return loss


Expand Down
5 changes: 4 additions & 1 deletion src/torchphysics/problem/samplers/function_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@ class FunctionSamplerOrdered(FunctionSampler):
def __init__(self, n_functions, function_set : FunctionSet, function_creation_interval : int = 0):
super().__init__(n_functions, function_set, function_creation_interval)
self.current_indices = torch.arange(self.n_functions, dtype=torch.int64)
self.new_indieces = torch.zeros_like(self.current_indices, dtype=torch.int64)


def sample_functions(self, device="cpu"):
self._check_recreate_functions(device=device)
self.current_indices = self.new_indieces.clone()
current_out = self.function_set.get_function(self.current_indices)
self.current_indices = (self.current_indices + self.n_functions) % self.function_set.function_set_size
self.new_indieces = (self.current_indices + self.n_functions) % self.function_set.function_set_size
return current_out


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,10 @@ def jac(model_out, *derivative_variable):
Du_i = []
for vari in derivative_variable:
Du_i.append(
torch.autograd.grad(model_out[:, i].sum(), vari, create_graph=True)[0]
torch.autograd.grad(model_out[..., i].sum(), vari, create_graph=True)[0]
)
Du_rows.append(torch.cat(Du_i, dim=1))
Du = torch.stack(Du_rows, dim=1)
Du_rows.append(torch.cat(Du_i, dim=-1))
Du = torch.stack(Du_rows, dim=-2)
return Du


Expand Down Expand Up @@ -284,10 +284,10 @@ def rot(model_out, *derivative_variable):
""
"""
jacobian = jac(model_out, *derivative_variable)
rotation = torch.zeros((len(derivative_variable[0]), 3))
rotation[:, 0] = jacobian[:, 2, 1] - jacobian[:, 1, 2]
rotation[:, 1] = jacobian[:, 0, 2] - jacobian[:, 2, 0]
rotation[:, 2] = jacobian[:, 1, 0] - jacobian[:, 0, 1]
rotation = torch.zeros((*(jacobian.shape[:-2]), 3))
rotation[..., 0] = jacobian[..., 2, 1] - jacobian[..., 1, 2]
rotation[..., 1] = jacobian[..., 0, 2] - jacobian[..., 2, 0]
rotation[..., 2] = jacobian[..., 1, 0] - jacobian[..., 0, 1]
return rotation


Expand Down Expand Up @@ -359,7 +359,7 @@ def sym_grad(model_out, *derivative_variable):
symmetric gradient.
"""
jac_matrix = jac(model_out, *derivative_variable)
return 0.5 * (jac_matrix + torch.transpose(jac_matrix, 1, 2))
return 0.5 * (jac_matrix + torch.transpose(jac_matrix, -2, -1))


def matrix_div(model_out, *derivative_variable):
Expand Down
16 changes: 14 additions & 2 deletions tests/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,24 @@ def test_datacondition_forward():
def test_datacondition_forward_2():
module = UserFunction(helper_fn)
loader = PointsDataLoader((Points(torch.tensor([[0.0], [2.0]]), R1('x')),
Points(torch.tensor([[0.0], [1.0]]), R1('u'))),
Points(torch.tensor([[0.0], [2.0]]), R1('u'))),
batch_size=1)
cond = DataCondition(module=module, dataloader=loader,
norm=2, relative=False, use_full_dataset=True)
out = cond()
assert out == 1.0


def test_datacondition_forward_relative():
module = UserFunction(helper_fn)
loader = PointsDataLoader((Points(torch.tensor([[0.0], [2.0]]), R1('x')),
Points(torch.tensor([[0.0], [2.0]]), R1('u'))),
batch_size=1)
cond = DataCondition(module=module, dataloader=loader,
norm=2, use_full_dataset=True)
out = cond()
assert out == 4.5
print(out)
assert out == 0.5


def test_create_pinncondition():
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_models/test_fno.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
def test_create_fourier_layer():
fourier_layer = _FourierLayer(4, 4)
assert fourier_layer.data_dim == 1
assert fourier_layer.fourier_kernel.shape == (4, 4)
assert fourier_layer.fourier_kernel.shape == (3, 4, 4)


def test_create_fourier_layer_higher_dim():
fourier_layer = _FourierLayer(8, (4, 6))
assert fourier_layer.data_dim == 2
assert fourier_layer.fourier_kernel.shape == (4, 6, 8)
assert fourier_layer.fourier_kernel.shape == (4, 4, 8, 8)


def test_create_fourier_layer_with_linear_transform():
Expand Down
16 changes: 8 additions & 8 deletions tests/tests_sampler/test_function_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_ordered_function_sampler_sample():
fn_set = make_default_fn_set()
fn_sampler = FunctionSamplerOrdered(20, fn_set, 100)
fns = fn_sampler.sample_functions()
assert torch.all(fn_sampler.current_indices >= 20)
assert torch.all(fn_sampler.current_indices < 40)
assert torch.all(fn_sampler.new_indieces >= 20)
assert torch.all(fn_sampler.new_indieces < 40)
assert callable(fns)


Expand All @@ -81,8 +81,8 @@ def test_ordered_function_sampler_sample_two_times():
fn_sampler = FunctionSamplerOrdered(20, fn_set, 100)
fns = fn_sampler.sample_functions()
fns = fn_sampler.sample_functions()
assert torch.all(fn_sampler.current_indices >= 40)
assert torch.all(fn_sampler.current_indices < 60)
assert torch.all(fn_sampler.new_indieces >= 40)
assert torch.all(fn_sampler.new_indieces < 60)
assert callable(fns)


Expand All @@ -91,8 +91,8 @@ def test_ordered_function_sampler_sample_multiple_times():
fn_sampler = FunctionSamplerOrdered(20, fn_set, 100)
for _ in range(5):
_ = fn_sampler.sample_functions()
assert torch.all(fn_sampler.current_indices >= 0)
assert torch.all(fn_sampler.current_indices < 20)
assert torch.all(fn_sampler.new_indieces >= 0)
assert torch.all(fn_sampler.new_indieces < 20)


def test_create_coupled_function_sampler():
Expand All @@ -110,5 +110,5 @@ def test_coupled_function_sampler_sample():
fn_sampler2 = FunctionSamplerCoupled(fn_set, fn_sampler)
_ = fn_sampler.sample_functions()
_ = fn_sampler2.sample_functions()
assert torch.all(fn_sampler2.current_indices >= 20)
assert torch.all(fn_sampler2.current_indices < 40)
assert torch.all(fn_sampler2.current_indices >= 0)
assert torch.all(fn_sampler2.current_indices < 20)
17 changes: 17 additions & 0 deletions tests/tests_utils/test_differentialoperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,23 @@ def f(x, y):
assert torch.allclose(d[1], torch.tensor([[12.0, 9.0], [torch.exp(a[1]), 1.0]]))


def test_jac_for_different_input_shape():
def f(x, y):
out = torch.zeros((*x.shape[:-1], 2))
out[..., :1] = x**2 * y
out[..., 1:] = y + torch.exp(x)
return out
a = torch.tensor([[[0.0], [3.0]], [[0.0], [3.0]]], requires_grad=True)
b = torch.tensor([[[1.0], [2.0]], [[1.0], [2.0]]], requires_grad=True)
output = f(a, b)
d = jac(output, a, b)
exp_value = torch.exp(torch.tensor(3.0))
assert d.shape == (2, 2, 2, 2)
assert torch.allclose(d[:, 0], torch.tensor([[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [1.0, 1.0]]]))
assert torch.allclose(d[:, 1], torch.tensor([[[12.0, 9.0], [exp_value, 1.0]],
[[12.0, 9.0], [exp_value, 1.0]]]), atol=0.001)


# Test rot
def rot_function(x):
out = torch.zeros((len(x), 3))
Expand Down