Skip to content

Commit

Permalink
Fix to device in brownian utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
lxuechen committed Jul 8, 2020
1 parent 495f0ef commit 0a0f548
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 22 deletions.
38 changes: 27 additions & 11 deletions tests/test_brownian_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,38 @@
torch.set_default_dtype(torch.float64)

D = 3
BATCH_SIZE = 16384
SMALL_BATCH_SIZE = 16
LARGE_BATCH_SIZE = 16384
REPS = 3
ALPHA = 0.001


class TestBrownianTree(TorchTestCase):

def _setUp(self, device=None):
def _setUp(self, batch_size, device=None):
t0, t1 = torch.tensor([0., 1.]).to(device)
w0 = torch.zeros(BATCH_SIZE, D).to(device=device)
w1 = torch.randn(BATCH_SIZE, D).to(device=device)
w0 = torch.zeros(batch_size, D).to(device=device)
w1 = torch.randn(batch_size, D).to(device=device)
t = torch.rand([]).to(device)

self.t = t
self.bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, entropy=0)

def test_basic_cpu(self):
self._setUp(device=torch.device('cpu'))
self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cpu'))
sample = self.bm(self.t)
self.assertEqual(sample.size(), (BATCH_SIZE, D))
self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D))

def test_basic_gpu(self):
if not torch.cuda.is_available():
self.skipTest(reason='CUDA not available.')

self._setUp(device=torch.device('cuda'))
self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cuda'))
sample = self.bm(self.t)
self.assertEqual(sample.size(), (BATCH_SIZE, D))
self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D))

def test_determinism(self):
self._setUp()
self._setUp(batch_size=SMALL_BATCH_SIZE)
vals = [self.bm(self.t) for _ in range(REPS)]
for val in vals[1:]:
self.tensorAssertAllClose(val, vals[0])
Expand All @@ -73,8 +74,8 @@ def test_normality(self):
for _ in range(REPS):
w0_, w1_ = 0.0, npr.randn()
# Use the same endpoint for the batch, so samples from same dist.
w0 = torch.tensor(w0_).repeat(BATCH_SIZE)
w1 = torch.tensor(w1_).repeat(BATCH_SIZE)
w0 = torch.tensor(w0_).repeat(LARGE_BATCH_SIZE)
w1 = torch.tensor(w1_).repeat(LARGE_BATCH_SIZE)
bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, pool_size=100, tol=1e-14)

for _ in range(REPS):
Expand All @@ -89,6 +90,21 @@ def test_normality(self):
_, pval = kstest(samples_, ref_dist.cdf)
self.assertGreaterEqual(pval, ALPHA)

def test_to(self):
if not torch.cuda.is_available():
self.skipTest(reason='CUDA not available.')

self._setUp(batch_size=SMALL_BATCH_SIZE)
cache = self.bm.get_cache()
old = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0)

gpu = torch.device('cuda')
self.bm.to(gpu)
cache = self.bm.get_cache()
new = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0)
self.assertTrue(str(new.device).startswith('cuda'))
self.tensorAssertAllClose(old, new.cpu())


if __name__ == '__main__':
unittest.main()
11 changes: 7 additions & 4 deletions torchsde/brownian/brownian_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ def __repr__(self):
)

def to(self, *args, **kwargs):
ws_new = blist.blist()
for w in self._ws:
ws_new.append(w.to(*args, **kwargs))
self._ws = ws_new
self._ws = utils.blist_to(self._ws, *args, **kwargs)

@property
def dtype(self):
Expand All @@ -153,3 +150,9 @@ def size(self):

def __len__(self):
return len(self._ts)

def get_cache(self):
return {
'ts': self._ts,
'ws': self._ws,
}
20 changes: 13 additions & 7 deletions torchsde/brownian/brownian_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def __repr__(self):
)

def to(self, *args, **kwargs):
self._ws_prev = _list_to(self._ws_prev, *args, **kwargs)
self._ws_post = _list_to(self._ws_post, *args, **kwargs)
self._ws = _list_to(self._ws, *args, **kwargs)
self._ws_prev = utils.blist_to(self._ws_prev, *args, **kwargs)
self._ws_post = utils.blist_to(self._ws_post, *args, **kwargs)
self._ws = utils.blist_to(self._ws, *args, **kwargs)

@property
def dtype(self):
Expand All @@ -157,6 +157,16 @@ def size(self):
def __len__(self):
return len(self._ts) + len(self._ts_prev) + len(self._ts_post)

def get_cache(self):
return {
'ts_prev': self._ts_prev,
'ts': self._ts,
'ts_post': self._ts_post,
'ws_prev': self._ws_prev,
'ws': self._ws,
'ws_post': self._ws_post
}


def _binary_search(t0, t1, w0, w1, t, parent, tol):
seedv, seedl, seedr = parent.spawn(3)
Expand Down Expand Up @@ -211,7 +221,3 @@ def _create_cache(t0, t1, w0, w1, entropy, pool_size, k):
seeds = new_seeds

return ts, ws, seeds


def _list_to(l, *args, **kwargs):
return [li.to(*args, **kwargs) for li in l]
4 changes: 4 additions & 0 deletions torchsde/brownian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ def brownian_bridge(t0: float, t1: float, w0, w1, t: float, seed=None):

def is_scalar(x):
return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1)


def blist_to(l, *args, **kwargs):
return blist.blist([li.to(*args, **kwargs) for li in l])

0 comments on commit 0a0f548

Please sign in to comment.