Skip to content

Commit a855547

Browse files
authored
Replaces the Poisson rejection method implementation (#1560)
- [x] Added a `CHANGELOG.md` entry # Summary As discussed in #1515, this PR replaces the implementation of `poisson::RejectionMethod` with a new algorithm based on the [paper ](https://dl.acm.org/doi/10.1145/355993.355997). # Motivation The new implementation offers improved performance and maintains better sampling distribution, especially for extreme values of lambda (> 1e9). # Details In terms of performance, here are the benchmarks I ran, with the current implementation as the baseline: ```text poisson/100 time: [45.5242 cycles 45.6734 cycles 45.8337 cycles] change: [-86.572% -86.507% -86.438%] (p = 0.00 < 0.05) Performance has improved. Found 5 outliers among 100 measurements (5.00%) 2 (2.00%) low mild 2 (2.00%) high mild 1 (1.00%) high severe poisson/variable time: [5494.6626 cycles 5508.2882 cycles 5523.2298 cycles] thrpt: [5523.2298 cycles/100 5508.2882 cycles/100 5494.6626 cycles/100] change: time: [-76.728% -76.573% -76.430%] (p = 0.00 < 0.05) thrpt: [+324.27% +326.85% +329.69%] Performance has improved. Found 5 outliers among 100 measurements (5.00%) 1 (1.00%) low mild 3 (3.00%) high mild 1 (1.00%) high severe ```
1 parent 67fd92e commit a855547

File tree

5 files changed

+125
-120
lines changed

5 files changed

+125
-120
lines changed

distr_test/tests/cdf.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,9 @@ fn hypergeometric() {
427427
fn poisson() {
428428
use rand_distr::Poisson;
429429
let parameters = [
430-
0.1, 1.0, 7.5,
431-
45.0, // 1e9, passed case but too slow
432-
// 1.844E+19, // fail case
430+
0.1, 1.0, 7.5, 15.0, 45.0, 98.0, 230.0, 4567.5,
431+
4.4541e7, // 1e10, //passed case but too slow
432+
// 1.844E+19, // fail case
433433
];
434434

435435
for (seed, lambda) in parameters.into_iter().enumerate() {

rand_distr/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4848
This breaks serialization compatibility with older versions.
4949
- Add plots for `rand_distr` distributions to documentation (#1434)
5050
- Move some of the computations in Binomial from `sample` to `new` (#1484)
51+
- Reimplement `Poisson`'s rejection method to improve performance and correct sampling inaccuracies for large lambda values, this is a Value-breaking change (#1560)
5152

5253
## [0.4.3] - 2021-12-30
5354
- Fix `no_std` build (#1208)

rand_distr/src/poisson.rs

+120-73
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
//! The Poisson distribution `Poisson(λ)`.
1111
12-
use crate::{Cauchy, Distribution, StandardUniform};
12+
use crate::{Distribution, Exp1, Normal, StandardNormal, StandardUniform};
1313
use core::fmt;
1414
use num_traits::{Float, FloatConst};
1515
use rand::Rng;
@@ -101,21 +101,37 @@ impl<F: Float> KnuthMethod<F> {
101101
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
102102
struct RejectionMethod<F> {
103103
lambda: F,
104-
log_lambda: F,
105-
sqrt_2lambda: F,
106-
magic_val: F,
104+
s: F,
105+
d: F,
106+
l: F,
107+
c: F,
108+
c0: F,
109+
c1: F,
110+
c2: F,
111+
c3: F,
112+
omega: F,
107113
}
108114

109-
impl<F: Float> RejectionMethod<F> {
115+
impl<F: Float + FloatConst> RejectionMethod<F> {
110116
pub(crate) fn new(lambda: F) -> Self {
111-
let log_lambda = lambda.ln();
112-
let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt();
113-
let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda);
117+
let b1 = F::from(1.0 / 24.0).unwrap() / lambda;
118+
let b2 = F::from(0.3).unwrap() * b1 * b1;
119+
let c3 = F::from(1.0 / 7.0).unwrap() * b1 * b2;
120+
let c2 = b2 - F::from(15).unwrap() * c3;
121+
let c1 = b1 - F::from(6).unwrap() * b2 + F::from(45).unwrap() * c3;
122+
let c0 = F::one() - b1 + F::from(3).unwrap() * b2 - F::from(15).unwrap() * c3;
123+
114124
RejectionMethod {
115125
lambda,
116-
log_lambda,
117-
sqrt_2lambda,
118-
magic_val,
126+
s: lambda.sqrt(),
127+
d: F::from(6.0).unwrap() * lambda.powi(2),
128+
l: (lambda - F::from(1.1484).unwrap()).floor(),
129+
c: F::from(0.1069).unwrap() / lambda,
130+
c0,
131+
c1,
132+
c2,
133+
c3,
134+
omega: F::one() / (F::from(2).unwrap() * F::PI()).sqrt() / lambda.sqrt(),
119135
}
120136
}
121137
}
@@ -189,56 +205,114 @@ impl<F> Distribution<F> for RejectionMethod<F>
189205
where
190206
F: Float + FloatConst,
191207
StandardUniform: Distribution<F>,
208+
StandardNormal: Distribution<F>,
209+
Exp1: Distribution<F>,
192210
{
193211
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
194-
// The algorithm from Numerical Recipes in C
212+
// The algorithm is based on:
213+
// J. H. Ahrens and U. Dieter. 1982.
214+
// Computer Generation of Poisson Deviates from Modified Normal Distributions.
215+
// ACM Trans. Math. Softw. 8, 2 (June 1982), 163–179. https://doi.org/10.1145/355993.355997
216+
217+
// Step F
218+
let f = |k: F| {
219+
const FACT: [f64; 10] = [
220+
1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
221+
]; // factorial of 0..10
222+
const A: [f64; 10] = [
223+
-0.5000000002,
224+
0.3333333343,
225+
-0.2499998565,
226+
0.1999997049,
227+
-0.1666848753,
228+
0.1428833286,
229+
-0.1241963125,
230+
0.1101687109,
231+
-0.1142650302,
232+
0.1055093006,
233+
]; // coefficients from Table 1
234+
let (px, py) = if k < F::from(10.0).unwrap() {
235+
let px = -self.lambda;
236+
let py = self.lambda.powf(k) / F::from(FACT[k.to_usize().unwrap()]).unwrap();
237+
238+
(px, py)
239+
} else {
240+
let delta = (F::from(12.0).unwrap() * k).recip();
241+
let delta = delta - F::from(4.8).unwrap() * delta.powi(3);
242+
let v = (self.lambda - k) / k;
243+
244+
let px = if v.abs() <= F::from(0.25).unwrap() {
245+
k * v.powi(2)
246+
* A.iter()
247+
.rev()
248+
.fold(F::zero(), |acc, &a| {
249+
acc * v + F::from(a).unwrap()
250+
}) // Σ a_i * v^i
251+
- delta
252+
} else {
253+
k * (F::one() + v).ln() - (self.lambda - k) - delta
254+
};
255+
256+
let py = F::one() / (F::from(2.0).unwrap() * F::PI()).sqrt() / k.sqrt();
257+
258+
(px, py)
259+
};
260+
261+
let x = (k - self.lambda + F::from(0.5).unwrap()) / self.s;
262+
let fx = -F::from(0.5).unwrap() * x * x;
263+
let fy =
264+
self.omega * (((self.c3 * x * x + self.c2) * x * x + self.c1) * x * x + self.c0);
265+
266+
(px, py, fx, fy)
267+
};
268+
269+
// Step N
270+
let normal = Normal::new(self.lambda, self.s).unwrap();
271+
let g = normal.sample(rng);
272+
if g >= F::zero() {
273+
let k1 = g.floor();
274+
275+
// Step I
276+
if k1 >= self.l {
277+
return k1;
278+
}
195279

196-
// we use the Cauchy distribution as the comparison distribution
197-
// f(x) ~ 1/(1+x^2)
198-
let cauchy = Cauchy::new(F::zero(), F::one()).unwrap();
199-
let mut result;
280+
// Step S
281+
let u: F = rng.random();
282+
if self.d * u >= (self.lambda - k1).powi(3) {
283+
return k1;
284+
}
285+
286+
let (px, py, fx, fy) = f(k1);
287+
288+
if fy * (F::one() - u) <= py * (px - fx).exp() {
289+
return k1;
290+
}
291+
}
200292

201293
loop {
202-
let mut comp_dev;
203-
204-
loop {
205-
// draw from the Cauchy distribution
206-
comp_dev = rng.sample(cauchy);
207-
// shift the peak of the comparison distribution
208-
result = self.sqrt_2lambda * comp_dev + self.lambda;
209-
// repeat the drawing until we are in the range of possible values
210-
if result >= F::zero() {
211-
break;
294+
// Step E
295+
let e = Exp1.sample(rng);
296+
let u: F = rng.random() * F::from(2.0).unwrap() - F::one();
297+
let t = F::from(1.8).unwrap() + e * u.signum();
298+
if t > F::from(-0.6744).unwrap() {
299+
let k2 = (self.lambda + self.s * t).floor();
300+
let (px, py, fx, fy) = f(k2);
301+
// Step H
302+
if self.c * u.abs() <= py * (px + e).exp() - fy * (fx + e).exp() {
303+
return k2;
212304
}
213305
}
214-
// now the result is a random variable greater than 0 with Cauchy distribution
215-
// the result should be an integer value
216-
result = result.floor();
217-
218-
// this is the ratio of the Poisson distribution to the comparison distribution
219-
// the magic value scales the distribution function to a range of approximately 0-1
220-
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
221-
// this doesn't change the resulting distribution, only increases the rate of failed drawings
222-
let check = F::from(0.9).unwrap()
223-
* (F::one() + comp_dev * comp_dev)
224-
* (result * self.log_lambda
225-
- crate::utils::log_gamma(F::one() + result)
226-
- self.magic_val)
227-
.exp();
228-
229-
// check with uniform random value - if below the threshold, we are within the target distribution
230-
if rng.random::<F>() <= check {
231-
break;
232-
}
233306
}
234-
result
235307
}
236308
}
237309

238310
impl<F> Distribution<F> for Poisson<F>
239311
where
240312
F: Float + FloatConst,
241313
StandardUniform: Distribution<F>,
314+
StandardNormal: Distribution<F>,
315+
Exp1: Distribution<F>,
242316
{
243317
#[inline]
244318
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
@@ -253,33 +327,6 @@ where
253327
mod test {
254328
use super::*;
255329

256-
fn test_poisson_avg_gen<F: Float + FloatConst>(lambda: F, tol: F)
257-
where
258-
StandardUniform: Distribution<F>,
259-
{
260-
let poisson = Poisson::new(lambda).unwrap();
261-
let mut rng = crate::test::rng(123);
262-
let mut sum = F::zero();
263-
for _ in 0..1000 {
264-
sum = sum + poisson.sample(&mut rng);
265-
}
266-
let avg = sum / F::from(1000.0).unwrap();
267-
assert!((avg - lambda).abs() < tol);
268-
}
269-
270-
#[test]
271-
fn test_poisson_avg() {
272-
test_poisson_avg_gen::<f64>(10.0, 0.1);
273-
test_poisson_avg_gen::<f64>(15.0, 0.1);
274-
275-
test_poisson_avg_gen::<f32>(10.0, 0.1);
276-
test_poisson_avg_gen::<f32>(15.0, 0.1);
277-
278-
// Small lambda will use Knuth's method with exp_lambda == 1.0
279-
test_poisson_avg_gen::<f32>(0.00000000000000005, 0.1);
280-
test_poisson_avg_gen::<f64>(0.00000000000000005, 0.1);
281-
}
282-
283330
#[test]
284331
#[should_panic]
285332
fn test_poisson_invalid_lambda_zero() {

rand_distr/src/utils.rs

-43
Original file line numberDiff line numberDiff line change
@@ -9,52 +9,9 @@
99
//! Math helper functions
1010
1111
use crate::ziggurat_tables;
12-
use num_traits::Float;
1312
use rand::distr::hidden_export::IntoFloat;
1413
use rand::Rng;
1514

16-
/// Calculates ln(gamma(x)) (natural logarithm of the gamma
17-
/// function) using the Lanczos approximation.
18-
///
19-
/// The approximation expresses the gamma function as:
20-
/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)`
21-
/// `g` is an arbitrary constant; we use the approximation with `g=5`.
22-
///
23-
/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides:
24-
/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)`
25-
///
26-
/// `Ag(z)` is an infinite series with coefficients that can be calculated
27-
/// ahead of time - we use just the first 6 terms, which is good enough
28-
/// for most purposes.
29-
pub(crate) fn log_gamma<F: Float>(x: F) -> F {
30-
// precalculated 6 coefficients for the first 6 terms of the series
31-
let coefficients: [F; 6] = [
32-
F::from(76.18009172947146).unwrap(),
33-
F::from(-86.50532032941677).unwrap(),
34-
F::from(24.01409824083091).unwrap(),
35-
F::from(-1.231739572450155).unwrap(),
36-
F::from(0.1208650973866179e-2).unwrap(),
37-
F::from(-0.5395239384953e-5).unwrap(),
38-
];
39-
40-
// (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
41-
let tmp = x + F::from(5.5).unwrap();
42-
let log = (x + F::from(0.5).unwrap()) * tmp.ln() - tmp;
43-
44-
// the first few terms of the series for Ag(x)
45-
let mut a = F::from(1.000000000190015).unwrap();
46-
let mut denom = x;
47-
for &coeff in &coefficients {
48-
denom = denom + F::one();
49-
a = a + (coeff / denom);
50-
}
51-
52-
// get everything together
53-
// a is Ag(x)
54-
// 2.5066... is sqrt(2pi)
55-
log + (F::from(2.5066282746310005).unwrap() * a / x).ln()
56-
}
57-
5815
/// Sample a random number using the Ziggurat method (specifically the
5916
/// ZIGNOR variant from Doornik 2005). Most of the arguments are
6017
/// directly from the paper:

rand_distr/tests/value_stability.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ fn poisson_stability() {
207207
test_samples(
208208
223,
209209
Poisson::new(27.0).unwrap(),
210-
&[28.0f32, 32.0, 36.0, 36.0],
210+
&[30.0f32, 33.0, 23.0, 25.0],
211211
);
212212
}
213213

0 commit comments

Comments
 (0)