Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pints/_mcmc/_mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self, x0, sigma0=None):

# Set initial state
self._running = False
self._ready_for_tell = False

# Current point and proposed point
self._current = None
Expand Down Expand Up @@ -225,9 +224,8 @@ def tell(self, reply):
""" See :meth:`pints.SingleChainMCMC.tell()`. """

# Check if we had a proposal
if not self._ready_for_tell:
if self._proposed is None:
raise RuntimeError('Tell called before proposal was set.')
self._ready_for_tell = False

Comment on lines -228 to 229
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, probably not getting something but don't quite get why we'd remove these checks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more in line with what we have in the other MCMC methods, with the exception of the ones that do more complicated stuff in ask() than just set a proposal.

So no very strong reason! Just that it's possible to let users ask() as often as they like in this case (and always get the same result), and it lets us remove a variable...

We might want to disallow this everywhere instead though? In which case we'd have to update the documentation for SingleChainMCMC and MultiChainMCMC a bit to make this explicit...

# Unpack reply
fx, log_gradient = reply
Expand Down
61 changes: 37 additions & 24 deletions pints/tests/test_mcmc_mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class TestMALAMCMC(unittest.TestCase):
Tests the basic methods of the MALA MCMC routine.
"""

def test_method(self):
def test_short_run(self):
# Test a short run with MALA

# Create log pdf
log_pdf = pints.toy.GaussianLogPDF([5, 5], [[4, 1], [1, 3]])
Expand All @@ -30,9 +31,6 @@ def test_method(self):
sigma = [[3, 0], [0, 3]]
mcmc = pints.MALAMCMC(x0, sigma)

# This method needs sensitivities
self.assertTrue(mcmc.needs_sensitivities())

# Perform short run
chain = []
for i in range(100):
Expand All @@ -47,41 +45,49 @@ def test_method(self):
chain = np.array(chain)
self.assertEqual(chain.shape[0], 50)
self.assertEqual(chain.shape[1], len(x0))
self.assertTrue(mcmc.acceptance_rate() >= 0.0 and
mcmc.acceptance_rate() <= 1.0)
self.assertTrue(0 <= mcmc.acceptance_rate() <= 1.0)

def test_needs_sensitivities(self):
# This method needs sensitivities

mcmc._proposed = [1, 3]
self.assertRaises(RuntimeError, mcmc.tell, (fx, gr))
mcmc = pints.MALAMCMC(np.array([2, 2]))
self.assertTrue(mcmc.needs_sensitivities())

def test_logging(self):
# Test logging includes name and custom fields.

log_pdf = pints.toy.GaussianLogPDF([5, 5], [[4, 1], [1, 3]])
x0 = [np.array([2, 2]), np.array([8, 8])]

mcmc = pints.MCMCSampling(log_pdf, 2, x0, method=pints.MALAMCMC)
mcmc = pints.MCMCController(log_pdf, 2, x0, method=pints.MALAMCMC)
mcmc.set_max_iterations(5)
with StreamCapture() as c:
mcmc.run()
text = c.text()

self.assertIn('Metropolis-Adjusted Langevin Algorithm (MALA)',
text)
self.assertIn('Metropolis-Adjusted Langevin Algorithm (MALA)', text)
self.assertIn(' Accept.', text)

def test_flow(self):
# Test the ask-and-tell flow

log_pdf = pints.toy.GaussianLogPDF([5, 5], [[4, 1], [1, 3]])
x0 = np.array([2, 2])

# Test initial proposal is first point
mcmc = pints.MALAMCMC(x0)
self.assertTrue(np.all(mcmc.ask() == mcmc._x0))

# Repeated asks
self.assertRaises(RuntimeError, mcmc.ask)

# Tell without ask
self.assertTrue(np.all(mcmc.ask() == x0))

# Repeated asks return same point
self.assertTrue(np.all(mcmc.ask() == x0))
self.assertTrue(np.all(mcmc.ask() == x0))
self.assertTrue(np.all(mcmc.ask() == x0))
for i in range(5):
mcmc.tell(log_pdf.evaluateS1(mcmc.ask()))
x1 = mcmc.ask()
self.assertTrue(np.all(mcmc.ask() == x1))

# Tell without ask should fail
mcmc = pints.MALAMCMC(x0)
self.assertRaises(RuntimeError, mcmc.tell, 0)

Expand All @@ -101,8 +107,8 @@ def test_flow(self):
mcmc._running = True
self.assertRaises(RuntimeError, mcmc._initialise)

def test_set_hyper_parameters(self):
# Tests the parameter interface for this sampler.
def test_hyper_parameters(self):
# Tests the hyper parameter interface for this sampler.

x0 = np.array([2, 2])
mcmc = pints.MALAMCMC(x0)
Expand All @@ -113,17 +119,24 @@ def test_set_hyper_parameters(self):
self.assertTrue(np.array_equal(mcmc.epsilon(),
0.2 * np.diag(mcmc._sigma0)))

mcmc = pints.MALAMCMC(np.array([2, 2]))
self.assertEqual(mcmc.n_hyper_parameters(), 1)
mcmc.set_hyper_parameters([[3, 2]])
self.assertTrue(np.array_equal(mcmc.epsilon(), [3, 2]))
mcmc.set_hyper_parameters([[5, 5]])
self.assertTrue(np.array_equal(mcmc.epsilon(), [5, 5]))

mcmc._step_size = 5
mcmc._scale_vector = np.array([3, 7])
mcmc._epsilon = None
def test_epsilon(self):
# Test the epsilon methods

mcmc = pints.MALAMCMC(np.array([2, 2]), np.array([3, 3]))
mcmc.set_epsilon()
self.assertTrue(np.array_equal(mcmc.epsilon(), [15, 35]))
x = mcmc.epsilon()
self.assertAlmostEqual(x[0], 0.6)
self.assertAlmostEqual(x[1], 0.6)
mcmc.set_epsilon([0.4, 0.5])
self.assertTrue(np.array_equal(mcmc.epsilon(), [0.4, 0.5]))
self.assertTrue(np.all(mcmc.epsilon() == [0.4, 0.5]))

self.assertRaises(ValueError, mcmc.set_epsilon, 3.0)
self.assertRaises(ValueError, mcmc.set_epsilon, [-2.0, 1])

Expand Down