diff --git a/visualizations/animation.py b/visualizations/animation.py index 08529b1e..4ebf96ff 100644 --- a/visualizations/animation.py +++ b/visualizations/animation.py @@ -23,17 +23,19 @@ def build_mask(s: int, margin: float = 2., dtype=torch.float32): - mask = torch.zeros(1, 1, s, s, dtype=dtype) + mask = torch.ones(1, 1, s, s, dtype=dtype) c = (s - 1) / 2 t = (c - margin / 100. * c) ** 2 sig = 2. - for x in range(s): - for y in range(s): - r = (x - c) ** 2 + (y - c) ** 2 - if r > t: - mask[..., x, y] = np.exp((t - r) / sig ** 2) - else: - mask[..., x, y] = 1. + y, x = np.ogrid[:s, :s] + r = (x - c) ** 2 + (y - c) ** 2 + # r > t + outer_mask = ((t - r) / sig ** 2) + outer_mask = outer_mask ** (r > t) # To prevent overflow + outer_mask = (r > t) * np.exp(outer_mask) + # r <= t + inner_mask = (r <= t) + mask = mask * outer_mask + mask * inner_mask return mask