Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 7 additions & 216 deletions crates/transpiler/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,30 +296,26 @@ where
}

/// This is the internal structure for the Python CommutationChecker class
/// It handles the actual commutation checking, cache management, and library
/// lookups. It's not meant to be a public facing Python object though and only used
/// It handles the actual commutation checking, and library lookups. It's
/// not meant to be a public facing Python object though and only used
/// internally by the Python class.
#[pyclass(module = "qiskit._accelerate.commutation_checker")]
pub struct CommutationChecker {
library: CommutationLibrary,
cache_max_entries: usize,
cache: HashMap<(String, String), CommutationCacheEntry>,
current_cache_entries: usize,
#[pyo3(get)]
gates: Option<HashSet<String>>,
}

#[pymethods]
impl CommutationChecker {
#[pyo3(signature = (standard_gate_commutations=None, cache_max_entries=1_000_000, gates=None))]
#[pyo3(signature = (standard_gate_commutations=None, gates=None))]
#[new]
fn py_new(
standard_gate_commutations: Option<Bound<PyAny>>,
cache_max_entries: usize,
gates: Option<HashSet<String>>,
) -> Self {
let library = CommutationLibrary::py_new(standard_gate_commutations);
CommutationChecker::new(Some(library), cache_max_entries, gates)
CommutationChecker::new(Some(library), gates)
}

#[pyo3(signature=(op1, op2, max_num_qubits=None, approximation_degree=1., matrix_max_num_qubits=3))]
Expand Down Expand Up @@ -388,52 +384,18 @@ impl CommutationChecker {
)?)
}

/// Return the current number of cache entries
fn num_cached_entries(&self) -> usize {
self.current_cache_entries
}

/// Clear the cache
fn clear_cached_commutations(&mut self) {
self.clear_cache()
}

fn __getstate__(&self, py: Python) -> PyResult<Py<PyDict>> {
let out_dict = PyDict::new(py);
out_dict.set_item("cache_max_entries", self.cache_max_entries)?;
out_dict.set_item("current_cache_entries", self.current_cache_entries)?;
let cache_dict = PyDict::new(py);
for (key, value) in &self.cache {
cache_dict.set_item(key, commutation_entry_to_pydict(py, value)?)?;
}
out_dict.set_item("cache", cache_dict)?;
out_dict.set_item("library", self.library.library.clone().into_pyobject(py)?)?;
out_dict.set_item("gates", self.gates.clone())?;
Ok(out_dict.unbind())
}

fn __setstate__(&mut self, py: Python, state: Py<PyAny>) -> PyResult<()> {
let dict_state = state.cast_bound::<PyDict>(py)?;
self.cache_max_entries = dict_state
.get_item("cache_max_entries")?
.unwrap()
.extract()?;
self.current_cache_entries = dict_state
.get_item("current_cache_entries")?
.unwrap()
.extract()?;
self.library = CommutationLibrary {
library: dict_state.get_item("library")?.unwrap().extract()?,
};
let raw_cache: Bound<PyDict> = dict_state.get_item("cache")?.unwrap().extract()?;
self.cache = HashMap::with_capacity(raw_cache.len());
for (key, value) in raw_cache.iter() {
let value_dict: &Bound<PyDict> = value.cast()?;
self.cache.insert(
key.extract()?,
commutation_cache_entry_from_pydict(value_dict)?,
);
}
self.gates = dict_state.get_item("gates")?.unwrap().extract()?;
Ok(())
}
Expand All @@ -444,21 +406,13 @@ impl CommutationChecker {
///
/// # Arguments
///
/// - `library`: An optional existing [CommutationLibrary] with cached entries.
/// - `cache_max_entries`: The maximum size of the cache.
/// - `library`: An optional existing [CommutationLibrary].
/// - `gates`: An optional set of gates (by name) to check commutations for. If `None`,
/// commutation is cached and checked for all gates.
pub fn new(
library: Option<CommutationLibrary>,
cache_max_entries: usize,
gates: Option<HashSet<String>>,
) -> Self {
/// commutation is checked for all gates.
pub fn new(library: Option<CommutationLibrary>, gates: Option<HashSet<String>>) -> Self {
// Initialize sets before they are used in the commutation checker
CommutationChecker {
library: library.unwrap_or(CommutationLibrary { library: None }),
cache: HashMap::new(),
cache_max_entries,
current_cache_entries: 0,
gates,
}
}
Expand Down Expand Up @@ -586,39 +540,6 @@ impl CommutationChecker {
(qargs1, qargs2)
};

// For our cache to work correctly, we require the gate's definition to only depend on the
// ``params`` attribute. This cannot be guaranteed for custom gates, so we only check
// the cache for
// * gates we know are in the cache (SUPPORTED_OPS), or
// * standard gates with float params (otherwise we cannot cache them)
let is_cachable = |op: &OperationRef, params: &[Param]| {
if let OperationRef::StandardGate(gate) = op {
SUPPORTED_OP[(*gate) as usize]
|| params.iter().all(|p| matches!(p, Param::Float(_)))
} else {
false
}
};
let check_cache =
is_cachable(first_op, first_params) && is_cachable(second_op, second_params);

if !check_cache {
// The arguments are sorted, so if first_qargs.len() > matrix_max_num_qubits, then
// second_qargs.len() > matrix_max_num_qubits as well.
if second_qargs.len() > matrix_max_num_qubits as usize {
return Ok(false);
}
return self.commute_matmul(
first_op,
first_params,
first_qargs,
second_op,
second_params,
second_qargs,
tol,
);
}

// Query commutation library
let relative_placement = get_relative_placement(first_qargs, second_qargs);
if let Some(is_commuting) =
Expand All @@ -628,19 +549,6 @@ impl CommutationChecker {
return Ok(is_commuting);
}

// Query cache
let key1 = hashable_params(first_params)?;
let key2 = hashable_params(second_params)?;
if let Some(commutation_dict) = self
.cache
.get(&(first_op.name().to_string(), second_op.name().to_string()))
{
let hashes = (key1.clone(), key2.clone());
if let Some(commutation) = commutation_dict.get(&(relative_placement.clone(), hashes)) {
return Ok(*commutation);
}
}

if second_qargs.len() > matrix_max_num_qubits as usize {
return Ok(false);
}
Expand All @@ -656,25 +564,6 @@ impl CommutationChecker {
tol,
)?;

// TODO: implement a LRU cache for this
if self.current_cache_entries >= self.cache_max_entries {
self.clear_cache();
}
// Cache results from is_commuting
self.cache
.entry((first_op.name().to_string(), second_op.name().to_string()))
.and_modify(|entries| {
let key = (relative_placement.clone(), (key1.clone(), key2.clone()));
entries.insert(key, is_commuting);
self.current_cache_entries += 1;
})
.or_insert_with(|| {
let mut entries = HashMap::with_capacity(1);
let key = (relative_placement, (key1, key2));
entries.insert(key, is_commuting);
self.current_cache_entries += 1;
entries
});
Ok(is_commuting)
}

Expand Down Expand Up @@ -773,11 +662,6 @@ impl CommutationChecker {
let matrix_tol = tol;
Ok(phase.abs() <= tol && (1.0 - fid).abs() <= matrix_tol)
}

fn clear_cache(&mut self) {
self.cache.clear();
self.current_cache_entries = 0;
}
}

