@@ -62,6 +62,12 @@ BlackJAX offers another JAX-based sampling implementation focused on flexibility
6262
6363+++
6464
65+ ## Installation Requirements
66+
67+ To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. ` conda install nutpie ` ). For JAX-based workflows, NumPyro provides mature functionality and is installed with the ` numpyro ` package. BlackJAX offers an alternative JAX implementation and is available in the ` blackjax ` package.
68+
69+ +++
70+
6571## Performance Guidelines
6672
6773Understanding when to use each sampler depends on several key factors including model size, variable types, and computational requirements.
@@ -73,28 +79,57 @@ Models containing **discrete variables** must use PyMC's built-in sampler, as it
7379** Numba** excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. ** JAX** offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The ** C** backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives.
7480
7581``` {code-cell} ipython3
76- import platform
82+ import time
83+
84+ from collections import defaultdict
7785
7886import arviz as az
7987import matplotlib.pyplot as plt
8088import numpy as np
89+ import numpyro
90+ import pandas as pd
8191import pymc as pm
8292
83- if platform.system() == "linux":
84- import multiprocessing
93+ numpyro.set_host_device_count(4)
8594
86- multiprocessing.set_start_method("spawn", force=True)
95+ %config InlineBackend.figure_format = 'retina'
96+ az.style.use("arviz-darkgrid")
8797
8898rng = np.random.default_rng(seed=42)
8999print(f"Running on PyMC v{pm.__version__}")
90100```
91101
92102``` {code-cell} ipython3
93- %config InlineBackend.figure_format = 'retina'
94- az.style.use("arviz-darkgrid")
95- ```
103+ import time
104+
105+ from collections import defaultdict
106+
107+ # Dictionary to store all results
108+ results = defaultdict(dict)
109+
110+
111+ class TimingContext:
112+ def __init__(self, name):
113+ self.name = name
114+
115+ def __enter__(self):
116+ self.start_wall = time.perf_counter()
117+ self.start_cpu = time.process_time()
118+ return self
119+
120+ def __exit__(self, *args):
121+ self.end_wall = time.perf_counter()
122+ self.end_cpu = time.process_time()
123+
124+ wall_time = self.end_wall - self.start_wall
125+ cpu_time = self.end_cpu - self.start_cpu
96126
97- We'll demonstrate the performance differences using a Probabilistic Principal Component Analysis (PPCA) model.
127+ results[self.name]["wall_time"] = wall_time
128+ results[self.name]["cpu_time"] = cpu_time
129+
130+ print(f"Wall time: {wall_time:.1f} s")
131+ print(f"CPU time: {cpu_time:.1f} s")
132+ ```
98133
99134``` {code-cell} ipython3
100135def build_toy_dataset(N, D, K, sigma=1):
@@ -129,10 +164,14 @@ plt.title("Simulated data set")
129164```
130165
131166``` {code-cell} ipython3
132- with pm.Model() as PPCA:
133- w = pm.Normal("w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered())
134- z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
135- x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
167+ def ppca_model():
168+ with pm.Model() as model:
169+ w = pm.Normal(
170+ "w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered()
171+ )
172+ z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
173+ x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
174+ return model
136175```
137176
138177## Performance Comparison
@@ -142,44 +181,154 @@ Now let's compare the performance of different sampling backends on our PPCA mod
142181### 1. PyMC Default Sampler (Python NUTS)
143182
144183``` {code-cell} ipython3
145- %%time
146- with PPCA:
147- idata_pymc = pm.sample(progressbar=False)
184+ n_draws = 2000
185+ n_tune = 2000
186+
187+ with TimingContext("PyMC Default"):
188+ with ppca_model():
189+ idata_pymc = pm.sample(draws=n_draws, tune=n_tune, progressbar=False)
190+
191+ ess_pymc = az.ess(idata_pymc)
192+ min_ess = min([ess_pymc[var].values.min() for var in ess_pymc.data_vars])
193+ mean_ess = np.mean([ess_pymc[var].values.mean() for var in ess_pymc.data_vars])
194+ results["PyMC Default"]["min_ess"] = min_ess
195+ results["PyMC Default"]["mean_ess"] = mean_ess
196+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
148197```
149198
150- ### 2. Nutpie with Numba Backend
199+ ### 2. Nutpie Sampler with Numba Backend
151200
152201``` {code-cell} ipython3
153- %%time
154- with PPCA:
155- idata_nutpie_numba = pm.sample(
156- nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "numba"}, progressbar=False
157- )
202+ with TimingContext("Nutpie Numba"):
203+ with ppca_model():
204+ idata_nutpie_numba = pm.sample(
205+ draws=n_draws,
206+ tune=n_tune,
207+ nuts_sampler="nutpie",
208+ nuts_sampler_kwargs={"backend": "numba"},
209+ progressbar=False,
210+ )
211+
212+ ess_nutpie_numba = az.ess(idata_nutpie_numba)
213+ min_ess = min([ess_nutpie_numba[var].values.min() for var in ess_nutpie_numba.data_vars])
214+ mean_ess = np.mean([ess_nutpie_numba[var].values.mean() for var in ess_nutpie_numba.data_vars])
215+ results["Nutpie Numba"]["min_ess"] = min_ess
216+ results["Nutpie Numba"]["mean_ess"] = mean_ess
217+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
158218```
159219
160- ### 3. Nutpie with JAX Backend
220+ ### 3. Nutpie Sampler with JAX Backend
161221
162222``` {code-cell} ipython3
163- %%time
164- with PPCA:
165- idata_nutpie_jax = pm.sample(
166- nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax"}, progressbar=False
167- )
223+ with TimingContext("Nutpie JAX"):
224+ with ppca_model():
225+ idata_nutpie_jax = pm.sample(
226+ draws=n_draws,
227+ tune=n_tune,
228+ nuts_sampler="nutpie",
229+ nuts_sampler_kwargs={"backend": "jax"},
230+ progressbar=False,
231+ )
232+
233+ ess_nutpie_jax = az.ess(idata_nutpie_jax)
234+ min_ess = min([ess_nutpie_jax[var].values.min() for var in ess_nutpie_jax.data_vars])
235+ mean_ess = np.mean([ess_nutpie_jax[var].values.mean() for var in ess_nutpie_jax.data_vars])
236+ results["Nutpie JAX"]["min_ess"] = min_ess
237+ results["Nutpie JAX"]["mean_ess"] = mean_ess
238+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
168239```
169240
170241### 4. NumPyro Sampler
171242
172243``` {code-cell} ipython3
173- %%time
174- with PPCA:
175- idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
244+ with TimingContext("NumPyro"):
245+ with ppca_model():
246+ idata_numpyro = pm.sample(
247+ draws=n_draws, tune=n_tune, nuts_sampler="numpyro", progressbar=False
248+ )
249+
250+ ess_numpyro = az.ess(idata_numpyro)
251+ min_ess = min([ess_numpyro[var].values.min() for var in ess_numpyro.data_vars])
252+ mean_ess = np.mean([ess_numpyro[var].values.mean() for var in ess_numpyro.data_vars])
253+ results["NumPyro"]["min_ess"] = min_ess
254+ results["NumPyro"]["mean_ess"] = mean_ess
255+ print(f"Min ESS: {min_ess:.0f}, Mean ESS: {mean_ess:.0f}")
176256```
177257
178- ## Installation Requirements
258+ ``` {code-cell} ipython3
259+ timing_data = []
260+ for backend_name, metrics in results.items():
261+ wall_time = metrics.get("wall_time", 0)
262+ cpu_time = metrics.get("cpu_time", 0)
263+ min_ess = metrics.get("min_ess", 0)
264+ mean_ess = metrics.get("mean_ess", 0)
265+ ess_per_sec = mean_ess / wall_time if wall_time > 0 else 0
266+
267+ timing_data.append(
268+ {
269+ "Sampling Backend": backend_name,
270+ "Wall Time (s)": f"{wall_time:.1f}",
271+ "CPU Time (s)": f"{cpu_time:.1f}",
272+ "Min ESS": f"{min_ess:.0f}",
273+ "Mean ESS": f"{mean_ess:.0f}",
274+ "ESS/sec": f"{ess_per_sec:.0f}",
275+ "Parallel Efficiency": f"{cpu_time/wall_time:.2f}" if wall_time > 0 else "N/A",
276+ }
277+ )
179278
180- To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. ` conda install nutpie ` ). For JAX-based workflows, NumPyro provides mature functionality and is installed with the ` numpyro ` package. BlackJAX offers an alternative JAX implementation and is available in the ` blackjax ` package.
279+ timing_df = pd.DataFrame(timing_data)
280+ timing_df = timing_df.sort_values("ESS/sec", ascending=False)
181281
182- +++
282+ print("\nPerformance Summary Table:")
283+ print("=" * 100)
284+ print(timing_df.to_string(index=False))
285+ print("=" * 100)
286+
287+ best_backend = timing_df.iloc[0]["Sampling Backend"]
288+ best_ess_per_sec = timing_df.iloc[0]["ESS/sec"]
289+ print(f"\nMost efficient backend: {best_backend} with {best_ess_per_sec} ESS/second")
290+ ```
291+
292+ ``` {code-cell} ipython3
293+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
294+
295+ backends = timing_df["Sampling Backend"].tolist()
296+ wall_times = [float(val) for val in timing_df["Wall Time (s)"].tolist()]
297+ mean_ess_values = [float(val) for val in timing_df["Mean ESS"].tolist()]
298+ ess_per_sec_values = [float(val) for val in timing_df["ESS/sec"].tolist()]
299+
300+ ax1.bar(backends, wall_times, color="skyblue")
301+ ax1.set_ylabel("Wall Time (seconds)")
302+ ax1.set_title("Sampling Time")
303+ ax1.tick_params(axis="x", rotation=45)
304+
305+ ax2.bar(backends, mean_ess_values, color="lightgreen")
306+ ax2.set_ylabel("Mean ESS")
307+ ax2.set_title("Effective Sample Size")
308+ ax2.tick_params(axis="x", rotation=45)
309+
310+ ax3.bar(backends, ess_per_sec_values, color="coral")
311+ ax3.set_ylabel("ESS per Second")
312+ ax3.set_title("Sampling Efficiency")
313+ ax3.tick_params(axis="x", rotation=45)
314+
315+ ax4.scatter(wall_times, mean_ess_values, s=200, alpha=0.6)
316+ for i, backend in enumerate(backends):
317+ ax4.annotate(
318+ backend,
319+ (wall_times[i], mean_ess_values[i]),
320+ xytext=(5, 5),
321+ textcoords="offset points",
322+ fontsize=9,
323+ )
324+ ax4.set_xlabel("Wall Time (seconds)")
325+ ax4.set_ylabel("Mean ESS")
326+ ax4.set_title("Time vs. Effective Sample Size")
327+ ax4.grid(True, alpha=0.3)
328+
329+ plt.tight_layout()
330+ plt.show()
331+ ```
183332
184333## Special Cases and Advanced Usage
185334
@@ -190,13 +339,13 @@ In certain scenarios, you may need to use PyMC's Python-based sampler while stil
190339The following examples demonstrate how to use PyMC's built-in sampler with different compilation targets. The ` fast_run ` mode uses optimized C compilation, which provides good performance while maintaining full compatibility. The ` numba ` mode offers the only way to access Numba's just-in-time compilation benefits when using PyMC's sampler. The ` jax ` mode enables JAX compilation, though for JAX workflows, Nutpie or NumPyro typically provide better performance.
191340
192341``` {code-cell} ipython3
193- with PPCA :
342+ with ppca_model() :
194343 idata_c = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "fast_run"}, progressbar=False)
195344
196- # with PPCA :
345+ # with ppca_model() :
197346# idata_pymc_numba = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "numba"}, progressbar=False)
198347
199- # with PPCA :
348+ # with ppca_model() :
200349# idata_pymc_jax = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "jax"}, progressbar=False)
201350```
202351
@@ -221,12 +370,9 @@ with pm.Model() as discrete_model:
221370## Authors
222371
223372- Originally authored by Thomas Wiecki in July 2023
224- - Substantially updated and expanded by Chris Fonnesbeck in May 2025
373+ - Updated and expanded by Chris Fonnesbeck in May 2025
225374
226375``` {code-cell} ipython3
227376%load_ext watermark
228377%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie
229378```
230-
231- :::{include} ../page_footer.md
232- :::
0 commit comments