Skip to content

Commit 8dd4924

Browse files
committed
[CP-SAT] add StopSearch C++ function.
1 parent 35c27ab commit 8dd4924

File tree

9 files changed

+41
-37
lines changed

9 files changed

+41
-37
lines changed

examples/cpp/network_routing_sat.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,6 @@ class NetworkRoutingSolver {
403403
cp_model.AddAllDifferent(node_vars);
404404

405405
Model model;
406-
// Create an atomic Boolean that will be periodically checked by the limit.
407-
std::atomic<bool> stopped(false);
408-
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
409406

410407
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
411408
const int path_id = all_paths_[demand_index].size();
@@ -415,7 +412,7 @@ class NetworkRoutingSolver {
415412
all_paths_[demand_index].back().insert(arc);
416413
}
417414
if (all_paths_[demand_index].size() >= max_paths) {
418-
stopped = true;
415+
StopSearch(&model);
419416
}
420417
}));
421418

examples/cpp/nqueens.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,14 @@ void CheckNumberOfSolutions(int size, int num_solutions) {
182182
if (absl::GetFlag(FLAGS_use_symmetry)) {
183183
if (size - 1 < kKnownUniqueSolutions) {
184184
CHECK_EQ(num_solutions, kNumUniqueSolutions[size - 1]);
185+
} else if (!absl::GetFlag(FLAGS_cp_disable_solve)) {
186+
CHECK_GT(num_solutions, 0);
185187
}
186188
} else {
187189
if (size - 1 < kKnownSolutions) {
188190
CHECK_EQ(num_solutions, kNumSolutions[size - 1]);
191+
} else if (!absl::GetFlag(FLAGS_cp_disable_solve)) {
192+
CHECK_GT(num_solutions, 0);
189193
}
190194
}
191195
}

examples/cpp/variable_intervals_sat.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@ void Solve() {
5353
parameters.set_enumerate_all_solutions(true);
5454
model.Add(NewSatParameters(parameters));
5555

56-
// Create an atomic Boolean that will be periodically checked by the limit.
57-
std::atomic<bool> stopped(false);
58-
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
59-
6056
const int kSolutionLimit = 100;
6157
int num_solutions = 0;
6258
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -68,7 +64,7 @@ void Solve() {
6864
LOG(INFO) << " start_ins = " << SolutionIntegerValue(r, start_ins);
6965
num_solutions++;
7066
if (num_solutions >= kSolutionLimit) {
71-
stopped = true;
67+
StopSearch(&model);
7268
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
7369
}
7470
}));

ortools/sat/cp_model_presolve.cc

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5457,19 +5457,26 @@ bool CpModelPresolver::PresolveTable(ConstraintProto* ct) {
54575457
namespace {
54585458

54595459
// A container that is valid if only one value was added.
5460-
struct UniqueNonNegativeValue {
5461-
int index = -1;
5462-
5463-
void Add(int new_index) {
5464-
DCHECK_GE(index, 0);
5465-
if (index == -1) {
5466-
index = new_index;
5460+
class UniqueNonNegativeValue {
5461+
public:
5462+
void Add(int value) {
5463+
DCHECK_GE(value, 0);
5464+
if (value_ == -1) {
5465+
value_ = value;
54675466
} else {
5468-
index = -2;
5467+
value_ = -2;
54695468
}
54705469
}
54715470

5472-
bool IsValid() const { return index >= 0; }
5471+
bool HasUniqueValue() const { return value_ >= 0; }
5472+
5473+
int64_t value() const {
5474+
DCHECK(HasUniqueValue());
5475+
return value_;
5476+
}
5477+
5478+
private:
5479+
int value_ = -1;
54735480
};
54745481

54755482
} // namespace
@@ -5492,7 +5499,7 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) {
54925499
return RemoveConstraint(ct);
54935500
}
54945501
if (size == 1) {
5495-
context_->UpdateRuleStats("all_diff: only one expression");
5502+
context_->UpdateRuleStats("all_diff: one expression");
54965503
return RemoveConstraint(ct);
54975504
}
54985505

@@ -5530,7 +5537,7 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) {
55305537
}
55315538
}
55325539
if (propagated) {
5533-
context_->UpdateRuleStats("all_diff: propagate fixed values");
5540+
context_->UpdateRuleStats("all_diff: propagate fixed expressions");
55345541
}
55355542
}
55365543

