Skip to content

Commit

Permalink
refactor: use functions in prec to follow crate precision defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Mar 5, 2025
1 parent b4a7556 commit a2c6c23
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 93 deletions.
2 changes: 1 addition & 1 deletion src/distribution/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ mod tests {
// let res = n.std_dev();
// for i in 1..11 {
// let f = i as f64;
// assert_almost_eq!(res[i-1], (f * (sum - f) / (sum * sum * (sum + 1.0))).sqrt(), 1e-15);
// prec::assert_abs_diff_eq!(res[i-1], (f * (sum - f) / (sum * sum * (sum + 1.0))).sqrt(), epsilon = 1e-15);
// }
// }

Expand Down
8 changes: 4 additions & 4 deletions src/distribution/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,10 @@ pub mod test {
}

prec::assert_abs_diff_eq!(sum, dist.cdf(i), epsilon = 1e-10);
// assert_almost_eq!(sum, dist.cdf(i as f64), 1e-10);
// assert_almost_eq!(sum, dist.cdf(i as f64 + 0.1), 1e-10);
// assert_almost_eq!(sum, dist.cdf(i as f64 + 0.5), 1e-10);
// assert_almost_eq!(sum, dist.cdf(i as f64 + 0.9), 1e-10);
// prec::assert_abs_diff_eq!(sum, dist.cdf(i as f64), epsilon = 1e-10);
// prec::assert_abs_diff_eq!(sum, dist.cdf(i as f64 + 0.1), epsilon = 1e-10);
// prec::assert_abs_diff_eq!(sum, dist.cdf(i as f64 + 0.5), epsilon = 1e-10);
// prec::assert_abs_diff_eq!(sum, dist.cdf(i as f64 + 0.9), epsilon = 1e-10);
}

assert!(sum > 0.99);
Expand Down
6 changes: 3 additions & 3 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,13 +551,13 @@ mod tests {
// let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
// let n = Multinomial::new(large_p, 45).unwrap();
// let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9];
// assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13);
// prec::assert_abs_diff_eq!(n.pmf(x).ln(), n.ln_pmf(x), epsilon = 1e-13);
// let n2 = Multinomial::new(large_p, 18).unwrap();
// let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3];
// assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13);
// prec::assert_abs_diff_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), epsilon = 1e-13);
// let n3 = Multinomial::new(large_p, 51).unwrap();
// let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3];
// assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13);
// prec::assert_abs_diff_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), epsilon = 1e-13);
// }

// #[test]
Expand Down
60 changes: 30 additions & 30 deletions src/distribution/multivariate_students_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ mod tests {
use core::fmt::Debug;
use crate::prec;

use approx::RelativeEq;
use nalgebra::{DMatrix, DVector, Dyn, OMatrix, OVector, U1, U2};

use crate::{
Expand All @@ -405,6 +404,9 @@ mod tests {

use super::MultivariateStudentError;

// Module-specific precision constants
const MODULE_RELATIVE_EQ: f64 = 1e-15; // Default relative precision for relative comparisons

fn try_create(location: Vec<f64>, scale: Vec<f64>, freedom: f64) -> MultivariateStudent<Dyn>
{
let mvs = MultivariateStudent::new(location, scale, freedom);
Expand Down Expand Up @@ -435,22 +437,22 @@ mod tests {
assert_eq!(expected, x);
}

fn test_almost<F>(
fn test_relative<F>(
location: Vec<f64>,
scale: Vec<f64>,
freedom: f64,
expected: f64,
acc: f64,
max_relative: f64,
eval: F,
) where
F: FnOnce(MultivariateStudent<Dyn>) -> f64,
{
let mvs = try_create(location, scale, freedom);
let x = eval(mvs);
prec::assert_abs_diff_eq!(expected, x, epsilon = acc);
prec::assert_relative_eq!(expected, x, epsilon = prec::DEFAULT_EPS, max_relative = max_relative);
}

fn test_almost_multivariate_normal<F1, F2>(
fn test_abs_diff_multivariate_normal<F1, F2>(
location: Vec<f64>,
scale: Vec<f64>,
freedom: f64,
Expand All @@ -468,11 +470,9 @@ mod tests {
let mvn = mvn0.unwrap();
let mvs_x = eval_mvs(mvs, x.clone());
let mvn_x = eval_mvn(mvn, x.clone());
assert!(mvs_x.relative_eq(&mvn_x, acc, acc), "mvn: {mvn_x} =/=\nmvs: {mvs_x}");
// assert_relative_eq!(mvs_x, mvn_x, acc);
prec::assert_abs_diff_eq!(mvs_x, mvn_x, epsilon = acc);
}


macro_rules! dvec {
($($x:expr),*) => (DVector::from_vec(vec![$($x),*]));
}
Expand Down Expand Up @@ -559,45 +559,45 @@ mod tests {
#[test]
fn test_pdf() {
let pdf = |arg: DVector<f64>| move |x: MultivariateStudent<Dyn>| x.pdf(&arg);
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, 1e-15, pdf(dvec![1., 1.]));
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, 1e-15, pdf(dvec![1., 2.]));
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, 1e-17, pdf(dvec![1., 2.]));
test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, 1e-19, pdf(dvec![1., 10.]));
test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, 1e-19, pdf(dvec![10., 10.]));
test_relative(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, MODULE_RELATIVE_EQ, pdf(dvec![1., 1.]));
test_relative(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, MODULE_RELATIVE_EQ, pdf(dvec![1., 2.]));
test_relative(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, MODULE_RELATIVE_EQ, pdf(dvec![1., 2.]));
test_relative(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, MODULE_RELATIVE_EQ, pdf(dvec![1., 10.]));
test_relative(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, MODULE_RELATIVE_EQ, pdf(dvec![10., 10.]));
// These three are crossed checked against both python's scipy.multivariate_t.pdf and octave's mvtpdf.
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, 1e-30, pdf(dvec![0.9718, 0.1298, 0.8134]));
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, 1e-30, pdf(dvec![0.4922, 0.5522, 0.7185]));
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8.,6.951631724511314e-16, 1e-30, pdf(dvec![0.3020, 0.1491, 0.5008]));
test_relative(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, MODULE_RELATIVE_EQ, pdf(dvec![0.9718, 0.1298, 0.8134]));
test_relative(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, MODULE_RELATIVE_EQ, pdf(dvec![0.4922, 0.5522, 0.7185]));
test_relative(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.951631724511314e-16, MODULE_RELATIVE_EQ, pdf(dvec![0.3020, 0.1491, 0.5008]));
test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., 0., pdf(dvec![10., 10.]));
}

