|
20 | 20 | #include <llvm/ADT/TypeSwitch.h> |
21 | 21 | #include <llvm/Support/Debug.h> |
22 | 22 | #include <llvm/Support/ErrorHandling.h> |
| 23 | +#include <mlir/Analysis/TopologicalSortUtils.h> |
23 | 24 | #include <mlir/Dialect/Func/IR/FuncOps.h> |
24 | 25 | #include <mlir/IR/Block.h> |
25 | 26 | #include <mlir/IR/BuiltinOps.h> |
@@ -321,6 +322,65 @@ struct MappingPass : impl::MappingPassBase<MappingPass> { |
321 | 322 |
|
322 | 323 | using MinQueue = std::priority_queue<Node, std::vector<Node>, std::greater<>>; |
323 | 324 |
|
| 325 | + struct [[nodiscard]] TrialResult { |
| 326 | + explicit TrialResult(Layout layout) : layout(std::move(layout)) {} |
| 327 | + |
| 328 | + /// @brief The computed initial layout. |
| 329 | + Layout layout; |
| 330 | + /// @brief A vector of SWAPs for each layer. |
| 331 | + SmallVector<SmallVector<IndexGate>> swaps; |
| 332 | + /// @brief The number of inserted SWAPs. |
| 333 | + std::size_t nswaps{}; |
| 334 | + }; |
| 335 | + |
| 336 | + struct SynchronizationMap { |
| 337 | + /** |
| 338 | + * @returns true if the operation is contained in the map. |
| 339 | + */ |
| 340 | + bool contains(Operation* op) const { return onHold.contains(op); } |
| 341 | + |
| 342 | + /** |
| 343 | + * @brief Add op with respective iterator and ref count to the map. |
| 344 | + */ |
| 345 | + void add(Operation* op, WireIterator* it, const std::size_t cnt) { |
| 346 | + onHold.try_emplace(op, SmallVector<WireIterator*>{it}); |
| 347 | + // Decrease the cnt by one because the op was visited when adding. |
| 348 | + refCount.try_emplace(op, cnt - 1); |
| 349 | + } |
| 350 | + |
| 351 | + /** |
| 352 | + * @brief Decrement ref count of op and potentially release its iterators. |
| 353 | + */ |
| 354 | + std::optional<SmallVector<WireIterator*, 0>> visit(Operation* op, |
| 355 | + WireIterator* it) { |
| 356 | + assert(refCount.contains(op) && "expected sync map to contain op"); |
| 357 | + |
| 358 | + // Add iterator for later release. |
| 359 | + onHold[op].push_back(it); |
| 360 | + |
| 361 | + // Release iterators whenever the ref count reaches zero. |
| 362 | + if (--refCount[op] == 0) { |
| 363 | + return onHold[op]; |
| 364 | + } |
| 365 | + |
| 366 | + return std::nullopt; |
| 367 | + } |
| 368 | + |
| 369 | + /** |
| 370 | + * @brief Clear the contents of the map. |
| 371 | + */ |
| 372 | + void clear() { |
| 373 | + onHold.clear(); |
| 374 | + refCount.clear(); |
| 375 | + } |
| 376 | + |
| 377 | + private: |
| 378 | + /// @brief Maps operations to to-be-released iterators. |
| 379 | + DenseMap<Operation*, SmallVector<WireIterator*, 0>> onHold; |
| 380 | + /// @brief Maps operations to ref counts. |
| 381 | + DenseMap<Operation*, std::size_t> refCount; |
| 382 | + }; |
| 383 | + |
324 | 384 | public: |
325 | 385 | using MappingPassBase::MappingPassBase; |
326 | 386 |
|
@@ -382,17 +442,6 @@ struct MappingPass : impl::MappingPassBase<MappingPass> { |
382 | 442 | } |
383 | 443 |
|
384 | 444 | private: |
385 | | - struct [[nodiscard]] TrialResult { |
386 | | - explicit TrialResult(Layout layout) : layout(std::move(layout)) {} |
387 | | - |
388 | | - /// @brief The computed initial layout. |
389 | | - Layout layout; |
390 | | - /// @brief A vector of SWAPs for each layer. |
391 | | - SmallVector<SmallVector<IndexGate>> swaps; |
392 | | - /// @brief The number of inserted SWAPs. |
393 | | - std::size_t nswaps{}; |
394 | | - }; |
395 | | - |
396 | 445 | /** |
397 | 446 | * @brief Find the best trial result in terms of the number of SWAPs. |
398 | 447 | * @returns the best trial result or nullptr if no result is valid. |
@@ -615,28 +664,33 @@ struct MappingPass : impl::MappingPassBase<MappingPass> { |
615 | 664 | while (shouldContinue(it)) { |
616 | 665 | const auto res = |
617 | 666 | TypeSwitch<Operation*, WalkResult>(it.operation()) |
618 | | - .Case<UnitaryOpInterface>([&](UnitaryOpInterface op) { |
619 | | - assert(op.getNumQubits() > 0 && op.getNumQubits() <= 2); |
| 667 | + .Case<BarrierOp>([&](auto) { |
| 668 | + std::ranges::advance(it, step); |
| 669 | + return WalkResult::advance(); |
| 670 | + }) |
| 671 | + .template Case<UnitaryOpInterface>( |
| 672 | + [&](UnitaryOpInterface op) { |
| 673 | + assert(op.getNumQubits() > 0 && op.getNumQubits() <= 2); |
620 | 674 |
|
621 | | - if (op.getNumQubits() == 1) { |
622 | | - std::ranges::advance(it, step); |
623 | | - return WalkResult::advance(); |
624 | | - } |
| 675 | + if (op.getNumQubits() == 1) { |
| 676 | + std::ranges::advance(it, step); |
| 677 | + return WalkResult::advance(); |
| 678 | + } |
625 | 679 |
|
626 | | - if (visited.contains(op)) { |
627 | | - const auto otherIndex = visited[op]; |
628 | | - layer.insert(std::make_pair(index, otherIndex)); |
| 680 | + if (visited.contains(op)) { |
| 681 | + const auto otherIndex = visited[op]; |
| 682 | + layer.insert(std::make_pair(index, otherIndex)); |
629 | 683 |
|
630 | | - std::ranges::advance(wires[index], step); |
631 | | - std::ranges::advance(wires[otherIndex], step); |
| 684 | + std::ranges::advance(wires[index], step); |
| 685 | + std::ranges::advance(wires[otherIndex], step); |
632 | 686 |
|
633 | | - visited.erase(op); |
634 | | - } else { |
635 | | - visited.try_emplace(op, index); |
636 | | - } |
| 687 | + visited.erase(op); |
| 688 | + } else { |
| 689 | + visited.try_emplace(op, index); |
| 690 | + } |
637 | 691 |
|
638 | | - return WalkResult::interrupt(); |
639 | | - }) |
| 692 | + return WalkResult::interrupt(); |
| 693 | + }) |
640 | 694 | .template Case<AllocOp, StaticOp, ResetOp, MeasureOp, |
641 | 695 | DeallocOp>([&](auto) { |
642 | 696 | std::ranges::advance(it, step); |
@@ -709,25 +763,30 @@ struct MappingPass : impl::MappingPassBase<MappingPass> { |
709 | 763 | // Helper function that advances the iterator to the input qubit (the |
710 | 764 | // operation producing it) of a deallocation or two-qubit op. |
711 | 765 | const auto advFront = [](WireIterator& it) { |
| 766 | + auto next = std::next(it); |
712 | 767 | while (true) { |
713 | | - const auto next = std::next(it); |
714 | 768 | if (isa<DeallocOp>(next.operation())) { |
715 | 769 | break; |
716 | 770 | } |
717 | 771 |
|
| 772 | + if (isa<BarrierOp>(next.operation())) { |
| 773 | + break; |
| 774 | + } |
| 775 | + |
718 | 776 | auto op = dyn_cast<UnitaryOpInterface>(next.operation()); |
719 | 777 | if (op && op.getNumQubits() > 1) { |
720 | 778 | break; |
721 | 779 | } |
722 | 780 |
|
723 | 781 | std::ranges::advance(it, 1); |
| 782 | + std::ranges::advance(next, 1); |
724 | 783 | } |
725 | 784 | }; |
726 | 785 |
|
727 | 786 | auto wires = toWires(place(dynQubits, result.layout, funcBody, rewriter)); |
728 | 787 |
|
729 | | - DenseMap<Operation*, WireIterator*> seen; |
730 | | - for (const auto [i, swaps] : enumerate(result.swaps)) { |
| 788 | + SynchronizationMap ready; |
| 789 | + for (const auto& swaps : result.swaps) { |
731 | 790 | // Advance all wires to the next front of one-qubit outputs |
732 | 791 | // (the SSA values). |
733 | 792 | for_each(wires, advFront); |
@@ -760,21 +819,30 @@ struct MappingPass : impl::MappingPassBase<MappingPass> { |
760 | 819 | std::ranges::advance(wires[hw1], 1); |
761 | 820 | } |
762 | 821 |
|
763 | | - // Jump over "ready" two-qubit gates. |
| 822 | + // Jump over "ready" gates. |
764 | 823 | for (auto& it : wires) { |
765 | 824 | auto op = dyn_cast<UnitaryOpInterface>(std::next(it).operation()); |
766 | | - if (op && op.getNumQubits() > 1) { |
767 | | - if (seen.contains(op)) { |
768 | | - std::ranges::advance(it, 1); |
769 | | - std::ranges::advance(*seen[op], 1); |
770 | | - continue; |
771 | | - } |
| 825 | + if (!op) { |
| 826 | + continue; |
| 827 | + } |
| 828 | + |
| 829 | + if (op.getNumQubits() < 2) { |
| 830 | + continue; |
| 831 | + } |
772 | 832 |
|
773 | | - seen.try_emplace(op, &it); |
| 833 | + if (!ready.contains(op)) { |
| 834 | + ready.add(op, &it, op.getNumQubits()); |
| 835 | + continue; |
| 836 | + } |
| 837 | + |
| 838 | + if (auto opt = ready.visit(op, &it)) { |
| 839 | + for (WireIterator* wire : *opt) { |
| 840 | + std::ranges::advance(*wire, 1); |
| 841 | + } |
774 | 842 | } |
775 | 843 | } |
776 | 844 |
|
777 | | - seen.clear(); // Prepare for next iteration. |
| 845 | + ready.clear(); // Prepare for next iteration. |
778 | 846 | } |
779 | 847 | } |
780 | 848 | }; |
|
0 commit comments