/// A pre-check status.
Expand Down Expand Up @@ -1085,104 +969,11 @@ impl<'a, 'py> FromPyObject<'a, 'py> for CommutationLibraryEntry {
}
}

type CacheKey = (
SmallVec<[Option<Qubit>; 2]>,
(SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>),
);

type CommutationCacheEntry = HashMap<CacheKey, bool>;

fn commutation_entry_to_pydict(py: Python, entry: &CommutationCacheEntry) -> PyResult<Py<PyDict>> {
let out_dict = PyDict::new(py);
for (k, v) in entry.iter() {
let qubits = PyTuple::new(py, k.0.iter().map(|q| q.map(|t| t.0)))?;
let params0 = PyTuple::new(py, k.1.0.iter().map(|pk| pk.0))?;
let params1 = PyTuple::new(py, k.1.1.iter().map(|pk| pk.0))?;
out_dict.set_item(
PyTuple::new(py, [qubits, PyTuple::new(py, [params0, params1])?])?,
PyBool::new(py, *v),
)?;
}
Ok(out_dict.unbind())
}

fn commutation_cache_entry_from_pydict(dict: &Bound<PyDict>) -> PyResult<CommutationCacheEntry> {
let mut ret = hashbrown::HashMap::with_capacity(dict.len());
for (k, v) in dict {
let raw_key: CacheKeyRaw = k.extract()?;
let qubits = raw_key.0.iter().map(|q| q.map(Qubit)).collect();
let params0: SmallVec<_> = raw_key.1.0;
let params1: SmallVec<_> = raw_key.1.1;
let v: bool = v.extract()?;
ret.insert((qubits, (params0, params1)), v);
}
Ok(ret)
}

type CacheKeyRaw = (
SmallVec<[Option<u32>; 2]>,
(SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>),
);

/// This newtype wraps a f64 to make it hashable so we can cache parameterized gates
/// based on the parameter value (assuming it's a float angle). However, Rust doesn't do
/// this by default and there are edge cases to track around it's usage. The biggest one
/// is this does not work with f64::NAN, f64::INFINITY, or f64::NEG_INFINITY
/// If you try to use these values with this type they will not work as expected.
/// This should only be used with the cache hashmap's keys and not used beyond that.
#[derive(Debug, Copy, Clone, PartialEq, FromPyObject)]
struct ParameterKey(f64);

impl ParameterKey {
fn key(&self) -> u64 {
// If we get a -0 the to_bits() return is not equivalent to 0
// because -0 has the sign bit set we'd be hashing 9223372036854775808
// and be storing it separately from 0. So this normalizes all 0s to
// be represented by 0
if self.0 == 0. { 0 } else { self.0.to_bits() }
}
}

