Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hessian preconditioner wrappers #410

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion src/api/nlopt-in.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace nlopt {

typedef nlopt_func func; // nlopt::func synoynm
typedef nlopt_mfunc mfunc; // nlopt::mfunc synoynm
typedef nlopt_precond pfunc;

// alternative to nlopt_func that takes std::vector<double>
// ... unfortunately requires a data copy
Expand Down Expand Up @@ -88,7 +89,7 @@ namespace nlopt {

typedef struct {
opt *o;
mfunc mf; func f; void *f_data;
pfunc pf; mfunc mf; func f; void *f_data;
vfunc vf;
nlopt_munge munge_destroy, munge_copy; // non-NULL for SWIG wrappers
} myfunc_data;
Expand Down Expand Up @@ -142,6 +143,27 @@ namespace nlopt {
return HUGE_VAL;
}

// nlopt_precond wrapper that catches exceptions
static void mypfunc( unsigned n, const double *x, const double *v, double *vpre, void *d_) {
myfunc_data *d = reinterpret_cast<myfunc_data*>(d_);
try {
d->pf(n, x, v, vpre, d->f_data);
return;
}
catch (std::bad_alloc&)
{ d->o->forced_stop_reason = NLOPT_OUT_OF_MEMORY; }
catch (std::invalid_argument&)
{ d->o->forced_stop_reason = NLOPT_INVALID_ARGS; }
catch (roundoff_limited&)
{ d->o->forced_stop_reason = NLOPT_ROUNDOFF_LIMITED; }
catch (forced_stop&)
{ d->o->forced_stop_reason = NLOPT_FORCED_STOP; }
catch (...)
{ d->o->forced_stop_reason = NLOPT_FAILURE; }
d->o->force_stop(); // stop gracefully, opt::optimize will re-throw
//return HUGE_VAL;
}

// nlopt_mfunc wrapper that catches exceptions
static void mymfunc(unsigned m, double *result,
unsigned n, const double *x, double *grad, void *d_) {
Expand Down Expand Up @@ -314,6 +336,16 @@ namespace nlopt {
mythrow(nlopt_set_max_objective(o, myvfunc, d)); // d freed via o
alloc_tmp();
}
void set_precond_min_objective(func f, pfunc pf, void *f_data) {
myfunc_data *d = new myfunc_data;
if (!d) throw std::bad_alloc();
d->o = this; d->f = f; d->f_data = f_data; d->mf = NULL; d->vf = NULL;
d->pf = pf;
d->munge_destroy = d->munge_copy = NULL;
printf("ALEC\n");
mythrow(nlopt_set_precond_min_objective(o, myfunc, mypfunc, d)); // d freed via o
alloc_tmp();
}

// for internal use in SWIG wrappers -- variant that
// takes ownership of f_data, with munging for destroy/copy
Expand All @@ -333,6 +365,16 @@ namespace nlopt {
d->munge_destroy = md; d->munge_copy = mc;
mythrow(nlopt_set_max_objective(o, myfunc, d)); // d freed via o
}
void set_precond_min_objective(func f, pfunc pf, void *f_data,
nlopt_munge md, nlopt_munge mc) {
myfunc_data *d = new myfunc_data;
if (!d) throw std::bad_alloc();
d->o = this; d->f = f; d->f_data = f_data; d->mf = NULL; d->vf = NULL;
d->pf = pf;
d->munge_destroy = md; d->munge_copy = mc;
printf("ALEC\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this line is an artifact of some debugging?

mythrow(nlopt_set_precond_min_objective(o, myfunc, mypfunc, d)); // d freed via o
}

// Nonlinear constraints:

Expand Down
3 changes: 3 additions & 0 deletions src/swig/nlopt-exceptions.i
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
%catches(std::bad_alloc,std::invalid_argument) nlopt::opt::set_max_objective(func f, void *f_data);
%catches(std::bad_alloc,std::invalid_argument) nlopt::opt::set_max_objective(vfunc vf, void *f_data);

%catches(std::bad_alloc,std::invalid_argument) nlopt::opt::set_precond_min_objective(func f, pfunc pf, void *f_data);

%catches(std::bad_alloc,std::invalid_argument) nlopt::opt::set_min_objective(func f, void *f_data, nlopt_munge md, nlopt_munge mc);
%catches(std::bad_alloc,std::invalid_argument) nlopt::opt::set_max_objective(func f, void *f_data, nlopt_munge md, nlopt_munge mc);
%catches(std::bad_alloc,std::invalid_argument) nlopt::opt::set_precond_min_objective(func f, pfunc pf, void *f_data, nlopt_munge md, nlopt_munge mc);

%catches(std::invalid_argument) nlopt::opt::remove_inequality_constraints();
%catches(std::bad_alloc,std::invalid_argument) nlopt::opt::add_inequality_constraint(func f, void *f_data, double tol=0);
Expand Down
29 changes: 28 additions & 1 deletion src/swig/nlopt-guile.i
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ static void *free_guilefunc(void *p) {
static void *dup_guilefunc(void *p) {
scm_gc_protect_object((SCM) p); return p; }

// func wrapper around Guile function val = f(x, grad)
static void pfunc_guile(unsigned n, const double *x, const double *v, double *vpre, void *f) {
SCM xscm = scm_c_make_vector(n, SCM_UNSPECIFIED);
SCM vscm = scm_c_make_vector(n, SCM_UNSPECIFIED);
for (unsigned i = 0; i < n; ++i){
SCM_SIMPLE_VECTOR_SET(xscm, i, scm_from_double(x[i]));
SCM_SIMPLE_VECTOR_SET(vscm, i, scm_from_double(v[i]));
}
SCM vpre_scm = vpre ? scm_c_make_vector(n, SCM_UNSPECIFIED) : SCM_BOOL_F;
scm_call_3((SCM) f, xscm, vscm, vpre_scm);
if (vpre) {
for (unsigned i = 0; i < n; ++i) {
vpre[i] = scm_to_double(SCM_SIMPLE_VECTOR_REF(vpre_scm, i));
}
}
}

// func wrapper around Guile function val = f(x, grad)
static double func_guile(unsigned n, const double *x, double *grad, void *f) {
SCM xscm = scm_c_make_vector(n, SCM_UNSPECIFIED);
Expand All @@ -70,10 +87,20 @@ static double func_guile(unsigned n, const double *x, double *grad, void *f) {
$3 = free_guilefunc;
$4 = dup_guilefunc;
}
%typemap(in)(nlopt::func f, nlopt::pfunc pf, void *f_data, nlopt_munge md, nlopt_munge mc) {
$1 = func_guile;
$2 = pfunc_guile;
$3 = dup_guilefunc((void*) $input); // input = SCM pointer to Scheme function
$4 = free_guilefunc;
$5 = dup_guilefunc;
}
%typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
$1 = scm_is_true(scm_procedure_p($input));
}

%typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, nlopt::pfunc pf, void *f_data, nlopt_munge md, nlopt_munge mc) {
$1 = scm_is_true(scm_procedure_p($input));
//$2 = scm_is_true(scm_procedure_p($input));
}
// export constants as variables, rather than as functions returning the value
%feature("constasvar", "1");

Expand Down
39 changes: 39 additions & 0 deletions src/swig/nlopt-python.i
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,33 @@ static double func_python(unsigned n, const double *x, double *grad, void *f)
return val;
}

static void pfunc_python(unsigned n, const double *x, const double *v, double *vpre, void *f)
{
npy_intp sz = npy_intp(n), sz0 = 0, stride1 = sizeof(double);
PyObject *xpy = PyArray_New(&PyArray_Type, 1, &sz, NPY_DOUBLE, &stride1,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like these calls are equivalent to calling PyArray_SimpleNewFromData?

const_cast<double*>(x), // not NPY_WRITEABLE
0, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL);
PyObject *vpy = PyArray_New(&PyArray_Type, 1, &sz, NPY_DOUBLE, &stride1,
const_cast<double*>(v), // not NPY_WRITEABLE
0, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL);
PyObject *vprepy = vpre
? PyArray_SimpleNewFromData(1, &sz, NPY_DOUBLE, vpre)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't vpre always be non-NULL here?

: PyArray_SimpleNew(1, &sz0, NPY_DOUBLE);

PyObject *arglist = Py_BuildValue("OO", xpy, vpy, vprepy);
PyEval_CallObject((PyObject *) f, arglist);
Py_DECREF(arglist);

Py_DECREF(vprepy);
Py_DECREF(vpy);
Py_DECREF(xpy);

double val = HUGE_VAL;
if (PyErr_Occurred()) {
throw nlopt::forced_stop(); // just stop, don't call PyErr_Clear()
}
}

static void mfunc_python(unsigned m, double *result,
unsigned n, const double *x, double *grad, void *f)
{
Expand Down Expand Up @@ -195,6 +222,18 @@ static void mfunc_python(unsigned m, double *result,
$1 = PyCallable_Check($input);
}

%typemap(in)(nlopt::func f, nlopt::pfunc pf, void *f_data, nlopt_munge md, nlopt_munge mc) {
$1 = func_python;
$2 = pfunc_python;
$3 = dup_pyfunc((void*) $input);
$4 = free_pyfunc;
$5 = dup_pyfunc;
}

%typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, nlopt::pfunc pf, void *f_data, nlopt_munge md, nlopt_munge mc) {
$1 = PyCallable_Check($input);
}

%typemap(in)(nlopt::mfunc mf, void *f_data, nlopt_munge md, nlopt_munge mc) {
$1 = mfunc_python;
$2 = dup_pyfunc((void*) $input);
Expand Down