Skip to content

Commit ca59171

Browse files
Fix SSA dominance issues
1 parent 8d7cc71 commit ca59171

2 files changed

Lines changed: 120 additions & 40 deletions

File tree

mlir/lib/Dialect/QCO/Transforms/Mapping/Mapping.cpp

Lines changed: 108 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <llvm/ADT/TypeSwitch.h>
2121
#include <llvm/Support/Debug.h>
2222
#include <llvm/Support/ErrorHandling.h>
23+
#include <mlir/Analysis/TopologicalSortUtils.h>
2324
#include <mlir/Dialect/Func/IR/FuncOps.h>
2425
#include <mlir/IR/Block.h>
2526
#include <mlir/IR/BuiltinOps.h>
@@ -321,6 +322,65 @@ struct MappingPass : impl::MappingPassBase<MappingPass> {
321322

322323
using MinQueue = std::priority_queue<Node, std::vector<Node>, std::greater<>>;
323324

325+
struct [[nodiscard]] TrialResult {
326+
explicit TrialResult(Layout layout) : layout(std::move(layout)) {}
327+
328+
/// @brief The computed initial layout.
329+
Layout layout;
330+
/// @brief A vector of SWAPs for each layer.
331+
SmallVector<SmallVector<IndexGate>> swaps;
332+
/// @brief The number of inserted SWAPs.
333+
std::size_t nswaps{};
334+
};
335+
336+
struct SynchronizationMap {
337+
/**
338+
* @returns true if the operation is contained in the map.
339+
*/
340+
bool contains(Operation* op) const { return onHold.contains(op); }
341+
342+
/**
343+
* @brief Add op with respective iterator and ref count to the map.
344+
*/
345+
void add(Operation* op, WireIterator* it, const std::size_t cnt) {
346+
onHold.try_emplace(op, SmallVector<WireIterator*>{it});
347+
// Decrease the cnt by one because the op was visited when adding.
348+
refCount.try_emplace(op, cnt - 1);
349+
}
350+
351+
/**
352+
* @brief Decrement ref count of op and potentially release its iterators.
353+
*/
354+
std::optional<SmallVector<WireIterator*, 0>> visit(Operation* op,
355+
WireIterator* it) {
356+
assert(refCount.contains(op) && "expected sync map to contain op");
357+
358+
// Add iterator for later release.
359+
onHold[op].push_back(it);
360+
361+
// Release iterators whenever the ref count reaches zero.
362+
if (--refCount[op] == 0) {
363+
return onHold[op];
364+
}
365+
366+
return std::nullopt;
367+
}
368+
369+
/**
370+
* @brief Clear the contents of the map.
371+
*/
372+
void clear() {
373+
onHold.clear();
374+
refCount.clear();
375+
}
376+
377+
private:
378+
/// @brief Maps operations to to-be-released iterators.
379+
DenseMap<Operation*, SmallVector<WireIterator*, 0>> onHold;
380+
/// @brief Maps operations to ref counts.
381+
DenseMap<Operation*, std::size_t> refCount;
382+
};
383+
324384
public:
325385
using MappingPassBase::MappingPassBase;
326386

@@ -382,17 +442,6 @@ struct MappingPass : impl::MappingPassBase<MappingPass> {
382442
}
383443

384444
private:
385-
struct [[nodiscard]] TrialResult {
386-
explicit TrialResult(Layout layout) : layout(std::move(layout)) {}
387-
388-
/// @brief The computed initial layout.
389-
Layout layout;
390-
/// @brief A vector of SWAPs for each layer.
391-
SmallVector<SmallVector<IndexGate>> swaps;
392-
/// @brief The number of inserted SWAPs.
393-
std::size_t nswaps{};
394-
};
395-
396445
/**
397446
* @brief Find the best trial result in terms of the number of SWAPs.
398447
* @returns the best trial result or nullptr if no result is valid.
@@ -615,28 +664,33 @@ struct MappingPass : impl::MappingPassBase<MappingPass> {
615664
while (shouldContinue(it)) {
616665
const auto res =
617666
TypeSwitch<Operation*, WalkResult>(it.operation())
618-
.Case<UnitaryOpInterface>([&](UnitaryOpInterface op) {
619-
assert(op.getNumQubits() > 0 && op.getNumQubits() <= 2);
667+
.Case<BarrierOp>([&](auto) {
668+
std::ranges::advance(it, step);
669+
return WalkResult::advance();
670+
})
671+
.template Case<UnitaryOpInterface>(
672+
[&](UnitaryOpInterface op) {
673+
assert(op.getNumQubits() > 0 && op.getNumQubits() <= 2);
620674

621-
if (op.getNumQubits() == 1) {
622-
std::ranges::advance(it, step);
623-
return WalkResult::advance();
624-
}
675+
if (op.getNumQubits() == 1) {
676+
std::ranges::advance(it, step);
677+
return WalkResult::advance();
678+
}
625679

626-
if (visited.contains(op)) {
627-
const auto otherIndex = visited[op];
628-
layer.insert(std::make_pair(index, otherIndex));
680+
if (visited.contains(op)) {
681+
const auto otherIndex = visited[op];
682+
layer.insert(std::make_pair(index, otherIndex));
629683

630-
std::ranges::advance(wires[index], step);
631-
std::ranges::advance(wires[otherIndex], step);
684+
std::ranges::advance(wires[index], step);
685+
std::ranges::advance(wires[otherIndex], step);
632686

633-
visited.erase(op);
634-
} else {
635-
visited.try_emplace(op, index);
636-
}
687+
visited.erase(op);
688+
} else {
689+
visited.try_emplace(op, index);
690+
}
637691

638-
return WalkResult::interrupt();
639-
})
692+
return WalkResult::interrupt();
693+
})
640694
.template Case<AllocOp, StaticOp, ResetOp, MeasureOp,
641695
DeallocOp>([&](auto) {
642696
std::ranges::advance(it, step);
@@ -709,25 +763,30 @@ struct MappingPass : impl::MappingPassBase<MappingPass> {
709763
// Helper function that advances the iterator to the input qubit (the
710764
// operation producing it) of a deallocation or two-qubit op.
711765
const auto advFront = [](WireIterator& it) {
766+
auto next = std::next(it);
712767
while (true) {
713-
const auto next = std::next(it);
714768
if (isa<DeallocOp>(next.operation())) {
715769
break;
716770
}
717771

772+
if (isa<BarrierOp>(next.operation())) {
773+
break;
774+
}
775+
718776
auto op = dyn_cast<UnitaryOpInterface>(next.operation());
719777
if (op && op.getNumQubits() > 1) {
720778
break;
721779
}
722780

723781
std::ranges::advance(it, 1);
782+
std::ranges::advance(next, 1);
724783
}
725784
};
726785

727786
auto wires = toWires(place(dynQubits, result.layout, funcBody, rewriter));
728787

729-
DenseMap<Operation*, WireIterator*> seen;
730-
for (const auto [i, swaps] : enumerate(result.swaps)) {
788+
SynchronizationMap ready;
789+
for (const auto& swaps : result.swaps) {
731790
// Advance all wires to the next front of one-qubit outputs
732791
// (the SSA values).
733792
for_each(wires, advFront);
@@ -760,21 +819,30 @@ struct MappingPass : impl::MappingPassBase<MappingPass> {
760819
std::ranges::advance(wires[hw1], 1);
761820
}
762821

763-
// Jump over "ready" two-qubit gates.
822+
// Jump over "ready" gates.
764823
for (auto& it : wires) {
765824
auto op = dyn_cast<UnitaryOpInterface>(std::next(it).operation());
766-
if (op && op.getNumQubits() > 1) {
767-
if (seen.contains(op)) {
768-
std::ranges::advance(it, 1);
769-
std::ranges::advance(*seen[op], 1);
770-
continue;
771-
}
825+
if (!op) {
826+
continue;
827+
}
828+
829+
if (op.getNumQubits() < 2) {
830+
continue;
831+
}
772832

773-
seen.try_emplace(op, &it);
833+
if (!ready.contains(op)) {
834+
ready.add(op, &it, op.getNumQubits());
835+
continue;
836+
}
837+
838+
if (auto opt = ready.visit(op, &it)) {
839+
for (WireIterator* wire : *opt) {
840+
std::ranges::advance(*wire, 1);
841+
}
774842
}
775843
}
776844

777-
seen.clear(); // Prepare for next iteration.
845+
ready.clear(); // Prepare for next iteration.
778846
}
779847
}
780848
};

mlir/unittests/Dialect/QCO/Transforms/Mapping/test_mapping.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include <gtest/gtest.h>
2222
#include <llvm/ADT/STLExtras.h>
23+
#include <llvm/Support/Debug.h>
2324
#include <llvm/Support/LogicalResult.h>
2425
#include <mlir/Dialect/Arith/IR/Arith.h>
2526
#include <mlir/Dialect/Func/IR/FuncOps.h>
@@ -63,6 +64,9 @@ class MappingPassTest : public testing::Test,
6364

6465
bool executable = true;
6566
std::ignore = moduleOp->walk([&](qc::UnitaryOpInterface op) {
67+
if (isa<qc::BarrierOp>(op)) {
68+
return WalkResult::advance();
69+
}
6670
if (op.getNumQubits() > 1) {
6771
assert(op.getNumQubits() == 2 &&
6872
"Expected only 2-qubit gates after decomposition");
@@ -178,6 +182,14 @@ TEST_P(MappingPassTest, Sabre) {
178182

179183
builder.cx(q3, q0);
180184

185+
builder.barrier({q0, q1, q2, q3, q4, q5});
186+
builder.measure(q0);
187+
builder.measure(q1);
188+
builder.measure(q2);
189+
builder.measure(q3);
190+
builder.measure(q4);
191+
builder.measure(q5);
192+
181193
builder.dealloc(q0);
182194
builder.dealloc(q1);
183195
builder.dealloc(q2);

0 commit comments

Comments
 (0)