diff --git a/pearl/neural_networks/common/epistemic_neural_networks.py b/pearl/neural_networks/common/epistemic_neural_networks.py index b99ded4c..b7c22cca 100644 --- a/pearl/neural_networks/common/epistemic_neural_networks.py +++ b/pearl/neural_networks/common/epistemic_neural_networks.py @@ -192,9 +192,14 @@ def forward(self, x: Tensor, z: Tensor) -> Tensor: Output: ensemble output of x weighted by epistemic index vector z. """ - outputs = torch.vmap(self.call_single_model, (0, 0, None))( - self.params, self.buffers, x - ) + # vmap is not compatible with torchscript + # outputs = torch.vmap(self.call_single_model, (0, 0, None))( + # self.params, self.buffers, x + # ) + outputs = [] + for model in self.models: + outputs.append(model(x)) + outputs = torch.stack(outputs, dim=0) return torch.einsum("ijk,ji->jk", outputs, z)