@@ -44,7 +44,8 @@ void sample_points_eigen_matrix(HPolytopeType const& HP, Point const& q, Walk co
4444struct CustomFunctor {
4545
4646 // Custom density with neg log prob equal to c^T x
47- template <
47+ template
48+ <
4849 typename NT,
4950 typename Point
5051 >
@@ -59,10 +60,7 @@ struct CustomFunctor {
5960
6061 };
6162
62- template
63- <
64- typename Point
65- >
63+ template <typename Point>
6664 struct GradientFunctor {
6765 typedef typename Point::FT NT;
6866 typedef std::vector<Point> pts;
@@ -80,13 +78,71 @@ struct CustomFunctor {
8078 return xs[i + 1 ]; // returns derivative
8179 }
8280 }
81+ };
82+
83+ template <typename Point>
84+ struct FunctionFunctor {
85+ typedef typename Point::FT NT;
86+
87+ parameters<NT, Point> ¶ms;
88+
89+ FunctionFunctor (parameters<NT, Point> ¶ms_) : params(params_) {};
8390
91+ // The index i represents the state vector index
92+ NT operator () (Point const & x) const {
93+ Point y = x - params.x0 ;
94+ return 0.5 * y.dot (y);
95+ }
8496 };
97+ };
98+
99+ struct CustomGaussianFunctor {
85100
86101 template
87102 <
88- typename Point
103+ typename NT,
104+ typename Point
89105 >
106+ struct parameters {
107+ Point x0;
108+ NT a;
109+ NT eta;
110+ unsigned int order;
111+ NT L; // Lipschitz constant for gradient
112+ NT m; // Strong convexity constant
113+ NT kappa; // Condition number
114+
115+ parameters (Point x0_, NT a_, NT eta_) :
116+ x0 (x0_), a(a_), eta(eta_), order(2 ), L(2 * a_), m(2 * a_), kappa(1 ) {};
117+
118+ };
119+
120+ template <typename Point>
121+ struct GradientFunctor {
122+ typedef typename Point::FT NT;
123+ typedef std::vector<Point> pts;
124+
125+ parameters<NT, Point> ¶ms;
126+
127+ GradientFunctor (parameters<NT, Point> ¶ms_) : params(params_) {};
128+
129+ // The index i represents the state vector index
130+ /*
131+ Point operator() (unsigned int const& i, pts const& xs, NT const& t) const {
132+ if (i == params.order - 1) {
133+ Point y = (-2.0 * params.a) * (xs[0] - params.x0);
134+ return y;
135+ } else {
136+ return xs[i + 1]; // returns derivative
137+ }
138+ }*/
139+ Point operator ()(Point const &x){
140+ Point y = (-2.0 * params.a ) * (x - params.x0 );
141+ return y;
142+ }
143+ };
144+
145+ template <typename Point>
90146 struct FunctionFunctor {
91147 typedef typename Point::FT NT;
92148
@@ -97,9 +153,22 @@ struct CustomFunctor {
97153 // The index i represents the state vector index
98154 NT operator () (Point const & x) const {
99155 Point y = x - params.x0 ;
100- return 0.5 * y.dot (y);
156+ return params. a * y.dot (y);
101157 }
158+ };
102159
160+ template <typename Point>
161+ struct HessianFunctor {
162+ typedef typename Point::FT NT;
163+
164+ parameters<NT, Point> ¶ms;
165+
166+ HessianFunctor (parameters<NT, Point> ¶ms_) : params(params_) {};
167+
168+ // The index i represents the state vector index
169+ Point operator () (Point const & x) const {
170+ return (2.0 * params.a ) * Point::all_ones (x.dimension ());
171+ }
103172 };
104173
105174};
@@ -140,6 +209,7 @@ int main() {
140209
141210 HamiltonianMonteCarloWalk hmc_walk;
142211 NutsHamiltonianMonteCarloWalk nhmc_walk;
212+ CRHMCWalk crhmc_walk;
143213
144214 // Distributions
145215
@@ -161,45 +231,30 @@ int main() {
161231 ExponentialDistribution edistr (c, variance);
162232
163233 // 4. LogConcave
164- using NegativeGradientFunctor = CustomFunctor::GradientFunctor<Point>;
165- using NegativeLogprobFunctor = CustomFunctor::FunctionFunctor<Point>;
166- using Solver = LeapfrogODESolver<Point, NT, HPolytopeType, NegativeGradientFunctor>;
167234
168235 std::pair<Point, NT> inner_ball = HP.ComputeInnerBall ();
169236 Point x0 = inner_ball.first ;
170237
171- CustomFunctor::parameters<NT, Point> params (x0);
238+ // Reflective HMC and Remmannian HMC are using slightly different functor interfaces
239+ // TODO: check if this could be unified
172240
173- NegativeGradientFunctor g (params);
174- NegativeLogprobFunctor f (params);
175- LogConcaveDistribution logconcave (g, f, params.L );
241+ using NegativeGradientFunctorR = CustomFunctor::GradientFunctor<Point>;
242+ using NegativeLogprobFunctorR = CustomFunctor::FunctionFunctor<Point>;
243+ CustomFunctor::parameters<NT, Point> params_r (x0);
244+ NegativeGradientFunctorR gr (params_r);
245+ NegativeLogprobFunctorR fr (params_r);
246+ LogConcaveDistribution logconcave_reflective (gr, fr, params_r.L );
176247
177- /*
178- NegativeGradientFunctor F(params);
248+ using NegativeGradientFunctor = CustomGaussianFunctor::GradientFunctor<Point>;
249+ using NegativeLogprobFunctor = CustomGaussianFunctor::FunctionFunctor<Point>;
250+ using HessianFunctor = CustomGaussianFunctor::HessianFunctor<Point>;
251+ CustomGaussianFunctor::parameters<NT, Point> params (x0, 0.5 , 1 );
252+ NegativeGradientFunctor g (params);
179253 NegativeLogprobFunctor f (params);
254+ HessianFunctor h (params);
255+ LogConcaveDistribution logconcave_crhmc (g, f, h, params.L );
180256
181- HamiltonianMonteCarloWalk::parameters<NT, NegativeGradientFunctor> hmc_params(F, HP.dimension());
182-
183- HamiltonianMonteCarloWalk hmc(F, f, hmc_params);
184-
185- int n_samples = 80000;
186- int n_burns = 0;
187-
188- MT samples;
189- samples.resize(dim, n_samples - n_burns);
190257
191- hmc.solver->eta0 = 0.5;
192-
193- for (int i = 0; i < n_samples; i++) {
194- if (i % 1000 == 0) std::cerr << ".";
195- hmc.apply(rng, 3);
196- if (i >= n_burns) {
197- samples.col(i - n_burns) = hmc.x.getCoefficients();
198- std::cout << hmc.x.getCoefficients().transpose() << std::endl;
199- }
200- }
201- std::cerr << std::endl;
202- */
203258 // Sampling
204259
205260 using NT = double ;
@@ -236,9 +291,10 @@ int main() {
236291 sample_points_eigen_matrix (HP, q, ehmc_walk, edistr, rng, walk_len, rnum, nburns);
237292
238293 std::cout << " logconcave" << std::endl;
239- sample_points_eigen_matrix (HP, q, hmc_walk, logconcave , rng, walk_len, rnum, nburns);
240- sample_points_eigen_matrix (HP, q, nhmc_walk, logconcave , rng, walk_len, rnum, nburns);
294+ sample_points_eigen_matrix (HP, q, hmc_walk, logconcave_reflective , rng, walk_len, rnum, nburns);
295+ sample_points_eigen_matrix (HP, q, nhmc_walk, logconcave_reflective , rng, walk_len, rnum, nburns);
241296
297+ sample_points_eigen_matrix (HP, q, crhmc_walk, logconcave_crhmc, rng, walk_len, rnum, nburns);
242298
243299
244300 std::cout << " fix the following" << std::endl;
0 commit comments