Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

bug: ConcretizationTypeError when trying to use prob_model.predictive() #101

@PaulScemama

Description

@PaulScemama

Bug Report

Hi! I've trained a prob_model and created checkpoints. I then run prob_model.load_state and attempt to produce predictions on the test set. However, I'm getting the following error:

...
  pspec=PartitionSpec('processes',)
] b
    from line [/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63](https://vscode-remote+ssh-002dremote-002brapidstart.vscode-resource.vscode-cdn.net/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63) (sync_global_devices)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The only thing I've done that is not standard is use my own custom model, which is here:

from typing import Any
import flax.linen as nn
import jax.numpy as jnp
import jax

act = jax.nn.swish


class AlexNet(nn.Module):
    output_dim: int
    dtype: Any = jnp.float32
    """
    An AlexNet model for Cifar10.
    """

    def setup(self):
        self.hidden_layers = AlexNetHiddenLayers(dtype=self.dtype)
        self.last_layer = AlexNetLastLayer(output_dim=self.output_dim, dtype=self.dtype)

    def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        x = self.hidden_layers(x, train)
        x = self.last_layer(x, train)
        return x


class AlexNetHiddenLayers(nn.Module):
    dtype: Any = jnp.float32
    """
    Hidden Convolutional layers of AlexNet model
    """

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True):
        # [32, 32, 3]
        x = nn.Conv(features=64, kernel_size=(3,))(x)
        # [32, 32, 64]
        x = act(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # [16, 16, 64]

        x = nn.Conv(features=128, kernel_size=(3,))(x)
        # [16, 16, 128]
        x = act(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # [8, 8, 128]

        x = nn.Conv(features=256, kernel_size=(2,))(x)
        # [8, 8, 256]
        x = act(x)

        x = nn.Conv(features=128, kernel_size=(2,))(x)
        # [8, 8, 128]
        x = act(x)

        x = nn.Conv(features=64, kernel_size=(2,))(x)
        # [8, 8, 64]
        x = act(x)

        x = x.reshape((x.shape[0], -1))
        return x


class AlexNetLastLayer(nn.Module):
    output_dim: int
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True):
        x = nn.Dense(features=256, dtype=self.dtype)(x)
        x = act(x)
        x = nn.Dense(features=256, dtype=self.dtype)(x)
        x = act(x)
        x = nn.Dense(features=self.output_dim, dtype=self.dtype)(x)
        return x

Steps to reproduce:

# // Model
prob_model = ProbClassifier(
    model=AlexNet(output_dim=10), 
    posterior_approximator=LaplacePosteriorApproximator(),
    prior=IsotropicGaussianPrior(log_var=jnp.log(PRIOR_VAR))
)
prob_model.load_state("../sgd_checkpoints/checkpoint_11532/")
test_log_probs = prob_model.predictive.log_prob(data_loader=test_loader)
# RAISES ERROR

Other information:

The data is coming from a torch dataloader, and converted with .from_torch_dataloader(). Let me know if you need more information on the actual data.

My hunch is that maybe I'm doing something wrong here. Any guidance is appreciated :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions