Skip to content

Commit

Permalink
[CP-SAT] more minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Jan 30, 2025
1 parent 15b93ce commit bf87fd8
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 61 deletions.
1 change: 1 addition & 0 deletions ortools/sat/2d_rectangle_presolve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ bool PresolveFixed2dRectangles(
// The whole rectangle was outside of the domain, remove it.
std::swap(rectangle, (*fixed_boxes)[fixed_boxes->size() - 1]);
fixed_boxes->resize(fixed_boxes->size() - 1);
changed = true;
continue;
} else {
new_size++;
Expand Down
15 changes: 11 additions & 4 deletions ortools/sat/disjunctive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -871,18 +871,19 @@ bool DisjunctiveDetectablePrecedences::Propagate() {
const auto by_shifted_smin = helper_->TaskByIncreasingShiftedStartMin();
int rank = -1;
IntegerValue window_end = kMinIntegerValue;
int* const ranks = ranks_.data();
for (const auto [task, presence_lit, start_min] : by_shifted_smin) {
if (!helper_->IsPresent(presence_lit)) {
ranks_[task] = -1;
ranks[task] = -1;
continue;
}

const IntegerValue size_min = helper_->SizeMin(task);
if (start_min < window_end) {
ranks_[task] = rank;
ranks[task] = rank;
window_end += size_min;
} else {
ranks_[task] = ++rank;
ranks[task] = ++rank;
window_end = start_min + size_min;
}
}
Expand Down Expand Up @@ -1329,9 +1330,15 @@ bool DisjunctivePrecedences::PropagateSubwindow() {

int DisjunctivePrecedences::RegisterWith(GenericLiteralWatcher* watcher) {
// This propagator reach the fixed point in one go.
// Maybe not in corner cases, but since it is expansive, it is okay not to
// run it again right away
//
// Note also that technically, we don't need to be waked up if only the upper
// bound of the task changes, but this require to use more memory and the gain
// is unclear as this runs with the highest priority.
const int id = watcher->Register(this);
helper_->SetTimeDirection(time_direction_);
helper_->WatchAllTasks(id, /*watch_max_side=*/false);
helper_->WatchAllTasks(id);
return id;
}

Expand Down
14 changes: 3 additions & 11 deletions ortools/sat/integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2119,10 +2119,7 @@ void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) {
const Literal literal = (*trail)[propagation_trail_index_++];
if (literal.Index() >= literal_limit) continue;
for (const auto entry : literal_to_watcher_[literal]) {
if (!in_queue_[entry.id]) {
in_queue_[entry.id] = true;
queue_by_priority_[id_to_priority_[entry.id]].push_back(entry.id);
}
CallOnNextPropagate(entry.id);
if (entry.watch_index >= 0) {
id_to_watch_indices_[entry.id].push_back(entry.watch_index);
}
Expand All @@ -2134,10 +2131,7 @@ void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) {
for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) {
if (var.value() >= var_limit) continue;
for (const auto entry : var_to_watcher_[var]) {
if (!in_queue_[entry.id]) {
in_queue_[entry.id] = true;
queue_by_priority_[id_to_priority_[entry.id]].push_back(entry.id);
}
CallOnNextPropagate(entry.id);
if (entry.watch_index >= 0) {
id_to_watch_indices_[entry.id].push_back(entry.watch_index);
}
Expand All @@ -2161,9 +2155,7 @@ bool GenericLiteralWatcher::Propagate(Trail* trail) {
const int level = trail->CurrentDecisionLevel();
if (level == 0) {
for (const int id : propagator_ids_to_call_at_level_zero_) {
if (in_queue_[id]) continue;
in_queue_[id] = true;
queue_by_priority_[id_to_priority_[id]].push_back(id);
CallOnNextPropagate(id);
}
}

Expand Down
7 changes: 6 additions & 1 deletion ortools/sat/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,13 +787,18 @@ class IntegerTrail final : public SatPropagator {
void AddAllGreaterThanConstantReason(absl::Span<AffineExpression> exprs,
IntegerValue target_min,
std::vector<int>* indices) const {
int64_t num_processed = 0;
for (const AffineExpression& expr : exprs) {
if (expr.IsConstant()) {
DCHECK_GE(expr.constant, target_min);
continue;
}
DCHECK_NE(expr.var, kNoIntegerVariable);

// On large routing problems, we can spend a lot of time in this loop.
// We check the time limit every 5 processed expressions.
if (++num_processed % 5 == 0 && time_limit_->LimitReached()) return;

// Skip if we already have an explanation for expr >= target_min. Note
// that we already do that while processing the returned indices, so this
// mainly save a FindLowestTrailIndexThatExplainBound() call per skipped
Expand Down Expand Up @@ -897,7 +902,7 @@ class IntegerTrail final : public SatPropagator {
// Returns some debugging info.
std::string DebugString();

// Used internally to return the next conlict number.
// Used internally to return the next conflict number.
int64_t NextConflictId();

// Information for each integer variable about its current lower bound and
Expand Down
7 changes: 7 additions & 0 deletions ortools/sat/intervals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ SchedulingConstraintHelper* IntervalsRepository::GetOrCreateHelper(
std::vector<AffineExpression> sizes;
std::vector<LiteralIndex> reason_for_presence;

const int num_variables = variables.size();
starts.reserve(num_variables);
ends.reserve(num_variables);
sizes.reserve(num_variables);
reason_for_presence.reserve(num_variables);

for (const IntervalVariable i : variables) {
if (IsOptional(i)) {
reason_for_presence.push_back(PresenceLiteral(i).Index());
Expand All @@ -224,6 +230,7 @@ SchedulingConstraintHelper* IntervalsRepository::GetOrCreateHelper(
starts.push_back(Start(i));
ends.push_back(End(i));
}

SchedulingConstraintHelper* helper = new SchedulingConstraintHelper(
std::move(starts), std::move(ends), std::move(sizes),
std::move(reason_for_presence), model_);
Expand Down
56 changes: 24 additions & 32 deletions ortools/sat/scheduling_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/meta/type_traits.h"
#include "absl/strings/str_cat.h"
Expand All @@ -47,8 +46,8 @@ SchedulingConstraintHelper::SchedulingConstraintHelper(
std::vector<AffineExpression> sizes,
std::vector<LiteralIndex> reason_for_presence, Model* model)
: model_(model),
trail_(model->GetOrCreate<Trail>()),
sat_solver_(model->GetOrCreate<SatSolver>()),
assignment_(sat_solver_->Assignment()),
integer_trail_(model->GetOrCreate<IntegerTrail>()),
watcher_(model->GetOrCreate<GenericLiteralWatcher>()),
precedence_relations_(model->GetOrCreate<PrecedenceRelations>()),
Expand All @@ -64,15 +63,17 @@ SchedulingConstraintHelper::SchedulingConstraintHelper(
cached_negated_end_max_(new IntegerValue[capacity_]),
cached_shifted_start_min_(new IntegerValue[capacity_]),
cached_negated_shifted_end_max_(new IntegerValue[capacity_]) {
minus_ends_.clear();
minus_starts_.clear();
DCHECK_EQ(starts_.size(), ends_.size());
DCHECK_EQ(starts_.size(), sizes_.size());
DCHECK_EQ(starts_.size(), reason_for_presence_.size());

minus_starts_.clear();
minus_starts_.reserve(starts_.size());
minus_ends_.clear();
minus_ends_.reserve(starts_.size());
for (int i = 0; i < starts_.size(); ++i) {
minus_ends_.push_back(ends_[i].Negated());
minus_starts_.push_back(starts_[i].Negated());
minus_ends_.push_back(ends_[i].Negated());
}

InitSortedVectors();
Expand All @@ -84,8 +85,8 @@ SchedulingConstraintHelper::SchedulingConstraintHelper(
SchedulingConstraintHelper::SchedulingConstraintHelper(int num_tasks,
Model* model)
: model_(model),
trail_(model->GetOrCreate<Trail>()),
sat_solver_(model->GetOrCreate<SatSolver>()),
assignment_(sat_solver_->Assignment()),
integer_trail_(model->GetOrCreate<IntegerTrail>()),
precedence_relations_(model->GetOrCreate<PrecedenceRelations>()),
capacity_(num_tasks),
Expand Down Expand Up @@ -120,6 +121,16 @@ void SchedulingConstraintHelper::RegisterWith(GenericLiteralWatcher* watcher) {
watcher->WatchIntegerVariable(sizes_[t].var, id, t);
watcher->WatchIntegerVariable(starts_[t].var, id, t);
watcher->WatchIntegerVariable(ends_[t].var, id, t);

// This class do not need to be waked up on presence change, since this is
// not cached. However given that we can have many propagators that use the
// same helper, it is nicer to only register this one, and wake up all
// propagator through it rather than registering all of them individually.
// Note that IncrementalPropagate() will do nothing if this is the only
// change except waking up registered propagators.
if (!IsPresent(t) && !IsAbsent(t)) {
watcher_->WatchLiteral(Literal(reason_for_presence_[t]), id);
}
}
watcher->SetPropagatorPriority(id, 0);
}
Expand Down Expand Up @@ -352,7 +363,7 @@ IntegerValue SchedulingConstraintHelper::GetCurrentMinDistanceBetweenTasks(
bool SchedulingConstraintHelper::PropagatePrecedence(int a, int b) {
CHECK(IsPresent(a));
CHECK(IsPresent(b));
CHECK_EQ(trail_->CurrentDecisionLevel(), 0);
CHECK_EQ(sat_solver_->CurrentDecisionLevel(), 0);

const AffineExpression before = ends_[a];
const AffineExpression after = starts_[b];
Expand All @@ -370,7 +381,7 @@ bool SchedulingConstraintHelper::PropagatePrecedence(int a, int b) {
AddWeightedSumLowerOrEqual({}, {before.var, after.var},
{int64_t{1}, int64_t{-1}}, -offset.value(),
model_);
if (model_->GetOrCreate<SatSolver>()->ModelIsUnsat()) return false;
if (sat_solver_->ModelIsUnsat()) return false;
}
return true;
}
Expand Down Expand Up @@ -609,30 +620,11 @@ bool SchedulingConstraintHelper::ReportConflict() {
return integer_trail_->ReportConflict(literal_reason_, integer_reason_);
}

void SchedulingConstraintHelper::WatchAllTasks(int id, bool watch_max_side) {
// In all cases, we watch presence literals since this class is not waked up
// when those changes.
const int num_tasks = starts_.size();
for (int t = 0; t < num_tasks; ++t) {
if (!IsPresent(t) && !IsAbsent(t)) {
watcher_->WatchLiteral(Literal(reason_for_presence_[t]), id);
}
}

// If everything is watched, it is slighlty more efficient to enqueue the
// propagator when the helper Propagate() is called. This result in less
// entries in our watched lists.
if (watch_max_side) {
propagator_ids_.push_back(id);
return;
}

// We only watch "min" side.
for (int t = 0; t < num_tasks; ++t) {
watcher_->WatchLowerBound(starts_[t], id);
watcher_->WatchLowerBound(ends_[t], id);
watcher_->WatchLowerBound(sizes_[t], id);
}
void SchedulingConstraintHelper::WatchAllTasks(int id) {
// It is more efficient to enqueue the propagator
// when the helper Propagate() is called. This result in less entries in our
// watched lists.
propagator_ids_.push_back(id);
}

void SchedulingConstraintHelper::AddOtherReason(int t) {
Expand Down
30 changes: 20 additions & 10 deletions ortools/sat/scheduling_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class SchedulingConstraintHelper : public PropagatorInterface {

// Registers the given propagator id to be called if any of the tasks
// in this class change. Note that we do not watch size max though.
void WatchAllTasks(int id, bool watch_max_side = true);
void WatchAllTasks(int id);

// Manages the other helper (used by the diffn constraint).
//
Expand Down Expand Up @@ -349,7 +349,9 @@ class SchedulingConstraintHelper : public PropagatorInterface {
// not handle this correctly.
bool InPropagationLoop() const { return integer_trail_->InPropagationLoop(); }

int CurrentDecisionLevel() const { return trail_->CurrentDecisionLevel(); }
int CurrentDecisionLevel() const {
return sat_solver_->CurrentDecisionLevel();
}

private:
// Tricky: when a task is optional, it is possible it size min is negative,
Expand Down Expand Up @@ -384,8 +386,8 @@ class SchedulingConstraintHelper : public PropagatorInterface {
void ImportOtherReasons();

Model* model_;
Trail* trail_;
SatSolver* sat_solver_;
const VariablesAssignment& assignment_;
IntegerTrail* integer_trail_;
GenericLiteralWatcher* watcher_;
PrecedenceRelations* precedence_relations_;
Expand Down Expand Up @@ -610,17 +612,25 @@ inline bool SchedulingConstraintHelper::SizeIsFixed(int t) const {
}

inline bool SchedulingConstraintHelper::IsOptional(int t) const {
return reason_for_presence_[t] != kNoLiteralIndex;
DCHECK_GE(t, 0);
DCHECK_LT(t, reason_for_presence_.size());
return reason_for_presence_.data()[t] != kNoLiteralIndex;
}

inline bool SchedulingConstraintHelper::IsPresent(int t) const {
if (reason_for_presence_[t] == kNoLiteralIndex) return true;
return trail_->Assignment().LiteralIsTrue(Literal(reason_for_presence_[t]));
DCHECK_GE(t, 0);
DCHECK_LT(t, reason_for_presence_.size());
const LiteralIndex lit = reason_for_presence_.data()[t];
if (lit == kNoLiteralIndex) return true;
return assignment_.LiteralIsTrue(Literal(lit));
}

inline bool SchedulingConstraintHelper::IsAbsent(int t) const {
if (reason_for_presence_[t] == kNoLiteralIndex) return false;
return trail_->Assignment().LiteralIsFalse(Literal(reason_for_presence_[t]));
DCHECK_GE(t, 0);
DCHECK_LT(t, reason_for_presence_.size());
const LiteralIndex lit = reason_for_presence_.data()[t];
if (lit == kNoLiteralIndex) return false;
return assignment_.LiteralIsFalse(Literal(lit));
}

inline bool SchedulingConstraintHelper::IsOptional(LiteralIndex lit) const {
Expand All @@ -629,12 +639,12 @@ inline bool SchedulingConstraintHelper::IsOptional(LiteralIndex lit) const {

inline bool SchedulingConstraintHelper::IsPresent(LiteralIndex lit) const {
if (lit == kNoLiteralIndex) return true;
return trail_->Assignment().LiteralIsTrue(Literal(lit));
return assignment_.LiteralIsTrue(Literal(lit));
}

inline bool SchedulingConstraintHelper::IsAbsent(LiteralIndex lit) const {
if (lit == kNoLiteralIndex) return false;
return trail_->Assignment().LiteralIsFalse(Literal(lit));
return assignment_.LiteralIsFalse(Literal(lit));
}

inline void SchedulingConstraintHelper::ClearReason() {
Expand Down
3 changes: 1 addition & 2 deletions ortools/sat/timetable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "absl/types/span.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/integer_base.h"
#include "ortools/sat/intervals.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/scheduling_helpers.h"
Expand Down Expand Up @@ -344,7 +343,7 @@ TimeTablingPerTask::TimeTablingPerTask(AffineExpression capacity,

void TimeTablingPerTask::RegisterWith(GenericLiteralWatcher* watcher) {
const int id = watcher->Register(this);
helper_->WatchAllTasks(id, watcher);
helper_->WatchAllTasks(id);
watcher->WatchUpperBound(capacity_.var, id);
for (int t = 0; t < num_tasks_; t++) {
watcher->WatchLowerBound(demands_->Demands()[t], id);
Expand Down
2 changes: 1 addition & 1 deletion ortools/sat/timetable_edgefinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TimeTableEdgeFinding::TimeTableEdgeFinding(AffineExpression capacity,
void TimeTableEdgeFinding::RegisterWith(GenericLiteralWatcher* watcher) {
const int id = watcher->Register(this);
watcher->WatchUpperBound(capacity_, id);
helper_->WatchAllTasks(id, watcher);
helper_->WatchAllTasks(id);
for (int t = 0; t < num_tasks_; t++) {
watcher->WatchLowerBound(demands_->Demands()[t], id);
}
Expand Down

0 comments on commit bf87fd8

Please sign in to comment.