Skip to content

Commit

Permalink
Merge branch 'facebookresearch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
kuds authored Sep 14, 2024
2 parents 2778513 + 69577db commit d397336
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pearl/neural_networks/common/epistemic_neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,14 @@ def forward(self, x: Tensor, z: Tensor) -> Tensor:
Output:
ensemble output of x weighted by epistemic index vector z.
"""
outputs = torch.vmap(self.call_single_model, (0, 0, None))(
self.params, self.buffers, x
)
# vmap is not compatible with torchscript
# outputs = torch.vmap(self.call_single_model, (0, 0, None))(
# self.params, self.buffers, x
# )
outputs = []
for model in self.models:
outputs.append(model(x))
outputs = torch.stack(outputs, dim=0)
return torch.einsum("ijk,ji->jk", outputs, z)


Expand Down

0 comments on commit d397336

Please sign in to comment.