@@ -137,63 +137,100 @@ def _reorder_variables(sampleset: dimod.SampleSet,
137137
138138 @dimod .decorators .nonblocking_sample_method
139139 def sample (self , bqm : dimod .BinaryQuadraticModel , * ,
140- num_spin_reversal_transforms : int = 1 ,
140+ srts : typing .Optional [np .ndarray ] = None ,
141+ num_spin_reversal_transforms : typing .Optional [int ] = None ,
141142 ** kwargs ,
142143 ):
143144 """Sample from the binary quadratic model.
144145
145146 Args:
146147 bqm: Binary quadratic model to be sampled from.
147148
149+ srts: A boolean NumPy array with shape
150+ ``(num_spin_reversal_transforms, bqm.num_variables)``.
151+ True indicates a flip and False indicates no flip; applied to
152+ in the order given by ``bqm.variables``.
153+ If this is not specified as an input values are generated uniformly
154+ at random from the class pseudo-random number generator.
155+
148156 num_spin_reversal_transforms:
149157 Number of spin reversal transform runs.
150158 A value of ``0`` will not transform the problem.
151159 If you specify a nonzero value, each spin reversal transform
152160 will result in an independent run of the child sampler.
161+ If ``srts`` is set then ``num_spin_reversal_transforms``
162+ is inferred by the shape, otherwise the default is 1.
153163
154164 Returns:
155165 A sample set. Note that for a sampler that returns ``num_reads`` samples,
156166 the sample set will contain ``num_reads*num_spin_reversal_transforms`` samples.
157167
168+ Raises:
169+ ValueError: If ``srts`` is inconsistent with
170+ ``num_spin_reversal_transforms`` or the binary quadratic model.
171+
158172 Examples:
159- This example runs 100 spin reversals applied to one variable of a QUBO problem.
173+ This example runs 10 spin reversals applied to an unfrustrated chain
174+ of length 6.
175+
176+ Using the lowest energy (ground) state returned, you can define a
177+ special SRT that transforms all programmed couplers to be ferromagnetic
178+ (ground state to all 1).
160179
161180 >>> from dimod import ExactSolver
181+ >>> import numpy as np
162182 >>> from dwave.preprocessing.composites import SpinReversalTransformComposite
163183 >>> base_sampler = ExactSolver()
164184 >>> composed_sampler = SpinReversalTransformComposite(base_sampler)
165185 ...
166- >>> Q = {('a', 'a'): -1, ('b', 'b'): -1, ('a', 'b'): 2}
167- >>> response = composed_sampler.sample_qubo(Q,
168- ... num_spin_reversal_transforms=100)
169- >>> len(response)
170- 400
186+ >>> num_var = 6
187+ >>> num_spin_reversal_transforms = 10
188+ >>> J = {(i, i+1): np.random.choice([-1,1]) for i in range(num_var-1)}
189+ >>> h = {i: 0 for i in range(num_var)}
190+ >>> response = composed_sampler.sample_ising(h, J,
191+ ... num_spin_reversal_transforms=num_spin_reversal_transforms)
192+ >>> len(response) == 2**num_var * num_spin_reversal_transforms
193+ True
194+ >>> srts = np.array([[response.first.sample[i] != 1 for i in range(num_var)]])
195+ >>> response = composed_sampler.sample_ising(h, J,
196+ ... srts=srts, num_reads=1)
197+ >>> sum(response.record.num_occurrences) == 2**num_var
198+ True
171199 """
172200 sampler = self ._child
173201
202+ if srts is not None :
203+ nsrt , num_bqm_var = srts .shape
204+ if num_bqm_var != bqm .num_variables :
205+ raise ValueError ('srt shape is inconsistent with the bqm' )
206+ if num_spin_reversal_transforms is not None :
207+ if num_spin_reversal_transforms != nsrt :
208+ raise ValueError ('srt shape is inconsistent with num_spin_reversal_transforms' )
209+ else :
210+ num_spin_reversal_transforms = nsrt
211+ elif num_spin_reversal_transforms is None :
212+ num_spin_reversal_transforms = 1
213+
174214 # No SRTs, so just pass the problem through
175- if not num_spin_reversal_transforms or not bqm .num_variables :
215+ if num_spin_reversal_transforms == 0 or not bqm .num_variables :
176216 sampleset = sampler .sample (bqm , ** kwargs )
177217 # yield twice because we're using the @nonblocking_sample_method
178218 yield sampleset # this one signals done()-ness
179219 yield sampleset # this is the one actually used by the user
180220 return
181221
222+ if srts is None :
223+ srts = self .rng .random ((num_spin_reversal_transforms , bqm .num_variables )) > .5
224+
182225 # we'll be modifying the BQM, so make a copy
183226 bqm = bqm .copy ()
184227
185- # We maintain the Leap behavior that num_spin_reversal_transforms == 1
186- # corresponds to a single problem with randomly flipped variables.
187-
188- # Get the SRT matrix
189- SRT = self .rng .random ((num_spin_reversal_transforms , bqm .num_variables )) > .5
190-
191228 # Submit the problems
192229 samplesets : typing .List [dimod .SampleSet ] = []
193230 flipped = np .zeros (bqm .num_variables , dtype = bool ) # what variables are currently flipped
194231 for i in range (num_spin_reversal_transforms ):
195232 # determine what needs to be flipped
196- transform = flipped != SRT [i , :]
233+ transform = flipped != srts [i , :]
197234
198235 # apply the transform
199236 for v , flip in zip (bqm .variables , transform ):
@@ -213,10 +250,10 @@ def sample(self, bqm: dimod.BinaryQuadraticModel, *,
213250 # Undo the SRTs according to vartype
214251 if bqm .vartype is Vartype .BINARY :
215252 for i , sampleset in enumerate (samplesets ):
216- sampleset .record .sample [:, SRT [i , :]] = 1 - sampleset .record .sample [:, SRT [i , :]]
253+ sampleset .record .sample [:, srts [i , :]] = 1 - sampleset .record .sample [:, srts [i , :]]
217254 elif bqm .vartype is Vartype .SPIN :
218255 for i , sampleset in enumerate (samplesets ):
219- sampleset .record .sample [:, SRT [i , :]] *= - 1
256+ sampleset .record .sample [:, srts [i , :]] *= - 1
220257 else :
221258 raise RuntimeError ("unexpected vartype" )
222259 if num_spin_reversal_transforms == 1 :
0 commit comments