#[test]
fn test_ln_pdf() {
let ln_pdf = |arg: DVector<f64>| move |x: MultivariateStudent<Dyn>| x.ln_pdf(&arg);
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., -3.0542723907338383, 1e-14, ln_pdf(dvec![1., 1.]));
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., -4.3434030034000815, 1e-14, ln_pdf(dvec![1., 2.]));
test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, -10.542229575274265, 1e-14, ln_pdf(dvec![1., 10.]));
test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, -9.650699521198622, 1e-14, ln_pdf(dvec![10., 10.]));
// test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., f64::NEG_INFINITY, ln_pdf(dvec![10., 10.]));
test_relative(vec![0., 0.], vec![1., 0., 0., 1.], 4., -3.0542723907338383, MODULE_RELATIVE_EQ, ln_pdf(dvec![1., 1.]));
test_relative(vec![0., 0.], vec![1., 0., 0., 1.], 2., -4.3434030034000815, MODULE_RELATIVE_EQ, ln_pdf(dvec![1., 2.]));
test_relative(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, -10.542229575274265, MODULE_RELATIVE_EQ, ln_pdf(dvec![1., 10.]));
test_relative(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, -9.650699521198622, MODULE_RELATIVE_EQ, ln_pdf(dvec![10., 10.]));
}

#[test]
fn test_pdf_freedom_large() {
let pdf_mvs = |mv: MultivariateStudent<Dyn>, arg: DVector<f64>| mv.pdf(&arg);
let pdf_mvn = |mv: MultivariateNormal<Dyn>, arg: DVector<f64>| mv.pdf(&arg);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1e-4, 1e-4], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-6, dvec![1e-4, 1e-4], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn);
}

#[test]
fn test_ln_pdf_freedom_large() {
let pdf_mvs = |mv: MultivariateStudent<Dyn>, arg: DVector<f64>| mv.ln_pdf(&arg);
let pdf_mvn = |mv: MultivariateNormal<Dyn>, arg: DVector<f64>| mv.ln_pdf(&arg);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
test_abs_diff_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
}

#[test]
Expand All @@ -606,7 +606,7 @@ mod tests {
let mvs = MultivariateStudent::new(vec![1., 1.], vec![1., 0., 0., 1.], 2.)
.expect("hard coded valid construction");
assert_eq!(mvs.freedom(), 2.);
prec::assert_relative_eq!(mvs.ln_pdf_const(), std::f64::consts::TAU.recip().ln(), epsilon = 1e-15);
prec::assert_relative_eq!(mvs.ln_pdf_const(), std::f64::consts::TAU.recip().ln(), epsilon = prec::DEFAULT_EPS, max_relative = MODULE_RELATIVE_EQ);

// compare to static
assert_eq!(mvs.dim(), 2);
Expand Down
4 changes: 2 additions & 2 deletions src/distribution/negative_binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ mod tests {
let theoretical_mean = dist.mean().unwrap();
let theoretical_variance = dist.variance().unwrap();

assert!(prec::almost_eq(sample_mean, theoretical_mean, tol));
assert!(prec::almost_eq(sample_variance, theoretical_variance, tol));
prec::assert_abs_diff_eq!(sample_mean, theoretical_mean, epsilon = tol);
prec::assert_abs_diff_eq!(sample_variance, theoretical_variance, epsilon = tol);
}
}
Loading

0 comments on commit a2c6c23

Please sign in to comment.