Skip to content

Commit 9546d7e

Browse files
authored
dde.gradients support 3D outputs (#1928)
1 parent 5d1567f commit 9546d7e

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

deepxde/gradients/gradients.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def jacobian(ys, xs, i=None, j=None):
1919
computation.
2020
2121
Args:
22-
ys: Output Tensor of shape (batch_size, dim_y).
22+
ys: Output Tensor of shape (batch_size, dim_y) or (batch_size_out, batch_size,
23+
dim_y). Here, the `batch_size` is the same one for `xs`, and
24+
`batch_size_out` is the batch size for an additional/outer dimension.
2325
xs: Input Tensor of shape (batch_size, dim_x).
2426
i (int or None): `i`th row. If `i` is ``None``, returns the `j`th column
2527
J[:, `j`].
@@ -29,6 +31,9 @@ def jacobian(ys, xs, i=None, j=None):
2931
3032
Returns:
3133
(`i`, `j`)th entry J[`i`, `j`], `i`th row J[`i`, :], or `j`th column J[:, `j`].
34+
When `ys` has shape (batch_size, dim_y), the output shape is (batch_size, 1).
35+
When `ys` has shape (batch_size_out, batch_size, dim_y), the output shape is
36+
(batch_size_out, batch_size, 1).
3237
"""
3338
if config.autodiff == "reverse":
3439
return gradients_reverse.jacobian(ys, xs, i=i, j=j)
@@ -48,14 +53,18 @@ def hessian(ys, xs, component=0, i=0, j=0):
4853
computation.
4954
5055
Args:
51-
ys: Output Tensor of shape (batch_size, dim_y).
56+
ys: Output Tensor of shape (batch_size, dim_y) or (batch_size_out, batch_size,
57+
dim_y). Here, the `batch_size` is the same one for `xs`, and
58+
`batch_size_out` is the batch size for an additional/outer dimension.
5259
xs: Input Tensor of shape (batch_size, dim_x).
5360
component: `ys[:, component]` is used as y to compute the Hessian.
5461
i (int): `i`th row.
5562
j (int): `j`th column.
5663
5764
Returns:
58-
H[`i`, `j`].
65+
H[`i`, `j`]. When `ys` has shape (batch_size, dim_y), the output shape is
66+
(batch_size, 1). When `ys` has shape (batch_size_out, batch_size, dim_y),
67+
the output shape is (batch_size_out, batch_size, 1).
5968
"""
6069
if config.autodiff == "reverse":
6170
return gradients_reverse.hessian(ys, xs, component=component, i=i, j=j)

deepxde/gradients/gradients_forward.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def grad_fn(x):
8787
# Compute J[i, j]
8888
if (i, j) not in self.J:
8989
if backend_name == "tensorflow.compat.v1":
90-
self.J[i, j] = self.J[j][:, i : i + 1]
90+
self.J[i, j] = self.J[j][..., i : i + 1]
9191
elif backend_name in ["tensorflow", "pytorch", "jax"]:
9292
# In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array
9393
# and a callable is returned, so that it is consistent with the argument,
9494
# which is also a tuple. This is useful for further computation, e.g.,
9595
# Hessian.
9696
self.J[i, j] = (
97-
self.J[j][0][:, i : i + 1],
97+
self.J[j][0][..., i : i + 1],
9898
lambda x: self.J[j][1](x)[i : i + 1],
9999
)
100100
return self.J[i, j]

deepxde/gradients/gradients_reverse.py

+9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
__all__ = ["hessian", "jacobian"]
44

55
from .jacobian import Jacobian, Jacobians
6+
from .. import backend as bkd
67
from ..backend import backend_name, tf, torch, jax, paddle
78

89

@@ -17,6 +18,14 @@ def __call__(self, i=None, j=None):
1718
"Reverse-mode autodiff doesn't support computing a column."
1819
)
1920
i = 0
21+
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
22+
ndim_y = bkd.ndim(self.ys)
23+
elif backend_name == "jax":
24+
ndim_y = bkd.ndim(self.ys[0])
25+
if ndim_y == 3:
26+
raise NotImplementedError(
27+
"Reverse-mode autodiff doesn't support 3D output"
28+
)
2029

2130
# Compute J[i, :]
2231
if i not in self.J:

deepxde/gradients/jacobian.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ class Jacobian(ABC):
1111
It is lazy evaluation, i.e., it only computes J[i, j] when needed.
1212
1313
Args:
14-
ys: Output Tensor of shape (batch_size, dim_y).
14+
ys: Output Tensor of shape (batch_size, dim_y) or (batch_size_out, batch_size,
15+
dim_y). Here, the `batch_size` is the same one for `xs`, and
16+
`batch_size_out` is the batch size for an additional/outer dimension.
1517
xs: Input Tensor of shape (batch_size, dim_x).
1618
"""
1719

@@ -20,22 +22,22 @@ def __init__(self, ys, xs):
2022
self.xs = xs
2123

2224
if backend_name in ["tensorflow.compat.v1", "paddle"]:
23-
self.dim_y = ys.shape[1]
25+
self.dim_y = ys.shape[-1]
2426
elif backend_name in ["tensorflow", "pytorch"]:
2527
if config.autodiff == "reverse":
2628
# For reverse-mode AD, only a tensor is passed.
27-
self.dim_y = ys.shape[1]
29+
self.dim_y = ys.shape[-1]
2830
elif config.autodiff == "forward":
2931
# For forward-mode AD, a tuple of a tensor and a callable is passed,
3032
# similar to backend jax.
31-
self.dim_y = ys[0].shape[1]
33+
self.dim_y = ys[0].shape[-1]
3234
elif backend_name == "jax":
3335
# For backend jax, a tuple of a jax array and a callable is passed as one of
3436
# the arguments, since jax does not support computational graph explicitly.
3537
# The array is used to control the dimensions and the callable is used to
3638
# obtain the derivative function, which can be used to compute the
3739
# derivatives.
38-
self.dim_y = ys[0].shape[1]
40+
self.dim_y = ys[0].shape[-1]
3941
self.dim_x = xs.shape[1]
4042

4143
self.J = {}
@@ -114,7 +116,9 @@ def __call__(self, ys, xs, i=None, j=None):
114116
# x = torch.from_numpy(x)
115117
# x.requires_grad_()
116118
# f(x)
117-
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
119+
if backend_name == "tensorflow.compat.v1":
120+
key = (ys.ref(), xs.ref())
121+
elif backend_name == "tensorflow":
118122
if config.autodiff == "reverse":
119123
key = (ys.ref(), xs.ref())
120124
elif config.autodiff == "forward":

0 commit comments

Comments
 (0)