Skip to content

Commit 79ed589

Browse files
authored
Merge pull request #154 from jackraymond/feature/specify_SRT
Add srts option to SpinReversalComposite
2 parents f853a5c + 3f3e3a9 commit 79ed589

File tree

3 files changed

+109
-25
lines changed

3 files changed

+109
-25
lines changed

dwave/preprocessing/composites/spin_reversal_transform.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -137,63 +137,100 @@ def _reorder_variables(sampleset: dimod.SampleSet,
137137

138138
@dimod.decorators.nonblocking_sample_method
139139
def sample(self, bqm: dimod.BinaryQuadraticModel, *,
140-
num_spin_reversal_transforms: int = 1,
140+
srts: typing.Optional[np.ndarray] = None,
141+
num_spin_reversal_transforms: typing.Optional[int] = None,
141142
**kwargs,
142143
):
143144
"""Sample from the binary quadratic model.
144145
145146
Args:
146147
bqm: Binary quadratic model to be sampled from.
147148
149+
srts: A boolean NumPy array with shape
150+
``(num_spin_reversal_transforms, bqm.num_variables)``.
151+
True indicates a flip and False indicates no flip; applied to
152+
in the order given by ``bqm.variables``.
153+
If this is not specified as an input values are generated uniformly
154+
at random from the class pseudo-random number generator.
155+
148156
num_spin_reversal_transforms:
149157
Number of spin reversal transform runs.
150158
A value of ``0`` will not transform the problem.
151159
If you specify a nonzero value, each spin reversal transform
152160
will result in an independent run of the child sampler.
161+
If ``srts`` is set then ``num_spin_reversal_transforms``
162+
is inferred by the shape, otherwise the default is 1.
153163
154164
Returns:
155165
A sample set. Note that for a sampler that returns ``num_reads`` samples,
156166
the sample set will contain ``num_reads*num_spin_reversal_transforms`` samples.
157167
168+
Raises:
169+
ValueError: If ``srts`` is inconsistent with
170+
``num_spin_reversal_transforms`` or the binary quadratic model.
171+
158172
Examples:
159-
This example runs 100 spin reversals applied to one variable of a QUBO problem.
173+
This example runs 10 spin reversals applied to an unfrustrated chain
174+
of length 6.
175+
176+
Using the lowest energy (ground) state returned, you can define a
177+
special SRT that transforms all programmed couplers to be ferromagnetic
178+
(ground state to all 1).
160179
161180
>>> from dimod import ExactSolver
181+
>>> import numpy as np
162182
>>> from dwave.preprocessing.composites import SpinReversalTransformComposite
163183
>>> base_sampler = ExactSolver()
164184
>>> composed_sampler = SpinReversalTransformComposite(base_sampler)
165185
...
166-
>>> Q = {('a', 'a'): -1, ('b', 'b'): -1, ('a', 'b'): 2}
167-
>>> response = composed_sampler.sample_qubo(Q,
168-
... num_spin_reversal_transforms=100)
169-
>>> len(response)
170-
400
186+
>>> num_var = 6
187+
>>> num_spin_reversal_transforms = 10
188+
>>> J = {(i, i+1): np.random.choice([-1,1]) for i in range(num_var-1)}
189+
>>> h = {i: 0 for i in range(num_var)}
190+
>>> response = composed_sampler.sample_ising(h, J,
191+
... num_spin_reversal_transforms=num_spin_reversal_transforms)
192+
>>> len(response) == 2**num_var * num_spin_reversal_transforms
193+
True
194+
>>> srts = np.array([[response.first.sample[i] != 1 for i in range(num_var)]])
195+
>>> response = composed_sampler.sample_ising(h, J,
196+
... srts=srts, num_reads=1)
197+
>>> sum(response.record.num_occurrences) == 2**num_var
198+
True
171199
"""
172200
sampler = self._child
173201

202+
if srts is not None:
203+
nsrt, num_bqm_var = srts.shape
204+
if num_bqm_var != bqm.num_variables:
205+
raise ValueError('srt shape is inconsistent with the bqm')
206+
if num_spin_reversal_transforms is not None:
207+
if num_spin_reversal_transforms != nsrt:
208+
raise ValueError('srt shape is inconsistent with num_spin_reversal_transforms')
209+
else:
210+
num_spin_reversal_transforms = nsrt
211+
elif num_spin_reversal_transforms is None:
212+
num_spin_reversal_transforms = 1
213+
174214
# No SRTs, so just pass the problem through
175-
if not num_spin_reversal_transforms or not bqm.num_variables:
215+
if num_spin_reversal_transforms == 0 or not bqm.num_variables:
176216
sampleset = sampler.sample(bqm, **kwargs)
177217
# yield twice because we're using the @nonblocking_sample_method
178218
yield sampleset # this one signals done()-ness
179219
yield sampleset # this is the one actually used by the user
180220
return
181221

222+
if srts is None:
223+
srts = self.rng.random((num_spin_reversal_transforms, bqm.num_variables)) > .5
224+
182225
# we'll be modifying the BQM, so make a copy
183226
bqm = bqm.copy()
184227

185-
# We maintain the Leap behavior that num_spin_reversal_transforms == 1
186-
# corresponds to a single problem with randomly flipped variables.
187-
188-
# Get the SRT matrix
189-
SRT = self.rng.random((num_spin_reversal_transforms, bqm.num_variables)) > .5
190-
191228
# Submit the problems
192229
samplesets: typing.List[dimod.SampleSet] = []
193230
flipped = np.zeros(bqm.num_variables, dtype=bool) # what variables are currently flipped
194231
for i in range(num_spin_reversal_transforms):
195232
# determine what needs to be flipped
196-
transform = flipped != SRT[i, :]
233+
transform = flipped != srts[i, :]
197234

198235
# apply the transform
199236
for v, flip in zip(bqm.variables, transform):
@@ -213,10 +250,10 @@ def sample(self, bqm: dimod.BinaryQuadraticModel, *,
213250
# Undo the SRTs according to vartype
214251
if bqm.vartype is Vartype.BINARY:
215252
for i, sampleset in enumerate(samplesets):
216-
sampleset.record.sample[:, SRT[i, :]] = 1 - sampleset.record.sample[:, SRT[i, :]]
253+
sampleset.record.sample[:, srts[i, :]] = 1 - sampleset.record.sample[:, srts[i, :]]
217254
elif bqm.vartype is Vartype.SPIN:
218255
for i, sampleset in enumerate(samplesets):
219-
sampleset.record.sample[:, SRT[i, :]] *= -1
256+
sampleset.record.sample[:, srts[i, :]] *= -1
220257
else:
221258
raise RuntimeError("unexpected vartype")
222259
if num_spin_reversal_transforms == 1:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
prelude: >
3+
Add srts option to SpinReversalTransformComposite
4+
features:
5+
- |
6+
Allow non-random specification of spin reversal transforms

tests/test_spin_reversal_transform.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from dwave.preprocessing.composites import SpinReversalTransformComposite
2121

22-
2322
@dimod.testing.load_sampler_bqm_tests(SpinReversalTransformComposite(dimod.ExactSolver()))
2423
class TestSpinTransformComposite(unittest.TestCase):
2524
def test_instantiation(self):
@@ -45,7 +44,7 @@ def test_empty_bqm_composition(self):
4544

4645
sampler = SpinReversalTransformComposite(dimod.RandomSampler())
4746
bqm = dimod.BinaryQuadraticModel('SPIN')
48-
sampleset = sampler.sample(bqm, num_spin_reversals=1)
47+
sampleset = sampler.sample(bqm, num_spin_reversal_transforms=1)
4948
self.assertEqual(len(sampleset.variables), 0)
5049

5150
def test_concatenation_stripping(self):
@@ -73,12 +72,13 @@ def test_sampleset_size(self):
7372
sampler = SpinReversalTransformComposite(dimod.RandomSampler())
7473
for num_spin_reversal_transforms in [1, 2]:
7574
for num_reads in [1, 3]:
76-
sampleset = sampler.sample_ising(
77-
{0: 1}, {},
78-
num_spin_reversal_transforms=num_spin_reversal_transforms,
75+
with self.subTest(f'{num_reads} {num_spin_reversal_transforms}'):
76+
sampleset = sampler.sample_ising(
77+
{0: 1}, {},
78+
num_spin_reversal_transforms=num_spin_reversal_transforms,
7979
num_reads=num_reads)
80-
self.assertTrue(sum(sampleset.record.num_occurrences) ==
81-
num_reads*num_spin_reversal_transforms)
80+
self.assertTrue(sum(sampleset.record.num_occurrences) ==
81+
num_reads*num_spin_reversal_transforms)
8282

8383
def test_empty(self):
8484
# Check that empty BQMs are handled
@@ -151,7 +151,6 @@ def sample(self, bqm):
151151
ss1 = SpinReversalTransformComposite(Sampler(), seed=42).sample(bqm)
152152
ss2 = SpinReversalTransformComposite(Sampler(), seed=42).sample(bqm)
153153
ss3 = SpinReversalTransformComposite(Sampler(), seed=35).sample(bqm)
154-
155154
self.assertTrue((ss1.record == ss2.record).all())
156155
self.assertFalse((ss1.record == ss3.record).all())
157156

@@ -230,3 +229,45 @@ def sample(bqm):
230229
self.assertTrue(hasattr(sampleset,'info'))
231230
self.assertEqual(sampleset.info, {'has_some': True})
232231

232+
def test_srts_argument(self):
233+
# All 1 ground state
234+
class Sampler:
235+
def sample(self, bqm):
236+
return dimod.SampleSet.from_samples_bqm([-1] * bqm.num_variables, bqm)
237+
num_var = 10
238+
bqm = dimod.BinaryQuadraticModel(
239+
{i: -1 for i in range(num_var)}, {}, 0, 'SPIN')
240+
241+
sampler = Sampler()
242+
ss = sampler.sample(bqm)
243+
samples = ss.record.sample
244+
sampler = SpinReversalTransformComposite(sampler)
245+
srts = np.zeros(shape=(1, num_var), dtype=bool) #
246+
ss = sampler.sample(bqm, srts=srts)
247+
self.assertTrue(np.all(ss.record.sample == samples),
248+
"Neutral srts leaves result unpermuted.")
249+
250+
ss = sampler.sample(bqm, srts=np.empty(shape=(0, num_var)))
251+
self.assertTrue(np.all(ss.record.sample == samples),
252+
"Empty srts also allows pass through "
253+
"(just like num_spin_reversals=0).")
254+
255+
srts = np.ones(shape=(1, num_var), dtype=bool)
256+
ss = sampler.sample(bqm, srts=srts)
257+
self.assertTrue(np.all(ss.record.sample == -samples),
258+
"Flip-all srts inverts the order")
259+
260+
with self.subTest('srt shape'):
261+
num_spin_reversal_transforms = 3
262+
srts = np.unique(np.random.random(size=(num_spin_reversal_transforms, num_var)) > 0.5, axis=0)
263+
ss = sampler.sample(bqm, srts=srts)
264+
self.assertEqual(np.sum(ss.record.num_occurrences), srts.shape[0],
265+
"Apply 3 srtss")
266+
self.assertTrue(np.all(srts == (ss.record.sample==1)))
267+
268+
with self.assertRaises(ValueError):
269+
# Inconsistent arguments
270+
ss = sampler.sample(bqm, srts=srts, num_spin_reversal_transforms=num_spin_reversal_transforms+1)
271+
272+
273+

0 commit comments

Comments
 (0)