Skip to content

Demo code not converging on L4 or A100 GPU #1446

@baptiste-pasquier

Description

@baptiste-pasquier

Description
I am encountering a convergence issue when running the Meridian_Getting_Started.ipynb demo. The code executes correctly and the model converges on a T4 GPU, but fails to converge when running the exact same environment and code on L4 or A100 GPUs.

To Reproduce

%%time
mmm.sample_prior(500)
mmm.sample_posterior(
    n_chains=10, n_adapt=2000, n_burnin=500, n_keep=1000, seed=0
)

Expected Behavior
The model should converge similarly to how it performs on a T4 GPU.

Actual Behavior
On L4/A100 GPUs, I receive PTXAS warnings during compilation, followed by a total convergence failure.

Logs & Output
Compilation Warnings:

2026-02-01 19:40:50.474780: I external/local_xla/xla/stream_executor/cuda/subprocess_compilation.cc:346] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_427', 8 bytes spill stores, 8 bytes spill loads

2026-02-01 19:41:53.997341: I external/local_xla/xla/stream_executor/cuda/subprocess_compilation.cc:346] ptxas warning : Registers are spilled to local memory in function 'input_concatenate_fusion_36', 4 bytes spill stores, 4 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_concatenate_fusion_37', 4 bytes spill stores, 4 bytes spill loads

Model Reviewer Output:

reviewer.ModelReviewer(mmm).run()
========================================
Model Quality Checks
========================================
Overall Status: FAIL
Summary: Failed: Model did not converge. Other checks were skipped.

Check Results:
----------------------------------------
Convergence Check:
  Status: FAIL
  Recommendation: The model hasn't converged, and the `max_r_hat` for parameter `mu_t` is inf. We recommend increasing MCMC iterations or investigating model misspecification (e.g., priors, multicollinearity) before proceeding.

Environment

  • Hardware: L4 or A100 GPU (Issue present), T4 (Issue absent)
  • Meridian Version: 1.5
  • TensorFlow Version: 2.19.0

Attempted Solutions
I attempted the following fixes (suggested by Gemini) to mitigate potential XLA/OneDNN issues, but neither solved the problem:

  1. Disabling XLA Auto JIT:
    os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=0'
  2. Disabling OneDNN:
    os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions