Skip to content

Commit 5bb3dd7

Browse files
committed
Support for constraint problem in sample_points
1 parent 5dfb3a8 commit 5bb3dd7

File tree

2 files changed

+74
-96
lines changed

2 files changed

+74
-96
lines changed

examples/sample_points/simple_example.cpp

Lines changed: 60 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ using NT = double;
2525
using MT = Eigen::Matrix<NT,Eigen::Dynamic,Eigen::Dynamic>;
2626
using VT = Eigen::Matrix<NT,Eigen::Dynamic,1>;
2727

28-
template <typename Walk, typename Distribution>
29-
void sample_points_eigen_matrix(HPolytopeType const& HP, Point const& q, Walk const& walk,
28+
template <typename PolytopeOrProblem, typename Walk, typename Distribution>
29+
void sample_points_eigen_matrix(PolytopeOrProblem const& HP, Point const& q, Walk const& walk,
3030
Distribution const& distr, RNGType rng, int walk_len, int rnum,
3131
int nburns)
3232
{
33-
MT samples(HP.dimension(), rnum);
33+
MT samples(q.dimension(), rnum);
3434

3535
sample_points(HP, q, walk, distr, rng, walk_len, rnum, nburns, samples);
3636

@@ -96,84 +96,48 @@ struct CustomFunctor {
9696
};
9797
};
9898

99-
struct CustomGaussianFunctor {
100-
101-
template
102-
<
103-
typename NT,
104-
typename Point
105-
>
106-
struct parameters {
107-
Point x0;
108-
NT a;
109-
NT eta;
110-
unsigned int order;
111-
NT L; // Lipschitz constant for gradient
112-
NT m; // Strong convexity constant
113-
NT kappa; // Condition number
114-
115-
parameters(Point x0_, NT a_, NT eta_) :
116-
x0(x0_), a(a_), eta(eta_), order(2), L(2 * a_), m(2 * a_), kappa(1) {};
117-
118-
};
119-
120-
template<typename Point>
121-
struct GradientFunctor {
122-
typedef typename Point::FT NT;
123-
typedef std::vector<Point> pts;
124-
125-
parameters<NT, Point> &params;
126-
127-
GradientFunctor(parameters<NT, Point> &params_) : params(params_) {};
128-
129-
// The index i represents the state vector index
130-
/*
131-
Point operator() (unsigned int const& i, pts const& xs, NT const& t) const {
132-
if (i == params.order - 1) {
133-
Point y = (-2.0 * params.a) * (xs[0] - params.x0);
134-
return y;
135-
} else {
136-
return xs[i + 1]; // returns derivative
137-
}
138-
}*/
139-
Point operator()(Point const&x){
140-
Point y = (-2.0 * params.a) * (x - params.x0);
141-
return y;
142-
}
143-
};
144-
145-
template<typename Point>
146-
struct FunctionFunctor {
147-
typedef typename Point::FT NT;
148-
149-
parameters<NT, Point> &params;
150-
151-
FunctionFunctor(parameters<NT, Point> &params_) : params(params_) {};
152-
153-
// The index i represents the state vector index
154-
NT operator() (Point const& x) const {
155-
Point y = x - params.x0;
156-
return params.a * y.dot(y);
157-
}
158-
};
159-
160-
template<typename Point>
161-
struct HessianFunctor {
162-
typedef typename Point::FT NT;
99+
inline bool exists_check(const std::string &name) {
100+
std::ifstream f(name.c_str());
101+
return f.good();
102+
}
163103

164-
parameters<NT, Point> &params;
104+
template< typename SpMat, typename VT>
105+
void load_crhmc_problem(SpMat &A, VT &b, VT &lb, VT &ub, int &dimension,
106+
std::string problem_name) {
107+
{
108+
std::string fileName("../crhmc_sampling/data/");
109+
fileName.append(problem_name);
110+
fileName.append(".mm");
111+
if(!exists_check(fileName)){
112+
std::cerr<<"Problem does not exist.\n";
113+
exit(1);}
114+
SpMat X;
115+
loadMarket(X, fileName);
116+
int m = X.rows();
117+
dimension = X.cols() - 1;
118+
A = X.leftCols(dimension);
119+
b = VT(X.col(dimension));
120+
}
121+
{
122+
std::string fileName("../crhmc_sampling/data/");
123+
fileName.append(problem_name);
124+
fileName.append("_bounds.mm");
125+
if(!exists_check(fileName)){
126+
std::cerr<<"Problem does not exist.\n";
127+
exit(1);}
128+
SpMat bounds;
129+
loadMarket(bounds, fileName);
130+
lb = VT(bounds.col(0));
131+
ub = VT(bounds.col(1));
132+
}
133+
}
165134

166-
HessianFunctor(parameters<NT, Point> &params_) : params(params_) {};
167135

168-
// The index i represents the state vector index
169-
Point operator() (Point const& x) const {
170-
return (2.0 * params.a) * Point::all_ones(x.dimension());
171-
}
172-
};
136+
int main() {
137+
// NEW INTERFACE Sampling
173138

174-
};
139+
// Inputs:
175140

176-
int main() {
177141
// Generating a 3-dimensional cube centered at origin
178142
HPolytopeType HP = generate_cube<HPolytopeType>(10, false);
179143
std::cout<<"Polytope: \n";
@@ -185,7 +149,19 @@ int main() {
185149
Point q(HP.dimension());
186150
RNGType rng(HP.dimension());
187151

188-
// NEW INTERFACE Sampling
152+
// Generating a sparse polytope/problem
153+
using SpMat = Eigen::SparseMatrix<NT>;
154+
using ConstraintProblem =constraint_problem<SpMat, Point>;
155+
std::string problem_name("simplex3");
156+
std::cerr << "CRHMC on " << problem_name << "\n";
157+
SpMat As;
158+
VT b, lb, ub;
159+
int dimension;
160+
load_crhmc_problem(As, b, lb, ub, dimension, problem_name);
161+
ConstraintProblem problem = ConstraintProblem(dimension);
162+
problem.set_equality_constraints(As, b);
163+
problem.set_bounds(lb, ub);
164+
189165

190166
// Walks
191167
AcceleratedBilliardWalk abill_walk;
@@ -245,15 +221,16 @@ int main() {
245221
NegativeLogprobFunctorR fr(params_r);
246222
LogConcaveDistribution logconcave_reflective(gr, fr, params_r.L);
247223

248-
using NegativeGradientFunctor = CustomGaussianFunctor::GradientFunctor<Point>;
249-
using NegativeLogprobFunctor = CustomGaussianFunctor::FunctionFunctor<Point>;
250-
using HessianFunctor = CustomGaussianFunctor::HessianFunctor<Point>;
251-
CustomGaussianFunctor::parameters<NT, Point> params(x0, 0.5, 1);
224+
using NegativeGradientFunctor = GaussianFunctor::GradientFunctor<Point>;
225+
using NegativeLogprobFunctor = GaussianFunctor::FunctionFunctor<Point>;
226+
using HessianFunctor = GaussianFunctor::HessianFunctor<Point>;
227+
GaussianFunctor::parameters<NT, Point> params(x0, 0.5, 1);
252228
NegativeGradientFunctor g(params);
253229
NegativeLogprobFunctor f(params);
254230
HessianFunctor h(params);
255231
LogConcaveDistribution logconcave_crhmc(g, f, h, params.L);
256232

233+
LogConcaveDistribution logconcave_ref_gaus(g, f, params.L);
257234

258235
// Sampling
259236

@@ -293,8 +270,10 @@ int main() {
293270
std::cout << "logconcave" << std::endl;
294271
sample_points_eigen_matrix(HP, q, hmc_walk, logconcave_reflective, rng, walk_len, rnum, nburns);
295272
sample_points_eigen_matrix(HP, q, nhmc_walk, logconcave_reflective, rng, walk_len, rnum, nburns);
273+
sample_points_eigen_matrix(HP, q, nhmc_walk, logconcave_ref_gaus, rng, walk_len, rnum, nburns);
296274

297275
sample_points_eigen_matrix(HP, q, crhmc_walk, logconcave_crhmc, rng, walk_len, rnum, nburns);
276+
sample_points_eigen_matrix(problem, q, crhmc_walk, logconcave_crhmc, rng, walk_len, rnum, nburns);
298277

299278

300279
std::cout << "fix the following" << std::endl;

include/sampling/sample_points.hpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,8 @@ void sample_points(Polytope& P, // TODO: make it a const&
258258
using HPolytope = typename std::remove_const<Polytope>::type;
259259
HPolytope HP = P; //TODO: avoid the copy
260260

261-
constexpr int simdLen = 8;
261+
constexpr int simdLen = 8; //TODO: input parameter
262262
using NT = double;
263-
using MT = typename HPolytope::MT;
264263

265264
int dimension = HP.dimension();
266265

@@ -298,19 +297,19 @@ void sample_points(Polytope& P, // TODO: make it a const&
298297
>;
299298

300299
using Walk = typename WalkType::template Walk
301-
<
302-
Point,
303-
CrhmcProblem,
304-
RandomNumberGenerator,
305-
NegativeGradientFunctor,
306-
NegativeLogprobFunctor,
307-
Solver
308-
>;
300+
<
301+
Point,
302+
CrhmcProblem,
303+
RandomNumberGenerator,
304+
NegativeGradientFunctor,
305+
NegativeLogprobFunctor,
306+
Solver
307+
>;
309308
using WalkParams = typename WalkType::template parameters
310-
<
311-
NT,
312-
NegativeGradientFunctor
313-
>;
309+
<
310+
NT,
311+
NegativeGradientFunctor
312+
>;
314313
Point p = Point(problem.center);
315314
problem.options.simdLen=simdLen;
316315
WalkParams params(distribution.L, p.dimension(), problem.options);
@@ -333,7 +332,7 @@ void sample_points(Polytope& P, // TODO: make it a const&
333332
{
334333
walk.apply(rng, walk_len);
335334
if (walk.P.terminate) {return;}
336-
MT x = raw_output ? walk.x : walk.getPoints();
335+
auto x = raw_output ? walk.x : walk.getPoints();
337336

338337
if ((i + 1) * simdLen > rnum)
339338
{

0 commit comments

Comments
 (0)