impl std::hash::Hash for ParameterKey {
fn hash<H>(&self, state: &mut H)
where
H: std::hash::Hasher,
{
self.key().hash(state)
}
}

impl Eq for ParameterKey {}

fn hashable_params(params: &[Param]) -> Result<SmallVec<[ParameterKey; 3]>, CommutationError> {
params
.iter()
.map(|x| {
if let Param::Float(x) = x {
// NaN and Infinity (negative or positive) are not valid
// parameter values and our hacks to store parameters in
// the cache HashMap don't take these into account. So return
// an error to Python if we encounter these values.
if x.is_nan() || x.is_infinite() {
Err(CommutationError::HashingNaN)
} else {
Ok(ParameterKey(*x))
}
} else {
Err(CommutationError::HashingParameter)
}
})
.collect()
}

#[pyfunction]
pub fn get_standard_commutation_checker() -> CommutationChecker {
let library = standard_gates_commutations::get_commutation_library();
CommutationChecker {
library,
cache_max_entries: 1_000_000,
cache: HashMap::new(),
current_cache_entries: 0,
gates: None,
}
}
Expand Down
23 changes: 16 additions & 7 deletions qiskit/circuit/commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from qiskit.circuit.operation import Operation
from qiskit.circuit import Qubit
from qiskit._accelerate.commutation_checker import CommutationChecker as RustChecker
from qiskit.utils import deprecate_arg, deprecate_func


class CommutationChecker:
Expand All @@ -45,18 +46,18 @@ class CommutationChecker:
gates with free parameters (such as :class:`.RXGate` with a :class:`.ParameterExpression` as
angle). Otherwise, a matrix-based check is performed, where two operations are said to
commute, if the average gate fidelity of performing the commutation is above a certain threshold
(see ``approximation_degree``). The result of this commutation is then added to the
cached lookup table.
(see ``approximation_degree``).
"""

@deprecate_arg("cache_max_entries", since="2.5.0", removal_timeline="in Qiskit 3.0")
def __init__(
self,
standard_gate_commutations: dict | None = None,
cache_max_entries: int = 10**6,
Comment thread
Cryoris marked this conversation as resolved.
*,
gates: set[str] | None = None,
):
self.cc = RustChecker(standard_gate_commutations, cache_max_entries, gates)
self.cc = RustChecker(standard_gate_commutations, gates)

def commute_nodes(
self,
Expand Down Expand Up @@ -118,13 +119,21 @@ def commute(
matrix_max_num_qubits,
)

@deprecate_func(since="2.5", removal_timeline="in Qiskit 3.0")
def num_cached_entries(self):
"""Returns number of cached entries"""
return self.cc.num_cached_entries()
"""Returns number of cached entries

This method will always return 0 because there is no longer an
internal cache.
"""
return 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a backward compatible API change, since the commutation checker does give the same results, just might take longer, but we are dropping functionality user's might've relied upon. I think it would be nice to point this out more explicitly in the release notes, what do you think?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can reword the release note to make it more explicit. Do you want me to add an other note to call it out more?

Functionally the caching only did anything in a very specific case, you had a gate pair of StandardGate that had all float parameters (which includes no params), and at least one gate we didn't account for in the library or another mechanism. You then were repeating the same checks with these gates repeatedly. Outside of this specific case we never cached anything. So I'm not sure how people could be relying on it in practice.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewording is good enough, just to point out users would have to do their own caching if they relied on that 👍🏻 I also don't think it's a commonly used thing, but the past has shown that every minuscule feature is used somehow by someone 😄


@deprecate_func(since="2.5", removal_timeline="in Qiskit 3.0")
def clear_cached_commutations(self):
"""Clears the dictionary holding cached commutations"""
self.cc.clear_cached_commutations()
"""Clears the dictionary holding cached commutations

This method is a no-op as there is no longer an internal cache
"""

def check_commutation_entries(
self,
Expand Down
15 changes: 15 additions & 0 deletions releasenotes/notes/remove_commutation_cache-545d7a087023f735.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
deprecations_circuits:
- The ``cache_max_entries`` argument on the the :class:`.CommutationChecker` class's
constructor is deprecated and will be removed in Qiskit 3.0.0. This argument no longer has
any effect because the :class:`.CommutationChecker` no longer maintains an internal cache
of commutation relationships between gates as it is no longer necessary.
- The :meth:`.CommutationChecker.clear_cached_commutations` method is deprecated and will be
removed in Qiskit 3.0.0. This method no longer has any effect because the
internal cache was removed from the :class:`.CommutationChecker` class as
it was no longer necessary so there is nothing to clear.
- The :meth:`.CommutationChecker.num_cached_entries` method is deprecated
and will be removed in Qiskit 3.0.0. Since the removal of the internal
cache from the :class:`.CommutationChecker` this method always returns 0
because there are no internally cached entries in a :class:`.CommutationChecker`
instance.
Loading