This repository was archived by the owner on Apr 23, 2025. It is now read-only.
generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 52
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
bug: ConcretizationTypeError when trying to use prob_model.predictive() #101
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Report
Hi! I've trained a prob_model and created checkpoints. I then run prob_model.load_state and attempt to produce predictions on the test set. However, I'm getting the following error:
...
pspec=PartitionSpec('processes',)
] b
from line [/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63](https://vscode-remote+ssh-002dremote-002brapidstart.vscode-resource.vscode-cdn.net/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63) (sync_global_devices)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeErrorThe only thing I've done that is not standard is use my own custom model, which is here:
from typing import Any
import flax.linen as nn
import jax.numpy as jnp
import jax
act = jax.nn.swish
class AlexNet(nn.Module):
output_dim: int
dtype: Any = jnp.float32
"""
An AlexNet model for Cifar10.
"""
def setup(self):
self.hidden_layers = AlexNetHiddenLayers(dtype=self.dtype)
self.last_layer = AlexNetLastLayer(output_dim=self.output_dim, dtype=self.dtype)
def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
x = self.hidden_layers(x, train)
x = self.last_layer(x, train)
return x
class AlexNetHiddenLayers(nn.Module):
dtype: Any = jnp.float32
"""
Hidden Convolutional layers of AlexNet model
"""
@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = True):
# [32, 32, 3]
x = nn.Conv(features=64, kernel_size=(3,))(x)
# [32, 32, 64]
x = act(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
# [16, 16, 64]
x = nn.Conv(features=128, kernel_size=(3,))(x)
# [16, 16, 128]
x = act(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
# [8, 8, 128]
x = nn.Conv(features=256, kernel_size=(2,))(x)
# [8, 8, 256]
x = act(x)
x = nn.Conv(features=128, kernel_size=(2,))(x)
# [8, 8, 128]
x = act(x)
x = nn.Conv(features=64, kernel_size=(2,))(x)
# [8, 8, 64]
x = act(x)
x = x.reshape((x.shape[0], -1))
return x
class AlexNetLastLayer(nn.Module):
output_dim: int
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = True):
x = nn.Dense(features=256, dtype=self.dtype)(x)
x = act(x)
x = nn.Dense(features=256, dtype=self.dtype)(x)
x = act(x)
x = nn.Dense(features=self.output_dim, dtype=self.dtype)(x)
return xSteps to reproduce:
# // Model
prob_model = ProbClassifier(
model=AlexNet(output_dim=10),
posterior_approximator=LaplacePosteriorApproximator(),
prior=IsotropicGaussianPrior(log_var=jnp.log(PRIOR_VAR))
)
prob_model.load_state("../sgd_checkpoints/checkpoint_11532/")
test_log_probs = prob_model.predictive.log_prob(data_loader=test_loader)
# RAISES ERROROther information:
The data is coming from a torch dataloader, and converted with .from_torch_dataloader(). Let me know if you need more information on the actual data.
My hunch is that maybe I'm doing something wrong here. Any guidance is appreciated :)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working