Skip to content

Redistribute intervals fix #855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/toast/intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,16 @@ def __init__(self, timestamps, intervals=None, timespans=None, samplespans=None)
if np.isclose(timespans[i][1], timespans[i + 1][0], rtol=1e-12):
# Force nearly equal timestamps to match
timespans[i][1] = timespans[i + 1][0]
# Check that the intervals are sorted and disjoint
if timespans[i][1] > timespans[i + 1][0]:
raise RuntimeError("Timespans must be sorted and disjoint")
t1 = timespans[i][1]
t2 = timespans[i + 1][0]
dt = t1 - t2
ts = np.median(np.diff(timestamps))
msg = f"Timespans must be sorted and disjoint"
msg += f" but {t1} - {t2} = {dt} s = {dt / ts} samples)"
raise RuntimeError(msg)
# Map interval times into sample indices
indices, times = self._find_indices(timespans)
self.data = np.array(
[
Expand Down Expand Up @@ -186,7 +194,12 @@ def __len__(self):
return len(self.data)

def __repr__(self):
return self.data.__repr__()
s = "<IntervalList [\n"
for ival in self.data:
s += f" {ival.start:15.3f} - {ival.stop:15.3f} ({ival.first:9} - {ival.last:9}),\n"
s += "]>"
# return self.data.__repr__()
return s

def __eq__(self, other):
if len(self.data) != len(other):
Expand Down
64 changes: 16 additions & 48 deletions src/toast/observation_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,17 +419,14 @@ def global_interval_times(dist, intervals_manager, name, join=False):
(list): List of tuples on the root process, and None on other processes.

"""
ilist = [(x.start, x.stop, x.first, x.last) for x in intervals_manager[name]]
pstart = intervals_manager[name].timestamps[0]
ilist = [(x.start, x.stop) for x in intervals_manager[name]]
all_ilist = None
if dist.comm_row is None:
all_ilist = [(ilist, dist.samps[dist.comm.group_rank].n_elem, pstart)]
all_ilist = [ilist]
else:
# Gather across the process row
if dist.comm_col_rank == 0:
all_ilist = dist.comm_row.gather(
(ilist, dist.samps[dist.comm.group_rank].n_elem, pstart), root=0
)
all_ilist = dist.comm_row.gather(ilist, root=0)
del ilist

glist = None
Expand All @@ -438,51 +435,22 @@ def global_interval_times(dist, intervals_manager, name, join=False):
# the rank zero process of the observation is also the process with rank
# zero along both the rows and columns.
glist = list()

prev = None
cur = None
last_pn = None
global_off = 0
for pdata, pn, pstrt in all_ilist:
last_start = 0
last_stop = 0
for pdata in all_ilist:
if len(pdata) == 0:
continue
for start, stop, first, last in pdata:
cur = [
float(start),
float(stop),
int(global_off + first),
int(global_off + last),
]
if prev is None:
# First global interval
prev = cur
for start, stop in pdata:
# Avoid adding the same time span twice
if (
np.isclose(start, last_start, rtol=1e-12) and
np.isclose(stop, last_stop, rtol=1e-12)
):
continue
if last_pn is not None:
# We are on a later process's data, see if we had any continuation
# of interval from last process.
if prev[3] == global_off:
# The previous interval ended at the final sample on that
# process. This means that the timestamp of the final sample
# was artificially truncated.
if cur[2] == 0 and join:
# The first interval on this process starts at sample zero,
# and we are joining intervals across the process boundary.
prev[1] = cur[1]
prev[3] = cur[3]
continue
else:
# We are keeping any break between processes, but use
# the first start time on this process as the stop time of
# the interval from the previous process.
prev[1] = pstrt
glist.append((prev[0], prev[1]))
prev = cur
last_pn = pn
global_off += pn

# Add final interval
if prev is not None:
glist.append((prev[0], prev[1]))
glist.append((start, stop))
last_start = start
last_stop = stop

return glist


Expand Down
7 changes: 5 additions & 2 deletions src/toast/ops/hwpss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,11 @@ def _cut_outliers(self, obs, det_mag):
all_dets = dets
all_mag = mag
else:
all_dets = flatten(obs.comm_col.gather(dets, root=0))
all_mag = np.array(flatten(obs.comm_col.gather(mag, root=0)))
all_dets = obs.comm_col.gather(dets, root=0)
all_mag = obs.comm_col.gather(mag, root=0)
if obs.comm_col.rank == 0:
all_dets = list(flatten(all_dets))
all_mag = np.array(list(flatten(all_mag)))

# One process does the trivial calculation
all_flags = None
Expand Down
47 changes: 31 additions & 16 deletions src/toast/ops/pixels_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
from .operator import Operator


def unwrap_together(x, y, period=2 * np.pi * u.rad):
"""Unwrap x but apply the same branch corrections to y"""
for i in range(1, len(x)):
while np.abs(x[i] - x[i - 1]) > np.abs(x[i] + period - x[i - 1]):
x[i] += period
y[i] += period
while np.abs(x[i] - x[i - 1]) > np.abs(x[i] - period - x[i - 1]):
x[i] -= period
y[i] -= period
return


@trait_docs
class PixelsWCS(Operator):
"""Operator which generates detector pixel indices defined on a flat projection.
Expand Down Expand Up @@ -439,26 +451,29 @@ def _exec(self, data, detectors=None, **kwargs):
)
minmax = minmax.T
# Compact observations on both sides of the zero meridian
# can confuse this calculation. Use np.unwrap() to find
# the most compact longitude range.
lonmin = np.amin(np.unwrap(minmax[0]))
lonmax = np.amax(np.unwrap(minmax[1]))
if lonmax < lonmin:
lonmax += 2 * np.pi
# can confuse this calculation. We must unwrap the per-observation
# limits for the most compact longitude range.
unwrap_together(minmax[0], minmax[1])
lonmin = np.amin(minmax[0])
lonmax = np.amin(minmax[1])
latmin = np.amin(minmax[2])
latmax = np.amax(minmax[3])
if data.comm.comm_world is not None:
# Zero meridian concern applies across processes
all_lonmin = data.comm.comm_world.allgather(lonmin.to_value(u.radian))
all_lonmax = data.comm.comm_world.allgather(lonmax.to_value(u.radian))
all_latmin = data.comm.comm_world.allgather(latmin.to_value(u.radian))
all_latmax = data.comm.comm_world.allgather(latmax.to_value(u.radian))
lonmin = np.amin(np.unwrap(all_lonmin)) * u.radian
lonmax = np.amax(np.unwrap(all_lonmax)) * u.radian
if lonmax < lonmin:
lonmax += 2 * np.pi
latmin = np.amin(all_latmin) * u.radian
latmax = np.amax(all_latmax) * u.radian
def gather(x):
return (
data.comm.comm_world.allgather(x.to_value(u.radian)) * u.radian
)

all_lonmin = gather(lonmin)
all_lonmax = gather(lonmax)
all_latmin = gather(latmin)
all_latmax = gather(latmax)
unwrap_together(all_lonmin, all_lonmax)
lonmin = np.amin(all_lonmin)
lonmax = np.amin(all_lonmax)
latmin = np.amin(all_latmin)
latmax = np.amax(all_latmax)
self.bounds = (
lonmin.to(u.degree),
lonmax.to(u.degree),
Expand Down
17 changes: 15 additions & 2 deletions src/toast/spt3g/spt3g_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from astropy import units as u

from ..instrument import GroundSite, SpaceSite
from ..intervals import IntervalList
from ..intervals import IntervalList, interval_dtype
from ..timing import function_timer
from ..utils import Environment, Logger, object_fullname
from .spt3g_utils import (
Expand Down Expand Up @@ -167,7 +167,20 @@ def export_intervals(obs, name, iframe):
(G3Object): The best container available.

"""
overlap = iframe & obs.intervals[name]
overlap = []
frame = iframe.data[0]
for ival in obs.intervals[name]:
if frame.start <= ival.stop and frame.stop >= ival.start:
overlap.append((
ival.start,
ival.stop,
max(frame.first, ival.first),
min(frame.last, ival.last),
))
overlap = IntervalList(
iframe.timestamps,
intervals=np.array(overlap, dtype=interval_dtype).view(np.recarray),
)

out = None
try:
Expand Down
8 changes: 6 additions & 2 deletions src/toast/tests/helpers/ground.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def create_ground_data(
turnarounds_invalid=False,
single_group=False,
flagged_pixels=True,
schedule_hours=2,
):
"""Create a data object with a simple ground sim.

Expand Down Expand Up @@ -137,6 +138,7 @@ def create_ground_data(
if tdir is None:
tdir = tempfile.mkdtemp()

sch_hours = f"{int(schedule_hours):02d}"
sch_file = os.path.join(tdir, "ground_schedule.txt")
run_scheduler(
opts=[
Expand All @@ -155,7 +157,7 @@ def create_ground_data(
"--start",
"2020-01-01 00:00:00",
"--stop",
"2020-01-01 06:00:00",
f"2020-01-01 {sch_hours}:00:00",
"--out",
sch_file,
]
Expand Down Expand Up @@ -218,6 +220,7 @@ def create_overdistributed_data(
freqs=None,
turnarounds_invalid=False,
single_group=False,
schedule_hours=2,
):
"""Create a data object with more detectors than processes.

Expand Down Expand Up @@ -264,6 +267,7 @@ def create_overdistributed_data(
if tdir is None:
tdir = tempfile.mkdtemp()

sch_hours = f"{int(schedule_hours):02d}"
sch_file = os.path.join(tdir, "ground_schedule.txt")
run_scheduler(
opts=[
Expand All @@ -282,7 +286,7 @@ def create_overdistributed_data(
"--start",
"2020-01-01 00:00:00",
"--stop",
"2020-01-01 06:00:00",
f"2020-01-01 {sch_hours}:00:00",
"--out",
sch_file,
]
Expand Down
8 changes: 4 additions & 4 deletions src/toast/tests/ops_hwpss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def create_test_data(self, testdir):
weights,
skyfile,
"input_sky_dist",
map_key="input_sky",
map_key=map_key,
fwhm=30.0 * u.arcmin,
lmax=3 * pixels.nside,
I_scale=0.001,
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_hwpss_basic(self):
os.makedirs(testdir)

data, tod_rms, coeff = self.create_test_data(testdir)
n_harmonics = len(coeff) // 4
n_harmonics = len(coeff[data.obs[0].name]) // 4

# Add random DC level
for ob in data.obs:
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_hwpss_relcal_fixed(self):
os.makedirs(testdir)

data, tod_rms, coeff = self.create_test_data(testdir)
n_harmonics = len(coeff) // 4
n_harmonics = len(coeff[data.obs[0].name]) // 4

# Apply a random inverse relative calibration
np.random.seed(123456)
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_hwpss_relcal_continuous(self):
os.makedirs(testdir)

data, tod_rms, coeff = self.create_test_data(testdir)
n_harmonics = len(coeff) // 4
n_harmonics = len(coeff[data.obs[0].name]) // 4

# Apply a random inverse relative calibration that is time-varying
np.random.seed(123456)
Expand Down
Loading