-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathode_model.py
105 lines (95 loc) · 4.58 KB
/
ode_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# MIT License
#
# Copyright (c) 2022 Matthieu Kirchmeyer & Yuan Yin
import numpy as np
from torchdiffeq import odeint
from network import *
from utils import *
class Derivative(nn.Module):
def __init__(self, state_c, hidden_c, code_c, n_env, factor, nl, dataset, is_ode, size=64, is_layer=False,
layers=[0], logger=None, mask=None, codes_init=None, device="cuda", htype="hyper", **kwargs):
super().__init__()
self.is_ode = is_ode
self.size = size
self.is_layer = is_layer
self.logger = logger
self.codes = nn.Parameter(0. * torch.ones(n_env, code_c)) if codes_init is None else codes_init
# Bias
if self.is_ode:
self.net_root = GroupConvMLP(state_c, hidden_c, groups=1, factor=factor, nl=nl)
elif dataset == "gray" or dataset == "wave":
self.net_root = GroupConv(state_c, hidden_c, groups=1, factor=factor, nl=nl, size=size)
elif dataset == "navier":
self.net_root = GroupFNO2d(state_c, nl=nl, groups=1)
n_param_tot = count_parameters(self.net_root)
n_param_mask = n_param_tot if not is_layer else get_n_param_layer(self.net_root, layers)
n_param_hypernet = n_param_mask
if logger:
self.logger.info(f"Params: n_mask {n_param_mask} / n_tot {n_param_tot} / n_hypernet {n_param_hypernet}")
# Hypernet
self.net_hyper = nn.Linear(code_c, n_param_hypernet, bias=False)
# Ghost
if self.is_ode:
self.ghost_structure = GroupConvMLP(state_c, hidden_c, groups=n_env, factor=factor, nl=nl)
elif dataset == "gray" or dataset == "wave":
self.ghost_structure = GroupConv(state_c, hidden_c, groups=n_env, factor=factor, nl=nl, size=size)
elif dataset == "navier":
self.ghost_structure = GroupFNO2d(state_c, nl=nl, groups=n_env)
else:
raise Exception(f"{dataset} net not implemented")
set_requires_grad(self.ghost_structure, False)
# Mask
if is_layer and mask is None:
self.mask = {"mask": generate_mask(self.net_root, "layer", layers)}
else:
self.mask = {"mask": mask}
# Total
self.net_leaf = HyperEnvNet(self.net_root, self.ghost_structure, self.net_hyper, self.codes, logger, self.mask["mask"], device, **kwargs)
def update_ghost(self):
self.net_leaf.update_ghost()
def forward(self, t, u):
return self.net_leaf(u)
class Forecaster(nn.Module):
def __init__(self, state_c, hidden_c, code_c, n_env, factor, options=None, method=None, nl="swish", dataset="lotka",
size=64, is_layer=False, is_ode=True, layers=[0], logger=None, mask=None, codes_init=None, device="cuda", htype='hyper', **kwargs):
super().__init__()
self.method = method
self.options = options
self.is_layer = is_layer
self.int_ = odeint
self.is_ode = is_ode
self.logger = logger
self.derivative = Derivative(state_c, hidden_c, code_c, n_env, factor, nl, dataset, is_ode, size, is_layer,
layers, self.logger, mask, codes_init, device, htype, **kwargs)
def forward(self, y, t, epsilon=0):
if epsilon < 1e-3:
epsilon = 0
y = y.permute(2, 0, 1) if self.is_ode else y.permute(2, 0, 1, 3, 4)
if epsilon == 0:
res = self.int_(self.derivative, y0=y[0], t=t, method=self.method, options=self.options)
else:
eval_points = np.random.random(len(t)) < epsilon
eval_points[-1] = False
eval_points = eval_points[1:]
start_i, end_i = 0, None
res = []
for i, eval_point in enumerate(eval_points):
if eval_point == True:
end_i = i + 1
t_seg = t[start_i:end_i + 1]
res_seg = self.int_(self.derivative, y0=y[start_i], t=t_seg,
method=self.method, options=self.options)
if len(res) == 0:
res.append(res_seg)
else:
res.append(res_seg[1:])
start_i = end_i
t_seg = t[start_i:]
res_seg = self.int_(self.derivative, y0=y[start_i], t=t_seg, method=self.method,
options=self.options)
if len(res) == 0:
res.append(res_seg)
else:
res.append(res_seg[1:])
res = torch.cat(res, dim=0)
return res.permute(1, 2, 0) if self.is_ode else res.permute(1, 2, 0, 3, 4)