diff --git a/cosipy/background_estimation/NFBackground.py b/cosipy/background_estimation/NFBackground.py new file mode 100644 index 000000000..5187e8de3 --- /dev/null +++ b/cosipy/background_estimation/NFBackground.py @@ -0,0 +1,90 @@ +from typing import List, Union, Optional, Dict +from pathlib import Path + +from cosipy import SpacecraftHistory + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch +import torch.multiprocessing as mp +from cosipy.response.NFBase import NFBase, CompileMode, update_density_worker_settings, init_density_worker, DensityApproximation, DensityModel, RateModel +from .NFBackgroundModels import TotalBackgroundDensityCMLPDGaussianCARQSFlow, TotalDC4BackgroundRate + + +class BackgroundDensityApproximation(DensityApproximation): + + def _setup_model(self): + version_map: Dict[int, DensityModel] = { + 1: TotalBackgroundDensityCMLPDGaussianCARQSFlow(self._density_input, self._worker_device, self._batch_size, self._compile_mode), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Density Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + self._expected_source_dim = self._model.source_dim + +class BackgroundRateApproximation: + def __init__(self, major_version: int, rate_input: Dict): + self._major_version = major_version + self._rate_input = rate_input + + self._setup_model() + + def _setup_model(self): + version_map: Dict[int, RateModel] = { + 1: TotalDC4BackgroundRate(self._rate_input), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Rate Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + + def evaluate_rate(self, context: torch.Tensor) -> torch.Tensor: + dim_context = context.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + + return self._model.evaluate_rate(*list_context) + +def init_background_worker(device_queue: mp.Queue, progress_queue: mp.Queue, major_version: int, + density_input: Dict, density_batch_size: int, + density_compile_mode: CompileMode): + + init_density_worker(device_queue, progress_queue, major_version, + density_input, density_batch_size, + density_compile_mode, BackgroundDensityApproximation) + +class NFBackground(NFBase): + def __init__(self, path_to_model: Union[str, Path], density_batch_size: int = 100_000, + devices: Optional[List[Union[str, int, torch.device]]] = None, + density_compile_mode: CompileMode = "default", show_progress: bool = True): + + super().__init__(path_to_model, update_density_worker_settings, init_background_worker, density_batch_size, devices, density_compile_mode, ['rate_input'], show_progress) + + self._rate_approximation = BackgroundRateApproximation(self._major_version, self._ckpt['rate_input']) + + self._update_pool_arguments() + + def _update_pool_arguments(self): + self._pool_arguments = [ + getattr(self, "_major_version", None), + getattr(self, "_density_input", None), + getattr(self, "_density_batch_size", None), + getattr(self, "_density_compile_mode", None), + ] + + def evaluate_rate(self, context: torch.Tensor) -> torch.Tensor: + return self._rate_approximation.evaluate_rate(context) + \ No newline at end of file diff --git a/cosipy/background_estimation/NFBackgroundModels.py b/cosipy/background_estimation/NFBackgroundModels.py new file mode 100644 index 000000000..0f1cd8314 --- /dev/null +++ b/cosipy/background_estimation/NFBackgroundModels.py @@ -0,0 +1,225 @@ +import numpy as np + +from typing import Union, Tuple, Dict + + +from importlib.util import find_spec + +if any(find_spec(pkg) is None for pkg in ["torch", "normflows"]): + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + + +from cosipy.response.NFBase import CompileMode, build_c_arqs_flow, build_cmlp_diaggaussian_base, NNDensityInferenceWrapper, DensityModel, RateModel +import normflows as nf +import torch + + +class TotalDC4BackgroundRate(RateModel): + @property + def context_dim(self) -> int: + return 1 + + def _unpack_rate_input(self, rate_input: Dict): + self._slew_duration = rate_input["slew_duration"] + self._obs_duration = rate_input["obs_duration"] + self._start_time = rate_input["start_time"] + + self._offset: float = rate_input["offset"] + self._slope: float = rate_input["slope"] + self._buildup_A: Tuple[float, float] = rate_input["buildup"][0] + self._buildup_T: Tuple[float, float] = rate_input["buildup"][1] + self._scale: float = rate_input["scale"] + self._cutoff_T: float = rate_input["cutoff"][0] + self._cutoff_A: Tuple[float, float, float] = rate_input["cutoff"][1] + self._cutoff_kappa: Tuple[float, float, float] = rate_input["cutoff"][2] + self._cutoff_mu: Tuple[float, float, float] = rate_input["cutoff"][3] + self._outlocs: torch.Tensor = rate_input["outlocs"] + self._saa_decay_A: Tuple[float, float] = rate_input["saa_decay"][0] + self._saa_decay_T: Tuple[float, float] = rate_input["saa_decay"][1] + + @staticmethod + def _buildup(time: torch.Tensor, A: float, T: float) -> torch.Tensor: + return A * (1 - torch.exp(-time * np.log(2) / T)) + + def _pointing_scale(self, time: torch.Tensor, scale: float, k0: float=1.0) -> torch.Tensor: + half_slew = self._slew_duration / 2.0 + full_cycle = 2 * (self._obs_duration + self._slew_duration) + rel_t = time % full_cycle + k = k0 / self._slew_duration + + t1 = self._obs_duration + half_slew + t2 = full_cycle - half_slew + + s1 = 1 / (1 + torch.exp(-k * (rel_t - t1))) + s2 = 1 / (1 + torch.exp(-k * (rel_t - t2))) + s0 = 1 / (1 + torch.exp(-k * (rel_t - (t2 - full_cycle)))) + + return scale * (s0 - s1 + s2) + + @staticmethod + def _von_mises(time: torch.Tensor, T: float, A: float, kappa: float, mu: float) -> torch.Tensor: + return A * torch.exp(kappa * torch.cos(2 * np.pi * (time - mu) / T)) + + def _base_cutoff(self, time, T: float, A: Tuple[float, float, float], + kappa: Tuple[float, float, float], mu: Tuple[float, float, float]) -> torch.Tensor: + return self._von_mises(time, T, A[0], kappa[0], mu[0]) + \ + self._von_mises(time, T, A[1], kappa[1], mu[1]) + \ + self._von_mises(time, T, A[2], kappa[2], mu[2]) + + def _orbital_period(self, time, scale: float, T: float, A: Tuple[float, float, float], + kappa: Tuple[float, float, float], mu: Tuple[float, float, float]) -> torch.Tensor: + sample_times = torch.linspace(0, T, 1000) + fmin = torch.min(self._base_cutoff(sample_times, T, A, kappa, mu)) + + fval = self._base_cutoff(time, T, A, kappa, mu) + + return fmin + (fval - fmin) * (1 + scale) + + @staticmethod + def _decay(time: torch.Tensor, A: float, T: float) -> torch.Tensor: + return A * torch.exp(-time * np.log(2) / T) + + def _saa_decay(self, time: torch.Tensor, A: Tuple[float, float], T: Tuple[float, float]) -> torch.Tensor: + exit_times = (self._outlocs - self._start_time)/60 + last_exit = exit_times[torch.searchsorted(exit_times, time, right=True) - 1] + + return self._decay(time - last_exit, A[0], T[0]) + self._decay(time - last_exit, A[1], T[1]) + + def evaluate_rate(self, *args: torch.Tensor) -> torch.Tensor: + time = (args[0] - self._start_time)/60 + rate = self._offset + self._slope * time + rate += self._buildup(time, self._buildup_A[0], self._buildup_T[0]) + rate += self._buildup(time, self._buildup_A[1], self._buildup_T[1]) + rate += self._orbital_period(time, self._pointing_scale(time, self._scale), + self._cutoff_T, self._cutoff_A, + self._cutoff_kappa, self._cutoff_mu) + rate += self._saa_decay(time, self._saa_decay_A, self._saa_decay_T) + + return rate + +class TotalBackgroundDensityCMLPDGaussianCARQSFlow(DensityModel): + def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "default"): + super().__init__(compile_mode, batch_size, worker_device, density_input) + + def _init_model(self, input: Dict): + self._snapshot = input["model_state_dict"] + self._bins = input["bins"] + self._hidden_units = input["hidden_units"] + self._residual_blocks = input["residual_blocks"] + self._total_layers = input["total_layers"] + self._context_size = input["context_size"] + self._mlp_hidden_units = input["mlp_hidden_units"] + self._mlp_hidden_layers = input["mlp_hidden_layers"] + self._menergy_cuts = input["menergy_cuts"] + self._phi_cuts = input["phi_cuts"] + + self._start_time: float = input["start_time"] + self._total_time: float = input["total_time"] + self._period: float = input["period"] + self._slew_duration: float = input["slew_duration"] + self._obs_duration: float = input["obs_duration"] + self._outlocs: torch.Tensor = input["outlocs"].to(self._worker_device) + + return self._load_model() + + @property + def context_dim(self) -> int: + return 1 + + @property + def source_dim(self) -> int: + return 4 + + def _build_model(self) -> nf.ConditionalNormalizingFlow: + base = build_cmlp_diaggaussian_base( + self._context_size, 2 * self.source_dim, self._mlp_hidden_units, self._mlp_hidden_layers + ) + return build_c_arqs_flow( + base, self._total_layers, self.source_dim, self._context_size, self._bins, self._hidden_units, self._residual_blocks + ) + + def _load_model(self) -> NNDensityInferenceWrapper: + model = self._build_model() + + model.load_state_dict(self._snapshot) + model = NNDensityInferenceWrapper(model) + model.eval() + model.to(self._worker_device) + + return model + + def _inverse_transform_coordinates(self, *args: torch.Tensor) -> torch.Tensor: + nem, nphi, npsi, nchi, _ = args + + em = 10 ** (2 * (nem + 1)) + phi = nphi * np.pi + az = npsi * 2 * np.pi + pol = torch.acos(2 * nchi - 1) + + return torch.stack([em, phi, az, pol], dim=1) + + def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + time, em, phi, scatt_az, scatt_pol = args + + jac = 1/(np.log(10) * em * 8*np.pi**2) + + ctx = self._transform_context(time) + + src = torch.cat([ + (torch.log10(em)/2 - 1).unsqueeze(1), + (phi / np.pi).unsqueeze(1), + (scatt_az / (2 * np.pi)).unsqueeze(1), + ((torch.cos(scatt_pol) + 1) / 2).unsqueeze(1) + ], dim=1) + + return ctx.to(torch.float32), src.to(torch.float32), jac.to(torch.float32) + + def _sigmoid_switch(self, t: torch.Tensor, k0: float=1.0) -> torch.Tensor: + half_slew = self._slew_duration / 2.0 + full_cycle = 2 * (self._obs_duration + self._slew_duration) + rel_t = (t - self._start_time) % full_cycle + k = k0 / self._slew_duration + + t1 = self._obs_duration + half_slew + t2 = full_cycle - half_slew + + s1 = 1 / (1 + torch.exp(-k * (rel_t - t1))) + s2 = 1 / (1 + torch.exp(-k * (rel_t - t2))) + s0 = 1 / (1 + torch.exp(-k * (rel_t - (t2 - full_cycle)))) + + return s0 - s1 + s2 + + def _transform_context(self, *args: torch.Tensor) -> torch.Tensor: + time = args[0] + + last_exits = self._outlocs[torch.searchsorted(self._outlocs, time, right=True) - 1] + time_since_start = (time - self._start_time)/self._total_time + pointing_phase = self._sigmoid_switch(time, k0 = 1.0) + time_since_saa = (time - last_exits)/self._period + phase_c = (torch.cos((time - self._start_time)/self._period * 2 * np.pi) + 1) / 2 + phase_s = (torch.sin((time - self._start_time)/self._period * 2 * np.pi) + 1) / 2 + + ctx = torch.hstack([ + (time_since_start).unsqueeze(1), + (pointing_phase).unsqueeze(1), + (time_since_saa).unsqueeze(1), + (phase_c).unsqueeze(1), + (phase_s).unsqueeze(1) + ]) + + return ctx.to(torch.float32) + + def _valid_samples(self, *args: torch.Tensor) -> torch.Tensor: + nem, nphi, npsi, nchi, _ = args + + valid_mask = (nem >= 0.0) & \ + (nphi > 0.0) & (nphi <= 1.0) & \ + (npsi >= 0.0) & (npsi <= 1.0) & \ + (nchi >= 0.0) & (nchi <= 1.0) & \ + (nem >= (np.log10(self._menergy_cuts[0])/2 - 1)) & \ + (nem <= (np.log10(self._menergy_cuts[1])/2 - 1)) & \ + (nphi >= self._phi_cuts[0]/np.pi) & \ + (nphi <= self._phi_cuts[1]/np.pi) + + return valid_mask diff --git a/cosipy/background_estimation/nf_unbinned_background.py b/cosipy/background_estimation/nf_unbinned_background.py new file mode 100644 index 000000000..b581c5bba --- /dev/null +++ b/cosipy/background_estimation/nf_unbinned_background.py @@ -0,0 +1,118 @@ +from typing import Dict, Iterable, Type, Optional + +from astropy import units as u +import numpy as np + +from cosipy import SpacecraftHistory +from cosipy.interfaces.event import EventInterface +from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface +from cosipy.data_io.EmCDSUnbinnedData import TimeTagEmCDSEventInSCFrameInterface +from cosipy.interfaces.background_interface import BackgroundDensityInterface +from cosipy.util.iterables import asarray + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +from cosipy.background_estimation.NFBackground import NFBackground +import torch + +class FreeNormNFUnbinnedBackground(BackgroundDensityInterface): + + def __init__(self, + model: NFBackground, + data: TimeTagEmCDSEventDataInSCFrameInterface, + sc_history: SpacecraftHistory, + label: str = "bkg_norm"): + + self._expected_counts = None + self._expectation_density = None + self._model = model + self._data = data + self._sc_history = sc_history + + self._accum_livetime = self._sc_history.cumulative_livetime().to_value(u.s) + + self._norm = 1 + self._label = label + self._offset: Optional[float] = 1e-12 + + @property + def event_type(self) -> Type[EventInterface]: + return TimeTagEmCDSEventInSCFrameInterface + + @property + def offset(self) -> Optional[float]: + return self._offset + + @offset.setter + def offset(self, offset: Optional[float]): + if (offset is not None) and (offset < 0): + raise ValueError("The offset cannot be negative.") + self._offset = offset + + @property + def norm(self) -> u.Quantity: + self._update_cache(counts_only=True) + return u.Quantity(self._norm * self._expected_counts/self._accum_livetime, u.Hz) + + @norm.setter + def norm(self, norm: u.Quantity): + self._update_cache(counts_only=True) + self._norm = norm.to_value(u.Hz) * self._accum_livetime/self._expected_counts + + def set_parameters(self, **parameters: u.Quantity) -> None: + self.norm = parameters[self._label] + + @property + def parameters(self) -> Dict[str, u.Quantity]: + return {self._label: self.norm} + + def _integrate_rate(self) -> float: + mid_times = torch.as_tensor((self._sc_history.obstime[:-1] + (self._sc_history.obstime[1:] - self._sc_history.obstime[:-1]) / 2).utc.unix).view(-1, 1) + rate = self._model.evaluate_rate(mid_times) + return torch.sum(rate * torch.as_tensor(self._sc_history.livetime)).item() + + def _compute_density(self): + self._energy_m_keV = torch.as_tensor(asarray(self._data.energy_keV, dtype=np.float32)) + self._phi_rad = torch.as_tensor(asarray(self._data.scattering_angle_rad, dtype=np.float32)) + self._lon_scatt = torch.as_tensor(asarray(self._data.scattered_lon_rad_sc, dtype=np.float32)) + self._lat_scatt = torch.as_tensor(asarray(self._data.scattered_lat_rad_sc, dtype=np.float32)) + source = torch.stack((self._energy_m_keV, self._phi_rad, self._lon_scatt, np.pi/2 - self._lat_scatt), dim=1) + + time = torch.as_tensor(self._data.time.utc.unix).view(-1, 1) + if torch.any((time < self._sc_history.tstart.utc.unix) | (time > self._sc_history.tstop.utc.unix)): + raise ValueError("Input times are outside the spacecraft history range") + interval_ratios = torch.as_tensor(self._sc_history.livetime.to_value(u.s) / self._sc_history.intervals_duration.to_value(u.s)) + factor = torch.searchsorted(torch.as_tensor(self._sc_history.obstime.utc.unix), time.view(-1), right=True) - 1 + + return np.asarray(self._model.evaluate_density(time, source) * self._model.evaluate_rate(time) * interval_ratios[factor], dtype=np.float64) + + def _update_cache(self, counts_only=False): + if self._expected_counts is None: + self._expected_counts = self._integrate_rate() + + if (self._expectation_density is None) and (not counts_only): + active_pool = self._model.active_pool + if not active_pool: + self._model.init_compute_pool() + self._expectation_density = self._compute_density() + if not active_pool: + self._model.shutdown_compute_pool() + + def expected_counts(self) -> float: + self._update_cache() + + return self._expected_counts * self._norm + + def expectation_density(self) -> Iterable[float]: + self._update_cache() + + result = self._expectation_density * self._norm + + if self._offset is not None: + return result + self._offset + else: + return result + \ No newline at end of file diff --git a/cosipy/data_io/EmCDSUnbinnedData.py b/cosipy/data_io/EmCDSUnbinnedData.py index 3316ab570..1c0d3e7f0 100644 --- a/cosipy/data_io/EmCDSUnbinnedData.py +++ b/cosipy/data_io/EmCDSUnbinnedData.py @@ -8,7 +8,7 @@ from numpy._typing import ArrayLike from scoords import SpacecraftFrame -from cosipy import UnBinnedData +from cosipy.data_io.UnBinnedData import UnBinnedData from cosipy.interfaces import EventWithEnergyInterface, EventDataInterface, EventDataWithEnergyInterface from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface, EmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface, \ diff --git a/cosipy/interfaces/expectation_interface.py b/cosipy/interfaces/expectation_interface.py index 0ee8be283..c01657923 100644 --- a/cosipy/interfaces/expectation_interface.py +++ b/cosipy/interfaces/expectation_interface.py @@ -117,7 +117,7 @@ def __init__(self, *expectations:Tuple[ExpectationDensityInterface, None], vecto Parameters ---------- expectations: Other ExpectationDensityInterface implementations - vectorize: It True (default), it will first cache all the individual expectations on numpy arrays, and then it will sum + vectorize: If True (default), it will first cache all the individual expectations on numpy arrays, and then it will sum them up using numpy's method. The output will also be a numpy. If False, it will query one element from each expectation object a time and sum them up. The output in this case is an Generator. """ diff --git a/cosipy/interfaces/source_response_interface.py b/cosipy/interfaces/source_response_interface.py index ba2e488b9..0984aea9d 100644 --- a/cosipy/interfaces/source_response_interface.py +++ b/cosipy/interfaces/source_response_interface.py @@ -1,6 +1,7 @@ -from typing import Protocol, runtime_checkable +from typing import Protocol, runtime_checkable, Union from astromodels import Model from astromodels.sources import Source +from pathlib import Path from .expectation_interface import BinnedExpectationInterface, ExpectationDensityInterface @@ -64,6 +65,28 @@ class UnbinnedThreeMLSourceResponseInterface(ThreeMLSourceResponseInterface, Exp No new methods. Just the inherited ones. """ +@runtime_checkable +class CachedUnbinnedThreeMLSourceResponseInterface(UnbinnedThreeMLSourceResponseInterface, Protocol): + """ + Guaranteeing that the source response can be cached to and loaded from a file. + """ + + def cache_to_file(self, filename: Union[str, Path]): + """ + Saves the calculated response cache to the specified HDF5 file. + The implementation has to make sure that the source is handled correctly. + """ + + def cache_from_file(self, filename: Union[str, Path]): + """Loads the response cache from the specified HDF5 file.""" + + def init_cache(self): + """ + Initialize the response cache that can be saved to file. + This way there is no need to call expected_counts() or expectation_density() to initialize the cache. + Make sure that repeated calls don't lead to unnecessary recomputations. + """ + @runtime_checkable class BinnedThreeMLSourceResponseInterface(ThreeMLSourceResponseInterface, BinnedExpectationInterface, Protocol): """ diff --git a/cosipy/response/NFBase.py b/cosipy/response/NFBase.py new file mode 100644 index 000000000..979437abe --- /dev/null +++ b/cosipy/response/NFBase.py @@ -0,0 +1,569 @@ +from typing import List, Union, Optional, Literal, Tuple, Dict, Optional, Callable +from pathlib import Path +from abc import ABC, abstractmethod +import numpy as np +from tqdm.auto import tqdm +import queue + + +from importlib.util import find_spec + +if any(find_spec(pkg) is None for pkg in ["torch", "normflows"]): + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch +from torch import nn +import torch.multiprocessing as mp +import normflows as nf +import cosipy.response.NFWorkerState as NFWorkerState + + +CompileMode = Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] + +def build_cmlp_diaggaussian_base(input_dim: int, output_dim: int, + hidden_dim: int, num_hidden_layers: int) -> nf.distributions.BaseDistribution: + context_encoder = BaseMLP(input_dim, output_dim, hidden_dim, num_hidden_layers) + return ConditionalDiagGaussian(shape=output_dim//2, context_encoder=context_encoder) + +def build_c_arqs_flow(base: nf.distributions.BaseDistribution, num_layers: int, + latent_dim: int, context_dim: int, num_bins: int, + num_hidden_units: int, num_residual_blocks: int) -> nf.ConditionalNormalizingFlow: + flows = [] + for _ in range(num_layers): + flows += [nf.flows.AutoregressiveRationalQuadraticSpline(num_input_channels = latent_dim, + num_blocks = num_residual_blocks, + num_hidden_channels = num_hidden_units, + num_bins = num_bins, + num_context_channels = context_dim)] + flows += [nf.flows.LULinearPermute(latent_dim)] + return nf.ConditionalNormalizingFlow(base, flows) + +class NNDensityInferenceWrapper(nn.Module): + def __init__(self, model: nn.Module): + super().__init__() + self._model = model + + def forward(self, + source: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + n_samples: Optional[int] = None, + mode: str = "inference") -> torch.Tensor: + if mode == "inference": + if context is None: + return torch.exp(self._model.log_prob(source)) + else: + return torch.exp(self._model.log_prob(source, context)) + elif mode == "sampling": + if context is None: + return self._model.sample(num_samples=n_samples)[0] + else: + return self._model.sample(num_samples=n_samples, context=context)[0] + +class ConditionalDiagGaussian(nf.distributions.BaseDistribution): + def __init__(self, shape: Union[int, List[int], Tuple[int, ...]], context_encoder: nn.Module): + super().__init__() + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, list): + shape = tuple(shape) + self.shape = shape + self.n_dim = len(shape) + self.d = np.prod(shape) + self.context_encoder = context_encoder + + def forward(self, num_samples: int=1, context: Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, torch.Tensor]: + encoder_output = self.context_encoder(context) + split_ind = encoder_output.shape[-1] // 2 + mean = encoder_output[..., :split_ind] + log_scale = encoder_output[..., split_ind:] + eps = torch.randn( + (num_samples,) + self.shape, dtype=mean.dtype, device=mean.device + ) + z = mean + torch.exp(log_scale) * eps + log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( + log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1)) + ) + return z, log_p + + def log_prob(self, z: torch.Tensor, context: Optional[torch.Tensor]=None) -> torch.Tensor: + encoder_output = self.context_encoder(context) + split_ind = encoder_output.shape[-1] // 2 + mean = encoder_output[..., :split_ind] + log_scale = encoder_output[..., split_ind:] + log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( + log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2), + list(range(1, self.n_dim + 1)), + ) + return log_p + +class BaseMLP(nn.Module): + def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_hidden_layers: int): + super().__init__() + + layers = [] + + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(nn.ReLU()) + + for _ in range(num_hidden_layers): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(nn.ReLU()) + + layers.append(nn.Linear(hidden_dim, output_dim)) + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class BaseModel(ABC): + + def __init__(self, compile_mode: CompileMode, batch_size: int, + worker_device: Union[str, int, torch.device], input: Dict): + self._worker_device = torch.device(worker_device) + + self._base_model = self._init_model(input) + + self._compile_mode = compile_mode + self._compiled_cache = {} + + self._update_model_op() + + self._is_cuda = (self._worker_device.type == 'cuda') + self.batch_size = batch_size + + @abstractmethod + def _init_model(self, input: Dict) -> Union[nn.Module, Callable]: ... + + @property + @abstractmethod + def context_dim(self) -> int: ... + + @property + def compile_mode(self) -> CompileMode: + return self._compile_mode + + @compile_mode.setter + def compile_mode(self, value: CompileMode): + if value != self._compile_mode: + self._compile_mode = value + self._update_model_op() + + def _update_model_op(self): + if self._compile_mode is None: + self._model_op = self._base_model + else: + if self._compile_mode not in self._compiled_cache: + self._compiled_cache[self._compile_mode] = torch.compile( + self._base_model, + mode=self._compile_mode + ) + self._model_op = self._compiled_cache[self._compile_mode] + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError(f"Batch size must be a positive integer, got {value}") + self._batch_size = value + +class AreaModel(BaseModel): + @abstractmethod + def evaluate_effective_area(self, *args: torch.Tensor, progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: ... + +class DensityModel(BaseModel): + @property + @abstractmethod + def source_dim(self) -> int: ... + + @torch.inference_mode() + def sample_density(self, *args: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: + N = args[0].shape[0] + + result = torch.empty((N, self.source_dim), dtype=torch.float32, device="cpu") + failed_mask = torch.zeros(N, dtype=torch.bool, device="cpu") + + for start in range(0, N, self._batch_size): + end = min(start + self._batch_size, N) + batch_len = end - start + + b_ctx = [t[start:end].to(self._worker_device) for t in args] + n_ctx = self._transform_context(*b_ctx) + + n_latent = self._model_op(context=n_ctx, mode="sampling", n_samples=batch_len) + result[start:end] = self._inverse_transform_coordinates(*(n_latent.T), *b_ctx) + failed_mask[start:end] = ~self._valid_samples(*(n_latent.T), *b_ctx) + + if progress_callback is not None: + amount = batch_len - torch.sum(failed_mask[start:end]).item() + progress_callback(amount) + + if torch.any(failed_mask): + result[failed_mask] = self.sample_density(*[t[failed_mask] for t in args], progress_callback=progress_callback) + + return result + + @abstractmethod + def _inverse_transform_coordinates(self, *args: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def _valid_samples(self, *args: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def _transform_context(self, *args: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @torch.inference_mode() + def evaluate_density(self, *args: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: + + N = args[0].shape[0] + result = torch.empty(N, dtype=torch.float32, device="cpu") + + for start in range(0, N, self._batch_size): + end = min(start + self._batch_size, N) + batch_len = end - start + + ctx, src, jac = self._transform_coordinates(*[t[start:end].to(self._worker_device) for t in args]) + result[start:end] = self._model_op(src, ctx, mode="inference") * jac + + if progress_callback is not None: + progress_callback(batch_len) + + return result + +class RateModel(ABC): + def __init__(self, rate_input: Dict): + self._unpack_rate_input(rate_input) + + @property + @abstractmethod + def context_dim(self) -> int: ... + + @abstractmethod + def _unpack_rate_input(self, rate_input: Dict): ... + + @abstractmethod + def evaluate_rate(self, *args: torch.Tensor) -> torch.Tensor: ... + + +class DensityApproximation(ABC): + def __init__(self, major_version: int, density_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): + self._major_version = major_version + self._worker_device = worker_device + self._density_input = density_input + self._batch_size = batch_size + self._compile_mode = compile_mode + + self._model: DensityModel + self._expected_context_dim: int + self._expected_source_dim: int + + self._setup_model() + + @abstractmethod + def _setup_model(self): ... + + def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: + dim_context = context.shape[1] + dim_source = source.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + elif dim_source != self._expected_source_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_source_dim} features, but source has {dim_source}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + list_source = [source[:, i] for i in range(dim_source)] + + return self._model.evaluate_density(*list_context, *list_source, progress_callback=progress_callback) + + def sample_density(self, context: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: + dim_context = context.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + + return self._model.sample_density(*list_context, progress_callback=progress_callback) + +def cuda_cleanup_task(_) -> bool: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return True + +def update_density_worker_settings(args: Tuple[str, Union[int, CompileMode]]): + attr, value = args + + if attr == 'density_batch_size': + NFWorkerState.density_module._model.batch_size = value + elif attr == 'density_compile_mode': + NFWorkerState.density_module._model.compile_mode = value + +def init_density_worker(device_queue: mp.Queue, progress_queue: mp.Queue, major_version: int, + density_input: Dict, density_batch_size: int, + density_compile_mode: CompileMode, density_approximation: DensityApproximation): + + NFWorkerState.worker_device = torch.device(device_queue.get()) + NFWorkerState.progress_queue = progress_queue + if NFWorkerState.worker_device.type == 'cuda': + torch.cuda.set_device(NFWorkerState.worker_device) + + NFWorkerState.density_module = density_approximation(major_version, density_input, NFWorkerState.worker_device, density_batch_size, density_compile_mode) + +def evaluate_density_task(args: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: + context, source, indices = args + + sub_context = context[indices, :] + sub_source = source[indices, :] + + cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None + return NFWorkerState.density_module.evaluate_density(sub_context, sub_source, progress_callback=cb) + +def sample_density_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + context, indices = args + + sub_context = context[indices, :] + + cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None + return NFWorkerState.density_module.sample_density(sub_context, progress_callback=cb) + +class NFBase(): + def __init__(self, path_to_model: Union[str, Path], update_worker, pool_worker, density_batch_size: int = 100_000, + devices: Optional[List[Union[str, int, torch.device]]] = None, density_compile_mode: CompileMode = "default", + additional_required_keys: List[str] = None, show_progress: bool = True): + self._ckpt = torch.load(str(path_to_model), map_location=torch.device('cpu'), weights_only=False) + + required_keys = ['version', 'density_input'] + (additional_required_keys or []) + + for key in required_keys: + if key not in self._ckpt: + raise KeyError( + f"Invalid Checkpoint: Metadata key '{key}' not found in {str(path_to_model)}. " + f"Ensure you saved the model as a dictionary, not just the state_dict." + ) + + self._version = self._ckpt['version'] + self._major_version = int(self._version.split('.')[0]) + self._density_input = self._ckpt['density_input'] + + self._pool = None + self._has_cuda = False + self._num_workers = 0 + self._ctx = mp.get_context("spawn") + self._pool_worker = pool_worker + self._update_worker = update_worker + self.show_progress = show_progress + self._progress_queue = None + + self.density_batch_size = density_batch_size + self.density_compile_mode = density_compile_mode + + self._update_pool_arguments() + + if devices is not None: + self.devices = devices + else: + self._devices = [] + + def __del__(self): + self.shutdown_compute_pool() + + @property + def show_progress(self) -> bool: + return self._show_progress + + @show_progress.setter + def show_progress(self, value: bool): + if not isinstance(value, bool): + raise ValueError("show_progress must be a boolean") + self._show_progress = value + + @property + def devices(self) -> List[Union[str, int, torch.device]]: + return self._devices + + @devices.setter + def devices(self, value: List[Union[str, int, torch.device]]): + if not isinstance(value, list): + raise ValueError("devices must be a list of device identifiers") + self._devices = value + + @property + def density_batch_size(self) -> int: + return self._density_batch_size + + @density_batch_size.setter + def density_batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError("density_batch_size must be a positive integer") + self._density_batch_size = value + self._update_pool_arguments() + self._update_worker_config('density_batch_size', value) + + @property + def density_compile_mode(self) -> CompileMode: return self._density_compile_mode + + @density_compile_mode.setter + def density_compile_mode(self, value: CompileMode): + self._density_compile_mode = value + self._update_pool_arguments() + self._update_worker_config('density_compile_mode', value) + + @property + def active_pool(self) -> bool: return self._pool is not None + + def _update_worker_config(self, attr: str, value: Union[int, CompileMode]): + if self._pool is not None: + self._pool.map(self._update_worker, [(attr, value)] * self._num_workers) + + def _update_pool_arguments(self): + self._pool_arguments = [ + getattr(self, "_major_version", None), + getattr(self, "_density_input", None), + getattr(self, "_density_batch_size", None), + getattr(self, "_density_compile_mode", None), + ] + + def clean_compute_pool(self): + if self._pool: + self._pool.map(cuda_cleanup_task, range(self._num_workers)) + + def shutdown_compute_pool(self): + if self._pool: + self._pool.close() + self._pool.join() + + self._num_workers = 0 + self._pool = None + self._has_cuda = None + + self._progress_queue.close() + self._progress_queue.join_thread() + self._progress_queue = None + + def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device]]]=None): + active_devices = devices if devices is not None else self._devices + + if not active_devices: + raise RuntimeError("Cannot initialize pool: no devices provided as argument or set as fallback.") + + if self._pool: + print("Warning: Pool already initialized. Shutting down old pool first.") + self.shutdown_compute_pool() + + self._num_workers = len(active_devices) + self._has_cuda = any(torch.device(d).type == 'cuda' for d in active_devices) + + device_queue = self._ctx.Queue() + for d in active_devices: + device_queue.put(d) + self._progress_queue = self._ctx.Queue() + + self._pool = self._ctx.Pool( + processes=self._num_workers, + initializer=self._pool_worker, + initargs=(device_queue, self._progress_queue,*self._pool_arguments), + ) + + def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): + context.share_memory_() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, idx) for idx in indices] + + async_result = self._pool.map_async(sample_density_task, tasks) + with tqdm(total=n_data, disable=(not self.show_progress), desc="Sampling the density", unit="calls", leave=False, smoothing=0.20) as pbar: + while not async_result.ready(): + try: + while True: + completed = self._progress_queue.get_nowait() + pbar.update(completed) + except queue.Empty: + async_result.wait(timeout=0.1) + + while not self._progress_queue.empty(): + try: + pbar.update(self._progress_queue.get_nowait()) + except queue.Empty: + break + + results = async_result.get() + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() + + def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): context.share_memory_() + if not source.is_shared(): source.share_memory_() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, source, idx) for idx in indices] + + async_result = self._pool.map_async(evaluate_density_task, tasks) + with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the density", unit="calls", leave=False, smoothing=0.20) as pbar: + while not async_result.ready(): + try: + while True: + completed = self._progress_queue.get_nowait() + pbar.update(completed) + except queue.Empty: + async_result.wait(timeout=0.1) + + while not self._progress_queue.empty(): + try: + pbar.update(self._progress_queue.get_nowait()) + except queue.Empty: + break + + results = async_result.get() + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() diff --git a/cosipy/response/NFResponse.py b/cosipy/response/NFResponse.py new file mode 100644 index 000000000..12514b9ea --- /dev/null +++ b/cosipy/response/NFResponse.py @@ -0,0 +1,185 @@ +from typing import List, Union, Optional, Dict, Tuple, Callable +from pathlib import Path +from tqdm.auto import tqdm +import queue + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch +import torch.multiprocessing as mp +from .NFResponseModels import UnpolarizedDensityCMLPDGaussianCARQSFlow, UnpolarizedAreaSphericalHarmonicsExpansion +from .NFBase import DensityApproximation, CompileMode, NFBase, init_density_worker, update_density_worker_settings, AreaModel, DensityModel +import cosipy.response.NFWorkerState as NFWorkerState + + +class ResponseDensityApproximation(DensityApproximation): + + def _setup_model(self): + version_map: Dict[int, DensityModel] = { + 1: UnpolarizedDensityCMLPDGaussianCARQSFlow(self._density_input, self._worker_device, self._batch_size, self._compile_mode), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Density Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + self._expected_source_dim = self._model.source_dim + +class AreaApproximation: + def __init__(self, major_version: int, area_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): + self._major_version = major_version + self._worker_device = worker_device + self._area_input = area_input + self._batch_size = batch_size + self._compile_mode = compile_mode + + self._setup_model() + + def _setup_model(self): + version_map: Dict[int, AreaModel] = { + 1: UnpolarizedAreaSphericalHarmonicsExpansion(self._area_input, self._worker_device, self._batch_size, self._compile_mode), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Effective Area Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + + def evaluate_effective_area(self, context: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: + dim_context = context.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + + return self._model.evaluate_effective_area(*list_context, progress_callback=progress_callback) + +def update_response_worker_settings(args: Tuple[str, Union[int, CompileMode]]): + update_density_worker_settings(args) + + attr, value = args + + if attr == 'area_batch_size': + NFWorkerState.area_module._model.batch_size = value + elif attr == 'area_compile_mode': + NFWorkerState.area_module._model.compile_mode = value + +def init_response_worker(device_queue: mp.Queue, progress_queue: mp.Queue, major_version: int, area_input: Dict, + density_input: Dict, area_batch_size: int, density_batch_size: int, + area_compile_mode: CompileMode, density_compile_mode: CompileMode): + + init_density_worker(device_queue, progress_queue, major_version, + density_input, density_batch_size, + density_compile_mode, ResponseDensityApproximation) + + NFWorkerState.area_module = AreaApproximation(major_version, area_input, NFWorkerState.worker_device, area_batch_size, area_compile_mode) + +def evaluate_area_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + context, indices = args + + sub_context = context[indices, :] + + cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None + return NFWorkerState.area_module.evaluate_effective_area(sub_context, progress_callback=cb) + +class NFResponse(NFBase): + def __init__(self, path_to_model: Union[str, Path], area_batch_size: int = 300_000, density_batch_size: int = 100_000, + devices: Optional[List[Union[str, int, torch.device]]] = None, area_compile_mode: CompileMode = "max-autotune-no-cudagraphs", + density_compile_mode: CompileMode = "default", show_progress: bool = True): + + super().__init__(path_to_model, update_response_worker_settings, init_response_worker, density_batch_size, devices, density_compile_mode, ['is_polarized', 'area_input'], show_progress) + + self._is_polarized = self._ckpt['is_polarized'] + self._area_input = self._ckpt['area_input'] + + self.area_batch_size = area_batch_size + self.area_compile_mode = area_compile_mode + + self._update_pool_arguments() + + @property + def is_polarized(self) -> bool: + return self._is_polarized + + @property + def area_batch_size(self) -> int: + return self._area_batch_size + + @area_batch_size.setter + def area_batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError("area_batch_size must be a positive integer") + self._area_batch_size = value + self._update_pool_arguments() + self._update_worker_config('area_batch_size', value) + + @property + def area_compile_mode(self) -> CompileMode: return self._area_compile_mode + + @area_compile_mode.setter + def area_compile_mode(self, value: CompileMode): + self._area_compile_mode = value + self._update_pool_arguments() + self._update_worker_config('area_compile_mode', value) + + def _update_pool_arguments(self): + self._pool_arguments = [ + getattr(self, "_major_version", None), + getattr(self, "_area_input", None), + getattr(self, "_density_input", None), + getattr(self, "_area_batch_size", None), + getattr(self, "_density_batch_size", None), + getattr(self, "_area_compile_mode", None), + getattr(self, "_density_compile_mode", None), + ] + + def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): + context.share_memory_() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, idx) for idx in indices] + + async_result = self._pool.map_async(evaluate_area_task, tasks) + with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the effective area", unit="calls", leave=False, smoothing=0.20) as pbar: + while not async_result.ready(): + try: + while True: + completed = self._progress_queue.get_nowait() + pbar.update(completed) + except queue.Empty: + async_result.wait(timeout=0.1) + + while not self._progress_queue.empty(): + try: + pbar.update(self._progress_queue.get_nowait()) + except queue.Empty: + break + + results = async_result.get() + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() \ No newline at end of file diff --git a/cosipy/response/NFResponseModels.py b/cosipy/response/NFResponseModels.py new file mode 100644 index 000000000..1162f1f5c --- /dev/null +++ b/cosipy/response/NFResponseModels.py @@ -0,0 +1,267 @@ +import numpy as np +import healpy as hp + +from typing import Union, Tuple, Dict, Optional, Callable + + +from importlib.util import find_spec + +if any(find_spec(pkg) is None for pkg in ["torch", "normflows", "sphericart.torch"]): + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +from .NFBase import CompileMode, build_c_arqs_flow, build_cmlp_diaggaussian_base, NNDensityInferenceWrapper, AreaModel, DensityModel +import sphericart.torch +import normflows as nf +import torch + + +class UnpolarizedAreaSphericalHarmonicsExpansion(AreaModel): + def __init__(self, area_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode, batch_size, worker_device, area_input) + + def _init_model(self, input: Dict): + self._lmax = input['lmax'] + self._poly_degree = input['poly_degree'] + self._poly_coeffs = input['poly_coeffs'] + + self._conv_coeffs = self._convert_coefficients().to(self._worker_device) + self._sh_calculator = sphericart.torch.SphericalHarmonics(self._lmax) + + return self._horner_eval + + @property + def context_dim(self) -> int: + return 3 + + def _convert_coefficients(self) -> torch.Tensor: + num_sh = (self._lmax + 1)**2 + conv_coeffs = torch.zeros((num_sh, self._poly_degree + 1), dtype=torch.float64) + + for cnt, (l, m) in enumerate((l, m) for l in range(self._lmax + 1) for m in range(-l, l + 1)): + idx = hp.Alm.getidx(self._lmax, l, abs(m)) + if m == 0: + conv_coeffs[cnt] = self._poly_coeffs[0, :, idx] + else: + fac = np.sqrt(2) * (-1)**m + val = self._poly_coeffs[0, :, idx] if m > 0 else -self._poly_coeffs[1, :, idx] + conv_coeffs[cnt] = fac * val + return conv_coeffs.T + + def _horner_eval(self, x: torch.Tensor) -> torch.Tensor: + x_64 = x.to(torch.float64).unsqueeze(1) + result = self._conv_coeffs[0].expand(x.shape[0], -1).clone() + for i in range(1, self._conv_coeffs.size(0)): + result.mul_(x_64).add_(self._conv_coeffs[i]) + return result.to(torch.float32) + + def _compute_spherical_harmonics(self, dir_az: torch.Tensor, dir_polar: torch.Tensor) -> torch.Tensor: + sin_p = torch.sin(dir_polar) + xyz = torch.stack(( + sin_p * torch.cos(dir_az), + sin_p * torch.sin(dir_az), + torch.cos(dir_polar) + ), dim=-1) + return self._sh_calculator(xyz) + + @torch.inference_mode() + def evaluate_effective_area(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: + N = energy_keV.shape[0] + + ei_norm = (torch.log10(energy_keV) / 2 - 1).to(torch.float32) + result = torch.empty(N, dtype=torch.float32, device="cpu") + + def get_batch(start_idx): + end_idx = min(start_idx + self._batch_size, N) + return ( + ei_norm[start_idx:end_idx].to(self._worker_device), + dir_az[start_idx:end_idx].to(self._worker_device), + dir_polar[start_idx:end_idx].to(self._worker_device) + ) + + for start in range(0, N, self._batch_size): + end = min(start + self._batch_size, N) + batch_len = end - start + + ei_b, az_b, pol_b = get_batch(start) + + poly_b = self._model_op(ei_b) + ylm_b = self._compute_spherical_harmonics(az_b, pol_b) + result[start:end] = torch.sum(poly_b * ylm_b, dim=1) + + if progress_callback is not None: + progress_callback(batch_len) + + return torch.clamp(result, min=0) + +class UnpolarizedDensityCMLPDGaussianCARQSFlow(DensityModel): + def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "default"): + super().__init__(compile_mode, batch_size, worker_device, density_input) + + def _init_model(self, input: Dict): + self._snapshot = input["model_state_dict"] + self._bins = input["bins"] + self._hidden_units = input["hidden_units"] + self._residual_blocks = input["residual_blocks"] + self._total_layers = input["total_layers"] + self._context_size = input["context_size"] + self._mlp_hidden_units = input["mlp_hidden_units"] + self._mlp_hidden_layers = input["mlp_hidden_layers"] + self._menergy_cuts = input["menergy_cuts"] + self._phi_cuts = input["phi_cuts"] + + return self._load_model() + + @property + def context_dim(self) -> int: + return 3 + + @property + def source_dim(self) -> int: + return 4 + + def _build_model(self) -> nf.ConditionalNormalizingFlow: + base = build_cmlp_diaggaussian_base( + self._context_size, 2 * self.source_dim, self._mlp_hidden_units, self._mlp_hidden_layers + ) + return build_c_arqs_flow( + base, self._total_layers, self.source_dim, self._context_size, self._bins, self._hidden_units, self._residual_blocks + ) + + def _load_model(self) -> NNDensityInferenceWrapper: + model = self._build_model() + + model.load_state_dict(self._snapshot) + model = NNDensityInferenceWrapper(model) + model.eval() + model.to(self._worker_device) + + return model + + @staticmethod + def _get_vector(phi_sc: torch.Tensor, theta_sc: torch.Tensor) -> torch.Tensor: + x = theta_sc[:, 0] * phi_sc[:, 1] + y = theta_sc[:, 0] * phi_sc[:, 0] + z = theta_sc[:, 1] + return torch.stack((x, y, z), dim=-1) + + def _convert_conventions(self, dir_az_sc: torch.Tensor, dir_polar_sc: torch.Tensor, + ei: torch.Tensor, em: torch.Tensor, phi: torch.Tensor, + scatt_az_sc: torch.Tensor, scatt_polar_sc: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + eps = em / ei - 1 + + source = self._get_vector(dir_az_sc, dir_polar_sc) + scatter = self._get_vector(scatt_az_sc, scatt_polar_sc) + + dot_product = torch.sum(source * scatter, dim=1) + phi_geo = torch.acos(torch.clamp(dot_product, -1.0, 1.0)) + theta = phi_geo - phi + + xaxis = torch.tensor([1., 0., 0.], device=source.device, dtype=source.dtype) + pz = -source + + px = torch.linalg.cross(pz, xaxis.expand_as(pz)) + px = px / torch.linalg.norm(px, dim=1, keepdim=True) + + py = torch.linalg.cross(pz, px) + py = py / torch.linalg.norm(py, dim=1, keepdim=True) + + proj_x = torch.sum(scatter * px, dim=1) + proj_y = torch.sum(scatter * py, dim=1) + + zeta = torch.atan2(proj_y, proj_x) + zeta = torch.where(zeta < 0, zeta + 2 * np.pi, zeta) + + return eps, theta, zeta + + def _inverse_transform_coordinates(self, *args: torch.Tensor) -> torch.Tensor: + neps, nphi, ntheta, nzeta, dir_az, dir_pol, ei = args + + eps = -neps + phi = nphi * np.pi + theta = (ntheta - 0.5) * (2 * np.pi) + zeta = nzeta * (2 * np.pi) + + em = ei * (eps + 1) + + phi_geo = theta + phi + scatter_phf = self._get_vector(torch.stack((torch.sin(zeta), torch.cos(zeta)), dim=1), + torch.stack((torch.sin(np.pi - phi_geo), torch.cos(np.pi - phi_geo)), dim=1)) + + dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) + dir_pol_sc = torch.stack((torch.sin(dir_pol), torch.cos(dir_pol)), dim=1) + source_vec = self._get_vector(dir_az_sc, dir_pol_sc) + xaxis = torch.tensor([1., 0., 0.], device=self._worker_device, dtype=source_vec.dtype) + + pz = -source_vec + px = torch.linalg.cross(pz, xaxis.expand_as(pz)) + px = px / torch.linalg.norm(px, dim=1, keepdim=True) + py = torch.linalg.cross(pz, px) + py = py / torch.linalg.norm(py, dim=1, keepdim=True) + + basis = torch.stack((px, py, pz), dim=2) + scatter_scf = torch.bmm(basis, scatter_phf.unsqueeze(-1)).squeeze(-1) + + psi_cds = torch.atan2(scatter_scf[:, 1], scatter_scf[:, 0]) + psi_cds = torch.where(psi_cds < 0, psi_cds + 2 * np.pi, psi_cds) + chi_cds = torch.acos(torch.clamp(scatter_scf[:, 2], -1.0, 1.0)) + + return torch.stack([em, phi, psi_cds, chi_cds], dim=1) + + def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dir_az, dir_pol, ei, em, phi, scatt_az, scatt_pol = args + + dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) + dir_pol_sc = torch.stack((torch.sin(dir_pol), torch.cos(dir_pol)), dim=1) + scatt_az_sc = torch.stack((torch.sin(scatt_az), torch.cos(scatt_az)), dim=1) + scatt_pol_sc = torch.stack((torch.sin(scatt_pol), torch.cos(scatt_pol)), dim=1) + + eps_raw, theta_raw, zeta_raw = self._convert_conventions( + dir_az_sc, dir_pol_sc, ei, em, phi, scatt_az_sc, scatt_pol_sc + ) + + jac = 1.0 / (ei * torch.sin(theta_raw + phi) * 4 * np.pi**3) + jac[torch.isinf(jac) | (jac < 0)] = 0.0 + + ctx = self._transform_context(dir_az, dir_pol, ei) + + src = torch.cat([ + (-eps_raw).unsqueeze(1), + (phi / np.pi).unsqueeze(1), + (theta_raw / (2 * np.pi) + 0.5).unsqueeze(1), + (zeta_raw / (2 * np.pi)).unsqueeze(1) + ], dim=1) + + return ctx.to(torch.float32), src.to(torch.float32), jac.to(torch.float32) + + def _transform_context(self, *args: torch.Tensor) -> torch.Tensor: + dir_az, dir_pol, ei = args + + dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) + dir_pol_c = torch.cos(dir_pol).unsqueeze(1) + + ctx = torch.cat([ + (dir_az_sc + 1) / 2, + (dir_pol_c + 1) / 2, + (torch.log10(ei) / 2 - 1).unsqueeze(1) + ], dim=1) + + return ctx.to(torch.float32) + + def _valid_samples(self, *args: torch.Tensor) -> torch.Tensor: + neps, nphi, ntheta, nzeta, _, _, ei = args + + phi_geo_norm = nphi + 2 * ntheta - 1.0 + valid_mask = (neps < 1.0) & \ + (nphi > 0.0) & (nphi <= 1.0) & \ + (ntheta >= 0.0) & (ntheta <= 1.0) & \ + (nzeta >= 0.0) & (nzeta <= 1.0) & \ + (phi_geo_norm > 0.0) & (phi_geo_norm < 1.0) & \ + (neps <= (1 - self._menergy_cuts[0]/ei)) & \ + (neps >= (1 - self._menergy_cuts[1]/ei)) & \ + (nphi >= self._phi_cuts[0]/np.pi) & \ + (nphi <= self._phi_cuts[1]/np.pi) + + return valid_mask diff --git a/cosipy/response/NFWorkerState.py b/cosipy/response/NFWorkerState.py new file mode 100644 index 000000000..6e28c0e83 --- /dev/null +++ b/cosipy/response/NFWorkerState.py @@ -0,0 +1,4 @@ +worker_device = None +density_module = None +area_module = None +progress_queue = None \ No newline at end of file diff --git a/cosipy/response/__init__.py b/cosipy/response/__init__.py index 271a8858c..ff9b83d0a 100644 --- a/cosipy/response/__init__.py +++ b/cosipy/response/__init__.py @@ -8,4 +8,4 @@ from .threeml_point_source_response import * from .threeml_extended_source_response import * from .instrument_response import * -from .rsp_to_arf_rmf import RspArfRmfConverter +from .rsp_to_arf_rmf import RspArfRmfConverter \ No newline at end of file diff --git a/cosipy/response/nf_instrument_response_function.py b/cosipy/response/nf_instrument_response_function.py new file mode 100644 index 000000000..01fc46aec --- /dev/null +++ b/cosipy/response/nf_instrument_response_function.py @@ -0,0 +1,81 @@ +from typing import Iterable, Optional, List, Union + +import numpy as np + +from cosipy.interfaces.data_interface import EmCDSEventDataInSCFrameInterface +from cosipy.interfaces.instrument_response_interface import FarFieldSpectralInstrumentResponseFunctionInterface +from cosipy.interfaces.photon_parameters import PhotonListWithDirectionAndEnergyInSCFrameInterface +from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays +from cosipy.response.NFResponse import NFResponse +from cosipy.util.iterables import asarray + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch + + +class UnpolarizedNFFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): + + event_data_type = EmCDSEventDataInSCFrameInterface + photon_list_type = PhotonListWithDirectionAndEnergyInSCFrameInterface + + def __init__(self, response: NFResponse,): + if response.is_polarized: + raise ValueError("The provided NNResponse is polarized, but UnpolarizedNNFarFieldInstrumentResponseFunction only supports unpolarized responses.") + self._response = response + + def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device]]]=None): + self._response.init_compute_pool(devices) + + def shutdown_compute_pool(self): + self._response.shutdown_compute_pool() + + @property + def active_pool(self) -> bool: return self._response.active_pool + + @staticmethod + def _get_context(photons: PhotonListWithDirectionAndEnergyInSCFrameInterface): + lon = torch.as_tensor(asarray(photons.direction_lon_rad_sc, dtype=np.float32)) + lat = torch.as_tensor(asarray(photons.direction_lat_rad_sc, dtype=np.float32)) + en = torch.as_tensor(asarray(photons.energy_keV, dtype=np.float32)) + + lat = -lat + (np.pi / 2) + return torch.stack([lon, lat, en], dim=1) + + @staticmethod + def _get_source(events: EmCDSEventDataInSCFrameInterface): + lon = torch.as_tensor(asarray(events.scattered_lon_rad_sc, dtype=np.float32)) + lat = torch.as_tensor(asarray(events.scattered_lat_rad_sc, dtype=np.float32)) + phi = torch.as_tensor(asarray(events.scattering_angle_rad, dtype=np.float32)) + en = torch.as_tensor(asarray(events.energy_keV, dtype=np.float32)) + + lat = -lat + (np.pi / 2) + return torch.stack([en, phi, lon, lat], dim=1) + + def _effective_area_cm2(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> Iterable[float]: + context = self._get_context(photons) + + return np.asarray(self._response.evaluate_effective_area(context)) + + def _event_probability(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface, events: EmCDSEventDataInSCFrameInterface) -> Iterable[float]: + source = self._get_source(events) + context = self._get_context(photons) + + return np.asarray(self._response.evaluate_density(context, source)) + + def _random_events(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> EmCDSEventDataInSCFrameInterface: + context = self._get_context(photons) + samples = self._response.sample_density(context) + samples[:, 3].mul_(-1).add_(np.pi/2) + samples = np.asarray(samples) + + return EmCDSEventDataInSCFrameFromArrays( + samples[:, 0], # Energy + samples[:, 2], # Lon + samples[:, 3], # Lat + samples[:, 1] # Phi + ) diff --git a/cosipy/threeml/optimized_unbinned_folding.py b/cosipy/threeml/optimized_unbinned_folding.py new file mode 100644 index 000000000..3fcf7e070 --- /dev/null +++ b/cosipy/threeml/optimized_unbinned_folding.py @@ -0,0 +1,1019 @@ +import copy +import os +import json +from typing import Optional, Iterable, Type, Tuple, List, Union +from pathlib import Path +from tqdm.auto import tqdm + +import numpy as np +import h5py +from astromodels import PointSource +from astropy.coordinates import CartesianRepresentation +from executing import Source +from scoords import SpacecraftFrame + +from cosipy import SpacecraftHistory +from cosipy.interfaces.source_response_interface import CachedUnbinnedThreeMLSourceResponseInterface +from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays +from cosipy.interfaces import EventInterface +from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface +from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface +from cosipy.interfaces.instrument_response_interface import FarFieldSpectralInstrumentResponseFunctionInterface +from cosipy.response.photon_types import PhotonListWithDirectionAndEnergyInSCFrame +from cosipy.util.iterables import asarray + +from astropy import units as u +import astropy.constants as c +from astropy.coordinates import SkyCoord +from astropy.time import Time + +import logging + +logger = logging.getLogger(__name__) + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch +from cosipy.response.nf_instrument_response_function import UnpolarizedNFFarFieldInstrumentResponseFunction + + +class UnbinnedThreeMLPointSourceResponseIRFAdaptive(CachedUnbinnedThreeMLSourceResponseInterface): + + def __init__(self, + data: TimeTagEmCDSEventDataInSCFrameInterface, + irf: FarFieldSpectralInstrumentResponseFunctionInterface, + sc_history: SpacecraftHistory, + show_progress: bool = True, + force_energy_node_caching: bool = False, + reduce_memory: bool = True): + + """ + Will fold the IRF with the point source spectrum by evaluating the IRF at Ei positions adaptively chosen based on characteristic IRF features + Note that this assumes a smooth flux spectrum + + All IRF queries are cached and can be saved to / loaded from a file + """ + + # Interface inputs + self._source = None + + # Other implementation inputs + self._data = data + self._irf = irf + self._sc_ori = sc_history + self.show_progress = show_progress + self.force_energy_node_caching = force_energy_node_caching + + # Default parameters for irf energy node placement + self._total_energy_nodes = (60, 500) + self._peak_nodes = (18, 12) + self._peak_widths = (0.04, 0.1) + self._energy_range = (100., 10_000.) + self._cache_batch_size = 1_000_000 + self._integration_batch_size = 1_000_000 + self._offset: Optional[float] = 1e-12 + + # Placeholder for node pool - stored as Tensors + self._width_tensor: Optional[torch.Tensor] = None + self._nodes_primary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self._nodes_secondary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + self._nodes_bkg_1: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self._nodes_bkg_2: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + self._nodes_bkg_3: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + + # Checks to avoid unecessary recomputations + self._last_convolved_source_skycoord = None + self._last_convolved_source_dict_number = None + self._last_convolved_source_dict_density = None + self._sc_coord_sph_cache = None + + # Cached values + self._irf_cache: Optional[torch.Tensor] = None # cm^2/rad/sr + self._irf_energy_node_cache: Optional[np.ndarray] = None # (Optional, only if full batch) + self._area_cache: Optional[np.ndarray] = None # cm^2*s*keV + self._area_energy_node_cache: Optional[np.ndarray] = None + self._exp_events: Optional[float] = None + self._exp_density: Optional[torch.Tensor] = None + + # Precomputed spacecraft history - Midpoint + self._mid_times = self._sc_ori.obstime[:-1] + (self._sc_ori.obstime[1:] - self._sc_ori.obstime[:-1]) / 2 + self._sc_ori_center = self._sc_ori.interp(self._mid_times) + + # Precomputed spacecraft history - Simpson + # t_edges = self._sc_ori.obstime + # t_mids = t_edges[:-1] + (t_edges[1:] - t_edges[:-1]) / 2 + # all_t, inv_indices = np.unique(Time([t_edges, t_mids]), return_inverse=True) + # all_t = Time(all_t) + # self._sc_ori_simpson = self._sc_ori.interp(all_t) + # edge_indices = inv_indices[:len(t_edges)] + # mid_indices = inv_indices[len(t_edges):] + # livetime = self._sc_ori.livetime.to_value(u.s) + # self._unique_time_weights = np.zeros(len(all_t), dtype=np.float32) + # np.add.at(self._unique_time_weights, edge_indices[:-1], livetime / 6.0) + # np.add.at(self._unique_time_weights, edge_indices[1:], livetime / 6.0) + # np.add.at(self._unique_time_weights, mid_indices, 4.0 * livetime / 6.0) + + data_times = self._data.time + self._n_events = self._data.nevents + self._unique_unix, self._inv_idx = np.unique(data_times.utc.unix, return_inverse=True) + unique_times_obj = Time(self._unique_unix, format='unix', scale='utc') + self._sc_ori_unique = self._sc_ori.interp(unique_times_obj) + + interval_ratios = (self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) + bin_indices = np.searchsorted(self._sc_ori.obstime.utc.unix, self._unique_unix, side="right") - 1 + bin_indices = np.clip(bin_indices, 0, len(self._sc_ori.livetime) - 1) + unique_ratio = interval_ratios[bin_indices] + self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) + + self._energy_m_keV = torch.as_tensor(asarray(self._data.energy_keV, dtype=np.float32)) + self._phi_rad = torch.as_tensor(asarray(self._data.scattering_angle_rad, dtype=np.float32)) + + self._lon_scatt = torch.as_tensor(asarray(self._data.scattered_lon_rad_sc, dtype=np.float32)) + self._lat_scatt = torch.as_tensor(asarray(self._data.scattered_lat_rad_sc, dtype=np.float32)) + self._cos_lat_scatt = torch.cos(self._lat_scatt) + self._sin_lat_scatt = torch.sin(self._lat_scatt) + self._cos_lon_scatt = torch.cos(self._lon_scatt) + self._sin_lon_scatt = torch.sin(self._lon_scatt) + + # Also runs _check_memory_savings + self.reduce_memory = reduce_memory + + @property + def event_type(self) -> Type[EventInterface]: + return TimeTagEmCDSEventInSCFrameInterface + + @property + def force_energy_node_caching(self) -> bool: return self._force_energy_node_caching + @force_energy_node_caching.setter + def force_energy_node_caching(self, val): + if not isinstance(val, bool): + raise ValueError("force_energy_node_caching must be a boolean") + self._force_energy_node_caching = val + + @property + def total_energy_nodes(self) -> Tuple[int, int]: return self._total_energy_nodes + @total_energy_nodes.setter + def total_energy_nodes(self, val): self.set_integration_parameters(total_energy_nodes=val) + + @property + def peak_nodes(self) -> Tuple[int, int]: return self._peak_nodes + @peak_nodes.setter + def peak_nodes(self, val): self.set_integration_parameters(peak_nodes=val) + + @property + def peak_widths(self) -> Tuple[float, float]: return self._peak_widths + @peak_widths.setter + def peak_widths(self, val): self.set_integration_parameters(peak_widths=val) + + @property + def energy_range(self) -> Tuple[float, float]: return self._energy_range + @energy_range.setter + def energy_range(self, val): self.set_integration_parameters(energy_range=val) + + @property + def cache_batch_size(self) -> int: return self._cache_batch_size + @cache_batch_size.setter + def cache_batch_size(self, val): self.set_integration_parameters(cache_batch_size=val) + + @property + def integration_batch_size(self) -> int: return self._integration_batch_size + @integration_batch_size.setter + def integration_batch_size(self, val): self.set_integration_parameters(integration_batch_size=val) + + @property + def offset(self) -> Optional[float]: return self._offset + @offset.setter + def offset(self, val): self.set_integration_parameters(offset=val) + + @property + def show_progress(self) -> bool: return self._show_progress + @show_progress.setter + def show_progress(self, val: bool): + if not isinstance(val, bool): + raise ValueError("show_progress must be a boolean") + self._show_progress = val + + def _check_memory_savings(self): + inefficient = (self._integration_batch_size > (self._n_events * self._total_energy_nodes[0]/2)) + if inefficient & self._reduce_memory: + logger.warning(f"Since integration_batch_size is too large reduce_memory will increase the memory usage! Disable it if this behavior is not desired.") + @property + def reduce_memory(self) -> bool: return self._reduce_memory + @reduce_memory.setter + def reduce_memory(self, val: bool): + if not isinstance(val, bool): + raise ValueError("reduce_memory must be a boolean") + self._reduce_memory = val + self._check_memory_savings() + if not val: + if self._irf_cache is not None: + self._irf_cache = torch.as_tensor(self._irf_cache, dtype=torch.float64) + if self._irf_energy_node_cache is not None: + self._irf_energy_node_cache = np.asarray(self._irf_energy_node_cache, dtype=np.float64) + else: + if self._irf_cache is not None: + self._irf_cache = torch.as_tensor(self._irf_cache, dtype=torch.float32) + if self._irf_energy_node_cache is not None: + self._irf_energy_node_cache = np.asarray(self._irf_energy_node_cache, dtype=np.float32) + + def set_integration_parameters(self, + total_energy_nodes: Optional[Tuple[int, int]] = None, + peak_nodes: Optional[Tuple[int, int]] = None, + peak_widths: Optional[Tuple[float, float]] = None, + energy_range: Optional[Tuple[float, float]] = None, + cache_batch_size: Optional[int] = -1, + integration_batch_size: Optional[int] = -1, + offset: Optional[float] = -1.0): + + new_total = total_energy_nodes or self._total_energy_nodes + new_peak_nodes = peak_nodes or self._peak_nodes + new_peak_widths = peak_widths or self._peak_widths + new_range = energy_range or self._energy_range + new_cache_batch = cache_batch_size if cache_batch_size != -1 else self._cache_batch_size + new_integration_batch = integration_batch_size if integration_batch_size != -1 else self._integration_batch_size + new_offset = offset if offset != -1.0 else self._offset + + irf_affected = ( + new_peak_nodes != self._peak_nodes or + new_peak_widths != self._peak_widths or + new_total[0] != self._total_energy_nodes[0] or + new_range != self._energy_range + ) + + area_affected = ( + new_total[1] != self._total_energy_nodes[1] or + new_range != self._energy_range + ) + + if irf_affected: + self._irf_cache = self._irf_energy_node_cache = self._width_tensor = None + self._nodes_primary = self._nodes_secondary = None + self._nodes_bkg_1 = self._nodes_bkg_2 = self._nodes_bkg_3 = None + + if area_affected: + self._area_cache = self._area_energy_node_cache = None + + if new_total[0] < (new_peak_nodes[0] + 2 * new_peak_nodes[1] + 3): + raise ValueError("Too many nodes per peak compared to the total number or peaks!") + + if any(n < 1 for n in new_total): + raise ValueError("The number of energy nodes must be at least 1.") + + if new_range[0] >= new_range[1]: + raise ValueError("The initial energy interval needs to be increasing!") + + if (new_cache_batch is not None) and (new_cache_batch < max(new_total)): + raise ValueError("The cache batch size cannot be smaller than the number of integration nodes.") + + if (new_integration_batch is not None) and (new_integration_batch < max(new_total)): + raise ValueError("The integration batch size cannot be smaller than the number of integration nodes.") + + if (new_offset is not None) and (new_offset < 0): + raise ValueError("The offset cannot be negative.") + + self._total_energy_nodes = new_total + self._peak_nodes = new_peak_nodes + self._peak_widths = new_peak_widths + self._energy_range = new_range + self._cache_batch_size = new_cache_batch if new_cache_batch is not None else (self._n_events * max(new_total)) + self._integration_batch_size = new_integration_batch if new_integration_batch is not None else (self._n_events * max(new_total)) + self._offset = new_offset + self._check_memory_savings() + + @staticmethod + def _build_nodes(degree: int) -> Tuple[torch.Tensor, torch.Tensor]: + x, w = np.polynomial.legendre.leggauss(degree) + return torch.as_tensor(x, dtype=torch.float32).unsqueeze(0), torch.as_tensor(w, dtype=torch.float32).unsqueeze(0) + + def _build_split_nodes(self, remaining: int, groups: int): + q, r = divmod(remaining, groups) + return [self._build_nodes(q + (1 if i < r else 0)) for i in range(groups)] + + def _init_node_pool(self): + self._width_tensor = torch.tensor([self._peak_widths[0], self._peak_widths[0], + self._peak_widths[1], self._peak_widths[1]], dtype=torch.float32) + + self._nodes_primary = self._build_nodes(self._peak_nodes[0]) + self._nodes_secondary = self._build_nodes(self._peak_nodes[1]) + + self._nodes_bkg_1 = self._build_nodes(self._total_energy_nodes[0] - self._peak_nodes[0]) + + self._nodes_bkg_2 = self._build_split_nodes( + self._total_energy_nodes[0] - self._peak_nodes[0] - self._peak_nodes[1], 2 + ) + + self._nodes_bkg_3 = self._build_split_nodes( + self._total_energy_nodes[0] - self._peak_nodes[0] - 2 * self._peak_nodes[1], 3 + ) + + @staticmethod + def _scale_nodes_exp(E1: torch.Tensor, E2: torch.Tensor, + nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + diff = E2 - E1 + + out_n = (nodes_u + 1).mul(0.5).pow(2).mul(diff).add(E1) + out_w = (nodes_u + 1).mul(0.5).mul(weights_u).mul(diff) + + return out_n, out_w + + @staticmethod + def _scale_nodes_center(E1: torch.Tensor, E2: torch.Tensor, EC: torch.Tensor, + nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask_left = (nodes_u < 0) + width_left = (EC - E1) + width_right = (E2 - EC) + + scale = torch.where(mask_left, width_left, width_right) + + out_n = nodes_u.pow(3).mul(scale).add(EC) + out_w = nodes_u.pow(2).mul(3).mul(weights_u).mul(scale) + + return out_n, out_w + + def _get_escape_peak(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor) -> torch.Tensor: + E2 = 511.0 / (1.0 + 511.0 / energy_m_keV - torch.cos(phi_rad)) + energy = energy_m_keV + 1022.0 - E2 + + accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > 1600.0) & (energy_m_keV < energy) + return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) + + def _get_missing_energy_peak(self, phi_geo_rad: torch.Tensor, energy_m_keV: torch.Tensor, + phi_rad: torch.Tensor, inverse: bool = False) -> torch.Tensor: + cos_geo = torch.cos(phi_geo_rad) + cos_phi = torch.cos(phi_rad) + + if inverse: + denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) + root = torch.sqrt(energy_m_keV * (cos_geo - 1) * (-2044.0 - 5 * energy_m_keV + energy_m_keV * cos_geo + 4 * energy_m_keV * cos_phi)) + energy = 511.0 * (energy_m_keV - energy_m_keV * cos_geo + root) / denom + else: + denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) + root = torch.sqrt(energy_m_keV**2 * (cos_geo - 1) * (cos_phi - 1) * + ((1022.0 + energy_m_keV)**2 - energy_m_keV * (2044.0 + energy_m_keV) * cos_phi - 2 * energy_m_keV**2 * cos_geo * torch.sin(phi_rad/2)**2)) + energy = (energy_m_keV**2 * (1 - cos_geo - cos_phi + cos_phi * cos_geo) + root) / denom + + accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > energy_m_keV) & (energy_m_keV/energy - 1 < -0.2) + return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) + + def init_cache(self): + self._update_cache() + + def clear_cache(self): + self._irf_cache = None + self._irf_energy_node_cache = None + self._area_cache = None + self._area_energy_node_cache = None + self._exp_events = None + self._exp_density = None + + self._last_convolved_source_skycoord = None + self._last_convolved_source_dict_number = None + self._last_convolved_source_dict_density = None + self._sc_coord_sph_cache = None + + def set_source(self, source: Source): + if not isinstance(source, PointSource): + raise TypeError("Please provide a PointSource!") + + self._source = source + + def copy(self) -> CachedUnbinnedThreeMLSourceResponseInterface: + new_instance = copy.copy(self) + new_instance.clear_cache() + new_instance._source = None + + return new_instance + + @staticmethod + def _earth_occ(source_coord: SkyCoord, ori: SpacecraftHistory) -> np.ndarray: + dist_earth_center = ori.location.spherical.distance.km + max_angle = np.pi - np.arcsin(c.R_earth.to(u.km).value/dist_earth_center) + src_angle = source_coord.separation(ori.earth_zenith) + return (src_angle.to(u.rad).value < max_angle).astype(np.float32) + + @staticmethod + def _get_target_in_sc_frame(source_coord: SkyCoord, ori: SpacecraftHistory) -> SkyCoord: + src_in_sc_frame = SkyCoord(np.dot(ori.attitude.rot.inv().as_matrix(), source_coord.transform_to(ori.attitude.frame).cartesian.xyz.value), + representation_type = 'cartesian', frame = SpacecraftFrame()) + + src_in_sc_frame.representation_type = 'spherical' + return src_in_sc_frame + + def _compute_area(self): + coord = self._source.position.sky_coord + n_energy = self._total_energy_nodes[1] + + log_E_min = np.log10(self._energy_range[0]) + log_E_max = np.log10(self._energy_range[1]) + + x, w = np.polynomial.legendre.leggauss(n_energy) + + scale = 0.5 * (log_E_max - log_E_min) + y_nodes = scale * x + 0.5 * (log_E_max + log_E_min) + self._area_energy_node_cache = 10**y_nodes + + e_w = (np.log(10) * self._area_energy_node_cache * (w * scale)).astype(np.float32).reshape(1, -1) + e_n = self._area_energy_node_cache.astype(np.float32) + + # Midpoint + sc_coord_sph = self._get_target_in_sc_frame(coord, self._sc_ori_center) + earth_occ_index = self._earth_occ(coord, self._sc_ori_center) + + combined_time_weights = (self._sc_ori.livetime.to_value(u.s)).astype(np.float32) * earth_occ_index + + # Simpson + # sc_coord_sph = self._sc_ori_simpson.get_target_in_sc_frame(coord) + # earth_occ_index = self._earth_occ(coord, self._sc_ori_simpson) + + # combined_time_weights = (self._unique_time_weights * earth_occ_index).astype(np.float32) + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + n_time = len(lon_ph_rad) + batch_size_time = self._cache_batch_size // n_energy + + total_area = np.zeros(n_energy, dtype=np.float64) + + max_batch_total = n_energy * min(batch_size_time, n_time) + batch_lons_buffer = np.empty(max_batch_total, dtype=np.float32) + batch_lats_buffer = np.empty(max_batch_total, dtype=np.float32) + batch_energies_buffer = np.empty(max_batch_total, dtype=np.float32) + + for i in tqdm(range(0, n_time, batch_size_time), + disable=(not self.show_progress), + desc="Caching the effective area", + smoothing=0.2, + leave=False): + start = i + end = min(i + batch_size_time, n_time) + current_n_time = end - start + current_total = current_n_time * n_energy + + batch_lons_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] + batch_lats_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] + batch_energies_buffer[:current_total].reshape(current_n_time, n_energy)[:] = e_n + + photons = PhotonListWithDirectionAndEnergyInSCFrame( + batch_lons_buffer[:current_total], + batch_lats_buffer[:current_total], + batch_energies_buffer[:current_total] + ) + + eff_areas_flat = asarray(self._irf._effective_area_cm2(photons), dtype=np.float32) + eff_areas_grid = eff_areas_flat.reshape(current_n_time, n_energy) + + total_area += np.einsum('ij,i,j->j', + eff_areas_grid, + combined_time_weights[start:end], + e_w.ravel()) + + self._area_cache = total_area + + def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, + indices: torch.Tensor, mode: int, + sorted_peaks: torch.Tensor, delta: torch.Tensor): + + Emin, Emax = self._energy_range + + if mode == 1: + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=Emax) + + EC = sorted_peaks[:, 0] + + E1, E2, EC = [E.view(-1, 1) for E in (E1, E2, EC)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_bkg_1[0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, Emax, *self._nodes_bkg_1) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + elif mode == 2: + center_peak = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 + + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) + E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=Emax) + + EC1 = sorted_peaks[:, 0] + EC2 = sorted_peaks[:, 1] + + E1, E2, E3, E4, EC1, EC2 = [E.view(-1, 1) for E in (E1, E2, E3, E4, EC1, EC2)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_bkg_2[0][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_2[0]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_bkg_2[1][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E4, Emax, *self._nodes_bkg_2[1]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + elif mode == 3: + center_peak_1 = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 + center_peak_2 = (sorted_peaks[:, 1] + sorted_peaks[:, 2]) / 2 + + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak_1) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) + E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=center_peak_2) + E5 = (sorted_peaks[:, 2] - delta[:, 2]).clamp(min=E4) + E6 = (sorted_peaks[:, 2] + delta[:, 2]).clamp(max=Emax) + + EC1, EC2, EC3 = [sorted_peaks[:, i] for i in range(3)] + + E1, E2, E3, E4, E5, E6, EC1, EC2, EC3 = [E.view(-1, 1) for E in (E1, E2, E3, E4, E5, E6, EC1, EC2, EC3)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_bkg_3[0][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_3[0]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_bkg_3[1][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E4, E5, *self._nodes_bkg_3[1]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E5, E6, EC3, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + c += w + w = self._nodes_bkg_3[2][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E6, Emax, *self._nodes_bkg_3[2]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + + + def _get_nodes(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor, + phi_geo_rad: torch.Tensor, phi_igeo_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + energy_m_keV = energy_m_keV.view(-1, 1) + phi_rad = phi_rad.view(-1, 1) + phi_geo_rad = phi_geo_rad.view(-1, 1) + phi_igeo_rad = phi_igeo_rad.view(-1, 1) + + batch_size = energy_m_keV.shape[0] + + nodes = torch.zeros((batch_size, self._total_energy_nodes[0]), dtype=torch.float32) + weights = torch.zeros_like(nodes) + + peaks = torch.zeros((batch_size, 4), dtype=torch.float32) + peaks[:, 0] = energy_m_keV.squeeze() + peaks[:, 1] = self._get_escape_peak(energy_m_keV, phi_rad).squeeze() + peaks[:, 2] = self._get_missing_energy_peak(phi_geo_rad, energy_m_keV, phi_rad).squeeze() + peaks[:, 3] = self._get_missing_energy_peak(phi_igeo_rad, energy_m_keV, phi_rad, inverse=True).squeeze() + + diffs = peaks * self._width_tensor[None, ...] + + n_peaks = torch.sum(~torch.isnan(peaks), dim=1) + + indices_1 = torch.where(n_peaks == 1)[0] + indices_2 = torch.where(n_peaks == 2)[0] + indices_3 = torch.where(n_peaks == 3)[0] + + if len(indices_1) > 0: + self._fill_nodes(nodes, weights, indices_1, 1, + peaks[indices_1, :1], diffs[indices_1, :1]) + + if len(indices_2) > 0: + p_sub = peaks[indices_2] + d_sub = diffs[indices_2] + mask = ~torch.isnan(p_sub) + p_comp = p_sub[mask].view(-1, 2) + d_comp = d_sub[mask].view(-1, 2) + self._fill_nodes(nodes, weights, indices_2, 2, p_comp, d_comp) + + if len(indices_3) > 0: + p_sub = peaks[indices_3] + d_sub = diffs[indices_3] + mask = ~torch.isnan(p_sub) + p_comp = p_sub[mask].view(-1, 3) + d_comp = d_sub[mask].view(-1, 3) + + p_sorted, idx = torch.sort(p_comp, dim=1) + d_sorted = torch.gather(d_comp, 1, idx) + + self._fill_nodes(nodes, weights, indices_3, 3, p_sorted, d_sorted) + + return nodes, weights + + def _get_CDS_coordinates(self, lon_src_rad: torch.Tensor, lat_src_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + cos_lat_src = torch.cos(lat_src_rad) + sin_lat_src = torch.sin(lat_src_rad) + cos_lon_src = torch.cos(lon_src_rad) + sin_lon_src = torch.sin(lon_src_rad) + + cos_geo = ( + cos_lat_src * cos_lon_src * self._cos_lat_scatt * self._cos_lon_scatt + + cos_lat_src * sin_lon_src * self._cos_lat_scatt * self._sin_lon_scatt + + sin_lat_src * self._sin_lat_scatt + ) + + cos_geo = torch.clip(cos_geo, -1.0, 1.0) + phi_geo_rad = torch.arccos(cos_geo) + + return phi_geo_rad, np.pi - phi_geo_rad + + def _compute_nodes(self): + sc_coord_sph = self._sc_coord_sph_cache + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) + np_memory_dtype = np.float32 if self._reduce_memory else np.float64 + self._irf_energy_node_cache = np.asarray(self._get_nodes(self._energy_m_keV, self._phi_rad, phi_geo_rad, phi_igeo_rad)[0], dtype=np_memory_dtype) + + def _compute_density(self): + coord = self._source.position.sky_coord + sc_coord_sph = self._sc_coord_sph_cache + earth_occ_index = self._earth_occ(coord, self._sc_ori_unique)[self._inv_idx] + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) + + n_energy = self._total_energy_nodes[0] + batch_size_events = self._cache_batch_size // n_energy + + np_memory_dtype = np.float32 if self._reduce_memory else np.float64 + torch_memory_dtype = torch.float32 if self._reduce_memory else torch.float64 + + self._irf_cache = torch.zeros((self._n_events, n_energy), dtype=torch_memory_dtype) + + buffer_size = n_energy * min(batch_size_events, self._n_events) + batch_lon_src_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lat_src_buffer = np.empty(buffer_size, dtype=np.float32) + batch_energy_buffer = np.empty(buffer_size, dtype=np.float32) + batch_phi_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lon_scatt_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lat_scatt_buffer = np.empty(buffer_size, dtype=np.float32) + + if (batch_size_events < self._n_events) & (self._force_energy_node_caching): + self._irf_energy_node_cache = np.zeros((self._n_events, n_energy), dtype=np_memory_dtype) + + for i in tqdm(range(0, self._n_events, batch_size_events), + disable=(not self.show_progress), + desc="Caching the response", + smoothing=0.2, + leave=False): + start = i + end = min(i + batch_size_events, self._n_events) + current_n = end - start + current_total = current_n * n_energy + + e_sl = self._energy_m_keV[start:end] + p_sl = self._phi_rad[start:end] + pg_sl = phi_geo_rad[start:end] + pig_sl = phi_igeo_rad[start:end] + + nodes, weights = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + + if batch_size_events >= self._n_events: + self._irf_energy_node_cache = np.asarray(nodes, dtype=np_memory_dtype) + else: + self._irf_energy_node_cache[start:end] = np.asarray(nodes) + + batch_lon_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] + batch_lat_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] + + batch_energy_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._energy_m_keV[start:end, np.newaxis]) + batch_lon_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lon_scatt[start:end, np.newaxis]) + batch_lat_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lat_scatt[start:end, np.newaxis]) + batch_phi_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._phi_rad[start:end, np.newaxis]) + + photons = PhotonListWithDirectionAndEnergyInSCFrame( + batch_lon_src_buffer[:current_total], + batch_lat_src_buffer[:current_total], + np.asarray(nodes).ravel() + ) + + events = EmCDSEventDataInSCFrameFromArrays( + batch_energy_buffer[:current_total], + batch_lon_scatt_buffer[:current_total], + batch_lat_scatt_buffer[:current_total], + batch_phi_buffer[:current_total], + ) + + eff_areas_flat = torch.as_tensor(asarray(self._irf._effective_area_cm2(photons), dtype=np.float32)) + densities_flat = torch.as_tensor(asarray(self._irf._event_probability(photons, events), dtype=np.float32)) + + res_block = (densities_flat * eff_areas_flat).view(current_n, n_energy) + + occ = torch.as_tensor(earth_occ_index[start:end]).view(-1, 1) + live = torch.as_tensor(self._livetime_ratio[start:end]).view(-1, 1) + + res_block *= occ * live * weights + + self._irf_cache[start:end] = res_block + + def _update_cache(self): + + if self._source is None: + raise RuntimeError("Call set_source() first.") + + source_coord = self._source.position.sky_coord + + if (self._sc_coord_sph_cache is None) or (source_coord != self._last_convolved_source_skycoord): + self._sc_coord_sph_cache = self._get_target_in_sc_frame(source_coord, self._sc_ori_unique)[self._inv_idx] + + no_recalculation = ((source_coord == self._last_convolved_source_skycoord) + and + (self._irf_cache is not None) + and + (self._area_cache is not None)) + + area_recalculation = ((source_coord != self._last_convolved_source_skycoord) + or + (self._area_cache is None)) + + pdf_recalculation = ((source_coord != self._last_convolved_source_skycoord) + or + (self._irf_cache is None)) + + if no_recalculation: + return + else: + active_pool = True + if isinstance(self._irf, UnpolarizedNFFarFieldInstrumentResponseFunction): + active_pool = self._irf.active_pool + if not active_pool: + self._irf.init_compute_pool() + + if source_coord != self._last_convolved_source_skycoord: + self._irf_energy_node_cache = None + + if area_recalculation: + self._compute_area() + + if pdf_recalculation: + self._init_node_pool() + self._compute_density() + + if not active_pool: + self._irf.shutdown_compute_pool() + + self._last_convolved_source_skycoord = source_coord.copy() + + node_caching = (self.force_energy_node_caching + and + self._irf_energy_node_cache is None) + + if node_caching: + self._compute_nodes() + + def cache_to_file(self, filename: Union[str, Path]): + with h5py.File(str(filename), 'w') as f: + f.attrs['total_energy_nodes'] = self._total_energy_nodes + f.attrs['peak_nodes'] = self._peak_nodes + f.attrs['peak_widths'] = self._peak_widths + f.attrs['energy_range'] = self._energy_range + f.attrs['cache_batch_size'] = self._cache_batch_size + f.attrs['integration_batch_size'] = self._integration_batch_size + f.attrs['show_progress'] = self._show_progress + f.attrs['force_energy_node_caching'] = self._force_energy_node_caching + f.attrs['reduce_memory'] = self._reduce_memory + + if self._offset is not None: + f.attrs['offset'] = self._offset + + if self._irf_cache is not None: + f.create_dataset('irf_cache', data=self._irf_cache.numpy(), + compression='gzip', compression_opts=4) + + if self._irf_energy_node_cache is not None: + f.create_dataset('irf_energy_node_cache', data=self._irf_energy_node_cache, + compression='gzip') + + if self._area_cache is not None: + f.create_dataset('area_cache', data=self._area_cache, + compression='gzip') + + if self._area_energy_node_cache is not None: + f.create_dataset('area_energy_node_cache', data=self._area_energy_node_cache, + compression='gzip') + + if self._exp_events is not None: + f.create_dataset('exp_events', data=self._exp_events) + + if self._exp_density is not None: + f.create_dataset('exp_density', data=self._exp_density.numpy(), + compression='gzip') + + if self._last_convolved_source_dict_number is not None: + json_str = json.dumps(self._last_convolved_source_dict_number) + f.attrs['last_convolved_source_dict_number'] = json_str + + if self._last_convolved_source_dict_density is not None: + json_str = json.dumps(self._last_convolved_source_dict_density) + f.attrs['last_convolved_source_dict_density'] = json_str + + if self._last_convolved_source_skycoord is not None: + sc = self._last_convolved_source_skycoord + f.attrs['last_convolved_lon_deg'] = sc.spherical.lon.deg + f.attrs['last_convolved_lat_deg'] = sc.spherical.lat.deg + f.attrs['last_convolved_frame'] = sc.frame.name + if hasattr(sc, 'equinox'): + f.attrs['last_convolved_equinox'] = sc.equinox.value + + def cache_from_file(self, filename: Union[str, Path]): + if not os.path.exists(str(filename)): + raise FileNotFoundError(f"Cache file {str(filename)} not found.") + + with h5py.File(str(filename), 'r') as f: + self._total_energy_nodes = tuple(f.attrs['total_energy_nodes']) + self._peak_nodes = tuple(f.attrs['peak_nodes']) + self._peak_widths = tuple(f.attrs['peak_widths']) + self._energy_range = tuple(f.attrs['energy_range']) + self._cache_batch_size = int(f.attrs['cache_batch_size']) + self._integration_batch_size = int(f.attrs['integration_batch_size']) + self._show_progress = bool(f.attrs['show_progress']) + self._force_energy_node_caching = bool(f.attrs['force_energy_node_caching']) + self._reduce_memory = bool(f.attrs['reduce_memory']) + + if 'offset' in f.attrs: + self._offset = f.attrs['offset'] + else: + self._offset = None + + if 'irf_cache' in f: + self._irf_cache = torch.from_numpy(f['irf_cache'][:]) + else: + self._irf_cache = None + + if 'irf_energy_node_cache' in f: + self._irf_energy_node_cache = f['irf_energy_node_cache'][:] + else: + self._irf_energy_node_cache = None + + if 'area_cache' in f: + self._area_cache = f['area_cache'][:] + else: + self._area_cache = None + + if 'area_energy_node_cache' in f: + self._area_energy_node_cache = f['area_energy_node_cache'][:] + else: + self._area_energy_node_cache = None + + if 'exp_events' in f: + self._exp_events = float(f['exp_events'][()]) + else: + self._exp_events = None + + if 'exp_density' in f: + self._exp_density = torch.from_numpy(f['exp_density'][:]) + else: + self._exp_density = None + + if 'last_convolved_source_dict_number' in f.attrs: + self._last_convolved_source_dict_number = json.loads(f.attrs['last_convolved_source_dict_number']) + else: + self._last_convolved_source_dict_number = None + + if 'last_convolved_source_dict_density' in f.attrs: + self._last_convolved_source_dict_density = json.loads(f.attrs['last_convolved_source_dict_density']) + else: + self._last_convolved_source_dict_density = None + + if 'last_convolved_lon_deg' in f.attrs: + lon = f.attrs['last_convolved_lon_deg'] + lat = f.attrs['last_convolved_lat_deg'] + frame = f.attrs['last_convolved_frame'] + equinox = f.attrs.get('last_convolved_equinox', None) + + self._last_convolved_source_skycoord = SkyCoord(lon, lat, unit='deg', frame=frame, equinox=equinox) + else: + self._last_convolved_source_skycoord = None + + if self._irf_cache is not None: + self._init_node_pool() + + def expected_counts(self) -> float: + """ + Return the total expected counts. + """ + self._update_cache() + source_dict = self._source.to_dict() + + if (source_dict != self._last_convolved_source_dict_number) or (self._exp_events is None): + area = self._area_cache + flux = self._source(self._area_energy_node_cache) + self._exp_events = np.sum(area * flux, dtype=float) + + self._last_convolved_source_dict_number = source_dict + return self._exp_events + + def expectation_density(self) -> Iterable[float]: + """ + Return the expected number of counts density. This equals the event probabiliy times the number of events. + """ + + self._update_cache() + source_dict = self._source.to_dict() + if (source_dict != self._last_convolved_source_dict_density) or (self._exp_density is None): + self._exp_density = torch.zeros(self._n_events, dtype=torch.float64) + + n_energy = self._total_energy_nodes[0] + batch_size = self._integration_batch_size // n_energy + + if (self._irf_energy_node_cache is not None) & (batch_size >= self._n_events): + flux = torch.as_tensor( + self._source( + np.asarray(self._irf_energy_node_cache, dtype=np.float64).ravel() + ), + dtype=torch.float64 + ).view(self._irf_energy_node_cache.shape) + + cache = torch.as_tensor(self._irf_cache, dtype=torch.float64) + + torch.linalg.vecdot(cache, flux, dim=1, out=self._exp_density) + + else: + if self._irf_energy_node_cache is None: + sc_coord_sph = self._sc_coord_sph_cache + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) + + for i in range(0, self._n_events, batch_size): + end = min(i + batch_size, self._n_events) + + if self._irf_energy_node_cache is None: + e_sl = self._energy_m_keV[i:end] + p_sl = self._phi_rad[i:end] + pg_sl = phi_geo_rad[i:end] + pig_sl = phi_igeo_rad[i:end] + + nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + else: + nodes = self._irf_energy_node_cache[i:end] + + nodes = np.asarray(nodes, dtype=np.float64) + + flux_batch = torch.as_tensor( + self._source(nodes.ravel()), + dtype=torch.float64 + ).view(nodes.shape) + + cache = torch.as_tensor(self._irf_cache[i:end], dtype=torch.float64) + + torch.linalg.vecdot(cache, flux_batch, dim=1, out=self._exp_density[i:end]) + + self._last_convolved_source_dict_density = source_dict + + result = np.asarray(self._exp_density, dtype=np.float64) + + if self._offset is not None: + return result + self._offset + else: + return result \ No newline at end of file diff --git a/cosipy/threeml/psr_fixed_ei.py b/cosipy/threeml/psr_fixed_ei.py index e06af95c4..ded88e118 100644 --- a/cosipy/threeml/psr_fixed_ei.py +++ b/cosipy/threeml/psr_fixed_ei.py @@ -209,4 +209,4 @@ def event_probability(self) -> Iterable[float]: self._update_cache() - return self._event_prob + return self._event_prob \ No newline at end of file diff --git a/cosipy/threeml/unbinned_model_folding.py b/cosipy/threeml/unbinned_model_folding.py index a843565f8..c6fc5e62d 100644 --- a/cosipy/threeml/unbinned_model_folding.py +++ b/cosipy/threeml/unbinned_model_folding.py @@ -1,17 +1,22 @@ import itertools -from typing import Optional, Iterable +from typing import Optional, Iterable, Union import numpy as np from astromodels import Model, PointSource, ExtendedSource +from pathlib import Path from cosipy.interfaces import UnbinnedThreeMLModelFoldingInterface, UnbinnedThreeMLSourceResponseInterface +from cosipy.interfaces.source_response_interface import CachedUnbinnedThreeMLSourceResponseInterface from cosipy.response.threeml_response import ThreeMLModelFoldingCacheSourceResponsesMixin +from cosipy.util.iterables import asarray +from cosipy.util.iterables import itertools_batched class UnbinnedThreeMLModelFolding(UnbinnedThreeMLModelFoldingInterface, ThreeMLModelFoldingCacheSourceResponsesMixin): def __init__(self, point_source_response = UnbinnedThreeMLSourceResponseInterface, - extended_source_response: UnbinnedThreeMLSourceResponseInterface = None): + extended_source_response: UnbinnedThreeMLSourceResponseInterface = None, + batch_size: Optional[int] = None): # Interface inputs self._model = None @@ -19,6 +24,7 @@ def __init__(self, # Implementation inputs self._psr = point_source_response self._esr = extended_source_response + self._batch_size = batch_size if (self._psr is not None) and (self._esr is not None) and self._psr.event_type != self._esr.event_type: raise RuntimeError("Point and Extended Source Response must handle the same event type") @@ -54,11 +60,74 @@ def expected_counts(self) -> float: return sum(s.expected_counts() for s in self._source_responses.values()) + def _expectation_density_batched_gen(self, sources: list) -> Iterable[float]: + batched_sources = [itertools_batched(s, self._batch_size) for s in sources] + + for chunks in zip(*batched_sources): + densities = [asarray(c, dtype=np.float64) for c in chunks] + + yield from np.add.reduce(densities) + def expectation_density(self) -> Iterable[float]: + self._cache_source_responses() + + if not self._source_responses: + return np.array([], dtype=np.float64) + + sources = [ex.expectation_density() for ex in self._source_responses.values()] + + if (self._batch_size is None) or all(hasattr(s, "__len__") for s in sources): + densities = [asarray(s, dtype=np.float64) for s in sources] + return np.add.reduce(densities) + else: + return self._expectation_density_batched_gen(sources) + + +class CachedUnbinnedThreeMLModelFolding(UnbinnedThreeMLModelFolding): + def __init__(self, + point_source_response: Optional[UnbinnedThreeMLSourceResponseInterface] = None, + extended_source_response: Optional[UnbinnedThreeMLSourceResponseInterface] = None, + batch_size: Optional[int] = None): + + super().__init__(point_source_response=point_source_response, + extended_source_response=extended_source_response, + batch_size=batch_size) + + self._base_filename = "_source_response_cache.h5" + + def init_cache(self): """ - Sum of expectation density + Forces the creation of response objects for each source in the model. """ - + self._cache_source_responses() + + for response in self._source_responses.values(): + if isinstance(response, CachedUnbinnedThreeMLSourceResponseInterface): + response.init_cache() + + def save_caches(self, directory: Union[str, Path], cache_only: Optional[Iterable[str]] = None): + """Saves only the responses that implement the cache interface.""" + self.init_cache() + + dir_path = Path(directory) + dir_path.mkdir(parents=True, exist_ok=True) + + for name, response in self._source_responses.items(): + if (cache_only is not None) and (name not in set(cache_only)): + continue + if isinstance(response, CachedUnbinnedThreeMLSourceResponseInterface): + filepath = dir_path / f"{name}{self._base_filename}" + response.cache_to_file(filepath) + + def load_caches(self, directory: Union[str, Path], load_only: Optional[Iterable[str]] = None): + """Loads available cache files into compatible response objects.""" self._cache_source_responses() - return [sum(expectations) for expectations in zip(*(s.expectation_density() for s in self._source_responses.values()))] + dir_path = Path(directory) + for name, response in self._source_responses.items(): + if (load_only is not None) and (name not in set(load_only)): + continue + if isinstance(response, CachedUnbinnedThreeMLSourceResponseInterface): + filepath = dir_path / f"{name}{self._base_filename}" + if filepath.exists(): + response.cache_from_file(filepath) diff --git a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb new file mode 100644 index 000000000..79a1a21dd --- /dev/null +++ b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb @@ -0,0 +1,585 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8e42b9b4-766d-4818-b5db-f3d6c56a9563", + "metadata": { + "deletable": true, + "editable": true, + "frozen": false + }, + "source": [ + "# Fitting the Crab Spectrum with the Neural-Network Response and Background Approximation" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c02898da-7753-46b6-b291-612cb953dc36", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "This notebook provides an overview of the different new classes introduced to allow for a truly unbinned spectral analysis of continuum point sources. The user will mainly interact with:\n", + "\n", + "1. `NFResponse` and `UnpolarizedNFFarFieldInstrumentResponseFunction`: Provide the **C-A-RQS + Spherical Harmonics Expansion** approximation of the response.\n", + "2. `NFBackground` and `FreeNormNFUnbinnedBackground`: Provide the **C-A-RQS + Analytical Rate Model** approximation of the simulated background (currently total DC4).\n", + "3. `UnbinnedThreeMLPointSourceResponseIRFAdaptive`: Perform the folding of the response with the flux model in a more efficient way.\n", + "4. `CachedUnbinnedThreeMLModelFolding`: Adds the capability to save and load the cache of `UnbinnedThreeMLPointSourceResponseIRFAdaptive`.\n", + "\n", + "### Inner Workings\n", + "\n", + "For a comprehensive description of the underlying architecture or more detailed comparison plots and benchmarks, please for now refer to my (Pascal J.) thesis, \"Development of an Efficient Response Description for the COSI MeV 𝜸-Ray Telescope.\" Here, I will just provide a very basic explanation of the code structure.\n", + "\n", + "#### Flow-Based Models\n", + "\n", + "The building blocks managing the approximations are `NFResponse` and `NFBackground` (both of type `NFBase`). During setup, both require a checkpoint file (e.g., `unpolarized_nfresponse_v1-00.pt`) which contains all the necessary information, such as the neural network weights and hyperparameters, the version number to choose the correct model, or the coefficients for the effective area or background rate model.\n", + "\n", + "Another important task is the creation of compute pools. All PyTorch-related inference tasks are managed by a separate process for each chosen device. These processes need to be started or stopped, as shown in the example below.\n", + "\n", + "All queries, such as evaluating the effective area or density as well as sampling, are passed through an `Approximation` object (which chooses the correct model and prepares the input) to a `Model` object (which handles the actual inference on the PyTorch device). For the response $R = A_\\mathrm{eff}(\\nu\\lambda, E_i) \\cdot P(E_m, \\phi, \\psi\\chi|\\nu\\lambda, E_i)$, this includes:\n", + "\n", + "1. The effective area $A_\\mathrm{eff}$, modeled with a spherical harmonics expansion. The latter are evaluated using the PyTorch backend of the `sphericart` library.\n", + "2. The probability density $P(E_m, \\phi, \\psi\\chi|\\nu\\lambda, E_i)$, modeled with conditional-autoregressive-rational quadratic spline flows. The library used is called `normflows`.\n", + "\n", + "The background model $B = R(t) \\cdot P(E_m, \\phi, \\psi\\chi|t)$ follows a very similar structure and therefore uses most of the same code. While the rate $R(t)$ is always computed on the CPU, the density $P$ is also modeled using `normflows`.\n", + "\n", + "#### Folding\n", + "\n", + "For the unbinned analysis, we need to calculate the total number of events we expect, $N$, and the expectation density, $\\mathrm{d}N/\\mathrm{d}t\\mathrm{d}E_m\\mathrm{d}\\phi\\mathrm{d}\\psi\\chi$. The latter is especially difficult to compute, since it requires folding the response with the flux model and therefore computing a numerical integral with enough precision for every event in the analysis. The background model $B$, on the other hand, is already the expectation density and requires no integration.\n", + "\n", + "The folding can be performed using `UnbinnedThreeMLPointSourceResponseTrapz`. However, since $R$ has a low inference rate, `UnbinnedThreeMLPointSourceResponseIRFAdaptive` provides an optimized implementation. It uses the fact that many flux models are \"well-behaved\" (low order in a Taylor series). This means one can significantly reduce the integration error by concentrating most Gauss-Legendre integration nodes at peaks of the response and distributing the remaining ones in between. \n", + "\n", + "The exact parameters can be tuned by the user, but it probably won't be able to account for very complex spectral shapes. For each integration node, the response is evaluated once and cached, which takes most of the time. This way, during optimization, only the flux model changes and the fitting is fast.\n", + "\n", + "Using `CachedUnbinnedThreeMLModelFolding` instead of `UnbinnedThreeMLModelFolding` adds the capability to save this cache and other parameters of `UnbinnedThreeMLPointSourceResponseIRFAdaptive`. These files can be shared with other scientists and loaded during another session to skip the expensive cache initialization.\n", + "\n", + "### Hardware Requirements\n", + "\n", + "It is highly recommended to use a GPU, preferably by NVIDIA, for the inference. A CPU, even something like an AMD Threadripper, will only allow you to analyze a few orbits in a reasonable time. Also, note that the unbinned analysis requires you to have enough RAM, which increases as the number of events you include in the analysis increases. The batch size can be decreased to limit the maximum consumption." + ] + }, + { + "cell_type": "markdown", + "id": "35a07884-ca61-461f-bcde-96ec572d7f9d", + "metadata": {}, + "source": [ + "## Example" + ] + }, + { + "cell_type": "markdown", + "id": "87d50f2d-6d5a-48b7-9fe5-1c63014c80e0", + "metadata": {}, + "source": [ + "### Basic Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d2076d1-f84b-4971-b4e6-611844dcdef6", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from cosipy.util import fetch_wasabi_file\n", + "from cosipy.spacecraftfile import SpacecraftHistory\n", + "from astropy.time import Time\n", + "\n", + "import astropy.units as u\n", + "from copy import deepcopy\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from threeML import Band, PointSource, Model, JointLikelihood, DataList\n", + "from astromodels import Parameter\n", + "\n", + "from cosipy.threeml.unbinned_model_folding import CachedUnbinnedThreeMLModelFolding\n", + "from cosipy.statistics import UnbinnedLikelihood\n", + "from cosipy.interfaces import ThreeMLPluginInterface\n", + "from cosipy.interfaces.expectation_interface import SumExpectationDensity\n", + "\n", + "from cosipy.event_selection.time_selection import TimeSelector\n", + "from cosipy.data_io.EmCDSUnbinnedData import TimeTagEmCDSEventDataInSCFrameFromDC3Fits\n", + "\n", + "from cosipy.response.NFResponse import NFResponse\n", + "from cosipy.response.nf_instrument_response_function import UnpolarizedNFFarFieldInstrumentResponseFunction\n", + "\n", + "from cosipy.background_estimation.NFBackground import NFBackground\n", + "from cosipy.background_estimation.nf_unbinned_background import FreeNormNFUnbinnedBackground\n", + "\n", + "from cosipy.threeml.optimized_unbinned_folding import UnbinnedThreeMLPointSourceResponseIRFAdaptive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ec4a3dc-5b5e-4587-ba36-fc5e024e0682", + "metadata": {}, + "outputs": [], + "source": [ + "data_path = Path(\"./data_files/\")\n", + "\n", + "crab_data_path = data_path / \"crab_standard_3months_unbinned_data_filtered_with_SAAcut.fits.gz\"\n", + "fetch_wasabi_file('COSI-SMEX/DC3/Data/Sources/crab_standard_3months_unbinned_data_filtered_with_SAAcut.fits.gz', \n", + " output=str(crab_data_path))\n", + "\n", + "sc_orientation_path = data_path / \"DC3_final_530km_3_month_with_slew_15sbins_GalacticEarth_SAA.ori\"\n", + "fetch_wasabi_file('COSI-SMEX/DC3/Data/Orientation/DC3_final_530km_3_month_with_slew_15sbins_GalacticEarth_SAA.ori', \n", + " output=str(sc_orientation_path))\n", + "\n", + "bkg_data_path = data_path / \"Total_DC4_BG_3months_unbinned_data_filtered_with_SAAcut_withSAAbck.fits.gz\"\n", + "fetch_wasabi_file('COSI-SMEX/DC4/Data/Backgrounds/Total_DC4_BG_3months_unbinned_data_filtered_with_SAAcut_withSAAbck.fits.gz', \n", + " output=str(bkg_data_path))\n", + "\n", + "rsp_path = data_path / \"unpolarized_nfresponse_v1-00.pt\"\n", + "fetch_wasabi_file('cosi-pipeline-public/COSI-SMEX/DC4/Data/Responses/unpolarized_nfresponse_v1-00.pt',\n", + " checksum = 'a4f9a7842f2a7345f604da32a155803f', output=str(rsp_path))\n", + "\n", + "bkg_path = data_path / \"nfbackground_v1-01.pt\"\n", + "fetch_wasabi_file('cosi-pipeline-public/COSI-SMEX/DC4/Data/Responses/nfbackground_v1-01.pt',\n", + " checksum = '52a4fb024930e18430bff80db882d3d5', output=str(bkg_path))" + ] + }, + { + "cell_type": "markdown", + "id": "68bdf301-5315-48b7-af07-d340c52e8780", + "metadata": {}, + "source": [ + "For this tutorial we only use one day of data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b9f01c9-8bd8-4e69-b964-71ef99032452", + "metadata": {}, + "outputs": [], + "source": [ + "tstart = Time(\"2028-03-15 00:00:00.000\")\n", + "tstop = Time(\"2028-03-15 02:00:00.000\")\n", + "sc_orientation = SpacecraftHistory.open(sc_orientation_path)\n", + "sc_orientation = sc_orientation.select_interval(tstart, tstop)" + ] + }, + { + "cell_type": "markdown", + "id": "cd8dbcc7-016d-411a-b6de-f0958aa6ae5c", + "metadata": {}, + "source": [ + "The time selection is very slow (see issue https://github.com/cositools/cosipy/issues/504). Consider using your own implementation for now." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a6b82a7-9c78-4315-97bd-50d83ed09f42", + "metadata": {}, + "outputs": [], + "source": [ + "data_file = [crab_data_path, bkg_data_path]\n", + "selector = TimeSelector(tstart = sc_orientation.tstart, tstop = sc_orientation.tstop)\n", + "data = TimeTagEmCDSEventDataInSCFrameFromDC3Fits(data_file, selection=selector)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b974b15-e507-4ab9-9299-2b735f7d6b7f", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"This analysis uses {data.nevents} Events\")" + ] + }, + { + "cell_type": "markdown", + "id": "e676fa4e-d0a5-4b58-b309-19bb17dc95c3", + "metadata": {}, + "source": [ + "`NFResponse` and `NFBackground`\n", + "- devices: Optional default devices. For CPU, choose `[\"cpu\"]`.\n", + " - Default devices cause the pool to be initialized and closed automatically for each inference. This can produce unnecessary overhead when not used with functions like `UnbinnedThreeMLPointSourceResponseIRFAdaptive`, which handles this management internally.\n", + " - You can also manually manage the pools using `init_compute_pool`, `clean_compute_pool`, and `shutdown_compute_pool`.\n", + "- compile_mode: Optional from a list of options. If issues arise, choose `None`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25eca2eb-c55d-4f02-a61c-db9e2034a706", + "metadata": {}, + "outputs": [], + "source": [ + "rsp = NFResponse(\n", + " path_to_model=rsp_path,\n", + " area_batch_size=300_000,\n", + " density_batch_size=100_000, \n", + " devices=[\"cuda:0\", \"cuda:1\", \"cuda:2\", \"cuda:3\"],\n", + " area_compile_mode=\"max-autotune-no-cudagraphs\",\n", + " density_compile_mode=\"default\",\n", + " show_progress=True)\n", + "\n", + "irf = UnpolarizedNFFarFieldInstrumentResponseFunction(rsp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c8d0df3-813c-4eed-8316-d604d80dcbcb", + "metadata": {}, + "outputs": [], + "source": [ + "bkg_model = NFBackground(\n", + " path_to_model=bkg_path,\n", + " density_batch_size=100_000,\n", + " devices=[\"cuda:0\", \"cuda:1\", \"cuda:2\", \"cuda:3\"],\n", + " density_compile_mode=\"default\",\n", + " show_progress=True)\n", + "\n", + "bkg = FreeNormNFUnbinnedBackground(\n", + " model=bkg_model, \n", + " data=data, \n", + " sc_history=sc_orientation, \n", + " label=\"bkg_norm\")" + ] + }, + { + "cell_type": "markdown", + "id": "4904469c-285e-4a3a-8770-3dbb2bc0857b", + "metadata": {}, + "source": [ + "The next step takes some time, as it needs to interpolate the `SpacecraftHistory` (again issue https://github.com/cositools/cosipy/issues/504)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2dc234db-30f6-44e4-b9e6-22519b177b15", + "metadata": {}, + "outputs": [], + "source": [ + "psr = UnbinnedThreeMLPointSourceResponseIRFAdaptive(\n", + " data=data, \n", + " irf=irf, \n", + " sc_history=sc_orientation, \n", + " show_progress=True, \n", + " force_energy_node_caching=True, # Saves the energy nodes even when the batch size is too small\n", + " reduce_memory=True) # May reduce the peak memory consumption (saves the cache as float32). Look out for warnings\n", + "\n", + "# There are several other options that can be set, for example:\n", + "\n", + "psr.integration_batch_size = 100_000" + ] + }, + { + "cell_type": "markdown", + "id": "e614fff3-9729-4e01-b8aa-2b7f08b1c89a", + "metadata": {}, + "source": [ + "The default `Band`, `Powerlaw` and `PointSource` are very slow (see discussion https://github.com/cositools/cosipy/discussions/492). Consider implementing your own version with torch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "390f3822-e071-4add-b08e-833566ee7514", + "metadata": {}, + "outputs": [], + "source": [ + "l = 184.56\n", + "b = -5.78\n", + "\n", + "alpha = -1.99\n", + "beta = -2.32\n", + "xp = 531. * u.keV * (alpha + 2)\n", + "piv = 500. * u.keV\n", + "K = 3.07e-5 / u.cm / u.cm / u.s / u.keV\n", + "\n", + "spectrum = Band()\n", + "\n", + "spectrum.alpha.min_value = -2.14\n", + "spectrum.alpha.max_value = 3.0\n", + "spectrum.beta.min_value = -5.0\n", + "spectrum.beta.max_value = -2.15\n", + "spectrum.xp.min_value = 1.0\n", + "spectrum.alpha.delta = 0.01\n", + "spectrum.beta.delta = 0.01\n", + "\n", + "spectrum.alpha.value = alpha\n", + "spectrum.beta.value = beta\n", + "spectrum.xp.value = xp.value\n", + "spectrum.K.value = K.value\n", + "spectrum.piv.value = piv.value\n", + "\n", + "spectrum.xp.unit = xp.unit\n", + "spectrum.K.unit = K.unit\n", + "spectrum.piv.unit = piv.unit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f3af0aa", + "metadata": {}, + "outputs": [], + "source": [ + "spectrum_inj = deepcopy(spectrum)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3de7c067-204b-491f-9bea-4e44929dcfb9", + "metadata": {}, + "outputs": [], + "source": [ + "source = PointSource(\"Crab\",\n", + " l=l,\n", + " b=b,\n", + " spectral_shape=spectrum)\n", + "model = Model(source)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e52ef27-a5e7-41d7-87fd-3a82836bc131", + "metadata": {}, + "outputs": [], + "source": [ + "response = CachedUnbinnedThreeMLModelFolding(psr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01771aa6-36b1-46fb-a331-4dffff30ac91", + "metadata": {}, + "outputs": [], + "source": [ + "expectation_density = SumExpectationDensity(response, bkg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a5fbeee-e1dc-4c7b-993b-80b65f5217f4", + "metadata": {}, + "outputs": [], + "source": [ + "like_fun = UnbinnedLikelihood(expectation_density)\n", + "cosi = ThreeMLPluginInterface('cosi', like_fun, response, bkg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13e8ad22-f5e4-41cd-88fd-a04be7ad8200", + "metadata": {}, + "outputs": [], + "source": [ + "bkg_norm = bkg.norm.to_value(u.Hz)\n", + "\n", + "cosi.bkg_parameter['bkg_norm'] = Parameter(\"bkg_norm\",\n", + " bkg_norm,\n", + " unit = u.Hz,\n", + " min_value=0,\n", + " max_value=100,\n", + " delta=0.05,\n", + " )\n", + "print(f\"The average background flux is {bkg_norm:.2f} Hz\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88ea41cb-867a-41e1-840a-a328f316d3c2", + "metadata": {}, + "outputs": [], + "source": [ + "plugins = DataList(cosi)\n", + "like = JointLikelihood(model, plugins, verbose=True) # You can disable debugging" + ] + }, + { + "cell_type": "markdown", + "id": "9bcad107-7f24-45c9-838c-bf5c672607dd", + "metadata": {}, + "source": [ + "### Initializing the Cache" + ] + }, + { + "cell_type": "markdown", + "id": "e5483878-ee8b-44c7-b2cd-49abfc07d0f5", + "metadata": {}, + "source": [ + "Here you could load the cache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c87b4579-25dd-4e6d-be92-635d131f99d5", + "metadata": {}, + "outputs": [], + "source": [ + "# response.load_caches(data_path / \"Crab_tutorial\")" + ] + }, + { + "cell_type": "markdown", + "id": "ba814a8b-ad8a-41ef-ad44-c8794fd574f7", + "metadata": {}, + "source": [ + "The cache is initialized, which takes some time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37ab7f53-e474-4699-b7a6-5e4472ee3309", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Data Events: {data.nevents}\\nExpected Events: {expectation_density.expected_counts():.2f}\\nRelative Deviation {100 * (expectation_density.expected_counts()/data.nevents - 1):.3f} %\")" + ] + }, + { + "cell_type": "markdown", + "id": "28caa6c7-8775-4f4c-83ca-b27eb2918a90", + "metadata": {}, + "source": [ + "Now you could save the cache." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f64a3895-49ca-4e47-bc40-b41106ee270f", + "metadata": {}, + "outputs": [], + "source": [ + "response.save_caches(data_path / \"Crab_tutorial\")" + ] + }, + { + "cell_type": "markdown", + "id": "859f6453-b4be-4e1c-97a4-743f851b566d", + "metadata": {}, + "source": [ + "### Fitting" + ] + }, + { + "cell_type": "markdown", + "id": "4ea6ed61-8e10-49b8-a8d4-878f34b0b53d", + "metadata": {}, + "source": [ + "If the fit fails with \"Current minimum stored after fit ... and current ... do not correspond!\" simply rerun this cell (sometimes even that does not help)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d0f6e15-c963-4116-911c-f6ca11fa4292", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "like.fit()" + ] + }, + { + "cell_type": "markdown", + "id": "8eb8a59e", + "metadata": {}, + "source": [ + "Now we can plot the result and compare it with the injected spectrum." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "132119fa-6035-4a8d-a23d-550bdee2744f", + "metadata": {}, + "outputs": [], + "source": [ + "results = like.results\n", + "\n", + "parameters = {par.name: results.get_variates(par.path)\n", + " for par in results.optimized_model[\"Crab\"].parameters.values()\n", + " if par.free}\n", + "\n", + "results_err = results.propagate(results.optimized_model[\"Crab\"].spectrum.main.shape.evaluate_at, **parameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f68e51c9", + "metadata": {}, + "outputs": [], + "source": [ + "energy = np.geomspace(100*u.keV, 10*u.MeV).to_value(u.keV)\n", + "\n", + "flux_lo = np.zeros_like(energy)\n", + "flux_median = np.zeros_like(energy)\n", + "flux_hi = np.zeros_like(energy)\n", + "flux_inj = np.zeros_like(energy)\n", + "\n", + "for i, e in enumerate(energy):\n", + " flux = results_err(e)\n", + " flux_median[i] = flux.median\n", + " flux_lo[i], flux_hi[i] = flux.equal_tail_interval(cl=0.68)\n", + " flux_inj[i] = spectrum_inj.evaluate_at(e)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7500f4f7", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize = (9, 6))\n", + "\n", + "ax.plot(energy, energy**2 * flux_median, label = \"Best fit\")\n", + "ax.fill_between(energy, energy**2 * flux_lo, energy*energy*flux_hi, alpha = .5, label = \"Best fit (errors)\")\n", + "ax.plot(energy, energy**2 * flux_inj, color = 'black', ls = \":\", label = \"Injected\")\n", + "\n", + "ax.semilogx()\n", + "ax.semilogy()\n", + "\n", + "ax.set_xlabel(\"Energy [keV]\")\n", + "ax.set_ylabel(r\"$E^2 \\frac{\\mathrm{d}N}{\\mathrm{d}E}$ [keV cm$^{-2}$ s$^{-1}$]\")\n", + "\n", + "ax.legend();" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index ce56e513c..93beb69fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ [project.optional-dependencies] # Machine learning stuff -ML = ["torch", "torch_geometric"] +ML = ["torch", "torch_geometric", "normflows", "sphericart[torch]"] [project.urls] Homepage = "https://github.com/cositools/cosipy"