Skip to content

Commit 5dfb3a8

Browse files
committed
New interface for sampling. Function is parameterized by walk and distribution. Added support for logconcave distributions (crhmc)
1 parent d7ea259 commit 5dfb3a8

File tree

4 files changed

+233
-60
lines changed

4 files changed

+233
-60
lines changed

examples/sample_points/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ GetBoost()
1818
include("../../external/cmake-files/LPSolve.cmake")
1919
GetLPSolve()
2020

21+
include("../../external/cmake-files/QD.cmake")
22+
GetQD()
23+
2124
set(CMAKE_EXPORT_COMPILE_COMMANDS "ON")
2225

2326
add_definitions(${CMAKE_CXX_FLAGS} "-g")
@@ -26,4 +29,4 @@ include_directories (BEFORE ../../external)
2629
include_directories (BEFORE ../../include)
2730

2831
add_executable (example simple_example.cpp)
29-
target_link_libraries(example PUBLIC lp_solve)
32+
target_link_libraries(example PUBLIC lp_solve QD_LIB)

examples/sample_points/simple_example.cpp

Lines changed: 95 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ void sample_points_eigen_matrix(HPolytopeType const& HP, Point const& q, Walk co
4444
struct CustomFunctor {
4545

4646
// Custom density with neg log prob equal to c^T x
47-
template <
47+
template
48+
<
4849
typename NT,
4950
typename Point
5051
>
@@ -59,10 +60,7 @@ struct CustomFunctor {
5960

6061
};
6162

62-
template
63-
<
64-
typename Point
65-
>
63+
template <typename Point>
6664
struct GradientFunctor {
6765
typedef typename Point::FT NT;
6866
typedef std::vector<Point> pts;
@@ -80,13 +78,71 @@ struct CustomFunctor {
8078
return xs[i + 1]; // returns derivative
8179
}
8280
}
81+
};
82+
83+
template<typename Point>
84+
struct FunctionFunctor {
85+
typedef typename Point::FT NT;
86+
87+
parameters<NT, Point> &params;
88+
89+
FunctionFunctor(parameters<NT, Point> &params_) : params(params_) {};
8390

91+
// The index i represents the state vector index
92+
NT operator() (Point const& x) const {
93+
Point y = x - params.x0;
94+
return 0.5 * y.dot(y);
95+
}
8496
};
97+
};
98+
99+
struct CustomGaussianFunctor {
85100

86101
template
87102
<
88-
typename Point
103+
typename NT,
104+
typename Point
89105
>
106+
struct parameters {
107+
Point x0;
108+
NT a;
109+
NT eta;
110+
unsigned int order;
111+
NT L; // Lipschitz constant for gradient
112+
NT m; // Strong convexity constant
113+
NT kappa; // Condition number
114+
115+
parameters(Point x0_, NT a_, NT eta_) :
116+
x0(x0_), a(a_), eta(eta_), order(2), L(2 * a_), m(2 * a_), kappa(1) {};
117+
118+
};
119+
120+
template<typename Point>
121+
struct GradientFunctor {
122+
typedef typename Point::FT NT;
123+
typedef std::vector<Point> pts;
124+
125+
parameters<NT, Point> &params;
126+
127+
GradientFunctor(parameters<NT, Point> &params_) : params(params_) {};
128+
129+
// The index i represents the state vector index
130+
/*
131+
Point operator() (unsigned int const& i, pts const& xs, NT const& t) const {
132+
if (i == params.order - 1) {
133+
Point y = (-2.0 * params.a) * (xs[0] - params.x0);
134+
return y;
135+
} else {
136+
return xs[i + 1]; // returns derivative
137+
}
138+
}*/
139+
Point operator()(Point const&x){
140+
Point y = (-2.0 * params.a) * (x - params.x0);
141+
return y;
142+
}
143+
};
144+
145+
template<typename Point>
90146
struct FunctionFunctor {
91147
typedef typename Point::FT NT;
92148

@@ -97,9 +153,22 @@ struct CustomFunctor {
97153
// The index i represents the state vector index
98154
NT operator() (Point const& x) const {
99155
Point y = x - params.x0;
100-
return 0.5 * y.dot(y);
156+
return params.a * y.dot(y);
101157
}
158+
};
102159

160+
template<typename Point>
161+
struct HessianFunctor {
162+
typedef typename Point::FT NT;
163+
164+
parameters<NT, Point> &params;
165+
166+
HessianFunctor(parameters<NT, Point> &params_) : params(params_) {};
167+
168+
// The index i represents the state vector index
169+
Point operator() (Point const& x) const {
170+
return (2.0 * params.a) * Point::all_ones(x.dimension());
171+
}
103172
};
104173

105174
};
@@ -140,6 +209,7 @@ int main() {
140209

141210
HamiltonianMonteCarloWalk hmc_walk;
142211
NutsHamiltonianMonteCarloWalk nhmc_walk;
212+
CRHMCWalk crhmc_walk;
143213

144214
// Distributions
145215

@@ -161,45 +231,30 @@ int main() {
161231
ExponentialDistribution edistr(c, variance);
162232

163233
// 4. LogConcave
164-
using NegativeGradientFunctor = CustomFunctor::GradientFunctor<Point>;
165-
using NegativeLogprobFunctor = CustomFunctor::FunctionFunctor<Point>;
166-
using Solver = LeapfrogODESolver<Point, NT, HPolytopeType, NegativeGradientFunctor>;
167234

168235
std::pair<Point, NT> inner_ball = HP.ComputeInnerBall();
169236
Point x0 = inner_ball.first;
170237

171-
CustomFunctor::parameters<NT, Point> params(x0);
238+
// Reflective HMC and Remmannian HMC are using slightly different functor interfaces
239+
// TODO: check if this could be unified
172240

173-
NegativeGradientFunctor g(params);
174-
NegativeLogprobFunctor f(params);
175-
LogConcaveDistribution logconcave(g, f, params.L);
241+
using NegativeGradientFunctorR = CustomFunctor::GradientFunctor<Point>;
242+
using NegativeLogprobFunctorR = CustomFunctor::FunctionFunctor<Point>;
243+
CustomFunctor::parameters<NT, Point> params_r(x0);
244+
NegativeGradientFunctorR gr(params_r);
245+
NegativeLogprobFunctorR fr(params_r);
246+
LogConcaveDistribution logconcave_reflective(gr, fr, params_r.L);
176247

177-
/*
178-
NegativeGradientFunctor F(params);
248+
using NegativeGradientFunctor = CustomGaussianFunctor::GradientFunctor<Point>;
249+
using NegativeLogprobFunctor = CustomGaussianFunctor::FunctionFunctor<Point>;
250+
using HessianFunctor = CustomGaussianFunctor::HessianFunctor<Point>;
251+
CustomGaussianFunctor::parameters<NT, Point> params(x0, 0.5, 1);
252+
NegativeGradientFunctor g(params);
179253
NegativeLogprobFunctor f(params);
254+
HessianFunctor h(params);
255+
LogConcaveDistribution logconcave_crhmc(g, f, h, params.L);
180256

181-
HamiltonianMonteCarloWalk::parameters<NT, NegativeGradientFunctor> hmc_params(F, HP.dimension());
182-
183-
HamiltonianMonteCarloWalk hmc(F, f, hmc_params);
184-
185-
int n_samples = 80000;
186-
int n_burns = 0;
187-
188-
MT samples;
189-
samples.resize(dim, n_samples - n_burns);
190257

191-
hmc.solver->eta0 = 0.5;
192-
193-
for (int i = 0; i < n_samples; i++) {
194-
if (i % 1000 == 0) std::cerr << ".";
195-
hmc.apply(rng, 3);
196-
if (i >= n_burns) {
197-
samples.col(i - n_burns) = hmc.x.getCoefficients();
198-
std::cout << hmc.x.getCoefficients().transpose() << std::endl;
199-
}
200-
}
201-
std::cerr << std::endl;
202-
*/
203258
// Sampling
204259

205260
using NT = double;
@@ -236,9 +291,10 @@ int main() {
236291
sample_points_eigen_matrix(HP, q, ehmc_walk, edistr, rng, walk_len, rnum, nburns);
237292

238293
std::cout << "logconcave" << std::endl;
239-
sample_points_eigen_matrix(HP, q, hmc_walk, logconcave, rng, walk_len, rnum, nburns);
240-
sample_points_eigen_matrix(HP, q, nhmc_walk, logconcave, rng, walk_len, rnum, nburns);
294+
sample_points_eigen_matrix(HP, q, hmc_walk, logconcave_reflective, rng, walk_len, rnum, nburns);
295+
sample_points_eigen_matrix(HP, q, nhmc_walk, logconcave_reflective, rng, walk_len, rnum, nburns);
241296

297+
sample_points_eigen_matrix(HP, q, crhmc_walk, logconcave_crhmc, rng, walk_len, rnum, nburns);
242298

243299

244300
std::cout << "fix the following" << std::endl;

include/random_walks/crhmc/crhmc_walk.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ struct CRHMCWalk {
4343
eta = 1.0 / (dim * sqrt(F.params.L));
4444
momentum = 1 - std::min(1.0, eta / effectiveStepSize);
4545
}
46+
parameters(NT const& L,
47+
unsigned int dim,
48+
Opts &user_options,
49+
NT epsilon_ = 2)
50+
: options(user_options)
51+
{
52+
epsilon = epsilon_;
53+
eta = 1.0 / (dim * sqrt(L));
54+
momentum = 1 - std::min(1.0, eta / effectiveStepSize);
55+
}
4656
};
4757

4858
template

0 commit comments

Comments
 (0)