diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d798926f7..85e6f7e9ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel - ✨ Add conversions between Jeff and QCO ([#1479], [#1548], [#1565]) ([**@denialhaag**]) - ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568]) ([**@MatthiasReumann**]) - ✨ Add initial infrastructure for new QC and QCO MLIR dialects - ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1570], [#1572], [#1573]) + ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1569], [#1570], [#1572], [#1573]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) ### Changed @@ -337,6 +337,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1572]: https://github.com/munich-quantum-toolkit/core/pull/1572 [#1571]: https://github.com/munich-quantum-toolkit/core/pull/1571 [#1570]: https://github.com/munich-quantum-toolkit/core/pull/1570 +[#1569]: https://github.com/munich-quantum-toolkit/core/pull/1569 [#1568]: https://github.com/munich-quantum-toolkit/core/pull/1568 [#1565]: https://github.com/munich-quantum-toolkit/core/pull/1565 [#1564]: https://github.com/munich-quantum-toolkit/core/pull/1564 diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index fdf5ab7310..f602903950 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -110,7 +110,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Get a static qubit by index - * @param index The qubit index (must be non-negative) + * @param index The qubit index * @return A qubit reference * * @par Example: @@ -121,7 +121,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * %q0 = qc.static 0 : !qc.qubit * ``` */ - Value staticQubit(int64_t index); + Value staticQubit(uint64_t index); /** * @brief Allocate a qubit register diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index 9ca5480d2f..1272bc0f24 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -26,6 +26,21 @@ include "mlir/Interfaces/SideEffectInterfaces.td" class QCOp traits = []> : Op; +//===----------------------------------------------------------------------===// +// Type Constraints +//===----------------------------------------------------------------------===// + +def DynamicQubit + : Type($_self)">, + CPred<"!::mlir::cast<::mlir::qc::QubitType>($_self)." + "getIsStatic()">]>, + "dynamic qubit type (!qc.qubit)">; + +def StaticQubit : Type($_self)">, + CPred<"::mlir::cast<::mlir::qc::QubitType>($_self)." + "getIsStatic()">]>, + "static qubit type (!qc.qubit)">; + //===----------------------------------------------------------------------===// // Resource Operations //===----------------------------------------------------------------------===// @@ -57,20 +72,20 @@ def AllocOp : QCOp<"alloc", [MemoryEffects<[MemAlloc]>]> { let arguments = (ins OptionalAttr:$register_name, OptionalAttr>:$register_size, OptionalAttr>:$register_index); - let results = (outs QubitType:$result); + let results = (outs DynamicQubit:$result); let assemblyFormat = [{ (`(` $register_name^ `,` $register_size `,` $register_index `)`)? attr-dict `:` type($result) }]; let builders = [OpBuilder<(ins), [{ - build($_builder, $_state, QubitType::get($_builder.getContext()), nullptr, nullptr, nullptr); + build($_builder, $_state, QubitType::get($_builder.getContext(), /*isStatic=*/false), nullptr, nullptr, nullptr); }]>, OpBuilder<(ins "::mlir::StringAttr":$register_name, "::mlir::IntegerAttr":$register_size, "::mlir::IntegerAttr":$register_index), [{ - build($_builder, $_state, QubitType::get($_builder.getContext()), + build($_builder, $_state, QubitType::get($_builder.getContext(), /*isStatic=*/false), register_name, register_size, register_index); }]>]; @@ -78,9 +93,10 @@ def AllocOp : QCOp<"alloc", [MemoryEffects<[MemAlloc]>]> { } def DeallocOp : QCOp<"dealloc", [MemoryEffects<[MemFree]>]> { - let summary = "Deallocate a qubit"; + let summary = "Deallocate a dynamically allocated qubit"; let description = [{ - Deallocates a qubit, releasing its resources. + Deallocates a dynamically allocated qubit, releasing its resources. + Static qubits (`!qc.qubit`) cannot be deallocated. Example: ```mlir @@ -88,7 +104,7 @@ def DeallocOp : QCOp<"dealloc", [MemoryEffects<[MemFree]>]> { ``` }]; - let arguments = (ins QubitType:$qubit); + let arguments = (ins DynamicQubit:$qubit); let assemblyFormat = "$qubit attr-dict `:` type($qubit)"; let hasCanonicalizer = 1; @@ -103,13 +119,17 @@ def StaticOp : QCOp<"static", [Pure]> { Example: ```mlir - %q = qc.static 0 : !qc.qubit + %q = qc.static 0 : !qc.qubit ``` }]; let arguments = (ins ConfinedAttr:$index); - let results = (outs QubitType:$qubit); + let results = (outs StaticQubit:$qubit); let assemblyFormat = "$index attr-dict `:` type($qubit)"; + + let builders = [OpBuilder<(ins "uint64_t":$index), [{ + build($_builder, $_state, QubitType::get($_builder.getContext(), /*isStatic=*/true), index); + }]>]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/QC/IR/QCTypes.td b/mlir/include/mlir/Dialect/QC/IR/QCTypes.td index 85d36311fa..bc4ab8852c 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCTypes.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCTypes.td @@ -25,7 +25,15 @@ def QubitType : QCType<"Qubit", "qubit"> { QC dialect. Operations using this type modify qubits in place using reference semantics, similar to how classical imperative languages handle mutable references. + + `!qc.qubit` (default) denotes a dynamically allocated qubit. + `!qc.qubit` denotes a qubit with a statically known identifier. }]; + let parameters = (ins DefaultValuedParameter<"bool", "false">:$isStatic); + let builders = [TypeBuilder<(ins), [{ + return $_get($_ctxt, /*isStatic=*/false); + }]>]; + let hasCustomAssemblyFormat = 1; } #endif // MLIR_DIALECT_QC_IR_QCTYPES_TD diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index a09c49ea7a..f86737ba5a 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -118,7 +118,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Get a static qubit by index - * @param index The qubit index (must be non-negative) + * @param index The qubit index * @return A tracked, valid qubit SSA value * * @par Example: @@ -129,7 +129,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { * %q0 = qco.static 0 : !qco.qubit * ``` */ - Value staticQubit(int64_t index); + Value staticQubit(uint64_t index); /** * @brief Allocate a qubit register diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index 28a9034e80..dcdeb1bcbb 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -26,6 +26,22 @@ include "mlir/Interfaces/SideEffectInterfaces.td" class QCOOp traits = []> : Op; +//===----------------------------------------------------------------------===// +// Type Constraints +//===----------------------------------------------------------------------===// + +def DynamicQubit + : Type($_self)">, + CPred<"!::mlir::cast<::mlir::qco::QubitType>($_self)." + "getIsStatic()">]>, + "dynamic qubit type (!qco.qubit)">; + +def StaticQubit + : Type($_self)">, + CPred<"::mlir::cast<::mlir::qco::QubitType>($_self)." + "getIsStatic()">]>, + "static qubit type (!qco.qubit)">; + //===----------------------------------------------------------------------===// // Resource Operations //===----------------------------------------------------------------------===// @@ -57,20 +73,20 @@ def AllocOp : QCOOp<"alloc", [MemoryEffects<[MemAlloc]>]> { let arguments = (ins OptionalAttr:$register_name, OptionalAttr>:$register_size, OptionalAttr>:$register_index); - let results = (outs QubitType:$result); + let results = (outs DynamicQubit:$result); let assemblyFormat = [{ (`(` $register_name^ `,` $register_size `,` $register_index `)`)? attr-dict `:` type($result) }]; let builders = [OpBuilder<(ins), [{ - build($_builder, $_state, QubitType::get($_builder.getContext()), nullptr, nullptr, nullptr); + build($_builder, $_state, QubitType::get($_builder.getContext(), /*isStatic=*/false), nullptr, nullptr, nullptr); }]>, OpBuilder<(ins "::mlir::StringAttr":$register_name, "::mlir::IntegerAttr":$register_size, "::mlir::IntegerAttr":$register_index), [{ - build($_builder, $_state, QubitType::get($_builder.getContext()), + build($_builder, $_state, QubitType::get($_builder.getContext(), /*isStatic=*/false), register_name, register_size, register_index); }]>]; @@ -80,7 +96,12 @@ def AllocOp : QCOOp<"alloc", [MemoryEffects<[MemAlloc]>]> { def DeallocOp : QCOOp<"dealloc", [MemoryEffects<[MemFree]>]> { let summary = "Deallocate a qubit"; let description = [{ - Deallocates a qubit, releasing its resources. + Deallocates a qubit. + + In QCO's value/linear semantics, this operation also serves as the sink + that consumes the qubit SSA value (ensuring every qubit value is used + exactly once). When converting back to QC (reference semantics), deallocs + corresponding to static qubits may be erased. Example: ```mlir @@ -102,20 +123,26 @@ def StaticOp : QCOOp<"static", [Pure]> { Example: ```mlir - %q = qco.static 0 : !qco.qubit + %q = qco.static 0 : !qco.qubit ``` }]; let arguments = (ins ConfinedAttr:$index); - let results = (outs QubitType:$qubit); + let results = (outs StaticQubit:$qubit); let assemblyFormat = "$index attr-dict `:` type($qubit)"; + + let builders = [OpBuilder<(ins "uint64_t":$index), [{ + build($_builder, $_state, QubitType::get($_builder.getContext(), /*isStatic=*/true), index); + }]>]; } //===----------------------------------------------------------------------===// // Measurement and Reset Operations //===----------------------------------------------------------------------===// -def MeasureOp : QCOOp<"measure"> { +def MeasureOp + : QCOOp<"measure", [TypesMatchWith<"qubit output type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Measure a qubit in the computational basis"; let description = [{ Measures a qubit in the computational (Z) basis, collapsing the state @@ -150,7 +177,7 @@ def MeasureOp : QCOOp<"measure"> { }]; let builders = [OpBuilder<(ins "Value":$qubit_in), [{ - build($_builder, $_state, QubitType::get($_builder.getContext()), $_builder.getI1Type(), + build($_builder, $_state, qubit_in.getType(), $_builder.getI1Type(), qubit_in, nullptr, nullptr, nullptr); }]>]; @@ -226,7 +253,10 @@ def GPhaseOp let hasCanonicalizer = 1; } -def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def IdOp : QCOOp<"id", + traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an Id gate to a qubit"; let description = [{ Applies an Id gate to a qubit and returns the transformed qubit. @@ -250,7 +280,10 @@ def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def XOp + : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an X gate to a qubit"; let description = [{ Applies an X gate to a qubit and returns the transformed qubit. @@ -274,7 +307,10 @@ def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def YOp + : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply a Y gate to a qubit"; let description = [{ Applies a Y gate to a qubit and returns the transformed qubit. @@ -298,7 +334,10 @@ def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def ZOp + : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply a Z gate to a qubit"; let description = [{ Applies a Z gate to a qubit and returns the transformed qubit. @@ -322,7 +361,10 @@ def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def HOp + : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply a H gate to a qubit"; let description = [{ Applies a H gate to a qubit and returns the transformed qubit. @@ -346,7 +388,10 @@ def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SOp + : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an S gate to a qubit"; let description = [{ Applies an S gate to a qubit and returns the transformed qubit. @@ -371,7 +416,10 @@ def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { } def SdgOp - : QCOOp<"sdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { + : QCOOp<"sdg", + traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", "qubit_in", + "qubit_out", "$_self">]> { let summary = "Apply an Sdg gate to a qubit"; let description = [{ Applies an Sdg gate to a qubit and returns the transformed qubit. @@ -395,7 +443,10 @@ def SdgOp let hasCanonicalizer = 1; } -def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def TOp + : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply a T gate to a qubit"; let description = [{ Applies a T gate to a qubit and returns the transformed qubit. @@ -420,7 +471,10 @@ def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { } def TdgOp - : QCOOp<"tdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { + : QCOOp<"tdg", + traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", "qubit_in", + "qubit_out", "$_self">]> { let summary = "Apply a Tdg gate to a qubit"; let description = [{ Applies a Tdg gate to a qubit and returns the transformed qubit. @@ -444,7 +498,10 @@ def TdgOp let hasCanonicalizer = 1; } -def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SXOp : QCOOp<"sx", + traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an SX gate to a qubit"; let description = [{ Applies an SX gate to a qubit and returns the transformed qubit. @@ -469,7 +526,10 @@ def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { } def SXdgOp - : QCOOp<"sxdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { + : QCOOp<"sxdg", + traits = [UnitaryOpInterface, OneTargetZeroParameter, + TypesMatchWith<"result type matches input", "qubit_in", + "qubit_out", "$_self">]> { let summary = "Apply an SXdg gate to a qubit"; let description = [{ Applies an SXdg gate to a qubit and returns the transformed qubit. @@ -493,7 +553,10 @@ def SXdgOp let hasCanonicalizer = 1; } -def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RXOp : QCOOp<"rx", + traits = [UnitaryOpInterface, OneTargetOneParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an RX gate to a qubit"; let description = [{ Applies an RX gate to a qubit and returns the transformed qubit. @@ -521,7 +584,10 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RYOp : QCOOp<"ry", + traits = [UnitaryOpInterface, OneTargetOneParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an RY gate to a qubit"; let description = [{ Applies an RY gate to a qubit and returns the transformed qubit. @@ -549,7 +615,10 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RZOp : QCOOp<"rz", + traits = [UnitaryOpInterface, OneTargetOneParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an RZ gate to a qubit"; let description = [{ Applies an RZ gate to a qubit and returns the transformed qubit. @@ -577,7 +646,10 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def POp + : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply a P gate to a qubit"; let description = [{ Applies a P gate to a qubit and returns the transformed qubit. @@ -605,7 +677,10 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def ROp + : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply an R gate to a qubit"; let description = [{ Applies an R gate to a qubit and returns the transformed qubit. @@ -635,7 +710,10 @@ def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def U2Op : QCOOp<"u2", + traits = [UnitaryOpInterface, OneTargetTwoParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply a U2 gate to a qubit"; let description = [{ Applies a U2 gate to a qubit and returns the transformed qubit. @@ -665,7 +743,10 @@ def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { +def UOp + : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter, + TypesMatchWith<"result type matches input", + "qubit_in", "qubit_out", "$_self">]> { let summary = "Apply a U gate to a qubit"; let description = [{ Applies a U gate to a qubit and returns the transformed qubit. @@ -698,7 +779,12 @@ def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { } def SWAPOp - : QCOOp<"swap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { + : QCOOp<"swap", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply a SWAP gate to two qubits"; let description = [{ Applies a SWAP gate to two qubits and returns the transformed qubits. @@ -726,7 +812,12 @@ def SWAPOp } def iSWAPOp - : QCOOp<"iswap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { + : QCOOp<"iswap", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply a iSWAP gate to two qubits"; let description = [{ Applies a iSWAP gate to two qubits and returns the transformed qubits. @@ -752,7 +843,12 @@ def iSWAPOp } def DCXOp - : QCOOp<"dcx", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { + : QCOOp<"dcx", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply a DCX gate to two qubits"; let description = [{ Applies a DCX gate to two qubits and returns the transformed qubits. @@ -780,7 +876,12 @@ def DCXOp } def ECROp - : QCOOp<"ecr", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { + : QCOOp<"ecr", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply an ECR gate to two qubits"; let description = [{ Applies an ECR gate to two qubits and returns the transformed qubits. @@ -807,7 +908,13 @@ def ECROp let hasCanonicalizer = 1; } -def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RXXOp + : QCOOp<"rxx", + traits = [UnitaryOpInterface, TwoTargetOneParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply an RXX gate to two qubits"; let description = [{ Applies an RXX gate to two qubits and returns the transformed qubits. @@ -838,7 +945,13 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RYYOp + : QCOOp<"ryy", + traits = [UnitaryOpInterface, TwoTargetOneParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply an RYY gate to two qubits"; let description = [{ Applies an RYY gate to two qubits and returns the transformed qubits. @@ -869,7 +982,13 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZXOp + : QCOOp<"rzx", + traits = [UnitaryOpInterface, TwoTargetOneParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply an RZX gate to two qubits"; let description = [{ Applies an RZX gate to two qubits and returns the transformed qubits. @@ -900,7 +1019,13 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZZOp + : QCOOp<"rzz", + traits = [UnitaryOpInterface, TwoTargetOneParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply an RZZ gate to two qubits"; let description = [{ Applies an RZZ gate to two qubits and returns the transformed qubits. @@ -931,8 +1056,13 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def XXPlusYYOp : QCOOp<"xx_plus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { +def XXPlusYYOp + : QCOOp<"xx_plus_yy", + traits = [UnitaryOpInterface, TwoTargetTwoParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply an XX+YY gate to two qubits"; let description = [{ Applies an XX+YY gate to two qubits and returns the transformed qubits. @@ -965,8 +1095,13 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", let hasCanonicalizer = 1; } -def XXMinusYYOp : QCOOp<"xx_minus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { +def XXMinusYYOp + : QCOOp<"xx_minus_yy", + traits = [UnitaryOpInterface, TwoTargetTwoParameter, + TypesMatchWith<"result type matches input", "qubit0_in", + "qubit0_out", "$_self">, + TypesMatchWith<"result type matches input", "qubit1_in", + "qubit1_out", "$_self">]> { let summary = "Apply an XX-YY gate to two qubits"; let description = [{ Applies an XX-YY gate to two qubits and returns the transformed qubits. @@ -1003,6 +1138,8 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { let summary = "Apply a barrier gate to a set of qubits"; let description = [{ Applies a barrier gate to a set of qubits and returns the transformed qubits. + Each output qubit type must match its corresponding input type (pairwise + type preservation, e.g., disallows !qco.qubit -> !qco.qubit). Example: ```mlir @@ -1016,6 +1153,8 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { let assemblyFormat = "$qubits_in attr-dict `:` type($qubits_in) `->` type($qubits_out)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ size_t getNumQubits() { return getNumTargets(); } size_t getNumTargets() { return getQubitsIn().size(); } @@ -1060,14 +1199,13 @@ def YieldOp : QCOOp<"yield", traits = [Terminator, ReturnLike]> { }]; let arguments = (ins Variadic:$targets); - let assemblyFormat = "$targets attr-dict"; + let assemblyFormat = "$targets attr-dict (`:` type($targets)^)?"; } def CtrlOp : QCOOp<"ctrl", traits = [UnitaryOpInterface, AttrSizedOperandSegments, - AttrSizedResultSegments, SameOperandsAndResultType, - SameOperandsAndResultShape, + AttrSizedResultSegments, SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOTypes.td b/mlir/include/mlir/Dialect/QCO/IR/QCOTypes.td index 2438b98d6c..358afffc8d 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOTypes.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOTypes.td @@ -26,6 +26,9 @@ def QubitType : QCOType<"Qubit", "qubit"> { and produce new output qubits following value semantics and the SSA paradigm, enabling powerful dataflow analysis and optimization. + `!qco.qubit` (default) denotes a dynamically allocated qubit. + `!qco.qubit` denotes a qubit with a statically known identifier. + Example: ```mlir %q0 = qco.alloc : !qco.qubit @@ -33,6 +36,11 @@ def QubitType : QCOType<"Qubit", "qubit"> { %q2 = qco.x %q1 : !qco.qubit -> !qco.qubit ``` }]; + let parameters = (ins DefaultValuedParameter<"bool", "false">:$isStatic); + let builders = [TypeBuilder<(ins), [{ + return $_get($_ctxt, /*isStatic=*/false); + }]>]; + let hasCustomAssemblyFormat = 1; } #endif // MLIR_DIALECT_QCO_IR_QCOTYPES_TD diff --git a/mlir/include/mlir/Dialect/QCO/Utils/ValueOrdering.h b/mlir/include/mlir/Dialect/QCO/Utils/ValueOrdering.h new file mode 100644 index 0000000000..e111820e06 --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/Utils/ValueOrdering.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include +#include + +namespace mlir::qco { + +/** + * @brief Deterministic order for SSA values. + * + * @details Uses block order when both defining ops are in the same block; + * otherwise fall back to opaque pointer order for a deterministic total order. + */ +struct SSAOrder { + bool operator()(Value a, Value b) const { + auto* opA = a.getDefiningOp(); + auto* opB = b.getDefiningOp(); + if (!opA || !opB || opA->getBlock() != opB->getBlock()) { + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + } + return opA->isBeforeInBlock(opB); + } +}; + +} // namespace mlir::qco diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 381486b02b..0f74f99045 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -55,9 +55,9 @@ class QCOToQCTypeConverter final : public TypeConverter { // Identity conversion for all types by default addConversion([](Type type) { return type; }); - // Convert QCO qubit values to QC qubit references - addConversion([ctx](qco::QubitType /*type*/) -> Type { - return qc::QubitType::get(ctx); + // Convert QCO qubit values to QC qubit references, preserving isStatic + addConversion([ctx](qco::QubitType type) -> Type { + return qc::QubitType::get(ctx, type.getIsStatic()); }); } }; @@ -96,19 +96,26 @@ struct ConvertQCOAllocOp final : OpConversionPattern { }; /** - * @brief Converts qco.dealloc to qc.dealloc + * @brief Converts qco.dealloc to qc.dealloc (dynamic) or erases it (static). * * @details - * Deallocates a qubit, releasing its resources. The OpAdaptor automatically - * provides the type-converted qubit operand (!qc.qubit instead of - * !qco.qubit), so we simply pass it through to the new operation. + * For dynamic qubits (`!qco.qubit`), converts to `qc.dealloc`. + * For static qubits (`!qco.qubit`), erases the op since QC does not + * require explicit deallocation of static qubits. * - * Example transformation: + * Example transformation (dynamic): * ```mlir * qco.dealloc %q_qco : !qco.qubit * // becomes: * qc.dealloc %q_qc : !qc.qubit * ``` + * + * Example transformation (static): + * ```mlir + * qco.dealloc %q_qco : !qco.qubit + * // becomes: + * (erased) + * ``` */ struct ConvertQCODeallocOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -116,7 +123,11 @@ struct ConvertQCODeallocOp final : OpConversionPattern { LogicalResult matchAndRewrite(qco::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // OpAdaptor provides the already type-converted qubit + if (op.getQubit().getType().getIsStatic()) { + rewriter.eraseOp(op); + return success(); + } + rewriter.replaceOpWithNewOp(op, adaptor.getQubit()); return success(); } diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 474f80aacf..1999197563 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -14,14 +14,17 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Utils/ValueOrdering.h" #include +#include #include +#include #include #include #include -#include #include +#include #include #include #include @@ -30,7 +33,6 @@ #include #include -#include #include namespace mlir { @@ -42,7 +44,6 @@ using namespace qc; #include "mlir/Conversion/QCToQCO/QCToQCO.h.inc" namespace { - /** * @brief State object for tracking qubit value flow during conversion * @@ -72,13 +73,25 @@ namespace { * - %q2 after the X gate */ struct LoweringState { - /// Map from original QC qubit references to their latest QCO SSA values - llvm::DenseMap qubitMap; - - /// Modifier information - int64_t inNestedRegion = 0; - DenseMap> targetsIn; - DenseMap> targetsOut; + struct ModifierFrame { + /// QC qubits yielded from the current modifier region, in yield order. + SmallVector yieldOrder; + + /// Latest QCO SSA values for QC qubits that are remapped inside the + /// modifier region. + llvm::DenseMap currentQubits; + }; + + /// Per-region map from QC qubit references to latest QCO SSA values. + /// + /// @details Keys are `Operation::getParentRegion()` for ops being converted + /// (typically a `func.func` body, or a `qc.ctrl` / `qc.inv` region). This + /// avoids clearing state at the first `func.return` while later functions + /// still convert. + llvm::DenseMap> qubitMap; + + /// Stack of active modifier regions (`qc.ctrl` / `qc.inv`). + SmallVector modifierFrames; }; /** @@ -110,6 +123,211 @@ class StatefulOpConversionPattern : public OpConversionPattern { private: LoweringState* state_; }; +} // namespace + +/** + * @brief Helper function to look up the latest QCO qubit value for a given QC + * qubit reference + * + * @param qubitMap The mapping from QC qubits to QCO qubits for the current + * region + * @param qcQubit The QC qubit reference to look up + * @return The latest QCO qubit value corresponding to the given QC qubit + * reference + */ +[[nodiscard]] static Value +lookupMappedQubit(llvm::DenseMap& qubitMap, Value qcQubit) { + auto it = qubitMap.find(qcQubit); + assert(it != qubitMap.end() && "QC qubit not found"); + return it->second; +} + +/** @brief Returns whether lowering currently processes a modifier body. */ +[[nodiscard]] static bool isInsideModifier(const LoweringState& state) { + return !state.modifierFrames.empty(); +} + +/** @brief Returns the active modifier frame. */ +[[nodiscard]] static LoweringState::ModifierFrame& +currentModifierFrame(LoweringState& state) { + assert(isInsideModifier(state) && "expected active modifier frame"); + return state.modifierFrames.back(); +} + +/** @brief Finds the nearest region-local qubit map containing @p qcQubit. */ +[[nodiscard]] static llvm::DenseMap* +findMappedQubitMap(LoweringState& state, Operation* anchor, Value qcQubit) { + for (Region* current = anchor->getParentRegion(); current != nullptr; + current = current->getParentRegion()) { + auto mapIt = state.qubitMap.find(current); + if (mapIt != state.qubitMap.end() && mapIt->second.contains(qcQubit)) { + return &mapIt->second; + } + } + return nullptr; +} + +/** @brief Resolves the latest QCO SSA value for a QC qubit reference. */ +[[nodiscard]] static Value lookupMappedQubit(LoweringState& state, + Operation* anchor, Value qcQubit) { + if (isInsideModifier(state)) { + auto& frame = currentModifierFrame(state); + if (auto it = frame.currentQubits.find(qcQubit); + it != frame.currentQubits.end()) { + return it->second; + } + } + + auto* qubitMap = findMappedQubitMap(state, anchor, qcQubit); + assert(qubitMap != nullptr && "QC qubit not found"); + return lookupMappedQubit(*qubitMap, qcQubit); +} + +/** @brief Updates the latest QCO SSA value for a QC qubit reference. */ +static void assignMappedQubit(LoweringState& state, Operation* anchor, + Value qcQubit, Value qcoQubit) { + if (isInsideModifier(state)) { + auto& frame = currentModifierFrame(state); + if (auto it = frame.currentQubits.find(qcQubit); + it != frame.currentQubits.end()) { + it->second = qcoQubit; + return; + } + } + + if (auto* qubitMap = findMappedQubitMap(state, anchor, qcQubit)) { + (*qubitMap)[qcQubit] = qcoQubit; + return; + } + + state.qubitMap[anchor->getParentRegion()][qcQubit] = qcoQubit; +} + +/** @brief Resolves a range of QC qubits to their latest QCO values. */ +template +[[nodiscard]] static SmallVector +resolveMappedQubits(LoweringState& state, Operation* anchor, + const Range& qcQubits) { + return llvm::to_vector(llvm::map_range(qcQubits, [&](Value qcQubit) { + return lookupMappedQubit(state, anchor, qcQubit); + })); +} + +/** @brief Updates mappings for matching QC and QCO qubit ranges. */ +template +static void assignMappedQubits(LoweringState& state, Operation* anchor, + const QcRange& qcQubits, + const QcoRange& qcoQubits) { + for (auto [qcQubit, qcoQubit] : llvm::zip_equal(qcQubits, qcoQubits)) { + assignMappedQubit(state, anchor, qcQubit, qcoQubit); + } +} + +/** @brief Collects the target qubits of a variadic QC unitary op. */ +template +[[nodiscard]] static SmallVector collectTargets(OpType op) { + SmallVector targets; + targets.reserve(op.getNumTargets()); + for (size_t i = 0; i < op.getNumTargets(); ++i) { + targets.emplace_back(op.getTarget(i)); + } + return targets; +} + +/** @brief Pushes a new modifier frame seeded with aliased target values. */ +static void pushModifierFrame(LoweringState& state, ValueRange qcTargets, + ValueRange qcoTargets) { + auto& [yieldOrder, currentQubits] = state.modifierFrames.emplace_back(); + llvm::append_range(yieldOrder, qcTargets); + for (auto [qcTarget, qcoTarget] : llvm::zip_equal(qcTargets, qcoTargets)) { + currentQubits.try_emplace(qcTarget, qcoTarget); + } +} + +/** @brief Pops the active modifier frame after lowering its yield. */ +static void popModifierFrame(LoweringState& state) { + assert(isInsideModifier(state) && "expected active modifier frame"); + state.modifierFrames.pop_back(); +} + +/** @brief Adds entry block aliases for modifier target values. */ +template +[[nodiscard]] static SmallVector +addModifierAliases(OpType op, ValueRange qcoTargets, + PatternRewriter& rewriter) { + auto& entryBlock = op.getRegion().front(); + SmallVector aliases; + aliases.reserve(qcoTargets.size()); + const auto opLoc = op.getLoc(); + rewriter.modifyOpInPlace(op, [&] { + for (Value qcoTarget : qcoTargets) { + aliases.emplace_back(entryBlock.addArgument(qcoTarget.getType(), opLoc)); + } + }); + return aliases; +} + +namespace { + +/** + * @brief Converts func.return and sinks remaining live qubits. + * + * @details + * QC uses reference semantics and does not enforce linear typing for qubits. + * After conversion, QCO requires that every qubit SSA value is consumed + * exactly once. For allocations (including static qubits), the sink is + * `qco.dealloc`. This pattern inserts `qco.dealloc` operations for all + * still-live qubits tracked in the lowering state right before the return. + */ +struct ConvertFuncReturnOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); + Region* funcRegion = op->getParentRegion(); + auto& map = state.qubitMap[funcRegion]; + + // Build return values from qubitMap (adaptor.getOperands() may carry stale + // root values because gate patterns use eraseOp instead of replaceOp). + llvm::SmallVector returnValues; + llvm::DenseSet escapedQubits; + returnValues.reserve(op.getNumOperands()); + for (auto [qcOperand, adaptorOperand] : + llvm::zip_equal(op.getOperands(), adaptor.getOperands())) { + if (map.contains(qcOperand)) { + auto latest = map[qcOperand]; + returnValues.emplace_back(latest); + escapedQubits.insert(latest); + } else { + returnValues.emplace_back(adaptorOperand); + } + } + + // Collect non-escaped live qubits for deallocation. + llvm::DenseSet liveQubits; + for (Value qcoQubit : llvm::make_second_range(map)) { + if (!escapedQubits.contains(qcoQubit)) { + liveQubits.insert(qcoQubit); + } + } + // Copy to a vector before sorting: DenseSet iterators are not + // random-access. + llvm::SmallVector liveQubitsSorted(liveQubits.begin(), + liveQubits.end()); + llvm::sort(liveQubitsSorted, SSAOrder{}); + + for (Value qubit : liveQubitsSorted) { + qco::DeallocOp::create(rewriter, op.getLoc(), qubit); + } + + state.qubitMap.erase(funcRegion); + + rewriter.replaceOpWithNewOp(op, returnValues); + return success(); + } +}; /** * @brief Type converter for QC-to-QCO conversion @@ -128,9 +346,9 @@ class QCToQCOTypeConverter final : public TypeConverter { // Identity conversion for all types by default addConversion([](Type type) { return type; }); - // Convert QC qubit references to QCO qubit values - addConversion([ctx](qc::QubitType /*type*/) -> Type { - return qco::QubitType::get(ctx); + // Convert QC qubit references to QCO qubit values, preserving isStatic + addConversion([ctx](qc::QubitType type) -> Type { + return qco::QubitType::get(ctx, type.getIsStatic()); }); } }; @@ -159,7 +377,8 @@ struct ConvertQCAllocOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::AllocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; + auto& state = getState(); + auto* operation = op.getOperation(); auto qcQubit = op.getResult(); // Create the qco.alloc operation with preserved register metadata @@ -171,7 +390,7 @@ struct ConvertQCAllocOp final : StatefulOpConversionPattern { // Establish initial mapping: this QC qubit reference now corresponds // to this QCO SSA value - qubitMap.try_emplace(qcQubit, qcoQubit); + assignMappedQubit(state, operation, qcQubit, qcoQubit); return success(); } @@ -198,12 +417,12 @@ struct ConvertQCDeallocOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; - auto qcQubit = op.getQubit(); - - // Look up the latest QCO value for this QC qubit - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - auto qcoQubit = qubitMap[qcQubit]; + auto& state = getState(); + auto* operation = op.getOperation(); + auto* region = operation->getParentRegion(); + auto& qubitMap = state.qubitMap[region]; + Value qcQubit = op.getQubit(); + Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the dealloc operation rewriter.replaceOpWithNewOp(op, qcoQubit); @@ -236,20 +455,12 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::StaticOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; - auto qcQubit = op.getQubit(); - - // Create new qco.static operation with the same index - auto qcoOp = qco::StaticOp::create(rewriter, op.getLoc(), op.getIndex()); - - // Collect QCO qubit SSA value - auto qcoQubit = qcoOp.getQubit(); - - // Establish mapping from QC reference to QCO value - qubitMap[qcQubit] = qcoQubit; + auto& state = getState(); + auto* operation = op.getOperation(); + Value qcQubit = op.getQubit(); - // Replace the old operation result with the new result - rewriter.replaceOp(op, qcoQubit); + auto qcoOp = rewriter.replaceOpWithNewOp(op, op.getIndex()); + assignMappedQubit(state, operation, qcQubit, qcoOp.getQubit()); return success(); } @@ -284,26 +495,21 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::MeasureOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; - auto qcQubit = op.getQubit(); - - // Get the latest QCO qubit value from the state map - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - auto qcoQubit = qubitMap[qcQubit]; + auto& state = getState(); + auto* operation = op.getOperation(); + Value qcQubit = op.getQubit(); + Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create qco.measure (returns both output qubit and bit result) auto qcoOp = qco::MeasureOp::create( rewriter, op.getLoc(), qcoQubit, op.getRegisterNameAttr(), op.getRegisterSizeAttr(), op.getRegisterIndexAttr()); - auto outQcoQubit = qcoOp.getQubitOut(); - auto newBit = qcoOp.getResult(); - // Update mapping: the QC qubit now corresponds to the output qubit - qubitMap[qcQubit] = outQcoQubit; + assignMappedQubit(state, operation, qcQubit, qcoOp.getQubitOut()); // Replace the QC operation's bit result with the QCO bit result - rewriter.replaceOp(op, newBit); + rewriter.replaceOp(op, qcoOp.getResult()); return success(); } @@ -335,18 +541,16 @@ struct ConvertQCResetOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::ResetOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; - auto qcQubit = op.getQubit(); - - // Get the latest QCO qubit value from the state map - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - auto qcoQubit = qubitMap[qcQubit]; + auto& state = getState(); + auto* operation = op.getOperation(); + Value qcQubit = op.getQubit(); + Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create qco.reset (consumes input, produces output) auto qcoOp = qco::ResetOp::create(rewriter, op.getLoc(), qcoQubit); // Update mapping: the QC qubit now corresponds to the reset output - qubitMap[qcQubit] = qcoOp.getQubitOut(); + assignMappedQubit(state, operation, qcQubit, qcoOp.getQubitOut()); // Erase the old (it has no results to replace) rewriter.eraseOp(op); @@ -378,18 +582,8 @@ struct ConvertQCZeroTargetOneParameterToQCO final LogicalResult matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& state = this->getState(); - const auto inNestedRegion = state.inNestedRegion; - QCOOpType::create(rewriter, op.getLoc(), op.getParameter(0)); - // Update the state - if (inNestedRegion != 0) { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut; - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } - rewriter.eraseOp(op); return success(); @@ -420,32 +614,14 @@ struct ConvertQCOneTargetZeroParameterToQCO final matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - auto& qubitMap = state.qubitMap; - const auto inNestedRegion = state.inNestedRegion; - - // Get the latest QCO qubit - auto qcQubit = op.getQubitIn(); - Value qcoQubit; - if (inNestedRegion == 0) { - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - qcoQubit = qubitMap[qcQubit]; - } else { - assert(state.targetsIn[inNestedRegion].size() == 1 && - "Invalid number of input targets"); - qcoQubit = state.targetsIn[inNestedRegion].front(); - } + auto* operation = op.getOperation(); + Value qcQubit = op.getQubitIn(); + Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit); - // Update the state map - if (inNestedRegion == 0) { - qubitMap[qcQubit] = qcoOp.getQubitOut(); - } else { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut({qcoOp.getQubitOut()}); - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } + assignMappedQubit(state, operation, qcQubit, qcoOp.getQubitOut()); rewriter.eraseOp(op); @@ -477,33 +653,15 @@ struct ConvertQCOneTargetOneParameterToQCO final matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - auto& qubitMap = state.qubitMap; - const auto inNestedRegion = state.inNestedRegion; - - // Get the latest QCO qubit - auto qcQubit = op.getQubitIn(); - Value qcoQubit; - if (inNestedRegion == 0) { - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - qcoQubit = qubitMap[qcQubit]; - } else { - assert(state.targetsIn[inNestedRegion].size() == 1 && - "Invalid number of input targets"); - qcoQubit = state.targetsIn[inNestedRegion].front(); - } + auto* operation = op.getOperation(); + Value qcQubit = op.getQubitIn(); + Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit, op.getParameter(0)); - // Update the state map - if (inNestedRegion == 0) { - qubitMap[qcQubit] = qcoOp.getQubitOut(); - } else { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut({qcoOp.getQubitOut()}); - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } + assignMappedQubit(state, operation, qcQubit, qcoOp.getQubitOut()); rewriter.eraseOp(op); @@ -535,33 +693,15 @@ struct ConvertQCOneTargetTwoParameterToQCO final matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - auto& qubitMap = state.qubitMap; - const auto inNestedRegion = state.inNestedRegion; - - // Get the latest QCO qubit - auto qcQubit = op.getQubitIn(); - Value qcoQubit; - if (inNestedRegion == 0) { - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - qcoQubit = qubitMap[qcQubit]; - } else { - assert(state.targetsIn[inNestedRegion].size() == 1 && - "Invalid number of input targets"); - qcoQubit = state.targetsIn[inNestedRegion].front(); - } + auto* operation = op.getOperation(); + Value qcQubit = op.getQubitIn(); + Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit, op.getParameter(0), op.getParameter(1)); - // Update the state map - if (inNestedRegion == 0) { - qubitMap[qcQubit] = qcoOp.getQubitOut(); - } else { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut({qcoOp.getQubitOut()}); - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } + assignMappedQubit(state, operation, qcQubit, qcoOp.getQubitOut()); rewriter.eraseOp(op); @@ -593,34 +733,16 @@ struct ConvertQCOneTargetThreeParameterToQCO final matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - auto& qubitMap = state.qubitMap; - const auto inNestedRegion = state.inNestedRegion; - - // Get the latest QCO qubit - auto qcQubit = op.getQubitIn(); - Value qcoQubit; - if (inNestedRegion == 0) { - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - qcoQubit = qubitMap[qcQubit]; - } else { - assert(state.targetsIn[inNestedRegion].size() == 1 && - "Invalid number of input targets"); - qcoQubit = state.targetsIn[inNestedRegion].front(); - } + auto* operation = op.getOperation(); + Value qcQubit = op.getQubitIn(); + Value qcoQubit = lookupMappedQubit(state, operation, qcQubit); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit, op.getParameter(0), op.getParameter(1), op.getParameter(2)); - // Update the state map - if (inNestedRegion == 0) { - qubitMap[qcQubit] = qcoOp.getQubitOut(); - } else { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut({qcoOp.getQubitOut()}); - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } + assignMappedQubit(state, operation, qcQubit, qcoOp.getQubitOut()); rewriter.eraseOp(op); @@ -653,40 +775,17 @@ struct ConvertQCTwoTargetZeroParameterToQCO final matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - auto& qubitMap = state.qubitMap; - const auto inNestedRegion = state.inNestedRegion; - - // Get the latest QCO qubits - auto qcQubit0 = op.getQubit0In(); - auto qcQubit1 = op.getQubit1In(); - Value qcoQubit0; - Value qcoQubit1; - if (inNestedRegion == 0) { - assert(qubitMap.contains(qcQubit0) && "QC qubit not found"); - assert(qubitMap.contains(qcQubit1) && "QC qubit not found"); - qcoQubit0 = qubitMap[qcQubit0]; - qcoQubit1 = qubitMap[qcQubit1]; - } else { - assert(state.targetsIn[inNestedRegion].size() == 2 && - "Invalid number of input targets"); - const auto& targetsIn = state.targetsIn[inNestedRegion]; - qcoQubit0 = targetsIn[0]; - qcoQubit1 = targetsIn[1]; - } + auto* operation = op.getOperation(); + Value qcQubit0 = op.getQubit0In(); + Value qcQubit1 = op.getQubit1In(); + Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1); - // Update the state map - if (inNestedRegion == 0) { - qubitMap[qcQubit0] = qcoOp.getQubit0Out(); - qubitMap[qcQubit1] = qcoOp.getQubit1Out(); - } else { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut( - {qcoOp.getQubit0Out(), qcoOp.getQubit1Out()}); - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } + assignMappedQubit(state, operation, qcQubit0, qcoOp.getQubit0Out()); + assignMappedQubit(state, operation, qcQubit1, qcoOp.getQubit1Out()); rewriter.eraseOp(op); @@ -719,41 +818,18 @@ struct ConvertQCTwoTargetOneParameterToQCO final matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - auto& qubitMap = state.qubitMap; - const auto inNestedRegion = state.inNestedRegion; - - // Get the latest QCO qubits - auto qcQubit0 = op.getQubit0In(); - auto qcQubit1 = op.getQubit1In(); - Value qcoQubit0; - Value qcoQubit1; - if (inNestedRegion == 0) { - assert(qubitMap.contains(qcQubit0) && "QC qubit not found"); - assert(qubitMap.contains(qcQubit1) && "QC qubit not found"); - qcoQubit0 = qubitMap[qcQubit0]; - qcoQubit1 = qubitMap[qcQubit1]; - } else { - assert(state.targetsIn[inNestedRegion].size() == 2 && - "Invalid number of input targets"); - const auto& targetsIn = state.targetsIn[inNestedRegion]; - qcoQubit0 = targetsIn[0]; - qcoQubit1 = targetsIn[1]; - } + auto* operation = op.getOperation(); + Value qcQubit0 = op.getQubit0In(); + Value qcQubit1 = op.getQubit1In(); + Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1, op.getParameter(0)); - // Update the state map - if (inNestedRegion == 0) { - qubitMap[qcQubit0] = qcoOp.getQubit0Out(); - qubitMap[qcQubit1] = qcoOp.getQubit1Out(); - } else { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut( - {qcoOp.getQubit0Out(), qcoOp.getQubit1Out()}); - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } + assignMappedQubit(state, operation, qcQubit0, qcoOp.getQubit0Out()); + assignMappedQubit(state, operation, qcQubit1, qcoOp.getQubit1Out()); rewriter.eraseOp(op); @@ -786,41 +862,18 @@ struct ConvertQCTwoTargetTwoParameterToQCO final matchAndRewrite(QCOpType op, QCOpType::Adaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - auto& qubitMap = state.qubitMap; - const auto inNestedRegion = state.inNestedRegion; - - // Get the latest QCO qubits - auto qcQubit0 = op.getQubit0In(); - auto qcQubit1 = op.getQubit1In(); - Value qcoQubit0; - Value qcoQubit1; - if (inNestedRegion == 0) { - assert(qubitMap.contains(qcQubit0) && "QC qubit not found"); - assert(qubitMap.contains(qcQubit1) && "QC qubit not found"); - qcoQubit0 = qubitMap[qcQubit0]; - qcoQubit1 = qubitMap[qcQubit1]; - } else { - assert(state.targetsIn[inNestedRegion].size() == 2 && - "Invalid number of input targets"); - const auto& targetsIn = state.targetsIn[inNestedRegion]; - qcoQubit0 = targetsIn[0]; - qcoQubit1 = targetsIn[1]; - } + auto* operation = op.getOperation(); + Value qcQubit0 = op.getQubit0In(); + Value qcQubit1 = op.getQubit1In(); + Value qcoQubit0 = lookupMappedQubit(state, operation, qcQubit0); + Value qcoQubit1 = lookupMappedQubit(state, operation, qcQubit1); // Create the QCO operation (consumes input, produces output) auto qcoOp = QCOOpType::create(rewriter, op.getLoc(), qcoQubit0, qcoQubit1, op.getParameter(0), op.getParameter(1)); - // Update the state map - if (inNestedRegion == 0) { - qubitMap[qcQubit0] = qcoOp.getQubit0Out(); - qubitMap[qcQubit1] = qcoOp.getQubit1Out(); - } else { - state.targetsIn.erase(inNestedRegion); - const SmallVector targetsOut( - {qcoOp.getQubit0Out(), qcoOp.getQubit1Out()}); - state.targetsOut.try_emplace(inNestedRegion, targetsOut); - } + assignMappedQubit(state, operation, qcQubit0, qcoOp.getQubit0Out()); + assignMappedQubit(state, operation, qcQubit1, qcoOp.getQubit1Out()); rewriter.eraseOp(op); @@ -848,25 +901,14 @@ struct ConvertQCBarrierOp final : StatefulOpConversionPattern { matchAndRewrite(qc::BarrierOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - auto& qubitMap = state.qubitMap; - - // Get QCO qubits from state map - auto qcQubits = op.getQubits(); - SmallVector qcoQubits; - qcoQubits.reserve(qcQubits.size()); - for (const auto& qcQubit : qcQubits) { - assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - qcoQubits.push_back(qubitMap[qcQubit]); - } + auto* operation = op.getOperation(); + const auto qcQubits = llvm::to_vector(op.getQubits()); + auto qcoQubits = resolveMappedQubits(state, operation, qcQubits); // Create qco.barrier auto qcoOp = qco::BarrierOp::create(rewriter, op.getLoc(), qcoQubits); - // Update the state map - for (const auto& [qcQubit, qcoQubitOut] : - llvm::zip(qcQubits, qcoOp.getQubitsOut())) { - qubitMap[qcQubit] = qcoQubitOut; - } + assignMappedQubits(state, operation, qcQubits, qcoOp.getQubitsOut()); rewriter.eraseOp(op); return success(); @@ -897,48 +939,18 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { matchAndRewrite(qc::CtrlOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - auto& qubitMap = state.qubitMap; - - // Get QCO controls from state map - auto qcControls = op.getControls(); - SmallVector qcoControls; - qcoControls.reserve(qcControls.size()); - for (const auto& qcControl : qcControls) { - assert(qubitMap.contains(qcControl) && "QC qubit not found"); - qcoControls.push_back(qubitMap[qcControl]); - } - - // Get QCO targets from state map - const auto numTargets = op.getNumTargets(); - SmallVector qcoTargets; - qcoTargets.reserve(numTargets); - for (size_t i = 0; i < numTargets; ++i) { - auto qcTarget = op.getTarget(i); - assert(qubitMap.contains(qcTarget) && "QC qubit not found"); - auto qcoTarget = qubitMap[qcTarget]; - qcoTargets.push_back(qcoTarget); - } + auto* operation = op.getOperation(); + const auto qcControls = llvm::to_vector(op.getControls()); + const auto qcTargets = collectTargets(op); + auto qcoControls = resolveMappedQubits(state, operation, qcControls); + auto qcoTargets = resolveMappedQubits(state, operation, qcTargets); // Create qco.ctrl auto qcoOp = qco::CtrlOp::create(rewriter, op.getLoc(), qcoControls, qcoTargets); - // Update the state map if this is a top-level CtrlOp - // Nested CtrlOps are managed via the targetsIn and targetsOut maps - if (state.inNestedRegion == 0) { - for (const auto& [qcControl, qcoControl] : - llvm::zip(qcControls, qcoOp.getControlsOut())) { - qubitMap[qcControl] = qcoControl; - } - auto qcoTargetsOut = qcoOp.getTargetsOut(); - for (size_t i = 0; i < numTargets; ++i) { - auto qcTarget = op.getTarget(i); - qubitMap[qcTarget] = qcoTargetsOut[i]; - } - } - - // Update modifier information - state.inNestedRegion++; + assignMappedQubits(state, operation, qcControls, qcoOp.getControlsOut()); + assignMappedQubits(state, operation, qcTargets, qcoOp.getTargetsOut()); // Clone body region from QC to QCO auto& dstRegion = qcoOp.getRegion(); @@ -948,16 +960,8 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { auto& entryBlock = dstRegion.front(); assert(entryBlock.getNumArguments() == 0 && "QC ctrl region unexpectedly has entry block arguments"); - SmallVector qcoTargetAliases; - qcoTargetAliases.reserve(numTargets); - const auto qubitType = qco::QubitType::get(qcoOp.getContext()); - const auto opLoc = op.getLoc(); - rewriter.modifyOpInPlace(qcoOp, [&] { - for (auto i = 0UL; i < numTargets; i++) { - qcoTargetAliases.emplace_back(entryBlock.addArgument(qubitType, opLoc)); - } - }); - state.targetsIn[state.inNestedRegion] = std::move(qcoTargetAliases); + auto qcoTargetAliases = addModifierAliases(qcoOp, qcoTargets, rewriter); + pushModifierFrame(state, qcTargets, qcoTargetAliases); rewriter.eraseOp(op); return success(); @@ -987,61 +991,26 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(qc::InvOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& [qubitMap, inNestedRegion, targetsIn, targetsOut] = getState(); - - // Get QCO targets from state map - const auto numTargets = op.getNumTargets(); - SmallVector qcoTargets; - if (inNestedRegion == 0) { - qcoTargets.reserve(numTargets); - for (size_t i = 0; i < numTargets; ++i) { - auto qcTarget = op.getTarget(i); - assert(qubitMap.contains(qcTarget) && "QC qubit not found"); - qcoTargets.emplace_back(qubitMap[qcTarget]); - } - } else { - assert(targetsIn[inNestedRegion].size() == numTargets && - "Invalid number of input targets"); - qcoTargets = targetsIn[inNestedRegion]; - } + auto& state = getState(); + auto* operation = op.getOperation(); + const auto qcTargets = collectTargets(op); + auto qcoTargets = resolveMappedQubits(state, operation, qcTargets); // Create qco.inv auto qcoOp = qco::InvOp::create(rewriter, op.getLoc(), qcoTargets); - // Update state map - if (inNestedRegion == 0) { - const auto qubitsOut = qcoOp.getQubitsOut(); - for (size_t i = 0; i < numTargets; ++i) { - auto qcTarget = op.getTarget(i); - qubitMap[qcTarget] = qubitsOut[i]; - } - } else { - targetsIn.erase(inNestedRegion); - targetsOut.try_emplace(inNestedRegion, qcoOp.getQubitsOut()); - } - - // Update modifier information - inNestedRegion++; + assignMappedQubits(state, operation, qcTargets, qcoOp.getQubitsOut()); // Clone body region from QC to QCO auto& dstRegion = qcoOp.getRegion(); rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - // Create block arguments for target qubits and store them in - // `state.targetsIn`. + // Create block arguments for target qubits and seed the nested frame. auto& entryBlock = dstRegion.front(); assert(entryBlock.getNumArguments() == 0 && "QC inv region unexpectedly has entry block arguments"); - SmallVector qcoTargetAliases; - qcoTargetAliases.reserve(numTargets); - const auto qubitType = qco::QubitType::get(qcoOp.getContext()); - const auto opLoc = op.getLoc(); - rewriter.modifyOpInPlace(qcoOp, [&] { - for (auto i = 0UL; i < numTargets; i++) { - qcoTargetAliases.emplace_back(entryBlock.addArgument(qubitType, opLoc)); - } - }); - targetsIn[inNestedRegion] = std::move(qcoTargetAliases); + auto qcoTargetAliases = addModifierAliases(qcoOp, qcoTargets, rewriter); + pushModifierFrame(state, qcTargets, qcoTargetAliases); rewriter.eraseOp(op); return success(); @@ -1067,10 +1036,11 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { matchAndRewrite(qc::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - const auto& targets = state.targetsOut[state.inNestedRegion]; + auto* operation = op.getOperation(); + auto& frame = currentModifierFrame(state); + auto targets = resolveMappedQubits(state, operation, frame.yieldOrder); rewriter.replaceOpWithNewOp(op, targets); - state.targetsOut.erase(state.inNestedRegion); - state.inNestedRegion--; + popModifierFrame(state); return success(); } }; @@ -1158,9 +1128,19 @@ struct QCToQCO final : impl::QCToQCOBase { }); // Conversion of qc types in func.return - populateReturnOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](const func::ReturnOp op) { return typeConverter.isLegal(op); }); + // + // Note: `func.return` may already be type-legal even though we still need + // to insert sink operations (`qco.dealloc`) for remaining live + // qubits. Therefore, we make it dynamically illegal unless the lowering + // state has no remaining qubits. + patterns.add(typeConverter, context, &state); + target.addDynamicallyLegalOp([&](const func::ReturnOp op) { + if (!typeConverter.isLegal(op)) { + return false; + } + const auto it = state.qubitMap.find(op->getParentRegion()); + return it == state.qubitMap.end() || it->second.empty(); + }); // Conversion of qc types in func.call populateCallOpTypeConversionPattern(patterns, typeConverter); diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index da5cf312dd..991dba1c1b 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -79,16 +79,10 @@ Value QCProgramBuilder::allocQubit() { return qubit; } -Value QCProgramBuilder::staticQubit(const int64_t index) { +Value QCProgramBuilder::staticQubit(const uint64_t index) { checkFinalized(); - if (index < 0) { - llvm::reportFatalUsageError("Index must be non-negative"); - } - - // Create the StaticOp with the given index - auto indexAttr = getI64IntegerAttr(index); - auto staticOp = StaticOp::create(*this, indexAttr); + auto staticOp = StaticOp::create(*this, index); return staticOp.getQubit(); } diff --git a/mlir/lib/Dialect/QC/IR/QCOps.cpp b/mlir/lib/Dialect/QC/IR/QCOps.cpp index 5b93c2ebaa..0b573461f0 100644 --- a/mlir/lib/Dialect/QC/IR/QCOps.cpp +++ b/mlir/lib/Dialect/QC/IR/QCOps.cpp @@ -16,6 +16,8 @@ // IWYU pragma: begin_keep #include #include +#include +#include // IWYU pragma: end_keep using namespace mlir; @@ -46,6 +48,24 @@ void QCDialect::initialize() { // Types //===----------------------------------------------------------------------===// +/// Print `!qc.qubit` (dynamic, default) or `!qc.qubit`. +void QubitType::print(AsmPrinter& printer) const { + if (getIsStatic()) { + printer << ""; + } +} + +/// Parse `!qc.qubit` or `!qc.qubit`. +Type QubitType::parse(AsmParser& parser) { + if (succeeded(parser.parseOptionalLess())) { + if (parser.parseKeyword("static") || parser.parseGreater()) { + return {}; + } + return get(parser.getContext(), /*isStatic=*/true); + } + return get(parser.getContext(), /*isStatic=*/false); +} + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/QC/IR/QCOpsTypes.cpp.inc" diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 05b55ff6dc..327266b804 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Utils/ValueOrdering.h" #include "mlir/Dialect/QTensor/IR/QTensorDialect.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Dialect/Utils/Utils.h" @@ -83,15 +84,10 @@ Value QCOProgramBuilder::allocQubit() { return qubit; } -Value QCOProgramBuilder::staticQubit(const int64_t index) { +Value QCOProgramBuilder::staticQubit(const uint64_t index) { checkFinalized(); - if (index < 0) { - llvm::reportFatalUsageError("Index must be non-negative"); - } - - auto indexAttr = getI64IntegerAttr(index); - auto staticOp = StaticOp::create(*this, indexAttr); + auto staticOp = StaticOp::create(*this, index); const auto qubit = staticOp.getQubit(); // Track the static qubit as valid @@ -747,9 +743,8 @@ std::pair QCOProgramBuilder::ctrl( auto ctrlOp = CtrlOp::create(*this, controls, targets); auto& block = ctrlOp.getBodyRegion().emplaceBlock(); - const auto qubitType = QubitType::get(getContext()); for (const auto target : targets) { - const auto arg = block.addArgument(qubitType, getLoc()); + const auto arg = block.addArgument(target.getType(), getLoc()); updateQubitTracking(target, arg); } const InsertionGuard guard(*this); @@ -785,9 +780,8 @@ ValueRange QCOProgramBuilder::inv( // Add block arguments for all qubits auto& block = invOp.getBodyRegion().emplaceBlock(); - const auto qubitType = QubitType::get(getContext()); for (const auto qubit : qubits) { - const auto arg = block.addArgument(qubitType, getLoc()); + const auto arg = block.addArgument(qubit.getType(), getLoc()); updateQubitTracking(qubit, arg); } @@ -924,19 +918,10 @@ OwningOpRef QCOProgramBuilder::finalize() { "Insertion point is not in entry block of main function"); } - auto blockOrderComparator = [](Value a, Value b) { - auto* opA = a.getDefiningOp(); - auto* opB = b.getDefiningOp(); - if (!opA || !opB || opA->getBlock() != opB->getBlock()) { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - } - return opA->isBeforeInBlock(opB); - }; - // Automatically deallocate all still-allocated qubits // Sort qubits for deterministic output llvm::SmallVector sortedQubits(validQubits.begin(), validQubits.end()); - llvm::sort(sortedQubits, blockOrderComparator); + llvm::sort(sortedQubits, SSAOrder{}); for (auto qubit : sortedQubits) { DeallocOp::create(*this, qubit); @@ -946,7 +931,7 @@ OwningOpRef QCOProgramBuilder::finalize() { // Sort tensors for deterministic output llvm::SmallVector sortedTensors(validTensors.begin(), validTensors.end()); - llvm::sort(sortedTensors, blockOrderComparator); + llvm::sort(sortedTensors, SSAOrder{}); for (auto tensor : sortedTensors) { qtensor::DeallocOp::create(*this, tensor); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 96c811eed8..91db9e582a 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -8,7 +8,6 @@ * Licensed under the MIT License */ -#include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include @@ -115,6 +114,10 @@ struct ReduceCtrl final : OpRewritePattern { return success(); } + // Capture the promoted control's type before adjusting segments; after + // setAttr, getControlsIn().back() would point to a different control. + const auto promotedControlType = op.getControlsIn().back().getType(); + // Adjust the segment sizes of the control and target operands const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); auto segmentsAttr = @@ -125,9 +128,9 @@ struct ReduceCtrl final : OpRewritePattern { const auto opResultSegmentsAttrName = CtrlOp::getResultSegmentSizeAttr(); op->setAttr(opResultSegmentsAttrName, newSegments); - // Add a block argument for the target qubit - auto arg = op.getBody()->addArgument(QubitType::get(rewriter.getContext()), - op.getLoc()); + // Add a block argument for the promoted target qubit, preserving the + // control's type (including isStatic) + auto arg = op.getBody()->addArgument(promotedControlType, op.getLoc()); // Replace the current GPhaseOp with a PhaseOp const OpBuilder::InsertionGuard guard(rewriter); @@ -242,9 +245,8 @@ void CtrlOp::build( build(odsBuilder, odsState, controls, targets); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); - for (size_t i = 0; i < targets.size(); ++i) { - block.addArgument(qubitType, odsState.location); + for (auto target : targets) { + block.addArgument(target.getType(), odsState.location); } const OpBuilder::InsertionGuard guard(odsBuilder); @@ -254,6 +256,17 @@ void CtrlOp::build( } LogicalResult CtrlOp::verify() { + // Allows !qco.qubit and !qco.qubit to differ between controls and + // targets, but requires pairwise equality within each group. + if (!llvm::equal(getControlsIn().getTypes(), getControlsOut().getTypes())) { + return emitOpError("qco.ctrl control qubit input types must match output " + "types pairwise"); + } + if (!llvm::equal(getTargetsIn().getTypes(), getTargetsOut().getTypes())) { + return emitOpError("qco.ctrl target qubit input types must match output " + "types pairwise"); + } + auto& block = *getBody(); if (block.getOperations().size() < 2) { return emitOpError("body region must have at least two operations"); @@ -263,9 +276,8 @@ LogicalResult CtrlOp::verify() { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { - if (block.getArgument(i).getType() != qubitType) { + if (block.getArgument(i).getType() != getTargetsIn()[i].getType()) { return emitOpError("block argument type at index ") << i << " does not match target type"; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 18cc331c5b..dd97a43687 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -8,7 +8,6 @@ * Licensed under the MIT License */ -#include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" #include @@ -326,9 +325,8 @@ void InvOp::build( build(odsBuilder, odsState, qubits); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); - for (size_t i = 0; i < qubits.size(); ++i) { - block.addArgument(qubitType, odsState.location); + for (auto qubit : qubits) { + block.addArgument(qubit.getType(), odsState.location); } const OpBuilder::InsertionGuard guard(odsBuilder); @@ -347,9 +345,8 @@ LogicalResult InvOp::verify() { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { - if (block.getArgument(i).getType() != qubitType) { + if (block.getArgument(i).getType() != getQubitsIn()[i].getType()) { return emitOpError("block argument type at index ") << i << " does not match target type"; } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp index d168624987..c7488fd6b0 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp @@ -77,6 +77,14 @@ struct MergeSubsequentBarrier final : OpRewritePattern { } // namespace +LogicalResult BarrierOp::verify() { + if (!llvm::equal(getQubitsIn().getTypes(), getQubitsOut().getTypes())) { + return emitOpError("qco.barrier qubit input types must match output types " + "pairwise"); + } + return success(); +} + Value BarrierOp::getInputTarget(const size_t i) { if (i < getNumTargets()) { return getQubitsIn()[i]; diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index 36aa66ad54..02105cdc8b 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -14,10 +14,12 @@ #include #include +#include #include #include #include #include +#include #include #include #include @@ -68,10 +70,19 @@ parseTargetAliasing(OpAsmParser& parser, Region& region, } operands.push_back(oldOperand); - // Hard-code QubitType since targets in qco.ctrl are always qubits. - // This avoids double-binding type($targets_in) in the assembly format - // while keeping the parser simple and the assembly format clean. - newArg.type = QubitType::get(parser.getBuilder().getContext()); + // Parse optional inline type to preserve isStatic; when absent, default + // to dynamic to avoid double-binding type($targets_in) in the assembly + // format. + Type operandType; + if (succeeded(parser.parseOptionalColon())) { + if (parser.parseType(operandType)) { + return failure(); + } + } else { + operandType = QubitType::get(parser.getBuilder().getContext(), + /*isStatic=*/false); + } + newArg.type = operandType; blockArgs.push_back(newArg); } while (succeeded(parser.parseOptionalComma())); @@ -109,6 +120,12 @@ static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, printer.printOperand(entryBlock.getArgument(i)); printer << " = "; printer.printOperand(targetsIn[i]); + // Print inline type when static to preserve isStatic on round-trip + if (auto qubitType = llvm::dyn_cast(targetsIn[i].getType()); + qubitType && qubitType.getIsStatic()) { + printer << " : "; + printer.printType(qubitType); + } } printer << ") "; @@ -202,6 +219,24 @@ void QCODialect::initialize() { // Types //===----------------------------------------------------------------------===// +/// Print `!qco.qubit` (dynamic, default) or `!qco.qubit`. +void QubitType::print(AsmPrinter& printer) const { + if (getIsStatic()) { + printer << ""; + } +} + +/// Parse `!qco.qubit` or `!qco.qubit`. +Type QubitType::parse(AsmParser& parser) { + if (succeeded(parser.parseOptionalLess())) { + if (parser.parseKeyword("static") || parser.parseGreater()) { + return {}; + } + return get(parser.getContext(), /*isStatic=*/true); + } + return get(parser.getContext(), /*isStatic=*/false); +} + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/QCO/IR/QCOOpsTypes.cpp.inc" diff --git a/mlir/lib/Dialect/QCO/Transforms/Mapping/Mapping.cpp b/mlir/lib/Dialect/QCO/Transforms/Mapping/Mapping.cpp index ebd8fa2cf8..10ad23cd55 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Mapping/Mapping.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Mapping/Mapping.cpp @@ -18,17 +18,21 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include +#include #include #include #include +#include #include #include @@ -55,6 +59,288 @@ namespace mlir::qco { #define GEN_PASS_DEF_MAPPINGPASS #include "mlir/Dialect/QCO/Transforms/Passes.h.inc" +//===----------------------------------------------------------------------===// +// After alloc→static placement, operand qubit types match the mapped static +// type but many op results (and ctrl/inv region args) may still be typed as +// plain !qco.qubit. IR is recreated with IRRewriter. +//===----------------------------------------------------------------------===// + +/** True if any declared qubit result type disagrees with the paired operand. */ +[[nodiscard]] static bool +qubitOperandAndResultTypesDiffer(UnitaryOpInterface unitary) { + return llvm::any_of( + llvm::zip_equal(unitary.getInputQubits(), unitary.getOutputQubits()), + [](const auto& io) { + const auto& [input, output] = io; + const auto qIn = dyn_cast(input.getType()); + const auto qOut = dyn_cast(output.getType()); + return qIn && qOut && qIn != qOut; + }); +} + +/** Region entry args vs target operands, for ctrl/inv verifier alignment. */ +[[nodiscard]] static bool +modifierBodyArgsMismatchTargetOperands(Block& body, + OperandRange targetOperands) { + return llvm::any_of(llvm::zip_equal(body.getArguments(), targetOperands), + [](const auto& pair) { + const auto& [arg, target] = pair; + return arg.getType() != target.getType(); + }); +} + +/** + * @brief True if @p ctrl must be rebuilt so qubit types match after placement. + * + * @param ctrl Operation examined: sole region (targets + body args), operands + * (`getControlsIn()`, `getTargetsIn()`), and qubit results + * (`getOutputQubits()`). + * @return True when either (1) a body block argument type differs from the + * matching `getTargetsIn()` operand, or (2) any qubit result type + * differs from the paired input qubit type (via + * @ref modifierBodyArgsMismatchTargetOperands and + * @ref qubitOperandAndResultTypesDiffer). + * + * @details Pure predicate: no IR changes. Callers use this to skip redundant + * `replaceOpWithNewOp` work. + */ +[[nodiscard]] static bool ctrlNeedsQubitTypeResync(CtrlOp ctrl) { + return modifierBodyArgsMismatchTargetOperands(*ctrl.getBody(), + ctrl.getTargetsIn()) || + qubitOperandAndResultTypesDiffer(ctrl); +} + +/** + * @brief True if @p inv must be rebuilt so qubit types match after placement. + * + * @param inv Operation examined: sole region (`getBody()` args vs + * `getQubitsIn()` operands) and paired qubit results. + * @return True when body argument types disagree with `getQubitsIn()`, or qubit + * result types disagree with operands (same predicates as for @ref + * ctrlNeedsQubitTypeResync). + * + * @details Pure predicate: no IR changes. + */ +[[nodiscard]] static bool invNeedsQubitTypeResync(InvOp inv) { + return modifierBodyArgsMismatchTargetOperands(*inv.getBody(), + inv.getQubitsIn()) || + qubitOperandAndResultTypesDiffer(inv); +} + +/** @brief Copies op results into `SmallVector` (OpResult → Value). */ +[[nodiscard]] static llvm::SmallVector opResultsAsValues(Operation* op) { + llvm::SmallVector vals; + llvm::append_range(vals, op->getResults()); + return vals; +} + +/** + * @brief Duplicate @p op with new result types; clone attached regions. + * @param rewriter Insertion point is set to @p op before create. + */ +static Operation* cloneOpWithNewResultTypes(Operation* op, + ArrayRef newResultTypes, + IRRewriter& rewriter) { + assert(op->getNumResults() == newResultTypes.size() && + "result type count must match the operation"); + rewriter.setInsertionPoint(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands(op->getOperands()); + state.addTypes(newResultTypes); + state.addAttributes(op->getAttrs()); + state.addSuccessors(op->getSuccessors()); + for (Region& region : op->getRegions()) { + Region* newRegion = state.addRegion(); + IRMapping mapper; + rewriter.cloneRegionBefore(region, *newRegion, newRegion->end(), mapper); + } + return rewriter.create(state); +} + +[[nodiscard]] static llvm::SmallVector +resyncClonedUnitaryAndGetResults(Operation* cloned, IRRewriter& rewriter); + +/** + * @brief Rebuild a leaf unitary so each qubit result type matches its operand. + * + * Skips ops with nested regions (handled via @ref + * resyncClonedUnitaryAndGetResults). + */ +[[nodiscard]] static llvm::SmallVector +replaceLeafUnitaryWithAlignedQubitTypes(UnitaryOpInterface unitary, + IRRewriter& rewriter) { + Operation* op = unitary.getOperation(); + if (op->getNumRegions() != 0 || !qubitOperandAndResultTypesDiffer(unitary)) { + return opResultsAsValues(op); + } + SmallVector newTypes(op->getResultTypes()); + for (auto [input, output] : + llvm::zip_equal(unitary.getInputQubits(), unitary.getOutputQubits())) { + const auto qIn = dyn_cast(input.getType()); + const auto qOut = dyn_cast(output.getType()); + if (qIn && qOut && qIn != qOut) { + newTypes[llvm::cast(output).getResultNumber()] = qIn; + } + } + Operation* newOp = cloneOpWithNewResultTypes(op, newTypes, rewriter); + rewriter.replaceOp(op, newOp->getResults()); + return opResultsAsValues(newOp); +} + +/** + * @brief Clone all ops in the modifier body before `qco.yield`, then align the + * body unitary’s qubit types. + * + * The body may contain `arith.constant` (or similar) before the single + * `UnitaryOpInterface`; cloning only that unitary would leave uses pointing at + * the old region, which is destroyed when the modifier is replaced. + */ +[[nodiscard]] static llvm::SmallVector +cloneModifierBodyAndResyncUnitary(Block& oldBody, ValueRange newTargetArgs, + IRRewriter& rewriter) { + IRMapping mapping; + for (auto [oldArg, newArg] : + llvm::zip_equal(oldBody.getArguments(), newTargetArgs)) { + mapping.map(oldArg, newArg); + } + Operation* clonedUnitary = nullptr; + for (Operation& op : oldBody) { + if (llvm::isa(op)) { + break; + } + Operation* cloned = rewriter.clone(op, mapping); + for (auto [oldRes, newRes] : + llvm::zip_equal(op.getResults(), cloned->getResults())) { + mapping.map(oldRes, newRes); + } + if (llvm::isa(cloned)) { + clonedUnitary = cloned; + } + } + assert(clonedUnitary != nullptr && + "modifier body must contain a unitary before yield"); + return resyncClonedUnitaryAndGetResults(clonedUnitary, rewriter); +} + +/** + * @brief Replace @p ctrl with a new `qco.ctrl` that keeps the same controls, + * targets, and logical body (ops before `qco.yield` are cloned). + * + * @param ctrl Modifier to replace. Must satisfy the usual `qco.ctrl` verifier + * (single unitary before yield, etc.). + * @param rewriter Drives insertion and erasure; insertion point is updated by + * `replaceOpWithNewOp`. + * @return The new `CtrlOp` whose region holds cloned ops and refreshed types. + * + * @note Side effects: erases @p ctrl and rewires uses to the new op's results; + * runs @ref cloneModifierBodyAndResyncUnitary in the builder callback + * (may recurse for nested ctrl/inv in the body). + */ +[[nodiscard]] static CtrlOp replaceCtrlPreservingBody(CtrlOp ctrl, + IRRewriter& rewriter) { + return rewriter.replaceOpWithNewOp( + ctrl, ctrl.getControlsIn(), ctrl.getTargetsIn(), + [&](ValueRange newArgs) -> llvm::SmallVector { + return cloneModifierBodyAndResyncUnitary(*ctrl.getBody(), newArgs, + rewriter); + }); +} + +/** + * @brief Replace @p inv with a new `qco.inv` that preserves operands and body + * structure (clone ops before `qco.yield`, then align qubit types). + * + * @param inv Modifier to replace; must verify as `qco.inv`. + * @param rewriter Drives insertion and erasure. + * @return The new `InvOp` after `replaceOp` has rewired uses away from @p inv. + * + * @note Side effects: erases @p inv; uses @ref + * cloneModifierBodyAndResyncUnitary in the body callback (may recurse for + * nested modifiers). + */ +[[nodiscard]] static InvOp replaceInvPreservingBody(InvOp inv, + IRRewriter& rewriter) { + rewriter.setInsertionPoint(inv); + InvOp newInv = + InvOp::create(rewriter, inv.getLoc(), inv.getQubitsIn(), + [&](ValueRange newArgs) -> llvm::SmallVector { + return cloneModifierBodyAndResyncUnitary( + *inv.getBody(), newArgs, rewriter); + }); + rewriter.replaceOp(inv, newInv.getResults()); + return newInv; +} + +/** + * @brief If @p cloned is ctrl/inv, rebuild it when qubit types are stale; else + * align a leaf unitary. Returns values that replace @p cloned's results. + */ +[[nodiscard]] static llvm::SmallVector +resyncClonedUnitaryAndGetResults(Operation* cloned, IRRewriter& rewriter) { + if (auto ctrl = dyn_cast(cloned)) { + if (!ctrlNeedsQubitTypeResync(ctrl)) { + return opResultsAsValues(ctrl.getOperation()); + } + return opResultsAsValues( + replaceCtrlPreservingBody(ctrl, rewriter).getOperation()); + } + if (auto inv = dyn_cast(cloned)) { + if (!invNeedsQubitTypeResync(inv)) { + return opResultsAsValues(inv.getOperation()); + } + return opResultsAsValues( + replaceInvPreservingBody(inv, rewriter).getOperation()); + } + return replaceLeafUnitaryWithAlignedQubitTypes( + llvm::cast(cloned), rewriter); +} + +/** + * @brief Walk the function body and fix qubit SSA types after mapping. + * + * @details Post-order allows erasing/replacing the visited op safely. Ctrl and + * Inv are handled before the generic `UnitaryOpInterface` case so nested + * modifiers are not double-processed as leaf unitaries. + */ +static void synchronizeMappedQubitTypes(Region& region, IRRewriter& rewriter) { + region.walk([&](Operation* op) { + rewriter.setInsertionPoint(op); + if (auto ctrl = dyn_cast(op)) { + if (ctrlNeedsQubitTypeResync(ctrl)) { + (void)replaceCtrlPreservingBody(ctrl, rewriter); + } + return WalkResult::advance(); + } + if (auto inv = dyn_cast(op)) { + if (invNeedsQubitTypeResync(inv)) { + (void)replaceInvPreservingBody(inv, rewriter); + } + return WalkResult::advance(); + } + if (auto measure = dyn_cast(op)) { + if (measure.getQubitOut().getType() != measure.getQubitIn().getType()) { + SmallVector newTypes(measure.getResultTypes()); + newTypes[0] = measure.getQubitIn().getType(); + Operation* newOp = + cloneOpWithNewResultTypes(measure, newTypes, rewriter); + rewriter.replaceOp(measure, newOp->getResults()); + } + return WalkResult::advance(); + } + if (auto reset = dyn_cast(op)) { + if (reset.getQubitOut().getType() != reset.getQubitIn().getType()) { + rewriter.replaceOpWithNewOp(reset, reset.getQubitIn()); + } + return WalkResult::advance(); + } + if (auto unitary = dyn_cast(op)) { + (void)replaceLeafUnitaryWithAlignedQubitTypes(unitary, rewriter); + return WalkResult::advance(); + } + return WalkResult::advance(); + }); +} + namespace { struct MappingPass : impl::MappingPassBase { @@ -62,7 +348,6 @@ struct MappingPass : impl::MappingPassBase { using QubitValue = TypedValue; using IndexType = std::size_t; using IndexGate = std::pair; - using IndexGateSet = DenseSet; using Layer = DenseSet; /** @@ -378,6 +663,7 @@ struct MappingPass : impl::MappingPassBase { } commitTrial(*best, dyn, func.getFunctionBody(), rewriter); + synchronizeMappedQubitTypes(func.getFunctionBody(), rewriter); } } @@ -454,8 +740,10 @@ struct MappingPass : impl::MappingPassBase { */ [[nodiscard]] static SmallVector collectDynamicQubits(Region& funcBody) { - return SmallVector(map_range( - funcBody.getOps(), [](AllocOp op) { return op.getResult(); })); + return SmallVector( + map_range(funcBody.getOps(), [](AllocOp op) { + return llvm::cast>(op.getResult()); + })); } /** @@ -487,17 +775,17 @@ struct MappingPass : impl::MappingPassBase { const auto hw = layout.getHardwareIndex(p); rewriter.setInsertionPoint(q.getDefiningOp()); auto op = rewriter.replaceOpWithNewOp(q.getDefiningOp(), hw); - statics[hw] = op.getQubit(); + statics[hw] = llvm::cast>(op.getQubit()); } // 2. Create static qubits for the remaining (unused) hardware indices. for (std::size_t p = dynQubits.size(); p < layout.nqubits(); ++p) { rewriter.setInsertionPointToStart(&funcBody.front()); const auto hw = layout.getHardwareIndex(p); - auto op = StaticOp::create(rewriter, rewriter.getUnknownLoc(), hw); + auto op = StaticOp::create(rewriter, funcBody.getLoc(), hw); rewriter.setInsertionPoint(funcBody.back().getTerminator()); - DeallocOp::create(rewriter, rewriter.getUnknownLoc(), op.getQubit()); - statics[hw] = op.getQubit(); + DeallocOp::create(rewriter, funcBody.getLoc(), op.getQubit()); + statics[hw] = llvm::cast>(op.getQubit()); } return statics; diff --git a/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt index 98aa36b0c5..3be9dcf3f1 100644 --- a/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_dialect_library( PRIVATE MLIRIR MLIRArithDialect + MLIRDialectUtils MLIRInferTypeOpInterface MLIRSideEffectInterfaces PUBLIC diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp index 898b8b6412..335b0e3620 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/AllocOp.cpp @@ -31,9 +31,9 @@ void AllocOp::build(OpBuilder& builder, OperationState& result, Value size) { assert(*sizeValue > 0 && "qtensor.alloc size must be positive"); } - auto resultType = - RankedTensorType::get({sizeValue ? *sizeValue : ShapedType::kDynamic}, - qco::QubitType::get(builder.getContext())); + auto resultType = RankedTensorType::get( + {sizeValue ? *sizeValue : ShapedType::kDynamic}, + qco::QubitType::get(builder.getContext(), /*isStatic=*/false)); build(builder, result, resultType, size); } @@ -57,5 +57,12 @@ LogicalResult AllocOp::verify() { << resultSize << ")"; } + auto elementType = resultType.getElementType(); + if (auto qubitType = dyn_cast(elementType); + qubitType && qubitType.getIsStatic()) { + return emitOpError("qtensor.alloc cannot allocate static qubits; element " + "type must be a dynamic qubit type (!qco.qubit)"); + } + return success(); } diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index a81f6a7c01..5f648fe113 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -193,6 +193,36 @@ INSTANTIATE_TEST_SUITE_P( "StaticQubits", nullptr, MQT_NAMED_BUILDER(mlir::qc::staticQubits), MQT_NAMED_BUILDER(mlir::qc::staticQubits), MQT_NAMED_BUILDER(mlir::qir::staticQubits), false}, + CompilerPipelineTestCase{ + "StaticQubitsWithOps", nullptr, + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithOps), + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithOps), + MQT_NAMED_BUILDER(mlir::qir::staticQubitsWithOps), false}, + CompilerPipelineTestCase{ + "StaticQubitsWithParametricOps", nullptr, + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(mlir::qir::staticQubitsWithParametricOps), false}, + CompilerPipelineTestCase{ + "StaticQubitsWithTwoTargetOps", nullptr, + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(mlir::qir::staticQubitsWithTwoTargetOps), false}, + CompilerPipelineTestCase{ + "StaticQubitsWithCtrl", nullptr, + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithCtrl), + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithCtrl), + MQT_NAMED_BUILDER(mlir::qir::staticQubitsWithCtrl), false}, + CompilerPipelineTestCase{ + "StaticQubitsWithInv", nullptr, + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithInv), + MQT_NAMED_BUILDER(mlir::qc::staticQubitsWithInv), + MQT_NAMED_BUILDER(mlir::qir::staticQubitsWithInv), false}, + CompilerPipelineTestCase{ + "MixedStaticDynamicQubits", nullptr, + MQT_NAMED_BUILDER(mlir::qc::mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(mlir::qc::mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(mlir::qir::mixedStaticDynamicQubits), false}, CompilerPipelineTestCase{"AllocQubit", MQT_NAMED_BUILDER(qc::allocQubit), nullptr, MQT_NAMED_BUILDER(mlir::qc::allocQubit), @@ -415,8 +445,9 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(mlir::qc::p), MQT_NAMED_BUILDER(mlir::qir::p)}, CompilerPipelineTestCase{ - "SingleControlledP", MQT_NAMED_BUILDER(qc::singleControlledP), - nullptr, MQT_NAMED_BUILDER(mlir::qc::singleControlledP), + "SingleControlledP", + MQT_NAMED_BUILDER(qc::singleControlledP), nullptr, + MQT_NAMED_BUILDER(mlir::qc::singleControlledP), MQT_NAMED_BUILDER(mlir::qir::singleControlledP)}, CompilerPipelineTestCase{ "MultipleControlledP", MQT_NAMED_BUILDER(qc::multipleControlledP), diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index bb81b94ecc..3112362ac9 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -113,6 +113,33 @@ TEST_P(QCOToQCTest, ProgramEquivalence) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } +/// \name QCOToQC/QubitManagement/QubitManagement.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCOQubitManagementTest, QCOToQCTest, + testing::Values( + QCOToQCTestCase{"StaticQubits", MQT_NAMED_BUILDER(qco::staticQubits), + MQT_NAMED_BUILDER(qc::staticQubits)}, + QCOToQCTestCase{"StaticQubitsWithOps", + MQT_NAMED_BUILDER(qco::staticQubitsWithOps), + MQT_NAMED_BUILDER(qc::staticQubitsWithOps)}, + QCOToQCTestCase{"StaticQubitsWithParametricOps", + MQT_NAMED_BUILDER(qco::staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(qc::staticQubitsWithParametricOps)}, + QCOToQCTestCase{"StaticQubitsWithTwoTargetOps", + MQT_NAMED_BUILDER(qco::staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(qc::staticQubitsWithTwoTargetOps)}, + QCOToQCTestCase{"StaticQubitsWithCtrl", + MQT_NAMED_BUILDER(qco::staticQubitsWithCtrl), + MQT_NAMED_BUILDER(qc::staticQubitsWithCtrl)}, + QCOToQCTestCase{"StaticQubitsWithInv", + MQT_NAMED_BUILDER(qco::staticQubitsWithInv), + MQT_NAMED_BUILDER(qc::staticQubitsWithInv)}, + QCOToQCTestCase{"MixedStaticDynamicQubits", + MQT_NAMED_BUILDER(qco::mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(qc::mixedStaticDynamicQubits)})); +/// @} + /// \name QCOToQC/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 98d6c9e012..ba43c30b98 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -112,6 +112,33 @@ TEST_P(QCToQCOTest, ProgramEquivalence) { areModulesEquivalentWithPermutations(program.get(), reference.get())); } +/// \name QCToQCO/QubitManagement/StaticOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCStaticOpTest, QCToQCOTest, + testing::Values( + QCToQCOTestCase{"StaticQubits", MQT_NAMED_BUILDER(qc::staticQubits), + MQT_NAMED_BUILDER(qco::staticQubits)}, + QCToQCOTestCase{"StaticQubitsWithOps", + MQT_NAMED_BUILDER(qc::staticQubitsWithOps), + MQT_NAMED_BUILDER(qco::staticQubitsWithOps)}, + QCToQCOTestCase{"StaticQubitsWithParametricOps", + MQT_NAMED_BUILDER(qc::staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(qco::staticQubitsWithParametricOps)}, + QCToQCOTestCase{"StaticQubitsWithTwoTargetOps", + MQT_NAMED_BUILDER(qc::staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(qco::staticQubitsWithTwoTargetOps)}, + QCToQCOTestCase{"StaticQubitsWithCtrl", + MQT_NAMED_BUILDER(qc::staticQubitsWithCtrl), + MQT_NAMED_BUILDER(qco::staticQubitsWithCtrl)}, + QCToQCOTestCase{"StaticQubitsWithInv", + MQT_NAMED_BUILDER(qc::staticQubitsWithInv), + MQT_NAMED_BUILDER(qco::staticQubitsWithInv)}, + QCToQCOTestCase{"MixedStaticDynamicQubits", + MQT_NAMED_BUILDER(qc::mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(qco::mixedStaticDynamicQubits)})); +/// @} + /// \name QCToQCO/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index 82ab251e41..58e832305b 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -624,6 +624,24 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qir::emptyQIR)}, QCToQIRTestCase{"StaticQubits", MQT_NAMED_BUILDER(qc::staticQubits), MQT_NAMED_BUILDER(qir::emptyQIR)}, + QCToQIRTestCase{"StaticQubitsWithOps", + MQT_NAMED_BUILDER(qc::staticQubitsWithOps), + MQT_NAMED_BUILDER(qir::staticQubitsWithOps)}, + QCToQIRTestCase{"StaticQubitsWithParametricOps", + MQT_NAMED_BUILDER(qc::staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(qir::staticQubitsWithParametricOps)}, + QCToQIRTestCase{"StaticQubitsWithTwoTargetOps", + MQT_NAMED_BUILDER(qc::staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(qir::staticQubitsWithTwoTargetOps)}, + QCToQIRTestCase{"StaticQubitsWithCtrl", + MQT_NAMED_BUILDER(qc::staticQubitsWithCtrl), + MQT_NAMED_BUILDER(qir::staticQubitsWithCtrl)}, + QCToQIRTestCase{"StaticQubitsWithInv", + MQT_NAMED_BUILDER(qc::staticQubitsWithInv), + MQT_NAMED_BUILDER(qir::staticQubitsWithInv)}, + QCToQIRTestCase{"MixedStaticDynamicQubits", + MQT_NAMED_BUILDER(qc::mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(qir::mixedStaticDynamicQubits)}, QCToQIRTestCase{"AllocDeallocPair", MQT_NAMED_BUILDER(qc::allocDeallocPair), MQT_NAMED_BUILDER(qir::emptyQIR)})); diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 221b750f7e..33b8f0ac75 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -899,6 +899,24 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(emptyQC)}, QCTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), MQT_NAMED_BUILDER(emptyQC)}, + QCTestCase{"StaticQubitsWithOps", + MQT_NAMED_BUILDER(staticQubitsWithOps), + MQT_NAMED_BUILDER(staticQubitsWithOps)}, + QCTestCase{"StaticQubitsWithParametricOps", + MQT_NAMED_BUILDER(staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(staticQubitsWithParametricOps)}, + QCTestCase{"StaticQubitsWithTwoTargetOps", + MQT_NAMED_BUILDER(staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(staticQubitsWithTwoTargetOps)}, + QCTestCase{"StaticQubitsWithCtrl", + MQT_NAMED_BUILDER(staticQubitsWithCtrl), + MQT_NAMED_BUILDER(staticQubitsWithCtrl)}, + QCTestCase{"StaticQubitsWithInv", + MQT_NAMED_BUILDER(staticQubitsWithInv), + MQT_NAMED_BUILDER(staticQubitsWithInv)}, + QCTestCase{"MixedStaticDynamicQubits", + MQT_NAMED_BUILDER(mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(mixedStaticDynamicQubits)}, QCTestCase{"AllocDeallocPair", MQT_NAMED_BUILDER(allocDeallocPair), MQT_NAMED_BUILDER(emptyQC)})); /// @} diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index ad4981d1b7..aad6ebe783 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -1060,6 +1060,24 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(emptyQCO)}, QCOTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), MQT_NAMED_BUILDER(emptyQCO)}, + QCOTestCase{"StaticQubitsWithOps", + MQT_NAMED_BUILDER(staticQubitsWithOps), + MQT_NAMED_BUILDER(staticQubitsWithOps)}, + QCOTestCase{"StaticQubitsWithParametricOps", + MQT_NAMED_BUILDER(staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(staticQubitsWithParametricOps)}, + QCOTestCase{"StaticQubitsWithTwoTargetOps", + MQT_NAMED_BUILDER(staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(staticQubitsWithTwoTargetOps)}, + QCOTestCase{"StaticQubitsWithCtrl", + MQT_NAMED_BUILDER(staticQubitsWithCtrl), + MQT_NAMED_BUILDER(staticQubitsWithCtrl)}, + QCOTestCase{"StaticQubitsWithInv", + MQT_NAMED_BUILDER(staticQubitsWithInv), + MQT_NAMED_BUILDER(staticQubitsWithInv)}, + QCOTestCase{"MixedStaticDynamicQubits", + MQT_NAMED_BUILDER(mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(mixedStaticDynamicQubits)}, QCOTestCase{"AllocDeallocPair", MQT_NAMED_BUILDER(allocDeallocPair), MQT_NAMED_BUILDER(emptyQCO)})); /// @} diff --git a/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt index 3ea628104a..1ced00dc34 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt @@ -14,6 +14,7 @@ target_link_libraries( PRIVATE MLIRParser GTest::gtest_main MLIRQCProgramBuilder + MLIRQCOProgramBuilder MLIRQCOUtils MLIRQCToQCO MLIRQCOToQC diff --git a/mlir/unittests/Dialect/QCO/Transforms/Mapping/test_mapping.cpp b/mlir/unittests/Dialect/QCO/Transforms/Mapping/test_mapping.cpp index 8b65bf57f6..5e465716ab 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Mapping/test_mapping.cpp +++ b/mlir/unittests/Dialect/QCO/Transforms/Mapping/test_mapping.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCInterfaces.h" #include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/Transforms/Mapping/Architecture.h" #include "mlir/Dialect/QCO/Transforms/Passes.h" @@ -27,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -45,8 +48,7 @@ struct ArchitectureParam { Architecture (*factory)(); }; -class MappingPassTest : public testing::Test, - public testing::WithParamInterface { +class MappingPassTestBase : public testing::Test { public: /** * @brief Walks the IR and validates if each two-qubit op is executable on the @@ -115,10 +117,151 @@ class MappingPassTest : public testing::Test, ASSERT_TRUE(succeeded(res)); } + static void runQCOMapping(OwningOpRef& moduleOp) { + PassManager pm(moduleOp->getContext()); + pm.addPass(qco::createMappingPass(qco::MappingPassOptions{.nlookahead = 5, + .alpha = 1, + .lambda = 0.85, + .niterations = 2, + .ntrials = 8, + .seed = 1337})); + auto res = pm.run(*moduleOp); + ASSERT_TRUE(succeeded(res)); + ASSERT_TRUE(succeeded(verify(*moduleOp))); + } + + static void expectSameStaticQubitType(Value input, Value output) { + const auto qIn = dyn_cast(input.getType()); + const auto qOut = dyn_cast(output.getType()); + ASSERT_TRUE(qIn); + ASSERT_TRUE(qOut); + EXPECT_EQ(qIn, qOut); + EXPECT_TRUE(qOut.getIsStatic()); + } + std::unique_ptr context; }; + +class MappingPassTest : public MappingPassTestBase, + public testing::WithParamInterface { +}; }; // namespace +TEST_F(MappingPassTestBase, SynchronizesParameterizedGateAndMeasureTypes) { + qco::QCOProgramBuilder builder(context.get()); + builder.initialize(); + + auto q0 = builder.allocQubit(); + auto q1 = builder.allocQubit(); + + q0 = builder.rx(0.25, q0); + std::tie(q0, q1) = builder.rxx(0.5, q0, q1); + std::tie(q0, q1) = builder.xx_plus_yy(0.75, 1.25, q0, q1); + const auto [measuredQubit, measuredBit] = builder.measure(q0); + (void)measuredBit; + q0 = measuredQubit; + + builder.dealloc(q0); + builder.dealloc(q1); + + auto moduleOp = builder.finalize(); + runQCOMapping(moduleOp); + + auto entry = *moduleOp->getOps().begin(); + qco::RXOp rxOp; + qco::RXXOp rxxOp; + qco::XXPlusYYOp xxPlusYYOp; + qco::MeasureOp measureOp; + entry.walk([&](Operation* op) { + if (auto rx = dyn_cast(op)) { + rxOp = rx; + } else if (auto rxx = dyn_cast(op)) { + rxxOp = rxx; + } else if (auto xxPlusYY = dyn_cast(op)) { + xxPlusYYOp = xxPlusYY; + } else if (auto measure = dyn_cast(op)) { + measureOp = measure; + } + }); + + ASSERT_TRUE(rxOp); + ASSERT_TRUE(rxxOp); + ASSERT_TRUE(xxPlusYYOp); + ASSERT_TRUE(measureOp); + + expectSameStaticQubitType(rxOp.getQubitIn(), rxOp.getQubitOut()); + expectSameStaticQubitType(rxxOp.getQubit0In(), rxxOp.getQubit0Out()); + expectSameStaticQubitType(rxxOp.getQubit1In(), rxxOp.getQubit1Out()); + expectSameStaticQubitType(xxPlusYYOp.getQubit0In(), + xxPlusYYOp.getQubit0Out()); + expectSameStaticQubitType(xxPlusYYOp.getQubit1In(), + xxPlusYYOp.getQubit1Out()); + expectSameStaticQubitType(measureOp.getQubitIn(), measureOp.getQubitOut()); +} + +TEST_F(MappingPassTestBase, SynchronizesModifierRegionArguments) { + qco::QCOProgramBuilder builder(context.get()); + builder.initialize(); + + auto q0 = builder.allocQubit(); + auto q1 = builder.allocQubit(); + + auto [controlsOut, targetsOut] = builder.ctrl( + {q0}, {q1}, [&](ValueRange targets) -> llvm::SmallVector { + return {builder.rx(0.5, targets[0])}; + }); + q0 = controlsOut.front(); + q1 = targetsOut.front(); + + auto invOut = + builder.inv({q1}, [&](ValueRange qubits) -> llvm::SmallVector { + return {builder.rx(0.25, qubits[0])}; + }); + q1 = invOut.front(); + + builder.dealloc(q0); + builder.dealloc(q1); + + auto moduleOp = builder.finalize(); + runQCOMapping(moduleOp); + + auto entry = *moduleOp->getOps().begin(); + qco::CtrlOp ctrlOp; + qco::InvOp invOp; + qco::RXOp ctrlBodyRxOp; + qco::RXOp invBodyRxOp; + entry.walk([&](Operation* op) { + if (auto ctrl = dyn_cast(op)) { + ctrlOp = ctrl; + } else if (auto inv = dyn_cast(op)) { + invOp = inv; + } else if (auto rx = dyn_cast(op)) { + if (rx->getParentOfType()) { + ctrlBodyRxOp = rx; + } else if (rx->getParentOfType()) { + invBodyRxOp = rx; + } + } + }); + + ASSERT_TRUE(ctrlOp); + ASSERT_TRUE(invOp); + ASSERT_TRUE(ctrlBodyRxOp); + ASSERT_TRUE(invBodyRxOp); + + ASSERT_EQ(ctrlOp.getBody()->getNumArguments(), ctrlOp.getNumTargets()); + expectSameStaticQubitType(ctrlOp.getTargetsIn()[0], + ctrlOp.getBody()->getArgument(0)); + expectSameStaticQubitType(ctrlBodyRxOp.getQubitIn(), + ctrlBodyRxOp.getQubitOut()); + + ASSERT_EQ(invOp.getBody()->getNumArguments(), invOp.getNumTargets()); + expectSameStaticQubitType(invOp.getQubitsIn()[0], + invOp.getBody()->getArgument(0)); + expectSameStaticQubitType(invBodyRxOp.getQubitIn(), + invBodyRxOp.getQubitOut()); +} + TEST_P(MappingPassTest, GHZ) { auto arch = GetParam().factory(); diff --git a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp index e177c90329..157fb5cabc 100644 --- a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp +++ b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp @@ -522,17 +522,34 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QIRQubitManagementTest, QIRTest, - testing::Values(QIRTestCase{"AllocQubit", MQT_NAMED_BUILDER(allocQubit), - MQT_NAMED_BUILDER(allocQubit)}, - QIRTestCase{"AllocQubitRegister", - MQT_NAMED_BUILDER(allocQubitRegister), - MQT_NAMED_BUILDER(allocQubitRegister)}, - QIRTestCase{"AllocMultipleQubitRegisters", - MQT_NAMED_BUILDER(allocMultipleQubitRegisters), - MQT_NAMED_BUILDER(allocMultipleQubitRegisters)}, - QIRTestCase{"AllocLargeRegister", - MQT_NAMED_BUILDER(allocLargeRegister), - MQT_NAMED_BUILDER(allocLargeRegister)}, - QIRTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), - MQT_NAMED_BUILDER(staticQubits)})); + testing::Values( + QIRTestCase{"AllocQubit", MQT_NAMED_BUILDER(allocQubit), + MQT_NAMED_BUILDER(allocQubit)}, + QIRTestCase{"AllocQubitRegister", MQT_NAMED_BUILDER(allocQubitRegister), + MQT_NAMED_BUILDER(allocQubitRegister)}, + QIRTestCase{"AllocMultipleQubitRegisters", + MQT_NAMED_BUILDER(allocMultipleQubitRegisters), + MQT_NAMED_BUILDER(allocMultipleQubitRegisters)}, + QIRTestCase{"AllocLargeRegister", MQT_NAMED_BUILDER(allocLargeRegister), + MQT_NAMED_BUILDER(allocLargeRegister)}, + QIRTestCase{"StaticQubits", MQT_NAMED_BUILDER(staticQubits), + MQT_NAMED_BUILDER(staticQubits)}, + QIRTestCase{"StaticQubitsWithOps", + MQT_NAMED_BUILDER(staticQubitsWithOps), + MQT_NAMED_BUILDER(staticQubitsWithOps)}, + QIRTestCase{"StaticQubitsWithParametricOps", + MQT_NAMED_BUILDER(staticQubitsWithParametricOps), + MQT_NAMED_BUILDER(staticQubitsWithParametricOps)}, + QIRTestCase{"StaticQubitsWithTwoTargetOps", + MQT_NAMED_BUILDER(staticQubitsWithTwoTargetOps), + MQT_NAMED_BUILDER(staticQubitsWithTwoTargetOps)}, + QIRTestCase{"StaticQubitsWithCtrl", + MQT_NAMED_BUILDER(staticQubitsWithCtrl), + MQT_NAMED_BUILDER(staticQubitsWithCtrl)}, + QIRTestCase{"StaticQubitsWithInv", + MQT_NAMED_BUILDER(staticQubitsWithInv), + MQT_NAMED_BUILDER(staticQubitsWithInv)}, + QIRTestCase{"MixedStaticDynamicQubits", + MQT_NAMED_BUILDER(mixedStaticDynamicQubits), + MQT_NAMED_BUILDER(mixedStaticDynamicQubits)})); /// @} diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 2134b1f368..c8c5d7d9c4 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -34,6 +34,45 @@ void staticQubits(QCProgramBuilder& b) { b.staticQubit(1); } +void staticQubitsWithOps(QCProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.h(q0); + b.h(q1); +} + +void staticQubitsWithParametricOps(QCProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.rx(std::numbers::pi / 4., q0); + b.p(std::numbers::pi / 2., q1); +} + +void staticQubitsWithTwoTargetOps(QCProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.rzz(0.123, q0, q1); +} + +void staticQubitsWithCtrl(QCProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.cx(q0, q1); +} + +void staticQubitsWithInv(QCProgramBuilder& b) { + auto q0 = b.staticQubit(0); + b.inv([&]() { b.t(q0); }); +} + +void mixedStaticDynamicQubits(QCProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.allocQubit(); + b.rzz(0.123, q0, q1); + b.h(q0); + b.h(q1); +} + void allocDeallocPair(QCProgramBuilder& b) { auto q = b.allocQubit(); b.dealloc(q); diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index 21225c5b44..a0c379197c 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -33,6 +33,24 @@ void allocLargeRegister(QCProgramBuilder& b); /// Allocates two inline qubits. void staticQubits(QCProgramBuilder& b); +/// Allocates two static qubits and applies operations. +void staticQubitsWithOps(QCProgramBuilder& b); + +/// Allocates two static qubits and applies parametric gates. +void staticQubitsWithParametricOps(QCProgramBuilder& b); + +/// Allocates two static qubits and applies a two-target gate. +void staticQubitsWithTwoTargetOps(QCProgramBuilder& b); + +/// Allocates two static qubits and applies a controlled gate. +void staticQubitsWithCtrl(QCProgramBuilder& b); + +/// Allocates a static qubit and applies an inverse modifier. +void staticQubitsWithInv(QCProgramBuilder& b); + +/// Allocates one static and one dynamic qubit and applies mixed operations. +void mixedStaticDynamicQubits(QCProgramBuilder& b); + /// Allocates and explicitly deallocates a single qubit. void allocDeallocPair(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 1ef16df6d0..5c2995f34a 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -38,6 +38,47 @@ void staticQubits(QCOProgramBuilder& b) { b.staticQubit(1); } +void staticQubitsWithOps(QCOProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + q0 = b.h(q0); + q1 = b.h(q1); +} + +void staticQubitsWithParametricOps(QCOProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + q0 = b.rx(std::numbers::pi / 4., q0); + q1 = b.p(std::numbers::pi / 2., q1); +} + +void staticQubitsWithTwoTargetOps(QCOProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + std::tie(q0, q1) = b.rzz(0.123, q0, q1); +} + +void staticQubitsWithCtrl(QCOProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + std::tie(q0, q1) = b.cx(q0, q1); +} + +void staticQubitsWithInv(QCOProgramBuilder& b) { + auto q0 = b.staticQubit(0); + q0 = b.inv({q0}, [&](auto targets) -> llvm::SmallVector { + return {b.t(targets[0])}; + })[0]; +} + +void mixedStaticDynamicQubits(QCOProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.allocQubit(); + std::tie(q0, q1) = b.rzz(0.123, q0, q1); + q0 = b.h(q0); + q1 = b.h(q1); +} + void allocDeallocPair(QCOProgramBuilder& b) { auto q = b.allocQubit(); b.dealloc(q); diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index ba26e42969..29768bb731 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -33,6 +33,24 @@ void allocLargeRegister(QCOProgramBuilder& b); /// Allocates two inline qubits. void staticQubits(QCOProgramBuilder& b); +/// Allocates two static qubits and applies operations. +void staticQubitsWithOps(QCOProgramBuilder& b); + +/// Allocates two static qubits and applies parametric gates. +void staticQubitsWithParametricOps(QCOProgramBuilder& b); + +/// Allocates two static qubits and applies a two-target gate. +void staticQubitsWithTwoTargetOps(QCOProgramBuilder& b); + +/// Allocates two static qubits and applies a controlled gate. +void staticQubitsWithCtrl(QCOProgramBuilder& b); + +/// Allocates a static qubit and applies an inverse modifier. +void staticQubitsWithInv(QCOProgramBuilder& b); + +/// Allocates one static and one dynamic qubit and applies mixed operations. +void mixedStaticDynamicQubits(QCOProgramBuilder& b); + /// Allocates and explicitly deallocates a single qubit. void allocDeallocPair(QCOProgramBuilder& b); diff --git a/mlir/unittests/programs/qir_programs.cpp b/mlir/unittests/programs/qir_programs.cpp index 883cf59af8..fa1102d04d 100644 --- a/mlir/unittests/programs/qir_programs.cpp +++ b/mlir/unittests/programs/qir_programs.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/QIR/Builder/QIRProgramBuilder.h" +#include + namespace mlir::qir { void emptyQIR([[maybe_unused]] QIRProgramBuilder& builder) {} @@ -32,6 +34,46 @@ void staticQubits(QIRProgramBuilder& b) { b.staticQubit(1); } +void staticQubitsWithOps(QIRProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.h(q0); + b.h(q1); +} + +void staticQubitsWithParametricOps(QIRProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.rx(std::numbers::pi / 4., q0); + b.p(std::numbers::pi / 2., q1); +} + +void staticQubitsWithTwoTargetOps(QIRProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.rzz(0.123, q0, q1); +} + +void staticQubitsWithCtrl(QIRProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto q1 = b.staticQubit(1); + b.cx(q0, q1); +} + +void staticQubitsWithInv(QIRProgramBuilder& b) { + auto q0 = b.staticQubit(0); + b.tdg(q0); +} + +void mixedStaticDynamicQubits(QIRProgramBuilder& b) { + auto q0 = b.staticQubit(0); + auto qDyn = b.allocQubitRegister(1); + auto q1 = qDyn[0]; + b.rzz(0.123, q0, q1); + b.h(q0); + b.h(q1); +} + void singleMeasurementToSingleBit(QIRProgramBuilder& b) { auto q = b.allocQubitRegister(1); const auto c = b.allocClassicalBitRegister(1); diff --git a/mlir/unittests/programs/qir_programs.h b/mlir/unittests/programs/qir_programs.h index f379f19785..d14a5d57dc 100644 --- a/mlir/unittests/programs/qir_programs.h +++ b/mlir/unittests/programs/qir_programs.h @@ -33,6 +33,24 @@ void allocLargeRegister(QIRProgramBuilder& b); /// Allocates two inline qubits. void staticQubits(QIRProgramBuilder& b); +/// Allocates two static qubits and applies operations. +void staticQubitsWithOps(QIRProgramBuilder& b); + +/// Allocates two static qubits and applies parametric gates. +void staticQubitsWithParametricOps(QIRProgramBuilder& b); + +/// Allocates two static qubits and applies a two-target gate. +void staticQubitsWithTwoTargetOps(QIRProgramBuilder& b); + +/// Allocates two static qubits and applies a controlled gate. +void staticQubitsWithCtrl(QIRProgramBuilder& b); + +/// Allocates a static qubit and applies the inverse of a T gate (Tdg). +void staticQubitsWithInv(QIRProgramBuilder& b); + +/// Allocates one static and one dynamic qubit and applies mixed operations. +void mixedStaticDynamicQubits(QIRProgramBuilder& b); + // --- MeasureOp ------------------------------------------------------------ // /// Measures a single qubit into a single classical bit.