Skip to content

Commit ad5cc6e

Browse files
committed
[CP-SAT] cleanup C and go API
1 parent d432627 commit ad5cc6e

File tree

7 files changed

+56
-72
lines changed

7 files changed

+56
-72
lines changed

ortools/sat/c_api/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ cc_library(
2525
"//ortools/sat:cp_model_solver",
2626
"//ortools/sat:model",
2727
"//ortools/sat:sat_parameters_cc_proto",
28-
"//ortools/util:time_limit",
2928
"@com_google_absl//absl/log:check",
3029
],
3130
)

ortools/sat/c_api/cp_solver_c.cc

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
#include "ortools/sat/c_api/cp_solver_c.h"
1515

16-
#include <atomic>
1716
#include <string>
1817

1918
#include "absl/log/check.h"
@@ -22,19 +21,15 @@
2221
#include "ortools/sat/cp_model_solver.h"
2322
#include "ortools/sat/model.h"
2423
#include "ortools/sat/sat_parameters.pb.h"
25-
#include "ortools/util/time_limit.h"
2624

2725
namespace operations_research::sat {
2826

2927
namespace {
3028

31-
CpSolverResponse solveWithParameters(std::atomic<bool>* const limit_reached,
32-
const CpModelProto& proto,
29+
CpSolverResponse solveWithParameters(Model* model, const CpModelProto& proto,
3330
const SatParameters& params) {
34-
Model model;
35-
model.Add(NewSatParameters(params));
36-
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(limit_reached);
37-
return SolveCpModel(proto, &model);
31+
model->Add(NewSatParameters(params));
32+
return SolveCpModel(proto, model);
3833
}
3934

4035
} // namespace
@@ -44,31 +39,28 @@ extern "C" {
4439
void SolveCpModelWithParameters(const void* creq, int creq_len,
4540
const void* cparams, int cparams_len,
4641
void** cres, int* cres_len) {
47-
return SolveCpInterruptible(nullptr, creq, creq_len, cparams, cparams_len,
48-
cres, cres_len);
42+
Model model;
43+
SolveCpInterruptible(&model, creq, creq_len, cparams, cparams_len, cres,
44+
cres_len);
4945
}
5046

51-
void* SolveCpNewAtomicBool() { return new std::atomic<bool>(false); }
47+
void* SolveCpNewEnv() { return new Model(); }
5248

53-
void SolveCpDestroyAtomicBool(void* const atomic_bool) {
54-
delete static_cast<std::atomic<bool>*>(atomic_bool);
55-
}
49+
void SolveCpDestroyEnv(void* const cenv) { delete static_cast<Model*>(cenv); }
5650

57-
void SolveCpStopSolve(void* const atomic_bool) {
58-
*static_cast<std::atomic<bool>*>(atomic_bool) = true;
59-
}
51+
void SolveCpStopSolve(void* cenv) { StopSearch(static_cast<Model*>(cenv)); }
6052

61-
void SolveCpInterruptible(void* const limit_reached, const void* creq,
62-
int creq_len, const void* cparams, int cparams_len,
63-
void** cres, int* cres_len) {
53+
void SolveCpInterruptible(void* const cenv, const void* creq, int creq_len,
54+
const void* cparams, int cparams_len, void** cres,
55+
int* cres_len) {
6456
CpModelProto req;
6557
CHECK(req.ParseFromArray(creq, creq_len));
6658

6759
SatParameters params;
6860
CHECK(params.ParseFromArray(cparams, cparams_len));
6961

70-
CpSolverResponse res = solveWithParameters(
71-
static_cast<std::atomic<bool>*>(limit_reached), req, params);
62+
CpSolverResponse res =
63+
solveWithParameters(static_cast<Model*>(cenv), req, params);
7264

7365
std::string res_str;
7466
CHECK(res.SerializeToString(&res_str));

ortools/sat/c_api/cp_solver_c.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ void SolveCpModelWithParameters(const void* creq, int creq_len,
2525
const void* cparams, int cparams_len,
2626
void** cres, int* cres_len);
2727

28-
void* SolveCpNewAtomicBool();
29-
void SolveCpDestroyAtomicBool(void* atomic_bool);
30-
void SolveCpStopSolve(void* atomic_bool);
28+
void* SolveCpNewEnv();
29+
void SolveCpDestroyEnv(void* cenv);
30+
void SolveCpStopSolve(void* cenv);
3131
// Allows for interruptible solves. Solves can be interrupted by calling
32-
// `SolveCpStopSolve` with the `limit_reached` atomic Boolean.
33-
void SolveCpInterruptible(void* limit_reached, const void* creq, int creq_len,
32+
// `SolveCpStopSolve` with the `cenv` argument.
33+
void SolveCpInterruptible(void* cenv, const void* creq, int creq_len,
3434
const void* cparams, int cparams_len, void** cres,
3535
int* cres_len);
3636

ortools/sat/cp_model.proto

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -827,10 +827,12 @@ message CpSolverResponse {
827827

828828
// Advanced usage.
829829
//
830-
// A lower bound on the inner integer expression of the objective. This is
831-
// either a bound on the expression in the returned integer_objective or on
832-
// the integer expression of the original objective if the problem already has
833-
// an integer objective.
830+
// A lower bound on the integer expression of the objective. This is either a
831+
// bound on the expression in the returned integer_objective or on the integer
832+
// expression of the original objective if the problem already has an integer
833+
// objective.
834+
//
835+
// TODO(user): This should be renamed integer_objective_lower_bound.
834836
int64 inner_objective_lower_bound = 29;
835837

836838
// Some statistics about the solve.

ortools/sat/go/cpmodel/cp_solver.go

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ func SolveCpModelWithParameters(input *cmpb.CpModelProto, params *sppb.SatParame
7171
}
7272

7373
// SolveCpModelInterruptibleWithParameters solves a CP Model with the given input proto
74-
// and parameters and returns a CPSolverResponse. The solve can be interrupted by triggering
75-
// the `interrupt`.
74+
// and parameters and returns a CPSolverResponse. The solve can be interrupted by calling
75+
// the `stopSolve`.
7676
func SolveCpModelInterruptibleWithParameters(input *cmpb.CpModelProto, params *sppb.SatParameters, interrupt <-chan struct{}) (*cmpb.CpSolverResponse, error) {
77-
// Create the atomic bool for interrupting solves.
78-
limitReached := newAtomicBoolWrapper()
79-
defer limitReached.delete()
77+
// Create the environment for interrupting solves.
78+
env := newEnvWrapper()
79+
defer env.delete()
8080

8181
// Transform `input` into bytes.
8282
bReq, err := proto.Marshal(input)
@@ -100,25 +100,24 @@ func SolveCpModelInterruptibleWithParameters(input *cmpb.CpModelProto, params *s
100100
go func() {
101101
select {
102102
case <-interrupt:
103-
limitReached.trigger()
103+
env.stopSolve()
104104
case <-solveDone:
105105
}
106106
}()
107107

108-
// We want to make sure we trigger the atomic Bool before we call the solver
109-
// if the input `interrupt` is already closed. We can't trust the
108+
// We want to make sure we stop the search before we call the solver. We can't trust the
110109
// scheduler to execute the previous goroutine immediately, even calling
111110
// `runtime.Gosched()` (the unit test failed 3 out of 1000 times when doing
112111
// so).
113112
select {
114113
case <-interrupt:
115-
limitReached.trigger()
114+
env.stopSolve()
116115
default:
117116
}
118117

119118
var cRes unsafe.Pointer
120119
var cResLen C.int
121-
C.SolveCpInterruptible(limitReached.ptr, cReq, C.int(len(bReq)), cParams, C.int(len(bParams)), &cRes, &cResLen)
120+
C.SolveCpInterruptible(env.ptr, cReq, C.int(len(bReq)), cParams, C.int(len(bParams)), &cRes, &cResLen)
122121
defer C.free(cRes)
123122

124123
// Transform `cRes` into the Go response proto.
@@ -131,29 +130,29 @@ func SolveCpModelInterruptibleWithParameters(input *cmpb.CpModelProto, params *s
131130
return result, nil
132131
}
133132

134-
// atomicBoolWrapper keeps a pointer on a C++ AtomicBool instance.
135-
type atomicBoolWrapper struct {
133+
// envWrapper keeps a pointer on a C++ Model instance.
134+
type envWrapper struct {
136135
mutex sync.Mutex
137136
ptr unsafe.Pointer // Guarded by mutex.
138137
}
139138

140-
// newAtomicBoolWrapper returns a new instance of a C++ AtomicBool.
139+
// newEnvWrapper returns a new instance of a C++ Model.
141140
//
142141
// The returned object must be destroyed with delete() for the C++ object not to
143142
// leak.
144143
//
145-
// This object is thread-safe: delete() and trigger() can be called
144+
// This object is thread-safe: delete() and stopSolve() can be called
146145
// concurrently.
147-
func newAtomicBoolWrapper() *atomicBoolWrapper {
148-
return &atomicBoolWrapper{
149-
ptr: C.SolveCpNewAtomicBool(),
146+
func newEnvWrapper() *envWrapper {
147+
return &envWrapper{
148+
ptr: C.SolveCpNewEnv(),
150149
}
151150
}
152151

153-
// trigger triggers the C++ SolveCpStopSolve method with the atomic bool.
152+
// stopSolve triggers the C++ SolveCpStopSolve method with the environment.
154153
//
155-
// If the atomic bool has been deleted this has no effect.
156-
func (intr *atomicBoolWrapper) trigger() {
154+
// If the environment has been deleted this has no effect.
155+
func (intr *envWrapper) stopSolve() {
157156
intr.mutex.Lock()
158157
defer intr.mutex.Unlock()
159158
if uintptr(intr.ptr) != 0 {
@@ -164,12 +163,12 @@ func (intr *atomicBoolWrapper) trigger() {
164163
// delete deletes the underlying C++ object.
165164
//
166165
// Calling it multiple times has not effect.
167-
func (intr *atomicBoolWrapper) delete() {
166+
func (intr *envWrapper) delete() {
168167
intr.mutex.Lock()
169168
defer intr.mutex.Unlock()
170169
// We don't test that intr.ptr is not nullptr here since C++ `delete` can be
171170
// called with nullptr.
172-
C.SolveCpDestroyAtomicBool(intr.ptr)
171+
C.SolveCpDestroyEnv(intr.ptr)
173172
intr.ptr = unsafe.Pointer(uintptr(0))
174173
}
175174

ortools/sat/presolve_context.cc

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,22 +2017,15 @@ bool PresolveContext::CanonicalizeObjective(bool simplify_domain) {
20172017
objective_offset_ /= static_cast<double>(gcd);
20182018
objective_scaling_factor_ *= static_cast<double>(gcd);
20192019

2020-
// We update the offset accordingly.
2021-
absl::int128 offset = absl::int128(objective_integer_before_offset_) *
2022-
absl::int128(objective_integer_scaling_factor_) +
2023-
absl::int128(objective_integer_after_offset_);
2024-
2025-
if (objective_domain_.IsFixed()) {
2026-
// To avoid overflow in (fixed_value * gcd + before_offset) * factor +
2027-
// after_offset because the objective is constant (and should fit on an
2028-
// int64_t), we can rewrite it as fixed_value + offset.
2029-
objective_integer_scaling_factor_ = 1;
2030-
offset +=
2031-
absl::int128(gcd - 1) * absl::int128(objective_domain_.FixedValue());
2032-
} else {
2033-
objective_integer_scaling_factor_ *= gcd;
2034-
}
2035-
2020+
// We update the integer offsets accordingly.
2021+
//
2022+
// We compute the old "a * objective_scaling_factor_ + b" offset and rewrite
2023+
// it in term of the new "objective_scaling_factor_".
2024+
const absl::int128 offset =
2025+
absl::int128(objective_integer_before_offset_) *
2026+
absl::int128(objective_integer_scaling_factor_) +
2027+
absl::int128(objective_integer_after_offset_);
2028+
objective_integer_scaling_factor_ *= gcd;
20362029
objective_integer_before_offset_ = static_cast<int64_t>(
20372030
offset / absl::int128(objective_integer_scaling_factor_));
20382031
objective_integer_after_offset_ = static_cast<int64_t>(

ortools/sat/swig_helper.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include "ortools/sat/cp_model_utils.h"
2727
#include "ortools/sat/model.h"
2828
#include "ortools/sat/sat_parameters.pb.h"
29-
#include "ortools/sat/util.h"
3029
#include "ortools/util/logging.h"
3130
#include "ortools/util/sorted_interval_list.h"
3231

@@ -156,7 +155,7 @@ operations_research::sat::CpSolverResponse SolveWrapper::Solve(
156155
}
157156

158157
void SolveWrapper::StopSearch() {
159-
model_.GetOrCreate<ModelSharedTimeLimit>()->Stop();
158+
::operations_research::sat::StopSearch(&model_);
160159
}
161160

162161
std::string CpSatHelper::ModelStats(

0 commit comments

Comments
 (0)