diff --git a/include/sampling/random_point_generators.hpp b/include/sampling/random_point_generators.hpp index 998d8dafc..22e57e259 100644 --- a/include/sampling/random_point_generators.hpp +++ b/include/sampling/random_point_generators.hpp @@ -10,11 +10,12 @@ template < - typename Walk + typename Walk, + bool storeValue = true > struct RandomPointGenerator { - template + template < typename Polytope, typename Point, @@ -36,7 +37,10 @@ struct RandomPointGenerator for (unsigned int i=0; i struct MultivariateGaussianRandomPointGenerator { @@ -96,7 +104,10 @@ struct MultivariateGaussianRandomPointGenerator for (unsigned int i=0; i struct GaussianRandomPointGenerator { @@ -156,7 +171,10 @@ struct GaussianRandomPointGenerator for (unsigned int i=0; i +template +< + typename Walk, + bool storeValue = true +> struct BoundaryRandomPointGenerator { template @@ -216,8 +241,11 @@ struct BoundaryRandomPointGenerator for (unsigned int i=0; i struct LogconcaveRandomPointGenerator { @@ -250,14 +279,18 @@ struct LogconcaveRandomPointGenerator walk.apply(rng, walk_length); // Use PushBackWalkPolicy - policy.apply(randPoints, walk.x); + if constexpr (storeValue) // Only store points if storeValue is true + { + policy.apply(randPoints, walk.x); + } } } }; template < - typename Walk + typename Walk, + bool storeValue = true > struct CrhmcRandomPointGenerator { @@ -302,14 +335,20 @@ struct CrhmcRandomPointGenerator if((i + 1) * simdLen > rnum){ for(int j = 0; j < rnum-simdLen*i; j++){ Point p = Point(x.col(j)); + if constexpr (storeValue) // Only store points if storeValue is true + { policy.apply(randPoints, p); + } } break; } // Use PushBackWalkPolicy for(int j=0; j struct ExponentialRandomPointGenerator { @@ -349,7 +389,10 @@ struct ExponentialRandomPointGenerator //return; throw std::range_error("A generated point is outside polytope"); } - policy.apply(randPoints, p); + if constexpr (storeValue) // Only store points if storeValue is true + { + policy.apply(randPoints, p); + } } } @@ -384,7 +427,10 @@ struct ExponentialRandomPointGenerator //return; throw std::range_error("A generated point is outside polytope"); } - policy.apply(randPoints, p); + if constexpr (storeValue) // Only store points if storeValue is true + { + policy.apply(randPoints, p); + } } } diff --git a/include/sampling/sampling.hpp b/include/sampling/sampling.hpp index 696186098..72062ff24 100644 --- a/include/sampling/sampling.hpp +++ b/include/sampling/sampling.hpp @@ -48,13 +48,14 @@ void uniform_sampling(PointList &randPoints, Point p = starting_point; - typedef RandomPointGenerator RandomPointGenerator; + typedef RandomPointGenerator SamplingGenerator; + typedef RandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(P, p, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, nburns, walk_len, randPoints, push_back_policy, rng); - randPoints.clear(); } - RandomPointGenerator::apply(P, p, rnum, walk_len, randPoints, + + SamplingGenerator::apply(P, p, rnum, walk_len, randPoints, push_back_policy, rng); @@ -84,15 +85,16 @@ void uniform_sampling(PointList &randPoints, //RandomNumberGenerator rng(P.dimension()); PushBackWalkPolicy push_back_policy; - typedef RandomPointGenerator RandomPointGenerator; - + typedef RandomPointGenerator SamplingGenerator; + typedef RandomPointGenerator BurnInGenerator; + Point p = starting_point; if (nburns > 0) { - RandomPointGenerator::apply(P, p, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, nburns, walk_len, randPoints, push_back_policy, rng, WalkType.param); - randPoints.clear(); + } - RandomPointGenerator::apply(P, p, rnum, walk_len, randPoints, + SamplingGenerator::apply(P, p, rnum, walk_len, randPoints, push_back_policy, rng, WalkType.param); } @@ -124,14 +126,14 @@ void uniform_sampling_boundary(PointList &randPoints, Point p = starting_point; - typedef BoundaryRandomPointGenerator BoundaryRandomPointGenerator; + typedef BoundaryRandomPointGenerator SamplingGenerator; + typedef BoundaryRandomPointGenerator BurnInGenerator; if (nburns > 0) { - BoundaryRandomPointGenerator::apply(P, p, nburns, walk_len, + BurnInGenerator::apply(P, p, nburns, walk_len, randPoints, push_back_policy, rng); - randPoints.clear(); } unsigned int n = rnum / 2; - BoundaryRandomPointGenerator::apply(P, p, rnum / 2, walk_len, + SamplingGenerator::apply(P, p, rnum / 2, walk_len, randPoints, push_back_policy, rng); } @@ -167,13 +169,13 @@ void gaussian_sampling(PointList &randPoints, Point p = starting_point; - typedef GaussianRandomPointGenerator RandomPointGenerator; + typedef GaussianRandomPointGenerator SamplingGenerator; + typedef GaussianRandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(P, p, a, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, a, nburns, walk_len, randPoints, push_back_policy, rng); - randPoints.clear(); } - RandomPointGenerator::apply(P, p, a, rnum, walk_len, randPoints, + SamplingGenerator::apply(P, p, a, rnum, walk_len, randPoints, push_back_policy, rng); @@ -210,13 +212,13 @@ void gaussian_sampling(PointList &randPoints, Point p = starting_point; - typedef GaussianRandomPointGenerator RandomPointGenerator; + typedef GaussianRandomPointGenerator SamplingGenerator; + typedef GaussianRandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(P, p, a, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, a, nburns, walk_len, randPoints, push_back_policy, rng, WalkType.param); - randPoints.clear(); } - RandomPointGenerator::apply(P, p, a, rnum, walk_len, randPoints, + SamplingGenerator::apply(P, p, a, rnum, walk_len, randPoints, push_back_policy, rng, WalkType.param); } @@ -271,16 +273,15 @@ void logconcave_sampling(PointList &randPoints, walk logconcave_walk(&P, p, F, f, params); - typedef LogconcaveRandomPointGenerator RandomPointGenerator; - + typedef LogconcaveRandomPointGenerator SamplingGenerator; + typedef LogconcaveRandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(nburns, walk_len, randPoints, + BurnInGenerator::apply(nburns, walk_len, randPoints, push_back_policy, rng, logconcave_walk); } logconcave_walk.disable_adaptive(); - randPoints.clear(); - RandomPointGenerator::apply(rnum, walk_len, randPoints, + SamplingGenerator::apply(rnum, walk_len, randPoints, push_back_policy, rng, logconcave_walk); } #include "preprocess/crhmc/crhmc_input.h" @@ -348,13 +349,12 @@ void crhmc_sampling(PointList &randPoints, walk crhmc_walk = walk(problem, p, input.df, input.f, params); - typedef CrhmcRandomPointGenerator RandomPointGenerator; - - RandomPointGenerator::apply(problem, p, nburns, walk_len, randPoints, + typedef CrhmcRandomPointGenerator SamplingGenerator; + typedef CrhmcRandomPointGenerator BurnInGenerator; + BurnInGenerator::apply(problem, p, nburns, walk_len, randPoints, push_back_policy, rng, F, f, params, crhmc_walk); //crhmc_walk.disable_adaptive(); - randPoints.clear(); - RandomPointGenerator::apply(problem, p, rnum, walk_len, randPoints, + SamplingGenerator::apply(problem, p, rnum, walk_len, randPoints, push_back_policy, rng, F, f, params, crhmc_walk, simdLen, raw_output); } #include "ode_solvers/ode_solvers.hpp" @@ -464,13 +464,13 @@ void exponential_sampling(PointList &randPoints, Point p = starting_point; - typedef ExponentialRandomPointGenerator RandomPointGenerator; + typedef ExponentialRandomPointGenerator SamplingGenerator; + typedef ExponentialRandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(P, p, c, a, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, c, a, nburns, walk_len, randPoints, push_back_policy, rng); - randPoints.clear(); } - RandomPointGenerator::apply(P, p, c, a, rnum, walk_len, randPoints, + SamplingGenerator::apply(P, p, c, a, rnum, walk_len, randPoints, push_back_policy, rng); } @@ -505,13 +505,13 @@ void exponential_sampling(PointList &randPoints, Point p = starting_point; - typedef ExponentialRandomPointGenerator RandomPointGenerator; + typedef ExponentialRandomPointGenerator SamplingGenerator; + typedef ExponentialRandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(P, p, c, a, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, c, a, nburns, walk_len, randPoints, push_back_policy, rng, WalkType.param); - randPoints.clear(); } - RandomPointGenerator::apply(P, p, c, a, rnum, walk_len, randPoints, + SamplingGenerator::apply(P, p, c, a, rnum, walk_len, randPoints, push_back_policy, rng, WalkType.param); } @@ -547,13 +547,13 @@ void exponential_sampling(PointList &randPoints, Point p = starting_point; - typedef ExponentialRandomPointGenerator RandomPointGenerator; + typedef ExponentialRandomPointGenerator SamplingGenerator; + typedef ExponentialRandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(P, p, c, a, eta, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, c, a, eta, nburns, walk_len, randPoints, push_back_policy, rng); - randPoints.clear(); } - RandomPointGenerator::apply(P, p, c, a, eta, rnum, walk_len, randPoints, + SamplingGenerator::apply(P, p, c, a, eta, rnum, walk_len, randPoints, push_back_policy, rng); } @@ -589,13 +589,13 @@ void exponential_sampling(PointList &randPoints, Point p = starting_point; - typedef ExponentialRandomPointGenerator RandomPointGenerator; + typedef ExponentialRandomPointGenerator SamplingGenerator; + typedef ExponentialRandomPointGenerator BurnInGenerator; if (nburns > 0) { - RandomPointGenerator::apply(P, p, c, a, eta, nburns, walk_len, randPoints, + BurnInGenerator::apply(P, p, c, a, eta, nburns, walk_len, randPoints, push_back_policy, rng, WalkType.param); - randPoints.clear(); } - RandomPointGenerator::apply(P, p, c, a, eta, rnum, walk_len, randPoints, + SamplingGenerator::apply(P, p, c, a, eta, rnum, walk_len, randPoints, push_back_policy, rng, WalkType.param); }