@@ -5610,9 +5617,10 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) {
56105617

56115618
bool propagated = false;
56125619
for (const auto& [value, unique_index] : value_to_index) {
5613-
if (!unique_index.IsValid()) continue;
5620+
if (!unique_index.HasUniqueValue()) continue;
56145621

5615-
const LinearExpressionProto& expr = all_diff.exprs(unique_index.index);
5622+
const LinearExpressionProto& expr =
5623+
all_diff.exprs(unique_index.value());
56165624
if (!context_->IntersectDomainWith(expr, Domain(value), &propagated)) {
56175625
return true;
56185626
}
@@ -7762,6 +7770,8 @@ void CpModelPresolver::Probe() {
77627770
return (void)context_->NotifyThatModelIsUnsat("during probing");
77637771
}
77647772

7773+
time_limit_->ResetHistory();
7774+
77657775
// Update the presolve context with fixed Boolean variables.
77667776
int num_fixed = 0;
77677777
CHECK_EQ(sat_solver->CurrentDecisionLevel(), 0);
@@ -8694,6 +8704,7 @@ void CpModelPresolver::MergeNoOverlapConstraints() {
86948704
// We reuse the max-clique code from sat.
86958705
Model local_model;
86968706
local_model.GetOrCreate<Trail>()->Resize(num_constraints);
8707+
local_model.GetOrCreate<TimeLimit>()->MergeWithGlobalTimeLimit(time_limit_);
86978708
auto* graph = local_model.GetOrCreate<BinaryImplicationGraph>();
86988709
graph->Resize(num_constraints);
86998710
for (const std::vector<Literal>& clique : cliques) {
@@ -8730,6 +8741,7 @@ void CpModelPresolver::MergeNoOverlapConstraints() {
87308741
new_num_intervals, " intervals).");
87318742
context_->UpdateRuleStats("no_overlap: merged constraints");
87328743
}
8744+
time_limit_->ResetHistory();
87338745
}
87348746

87358747
// TODO(user): Should we take into account the exactly_one constraints? note

ortools/sat/cp_model_solver.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,6 +2269,10 @@ std::function<SatParameters(Model*)> NewSatParameters(
22692269
};
22702270
}
22712271

2272+
void StopSearch(Model* model) {
2273+
model->GetOrCreate<ModelSharedTimeLimit>()->Stop();
2274+
}
2275+
22722276
namespace {
22732277
void RegisterSearchStatisticCallback(Model* global_model) {
22742278
global_model->GetOrCreate<SharedResponseManager>()

ortools/sat/cp_model_solver.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ std::function<SatParameters(Model*)> NewSatParameters(
128128
std::function<SatParameters(Model*)> NewSatParameters(
129129
const SatParameters& parameters);
130130

131+
/// Stops the current search.
132+
void StopSearch(Model* model);
133+
131134
// TODO(user): Clean this up.
132135
/// Solves a CpModelProto without any processing. Only used for unit tests.
133136
void LoadAndSolveCpModelForTest(const CpModelProto& model_proto, Model* model);

ortools/sat/docs/solver.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,6 @@ void StopAfterNSolutionsSampleSat() {
10251025
parameters.set_enumerate_all_solutions(true);
10261026
model.Add(NewSatParameters(parameters));
10271027

1028-
// Create an atomic Boolean that will be periodically checked by the limit.
1029-
std::atomic<bool> stopped(false);
1030-
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
1031-
10321028
const int kSolutionLimit = 5;
10331029
int num_solutions = 0;
10341030
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -1038,7 +1034,7 @@ void StopAfterNSolutionsSampleSat() {
10381034
LOG(INFO) << " z = " << SolutionIntegerValue(r, z);
10391035
num_solutions++;
10401036
if (num_solutions >= kSolutionLimit) {
1041-
stopped = true;
1037+
StopSearch(&model);
10421038
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
10431039
}
10441040
}));

ortools/sat/samples/nurses_sat.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ void NurseSat() {
138138

139139
// Display the first five solutions.
140140
// [START solution_printer]
141-
// Create an atomic Boolean that will be periodically checked by the limit.
142-
std::atomic<bool> stopped(false);
143-
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
144-
145141
const int kSolutionLimit = 5;
146142
int num_solutions = 0;
147143
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -165,7 +161,7 @@ void NurseSat() {
165161
}
166162
num_solutions++;
167163
if (num_solutions >= kSolutionLimit) {
168-
stopped = true;
164+
StopSearch(&model);
169165
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
170166
}
171167
}));

ortools/sat/samples/stop_after_n_solutions_sample_sat.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ void StopAfterNSolutionsSampleSat() {
4343
parameters.set_enumerate_all_solutions(true);
4444
model.Add(NewSatParameters(parameters));
4545

46-
// Create an atomic Boolean that will be periodically checked by the limit.
47-
std::atomic<bool> stopped(false);
48-
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
49-
5046
const int kSolutionLimit = 5;
5147
int num_solutions = 0;
5248
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -56,7 +52,7 @@ void StopAfterNSolutionsSampleSat() {
5652
LOG(INFO) << " z = " << SolutionIntegerValue(r, z);
5753
num_solutions++;
5854
if (num_solutions >= kSolutionLimit) {
59-
stopped = true;
55+
StopSearch(&model);
6056
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
6157
}
6258
}));

0 commit comments

Comments
 (0)