diff --git a/olmoearth_pretrain/train/masking.py b/olmoearth_pretrain/train/masking.py index e0a8f7e54..c4bb12a90 100644 --- a/olmoearth_pretrain/train/masking.py +++ b/olmoearth_pretrain/train/masking.py @@ -1917,7 +1917,9 @@ def apply_mask( else: use_random_masking = False not_missing_t = torch.argwhere(missing_per_time)[:, 0] - not_missing_t = not_missing_t[torch.randperm(len(not_missing_t))] + not_missing_t = not_missing_t[ + torch.randperm(len(not_missing_t), device=not_missing_t.device) + ] num_encode = math.ceil(len(not_missing_t) * self.encode_ratio) encode_timestamps = not_missing_t[:num_encode] decode_timestamps = not_missing_t[num_encode:]