From 4b0f5bacce35be5189316d46ff0e39e3ebc5ce94 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Mon, 2 Feb 2026 11:09:24 +0100 Subject: [PATCH 01/16] Committing my local changes to interfaces branch Adding my neural network response and psr implementation to cosipy interfaces Merged latest changes from Israels interfaces branch and applied my changes: Response implementation and PSR folding with EventData and PhotonList --- cosipy/data_io/EmCDSUnbinnedData.py | 2 +- cosipy/response/NNResponse.py | 375 ++++++++ cosipy/response/__init__.py | 2 + .../response/instrument_response_function.py | 64 +- cosipy/response/nnresponse_helper.py | 697 +++++++++++++++ cosipy/threeml/psr_fixed_ei.py | 822 +++++++++++++++++- 6 files changed, 1959 insertions(+), 3 deletions(-) create mode 100644 cosipy/response/NNResponse.py create mode 100644 cosipy/response/nnresponse_helper.py diff --git a/cosipy/data_io/EmCDSUnbinnedData.py b/cosipy/data_io/EmCDSUnbinnedData.py index 3316ab570..1c0d3e7f0 100644 --- a/cosipy/data_io/EmCDSUnbinnedData.py +++ b/cosipy/data_io/EmCDSUnbinnedData.py @@ -8,7 +8,7 @@ from numpy._typing import ArrayLike from scoords import SpacecraftFrame -from cosipy import UnBinnedData +from cosipy.data_io.UnBinnedData import UnBinnedData from cosipy.interfaces import EventWithEnergyInterface, EventDataInterface, EventDataWithEnergyInterface from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface, EmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface, \ diff --git a/cosipy/response/NNResponse.py b/cosipy/response/NNResponse.py new file mode 100644 index 000000000..bfdcf5273 --- /dev/null +++ b/cosipy/response/NNResponse.py @@ -0,0 +1,375 @@ +import torch +from typing import List, Union +import torch.multiprocessing as mp +from .nnresponse_helper import * + +def cuda_cleanup_task(_) -> bool: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return True + +def update_worker_settings(args: Tuple[str, Union[int, CompileMode]]): + attr, value = args + global area_module + global density_module + + if attr == 'area_batch_size': + area_module.batch_size = value + elif attr == 'density_batch_size': + density_module.batch_size = value + elif attr == 'area_compile_mode': + area_module.compile_mode = value + elif attr == 'density_compile_mode': + density_module.compile_mode = value + +def init_worker(device_queue: mp.Queue, major_version: int, area_input: Dict, + density_input: Dict, area_batch_size: int, density_batch_size: int, + area_compile_mode: CompileMode, density_compile_mode: CompileMode): + global area_module + global density_module + global worker_device + + worker_device = torch.device(device_queue.get()) + if worker_device.type == 'cuda': + torch.cuda.set_device(worker_device) + + area_module = AreaApproximation(major_version, area_input, worker_device, area_batch_size, area_compile_mode) + density_module = DensityApproximation(major_version, density_input, worker_device, density_batch_size, density_compile_mode) + +def evaluate_area_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + global area_module + context, indices = args + + sub_context = context[indices, :] + if torch.device(worker_device).type == 'cuda': + sub_context = sub_context.pin_memory() + + return area_module.evaluate_effective_area(sub_context) + +def evaluate_density_task(args: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: + global density_module + context, source, indices = args + + sub_context = context[indices, :] + sub_source = source[indices, :] + if torch.device(worker_device).type == 'cuda': + sub_context = sub_context.pin_memory() + sub_source = sub_source.pin_memory() + + return density_module.evaluate_density(sub_context, sub_source) + +def sample_density_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + global density_module + context, indices = args + + sub_context = context[indices, :] + if torch.device(worker_device).type == 'cuda': + sub_context = sub_context.pin_memory() + + return density_module.sample_density(sub_context) + +class DensityApproximation: + def __init__(self, major_version: int, density_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): + self._major_version = major_version + self._worker_device = worker_device + self._density_input = density_input + self._batch_size = batch_size + self._compile_mode = compile_mode + + self._setup_model() + + def _setup_model(self): + version_map = { + 1: UnpolarizedDensityCMLPDGaussianCARQSFlow(self._density_input, self._worker_device, self._batch_size, self._compile_mode), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Density Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + self._expected_source_dim = self._model.source_dim + + def evaluate_density(self, context: torch.Tensor, source: torch.Tensor) -> torch.Tensor: + dim_context = context.shape[1] + dim_source = source.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + elif dim_source != self._expected_source_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_source_dim} features, but source has {dim_source}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + list_source = [source[:, i] for i in range(dim_source)] + + return self._model.evaluate_density(*list_context, *list_source) + + def sample_density(self, context: torch.Tensor) -> torch.Tensor: + dim_context = context.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + + return self._model.sample_density(*list_context) + +class AreaApproximation: + def __init__(self, major_version: int, area_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): + self._major_version = major_version + self._worker_device = worker_device + self._area_input = area_input + self._batch_size = batch_size + self._compile_mode = compile_mode + + self._setup_model() + + def _setup_model(self): + version_map = { + 1: UnpolarizedAreaSphericalHarmonicsExpansion(self._area_input, self._worker_device, self._batch_size, self._compile_mode), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Effective Area Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + + def evaluate_effective_area(self, context: torch.Tensor) -> torch.Tensor: + dim_context = context.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + + return self._model.evaluate_effective_area(*list_context) + +class NNResponse: + def __init__(self, path_to_model: str, area_batch_size: int = 100_000, density_batch_size: int = 100_000, + devices: Optional[List[Union[str, int, torch.device]]] = None, + area_compile_mode: CompileMode = "max-autotune-no-cudagraphs", density_compile_mode: CompileMode = "default"): + ckpt = torch.load(path_to_model, map_location=torch.device('cpu'), weights_only=False) + + required_keys = ['version', 'is_polarized', 'density_input', 'area_input'] + + for key in required_keys: + if key not in ckpt: + raise KeyError( + f"Invalid Checkpoint: Metadata key '{key}' not found in {path_to_model}. " + f"Ensure you saved the model as a dictionary, not just the state_dict." + ) + + self._version = ckpt['version'] + self._major_version = int(self._version.split('.')[0]) + self._is_polarized = ckpt['is_polarized'] + self._density_input = ckpt['density_input'] + self._area_input = ckpt['area_input'] + + self._pool = None + self._has_cuda = False + self._ctx = mp.get_context("spawn") + + self.area_batch_size = area_batch_size + self.density_batch_size = density_batch_size + self._area_compile_mode = area_compile_mode + self._density_compile_mode = density_compile_mode + + if devices is not None: + self.devices = devices + else: + self._devices = [] + + def __del__(self): + self.shutdown_compute_pool() + + @property + def devices(self) -> List[Union[str, int, torch.device]]: + return self._devices + + @devices.setter + def devices(self, value: List[Union[str, int, torch.device]]): + if not isinstance(value, list): + raise ValueError("devices must be a list of device identifiers") + self._devices = value + + @property + def is_polarized(self) -> bool: + return self._is_polarized + + @property + def area_batch_size(self) -> int: + return self._area_batch_size + + @area_batch_size.setter + def area_batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError("area_batch_size must be a positive integer") + self._area_batch_size = value + self._update_worker_config('area_batch_size', value) + + @property + def density_batch_size(self) -> int: + return self._density_batch_size + + @density_batch_size.setter + def density_batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError("density_batch_size must be a positive integer") + self._density_batch_size = value + self._update_worker_config('density_batch_size', value) + + @property + def area_compile_mode(self) -> CompileMode: return self._area_compile_mode + + @area_compile_mode.setter + def area_compile_mode(self, value: CompileMode): + self._area_compile_mode = value + self._update_worker_config('area_compile_mode', value) + + @property + def density_compile_mode(self) -> CompileMode: return self._density_compile_mode + + @density_compile_mode.setter + def density_compile_mode(self, value: CompileMode): + self._density_compile_mode = value + self._update_worker_config('density_compile_mode', value) + + def _update_worker_config(self, attr: str, value: Union[int, CompileMode]): + if self._pool is not None: + self._pool.map(update_worker_settings, [(attr, value)] * self._num_workers) + + def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device]]]=None): + active_devices = devices if devices is not None else self._devices + + if not active_devices: + raise RuntimeError("Cannot initialize pool: no devices provided as argument or set as fallback.") + + if self._pool: + print("Warning: Pool already initialized. Shutting down old pool first.") + self.shutdown_compute_pool() + + self._num_workers = len(active_devices) + self._has_cuda = any(torch.device(d).type == 'cuda' for d in active_devices) + + device_queue = self._ctx.Queue() + for d in active_devices: + device_queue.put(d) + + self._pool = self._ctx.Pool( + processes=self._num_workers, + initializer=init_worker, + initargs=(device_queue, self._major_version, self._area_input, self._density_input, + self._area_batch_size, self._density_batch_size, + self._area_compile_mode, self._density_compile_mode), + ) + + def clean_compute_pool(self): + if self._pool: + self._pool.map(cuda_cleanup_task, range(self._num_workers)) + + def shutdown_compute_pool(self): + if self._pool: + self._pool.close() + self._pool.join() + + self._num_workers = 0 + self._pool = None + self._has_cuda = None + + def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): + context.share_memory_() + #if self._has_cuda and not context.is_pinned(): + # context = context.pin_memory() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, idx) for idx in indices] + results = self._pool.map(sample_density_task, tasks) + + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() + + def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): + context.share_memory_() + #if self._has_cuda and not context.is_pinned(): + # context = context.pin_memory() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, idx) for idx in indices] + results = self._pool.map(evaluate_area_task, tasks) + + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() + + def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): context.share_memory_() + if not source.is_shared(): source.share_memory_() + + #if self._has_cuda: + # if not context.is_pinned(): + # context = context.pin_memory() + # if not source.is_pinned(): + # source = source.pin_memory() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, source, idx) for idx in indices] + results = self._pool.map(evaluate_density_task, tasks) + + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() \ No newline at end of file diff --git a/cosipy/response/__init__.py b/cosipy/response/__init__.py index 271a8858c..aefd11d21 100644 --- a/cosipy/response/__init__.py +++ b/cosipy/response/__init__.py @@ -9,3 +9,5 @@ from .threeml_extended_source_response import * from .instrument_response import * from .rsp_to_arf_rmf import RspArfRmfConverter +from .NNResponse import NNResponse +from .instrument_response_function import * diff --git a/cosipy/response/instrument_response_function.py b/cosipy/response/instrument_response_function.py index 95962ed24..e12583fd3 100644 --- a/cosipy/response/instrument_response_function.py +++ b/cosipy/response/instrument_response_function.py @@ -1,6 +1,7 @@ import itertools from typing import Iterable, Tuple +import torch import numpy as np from astropy.coordinates import SkyCoord @@ -17,8 +18,69 @@ FarFieldSpectralInstrumentResponseFunctionInterface from cosipy.interfaces.photon_parameters import PhotonInterface, PhotonWithDirectionAndEnergyInSCFrameInterface, PhotonListWithDirectionInterface from cosipy.response import FullDetectorResponse +from cosipy.response.NNResponse import NNResponse from cosipy.util.iterables import itertools_batched - +from operator import attrgetter + +class UnpolarizedNNFarFieldInstrumentResponseFunction(FarFieldInstrumentResponseFunctionInterface): + + photon_type = PhotonWithDirectionAndEnergyInSCFrameInterface + event_type = EmCDSEventInSCFrameInterface + + def __init__(self, response: NNResponse,): + if response.is_polarized: + raise ValueError("The provided NNResponse is polarized, but UnpolarizedNNFarFieldInstrumentResponseFunction only supports unpolarized responses.") + self._response = response + + def effective_area_cm2(self, photons: Iterable[PhotonWithDirectionAndEnergyInSCFrameInterface]) -> Iterable[float]: + getter = attrgetter('direction_lon_radians', 'direction_lat_radians', 'energy_keV') + + raw_data = list(map(getter, photons)) + + if not raw_data: + return np.array([], dtype=np.float32) + + context = torch.tensor(raw_data, dtype=torch.float32) + context[:, 1] = np.pi/2 - context[:, 1] + + return self._response.evaluate_effective_area(context).numpy() + + def event_probability(self, query: Iterable[Tuple[PhotonWithDirectionAndEnergyInSCFrameInterface, EmCDSEventInSCFrameInterface]]) -> Iterable[float]: + context_list = [] + source_list = [] + + for photon, event in query: + context_list.append((photon.direction_lon_radians, photon.direction_lat_radians, photon.energy_keV)) + source_list.append((event.energy_keV, event.scattering_angle_rad, event.scattered_lon_rad_sc, event.scattered_lat_rad_sc)) + + if not context_list: + return np.array([], dtype=np.float32) + + context = torch.tensor(context_list, dtype=torch.float32) + source = torch.tensor(source_list, dtype=torch.float32) + + context[:, 1] = np.pi/2 - context[:, 1] + source[:, 3] = np.pi/2 - source[:, 3] + + return self._response.evaluate_density(context, source).numpy() + + def random_events(self, photons: Iterable[PhotonWithDirectionAndEnergyInSCFrameInterface]) -> Iterable[EventInterface]: + getter = attrgetter('direction_lon_radians', 'direction_lat_radians', 'energy_keV') + + raw_data = list(map(getter, photons)) + + if not raw_data: + return [] + + context = torch.tensor(raw_data, dtype=torch.float32) + context[:, 1] = np.pi/2 - context[:, 1] + + samples = self._response.sample_density(context).numpy() + samples[:, 3] = np.pi/2 - samples[:, 3] + + return [ + EmCDSEventInSCFrame(e, phi, lon, lat) for e, phi, lon, lat in samples + ] class UnpolarizedDC3InterpolatedFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): diff --git a/cosipy/response/nnresponse_helper.py b/cosipy/response/nnresponse_helper.py new file mode 100644 index 000000000..0339bf8dc --- /dev/null +++ b/cosipy/response/nnresponse_helper.py @@ -0,0 +1,697 @@ +import normflows as nf +import numpy as np +import torch +from torch import nn +import healpy as hp +import sphericart.torch +from typing import Protocol, Optional, Literal, List, Union, Tuple, Dict + +CompileMode = Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] + +def build_cmlp_diaggaussian_base(input_dim: int, output_dim: int, + hidden_dim: int, num_hidden_layers: int) -> nf.distributions.BaseDistribution: + context_encoder = BaseMLP(input_dim, output_dim, hidden_dim, num_hidden_layers) + return ConditionalDiagGaussian(shape=output_dim//2, context_encoder=context_encoder) + +def build_c_arqs_flow(base: nf.distributions.BaseDistribution, num_layers: int, + latent_dim: int, context_dim: int, num_bins: int, + num_hidden_units: int, num_residual_blocks: int) -> nf.ConditionalNormalizingFlow: + flows = [] + for _ in range(num_layers): + flows += [nf.flows.AutoregressiveRationalQuadraticSpline(num_input_channels = latent_dim, + num_blocks = num_residual_blocks, + num_hidden_channels = num_hidden_units, + num_bins = num_bins, + num_context_channels = context_dim)] + flows += [nf.flows.LULinearPermute(latent_dim)] + return nf.ConditionalNormalizingFlow(base, flows) + +class NNDensityInferenceWrapper(nn.Module): + def __init__(self, model: nn.Module): + super().__init__() + self._model = model + + def forward(self, + source: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + n_samples: Optional[int] = None, + mode: str = "inference") -> torch.Tensor: + if mode == "inference": + if context is None: + return torch.exp(self._model.log_prob(source)) + else: + return torch.exp(self._model.log_prob(source, context)) + elif mode == "sampling": + if context is None: + return self._model.sample(num_samples=n_samples)[0] + else: + return self._model.sample(num_samples=n_samples, context=context)[0] + +class ConditionalDiagGaussian(nf.distributions.BaseDistribution): + def __init__(self, shape: Union[int, List[int], Tuple[int, ...]], context_encoder: nn.Module): + super().__init__() + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, list): + shape = tuple(shape) + self.shape = shape + self.n_dim = len(shape) + self.d = np.prod(shape) + self.context_encoder = context_encoder + + def forward(self, num_samples: int=1, context: Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, torch.Tensor]: + encoder_output = self.context_encoder(context) + split_ind = encoder_output.shape[-1] // 2 + mean = encoder_output[..., :split_ind] + log_scale = encoder_output[..., split_ind:] + eps = torch.randn( + (num_samples,) + self.shape, dtype=mean.dtype, device=mean.device + ) + z = mean + torch.exp(log_scale) * eps + log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( + log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1)) + ) + return z, log_p + + def log_prob(self, z: torch.Tensor, context: Optional[torch.Tensor]=None) -> torch.Tensor: + encoder_output = self.context_encoder(context) + split_ind = encoder_output.shape[-1] // 2 + mean = encoder_output[..., :split_ind] + log_scale = encoder_output[..., split_ind:] + log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( + log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2), + list(range(1, self.n_dim + 1)), + ) + return log_p + +class BaseMLP(nn.Module): + def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_hidden_layers: int): + super().__init__() + + layers = [] + + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(nn.ReLU()) + + for _ in range(num_hidden_layers): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(nn.ReLU()) + + layers.append(nn.Linear(hidden_dim, output_dim)) + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + +class AreaModelProtocol(Protocol): + @property + def context_dim(self) -> int: ... + + @property + def compile_mode(self) -> CompileMode: ... + + @compile_mode.setter + def compile_mode(self, value: CompileMode): ... + + @property + def batch_size(self) -> int: ... + + @batch_size.setter + def batch_size(self, value: int): ... + + def evaluate_effective_area(self, *args: torch.Tensor) -> torch.Tensor: ... + +class UnpolarizedAreaSphericalHarmonicsExpansion(AreaModelProtocol): + def __init__(self, area_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "max-autotune-no-cudagraphs"): + self._worker_device = torch.device(worker_device) + + self._lmax = area_input['lmax'] + self._poly_degree = area_input['poly_degree'] + self._poly_coeffs = area_input['poly_coeffs'] + + self._conv_coeffs = self._convert_coefficients().to(self._worker_device) + self._sh_calculator = sphericart.torch.SphericalHarmonics(self._lmax) + + self._compile_mode = compile_mode + self._compiled_cache = {} + + self._update_horner_op() + + self._is_cuda = (self._worker_device.type == 'cuda') + self.batch_size = batch_size + + if self._is_cuda: + self._compute_stream = torch.cuda.Stream(device=self._worker_device) + self._transfer_stream = torch.cuda.Stream(device=self._worker_device) + self._transfer_ready = [torch.cuda.Event(), torch.cuda.Event()] + self._compute_ready = [torch.cuda.Event(), torch.cuda.Event()] + else: + self._compute_stream = None + self._transfer_stream = None + self._transfer_ready = None + self._compute_ready = None + + def _write_gpu_tensors(self): + self._gpu_inputs = [ + (torch.empty(self._batch_size, device=self._worker_device), + torch.empty(self._batch_size, device=self._worker_device), + torch.empty(self._batch_size, device=self._worker_device)) + for _ in range(2) + ] + self._gpu_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] + + @property + def context_dim(self) -> int: + return 3 + + @property + def compile_mode(self) -> CompileMode: + return self._compile_mode + + @compile_mode.setter + def compile_mode(self, value: CompileMode): + if value != self._compile_mode: + self._compile_mode = value + self._update_horner_op() + + def _update_horner_op(self): + if self._compile_mode is None: + self._horner_op = self._horner_eval + else: + if self._compile_mode not in self._compiled_cache: + self._compiled_cache[self._compile_mode] = torch.compile( + self._horner_eval, + mode=self._compile_mode + ) + self._horner_op = self._compiled_cache[self._compile_mode] + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError(f"Batch size must be a positive integer, got {value}") + self._batch_size = value + if self._is_cuda: + self._write_gpu_tensors() + + def _convert_coefficients(self) -> torch.Tensor: + num_sh = (self._lmax + 1)**2 + conv_coeffs = torch.zeros((num_sh, self._poly_degree + 1), dtype=torch.float64) + + for cnt, (l, m) in enumerate((l, m) for l in range(self._lmax + 1) for m in range(-l, l + 1)): + idx = hp.Alm.getidx(self._lmax, l, abs(m)) + if m == 0: + conv_coeffs[cnt] = self._poly_coeffs[0, :, idx] + else: + fac = np.sqrt(2) * (-1)**m + val = self._poly_coeffs[0, :, idx] if m > 0 else -self._poly_coeffs[1, :, idx] + conv_coeffs[cnt] = fac * val + return conv_coeffs.T + + def _horner_eval(self, x: torch.Tensor) -> torch.Tensor: + x_64 = x.to(torch.float64).unsqueeze(1) + result = self._conv_coeffs[0].expand(x.shape[0], -1).clone() + for i in range(1, self._conv_coeffs.size(0)): + result.mul_(x_64).add_(self._conv_coeffs[i]) + return result.to(torch.float32) + + def _compute_spherical_harmonics(self, dir_az: torch.Tensor, dir_polar: torch.Tensor) -> torch.Tensor: + sin_p = torch.sin(dir_polar) + xyz = torch.stack(( + sin_p * torch.cos(dir_az), + sin_p * torch.sin(dir_az), + torch.cos(dir_polar) + ), dim=-1) + return self._sh_calculator(xyz) + + @torch.inference_mode() + def evaluate_effective_area(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor) -> torch.Tensor: + N = energy_keV.shape[0] + + ei_norm = (torch.log10(energy_keV) / 2 - 1).to(torch.float32) + result = torch.empty(N, dtype=torch.float32, device="cpu") + + def get_batch(start_idx): + end_idx = min(start_idx + self._batch_size, N) + return ( + ei_norm[start_idx:end_idx].to(self._worker_device), + dir_az[start_idx:end_idx].to(self._worker_device), + dir_polar[start_idx:end_idx].to(self._worker_device) + ) + + if self._is_cuda: + ei_norm = ei_norm.pin_memory() + result = result.pin_memory() + + def enqueue_transfer(slot_idx, start_idx): + end_idx = min(start_idx + self._batch_size, N) + size = end_idx - start_idx + self._gpu_inputs[slot_idx][0][:size].copy_(ei_norm[start_idx:end_idx], non_blocking=True) + self._gpu_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) + self._gpu_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) + + if self._is_cuda and (N > 0): + with torch.cuda.stream(self._transfer_stream): + enqueue_transfer(0, 0) + self._transfer_ready[0].record(self._transfer_stream) + + for i, start in enumerate(range(0, N, self._batch_size)): + curr_idx = i % 2 + next_idx = (i + 1) % 2 + end = min(start + self._batch_size, N) + batch_len = end - start + next_start = start + self._batch_size + + if self._is_cuda: + with torch.cuda.stream(self._compute_stream): + self._compute_stream.wait_event(self._transfer_ready[curr_idx]) + + ei_b, az_b, pol_b = [t[:batch_len] for t in self._gpu_inputs[curr_idx]] + + poly_b = self._horner_op(ei_b) + ylm_b = self._compute_spherical_harmonics(az_b, pol_b) + + torch.sum(poly_b * ylm_b, dim=1, out=self._gpu_results[curr_idx][:batch_len]) + + self._compute_ready[curr_idx].record(self._compute_stream) + + if next_start < N: + with torch.cuda.stream(self._transfer_stream): + enqueue_transfer(next_idx, next_start) + + self._transfer_ready[next_idx].record(self._transfer_stream) + + with torch.cuda.stream(self._transfer_stream): + self._transfer_stream.wait_event(self._compute_ready[curr_idx]) + result[start:end].copy_(self._gpu_results[curr_idx][:batch_len], non_blocking=True) + else: + ei_b, az_b, pol_b = get_batch(start) + + poly_b = self._horner_op(ei_b) + ylm_b = self._compute_spherical_harmonics(az_b, pol_b) + result[start:end] = torch.sum(poly_b * ylm_b, dim=1) + + if self._is_cuda: + torch.cuda.synchronize(self._worker_device) + + return torch.clamp(result, min=0) + +class DensityModelProtocol(Protocol): + @property + def context_dim(self) -> int: ... + + @property + def source_dim(self) -> int: ... + + @property + def compile_mode(self) -> CompileMode: ... + + @compile_mode.setter + def compile_mode(self, value: CompileMode): ... + + @property + def batch_size(self) -> int: ... + + @batch_size.setter + def batch_size(self, value: int): ... + + def sample_density(self, *args: torch.Tensor) -> torch.Tensor: ... + + def evaluate_density(self, *args: torch.Tensor) -> torch.Tensor: ... + +class UnpolarizedDensityCMLPDGaussianCARQSFlow(DensityModelProtocol): + def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "default"): + self._worker_device = torch.device(worker_device) + + self._snapshot = density_input["model_state_dict"] + self._bins = density_input["bins"] + self._hidden_units = density_input["hidden_units"] + self._residual_blocks = density_input["residual_blocks"] + self._total_layers = density_input["total_layers"] + self._context_size = density_input["context_size"] + self._latent_size = density_input["latent_size"] + self._mlp_hidden_units = density_input["mlp_hidden_units"] + self._mlp_hidden_layers = density_input["mlp_hidden_layers"] + + self._compile_mode = compile_mode + self._compiled_cache = {} + + self._eager_model = self._init_base_model() + self._update_model_op() + + self._is_cuda = (self._worker_device.type == 'cuda') + self.batch_size = batch_size + + if self._is_cuda: + self._compute_stream = torch.cuda.Stream(device=self._worker_device) + self._transfer_stream = torch.cuda.Stream(device=self._worker_device) + self._transfer_ready = [torch.cuda.Event(), torch.cuda.Event()] + self._compute_ready = [torch.cuda.Event(), torch.cuda.Event()] + else: + self._compute_stream = None + self._transfer_stream = None + self._transfer_ready = None + self._compute_ready = None + + def _write_gpu_tensors(self): + self._eval_inputs = [ + tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.source_dim + self.context_dim)) + for _ in range(2) + ] + self._eval_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] + + self._sample_inputs = [ + tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.context_dim)) + for _ in range(2) + ] + + self._sample_results = [ + (torch.empty((self._batch_size, self._latent_size), device=self._worker_device), + torch.empty(self._batch_size, dtype=torch.bool, device=self._worker_device)) + for _ in range(2) + ] + + @property + def context_dim(self) -> int: + return 3 + + @property + def source_dim(self) -> int: + return 4 + + @property + def compile_mode(self) -> CompileMode: + return self._compile_mode + + @compile_mode.setter + def compile_mode(self, value: CompileMode): + if value != self._compile_mode: + self._compile_mode = value + self._update_model_op() + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError(f"Batch size must be a positive integer, got {value}") + self._batch_size = value + if self._is_cuda: + self._write_gpu_tensors() + + def _build_model(self) -> nf.ConditionalNormalizingFlow: + base = build_cmlp_diaggaussian_base( + self._context_size, 2 * self._latent_size, self._mlp_hidden_units, self._mlp_hidden_layers + ) + return build_c_arqs_flow( + base, self._total_layers, self._latent_size, self._context_size, self._bins, self._hidden_units, self._residual_blocks + ) + + def _init_base_model(self) -> NNDensityInferenceWrapper: + model = self._build_model() + + model.load_state_dict(self._snapshot) + model = NNDensityInferenceWrapper(model) + model.eval() + model.to(self._worker_device) + + return model + + def _update_model_op(self): + if self._compile_mode is None: + self._model_op = self._eager_model + else: + if self._compile_mode not in self._compiled_cache: + self._compiled_cache[self._compile_mode] = torch.compile( + self._eager_model, + mode=self._compile_mode + ) + self._model_op = self._compiled_cache[self._compile_mode] + + @staticmethod + def _get_vector(phi_sc: torch.Tensor, theta_sc: torch.Tensor) -> torch.Tensor: + x = theta_sc[:, 0] * phi_sc[:, 1] + y = theta_sc[:, 0] * phi_sc[:, 0] + z = theta_sc[:, 1] + return torch.stack((x, y, z), dim=-1) + + def _convert_conventions(self, dir_az_sc: torch.Tensor, dir_polar_sc: torch.Tensor, + ei: torch.Tensor, em: torch.Tensor, phi: torch.Tensor, + scatt_az_sc: torch.Tensor, scatt_polar_sc: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + eps = em / ei - 1 + + source = self._get_vector(dir_az_sc, dir_polar_sc) + scatter = self._get_vector(scatt_az_sc, scatt_polar_sc) + + dot_product = torch.sum(source * scatter, dim=1) + phi_geo = torch.acos(torch.clamp(dot_product, -1.0, 1.0)) + theta = phi_geo - phi + + xaxis = torch.tensor([1., 0., 0.], device=source.device, dtype=source.dtype) + pz = -source + + px = torch.linalg.cross(pz, xaxis.expand_as(pz)) + px = px / torch.linalg.norm(px, dim=1, keepdim=True) + + py = torch.linalg.cross(pz, px) + py = py / torch.linalg.norm(py, dim=1, keepdim=True) + + proj_x = torch.sum(scatter * px, dim=1) + proj_y = torch.sum(scatter * py, dim=1) + + zeta = torch.atan2(proj_y, proj_x) + zeta = torch.where(zeta < 0, zeta + 2 * np.pi, zeta) + + return eps, theta, zeta + + def _inverse_transform_coordinates(self, samples: torch.Tensor, ei: torch.Tensor, + dir_az_sc: torch.Tensor, dir_pol_sc: torch.Tensor) -> torch.Tensor: + eps = -samples[:, 0] + phi = samples[:, 1] * np.pi + theta = (samples[:, 2] - 0.5) * (2 * np.pi) + zeta = samples[:, 3] * (2 * np.pi) + + em = ei * (eps + 1) + + phi_geo = theta + phi + scatter_phf = self._get_vector(torch.stack((torch.sin(zeta), torch.cos(zeta)), dim=1), + torch.stack((torch.sin(np.pi - phi_geo), torch.cos(np.pi - phi_geo)), dim=1)) + source_vec = self._get_vector(dir_az_sc, dir_pol_sc) + xaxis = torch.tensor([1., 0., 0.], device=self._worker_device, dtype=source_vec.dtype) + + pz = -source_vec + px = torch.linalg.cross(pz, xaxis.expand_as(pz)) + px = px / torch.linalg.norm(px, dim=1, keepdim=True) + py = torch.linalg.cross(pz, px) + py = py / torch.linalg.norm(py, dim=1, keepdim=True) + + basis = torch.stack((px, py, pz), dim=2) + scatter_scf = torch.bmm(basis, scatter_phf.unsqueeze(-1)).squeeze(-1) + + psi_cds = torch.atan2(scatter_scf[:, 1], scatter_scf[:, 0]) + psi_cds = torch.where(psi_cds < 0, psi_cds + 2 * np.pi, psi_cds) + chi_cds = torch.acos(torch.clamp(scatter_scf[:, 2], -1.0, 1.0)) + + return torch.stack([em, phi, psi_cds, chi_cds], dim=1) + + def _transform_coordinates(self, dir_az: torch.Tensor, dir_pol: torch.Tensor, + ei: torch.Tensor, em: torch.Tensor, phi: torch.Tensor, + scatt_az: torch.Tensor, scatt_pol: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) + dir_pol_sc = torch.stack((torch.sin(dir_pol), torch.cos(dir_pol)), dim=1) + scatt_az_sc = torch.stack((torch.sin(scatt_az), torch.cos(scatt_az)), dim=1) + scatt_pol_sc = torch.stack((torch.sin(scatt_pol), torch.cos(scatt_pol)), dim=1) + + eps_raw, theta_raw, zeta_raw = self._convert_conventions( + dir_az_sc, dir_pol_sc, ei, em, phi, scatt_az_sc, scatt_pol_sc + ) + + jac = 1.0 / (ei * torch.sin(theta_raw + phi) * 4 * np.pi**3) + + ctx = torch.cat([ + (dir_az_sc + 1) / 2, + (dir_pol_sc[:, 1:] + 1) / 2, + (torch.log10(ei) / 2 - 1).unsqueeze(1) + ], dim=1) + + src = torch.cat([ + (-eps_raw).unsqueeze(1), + (phi / np.pi).unsqueeze(1), + (theta_raw / (2 * np.pi) + 0.5).unsqueeze(1), + (zeta_raw / (2 * np.pi)).unsqueeze(1) + ], dim=1) + + return ctx.to(torch.float32), src.to(torch.float32), jac.to(torch.float32) + + @staticmethod + def _valid_samples(samples: torch.Tensor) -> torch.Tensor: + phi_geo_norm = samples[:, 1] + 2 * samples[:, 2] - 1.0 + valid_mask = (samples[:, 0] < 1.0) & \ + (samples[:, 1] > 0.0) & (samples[:, 1] <= 1.0) & \ + (samples[:, 2] >= 0.0) & (samples[:, 2] <= 1.0) & \ + (samples[:, 3] >= 0.0) & (samples[:, 3] <= 1.0) & \ + (phi_geo_norm > 0.0) & (phi_geo_norm < 1.0) + + return valid_mask + + @torch.inference_mode() + def sample_density(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor) -> torch.Tensor: + N = dir_az.shape[0] + + result = torch.empty((N, self._latent_size), dtype=torch.float32, device="cpu") + failed_mask = torch.zeros(N, dtype=torch.bool, device="cpu") + + if self._is_cuda: + result, failed_mask = result.pin_memory(), failed_mask.pin_memory() + + def enqueue_sample_transfer(slot_idx, start_idx): + end_idx = min(start_idx + self._batch_size, N) + size = end_idx - start_idx + self._sample_inputs[slot_idx][0][:size].copy_(energy_keV[start_idx:end_idx], non_blocking=True) + self._sample_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) + self._sample_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) + + if self._is_cuda and N > 0: + with torch.cuda.stream(self._transfer_stream): + enqueue_sample_transfer(0, 0) + self._transfer_ready[0].record(self._transfer_stream) + + for i, start in enumerate(range(0, N, self._batch_size)): + curr_idx = i % 2 + next_idx = (i + 1) % 2 + end = min(start + self._batch_size, N) + batch_len = end - start + next_start = start + self._batch_size + + if self._is_cuda: + with torch.cuda.stream(self._compute_stream): + self._compute_stream.wait_event(self._transfer_ready[curr_idx]) + + b_ei, b_az, b_pol = [t[:batch_len] for t in self._sample_inputs[curr_idx]] + + b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) + b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) + + b_ctx = torch.cat([ + (b_az_sc + 1) / 2, + (b_pol_sc[:, 1:] + 1) / 2, + (torch.log10(b_ei) / 2 - 1).unsqueeze(1) + ], dim=1).to(torch.float32) + + b_latent = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) + + self._sample_results[curr_idx][0][:batch_len] = self._inverse_transform_coordinates( + b_latent, b_ei, b_az_sc, b_pol_sc + ) + self._sample_results[curr_idx][1][:batch_len] = ~self._valid_samples(b_latent) + + self._compute_ready[curr_idx].record(self._compute_stream) + + if next_start < N: + with torch.cuda.stream(self._transfer_stream): + enqueue_sample_transfer(next_idx, next_start) + self._transfer_ready[next_idx].record(self._transfer_stream) + + with torch.cuda.stream(self._transfer_stream): + self._transfer_stream.wait_event(self._compute_ready[curr_idx]) + + result[start:end].copy_(self._sample_results[curr_idx][0][:batch_len], non_blocking=True) + failed_mask[start:end].copy_(self._sample_results[curr_idx][1][:batch_len], non_blocking=True) + else: + b_ei = energy_keV[start:end].to(self._worker_device) + b_az, b_pol = dir_az[start:end].to(self._worker_device), dir_polar[start:end].to(self._worker_device) + + b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) + b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) + b_ctx = torch.cat([ + (b_az_sc + 1) / 2, (b_pol_sc[:, 1:] + 1) / 2, + (torch.log10(b_ei) / 2 - 1).unsqueeze(1) + ], dim=1).to(torch.float32) + + b_samples = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) + result[start:end] = self._inverse_transform_coordinates(b_samples, b_ei, b_az_sc, b_pol_sc) + failed_mask[start:end] = ~self._valid_samples(b_samples) + + if self._is_cuda: + torch.cuda.synchronize(self._worker_device) + + if torch.any(failed_mask): + result[failed_mask] = self.sample_density( + dir_az[failed_mask], dir_polar[failed_mask], energy_keV[failed_mask] + ) + + return result + + @torch.inference_mode() + def evaluate_density( + self, dir_az: torch.Tensor, dir_polar: torch.Tensor, + energy_keV: torch.Tensor, menergy_keV: torch.Tensor, + scatt_angle: torch.Tensor, scatt_az: torch.Tensor, + scatt_polar: torch.Tensor) -> torch.Tensor: + + N = dir_az.shape[0] + result = torch.empty(N, dtype=torch.float32, device="cpu") + + if self._is_cuda: + result = result.pin_memory() + + def enqueue_eval_transfer(slot_idx, start_idx): + end_idx = min(start_idx + self._batch_size, N) + size = end_idx - start_idx + tensors = [dir_az, dir_polar, energy_keV, menergy_keV, scatt_angle, scatt_az, scatt_polar] + for i in range(self.source_dim + self.context_dim): + self._eval_inputs[slot_idx][i][:size].copy_(tensors[i][start_idx:end_idx], non_blocking=True) + + if self._is_cuda and N > 0: + with torch.cuda.stream(self._transfer_stream): + enqueue_eval_transfer(0, 0) + self._transfer_ready[0].record(self._transfer_stream) + + for i, start in enumerate(range(0, N, self._batch_size)): + curr_idx = i % 2 + next_idx = (i + 1) % 2 + end = min(start + self._batch_size, N) + batch_len = end - start + next_start = start + self._batch_size + + if self._is_cuda: + with torch.cuda.stream(self._compute_stream): + self._compute_stream.wait_event(self._transfer_ready[curr_idx]) + + ctx, src, jac = self._transform_coordinates(*[t[:batch_len] for t in self._eval_inputs[curr_idx]]) + + torch.mul(self._model_op(src, ctx, mode="inference"), jac, out=self._eval_results[curr_idx][:batch_len]) + + self._compute_ready[curr_idx].record(self._compute_stream) + + if next_start < N: + with torch.cuda.stream(self._transfer_stream): + enqueue_eval_transfer(next_idx, next_start) + + self._transfer_ready[next_idx].record(self._transfer_stream) + + with torch.cuda.stream(self._transfer_stream): + self._transfer_stream.wait_event(self._compute_ready[curr_idx]) + + result[start:end].copy_(self._eval_results[curr_idx][:batch_len], non_blocking=True) + else: + b_az, b_pol = dir_az[start:end].to(self._worker_device), dir_polar[start:end].to(self._worker_device) + b_ei, b_em = energy_keV[start:end].to(self._worker_device), menergy_keV[start:end].to(self._worker_device) + b_phi = scatt_angle[start:end].to(self._worker_device) + b_s_az, b_s_pol = scatt_az[start:end].to(self._worker_device), scatt_polar[start:end].to(self._worker_device) + + ctx, src, jac = self._transform_coordinates(b_az, b_pol, b_ei, b_em, b_phi, b_s_az, b_s_pol) + result[start:end] = self._model_op(src, ctx, mode="inference") * jac + + if self._is_cuda: + torch.cuda.synchronize(self._worker_device) + return result \ No newline at end of file diff --git a/cosipy/threeml/psr_fixed_ei.py b/cosipy/threeml/psr_fixed_ei.py index e06af95c4..edcbe6148 100644 --- a/cosipy/threeml/psr_fixed_ei.py +++ b/cosipy/threeml/psr_fixed_ei.py @@ -1,7 +1,13 @@ import copy -from typing import Optional, Iterable, Type +import os +import json +from typing import Optional, Iterable, Type, Tuple, List, Union, Dict, Any +from itertools import chain, repeat, islice +from operator import attrgetter +import torch import numpy as np +import h5py from astromodels import PointSource from astropy.coordinates import UnitSphericalRepresentation, CartesianRepresentation from astropy.units import Quantity @@ -17,6 +23,9 @@ from cosipy.response.photon_types import PhotonWithDirectionAndEnergyInSCFrame from astropy import units as u +import astropy.constants as c +from astropy.coordinates import SkyCoord +from astropy.time import Time class UnbinnedThreeMLPointSourceResponseTrapz(UnbinnedThreeMLSourceResponseInterface): @@ -210,3 +219,814 @@ def event_probability(self) -> Iterable[float]: self._update_cache() return self._event_prob + + +class UnbinnedThreeMLPointSourceResponseIRFAdaptive(UnbinnedThreeMLSourceResponseInterface): + + def __init__(self, + data: TimeTagEmCDSEventDataInSCFrameInterface, + irf: FarFieldInstrumentResponseFunctionInterface, + sc_history: SpacecraftHistory,): + + """ + Will fold the IRF with the point source spectrum by evaluating the IRF at Ei positions adaptively chosen based on characteristic IRF features + Note that this assumes a smooth flux spectrum + + All IRF queries are cached and can be saved to / loaded from a file + + Parameters + ---------- + data + irf + sc_history + """ + + # Interface inputs + self._source = None + + # Other implementation inputs + self._data = data + self._irf = irf + self._sc_ori = sc_history + + # Default parameters for irf energy node placement + self._total_energy_nodes = (60, 500) + self._peak_nodes = (18, 12) + self._peak_widths = (0.04, 0.1) + self._energy_range = (100., 10_000.) + self._batch_size = 1_000_000 + + # Placeholder for node pool - stored as Tensors + self._width_tensor: Optional[torch.Tensor] = None + self._nodes_primary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self._nodes_secondary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + self._nodes_bkg_1: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self._nodes_bkg_2: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + self._nodes_bkg_3: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + + # Checks to avoid unecessary recomputations + self._last_convolved_source_skycoord = None + self._last_convolved_source_dict_number = None + self._last_convolved_source_dict_density = None + self._sc_coord_sph_cache = None + + # Cached values + self._irf_cache: Optional[torch.Tensor] = None # cm^2/rad/sr + self._irf_energy_node_cache: Optional[np.ndarray] = None # (Optional, only if full batch) + self._area_cache: Optional[np.ndarray] = None # cm^2*s*keV + self._area_energy_node_cache: Optional[np.ndarray] = None + self._exp_events: Optional[float] = None + self._exp_density: Optional[np.ndarray] = None + + # Precomputed spacecraft history + self._mid_times = self._sc_ori.obstime[:-1] + (self._sc_ori.obstime[1:] - self._sc_ori.obstime[:-1]) / 2 + self._sc_ori_center = self._sc_ori.interp(self._mid_times) + + data_times = self._data.time + self._n_events = self._data.nevents + self._unique_mjds, self._inv_idx = np.unique(data_times.mjd, return_inverse=True) + unique_times_obj = Time(self._unique_mjds, format='mjd') + + self._sc_ori_unique = self._sc_ori.interp(unique_times_obj) + + unique_ratio = np.interp(self._unique_mjds, + self._mid_times.mjd, + self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) + + self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) + + #wrong_order = np.where(((data_times[1:] - data_times[:-1]) <= 0))[0] + #data_times[wrong_order + 1] = data_times[wrong_order + 1] + 1 + #self._sc_ori_data = self._sc_ori.interp(data_times) + + #ratio = np.interp(self._data.time.mjd, + # self._mid_times.mjd, + # self._sc_ori.livetime.to_value(u.s)/self._sc_ori.intervals_duration.to_value(u.s)) + #self._livetime_ratio = ratio.astype(np.float32) + + @property + def event_type(self) -> Type[EventInterface]: + return TimeTagEmCDSEventInSCFrameInterface + + def set_integration_parameters(self, + total_energy_nodes: Tuple[int, int] = (60, 500), + peak_nodes: Tuple[int, int] = (18, 12), + peak_widths: Tuple[float, float] = (0.04, 0.1), + energy_range: Tuple[float, float] = (100., 10_000.), + batch_size: int = 1_000_000,): + + # Reset caches if parameters change + if (peak_nodes != self._peak_nodes + or + peak_widths != self._peak_widths + or + total_energy_nodes[0] != self._total_energy_nodes[0]): + self._irf_cache = None + self._irf_energy_node_cache = None + self._width_tensor = None + self._nodes_primary = None + self._nodes_secondary = None + self._nodes_bkg_1 = None + self._nodes_bkg_2 = None + self._nodes_bkg_3 = None + + if (total_energy_nodes[1] != self._total_energy_nodes[1]): + self._area_cache = None + self._area_energy_node_cache = None + + if (energy_range != self._energy_range): + self._irf_cache = None + self._irf_energy_node_cache = None + self._area_cache = None + self._area_energy_node_cache = None + + if total_energy_nodes[0] < (peak_nodes[0] + 2 * peak_nodes[1] + 3): + raise ValueError("To many nodes per peak compared to the total number or peaks!") + + if (total_energy_nodes[0] < 1) or (total_energy_nodes[1] < 1): + raise ValueError("The number of energy nodes must be at least 1.") + + if energy_range[0] >= energy_range[1]: + raise ValueError("The initial energy interval needs to be increasing!") + + if (batch_size < total_energy_nodes[0]) or (batch_size < total_energy_nodes[1]): + raise ValueError("The batch size cannot be smaller than the number of integration nodes.") + + self._total_energy_nodes = total_energy_nodes + self._peak_nodes = peak_nodes + self._peak_widths = peak_widths + self._energy_range = energy_range + self._batch_size = batch_size + + @staticmethod + def _build_nodes(degree: int) -> Tuple[torch.Tensor, torch.Tensor]: + x, w = np.polynomial.legendre.leggauss(degree) + return torch.as_tensor(x, dtype=torch.float32).unsqueeze(0), torch.as_tensor(w, dtype=torch.float32).unsqueeze(0) + + def _build_split_nodes(self, remaining: int, groups: int): + q, r = divmod(remaining, groups) + return [self._build_nodes(q + (1 if i < r else 0)) for i in range(groups)] + + def _init_node_pool(self): + self._width_tensor = torch.tensor([self._peak_widths[0], self._peak_widths[0], + self._peak_widths[1], self._peak_widths[1]], dtype=torch.float32) + + self._nodes_primary = self._build_nodes(self._peak_nodes[0]) + self._nodes_secondary = self._build_nodes(self._peak_nodes[1]) + + self._nodes_bkg_1 = self._build_nodes(self._total_energy_nodes[0] - self._peak_nodes[0]) + + self._nodes_bkg_2 = self._build_split_nodes( + self._total_energy_nodes[0] - self._peak_nodes[0] - self._peak_nodes[1], 2 + ) + + self._nodes_bkg_3 = self._build_split_nodes( + self._total_energy_nodes[0] - self._peak_nodes[0] - 2 * self._peak_nodes[1], 3 + ) + + @staticmethod + def _scale_nodes_exp(E1: torch.Tensor, E2: torch.Tensor, + nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + diff = E2 - E1 + + out_n = (nodes_u + 1).mul(0.5).pow(2).mul(diff).add(E1) + out_w = (nodes_u + 1).mul(0.5).mul(weights_u).mul(diff) + + return out_n, out_w + + @staticmethod + def _scale_nodes_center(E1: torch.Tensor, E2: torch.Tensor, EC: torch.Tensor, + nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask_left = (nodes_u < 0) + width_left = (EC - E1) + width_right = (E2 - EC) + + scale = torch.where(mask_left, width_left, width_right) + + out_n = nodes_u.pow(3).mul(scale).add(EC) + out_w = nodes_u.pow(2).mul(3).mul(weights_u).mul(scale) + + return out_n, out_w + + def _get_escape_peak(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor) -> torch.Tensor: + E2 = 511.0 / (1.0 + 511.0 / energy_m_keV - torch.cos(phi_rad)) + energy = energy_m_keV + 1022.0 - E2 + + accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > 1600.0) & (energy_m_keV < energy) + return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) + + def _get_missing_energy_peak(self, phi_geo_rad: torch.Tensor, energy_m_keV: torch.Tensor, + phi_rad: torch.Tensor, inverse: bool = False) -> torch.Tensor: + cos_geo = torch.cos(phi_geo_rad) + cos_phi = torch.cos(phi_rad) + + if inverse: + denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) + root = torch.sqrt(energy_m_keV * (cos_geo - 1) * (-2044.0 - 5 * energy_m_keV + energy_m_keV * cos_geo + 4 * energy_m_keV * cos_phi)) + energy = 511.0 * (energy_m_keV - energy_m_keV * cos_geo + root) / denom + else: + denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) + root = torch.sqrt(energy_m_keV**2 * (cos_geo - 1) * (cos_phi - 1) * + ((1022.0 + energy_m_keV)**2 - energy_m_keV * (2044.0 + energy_m_keV) * cos_phi - 2 * energy_m_keV**2 * cos_geo * torch.sin(phi_rad/2)**2)) + energy = (energy_m_keV**2 * (1 - cos_geo - cos_phi + cos_phi * cos_geo) + root) / denom + + accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > energy_m_keV) & (energy_m_keV/energy - 1 < -0.2) + return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) + + def init_cache(self): + self._update_cache() + + def clear_cache(self): + self._irf_cache = None + self._irf_energy_node_cache = None + self._area_cache = None + self._area_energy_node_cache = None + self._exp_events = None + self._exp_density = None + + self._last_convolved_source_skycoord = None + self._last_convolved_source_dict_number = None + self._last_convolved_source_dict_density = None + self._sc_coord_sph_cache = None + + def set_source(self, source: Source): + if not isinstance(source, PointSource): + raise TypeError("Please provide a PointSource!") + + self._source = source + + def copy(self) -> "ThreeMLSourceResponseInterface": + new_instance = copy.copy(self) + new_instance.clear_cache() + new_instance._source = None + + return new_instance + + @staticmethod + def _earth_occ(source_coord: SkyCoord, ori: SpacecraftHistory) -> np.ndarray: + gcrs_cart = ori.location.represent_as(CartesianRepresentation) + dist_earth_center = gcrs_cart.norm() + max_angle = np.pi*u.rad - np.arcsin(c.R_earth/dist_earth_center) + src_angle = source_coord.separation(ori.earth_zenith) + return (src_angle < max_angle).astype(np.float32) + + def _compute_area(self): + coord = self._source.position.sky_coord + n_energy = self._total_energy_nodes[1] + + log_E_min = np.log10(self._energy_range[0]) + log_E_max = np.log10(self._energy_range[1]) + + x, w = np.polynomial.legendre.leggauss(n_energy) + + scale = 0.5 * (log_E_max - log_E_min) + y_nodes = scale * x + 0.5 * (log_E_max + log_E_min) + self._area_energy_node_cache = 10**y_nodes + + e_w = (np.log(10) * self._area_energy_node_cache * (w * scale)).astype(np.float32).reshape(1, -1) + e_n = self._area_energy_node_cache.astype(np.float32) + + sc_coord_sph = self._sc_ori_center.get_target_in_sc_frame(coord) + earth_occ_index = self._earth_occ(coord, self._sc_ori_center) + + time_weights = (self._sc_ori.livetime.to_value(u.s)).astype(np.float32) * earth_occ_index + + lon_ph_rad = sc_coord_sph.lon.rad.astype(np.float32) + lat_ph_rad = sc_coord_sph.lat.rad.astype(np.float32) + + n_time = len(lon_ph_rad) + batch_size_time = self._batch_size // n_energy + + total_area = np.zeros(n_energy, dtype=np.float64) + + for i in range(0, n_time, batch_size_time): + print("start_area_loop") + start = i + end = min(i + batch_size_time, n_time) + current_n_time = end - start + + batch_lons = np.repeat(lon_ph_rad[start:end], n_energy) + batch_lats = np.repeat(lat_ph_rad[start:end], n_energy) + batch_energies = np.tile(e_n, current_n_time) + + photons = [ + PhotonWithDirectionAndEnergyInSCFrame(l, b, e) + for l, b, e in zip(batch_lons, batch_lats, batch_energies) + ] + + print("start_area_calc") + eff_areas_flat = np.fromiter( + self._irf.effective_area_cm2(photons), + dtype=np.float32, + count=len(photons) + ) + print("stop_area_calc") + + eff_areas_grid = eff_areas_flat.reshape(current_n_time, n_energy) + t_w_batch = time_weights[start:end] + e_w_flat = e_w.ravel() + + total_area += np.einsum('ij,i,j->j', + eff_areas_grid, + t_w_batch, + e_w_flat) + + self._area_cache = total_area + + def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, + indices: torch.Tensor, mode: int, + sorted_peaks: torch.Tensor, delta: torch.Tensor): + + Emin, Emax = self._energy_range + + if mode == 1: + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=Emax) + + EC = sorted_peaks[:, 0] + + E1, E2, EC = [E.view(-1, 1) for E in (E1, E2, EC)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E1, E2, EC, *self._nodes_primary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_1[0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, Emax, *self._nodes_bkg_1) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E2, Emax, *self._nodes_bkg_1, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + + elif mode == 2: + center_peak = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 + + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) + E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=Emax) + + EC1 = sorted_peaks[:, 0] + EC2 = sorted_peaks[:, 1] + + E1, E2, E3, E4, EC1, EC2 = [E.view(-1, 1) for E in (E1, E2, E3, E4, EC1, EC2)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_2[0][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_2[0]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_2[0], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_2[1][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E4, Emax, *self._nodes_bkg_2[1]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E4, Emax, *self._nodes_bkg_2[1], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + + elif mode == 3: + center_peak_1 = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 + center_peak_2 = (sorted_peaks[:, 1] + sorted_peaks[:, 2]) / 2 + + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak_1) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) + E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=center_peak_2) + E5 = (sorted_peaks[:, 2] - delta[:, 2]).clamp(min=E4) + E6 = (sorted_peaks[:, 2] + delta[:, 2]).clamp(max=Emax) + + EC1, EC2, EC3 = [sorted_peaks[:, i] for i in range(3)] + + E1, E2, E3, E4, E5, E6, EC1, EC2, EC3 = [E.view(-1, 1) for E in (E1, E2, E3, E4, E5, E6, EC1, EC2, EC3)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_3[0][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_3[0]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_3[0], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_3[1][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E4, E5, *self._nodes_bkg_3[1]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E4, E5, *self._nodes_bkg_3[1], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E5, E6, EC3, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E5, E6, EC3, *self._nodes_secondary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_3[2][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E6, Emax, *self._nodes_bkg_3[2]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E6, Emax, *self._nodes_bkg_3[2], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + + def _get_nodes(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor, + phi_geo_rad: torch.Tensor, phi_igeo_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + energy_m_keV = energy_m_keV.view(-1, 1) + phi_rad = phi_rad.view(-1, 1) + phi_geo_rad = phi_geo_rad.view(-1, 1) + phi_igeo_rad = phi_igeo_rad.view(-1, 1) + + batch_size = energy_m_keV.shape[0] + + nodes = torch.zeros((batch_size, self._total_energy_nodes[0]), dtype=torch.float32) + weights = torch.zeros_like(nodes) + + peaks = torch.zeros((batch_size, 4), dtype=torch.float32) + peaks[:, 0] = energy_m_keV.squeeze() + peaks[:, 1] = self._get_escape_peak(energy_m_keV, phi_rad).squeeze() + peaks[:, 2] = self._get_missing_energy_peak(phi_geo_rad, energy_m_keV, phi_rad).squeeze() + peaks[:, 3] = self._get_missing_energy_peak(phi_igeo_rad, energy_m_keV, phi_rad, inverse=True).squeeze() + + diffs = peaks * self._width_tensor[None, ...] + + n_peaks = torch.sum(~torch.isnan(peaks), dim=1) + + indices_1 = torch.where(n_peaks == 1)[0] + indices_2 = torch.where(n_peaks == 2)[0] + indices_3 = torch.where(n_peaks == 3)[0] + + if len(indices_1) > 0: + self._fill_nodes(nodes, weights, indices_1, 1, + peaks[indices_1, :1], diffs[indices_1, :1]) + + if len(indices_2) > 0: + p_sub = peaks[indices_2] + d_sub = diffs[indices_2] + mask = ~torch.isnan(p_sub) + p_comp = p_sub[mask].view(-1, 2) + d_comp = d_sub[mask].view(-1, 2) + self._fill_nodes(nodes, weights, indices_2, 2, p_comp, d_comp) + + if len(indices_3) > 0: + p_sub = peaks[indices_3] + d_sub = diffs[indices_3] + mask = ~torch.isnan(p_sub) + p_comp = p_sub[mask].view(-1, 3) + d_comp = d_sub[mask].view(-1, 3) + + p_sorted, idx = torch.sort(p_comp, dim=1) + d_sorted = torch.gather(d_comp, 1, idx) + + self._fill_nodes(nodes, weights, indices_3, 3, p_sorted, d_sorted) + + return nodes, weights + + def _get_CDS_coordinates(self, lon_src: np.ndarray, lat_src: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + getter = attrgetter('energy_keV', 'scattering_angle_rad', + 'scattered_lon_rad_sc', 'scattered_lat_rad_sc') + + dt = [('energy', 'f4'), ('phi', 'f4'), ('lon', 'f4'), ('lat', 'f4')] + + arr = np.fromiter(map(getter, self._data), dtype=dt, count=self._n_events) + + energy_m_keV = torch.from_numpy(arr['energy']) + phi_rad = torch.from_numpy(arr['phi']) + + lon_scat = arr['lon'] + lat_scat = arr['lat'] + + cos_geo = ( + np.cos(lat_src) * np.cos(lon_src) * np.cos(lat_scat) * np.cos(lon_scat) + + np.cos(lat_src) * np.sin(lon_src) * np.cos(lat_scat) * np.sin(lon_scat) + + np.sin(lat_src) * np.sin(lat_scat) + ) + cos_geo = np.clip(cos_geo, -1.0, 1.0) + + phi_geo_rad = torch.from_numpy(np.arccos(cos_geo)) + phi_igeo_rad = np.pi - phi_geo_rad + + return energy_m_keV, phi_rad, phi_geo_rad, phi_igeo_rad + + def _compute_density(self): + coord = self._source.position.sky_coord + sc_coord_sph = self._sc_coord_sph_cache + #earth_occ_index = self._earth_occ(coord, self._sc_ori_data) + earth_occ_index = self._earth_occ(coord, self._sc_ori_unique)[self._inv_idx] + + lon_ph_rad = sc_coord_sph.lon.rad.astype(np.float32) + lat_ph_rad = sc_coord_sph.lat.rad.astype(np.float32) + + energy_m_keV, phi_rad, phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(lon_ph_rad, lat_ph_rad) + + n_energy = self._total_energy_nodes[0] + batch_size_events = self._batch_size // n_energy + + self._irf_cache = torch.zeros((self._n_events, n_energy), dtype=torch.float32) + + data_iter = iter(self._data) + + for i in range(0, self._n_events, batch_size_events): + print("start_density_loop") + start = i + end = min(i + batch_size_events, self._n_events) + current_n = end - start + + e_sl = energy_m_keV[start:end] + p_sl = phi_rad[start:end] + pg_sl = phi_geo_rad[start:end] + pig_sl = phi_igeo_rad[start:end] + + nodes, weights = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + + if batch_size_events >= self._n_events: + self._irf_energy_node_cache = nodes.numpy() + + batch_lons = np.repeat(lon_ph_rad[start:end], n_energy) + batch_lats = np.repeat(lat_ph_rad[start:end], n_energy) + + current_batch_events = list(islice(data_iter, current_n)) + + nodes_np_flat = nodes.numpy().ravel() + + photons = [ + PhotonWithDirectionAndEnergyInSCFrame(l, b, en) + for l, b, en in zip(batch_lons, batch_lats, nodes_np_flat) + ] + + expanded_events = chain.from_iterable(repeat(x, n_energy) for x in current_batch_events) + + event_pairs = list(zip(photons, expanded_events)) + + print("start_density_calc") + densities_flat = np.fromiter( + self._irf.event_probability(event_pairs), + dtype=np.float32, + count=len(photons) + ) + + eff_areas_flat = np.fromiter( + self._irf.effective_area_cm2(photons), + dtype=np.float32, + count=len(photons) + ) + print("stop_density_calc") + + res_block = torch.from_numpy(densities_flat * eff_areas_flat).reshape(current_n, n_energy) + + occ = torch.from_numpy(earth_occ_index[start:end]).unsqueeze(1) + live = torch.from_numpy(self._livetime_ratio[start:end]).unsqueeze(1) + + res_block *= occ * live * weights + + self._irf_cache[start:end] = res_block + + def _update_cache(self): + + if self._source is None: + raise RuntimeError("Call set_source() first.") + + source_coord = self._source.position.sky_coord + + if (self._sc_coord_sph_cache is None) or (source_coord != self._last_convolved_source_skycoord): + #self._sc_coord_sph_cache = self._sc_ori_data.get_target_in_sc_frame(source_coord) + self._sc_coord_sph_cache = self._sc_ori_unique.get_target_in_sc_frame(source_coord)[self._inv_idx] + + no_recalculation = ((source_coord == self._last_convolved_source_skycoord) + and + (self._irf_cache is not None) + and + (self._area_cache is not None)) + + area_recalculation = ((source_coord != self._last_convolved_source_skycoord) + or + (self._area_cache is None)) + + pdf_recalculation = ((source_coord != self._last_convolved_source_skycoord) + or + (self._irf_cache is None)) + + if no_recalculation: + return + else: + if area_recalculation: + print("start_area_computation") + self._compute_area() + print("stop_area_computation") + + if pdf_recalculation: + self._init_node_pool() + print("start_density_computation") + self._compute_density() + print("stop_density_computation") + + self._last_convolved_source_skycoord = source_coord.copy() + + def cache_to_file(self, filename: str): + with h5py.File(filename, 'w') as f: + f.attrs['total_energy_nodes'] = self._total_energy_nodes + f.attrs['peak_nodes'] = self._peak_nodes + f.attrs['peak_widths'] = self._peak_widths + f.attrs['energy_range'] = self._energy_range + f.attrs['batch_size'] = self._batch_size + + if self._irf_cache is not None: + f.create_dataset('irf_cache', data=self._irf_cache.numpy(), + compression='gzip', compression_opts=4) + + if self._irf_energy_node_cache is not None: + f.create_dataset('irf_energy_node_cache', data=self._irf_energy_node_cache, + compression='gzip') + + if self._area_cache is not None: + f.create_dataset('area_cache', data=self._area_cache, + compression='gzip') + + if self._area_energy_node_cache is not None: + f.create_dataset('area_energy_node_cache', data=self._area_energy_node_cache, + compression='gzip') + + if self._exp_events is not None: + f.create_dataset('exp_events', data=self._exp_events) + + if self._exp_density is not None: + f.create_dataset('exp_density', data=self._exp_density, + compression='gzip') + + if self._last_convolved_source_dict_number is not None: + json_str = json.dumps(self._last_convolved_source_dict_number) + f.attrs['last_convolved_source_dict_number'] = json_str + + if self._last_convolved_source_dict_density is not None: + json_str = json.dumps(self._last_convolved_source_dict_density) + f.attrs['last_convolved_source_dict_density'] = json_str + + if self._last_convolved_source_skycoord is not None: + sc = self._last_convolved_source_skycoord + f.attrs['last_convolved_lon_deg'] = sc.spherical.lon.deg + f.attrs['last_convolved_lat_deg'] = sc.spherical.lat.deg + f.attrs['last_convolved_frame'] = sc.frame.name + + def cache_from_file(self, filename: str): + if not os.path.exists(filename): + raise FileNotFoundError(f"Cache file {filename} not found.") + + with h5py.File(filename, 'r') as f: + self._total_energy_nodes = tuple(f.attrs['total_energy_nodes']) + self._peak_nodes = tuple(f.attrs['peak_nodes']) + self._peak_widths = tuple(f.attrs['peak_widths']) + self._energy_range = tuple(f.attrs['energy_range']) + self._batch_size = int(f.attrs['batch_size']) + + if 'irf_cache' in f: + self._irf_cache = torch.from_numpy(f['irf_cache'][:]) + else: + self._irf_cache = None + + if 'irf_energy_node_cache' in f: + self._irf_energy_node_cache = f['irf_energy_node_cache'][:] + else: + self._irf_energy_node_cache = None + + if 'area_cache' in f: + self._area_cache = f['area_cache'][:] + else: + self._area_cache = None + + if 'area_energy_node_cache' in f: + self._area_energy_node_cache = f['area_energy_node_cache'][:] + else: + self._area_energy_node_cache = None + + if 'exp_events' in f: + self._exp_events = float(f['exp_events'][()]) + else: + self._exp_events = None + + if 'exp_density' in f: + self._exp_density = f['exp_density'][:] + else: + self._exp_density = None + + if 'last_convolved_source_dict_number' in f.attrs: + self._last_convolved_source_dict_number = json.loads(f.attrs['last_convolved_source_dict_number']) + else: + self._last_convolved_source_dict_number = None + + if 'last_convolved_source_dict_density' in f.attrs: + self._last_convolved_source_dict_density = json.loads(f.attrs['last_convolved_source_dict_density']) + else: + self._last_convolved_source_dict_density = None + + if 'last_convolved_lon_deg' in f.attrs: + lon = f.attrs['last_convolved_lon_deg'] + lat = f.attrs['last_convolved_lat_deg'] + frame = f.attrs['last_convolved_frame'] + self._last_convolved_source_skycoord = SkyCoord(lon, lat, unit='deg', frame=frame) + else: + self._last_convolved_source_skycoord = None + + if self._irf_cache is not None: + self._init_node_pool() + + def expected_counts(self) -> float: + """ + Return the total expected counts. + """ + self._update_cache() + source_dict = self._source.to_dict() + + if (source_dict != self._last_convolved_source_dict_number) or (self._exp_events is None): + area = self._area_cache + flux = self._source(self._area_energy_node_cache) + self._exp_events = np.sum(area * flux, dtype=float) + + self._last_convolved_source_dict_number = source_dict + return self._exp_events + + def expectation_density(self) -> Iterable[float]: + """ + Return the expected number of counts density. This equals the event probabiliy times the number of events. + """ + + self._update_cache() + source_dict = self._source.to_dict() + + if (source_dict != self._last_convolved_source_dict_density) or (self._exp_density is None): + + self._exp_density = np.zeros(self._n_events, dtype=np.float64) + + if self._irf_energy_node_cache is not None: + print("Start full") + flux = self._source(self._irf_energy_node_cache) + flux = torch.as_tensor(flux, dtype=self._irf_cache.dtype)#.view(self._irf_cache.shape) + + self._exp_density[:] = torch.einsum('ij,ij->i', self._irf_cache, flux).numpy().astype(np.float64) + + else: + print("Start batched") + sc_coord_sph = self._sc_coord_sph_cache + + lon_ph_rad = sc_coord_sph.lon.rad.astype(np.float32) + lat_ph_rad = sc_coord_sph.lat.rad.astype(np.float32) + + energy_m_keV, phi_rad, phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(lon_ph_rad, lat_ph_rad) + + n_energy = self._total_energy_nodes[0] + batch_size = self._batch_size // n_energy + + for i in range(0, self._n_events, batch_size): + end = min(i + batch_size, self._n_events) + + e_sl = energy_m_keV[i:end] + p_sl = phi_rad[i:end] + pg_sl = phi_geo_rad[i:end] + pig_sl = phi_igeo_rad[i:end] + + nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + + flux = torch.as_tensor(self._source(nodes.numpy()), dtype=self._irf_cache.dtype)#.view(nodes.shape) + + irf_slice = self._irf_cache[i:end] + + self._exp_density[i:end] = torch.einsum('ij,ij->i', irf_slice, flux).numpy().astype(np.float64) + + self._last_convolved_source_dict_density = source_dict + + #print(np.sum(self._exp_density <= 0)/self._n_events * 100) + print(self.expected_counts() - np.sum(np.log(self._exp_density+1e-12))) + return self._exp_density+1e-12 \ No newline at end of file From 658bc034a3467106534b5a5c0d1e706f84196f98 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Tue, 3 Feb 2026 11:35:41 +0100 Subject: [PATCH 02/16] Changed treatment of livetime ratio --- cosipy/threeml/psr_fixed_ei.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/cosipy/threeml/psr_fixed_ei.py b/cosipy/threeml/psr_fixed_ei.py index edcbe6148..dd814fc22 100644 --- a/cosipy/threeml/psr_fixed_ei.py +++ b/cosipy/threeml/psr_fixed_ei.py @@ -290,12 +290,20 @@ def __init__(self, self._sc_ori_unique = self._sc_ori.interp(unique_times_obj) - unique_ratio = np.interp(self._unique_mjds, - self._mid_times.mjd, - self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) - + interval_ratios = (self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) + bin_indices = np.searchsorted(self._sc_ori.obstime.mjd, self._unique_mjds) - 1 + + bin_indices = np.clip(bin_indices, 0, len(self._sc_ori.livetime) - 1) + unique_ratio = interval_ratios[bin_indices] + self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) + #unique_ratio = np.interp(self._unique_mjds, + # self._mid_times.mjd, + # self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) + # + #self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) + #wrong_order = np.where(((data_times[1:] - data_times[:-1]) <= 0))[0] #data_times[wrong_order + 1] = data_times[wrong_order + 1] + 1 #self._sc_ori_data = self._sc_ori.interp(data_times) @@ -501,7 +509,6 @@ def _compute_area(self): total_area = np.zeros(n_energy, dtype=np.float64) for i in range(0, n_time, batch_size_time): - print("start_area_loop") start = i end = min(i + batch_size_time, n_time) current_n_time = end - start @@ -515,13 +522,11 @@ def _compute_area(self): for l, b, e in zip(batch_lons, batch_lats, batch_energies) ] - print("start_area_calc") eff_areas_flat = np.fromiter( self._irf.effective_area_cm2(photons), dtype=np.float32, count=len(photons) ) - print("stop_area_calc") eff_areas_grid = eff_areas_flat.reshape(current_n_time, n_energy) t_w_batch = time_weights[start:end] @@ -761,7 +766,6 @@ def _compute_density(self): data_iter = iter(self._data) for i in range(0, self._n_events, batch_size_events): - print("start_density_loop") start = i end = min(i + batch_size_events, self._n_events) current_n = end - start @@ -792,7 +796,6 @@ def _compute_density(self): event_pairs = list(zip(photons, expanded_events)) - print("start_density_calc") densities_flat = np.fromiter( self._irf.event_probability(event_pairs), dtype=np.float32, @@ -804,7 +807,6 @@ def _compute_density(self): dtype=np.float32, count=len(photons) ) - print("stop_density_calc") res_block = torch.from_numpy(densities_flat * eff_areas_flat).reshape(current_n, n_energy) @@ -844,15 +846,11 @@ def _update_cache(self): return else: if area_recalculation: - print("start_area_computation") self._compute_area() - print("stop_area_computation") if pdf_recalculation: self._init_node_pool() - print("start_density_computation") self._compute_density() - print("stop_density_computation") self._last_convolved_source_skycoord = source_coord.copy() @@ -991,14 +989,12 @@ def expectation_density(self) -> Iterable[float]: self._exp_density = np.zeros(self._n_events, dtype=np.float64) if self._irf_energy_node_cache is not None: - print("Start full") flux = self._source(self._irf_energy_node_cache) flux = torch.as_tensor(flux, dtype=self._irf_cache.dtype)#.view(self._irf_cache.shape) self._exp_density[:] = torch.einsum('ij,ij->i', self._irf_cache, flux).numpy().astype(np.float64) else: - print("Start batched") sc_coord_sph = self._sc_coord_sph_cache lon_ph_rad = sc_coord_sph.lon.rad.astype(np.float32) @@ -1027,6 +1023,6 @@ def expectation_density(self) -> Iterable[float]: self._last_convolved_source_dict_density = source_dict - #print(np.sum(self._exp_density <= 0)/self._n_events * 100) - print(self.expected_counts() - np.sum(np.log(self._exp_density+1e-12))) + print(np.sum(self._exp_density <= 0)/self._n_events * 100) + #print(self.expected_counts() - np.sum(np.log(self._exp_density+1e-12))) return self._exp_density+1e-12 \ No newline at end of file From 1def04299a193d92ffd29eaf67b90857261ba15d Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Thu, 19 Feb 2026 10:40:08 +0100 Subject: [PATCH 03/16] Apply my latest changes on top of the rebase about EventData and PhotonList --- cosipy/response/NNResponse.py | 4 +- .../response/instrument_response_function.py | 97 ++++---- cosipy/response/nnresponse_helper.py | 4 +- cosipy/threeml/psr_fixed_ei.py | 225 +++++++++--------- 4 files changed, 177 insertions(+), 153 deletions(-) diff --git a/cosipy/response/NNResponse.py b/cosipy/response/NNResponse.py index bfdcf5273..dc4dc0a8f 100644 --- a/cosipy/response/NNResponse.py +++ b/cosipy/response/NNResponse.py @@ -79,7 +79,7 @@ def __init__(self, major_version: int, density_input: Dict, worker_device: Union self._setup_model() def _setup_model(self): - version_map = { + version_map: Dict[int, DensityModelProtocol] = { 1: UnpolarizedDensityCMLPDGaussianCARQSFlow(self._density_input, self._worker_device, self._batch_size, self._compile_mode), } if self._major_version not in version_map: @@ -133,7 +133,7 @@ def __init__(self, major_version: int, area_input: Dict, worker_device: Union[st self._setup_model() def _setup_model(self): - version_map = { + version_map: Dict[int, AreaModelProtocol] = { 1: UnpolarizedAreaSphericalHarmonicsExpansion(self._area_input, self._worker_device, self._batch_size, self._compile_mode), } if self._major_version not in version_map: diff --git a/cosipy/response/instrument_response_function.py b/cosipy/response/instrument_response_function.py index e12583fd3..1624eac8c 100644 --- a/cosipy/response/instrument_response_function.py +++ b/cosipy/response/instrument_response_function.py @@ -12,75 +12,86 @@ from scoords import SpacecraftFrame from cosipy.interfaces import EventInterface +from cosipy.interfaces.photon_parameters import PhotonListWithDirectionInSCFrameInterface from cosipy.interfaces.data_interface import EmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface, EmCDSEventInSCFrameInterface from cosipy.interfaces.instrument_response_interface import FarFieldInstrumentResponseFunctionInterface, \ FarFieldSpectralInstrumentResponseFunctionInterface -from cosipy.interfaces.photon_parameters import PhotonInterface, PhotonWithDirectionAndEnergyInSCFrameInterface, PhotonListWithDirectionInterface +from cosipy.interfaces.photon_parameters import PhotonInterface, PhotonWithDirectionAndEnergyInSCFrameInterface, PhotonListWithDirectionInterface, PhotonListWithDirectionAndEnergyInSCFrameInterface +from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays from cosipy.response import FullDetectorResponse from cosipy.response.NNResponse import NNResponse -from cosipy.util.iterables import itertools_batched +from cosipy.util.iterables import itertools_batched, asarray from operator import attrgetter -class UnpolarizedNNFarFieldInstrumentResponseFunction(FarFieldInstrumentResponseFunctionInterface): +class UnpolarizedNNFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): - photon_type = PhotonWithDirectionAndEnergyInSCFrameInterface - event_type = EmCDSEventInSCFrameInterface + event_data_type = EmCDSEventDataInSCFrameInterface + photon_list_type = PhotonListWithDirectionAndEnergyInSCFrameInterface def __init__(self, response: NNResponse,): if response.is_polarized: raise ValueError("The provided NNResponse is polarized, but UnpolarizedNNFarFieldInstrumentResponseFunction only supports unpolarized responses.") self._response = response - def effective_area_cm2(self, photons: Iterable[PhotonWithDirectionAndEnergyInSCFrameInterface]) -> Iterable[float]: - getter = attrgetter('direction_lon_radians', 'direction_lat_radians', 'energy_keV') + @staticmethod + def _get_context(photons: PhotonListWithDirectionAndEnergyInSCFrameInterface): + lon = asarray(photons.direction_lon_rad_sc, dtype=np.float32) + lat = asarray(photons.direction_lat_rad_sc, dtype=np.float32) + en = asarray(photons.energy_keV, dtype=np.float32) + + num_photons = lon.shape[0] + context = torch.empty((num_photons, 3), dtype=torch.float32) + + context[:, 0] = torch.from_numpy(lon) + context[:, 1] = torch.from_numpy(lat) + context[:, 2] = torch.from_numpy(en) - raw_data = list(map(getter, photons)) - - if not raw_data: - return np.array([], dtype=np.float32) + context[:, 1].mul_(-1).add_(np.pi/2) - context = torch.tensor(raw_data, dtype=torch.float32) - context[:, 1] = np.pi/2 - context[:, 1] - - return self._response.evaluate_effective_area(context).numpy() + return context - def event_probability(self, query: Iterable[Tuple[PhotonWithDirectionAndEnergyInSCFrameInterface, EmCDSEventInSCFrameInterface]]) -> Iterable[float]: - context_list = [] - source_list = [] - - for photon, event in query: - context_list.append((photon.direction_lon_radians, photon.direction_lat_radians, photon.energy_keV)) - source_list.append((event.energy_keV, event.scattering_angle_rad, event.scattered_lon_rad_sc, event.scattered_lat_rad_sc)) + @staticmethod + def _get_source(events: EmCDSEventDataInSCFrameInterface): + lon = asarray(events.scattered_lon_rad_sc, dtype=np.float32) + lat = asarray(events.scattered_lat_rad_sc, dtype=np.float32) + phi = asarray(events.scattering_angle_rad, dtype=np.float32) + en = asarray(events.energy_keV, dtype=np.float32) - if not context_list: - return np.array([], dtype=np.float32) + num_events = lon.shape[0] + source = torch.empty((num_events, 4), dtype=torch.float32) - context = torch.tensor(context_list, dtype=torch.float32) - source = torch.tensor(source_list, dtype=torch.float32) - - context[:, 1] = np.pi/2 - context[:, 1] - source[:, 3] = np.pi/2 - source[:, 3] + source[:, 0] = torch.from_numpy(en) + source[:, 1] = torch.from_numpy(phi) + source[:, 2] = torch.from_numpy(lon) + source[:, 3] = torch.from_numpy(lat) - return self._response.evaluate_density(context, source).numpy() - - def random_events(self, photons: Iterable[PhotonWithDirectionAndEnergyInSCFrameInterface]) -> Iterable[EventInterface]: - getter = attrgetter('direction_lon_radians', 'direction_lat_radians', 'energy_keV') + source[:, 3].mul_(-1).add_(np.pi/2) - raw_data = list(map(getter, photons)) + return source - if not raw_data: - return [] + def _effective_area_cm2(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> Iterable[float]: + context = self._get_context(photons) - context = torch.tensor(raw_data, dtype=torch.float32) - context[:, 1] = np.pi/2 - context[:, 1] + return np.asarray(self._response.evaluate_effective_area(context)) + + def _event_probability(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface, events: EmCDSEventDataInSCFrameInterface) -> Iterable[float]: + source = self._get_source(events) + context = self._get_context(photons) - samples = self._response.sample_density(context).numpy() - samples[:, 3] = np.pi/2 - samples[:, 3] + return np.asarray(self._response.evaluate_density(context, source)) + + def _random_events(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> EmCDSEventDataInSCFrameInterface: + context = self._get_context(photons) + samples = np.asarray(self._response.sample_density(context)) + samples[:, 3].mul_(-1).add_(np.pi/2) - return [ - EmCDSEventInSCFrame(e, phi, lon, lat) for e, phi, lon, lat in samples - ] + return EmCDSEventDataInSCFrameFromArrays( + samples[:, 0], # Energy + samples[:, 2], # Lon + samples[:, 3], # Lat + samples[:, 1] # Phi + ) class UnpolarizedDC3InterpolatedFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): diff --git a/cosipy/response/nnresponse_helper.py b/cosipy/response/nnresponse_helper.py index 0339bf8dc..52ad008d2 100644 --- a/cosipy/response/nnresponse_helper.py +++ b/cosipy/response/nnresponse_helper.py @@ -4,7 +4,7 @@ from torch import nn import healpy as hp import sphericart.torch -from typing import Protocol, Optional, Literal, List, Union, Tuple, Dict +from typing import Protocol, Optional, Literal, List, Union, Tuple, Dict, runtime_checkable CompileMode = Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] @@ -104,6 +104,7 @@ def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_hidden_ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) +@runtime_checkable class AreaModelProtocol(Protocol): @property def context_dim(self) -> int: ... @@ -301,6 +302,7 @@ def enqueue_transfer(slot_idx, start_idx): return torch.clamp(result, min=0) +@runtime_checkable class DensityModelProtocol(Protocol): @property def context_dim(self) -> int: ... diff --git a/cosipy/threeml/psr_fixed_ei.py b/cosipy/threeml/psr_fixed_ei.py index dd814fc22..27dc5864c 100644 --- a/cosipy/threeml/psr_fixed_ei.py +++ b/cosipy/threeml/psr_fixed_ei.py @@ -15,12 +15,13 @@ from histpy import Axis from cosipy import SpacecraftHistory -from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventInSCFrame +from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventInSCFrame, EmCDSEventDataInSCFrameFromArrays from cosipy.interfaces import UnbinnedThreeMLSourceResponseInterface, EventInterface from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import EmCDSEventInSCFrameInterface, TimeTagEmCDSEventInSCFrameInterface from cosipy.interfaces.instrument_response_interface import FarFieldInstrumentResponseFunctionInterface -from cosipy.response.photon_types import PhotonWithDirectionAndEnergyInSCFrame +from cosipy.response.photon_types import PhotonWithDirectionAndEnergyInSCFrame, PhotonListWithDirectionAndEnergyInSCFrame +from cosipy.util.iterables import asarray from astropy import units as u import astropy.constants as c @@ -277,7 +278,7 @@ def __init__(self, self._area_cache: Optional[np.ndarray] = None # cm^2*s*keV self._area_energy_node_cache: Optional[np.ndarray] = None self._exp_events: Optional[float] = None - self._exp_density: Optional[np.ndarray] = None + self._exp_density: Optional[torch.Tensor] = None # Precomputed spacecraft history self._mid_times = self._sc_ori.obstime[:-1] + (self._sc_ori.obstime[1:] - self._sc_ori.obstime[:-1]) / 2 @@ -298,6 +299,16 @@ def __init__(self, self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) + self._energy_m_keV = torch.as_tensor(asarray(self._data.energy_keV, dtype=np.float32)) + self._phi_rad = torch.as_tensor(asarray(self._data.scattering_angle_rad, dtype=np.float32)) + + self._lon_scatt = torch.as_tensor(asarray(self._data.scattered_lon_rad_sc, dtype=np.float32)) + self._lat_scatt = torch.as_tensor(asarray(self._data.scattered_lat_rad_sc, dtype=np.float32)) + self._cos_lat_scatt = torch.cos(self._lat_scatt) + self._sin_lat_scatt = torch.sin(self._lat_scatt) + self._cos_lon_scatt = torch.cos(self._lon_scatt) + self._sin_lon_scatt = torch.sin(self._lon_scatt) + #unique_ratio = np.interp(self._unique_mjds, # self._mid_times.mjd, # self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) @@ -464,7 +475,7 @@ def set_source(self, source: Source): self._source = source - def copy(self) -> "ThreeMLSourceResponseInterface": + def copy(self) -> UnbinnedThreeMLSourceResponseInterface: new_instance = copy.copy(self) new_instance.clear_cache() new_instance._source = None @@ -500,48 +511,51 @@ def _compute_area(self): time_weights = (self._sc_ori.livetime.to_value(u.s)).astype(np.float32) * earth_occ_index - lon_ph_rad = sc_coord_sph.lon.rad.astype(np.float32) - lat_ph_rad = sc_coord_sph.lat.rad.astype(np.float32) + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) n_time = len(lon_ph_rad) batch_size_time = self._batch_size // n_energy total_area = np.zeros(n_energy, dtype=np.float64) + + max_batch_total = n_energy * min(batch_size_time, n_time) + batch_lons_buffer = np.empty(max_batch_total, dtype=np.float32) + batch_lats_buffer = np.empty(max_batch_total, dtype=np.float32) + batch_energies_buffer = np.empty(max_batch_total, dtype=np.float32) for i in range(0, n_time, batch_size_time): start = i end = min(i + batch_size_time, n_time) current_n_time = end - start + current_total = current_n_time * n_energy + + #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lons_buffer[:current_total]) + #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lats_buffer[:current_total]) - batch_lons = np.repeat(lon_ph_rad[start:end], n_energy) - batch_lats = np.repeat(lat_ph_rad[start:end], n_energy) - batch_energies = np.tile(e_n, current_n_time) - - photons = [ - PhotonWithDirectionAndEnergyInSCFrame(l, b, e) - for l, b, e in zip(batch_lons, batch_lats, batch_energies) - ] - - eff_areas_flat = np.fromiter( - self._irf.effective_area_cm2(photons), - dtype=np.float32, - count=len(photons) - ) - + batch_lons_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] + batch_lats_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] + batch_energies_buffer[:current_total].reshape(current_n_time, n_energy)[:] = e_n + + photons = PhotonListWithDirectionAndEnergyInSCFrame( + batch_lons_buffer[:current_total], + batch_lats_buffer[:current_total], + batch_energies_buffer[:current_total] + ) + + eff_areas_flat = asarray(self._irf._effective_area_cm2(photons), dtype=np.float32) eff_areas_grid = eff_areas_flat.reshape(current_n_time, n_energy) - t_w_batch = time_weights[start:end] - e_w_flat = e_w.ravel() total_area += np.einsum('ij,i,j->j', eff_areas_grid, - t_w_batch, - e_w_flat) + time_weights[start:end], + e_w.ravel()) self._area_cache = total_area def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, - indices: torch.Tensor, mode: int, - sorted_peaks: torch.Tensor, delta: torch.Tensor): + indices: torch.Tensor, mode: int, + sorted_peaks: torch.Tensor, delta: torch.Tensor): Emin, Emax = self._energy_range @@ -721,97 +735,97 @@ def _get_nodes(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor, return nodes, weights - def _get_CDS_coordinates(self, lon_src: np.ndarray, lat_src: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - getter = attrgetter('energy_keV', 'scattering_angle_rad', - 'scattered_lon_rad_sc', 'scattered_lat_rad_sc') - - dt = [('energy', 'f4'), ('phi', 'f4'), ('lon', 'f4'), ('lat', 'f4')] - - arr = np.fromiter(map(getter, self._data), dtype=dt, count=self._n_events) + def _get_CDS_coordinates(self, lon_src_rad: torch.Tensor, lat_src_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + cos_lat_src = torch.cos(lat_src_rad) + sin_lat_src = torch.sin(lat_src_rad) + cos_lon_src = torch.cos(lon_src_rad) + sin_lon_src = torch.sin(lon_src_rad) - energy_m_keV = torch.from_numpy(arr['energy']) - phi_rad = torch.from_numpy(arr['phi']) - - lon_scat = arr['lon'] - lat_scat = arr['lat'] - cos_geo = ( - np.cos(lat_src) * np.cos(lon_src) * np.cos(lat_scat) * np.cos(lon_scat) + - np.cos(lat_src) * np.sin(lon_src) * np.cos(lat_scat) * np.sin(lon_scat) + - np.sin(lat_src) * np.sin(lat_scat) + cos_lat_src * cos_lon_src * self._cos_lat_scatt * self._cos_lon_scatt + + cos_lat_src * sin_lon_src * self._cos_lat_scatt * self._sin_lon_scatt + + sin_lat_src * self._sin_lat_scatt ) - cos_geo = np.clip(cos_geo, -1.0, 1.0) - phi_geo_rad = torch.from_numpy(np.arccos(cos_geo)) - phi_igeo_rad = np.pi - phi_geo_rad + cos_geo = torch.clip(cos_geo, -1.0, 1.0) + phi_geo_rad = torch.arccos(cos_geo) - return energy_m_keV, phi_rad, phi_geo_rad, phi_igeo_rad + return phi_geo_rad, np.pi - phi_geo_rad def _compute_density(self): coord = self._source.position.sky_coord sc_coord_sph = self._sc_coord_sph_cache - #earth_occ_index = self._earth_occ(coord, self._sc_ori_data) earth_occ_index = self._earth_occ(coord, self._sc_ori_unique)[self._inv_idx] - lon_ph_rad = sc_coord_sph.lon.rad.astype(np.float32) - lat_ph_rad = sc_coord_sph.lat.rad.astype(np.float32) + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) - energy_m_keV, phi_rad, phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(lon_ph_rad, lat_ph_rad) + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) n_energy = self._total_energy_nodes[0] batch_size_events = self._batch_size // n_energy self._irf_cache = torch.zeros((self._n_events, n_energy), dtype=torch.float32) - - data_iter = iter(self._data) + + buffer_size = n_energy * min(batch_size_events, self._n_events) + batch_lon_src_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lat_src_buffer = np.empty(buffer_size, dtype=np.float32) + batch_energy_buffer = np.empty(buffer_size, dtype=np.float32) + batch_phi_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lon_scatt_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lat_scatt_buffer = np.empty(buffer_size, dtype=np.float32) for i in range(0, self._n_events, batch_size_events): start = i end = min(i + batch_size_events, self._n_events) current_n = end - start + current_total = current_n * n_energy - e_sl = energy_m_keV[start:end] - p_sl = phi_rad[start:end] + e_sl = self._energy_m_keV[start:end] + p_sl = self._phi_rad[start:end] pg_sl = phi_geo_rad[start:end] pig_sl = phi_igeo_rad[start:end] nodes, weights = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) if batch_size_events >= self._n_events: - self._irf_energy_node_cache = nodes.numpy() - - batch_lons = np.repeat(lon_ph_rad[start:end], n_energy) - batch_lats = np.repeat(lat_ph_rad[start:end], n_energy) - - current_batch_events = list(islice(data_iter, current_n)) + self._irf_energy_node_cache = np.asarray(nodes) - nodes_np_flat = nodes.numpy().ravel() - - photons = [ - PhotonWithDirectionAndEnergyInSCFrame(l, b, en) - for l, b, en in zip(batch_lons, batch_lats, nodes_np_flat) - ] - - expanded_events = chain.from_iterable(repeat(x, n_energy) for x in current_batch_events) - - event_pairs = list(zip(photons, expanded_events)) - - densities_flat = np.fromiter( - self._irf.event_probability(event_pairs), - dtype=np.float32, - count=len(photons) - ) - - eff_areas_flat = np.fromiter( - self._irf.effective_area_cm2(photons), - dtype=np.float32, - count=len(photons) + #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lon_src_buffer[:current_total]) + #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lat_src_buffer[:current_total]) + + batch_lon_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] + batch_lat_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] + + batch_energy_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._energy_m_keV[start:end, np.newaxis]) + batch_lon_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lon_scatt[start:end, np.newaxis]) + batch_lat_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lat_scatt[start:end, np.newaxis]) + batch_phi_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._phi_rad[start:end, np.newaxis]) + + #np.repeat(np.asarray(self._energy_m_keV[start:end]), n_energy, out=batch_energy_buffer[:current_total]) + #np.repeat(np.asarray(self._lon_scatt[start:end]), n_energy, out=batch_lon_scatt_buffer[:current_total]) + #np.repeat(np.asarray(self._lat_scatt[start:end]), n_energy, out=batch_lat_scatt_buffer[:current_total]) + #np.repeat(np.asarray(self._phi_rad[start:end]), n_energy, out=batch_phi_buffer[:current_total]) + + photons = PhotonListWithDirectionAndEnergyInSCFrame( + batch_lon_src_buffer[:current_total], + batch_lat_src_buffer[:current_total], + np.asarray(nodes).ravel() + ) + events = EmCDSEventDataInSCFrameFromArrays( + batch_energy_buffer[:current_total], + batch_lon_scatt_buffer[:current_total], + batch_lat_scatt_buffer[:current_total], + batch_phi_buffer[:current_total], ) + + eff_areas_flat = torch.as_tensor(asarray(self._irf._effective_area_cm2(photons), dtype=np.float32)) + densities_flat = torch.as_tensor(asarray(self._irf._event_probability(photons, events), dtype=np.float32)) - res_block = torch.from_numpy(densities_flat * eff_areas_flat).reshape(current_n, n_energy) + res_block = (densities_flat * eff_areas_flat).view(current_n, n_energy) - occ = torch.from_numpy(earth_occ_index[start:end]).unsqueeze(1) - live = torch.from_numpy(self._livetime_ratio[start:end]).unsqueeze(1) + occ = torch.as_tensor(earth_occ_index[start:end]).view(-1, 1) + live = torch.as_tensor(self._livetime_ratio[start:end]).view(-1, 1) res_block *= occ * live * weights @@ -882,7 +896,7 @@ def cache_to_file(self, filename: str): f.create_dataset('exp_events', data=self._exp_events) if self._exp_density is not None: - f.create_dataset('exp_density', data=self._exp_density, + f.create_dataset('exp_density', data=self._exp_density.numpy(), compression='gzip') if self._last_convolved_source_dict_number is not None: @@ -936,7 +950,7 @@ def cache_from_file(self, filename: str): self._exp_events = None if 'exp_density' in f: - self._exp_density = f['exp_density'][:] + self._exp_density = torch.from_numpy(f['exp_density'][:]) else: self._exp_density = None @@ -985,44 +999,41 @@ def expectation_density(self) -> Iterable[float]: source_dict = self._source.to_dict() if (source_dict != self._last_convolved_source_dict_density) or (self._exp_density is None): - - self._exp_density = np.zeros(self._n_events, dtype=np.float64) + self._exp_density = torch.zeros(self._n_events, dtype=self._irf_cache.dtype) if self._irf_energy_node_cache is not None: - flux = self._source(self._irf_energy_node_cache) - flux = torch.as_tensor(flux, dtype=self._irf_cache.dtype)#.view(self._irf_cache.shape) + flux = torch.as_tensor(self._source(self._irf_energy_node_cache), dtype=self._irf_cache.dtype) - self._exp_density[:] = torch.einsum('ij,ij->i', self._irf_cache, flux).numpy().astype(np.float64) + torch.linalg.vecdot(self._irf_cache, flux, dim=1, out=self._exp_density) else: + n_energy = self._total_energy_nodes[0] + batch_size = self._batch_size // n_energy + sc_coord_sph = self._sc_coord_sph_cache - lon_ph_rad = sc_coord_sph.lon.rad.astype(np.float32) - lat_ph_rad = sc_coord_sph.lat.rad.astype(np.float32) + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) - energy_m_keV, phi_rad, phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(lon_ph_rad, lat_ph_rad) - - n_energy = self._total_energy_nodes[0] - batch_size = self._batch_size // n_energy + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) for i in range(0, self._n_events, batch_size): end = min(i + batch_size, self._n_events) - e_sl = energy_m_keV[i:end] - p_sl = phi_rad[i:end] + e_sl = self._energy_m_keV[i:end] + p_sl = self._phi_rad[i:end] pg_sl = phi_geo_rad[i:end] pig_sl = phi_igeo_rad[i:end] nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) - flux = torch.as_tensor(self._source(nodes.numpy()), dtype=self._irf_cache.dtype)#.view(nodes.shape) - - irf_slice = self._irf_cache[i:end] - - self._exp_density[i:end] = torch.einsum('ij,ij->i', irf_slice, flux).numpy().astype(np.float64) + flux_batch = torch.as_tensor(self._source(np.asarray(nodes)), dtype=self._irf_cache.dtype) + + torch.linalg.vecdot(self._irf_cache[i:end], flux_batch, dim=1, out=self._exp_density[i:end]) self._last_convolved_source_dict_density = source_dict - print(np.sum(self._exp_density <= 0)/self._n_events * 100) + #print(self._data.time.unix[self._exp_density <= 0][:100]) + #print(np.sum(self._exp_density <= 0)/self._n_events * 100) #print(self.expected_counts() - np.sum(np.log(self._exp_density+1e-12))) - return self._exp_density+1e-12 \ No newline at end of file + return np.asarray(self._exp_density, dtype=np.float64)+1e-12 \ No newline at end of file From 6a8dde07b571d58e8ed031fe83a3c6da80bd990b Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Thu, 19 Feb 2026 15:00:15 +0100 Subject: [PATCH 04/16] Fixed bug in sampling --- .../response/instrument_response_function.py | 42 ++++++------------- cosipy/response/nnresponse_helper.py | 15 ++++--- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/cosipy/response/instrument_response_function.py b/cosipy/response/instrument_response_function.py index 1624eac8c..befe4fd11 100644 --- a/cosipy/response/instrument_response_function.py +++ b/cosipy/response/instrument_response_function.py @@ -36,39 +36,22 @@ def __init__(self, response: NNResponse,): @staticmethod def _get_context(photons: PhotonListWithDirectionAndEnergyInSCFrameInterface): - lon = asarray(photons.direction_lon_rad_sc, dtype=np.float32) - lat = asarray(photons.direction_lat_rad_sc, dtype=np.float32) - en = asarray(photons.energy_keV, dtype=np.float32) - - num_photons = lon.shape[0] - context = torch.empty((num_photons, 3), dtype=torch.float32) - - context[:, 0] = torch.from_numpy(lon) - context[:, 1] = torch.from_numpy(lat) - context[:, 2] = torch.from_numpy(en) - - context[:, 1].mul_(-1).add_(np.pi/2) + lon = torch.as_tensor(asarray(photons.direction_lon_rad_sc, dtype=np.float32)) + lat = torch.as_tensor(asarray(photons.direction_lat_rad_sc, dtype=np.float32)) + en = torch.as_tensor(asarray(photons.energy_keV, dtype=np.float32)) - return context + lat = -lat + (np.pi / 2) + return torch.stack([lon, lat, en], dim=1) @staticmethod def _get_source(events: EmCDSEventDataInSCFrameInterface): - lon = asarray(events.scattered_lon_rad_sc, dtype=np.float32) - lat = asarray(events.scattered_lat_rad_sc, dtype=np.float32) - phi = asarray(events.scattering_angle_rad, dtype=np.float32) - en = asarray(events.energy_keV, dtype=np.float32) - - num_events = lon.shape[0] - source = torch.empty((num_events, 4), dtype=torch.float32) - - source[:, 0] = torch.from_numpy(en) - source[:, 1] = torch.from_numpy(phi) - source[:, 2] = torch.from_numpy(lon) - source[:, 3] = torch.from_numpy(lat) - - source[:, 3].mul_(-1).add_(np.pi/2) + lon = torch.as_tensor(asarray(events.scattered_lon_rad_sc, dtype=np.float32)) + lat = torch.as_tensor(asarray(events.scattered_lat_rad_sc, dtype=np.float32)) + phi = torch.as_tensor(asarray(events.scattering_angle_rad, dtype=np.float32)) + en = torch.as_tensor(asarray(events.energy_keV, dtype=np.float32)) - return source + lat = -lat + (np.pi / 2) + return torch.stack([en, phi, lon, lat], dim=1) def _effective_area_cm2(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> Iterable[float]: context = self._get_context(photons) @@ -83,8 +66,9 @@ def _event_probability(self, photons: PhotonListWithDirectionAndEnergyInSCFrameI def _random_events(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> EmCDSEventDataInSCFrameInterface: context = self._get_context(photons) - samples = np.asarray(self._response.sample_density(context)) + samples = self._response.sample_density(context) samples[:, 3].mul_(-1).add_(np.pi/2) + samples = np.asarray(samples) return EmCDSEventDataInSCFrameFromArrays( samples[:, 0], # Energy diff --git a/cosipy/response/nnresponse_helper.py b/cosipy/response/nnresponse_helper.py index 52ad008d2..f741559bb 100644 --- a/cosipy/response/nnresponse_helper.py +++ b/cosipy/response/nnresponse_helper.py @@ -340,6 +340,8 @@ def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.dev self._latent_size = density_input["latent_size"] self._mlp_hidden_units = density_input["mlp_hidden_units"] self._mlp_hidden_layers = density_input["mlp_hidden_layers"] + self._menergy_cuts = density_input["menergy_cuts"] + self._phi_cuts = density_input["phi_cuts"] self._compile_mode = compile_mode self._compiled_cache = {} @@ -534,14 +536,17 @@ def _transform_coordinates(self, dir_az: torch.Tensor, dir_pol: torch.Tensor, return ctx.to(torch.float32), src.to(torch.float32), jac.to(torch.float32) - @staticmethod - def _valid_samples(samples: torch.Tensor) -> torch.Tensor: + def _valid_samples(self, ienergy: torch.Tensor, samples: torch.Tensor) -> torch.Tensor: phi_geo_norm = samples[:, 1] + 2 * samples[:, 2] - 1.0 valid_mask = (samples[:, 0] < 1.0) & \ (samples[:, 1] > 0.0) & (samples[:, 1] <= 1.0) & \ (samples[:, 2] >= 0.0) & (samples[:, 2] <= 1.0) & \ (samples[:, 3] >= 0.0) & (samples[:, 3] <= 1.0) & \ - (phi_geo_norm > 0.0) & (phi_geo_norm < 1.0) + (phi_geo_norm > 0.0) & (phi_geo_norm < 1.0) & \ + (samples[:, 0] <= (1 - self._menergy_cuts[0]/ienergy)) & \ + (samples[:, 0] >= (1 - self._menergy_cuts[1]/ienergy)) & \ + (samples[:, 1] >= self._phi_cuts[0]/np.pi) & \ + (samples[:, 1] <= self._phi_cuts[1]/np.pi) return valid_mask @@ -594,7 +599,7 @@ def enqueue_sample_transfer(slot_idx, start_idx): self._sample_results[curr_idx][0][:batch_len] = self._inverse_transform_coordinates( b_latent, b_ei, b_az_sc, b_pol_sc ) - self._sample_results[curr_idx][1][:batch_len] = ~self._valid_samples(b_latent) + self._sample_results[curr_idx][1][:batch_len] = ~self._valid_samples(b_ei, b_latent) self._compute_ready[curr_idx].record(self._compute_stream) @@ -621,7 +626,7 @@ def enqueue_sample_transfer(slot_idx, start_idx): b_samples = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) result[start:end] = self._inverse_transform_coordinates(b_samples, b_ei, b_az_sc, b_pol_sc) - failed_mask[start:end] = ~self._valid_samples(b_samples) + failed_mask[start:end] = ~self._valid_samples(b_ei, b_samples) if self._is_cuda: torch.cuda.synchronize(self._worker_device) From 803c46ee877139210b0c07fc75af2be7167e4b68 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Thu, 19 Feb 2026 16:01:32 +0100 Subject: [PATCH 05/16] Moved NN implementation to own files and added import safeguard --- cosipy/response/NNResponse.py | 12 +- cosipy/response/__init__.py | 4 +- .../response/instrument_response_function.py | 61 +- .../nn_instrument_response_function.py | 72 ++ ...sponse_helper.py => nn_response_helper.py} | 17 +- cosipy/threeml/optimized_unbinned_folding.py | 849 ++++++++++++++++++ cosipy/threeml/psr_fixed_ei.py | 835 +---------------- 7 files changed, 951 insertions(+), 899 deletions(-) create mode 100644 cosipy/response/nn_instrument_response_function.py rename cosipy/response/{nnresponse_helper.py => nn_response_helper.py} (99%) create mode 100644 cosipy/threeml/optimized_unbinned_folding.py diff --git a/cosipy/response/NNResponse.py b/cosipy/response/NNResponse.py index dc4dc0a8f..12ba434b5 100644 --- a/cosipy/response/NNResponse.py +++ b/cosipy/response/NNResponse.py @@ -1,7 +1,15 @@ -import torch from typing import List, Union + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch import torch.multiprocessing as mp -from .nnresponse_helper import * +from .nn_response_helper import * + def cuda_cleanup_task(_) -> bool: if torch.cuda.is_available(): diff --git a/cosipy/response/__init__.py b/cosipy/response/__init__.py index aefd11d21..ff9b83d0a 100644 --- a/cosipy/response/__init__.py +++ b/cosipy/response/__init__.py @@ -8,6 +8,4 @@ from .threeml_point_source_response import * from .threeml_extended_source_response import * from .instrument_response import * -from .rsp_to_arf_rmf import RspArfRmfConverter -from .NNResponse import NNResponse -from .instrument_response_function import * +from .rsp_to_arf_rmf import RspArfRmfConverter \ No newline at end of file diff --git a/cosipy/response/instrument_response_function.py b/cosipy/response/instrument_response_function.py index befe4fd11..95962ed24 100644 --- a/cosipy/response/instrument_response_function.py +++ b/cosipy/response/instrument_response_function.py @@ -1,7 +1,6 @@ import itertools from typing import Iterable, Tuple -import torch import numpy as np from astropy.coordinates import SkyCoord @@ -12,70 +11,14 @@ from scoords import SpacecraftFrame from cosipy.interfaces import EventInterface -from cosipy.interfaces.photon_parameters import PhotonListWithDirectionInSCFrameInterface from cosipy.interfaces.data_interface import EmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface, EmCDSEventInSCFrameInterface from cosipy.interfaces.instrument_response_interface import FarFieldInstrumentResponseFunctionInterface, \ FarFieldSpectralInstrumentResponseFunctionInterface -from cosipy.interfaces.photon_parameters import PhotonInterface, PhotonWithDirectionAndEnergyInSCFrameInterface, PhotonListWithDirectionInterface, PhotonListWithDirectionAndEnergyInSCFrameInterface -from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays +from cosipy.interfaces.photon_parameters import PhotonInterface, PhotonWithDirectionAndEnergyInSCFrameInterface, PhotonListWithDirectionInterface from cosipy.response import FullDetectorResponse -from cosipy.response.NNResponse import NNResponse -from cosipy.util.iterables import itertools_batched, asarray -from operator import attrgetter +from cosipy.util.iterables import itertools_batched -class UnpolarizedNNFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): - - event_data_type = EmCDSEventDataInSCFrameInterface - photon_list_type = PhotonListWithDirectionAndEnergyInSCFrameInterface - - def __init__(self, response: NNResponse,): - if response.is_polarized: - raise ValueError("The provided NNResponse is polarized, but UnpolarizedNNFarFieldInstrumentResponseFunction only supports unpolarized responses.") - self._response = response - - @staticmethod - def _get_context(photons: PhotonListWithDirectionAndEnergyInSCFrameInterface): - lon = torch.as_tensor(asarray(photons.direction_lon_rad_sc, dtype=np.float32)) - lat = torch.as_tensor(asarray(photons.direction_lat_rad_sc, dtype=np.float32)) - en = torch.as_tensor(asarray(photons.energy_keV, dtype=np.float32)) - - lat = -lat + (np.pi / 2) - return torch.stack([lon, lat, en], dim=1) - - @staticmethod - def _get_source(events: EmCDSEventDataInSCFrameInterface): - lon = torch.as_tensor(asarray(events.scattered_lon_rad_sc, dtype=np.float32)) - lat = torch.as_tensor(asarray(events.scattered_lat_rad_sc, dtype=np.float32)) - phi = torch.as_tensor(asarray(events.scattering_angle_rad, dtype=np.float32)) - en = torch.as_tensor(asarray(events.energy_keV, dtype=np.float32)) - - lat = -lat + (np.pi / 2) - return torch.stack([en, phi, lon, lat], dim=1) - - def _effective_area_cm2(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> Iterable[float]: - context = self._get_context(photons) - - return np.asarray(self._response.evaluate_effective_area(context)) - - def _event_probability(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface, events: EmCDSEventDataInSCFrameInterface) -> Iterable[float]: - source = self._get_source(events) - context = self._get_context(photons) - - return np.asarray(self._response.evaluate_density(context, source)) - - def _random_events(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> EmCDSEventDataInSCFrameInterface: - context = self._get_context(photons) - samples = self._response.sample_density(context) - samples[:, 3].mul_(-1).add_(np.pi/2) - samples = np.asarray(samples) - - return EmCDSEventDataInSCFrameFromArrays( - samples[:, 0], # Energy - samples[:, 2], # Lon - samples[:, 3], # Lat - samples[:, 1] # Phi - ) class UnpolarizedDC3InterpolatedFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): diff --git a/cosipy/response/nn_instrument_response_function.py b/cosipy/response/nn_instrument_response_function.py new file mode 100644 index 000000000..35a3b981b --- /dev/null +++ b/cosipy/response/nn_instrument_response_function.py @@ -0,0 +1,72 @@ +from typing import Iterable + +import numpy as np + +from cosipy.interfaces.data_interface import EmCDSEventDataInSCFrameInterface +from cosipy.interfaces.instrument_response_interface import FarFieldSpectralInstrumentResponseFunctionInterface +from cosipy.interfaces.photon_parameters import PhotonListWithDirectionAndEnergyInSCFrameInterface +from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays +from cosipy.response.NNResponse import NNResponse +from cosipy.util.iterables import asarray + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch + + +class UnpolarizedNNFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): + + event_data_type = EmCDSEventDataInSCFrameInterface + photon_list_type = PhotonListWithDirectionAndEnergyInSCFrameInterface + + def __init__(self, response: NNResponse,): + if response.is_polarized: + raise ValueError("The provided NNResponse is polarized, but UnpolarizedNNFarFieldInstrumentResponseFunction only supports unpolarized responses.") + self._response = response + + @staticmethod + def _get_context(photons: PhotonListWithDirectionAndEnergyInSCFrameInterface): + lon = torch.as_tensor(asarray(photons.direction_lon_rad_sc, dtype=np.float32)) + lat = torch.as_tensor(asarray(photons.direction_lat_rad_sc, dtype=np.float32)) + en = torch.as_tensor(asarray(photons.energy_keV, dtype=np.float32)) + + lat = -lat + (np.pi / 2) + return torch.stack([lon, lat, en], dim=1) + + @staticmethod + def _get_source(events: EmCDSEventDataInSCFrameInterface): + lon = torch.as_tensor(asarray(events.scattered_lon_rad_sc, dtype=np.float32)) + lat = torch.as_tensor(asarray(events.scattered_lat_rad_sc, dtype=np.float32)) + phi = torch.as_tensor(asarray(events.scattering_angle_rad, dtype=np.float32)) + en = torch.as_tensor(asarray(events.energy_keV, dtype=np.float32)) + + lat = -lat + (np.pi / 2) + return torch.stack([en, phi, lon, lat], dim=1) + + def _effective_area_cm2(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> Iterable[float]: + context = self._get_context(photons) + + return np.asarray(self._response.evaluate_effective_area(context)) + + def _event_probability(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface, events: EmCDSEventDataInSCFrameInterface) -> Iterable[float]: + source = self._get_source(events) + context = self._get_context(photons) + + return np.asarray(self._response.evaluate_density(context, source)) + + def _random_events(self, photons: PhotonListWithDirectionAndEnergyInSCFrameInterface) -> EmCDSEventDataInSCFrameInterface: + context = self._get_context(photons) + samples = self._response.sample_density(context) + samples[:, 3].mul_(-1).add_(np.pi/2) + samples = np.asarray(samples) + + return EmCDSEventDataInSCFrameFromArrays( + samples[:, 0], # Energy + samples[:, 2], # Lon + samples[:, 3], # Lat + samples[:, 1] # Phi + ) diff --git a/cosipy/response/nnresponse_helper.py b/cosipy/response/nn_response_helper.py similarity index 99% rename from cosipy/response/nnresponse_helper.py rename to cosipy/response/nn_response_helper.py index f741559bb..120fcd422 100644 --- a/cosipy/response/nnresponse_helper.py +++ b/cosipy/response/nn_response_helper.py @@ -1,11 +1,20 @@ -import normflows as nf import numpy as np -import torch -from torch import nn import healpy as hp -import sphericart.torch + from typing import Protocol, Optional, Literal, List, Union, Tuple, Dict, runtime_checkable + +from importlib.util import find_spec + +if any(find_spec(pkg) is None for pkg in ["torch", "normflows", "sphericart.torch"]): + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import sphericart.torch +from torch import nn +import normflows as nf +import torch + + CompileMode = Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] def build_cmlp_diaggaussian_base(input_dim: int, output_dim: int, diff --git a/cosipy/threeml/optimized_unbinned_folding.py b/cosipy/threeml/optimized_unbinned_folding.py new file mode 100644 index 000000000..83216971a --- /dev/null +++ b/cosipy/threeml/optimized_unbinned_folding.py @@ -0,0 +1,849 @@ +import copy +import os +import json +from typing import Optional, Iterable, Type, Tuple, List + +import numpy as np +import h5py +from astromodels import PointSource +from astropy.coordinates import CartesianRepresentation +from executing import Source + +from cosipy import SpacecraftHistory +from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays +from cosipy.interfaces import UnbinnedThreeMLSourceResponseInterface, EventInterface +from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface +from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface +from cosipy.interfaces.instrument_response_interface import FarFieldInstrumentResponseFunctionInterface +from cosipy.response.photon_types import PhotonListWithDirectionAndEnergyInSCFrame +from cosipy.util.iterables import asarray + +from astropy import units as u +import astropy.constants as c +from astropy.coordinates import SkyCoord +from astropy.time import Time + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch + + +class UnbinnedThreeMLPointSourceResponseIRFAdaptive(UnbinnedThreeMLSourceResponseInterface): + + def __init__(self, + data: TimeTagEmCDSEventDataInSCFrameInterface, + irf: FarFieldInstrumentResponseFunctionInterface, + sc_history: SpacecraftHistory,): + + """ + Will fold the IRF with the point source spectrum by evaluating the IRF at Ei positions adaptively chosen based on characteristic IRF features + Note that this assumes a smooth flux spectrum + + All IRF queries are cached and can be saved to / loaded from a file + + Parameters + ---------- + data + irf + sc_history + """ + + # Interface inputs + self._source = None + + # Other implementation inputs + self._data = data + self._irf = irf + self._sc_ori = sc_history + + # Default parameters for irf energy node placement + self._total_energy_nodes = (60, 500) + self._peak_nodes = (18, 12) + self._peak_widths = (0.04, 0.1) + self._energy_range = (100., 10_000.) + self._batch_size = 1_000_000 + + # Placeholder for node pool - stored as Tensors + self._width_tensor: Optional[torch.Tensor] = None + self._nodes_primary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self._nodes_secondary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + self._nodes_bkg_1: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self._nodes_bkg_2: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + self._nodes_bkg_3: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + + # Checks to avoid unecessary recomputations + self._last_convolved_source_skycoord = None + self._last_convolved_source_dict_number = None + self._last_convolved_source_dict_density = None + self._sc_coord_sph_cache = None + + # Cached values + self._irf_cache: Optional[torch.Tensor] = None # cm^2/rad/sr + self._irf_energy_node_cache: Optional[np.ndarray] = None # (Optional, only if full batch) + self._area_cache: Optional[np.ndarray] = None # cm^2*s*keV + self._area_energy_node_cache: Optional[np.ndarray] = None + self._exp_events: Optional[float] = None + self._exp_density: Optional[torch.Tensor] = None + + # Precomputed spacecraft history + self._mid_times = self._sc_ori.obstime[:-1] + (self._sc_ori.obstime[1:] - self._sc_ori.obstime[:-1]) / 2 + self._sc_ori_center = self._sc_ori.interp(self._mid_times) + + data_times = self._data.time + self._n_events = self._data.nevents + self._unique_mjds, self._inv_idx = np.unique(data_times.mjd, return_inverse=True) + unique_times_obj = Time(self._unique_mjds, format='mjd') + + self._sc_ori_unique = self._sc_ori.interp(unique_times_obj) + + interval_ratios = (self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) + bin_indices = np.searchsorted(self._sc_ori.obstime.mjd, self._unique_mjds) - 1 + + bin_indices = np.clip(bin_indices, 0, len(self._sc_ori.livetime) - 1) + unique_ratio = interval_ratios[bin_indices] + + self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) + + self._energy_m_keV = torch.as_tensor(asarray(self._data.energy_keV, dtype=np.float32)) + self._phi_rad = torch.as_tensor(asarray(self._data.scattering_angle_rad, dtype=np.float32)) + + self._lon_scatt = torch.as_tensor(asarray(self._data.scattered_lon_rad_sc, dtype=np.float32)) + self._lat_scatt = torch.as_tensor(asarray(self._data.scattered_lat_rad_sc, dtype=np.float32)) + self._cos_lat_scatt = torch.cos(self._lat_scatt) + self._sin_lat_scatt = torch.sin(self._lat_scatt) + self._cos_lon_scatt = torch.cos(self._lon_scatt) + self._sin_lon_scatt = torch.sin(self._lon_scatt) + + #unique_ratio = np.interp(self._unique_mjds, + # self._mid_times.mjd, + # self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) + # + #self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) + + #wrong_order = np.where(((data_times[1:] - data_times[:-1]) <= 0))[0] + #data_times[wrong_order + 1] = data_times[wrong_order + 1] + 1 + #self._sc_ori_data = self._sc_ori.interp(data_times) + + #ratio = np.interp(self._data.time.mjd, + # self._mid_times.mjd, + # self._sc_ori.livetime.to_value(u.s)/self._sc_ori.intervals_duration.to_value(u.s)) + #self._livetime_ratio = ratio.astype(np.float32) + + @property + def event_type(self) -> Type[EventInterface]: + return TimeTagEmCDSEventInSCFrameInterface + + def set_integration_parameters(self, + total_energy_nodes: Tuple[int, int] = (60, 500), + peak_nodes: Tuple[int, int] = (18, 12), + peak_widths: Tuple[float, float] = (0.04, 0.1), + energy_range: Tuple[float, float] = (100., 10_000.), + batch_size: int = 1_000_000,): + + # Reset caches if parameters change + if (peak_nodes != self._peak_nodes + or + peak_widths != self._peak_widths + or + total_energy_nodes[0] != self._total_energy_nodes[0]): + self._irf_cache = None + self._irf_energy_node_cache = None + self._width_tensor = None + self._nodes_primary = None + self._nodes_secondary = None + self._nodes_bkg_1 = None + self._nodes_bkg_2 = None + self._nodes_bkg_3 = None + + if (total_energy_nodes[1] != self._total_energy_nodes[1]): + self._area_cache = None + self._area_energy_node_cache = None + + if (energy_range != self._energy_range): + self._irf_cache = None + self._irf_energy_node_cache = None + self._area_cache = None + self._area_energy_node_cache = None + + if total_energy_nodes[0] < (peak_nodes[0] + 2 * peak_nodes[1] + 3): + raise ValueError("To many nodes per peak compared to the total number or peaks!") + + if (total_energy_nodes[0] < 1) or (total_energy_nodes[1] < 1): + raise ValueError("The number of energy nodes must be at least 1.") + + if energy_range[0] >= energy_range[1]: + raise ValueError("The initial energy interval needs to be increasing!") + + if (batch_size < total_energy_nodes[0]) or (batch_size < total_energy_nodes[1]): + raise ValueError("The batch size cannot be smaller than the number of integration nodes.") + + self._total_energy_nodes = total_energy_nodes + self._peak_nodes = peak_nodes + self._peak_widths = peak_widths + self._energy_range = energy_range + self._batch_size = batch_size + + @staticmethod + def _build_nodes(degree: int) -> Tuple[torch.Tensor, torch.Tensor]: + x, w = np.polynomial.legendre.leggauss(degree) + return torch.as_tensor(x, dtype=torch.float32).unsqueeze(0), torch.as_tensor(w, dtype=torch.float32).unsqueeze(0) + + def _build_split_nodes(self, remaining: int, groups: int): + q, r = divmod(remaining, groups) + return [self._build_nodes(q + (1 if i < r else 0)) for i in range(groups)] + + def _init_node_pool(self): + self._width_tensor = torch.tensor([self._peak_widths[0], self._peak_widths[0], + self._peak_widths[1], self._peak_widths[1]], dtype=torch.float32) + + self._nodes_primary = self._build_nodes(self._peak_nodes[0]) + self._nodes_secondary = self._build_nodes(self._peak_nodes[1]) + + self._nodes_bkg_1 = self._build_nodes(self._total_energy_nodes[0] - self._peak_nodes[0]) + + self._nodes_bkg_2 = self._build_split_nodes( + self._total_energy_nodes[0] - self._peak_nodes[0] - self._peak_nodes[1], 2 + ) + + self._nodes_bkg_3 = self._build_split_nodes( + self._total_energy_nodes[0] - self._peak_nodes[0] - 2 * self._peak_nodes[1], 3 + ) + + @staticmethod + def _scale_nodes_exp(E1: torch.Tensor, E2: torch.Tensor, + nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + diff = E2 - E1 + + out_n = (nodes_u + 1).mul(0.5).pow(2).mul(diff).add(E1) + out_w = (nodes_u + 1).mul(0.5).mul(weights_u).mul(diff) + + return out_n, out_w + + @staticmethod + def _scale_nodes_center(E1: torch.Tensor, E2: torch.Tensor, EC: torch.Tensor, + nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask_left = (nodes_u < 0) + width_left = (EC - E1) + width_right = (E2 - EC) + + scale = torch.where(mask_left, width_left, width_right) + + out_n = nodes_u.pow(3).mul(scale).add(EC) + out_w = nodes_u.pow(2).mul(3).mul(weights_u).mul(scale) + + return out_n, out_w + + def _get_escape_peak(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor) -> torch.Tensor: + E2 = 511.0 / (1.0 + 511.0 / energy_m_keV - torch.cos(phi_rad)) + energy = energy_m_keV + 1022.0 - E2 + + accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > 1600.0) & (energy_m_keV < energy) + return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) + + def _get_missing_energy_peak(self, phi_geo_rad: torch.Tensor, energy_m_keV: torch.Tensor, + phi_rad: torch.Tensor, inverse: bool = False) -> torch.Tensor: + cos_geo = torch.cos(phi_geo_rad) + cos_phi = torch.cos(phi_rad) + + if inverse: + denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) + root = torch.sqrt(energy_m_keV * (cos_geo - 1) * (-2044.0 - 5 * energy_m_keV + energy_m_keV * cos_geo + 4 * energy_m_keV * cos_phi)) + energy = 511.0 * (energy_m_keV - energy_m_keV * cos_geo + root) / denom + else: + denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) + root = torch.sqrt(energy_m_keV**2 * (cos_geo - 1) * (cos_phi - 1) * + ((1022.0 + energy_m_keV)**2 - energy_m_keV * (2044.0 + energy_m_keV) * cos_phi - 2 * energy_m_keV**2 * cos_geo * torch.sin(phi_rad/2)**2)) + energy = (energy_m_keV**2 * (1 - cos_geo - cos_phi + cos_phi * cos_geo) + root) / denom + + accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > energy_m_keV) & (energy_m_keV/energy - 1 < -0.2) + return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) + + def init_cache(self): + self._update_cache() + + def clear_cache(self): + self._irf_cache = None + self._irf_energy_node_cache = None + self._area_cache = None + self._area_energy_node_cache = None + self._exp_events = None + self._exp_density = None + + self._last_convolved_source_skycoord = None + self._last_convolved_source_dict_number = None + self._last_convolved_source_dict_density = None + self._sc_coord_sph_cache = None + + def set_source(self, source: Source): + if not isinstance(source, PointSource): + raise TypeError("Please provide a PointSource!") + + self._source = source + + def copy(self) -> UnbinnedThreeMLSourceResponseInterface: + new_instance = copy.copy(self) + new_instance.clear_cache() + new_instance._source = None + + return new_instance + + @staticmethod + def _earth_occ(source_coord: SkyCoord, ori: SpacecraftHistory) -> np.ndarray: + gcrs_cart = ori.location.represent_as(CartesianRepresentation) + dist_earth_center = gcrs_cart.norm() + max_angle = np.pi*u.rad - np.arcsin(c.R_earth/dist_earth_center) + src_angle = source_coord.separation(ori.earth_zenith) + return (src_angle < max_angle).astype(np.float32) + + def _compute_area(self): + coord = self._source.position.sky_coord + n_energy = self._total_energy_nodes[1] + + log_E_min = np.log10(self._energy_range[0]) + log_E_max = np.log10(self._energy_range[1]) + + x, w = np.polynomial.legendre.leggauss(n_energy) + + scale = 0.5 * (log_E_max - log_E_min) + y_nodes = scale * x + 0.5 * (log_E_max + log_E_min) + self._area_energy_node_cache = 10**y_nodes + + e_w = (np.log(10) * self._area_energy_node_cache * (w * scale)).astype(np.float32).reshape(1, -1) + e_n = self._area_energy_node_cache.astype(np.float32) + + sc_coord_sph = self._sc_ori_center.get_target_in_sc_frame(coord) + earth_occ_index = self._earth_occ(coord, self._sc_ori_center) + + time_weights = (self._sc_ori.livetime.to_value(u.s)).astype(np.float32) * earth_occ_index + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + n_time = len(lon_ph_rad) + batch_size_time = self._batch_size // n_energy + + total_area = np.zeros(n_energy, dtype=np.float64) + + max_batch_total = n_energy * min(batch_size_time, n_time) + batch_lons_buffer = np.empty(max_batch_total, dtype=np.float32) + batch_lats_buffer = np.empty(max_batch_total, dtype=np.float32) + batch_energies_buffer = np.empty(max_batch_total, dtype=np.float32) + + for i in range(0, n_time, batch_size_time): + start = i + end = min(i + batch_size_time, n_time) + current_n_time = end - start + current_total = current_n_time * n_energy + + #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lons_buffer[:current_total]) + #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lats_buffer[:current_total]) + + batch_lons_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] + batch_lats_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] + batch_energies_buffer[:current_total].reshape(current_n_time, n_energy)[:] = e_n + + photons = PhotonListWithDirectionAndEnergyInSCFrame( + batch_lons_buffer[:current_total], + batch_lats_buffer[:current_total], + batch_energies_buffer[:current_total] + ) + + eff_areas_flat = asarray(self._irf._effective_area_cm2(photons), dtype=np.float32) + eff_areas_grid = eff_areas_flat.reshape(current_n_time, n_energy) + + total_area += np.einsum('ij,i,j->j', + eff_areas_grid, + time_weights[start:end], + e_w.ravel()) + + self._area_cache = total_area + + def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, + indices: torch.Tensor, mode: int, + sorted_peaks: torch.Tensor, delta: torch.Tensor): + + Emin, Emax = self._energy_range + + if mode == 1: + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=Emax) + + EC = sorted_peaks[:, 0] + + E1, E2, EC = [E.view(-1, 1) for E in (E1, E2, EC)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E1, E2, EC, *self._nodes_primary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_1[0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, Emax, *self._nodes_bkg_1) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E2, Emax, *self._nodes_bkg_1, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + + elif mode == 2: + center_peak = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 + + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) + E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=Emax) + + EC1 = sorted_peaks[:, 0] + EC2 = sorted_peaks[:, 1] + + E1, E2, E3, E4, EC1, EC2 = [E.view(-1, 1) for E in (E1, E2, E3, E4, EC1, EC2)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_2[0][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_2[0]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_2[0], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_2[1][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E4, Emax, *self._nodes_bkg_2[1]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E4, Emax, *self._nodes_bkg_2[1], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + + elif mode == 3: + center_peak_1 = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 + center_peak_2 = (sorted_peaks[:, 1] + sorted_peaks[:, 2]) / 2 + + E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) + E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak_1) + E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) + E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=center_peak_2) + E5 = (sorted_peaks[:, 2] - delta[:, 2]).clamp(min=E4) + E6 = (sorted_peaks[:, 2] + delta[:, 2]).clamp(max=Emax) + + EC1, EC2, EC3 = [sorted_peaks[:, i] for i in range(3)] + + E1, E2, E3, E4, E5, E6, EC1, EC2, EC3 = [E.view(-1, 1) for E in (E1, E2, E3, E4, E5, E6, EC1, EC2, EC3)] + + c = 0 + w = self._nodes_primary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_3[0][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_3[0]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_3[0], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_3[1][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E4, E5, *self._nodes_bkg_3[1]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E4, E5, *self._nodes_bkg_3[1], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_secondary[0].shape[1] + n_res, w_res = self._scale_nodes_center(E5, E6, EC3, *self._nodes_secondary) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_center_inplace(E5, E6, EC3, *self._nodes_secondary, + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w + w = self._nodes_bkg_3[2][0].shape[1] + n_res, w_res = self._scale_nodes_exp(E6, Emax, *self._nodes_bkg_3[2]) + nodes_out[indices, c:c+w] = n_res + weights_out[indices, c:c+w] = w_res + #self._scale_nodes_exp_inplace(E6, Emax, *self._nodes_bkg_3[2], + # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + + def _get_nodes(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor, + phi_geo_rad: torch.Tensor, phi_igeo_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + energy_m_keV = energy_m_keV.view(-1, 1) + phi_rad = phi_rad.view(-1, 1) + phi_geo_rad = phi_geo_rad.view(-1, 1) + phi_igeo_rad = phi_igeo_rad.view(-1, 1) + + batch_size = energy_m_keV.shape[0] + + nodes = torch.zeros((batch_size, self._total_energy_nodes[0]), dtype=torch.float32) + weights = torch.zeros_like(nodes) + + peaks = torch.zeros((batch_size, 4), dtype=torch.float32) + peaks[:, 0] = energy_m_keV.squeeze() + peaks[:, 1] = self._get_escape_peak(energy_m_keV, phi_rad).squeeze() + peaks[:, 2] = self._get_missing_energy_peak(phi_geo_rad, energy_m_keV, phi_rad).squeeze() + peaks[:, 3] = self._get_missing_energy_peak(phi_igeo_rad, energy_m_keV, phi_rad, inverse=True).squeeze() + + diffs = peaks * self._width_tensor[None, ...] + + n_peaks = torch.sum(~torch.isnan(peaks), dim=1) + + indices_1 = torch.where(n_peaks == 1)[0] + indices_2 = torch.where(n_peaks == 2)[0] + indices_3 = torch.where(n_peaks == 3)[0] + + if len(indices_1) > 0: + self._fill_nodes(nodes, weights, indices_1, 1, + peaks[indices_1, :1], diffs[indices_1, :1]) + + if len(indices_2) > 0: + p_sub = peaks[indices_2] + d_sub = diffs[indices_2] + mask = ~torch.isnan(p_sub) + p_comp = p_sub[mask].view(-1, 2) + d_comp = d_sub[mask].view(-1, 2) + self._fill_nodes(nodes, weights, indices_2, 2, p_comp, d_comp) + + if len(indices_3) > 0: + p_sub = peaks[indices_3] + d_sub = diffs[indices_3] + mask = ~torch.isnan(p_sub) + p_comp = p_sub[mask].view(-1, 3) + d_comp = d_sub[mask].view(-1, 3) + + p_sorted, idx = torch.sort(p_comp, dim=1) + d_sorted = torch.gather(d_comp, 1, idx) + + self._fill_nodes(nodes, weights, indices_3, 3, p_sorted, d_sorted) + + return nodes, weights + + def _get_CDS_coordinates(self, lon_src_rad: torch.Tensor, lat_src_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + cos_lat_src = torch.cos(lat_src_rad) + sin_lat_src = torch.sin(lat_src_rad) + cos_lon_src = torch.cos(lon_src_rad) + sin_lon_src = torch.sin(lon_src_rad) + + cos_geo = ( + cos_lat_src * cos_lon_src * self._cos_lat_scatt * self._cos_lon_scatt + + cos_lat_src * sin_lon_src * self._cos_lat_scatt * self._sin_lon_scatt + + sin_lat_src * self._sin_lat_scatt + ) + + cos_geo = torch.clip(cos_geo, -1.0, 1.0) + phi_geo_rad = torch.arccos(cos_geo) + + return phi_geo_rad, np.pi - phi_geo_rad + + def _compute_density(self): + coord = self._source.position.sky_coord + sc_coord_sph = self._sc_coord_sph_cache + earth_occ_index = self._earth_occ(coord, self._sc_ori_unique)[self._inv_idx] + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) + + n_energy = self._total_energy_nodes[0] + batch_size_events = self._batch_size // n_energy + + self._irf_cache = torch.zeros((self._n_events, n_energy), dtype=torch.float32) + + buffer_size = n_energy * min(batch_size_events, self._n_events) + batch_lon_src_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lat_src_buffer = np.empty(buffer_size, dtype=np.float32) + batch_energy_buffer = np.empty(buffer_size, dtype=np.float32) + batch_phi_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lon_scatt_buffer = np.empty(buffer_size, dtype=np.float32) + batch_lat_scatt_buffer = np.empty(buffer_size, dtype=np.float32) + + for i in range(0, self._n_events, batch_size_events): + start = i + end = min(i + batch_size_events, self._n_events) + current_n = end - start + current_total = current_n * n_energy + + e_sl = self._energy_m_keV[start:end] + p_sl = self._phi_rad[start:end] + pg_sl = phi_geo_rad[start:end] + pig_sl = phi_igeo_rad[start:end] + + nodes, weights = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + + if batch_size_events >= self._n_events: + self._irf_energy_node_cache = np.asarray(nodes) + + #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lon_src_buffer[:current_total]) + #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lat_src_buffer[:current_total]) + + batch_lon_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] + batch_lat_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] + + batch_energy_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._energy_m_keV[start:end, np.newaxis]) + batch_lon_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lon_scatt[start:end, np.newaxis]) + batch_lat_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lat_scatt[start:end, np.newaxis]) + batch_phi_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._phi_rad[start:end, np.newaxis]) + + #np.repeat(np.asarray(self._energy_m_keV[start:end]), n_energy, out=batch_energy_buffer[:current_total]) + #np.repeat(np.asarray(self._lon_scatt[start:end]), n_energy, out=batch_lon_scatt_buffer[:current_total]) + #np.repeat(np.asarray(self._lat_scatt[start:end]), n_energy, out=batch_lat_scatt_buffer[:current_total]) + #np.repeat(np.asarray(self._phi_rad[start:end]), n_energy, out=batch_phi_buffer[:current_total]) + + photons = PhotonListWithDirectionAndEnergyInSCFrame( + batch_lon_src_buffer[:current_total], + batch_lat_src_buffer[:current_total], + np.asarray(nodes).ravel() + ) + events = EmCDSEventDataInSCFrameFromArrays( + batch_energy_buffer[:current_total], + batch_lon_scatt_buffer[:current_total], + batch_lat_scatt_buffer[:current_total], + batch_phi_buffer[:current_total], + ) + + eff_areas_flat = torch.as_tensor(asarray(self._irf._effective_area_cm2(photons), dtype=np.float32)) + densities_flat = torch.as_tensor(asarray(self._irf._event_probability(photons, events), dtype=np.float32)) + + res_block = (densities_flat * eff_areas_flat).view(current_n, n_energy) + + occ = torch.as_tensor(earth_occ_index[start:end]).view(-1, 1) + live = torch.as_tensor(self._livetime_ratio[start:end]).view(-1, 1) + + res_block *= occ * live * weights + + self._irf_cache[start:end] = res_block + + def _update_cache(self): + + if self._source is None: + raise RuntimeError("Call set_source() first.") + + source_coord = self._source.position.sky_coord + + if (self._sc_coord_sph_cache is None) or (source_coord != self._last_convolved_source_skycoord): + #self._sc_coord_sph_cache = self._sc_ori_data.get_target_in_sc_frame(source_coord) + self._sc_coord_sph_cache = self._sc_ori_unique.get_target_in_sc_frame(source_coord)[self._inv_idx] + + no_recalculation = ((source_coord == self._last_convolved_source_skycoord) + and + (self._irf_cache is not None) + and + (self._area_cache is not None)) + + area_recalculation = ((source_coord != self._last_convolved_source_skycoord) + or + (self._area_cache is None)) + + pdf_recalculation = ((source_coord != self._last_convolved_source_skycoord) + or + (self._irf_cache is None)) + + if no_recalculation: + return + else: + if area_recalculation: + self._compute_area() + + if pdf_recalculation: + self._init_node_pool() + self._compute_density() + + self._last_convolved_source_skycoord = source_coord.copy() + + def cache_to_file(self, filename: str): + with h5py.File(filename, 'w') as f: + f.attrs['total_energy_nodes'] = self._total_energy_nodes + f.attrs['peak_nodes'] = self._peak_nodes + f.attrs['peak_widths'] = self._peak_widths + f.attrs['energy_range'] = self._energy_range + f.attrs['batch_size'] = self._batch_size + + if self._irf_cache is not None: + f.create_dataset('irf_cache', data=self._irf_cache.numpy(), + compression='gzip', compression_opts=4) + + if self._irf_energy_node_cache is not None: + f.create_dataset('irf_energy_node_cache', data=self._irf_energy_node_cache, + compression='gzip') + + if self._area_cache is not None: + f.create_dataset('area_cache', data=self._area_cache, + compression='gzip') + + if self._area_energy_node_cache is not None: + f.create_dataset('area_energy_node_cache', data=self._area_energy_node_cache, + compression='gzip') + + if self._exp_events is not None: + f.create_dataset('exp_events', data=self._exp_events) + + if self._exp_density is not None: + f.create_dataset('exp_density', data=self._exp_density.numpy(), + compression='gzip') + + if self._last_convolved_source_dict_number is not None: + json_str = json.dumps(self._last_convolved_source_dict_number) + f.attrs['last_convolved_source_dict_number'] = json_str + + if self._last_convolved_source_dict_density is not None: + json_str = json.dumps(self._last_convolved_source_dict_density) + f.attrs['last_convolved_source_dict_density'] = json_str + + if self._last_convolved_source_skycoord is not None: + sc = self._last_convolved_source_skycoord + f.attrs['last_convolved_lon_deg'] = sc.spherical.lon.deg + f.attrs['last_convolved_lat_deg'] = sc.spherical.lat.deg + f.attrs['last_convolved_frame'] = sc.frame.name + + def cache_from_file(self, filename: str): + if not os.path.exists(filename): + raise FileNotFoundError(f"Cache file {filename} not found.") + + with h5py.File(filename, 'r') as f: + self._total_energy_nodes = tuple(f.attrs['total_energy_nodes']) + self._peak_nodes = tuple(f.attrs['peak_nodes']) + self._peak_widths = tuple(f.attrs['peak_widths']) + self._energy_range = tuple(f.attrs['energy_range']) + self._batch_size = int(f.attrs['batch_size']) + + if 'irf_cache' in f: + self._irf_cache = torch.from_numpy(f['irf_cache'][:]) + else: + self._irf_cache = None + + if 'irf_energy_node_cache' in f: + self._irf_energy_node_cache = f['irf_energy_node_cache'][:] + else: + self._irf_energy_node_cache = None + + if 'area_cache' in f: + self._area_cache = f['area_cache'][:] + else: + self._area_cache = None + + if 'area_energy_node_cache' in f: + self._area_energy_node_cache = f['area_energy_node_cache'][:] + else: + self._area_energy_node_cache = None + + if 'exp_events' in f: + self._exp_events = float(f['exp_events'][()]) + else: + self._exp_events = None + + if 'exp_density' in f: + self._exp_density = torch.from_numpy(f['exp_density'][:]) + else: + self._exp_density = None + + if 'last_convolved_source_dict_number' in f.attrs: + self._last_convolved_source_dict_number = json.loads(f.attrs['last_convolved_source_dict_number']) + else: + self._last_convolved_source_dict_number = None + + if 'last_convolved_source_dict_density' in f.attrs: + self._last_convolved_source_dict_density = json.loads(f.attrs['last_convolved_source_dict_density']) + else: + self._last_convolved_source_dict_density = None + + if 'last_convolved_lon_deg' in f.attrs: + lon = f.attrs['last_convolved_lon_deg'] + lat = f.attrs['last_convolved_lat_deg'] + frame = f.attrs['last_convolved_frame'] + self._last_convolved_source_skycoord = SkyCoord(lon, lat, unit='deg', frame=frame) + else: + self._last_convolved_source_skycoord = None + + if self._irf_cache is not None: + self._init_node_pool() + + def expected_counts(self) -> float: + """ + Return the total expected counts. + """ + self._update_cache() + source_dict = self._source.to_dict() + + if (source_dict != self._last_convolved_source_dict_number) or (self._exp_events is None): + area = self._area_cache + flux = self._source(self._area_energy_node_cache) + self._exp_events = np.sum(area * flux, dtype=float) + + self._last_convolved_source_dict_number = source_dict + return self._exp_events + + def expectation_density(self) -> Iterable[float]: + """ + Return the expected number of counts density. This equals the event probabiliy times the number of events. + """ + + self._update_cache() + source_dict = self._source.to_dict() + + if (source_dict != self._last_convolved_source_dict_density) or (self._exp_density is None): + self._exp_density = torch.zeros(self._n_events, dtype=self._irf_cache.dtype) + + if self._irf_energy_node_cache is not None: + flux = torch.as_tensor(self._source(self._irf_energy_node_cache), dtype=self._irf_cache.dtype) + + torch.linalg.vecdot(self._irf_cache, flux, dim=1, out=self._exp_density) + + else: + n_energy = self._total_energy_nodes[0] + batch_size = self._batch_size // n_energy + + sc_coord_sph = self._sc_coord_sph_cache + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) + + for i in range(0, self._n_events, batch_size): + end = min(i + batch_size, self._n_events) + + e_sl = self._energy_m_keV[i:end] + p_sl = self._phi_rad[i:end] + pg_sl = phi_geo_rad[i:end] + pig_sl = phi_igeo_rad[i:end] + + nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + + flux_batch = torch.as_tensor(self._source(np.asarray(nodes)), dtype=self._irf_cache.dtype) + + torch.linalg.vecdot(self._irf_cache[i:end], flux_batch, dim=1, out=self._exp_density[i:end]) + + self._last_convolved_source_dict_density = source_dict + + #print(self._data.time.unix[self._exp_density <= 0][:100]) + #print(np.sum(self._exp_density <= 0)/self._n_events * 100) + #print(self.expected_counts() - np.sum(np.log(self._exp_density+1e-12))) + return np.asarray(self._exp_density, dtype=np.float64)+1e-12 \ No newline at end of file diff --git a/cosipy/threeml/psr_fixed_ei.py b/cosipy/threeml/psr_fixed_ei.py index 27dc5864c..ded88e118 100644 --- a/cosipy/threeml/psr_fixed_ei.py +++ b/cosipy/threeml/psr_fixed_ei.py @@ -1,13 +1,7 @@ import copy -import os -import json -from typing import Optional, Iterable, Type, Tuple, List, Union, Dict, Any -from itertools import chain, repeat, islice -from operator import attrgetter +from typing import Optional, Iterable, Type -import torch import numpy as np -import h5py from astromodels import PointSource from astropy.coordinates import UnitSphericalRepresentation, CartesianRepresentation from astropy.units import Quantity @@ -15,18 +9,14 @@ from histpy import Axis from cosipy import SpacecraftHistory -from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventInSCFrame, EmCDSEventDataInSCFrameFromArrays +from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventInSCFrame from cosipy.interfaces import UnbinnedThreeMLSourceResponseInterface, EventInterface from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import EmCDSEventInSCFrameInterface, TimeTagEmCDSEventInSCFrameInterface from cosipy.interfaces.instrument_response_interface import FarFieldInstrumentResponseFunctionInterface -from cosipy.response.photon_types import PhotonWithDirectionAndEnergyInSCFrame, PhotonListWithDirectionAndEnergyInSCFrame -from cosipy.util.iterables import asarray +from cosipy.response.photon_types import PhotonWithDirectionAndEnergyInSCFrame from astropy import units as u -import astropy.constants as c -from astropy.coordinates import SkyCoord -from astropy.time import Time class UnbinnedThreeMLPointSourceResponseTrapz(UnbinnedThreeMLSourceResponseInterface): @@ -219,821 +209,4 @@ def event_probability(self) -> Iterable[float]: self._update_cache() - return self._event_prob - - -class UnbinnedThreeMLPointSourceResponseIRFAdaptive(UnbinnedThreeMLSourceResponseInterface): - - def __init__(self, - data: TimeTagEmCDSEventDataInSCFrameInterface, - irf: FarFieldInstrumentResponseFunctionInterface, - sc_history: SpacecraftHistory,): - - """ - Will fold the IRF with the point source spectrum by evaluating the IRF at Ei positions adaptively chosen based on characteristic IRF features - Note that this assumes a smooth flux spectrum - - All IRF queries are cached and can be saved to / loaded from a file - - Parameters - ---------- - data - irf - sc_history - """ - - # Interface inputs - self._source = None - - # Other implementation inputs - self._data = data - self._irf = irf - self._sc_ori = sc_history - - # Default parameters for irf energy node placement - self._total_energy_nodes = (60, 500) - self._peak_nodes = (18, 12) - self._peak_widths = (0.04, 0.1) - self._energy_range = (100., 10_000.) - self._batch_size = 1_000_000 - - # Placeholder for node pool - stored as Tensors - self._width_tensor: Optional[torch.Tensor] = None - self._nodes_primary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - self._nodes_secondary: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - self._nodes_bkg_1: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - self._nodes_bkg_2: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None - self._nodes_bkg_3: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None - - # Checks to avoid unecessary recomputations - self._last_convolved_source_skycoord = None - self._last_convolved_source_dict_number = None - self._last_convolved_source_dict_density = None - self._sc_coord_sph_cache = None - - # Cached values - self._irf_cache: Optional[torch.Tensor] = None # cm^2/rad/sr - self._irf_energy_node_cache: Optional[np.ndarray] = None # (Optional, only if full batch) - self._area_cache: Optional[np.ndarray] = None # cm^2*s*keV - self._area_energy_node_cache: Optional[np.ndarray] = None - self._exp_events: Optional[float] = None - self._exp_density: Optional[torch.Tensor] = None - - # Precomputed spacecraft history - self._mid_times = self._sc_ori.obstime[:-1] + (self._sc_ori.obstime[1:] - self._sc_ori.obstime[:-1]) / 2 - self._sc_ori_center = self._sc_ori.interp(self._mid_times) - - data_times = self._data.time - self._n_events = self._data.nevents - self._unique_mjds, self._inv_idx = np.unique(data_times.mjd, return_inverse=True) - unique_times_obj = Time(self._unique_mjds, format='mjd') - - self._sc_ori_unique = self._sc_ori.interp(unique_times_obj) - - interval_ratios = (self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) - bin_indices = np.searchsorted(self._sc_ori.obstime.mjd, self._unique_mjds) - 1 - - bin_indices = np.clip(bin_indices, 0, len(self._sc_ori.livetime) - 1) - unique_ratio = interval_ratios[bin_indices] - - self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) - - self._energy_m_keV = torch.as_tensor(asarray(self._data.energy_keV, dtype=np.float32)) - self._phi_rad = torch.as_tensor(asarray(self._data.scattering_angle_rad, dtype=np.float32)) - - self._lon_scatt = torch.as_tensor(asarray(self._data.scattered_lon_rad_sc, dtype=np.float32)) - self._lat_scatt = torch.as_tensor(asarray(self._data.scattered_lat_rad_sc, dtype=np.float32)) - self._cos_lat_scatt = torch.cos(self._lat_scatt) - self._sin_lat_scatt = torch.sin(self._lat_scatt) - self._cos_lon_scatt = torch.cos(self._lon_scatt) - self._sin_lon_scatt = torch.sin(self._lon_scatt) - - #unique_ratio = np.interp(self._unique_mjds, - # self._mid_times.mjd, - # self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) - # - #self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) - - #wrong_order = np.where(((data_times[1:] - data_times[:-1]) <= 0))[0] - #data_times[wrong_order + 1] = data_times[wrong_order + 1] + 1 - #self._sc_ori_data = self._sc_ori.interp(data_times) - - #ratio = np.interp(self._data.time.mjd, - # self._mid_times.mjd, - # self._sc_ori.livetime.to_value(u.s)/self._sc_ori.intervals_duration.to_value(u.s)) - #self._livetime_ratio = ratio.astype(np.float32) - - @property - def event_type(self) -> Type[EventInterface]: - return TimeTagEmCDSEventInSCFrameInterface - - def set_integration_parameters(self, - total_energy_nodes: Tuple[int, int] = (60, 500), - peak_nodes: Tuple[int, int] = (18, 12), - peak_widths: Tuple[float, float] = (0.04, 0.1), - energy_range: Tuple[float, float] = (100., 10_000.), - batch_size: int = 1_000_000,): - - # Reset caches if parameters change - if (peak_nodes != self._peak_nodes - or - peak_widths != self._peak_widths - or - total_energy_nodes[0] != self._total_energy_nodes[0]): - self._irf_cache = None - self._irf_energy_node_cache = None - self._width_tensor = None - self._nodes_primary = None - self._nodes_secondary = None - self._nodes_bkg_1 = None - self._nodes_bkg_2 = None - self._nodes_bkg_3 = None - - if (total_energy_nodes[1] != self._total_energy_nodes[1]): - self._area_cache = None - self._area_energy_node_cache = None - - if (energy_range != self._energy_range): - self._irf_cache = None - self._irf_energy_node_cache = None - self._area_cache = None - self._area_energy_node_cache = None - - if total_energy_nodes[0] < (peak_nodes[0] + 2 * peak_nodes[1] + 3): - raise ValueError("To many nodes per peak compared to the total number or peaks!") - - if (total_energy_nodes[0] < 1) or (total_energy_nodes[1] < 1): - raise ValueError("The number of energy nodes must be at least 1.") - - if energy_range[0] >= energy_range[1]: - raise ValueError("The initial energy interval needs to be increasing!") - - if (batch_size < total_energy_nodes[0]) or (batch_size < total_energy_nodes[1]): - raise ValueError("The batch size cannot be smaller than the number of integration nodes.") - - self._total_energy_nodes = total_energy_nodes - self._peak_nodes = peak_nodes - self._peak_widths = peak_widths - self._energy_range = energy_range - self._batch_size = batch_size - - @staticmethod - def _build_nodes(degree: int) -> Tuple[torch.Tensor, torch.Tensor]: - x, w = np.polynomial.legendre.leggauss(degree) - return torch.as_tensor(x, dtype=torch.float32).unsqueeze(0), torch.as_tensor(w, dtype=torch.float32).unsqueeze(0) - - def _build_split_nodes(self, remaining: int, groups: int): - q, r = divmod(remaining, groups) - return [self._build_nodes(q + (1 if i < r else 0)) for i in range(groups)] - - def _init_node_pool(self): - self._width_tensor = torch.tensor([self._peak_widths[0], self._peak_widths[0], - self._peak_widths[1], self._peak_widths[1]], dtype=torch.float32) - - self._nodes_primary = self._build_nodes(self._peak_nodes[0]) - self._nodes_secondary = self._build_nodes(self._peak_nodes[1]) - - self._nodes_bkg_1 = self._build_nodes(self._total_energy_nodes[0] - self._peak_nodes[0]) - - self._nodes_bkg_2 = self._build_split_nodes( - self._total_energy_nodes[0] - self._peak_nodes[0] - self._peak_nodes[1], 2 - ) - - self._nodes_bkg_3 = self._build_split_nodes( - self._total_energy_nodes[0] - self._peak_nodes[0] - 2 * self._peak_nodes[1], 3 - ) - - @staticmethod - def _scale_nodes_exp(E1: torch.Tensor, E2: torch.Tensor, - nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - diff = E2 - E1 - - out_n = (nodes_u + 1).mul(0.5).pow(2).mul(diff).add(E1) - out_w = (nodes_u + 1).mul(0.5).mul(weights_u).mul(diff) - - return out_n, out_w - - @staticmethod - def _scale_nodes_center(E1: torch.Tensor, E2: torch.Tensor, EC: torch.Tensor, - nodes_u: torch.Tensor, weights_u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - mask_left = (nodes_u < 0) - width_left = (EC - E1) - width_right = (E2 - EC) - - scale = torch.where(mask_left, width_left, width_right) - - out_n = nodes_u.pow(3).mul(scale).add(EC) - out_w = nodes_u.pow(2).mul(3).mul(weights_u).mul(scale) - - return out_n, out_w - - def _get_escape_peak(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor) -> torch.Tensor: - E2 = 511.0 / (1.0 + 511.0 / energy_m_keV - torch.cos(phi_rad)) - energy = energy_m_keV + 1022.0 - E2 - - accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > 1600.0) & (energy_m_keV < energy) - return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) - - def _get_missing_energy_peak(self, phi_geo_rad: torch.Tensor, energy_m_keV: torch.Tensor, - phi_rad: torch.Tensor, inverse: bool = False) -> torch.Tensor: - cos_geo = torch.cos(phi_geo_rad) - cos_phi = torch.cos(phi_rad) - - if inverse: - denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) - root = torch.sqrt(energy_m_keV * (cos_geo - 1) * (-2044.0 - 5 * energy_m_keV + energy_m_keV * cos_geo + 4 * energy_m_keV * cos_phi)) - energy = 511.0 * (energy_m_keV - energy_m_keV * cos_geo + root) / denom - else: - denom = 2 * (-1 + cos_geo) * (-511.0 - energy_m_keV + energy_m_keV * cos_phi) - root = torch.sqrt(energy_m_keV**2 * (cos_geo - 1) * (cos_phi - 1) * - ((1022.0 + energy_m_keV)**2 - energy_m_keV * (2044.0 + energy_m_keV) * cos_phi - 2 * energy_m_keV**2 * cos_geo * torch.sin(phi_rad/2)**2)) - energy = (energy_m_keV**2 * (1 - cos_geo - cos_phi + cos_phi * cos_geo) + root) / denom - - accept = (energy < self._energy_range[1]) & (energy > self._energy_range[0]) & (energy > energy_m_keV) & (energy_m_keV/energy - 1 < -0.2) - return torch.where(accept, energy, torch.tensor(float('nan'), dtype=torch.float32)) - - def init_cache(self): - self._update_cache() - - def clear_cache(self): - self._irf_cache = None - self._irf_energy_node_cache = None - self._area_cache = None - self._area_energy_node_cache = None - self._exp_events = None - self._exp_density = None - - self._last_convolved_source_skycoord = None - self._last_convolved_source_dict_number = None - self._last_convolved_source_dict_density = None - self._sc_coord_sph_cache = None - - def set_source(self, source: Source): - if not isinstance(source, PointSource): - raise TypeError("Please provide a PointSource!") - - self._source = source - - def copy(self) -> UnbinnedThreeMLSourceResponseInterface: - new_instance = copy.copy(self) - new_instance.clear_cache() - new_instance._source = None - - return new_instance - - @staticmethod - def _earth_occ(source_coord: SkyCoord, ori: SpacecraftHistory) -> np.ndarray: - gcrs_cart = ori.location.represent_as(CartesianRepresentation) - dist_earth_center = gcrs_cart.norm() - max_angle = np.pi*u.rad - np.arcsin(c.R_earth/dist_earth_center) - src_angle = source_coord.separation(ori.earth_zenith) - return (src_angle < max_angle).astype(np.float32) - - def _compute_area(self): - coord = self._source.position.sky_coord - n_energy = self._total_energy_nodes[1] - - log_E_min = np.log10(self._energy_range[0]) - log_E_max = np.log10(self._energy_range[1]) - - x, w = np.polynomial.legendre.leggauss(n_energy) - - scale = 0.5 * (log_E_max - log_E_min) - y_nodes = scale * x + 0.5 * (log_E_max + log_E_min) - self._area_energy_node_cache = 10**y_nodes - - e_w = (np.log(10) * self._area_energy_node_cache * (w * scale)).astype(np.float32).reshape(1, -1) - e_n = self._area_energy_node_cache.astype(np.float32) - - sc_coord_sph = self._sc_ori_center.get_target_in_sc_frame(coord) - earth_occ_index = self._earth_occ(coord, self._sc_ori_center) - - time_weights = (self._sc_ori.livetime.to_value(u.s)).astype(np.float32) * earth_occ_index - - lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) - lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) - - n_time = len(lon_ph_rad) - batch_size_time = self._batch_size // n_energy - - total_area = np.zeros(n_energy, dtype=np.float64) - - max_batch_total = n_energy * min(batch_size_time, n_time) - batch_lons_buffer = np.empty(max_batch_total, dtype=np.float32) - batch_lats_buffer = np.empty(max_batch_total, dtype=np.float32) - batch_energies_buffer = np.empty(max_batch_total, dtype=np.float32) - - for i in range(0, n_time, batch_size_time): - start = i - end = min(i + batch_size_time, n_time) - current_n_time = end - start - current_total = current_n_time * n_energy - - #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lons_buffer[:current_total]) - #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lats_buffer[:current_total]) - - batch_lons_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] - batch_lats_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] - batch_energies_buffer[:current_total].reshape(current_n_time, n_energy)[:] = e_n - - photons = PhotonListWithDirectionAndEnergyInSCFrame( - batch_lons_buffer[:current_total], - batch_lats_buffer[:current_total], - batch_energies_buffer[:current_total] - ) - - eff_areas_flat = asarray(self._irf._effective_area_cm2(photons), dtype=np.float32) - eff_areas_grid = eff_areas_flat.reshape(current_n_time, n_energy) - - total_area += np.einsum('ij,i,j->j', - eff_areas_grid, - time_weights[start:end], - e_w.ravel()) - - self._area_cache = total_area - - def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, - indices: torch.Tensor, mode: int, - sorted_peaks: torch.Tensor, delta: torch.Tensor): - - Emin, Emax = self._energy_range - - if mode == 1: - E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) - E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=Emax) - - EC = sorted_peaks[:, 0] - - E1, E2, EC = [E.view(-1, 1) for E in (E1, E2, EC)] - - c = 0 - w = self._nodes_primary[0].shape[1] - n_res, w_res = self._scale_nodes_center(E1, E2, EC, *self._nodes_primary) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E1, E2, EC, *self._nodes_primary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_bkg_1[0].shape[1] - n_res, w_res = self._scale_nodes_exp(E2, Emax, *self._nodes_bkg_1) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E2, Emax, *self._nodes_bkg_1, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - - elif mode == 2: - center_peak = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 - - E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) - E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak) - E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) - E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=Emax) - - EC1 = sorted_peaks[:, 0] - EC2 = sorted_peaks[:, 1] - - E1, E2, E3, E4, EC1, EC2 = [E.view(-1, 1) for E in (E1, E2, E3, E4, EC1, EC2)] - - c = 0 - w = self._nodes_primary[0].shape[1] - n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_bkg_2[0][0].shape[1] - n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_2[0]) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_2[0], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_secondary[0].shape[1] - n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_bkg_2[1][0].shape[1] - n_res, w_res = self._scale_nodes_exp(E4, Emax, *self._nodes_bkg_2[1]) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E4, Emax, *self._nodes_bkg_2[1], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - - elif mode == 3: - center_peak_1 = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 - center_peak_2 = (sorted_peaks[:, 1] + sorted_peaks[:, 2]) / 2 - - E1 = (sorted_peaks[:, 0] - delta[:, 0]).clamp(min=Emin) - E3 = (sorted_peaks[:, 1] - delta[:, 1]).clamp(min=center_peak_1) - E2 = (sorted_peaks[:, 0] + delta[:, 0]).clamp(max=E3) - E4 = (sorted_peaks[:, 1] + delta[:, 1]).clamp(max=center_peak_2) - E5 = (sorted_peaks[:, 2] - delta[:, 2]).clamp(min=E4) - E6 = (sorted_peaks[:, 2] + delta[:, 2]).clamp(max=Emax) - - EC1, EC2, EC3 = [sorted_peaks[:, i] for i in range(3)] - - E1, E2, E3, E4, E5, E6, EC1, EC2, EC3 = [E.view(-1, 1) for E in (E1, E2, E3, E4, E5, E6, EC1, EC2, EC3)] - - c = 0 - w = self._nodes_primary[0].shape[1] - n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_bkg_3[0][0].shape[1] - n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_3[0]) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_3[0], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_secondary[0].shape[1] - n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_bkg_3[1][0].shape[1] - n_res, w_res = self._scale_nodes_exp(E4, E5, *self._nodes_bkg_3[1]) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E4, E5, *self._nodes_bkg_3[1], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_secondary[0].shape[1] - n_res, w_res = self._scale_nodes_center(E5, E6, EC3, *self._nodes_secondary) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E5, E6, EC3, *self._nodes_secondary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - c += w - w = self._nodes_bkg_3[2][0].shape[1] - n_res, w_res = self._scale_nodes_exp(E6, Emax, *self._nodes_bkg_3[2]) - nodes_out[indices, c:c+w] = n_res - weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E6, Emax, *self._nodes_bkg_3[2], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) - - def _get_nodes(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor, - phi_geo_rad: torch.Tensor, phi_igeo_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - - energy_m_keV = energy_m_keV.view(-1, 1) - phi_rad = phi_rad.view(-1, 1) - phi_geo_rad = phi_geo_rad.view(-1, 1) - phi_igeo_rad = phi_igeo_rad.view(-1, 1) - - batch_size = energy_m_keV.shape[0] - - nodes = torch.zeros((batch_size, self._total_energy_nodes[0]), dtype=torch.float32) - weights = torch.zeros_like(nodes) - - peaks = torch.zeros((batch_size, 4), dtype=torch.float32) - peaks[:, 0] = energy_m_keV.squeeze() - peaks[:, 1] = self._get_escape_peak(energy_m_keV, phi_rad).squeeze() - peaks[:, 2] = self._get_missing_energy_peak(phi_geo_rad, energy_m_keV, phi_rad).squeeze() - peaks[:, 3] = self._get_missing_energy_peak(phi_igeo_rad, energy_m_keV, phi_rad, inverse=True).squeeze() - - diffs = peaks * self._width_tensor[None, ...] - - n_peaks = torch.sum(~torch.isnan(peaks), dim=1) - - indices_1 = torch.where(n_peaks == 1)[0] - indices_2 = torch.where(n_peaks == 2)[0] - indices_3 = torch.where(n_peaks == 3)[0] - - if len(indices_1) > 0: - self._fill_nodes(nodes, weights, indices_1, 1, - peaks[indices_1, :1], diffs[indices_1, :1]) - - if len(indices_2) > 0: - p_sub = peaks[indices_2] - d_sub = diffs[indices_2] - mask = ~torch.isnan(p_sub) - p_comp = p_sub[mask].view(-1, 2) - d_comp = d_sub[mask].view(-1, 2) - self._fill_nodes(nodes, weights, indices_2, 2, p_comp, d_comp) - - if len(indices_3) > 0: - p_sub = peaks[indices_3] - d_sub = diffs[indices_3] - mask = ~torch.isnan(p_sub) - p_comp = p_sub[mask].view(-1, 3) - d_comp = d_sub[mask].view(-1, 3) - - p_sorted, idx = torch.sort(p_comp, dim=1) - d_sorted = torch.gather(d_comp, 1, idx) - - self._fill_nodes(nodes, weights, indices_3, 3, p_sorted, d_sorted) - - return nodes, weights - - def _get_CDS_coordinates(self, lon_src_rad: torch.Tensor, lat_src_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - cos_lat_src = torch.cos(lat_src_rad) - sin_lat_src = torch.sin(lat_src_rad) - cos_lon_src = torch.cos(lon_src_rad) - sin_lon_src = torch.sin(lon_src_rad) - - cos_geo = ( - cos_lat_src * cos_lon_src * self._cos_lat_scatt * self._cos_lon_scatt + - cos_lat_src * sin_lon_src * self._cos_lat_scatt * self._sin_lon_scatt + - sin_lat_src * self._sin_lat_scatt - ) - - cos_geo = torch.clip(cos_geo, -1.0, 1.0) - phi_geo_rad = torch.arccos(cos_geo) - - return phi_geo_rad, np.pi - phi_geo_rad - - def _compute_density(self): - coord = self._source.position.sky_coord - sc_coord_sph = self._sc_coord_sph_cache - earth_occ_index = self._earth_occ(coord, self._sc_ori_unique)[self._inv_idx] - - lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) - lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) - - phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) - - n_energy = self._total_energy_nodes[0] - batch_size_events = self._batch_size // n_energy - - self._irf_cache = torch.zeros((self._n_events, n_energy), dtype=torch.float32) - - buffer_size = n_energy * min(batch_size_events, self._n_events) - batch_lon_src_buffer = np.empty(buffer_size, dtype=np.float32) - batch_lat_src_buffer = np.empty(buffer_size, dtype=np.float32) - batch_energy_buffer = np.empty(buffer_size, dtype=np.float32) - batch_phi_buffer = np.empty(buffer_size, dtype=np.float32) - batch_lon_scatt_buffer = np.empty(buffer_size, dtype=np.float32) - batch_lat_scatt_buffer = np.empty(buffer_size, dtype=np.float32) - - for i in range(0, self._n_events, batch_size_events): - start = i - end = min(i + batch_size_events, self._n_events) - current_n = end - start - current_total = current_n * n_energy - - e_sl = self._energy_m_keV[start:end] - p_sl = self._phi_rad[start:end] - pg_sl = phi_geo_rad[start:end] - pig_sl = phi_igeo_rad[start:end] - - nodes, weights = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) - - if batch_size_events >= self._n_events: - self._irf_energy_node_cache = np.asarray(nodes) - - #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lon_src_buffer[:current_total]) - #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lat_src_buffer[:current_total]) - - batch_lon_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] - batch_lat_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] - - batch_energy_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._energy_m_keV[start:end, np.newaxis]) - batch_lon_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lon_scatt[start:end, np.newaxis]) - batch_lat_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lat_scatt[start:end, np.newaxis]) - batch_phi_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._phi_rad[start:end, np.newaxis]) - - #np.repeat(np.asarray(self._energy_m_keV[start:end]), n_energy, out=batch_energy_buffer[:current_total]) - #np.repeat(np.asarray(self._lon_scatt[start:end]), n_energy, out=batch_lon_scatt_buffer[:current_total]) - #np.repeat(np.asarray(self._lat_scatt[start:end]), n_energy, out=batch_lat_scatt_buffer[:current_total]) - #np.repeat(np.asarray(self._phi_rad[start:end]), n_energy, out=batch_phi_buffer[:current_total]) - - photons = PhotonListWithDirectionAndEnergyInSCFrame( - batch_lon_src_buffer[:current_total], - batch_lat_src_buffer[:current_total], - np.asarray(nodes).ravel() - ) - events = EmCDSEventDataInSCFrameFromArrays( - batch_energy_buffer[:current_total], - batch_lon_scatt_buffer[:current_total], - batch_lat_scatt_buffer[:current_total], - batch_phi_buffer[:current_total], - ) - - eff_areas_flat = torch.as_tensor(asarray(self._irf._effective_area_cm2(photons), dtype=np.float32)) - densities_flat = torch.as_tensor(asarray(self._irf._event_probability(photons, events), dtype=np.float32)) - - res_block = (densities_flat * eff_areas_flat).view(current_n, n_energy) - - occ = torch.as_tensor(earth_occ_index[start:end]).view(-1, 1) - live = torch.as_tensor(self._livetime_ratio[start:end]).view(-1, 1) - - res_block *= occ * live * weights - - self._irf_cache[start:end] = res_block - - def _update_cache(self): - - if self._source is None: - raise RuntimeError("Call set_source() first.") - - source_coord = self._source.position.sky_coord - - if (self._sc_coord_sph_cache is None) or (source_coord != self._last_convolved_source_skycoord): - #self._sc_coord_sph_cache = self._sc_ori_data.get_target_in_sc_frame(source_coord) - self._sc_coord_sph_cache = self._sc_ori_unique.get_target_in_sc_frame(source_coord)[self._inv_idx] - - no_recalculation = ((source_coord == self._last_convolved_source_skycoord) - and - (self._irf_cache is not None) - and - (self._area_cache is not None)) - - area_recalculation = ((source_coord != self._last_convolved_source_skycoord) - or - (self._area_cache is None)) - - pdf_recalculation = ((source_coord != self._last_convolved_source_skycoord) - or - (self._irf_cache is None)) - - if no_recalculation: - return - else: - if area_recalculation: - self._compute_area() - - if pdf_recalculation: - self._init_node_pool() - self._compute_density() - - self._last_convolved_source_skycoord = source_coord.copy() - - def cache_to_file(self, filename: str): - with h5py.File(filename, 'w') as f: - f.attrs['total_energy_nodes'] = self._total_energy_nodes - f.attrs['peak_nodes'] = self._peak_nodes - f.attrs['peak_widths'] = self._peak_widths - f.attrs['energy_range'] = self._energy_range - f.attrs['batch_size'] = self._batch_size - - if self._irf_cache is not None: - f.create_dataset('irf_cache', data=self._irf_cache.numpy(), - compression='gzip', compression_opts=4) - - if self._irf_energy_node_cache is not None: - f.create_dataset('irf_energy_node_cache', data=self._irf_energy_node_cache, - compression='gzip') - - if self._area_cache is not None: - f.create_dataset('area_cache', data=self._area_cache, - compression='gzip') - - if self._area_energy_node_cache is not None: - f.create_dataset('area_energy_node_cache', data=self._area_energy_node_cache, - compression='gzip') - - if self._exp_events is not None: - f.create_dataset('exp_events', data=self._exp_events) - - if self._exp_density is not None: - f.create_dataset('exp_density', data=self._exp_density.numpy(), - compression='gzip') - - if self._last_convolved_source_dict_number is not None: - json_str = json.dumps(self._last_convolved_source_dict_number) - f.attrs['last_convolved_source_dict_number'] = json_str - - if self._last_convolved_source_dict_density is not None: - json_str = json.dumps(self._last_convolved_source_dict_density) - f.attrs['last_convolved_source_dict_density'] = json_str - - if self._last_convolved_source_skycoord is not None: - sc = self._last_convolved_source_skycoord - f.attrs['last_convolved_lon_deg'] = sc.spherical.lon.deg - f.attrs['last_convolved_lat_deg'] = sc.spherical.lat.deg - f.attrs['last_convolved_frame'] = sc.frame.name - - def cache_from_file(self, filename: str): - if not os.path.exists(filename): - raise FileNotFoundError(f"Cache file {filename} not found.") - - with h5py.File(filename, 'r') as f: - self._total_energy_nodes = tuple(f.attrs['total_energy_nodes']) - self._peak_nodes = tuple(f.attrs['peak_nodes']) - self._peak_widths = tuple(f.attrs['peak_widths']) - self._energy_range = tuple(f.attrs['energy_range']) - self._batch_size = int(f.attrs['batch_size']) - - if 'irf_cache' in f: - self._irf_cache = torch.from_numpy(f['irf_cache'][:]) - else: - self._irf_cache = None - - if 'irf_energy_node_cache' in f: - self._irf_energy_node_cache = f['irf_energy_node_cache'][:] - else: - self._irf_energy_node_cache = None - - if 'area_cache' in f: - self._area_cache = f['area_cache'][:] - else: - self._area_cache = None - - if 'area_energy_node_cache' in f: - self._area_energy_node_cache = f['area_energy_node_cache'][:] - else: - self._area_energy_node_cache = None - - if 'exp_events' in f: - self._exp_events = float(f['exp_events'][()]) - else: - self._exp_events = None - - if 'exp_density' in f: - self._exp_density = torch.from_numpy(f['exp_density'][:]) - else: - self._exp_density = None - - if 'last_convolved_source_dict_number' in f.attrs: - self._last_convolved_source_dict_number = json.loads(f.attrs['last_convolved_source_dict_number']) - else: - self._last_convolved_source_dict_number = None - - if 'last_convolved_source_dict_density' in f.attrs: - self._last_convolved_source_dict_density = json.loads(f.attrs['last_convolved_source_dict_density']) - else: - self._last_convolved_source_dict_density = None - - if 'last_convolved_lon_deg' in f.attrs: - lon = f.attrs['last_convolved_lon_deg'] - lat = f.attrs['last_convolved_lat_deg'] - frame = f.attrs['last_convolved_frame'] - self._last_convolved_source_skycoord = SkyCoord(lon, lat, unit='deg', frame=frame) - else: - self._last_convolved_source_skycoord = None - - if self._irf_cache is not None: - self._init_node_pool() - - def expected_counts(self) -> float: - """ - Return the total expected counts. - """ - self._update_cache() - source_dict = self._source.to_dict() - - if (source_dict != self._last_convolved_source_dict_number) or (self._exp_events is None): - area = self._area_cache - flux = self._source(self._area_energy_node_cache) - self._exp_events = np.sum(area * flux, dtype=float) - - self._last_convolved_source_dict_number = source_dict - return self._exp_events - - def expectation_density(self) -> Iterable[float]: - """ - Return the expected number of counts density. This equals the event probabiliy times the number of events. - """ - - self._update_cache() - source_dict = self._source.to_dict() - - if (source_dict != self._last_convolved_source_dict_density) or (self._exp_density is None): - self._exp_density = torch.zeros(self._n_events, dtype=self._irf_cache.dtype) - - if self._irf_energy_node_cache is not None: - flux = torch.as_tensor(self._source(self._irf_energy_node_cache), dtype=self._irf_cache.dtype) - - torch.linalg.vecdot(self._irf_cache, flux, dim=1, out=self._exp_density) - - else: - n_energy = self._total_energy_nodes[0] - batch_size = self._batch_size // n_energy - - sc_coord_sph = self._sc_coord_sph_cache - - lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) - lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) - - phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) - - for i in range(0, self._n_events, batch_size): - end = min(i + batch_size, self._n_events) - - e_sl = self._energy_m_keV[i:end] - p_sl = self._phi_rad[i:end] - pg_sl = phi_geo_rad[i:end] - pig_sl = phi_igeo_rad[i:end] - - nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) - - flux_batch = torch.as_tensor(self._source(np.asarray(nodes)), dtype=self._irf_cache.dtype) - - torch.linalg.vecdot(self._irf_cache[i:end], flux_batch, dim=1, out=self._exp_density[i:end]) - - self._last_convolved_source_dict_density = source_dict - - #print(self._data.time.unix[self._exp_density <= 0][:100]) - #print(np.sum(self._exp_density <= 0)/self._n_events * 100) - #print(self.expected_counts() - np.sum(np.log(self._exp_density+1e-12))) - return np.asarray(self._exp_density, dtype=np.float64)+1e-12 \ No newline at end of file + return self._event_prob \ No newline at end of file From 7f12bc30ff7576eabcbc8b5720d0ff062b441a8d Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Thu, 19 Feb 2026 17:12:21 +0100 Subject: [PATCH 06/16] Changed filenames nn -> nf --- cosipy/background_estimation/nf_unbinned_background.py | 0 cosipy/response/{NNResponse.py => NFResponse.py} | 4 ++-- ...ponse_function.py => nf_instrument_response_function.py} | 6 +++--- .../{nn_response_helper.py => nf_response_helper.py} | 0 4 files changed, 5 insertions(+), 5 deletions(-) create mode 100644 cosipy/background_estimation/nf_unbinned_background.py rename cosipy/response/{NNResponse.py => NFResponse.py} (99%) rename cosipy/response/{nn_instrument_response_function.py => nf_instrument_response_function.py} (95%) rename cosipy/response/{nn_response_helper.py => nf_response_helper.py} (100%) diff --git a/cosipy/background_estimation/nf_unbinned_background.py b/cosipy/background_estimation/nf_unbinned_background.py new file mode 100644 index 000000000..e69de29bb diff --git a/cosipy/response/NNResponse.py b/cosipy/response/NFResponse.py similarity index 99% rename from cosipy/response/NNResponse.py rename to cosipy/response/NFResponse.py index 12ba434b5..910f95c0a 100644 --- a/cosipy/response/NNResponse.py +++ b/cosipy/response/NFResponse.py @@ -8,7 +8,7 @@ import torch import torch.multiprocessing as mp -from .nn_response_helper import * +from .nf_response_helper import * def cuda_cleanup_task(_) -> bool: @@ -163,7 +163,7 @@ def evaluate_effective_area(self, context: torch.Tensor) -> torch.Tensor: return self._model.evaluate_effective_area(*list_context) -class NNResponse: +class NFResponse: def __init__(self, path_to_model: str, area_batch_size: int = 100_000, density_batch_size: int = 100_000, devices: Optional[List[Union[str, int, torch.device]]] = None, area_compile_mode: CompileMode = "max-autotune-no-cudagraphs", density_compile_mode: CompileMode = "default"): diff --git a/cosipy/response/nn_instrument_response_function.py b/cosipy/response/nf_instrument_response_function.py similarity index 95% rename from cosipy/response/nn_instrument_response_function.py rename to cosipy/response/nf_instrument_response_function.py index 35a3b981b..b672f70ee 100644 --- a/cosipy/response/nn_instrument_response_function.py +++ b/cosipy/response/nf_instrument_response_function.py @@ -6,7 +6,7 @@ from cosipy.interfaces.instrument_response_interface import FarFieldSpectralInstrumentResponseFunctionInterface from cosipy.interfaces.photon_parameters import PhotonListWithDirectionAndEnergyInSCFrameInterface from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays -from cosipy.response.NNResponse import NNResponse +from cosipy.response.NFResponse import NFResponse from cosipy.util.iterables import asarray @@ -18,12 +18,12 @@ import torch -class UnpolarizedNNFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): +class UnpolarizedNFFarFieldInstrumentResponseFunction(FarFieldSpectralInstrumentResponseFunctionInterface): event_data_type = EmCDSEventDataInSCFrameInterface photon_list_type = PhotonListWithDirectionAndEnergyInSCFrameInterface - def __init__(self, response: NNResponse,): + def __init__(self, response: NFResponse,): if response.is_polarized: raise ValueError("The provided NNResponse is polarized, but UnpolarizedNNFarFieldInstrumentResponseFunction only supports unpolarized responses.") self._response = response diff --git a/cosipy/response/nn_response_helper.py b/cosipy/response/nf_response_helper.py similarity index 100% rename from cosipy/response/nn_response_helper.py rename to cosipy/response/nf_response_helper.py From fc790ad93ec7caf6f25e5ac159c7dbde29cb3981 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Fri, 20 Feb 2026 22:02:17 +0100 Subject: [PATCH 07/16] Renamed files, separated parts in to modules to allow for greater overlap with background development, fixed batchsize and compile mode setter bug. --- cosipy/response/NFBase.py | 278 +++++++++++++++ .../{nf_response_helper.py => NFModels.py} | 10 +- cosipy/response/NFResponse.py | 328 +++--------------- cosipy/response/NFWorkerState.py | 3 + cosipy/threeml/optimized_unbinned_folding.py | 4 +- 5 files changed, 344 insertions(+), 279 deletions(-) create mode 100644 cosipy/response/NFBase.py rename cosipy/response/{nf_response_helper.py => NFModels.py} (99%) create mode 100644 cosipy/response/NFWorkerState.py diff --git a/cosipy/response/NFBase.py b/cosipy/response/NFBase.py new file mode 100644 index 000000000..06f37967f --- /dev/null +++ b/cosipy/response/NFBase.py @@ -0,0 +1,278 @@ +from typing import List, Union, Optional +from pathlib import Path +from abc import ABC, abstractmethod + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch +import torch.multiprocessing as mp +from .NFModels import * +import cosipy.response.NFWorkerState as NFWorkerState + + +class DensityApproximation(ABC): + def __init__(self, major_version: int, density_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): + self._major_version = major_version + self._worker_device = worker_device + self._density_input = density_input + self._batch_size = batch_size + self._compile_mode = compile_mode + + self._model: DensityModel + self._expected_context_dim: int + self._expected_source_dim: int + + self._setup_model() + + @abstractmethod + def _setup_model(self): + ... + + def evaluate_density(self, context: torch.Tensor, source: torch.Tensor) -> torch.Tensor: + dim_context = context.shape[1] + dim_source = source.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + elif dim_source != self._expected_source_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_source_dim} features, but source has {dim_source}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + list_source = [source[:, i] for i in range(dim_source)] + + return self._model.evaluate_density(*list_context, *list_source) + + def sample_density(self, context: torch.Tensor) -> torch.Tensor: + dim_context = context.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + + return self._model.sample_density(*list_context) + +def cuda_cleanup_task(_) -> bool: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return True + +def update_density_worker_settings(args: Tuple[str, Union[int, CompileMode]]): + attr, value = args + + if attr == 'density_batch_size': + NFWorkerState.density_module._model.batch_size = value + elif attr == 'density_compile_mode': + NFWorkerState.density_module._model.compile_mode = value + +def init_density_worker(device_queue: mp.Queue, major_version: int, + density_input: Dict, density_batch_size: int, + density_compile_mode: CompileMode, density_approximation: DensityApproximation): + + NFWorkerState.worker_device = torch.device(device_queue.get()) + if NFWorkerState.worker_device.type == 'cuda': + torch.cuda.set_device(NFWorkerState.worker_device) + + NFWorkerState.density_module = density_approximation(major_version, density_input, NFWorkerState.worker_device, density_batch_size, density_compile_mode) + +def evaluate_density_task(args: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: + context, source, indices = args + + sub_context = context[indices, :] + sub_source = source[indices, :] + if torch.device(NFWorkerState.worker_device).type == 'cuda': + sub_context = sub_context.pin_memory() + sub_source = sub_source.pin_memory() + + return NFWorkerState.density_module.evaluate_density(sub_context, sub_source) + +def sample_density_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + context, indices = args + + sub_context = context[indices, :] + if torch.device(NFWorkerState.worker_device).type == 'cuda': + sub_context = sub_context.pin_memory() + + return NFWorkerState.density_module.sample_density(sub_context) + +class NFBase(): + def __init__(self, path_to_model: Union[str, Path], update_worker, pool_worker, density_batch_size: int = 100_000, + devices: Optional[List[Union[str, int, torch.device]]] = None, density_compile_mode: CompileMode = "default", + additional_required_keys: List[str] = None): + self._ckpt = torch.load(str(path_to_model), map_location=torch.device('cpu'), weights_only=False) + + required_keys = ['version', 'density_input'] + (additional_required_keys or []) + + for key in required_keys: + if key not in self._ckpt: + raise KeyError( + f"Invalid Checkpoint: Metadata key '{key}' not found in {str(path_to_model)}. " + f"Ensure you saved the model as a dictionary, not just the state_dict." + ) + + self._version = self._ckpt['version'] + self._major_version = int(self._version.split('.')[0]) + self._density_input = self._ckpt['density_input'] + + self._pool = None + self._has_cuda = False + self._num_workers = 0 + self._ctx = mp.get_context("spawn") + self._pool_worker = pool_worker + self._update_worker = update_worker + + self.density_batch_size = density_batch_size + self.density_compile_mode = density_compile_mode + + self._update_pool_arguments() + + if devices is not None: + self.devices = devices + else: + self._devices = [] + + def __del__(self): + self.shutdown_compute_pool() + + @property + def devices(self) -> List[Union[str, int, torch.device]]: + return self._devices + + @devices.setter + def devices(self, value: List[Union[str, int, torch.device]]): + if not isinstance(value, list): + raise ValueError("devices must be a list of device identifiers") + self._devices = value + + @property + def density_batch_size(self) -> int: + return self._density_batch_size + + @density_batch_size.setter + def density_batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError("density_batch_size must be a positive integer") + self._density_batch_size = value + self._update_pool_arguments() + self._update_worker_config('density_batch_size', value) + + @property + def density_compile_mode(self) -> CompileMode: return self._density_compile_mode + + @density_compile_mode.setter + def density_compile_mode(self, value: CompileMode): + self._density_compile_mode = value + self._update_pool_arguments() + self._update_worker_config('density_compile_mode', value) + + def _update_worker_config(self, attr: str, value: Union[int, CompileMode]): + if self._pool is not None: + self._pool.map(self._update_worker, [(attr, value)] * self._num_workers) + + def _update_pool_arguments(self): + self._pool_arguments = [ + getattr(self, "_major_version", None), + getattr(self, "_density_input", None), + getattr(self, "_density_batch_size", None), + getattr(self, "_density_compile_mode", None), + ] + + def clean_compute_pool(self): + if self._pool: + self._pool.map(cuda_cleanup_task, range(self._num_workers)) + + def shutdown_compute_pool(self): + if self._pool: + self._pool.close() + self._pool.join() + + self._num_workers = 0 + self._pool = None + self._has_cuda = None + + def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device]]]=None): + active_devices = devices if devices is not None else self._devices + + if not active_devices: + raise RuntimeError("Cannot initialize pool: no devices provided as argument or set as fallback.") + + if self._pool: + print("Warning: Pool already initialized. Shutting down old pool first.") + self.shutdown_compute_pool() + + self._num_workers = len(active_devices) + self._has_cuda = any(torch.device(d).type == 'cuda' for d in active_devices) + + device_queue = self._ctx.Queue() + for d in active_devices: + device_queue.put(d) + + self._pool = self._ctx.Pool( + processes=self._num_workers, + initializer=self._pool_worker, + initargs=(device_queue, *self._pool_arguments), + ) + + def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): + context.share_memory_() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, idx) for idx in indices] + results = self._pool.map(sample_density_task, tasks) + + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() + + def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: + temp_pool = False + if self._pool is None: + target_devices = devices if devices is not None else self._devices + if not target_devices: + raise RuntimeError("No compute pool initialized and no devices provided/set.") + + self.init_compute_pool(target_devices) + temp_pool = True + + try: + if not context.is_shared(): context.share_memory_() + if not source.is_shared(): source.share_memory_() + + n_data = context.shape[0] + indices = torch.tensor_split(torch.arange(n_data), self._num_workers) + + tasks = [(context, source, idx) for idx in indices] + results = self._pool.map(evaluate_density_task, tasks) + + return torch.cat(results, dim=0) + + finally: + if temp_pool: + self.shutdown_compute_pool() diff --git a/cosipy/response/nf_response_helper.py b/cosipy/response/NFModels.py similarity index 99% rename from cosipy/response/nf_response_helper.py rename to cosipy/response/NFModels.py index 120fcd422..729149171 100644 --- a/cosipy/response/nf_response_helper.py +++ b/cosipy/response/NFModels.py @@ -114,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) @runtime_checkable -class AreaModelProtocol(Protocol): +class AreaModel(Protocol): @property def context_dim(self) -> int: ... @@ -132,7 +132,7 @@ def batch_size(self, value: int): ... def evaluate_effective_area(self, *args: torch.Tensor) -> torch.Tensor: ... -class UnpolarizedAreaSphericalHarmonicsExpansion(AreaModelProtocol): +class UnpolarizedAreaSphericalHarmonicsExpansion: def __init__(self, area_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode = "max-autotune-no-cudagraphs"): self._worker_device = torch.device(worker_device) @@ -312,7 +312,7 @@ def enqueue_transfer(slot_idx, start_idx): return torch.clamp(result, min=0) @runtime_checkable -class DensityModelProtocol(Protocol): +class DensityModel(Protocol): @property def context_dim(self) -> int: ... @@ -335,7 +335,7 @@ def sample_density(self, *args: torch.Tensor) -> torch.Tensor: ... def evaluate_density(self, *args: torch.Tensor) -> torch.Tensor: ... -class UnpolarizedDensityCMLPDGaussianCARQSFlow(DensityModelProtocol): +class UnpolarizedDensityCMLPDGaussianCARQSFlow: def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode = "default"): self._worker_device = torch.device(worker_device) @@ -413,7 +413,7 @@ def batch_size(self) -> int: return self._batch_size @batch_size.setter - def batch_size(self, value: int): + def batch_size(self, value: int): if not isinstance(value, int) or value <= 0: raise ValueError(f"Batch size must be a positive integer, got {value}") self._batch_size = value diff --git a/cosipy/response/NFResponse.py b/cosipy/response/NFResponse.py index 910f95c0a..4e192074f 100644 --- a/cosipy/response/NFResponse.py +++ b/cosipy/response/NFResponse.py @@ -1,4 +1,5 @@ -from typing import List, Union +from typing import List, Union, Optional +from pathlib import Path from importlib.util import find_spec @@ -8,86 +9,14 @@ import torch import torch.multiprocessing as mp -from .nf_response_helper import * +from .NFBase import * +import cosipy.response.NFWorkerState as NFWorkerState -def cuda_cleanup_task(_) -> bool: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return True - -def update_worker_settings(args: Tuple[str, Union[int, CompileMode]]): - attr, value = args - global area_module - global density_module - - if attr == 'area_batch_size': - area_module.batch_size = value - elif attr == 'density_batch_size': - density_module.batch_size = value - elif attr == 'area_compile_mode': - area_module.compile_mode = value - elif attr == 'density_compile_mode': - density_module.compile_mode = value - -def init_worker(device_queue: mp.Queue, major_version: int, area_input: Dict, - density_input: Dict, area_batch_size: int, density_batch_size: int, - area_compile_mode: CompileMode, density_compile_mode: CompileMode): - global area_module - global density_module - global worker_device - - worker_device = torch.device(device_queue.get()) - if worker_device.type == 'cuda': - torch.cuda.set_device(worker_device) - - area_module = AreaApproximation(major_version, area_input, worker_device, area_batch_size, area_compile_mode) - density_module = DensityApproximation(major_version, density_input, worker_device, density_batch_size, density_compile_mode) - -def evaluate_area_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - global area_module - context, indices = args - - sub_context = context[indices, :] - if torch.device(worker_device).type == 'cuda': - sub_context = sub_context.pin_memory() - - return area_module.evaluate_effective_area(sub_context) - -def evaluate_density_task(args: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: - global density_module - context, source, indices = args - - sub_context = context[indices, :] - sub_source = source[indices, :] - if torch.device(worker_device).type == 'cuda': - sub_context = sub_context.pin_memory() - sub_source = sub_source.pin_memory() - - return density_module.evaluate_density(sub_context, sub_source) - -def sample_density_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - global density_module - context, indices = args - - sub_context = context[indices, :] - if torch.device(worker_device).type == 'cuda': - sub_context = sub_context.pin_memory() - - return density_module.sample_density(sub_context) - -class DensityApproximation: - def __init__(self, major_version: int, density_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): - self._major_version = major_version - self._worker_device = worker_device - self._density_input = density_input - self._batch_size = batch_size - self._compile_mode = compile_mode - - self._setup_model() +class ResponseDensityApproximation(DensityApproximation): def _setup_model(self): - version_map: Dict[int, DensityModelProtocol] = { + version_map: Dict[int, DensityModel] = { 1: UnpolarizedDensityCMLPDGaussianCARQSFlow(self._density_input, self._worker_device, self._batch_size, self._compile_mode), } if self._major_version not in version_map: @@ -96,39 +25,6 @@ def _setup_model(self): self._model = version_map[self._major_version] self._expected_context_dim = self._model.context_dim self._expected_source_dim = self._model.source_dim - - def evaluate_density(self, context: torch.Tensor, source: torch.Tensor) -> torch.Tensor: - dim_context = context.shape[1] - dim_source = source.shape[1] - - if dim_context != self._expected_context_dim: - raise ValueError( - f"Feature mismatch: {type(self._model).__name__} expects " - f"{self._expected_context_dim} features, but context has {dim_context}." - ) - elif dim_source != self._expected_source_dim: - raise ValueError( - f"Feature mismatch: {type(self._model).__name__} expects " - f"{self._expected_source_dim} features, but source has {dim_source}." - ) - - list_context = [context[:, i] for i in range(dim_context)] - list_source = [source[:, i] for i in range(dim_source)] - - return self._model.evaluate_density(*list_context, *list_source) - - def sample_density(self, context: torch.Tensor) -> torch.Tensor: - dim_context = context.shape[1] - - if dim_context != self._expected_context_dim: - raise ValueError( - f"Feature mismatch: {type(self._model).__name__} expects " - f"{self._expected_context_dim} features, but context has {dim_context}." - ) - - list_context = [context[:, i] for i in range(dim_context)] - - return self._model.sample_density(*list_context) class AreaApproximation: def __init__(self, major_version: int, area_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): @@ -141,7 +37,7 @@ def __init__(self, major_version: int, area_input: Dict, worker_device: Union[st self._setup_model() def _setup_model(self): - version_map: Dict[int, AreaModelProtocol] = { + version_map: Dict[int, AreaModel] = { 1: UnpolarizedAreaSphericalHarmonicsExpansion(self._area_input, self._worker_device, self._batch_size, self._compile_mode), } if self._major_version not in version_map: @@ -163,53 +59,50 @@ def evaluate_effective_area(self, context: torch.Tensor) -> torch.Tensor: return self._model.evaluate_effective_area(*list_context) -class NFResponse: - def __init__(self, path_to_model: str, area_batch_size: int = 100_000, density_batch_size: int = 100_000, +def update_response_worker_settings(args: Tuple[str, Union[int, CompileMode]]): + update_density_worker_settings(args) + + attr, value = args + + if attr == 'area_batch_size': + NFWorkerState.area_module._model.batch_size = value + elif attr == 'area_compile_mode': + NFWorkerState.area_module._model.compile_mode = value + +def init_response_worker(device_queue: mp.Queue, major_version: int, area_input: Dict, + density_input: Dict, area_batch_size: int, density_batch_size: int, + area_compile_mode: CompileMode, density_compile_mode: CompileMode): + + init_density_worker(device_queue, major_version, + density_input, density_batch_size, + density_compile_mode, ResponseDensityApproximation) + + #NFWorkerState.density_module = ResponseDensityApproximation(major_version, density_input, NFWorkerState.worker_device, density_batch_size, density_compile_mode) + NFWorkerState.area_module = AreaApproximation(major_version, area_input, NFWorkerState.worker_device, area_batch_size, area_compile_mode) + +def evaluate_area_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + context, indices = args + + sub_context = context[indices, :] + if torch.device(NFWorkerState.worker_device).type == 'cuda': + sub_context = sub_context.pin_memory() + + return NFWorkerState.area_module.evaluate_effective_area(sub_context) + +class NFResponse(NFBase): + def __init__(self, path_to_model: Union[str, Path], area_batch_size: int = 100_000, density_batch_size: int = 100_000, devices: Optional[List[Union[str, int, torch.device]]] = None, area_compile_mode: CompileMode = "max-autotune-no-cudagraphs", density_compile_mode: CompileMode = "default"): - ckpt = torch.load(path_to_model, map_location=torch.device('cpu'), weights_only=False) - required_keys = ['version', 'is_polarized', 'density_input', 'area_input'] + super().__init__(path_to_model, update_response_worker_settings, init_response_worker, density_batch_size, devices, density_compile_mode, ['is_polarized', 'area_input']) - for key in required_keys: - if key not in ckpt: - raise KeyError( - f"Invalid Checkpoint: Metadata key '{key}' not found in {path_to_model}. " - f"Ensure you saved the model as a dictionary, not just the state_dict." - ) - - self._version = ckpt['version'] - self._major_version = int(self._version.split('.')[0]) - self._is_polarized = ckpt['is_polarized'] - self._density_input = ckpt['density_input'] - self._area_input = ckpt['area_input'] - - self._pool = None - self._has_cuda = False - self._ctx = mp.get_context("spawn") + self._is_polarized = self._ckpt['is_polarized'] + self._area_input = self._ckpt['area_input'] self.area_batch_size = area_batch_size - self.density_batch_size = density_batch_size - self._area_compile_mode = area_compile_mode - self._density_compile_mode = density_compile_mode + self.area_compile_mode = area_compile_mode - if devices is not None: - self.devices = devices - else: - self._devices = [] - - def __del__(self): - self.shutdown_compute_pool() - - @property - def devices(self) -> List[Union[str, int, torch.device]]: - return self._devices - - @devices.setter - def devices(self, value: List[Union[str, int, torch.device]]): - if not isinstance(value, list): - raise ValueError("devices must be a list of device identifiers") - self._devices = value + self._update_pool_arguments() @property def is_polarized(self) -> bool: @@ -224,103 +117,28 @@ def area_batch_size(self, value: int): if not isinstance(value, int) or value <= 0: raise ValueError("area_batch_size must be a positive integer") self._area_batch_size = value + self._update_pool_arguments() self._update_worker_config('area_batch_size', value) - @property - def density_batch_size(self) -> int: - return self._density_batch_size - - @density_batch_size.setter - def density_batch_size(self, value: int): - if not isinstance(value, int) or value <= 0: - raise ValueError("density_batch_size must be a positive integer") - self._density_batch_size = value - self._update_worker_config('density_batch_size', value) - @property def area_compile_mode(self) -> CompileMode: return self._area_compile_mode @area_compile_mode.setter def area_compile_mode(self, value: CompileMode): self._area_compile_mode = value + self._update_pool_arguments() self._update_worker_config('area_compile_mode', value) - @property - def density_compile_mode(self) -> CompileMode: return self._density_compile_mode - - @density_compile_mode.setter - def density_compile_mode(self, value: CompileMode): - self._density_compile_mode = value - self._update_worker_config('density_compile_mode', value) - - def _update_worker_config(self, attr: str, value: Union[int, CompileMode]): - if self._pool is not None: - self._pool.map(update_worker_settings, [(attr, value)] * self._num_workers) - - def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device]]]=None): - active_devices = devices if devices is not None else self._devices - - if not active_devices: - raise RuntimeError("Cannot initialize pool: no devices provided as argument or set as fallback.") - - if self._pool: - print("Warning: Pool already initialized. Shutting down old pool first.") - self.shutdown_compute_pool() - - self._num_workers = len(active_devices) - self._has_cuda = any(torch.device(d).type == 'cuda' for d in active_devices) - - device_queue = self._ctx.Queue() - for d in active_devices: - device_queue.put(d) - - self._pool = self._ctx.Pool( - processes=self._num_workers, - initializer=init_worker, - initargs=(device_queue, self._major_version, self._area_input, self._density_input, - self._area_batch_size, self._density_batch_size, - self._area_compile_mode, self._density_compile_mode), - ) - - def clean_compute_pool(self): - if self._pool: - self._pool.map(cuda_cleanup_task, range(self._num_workers)) - - def shutdown_compute_pool(self): - if self._pool: - self._pool.close() - self._pool.join() - - self._num_workers = 0 - self._pool = None - self._has_cuda = None - - def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: - temp_pool = False - if self._pool is None: - target_devices = devices if devices is not None else self._devices - if not target_devices: - raise RuntimeError("No compute pool initialized and no devices provided/set.") - self.init_compute_pool(target_devices) - temp_pool = True - - try: - if not context.is_shared(): - context.share_memory_() - #if self._has_cuda and not context.is_pinned(): - # context = context.pin_memory() - - n_data = context.shape[0] - indices = torch.tensor_split(torch.arange(n_data), self._num_workers) - - tasks = [(context, idx) for idx in indices] - results = self._pool.map(sample_density_task, tasks) - - return torch.cat(results, dim=0) - - finally: - if temp_pool: - self.shutdown_compute_pool() + def _update_pool_arguments(self): + self._pool_arguments = [ + getattr(self, "_major_version", None), + getattr(self, "_area_input", None), + getattr(self, "_density_input", None), + getattr(self, "_area_batch_size", None), + getattr(self, "_density_batch_size", None), + getattr(self, "_area_compile_mode", None), + getattr(self, "_density_compile_mode", None), + ] def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: temp_pool = False @@ -335,8 +153,6 @@ def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[ try: if not context.is_shared(): context.share_memory_() - #if self._has_cuda and not context.is_pinned(): - # context = context.pin_memory() n_data = context.shape[0] indices = torch.tensor_split(torch.arange(n_data), self._num_workers) @@ -346,38 +162,6 @@ def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[ return torch.cat(results, dim=0) - finally: - if temp_pool: - self.shutdown_compute_pool() - - def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: - temp_pool = False - if self._pool is None: - target_devices = devices if devices is not None else self._devices - if not target_devices: - raise RuntimeError("No compute pool initialized and no devices provided/set.") - - self.init_compute_pool(target_devices) - temp_pool = True - - try: - if not context.is_shared(): context.share_memory_() - if not source.is_shared(): source.share_memory_() - - #if self._has_cuda: - # if not context.is_pinned(): - # context = context.pin_memory() - # if not source.is_pinned(): - # source = source.pin_memory() - - n_data = context.shape[0] - indices = torch.tensor_split(torch.arange(n_data), self._num_workers) - - tasks = [(context, source, idx) for idx in indices] - results = self._pool.map(evaluate_density_task, tasks) - - return torch.cat(results, dim=0) - finally: if temp_pool: self.shutdown_compute_pool() \ No newline at end of file diff --git a/cosipy/response/NFWorkerState.py b/cosipy/response/NFWorkerState.py new file mode 100644 index 000000000..735b9f7b7 --- /dev/null +++ b/cosipy/response/NFWorkerState.py @@ -0,0 +1,3 @@ +worker_device = None +density_module = None +area_module = None \ No newline at end of file diff --git a/cosipy/threeml/optimized_unbinned_folding.py b/cosipy/threeml/optimized_unbinned_folding.py index 83216971a..93be29fd1 100644 --- a/cosipy/threeml/optimized_unbinned_folding.py +++ b/cosipy/threeml/optimized_unbinned_folding.py @@ -14,7 +14,7 @@ from cosipy.interfaces import UnbinnedThreeMLSourceResponseInterface, EventInterface from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface -from cosipy.interfaces.instrument_response_interface import FarFieldInstrumentResponseFunctionInterface +from cosipy.interfaces.instrument_response_interface import FarFieldSpectralInstrumentResponseFunctionInterface from cosipy.response.photon_types import PhotonListWithDirectionAndEnergyInSCFrame from cosipy.util.iterables import asarray @@ -36,7 +36,7 @@ class UnbinnedThreeMLPointSourceResponseIRFAdaptive(UnbinnedThreeMLSourceRespons def __init__(self, data: TimeTagEmCDSEventDataInSCFrameInterface, - irf: FarFieldInstrumentResponseFunctionInterface, + irf: FarFieldSpectralInstrumentResponseFunctionInterface, sc_history: SpacecraftHistory,): """ From b9dfd01cc37e9f6bb34c401562edb1cfc30c193c Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Wed, 25 Feb 2026 18:54:54 +0100 Subject: [PATCH 08/16] Made the normalizing flows code more general to avoid repetition when used for background. --- cosipy/response/NFBase.py | 382 ++++++++++++++- cosipy/response/NFModels.py | 713 ---------------------------- cosipy/response/NFResponse.py | 5 +- cosipy/response/NFResponseModels.py | 317 +++++++++++++ 4 files changed, 697 insertions(+), 720 deletions(-) delete mode 100644 cosipy/response/NFModels.py create mode 100644 cosipy/response/NFResponseModels.py diff --git a/cosipy/response/NFBase.py b/cosipy/response/NFBase.py index 06f37967f..143edef5f 100644 --- a/cosipy/response/NFBase.py +++ b/cosipy/response/NFBase.py @@ -1,19 +1,392 @@ -from typing import List, Union, Optional +from typing import List, Union, Optional, Literal, Tuple, Dict, Optional, Callable from pathlib import Path from abc import ABC, abstractmethod +import numpy as np from importlib.util import find_spec -if find_spec("torch") is None: +if any(find_spec(pkg) is None for pkg in ["torch", "normflows"]): raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") import torch +from torch import nn import torch.multiprocessing as mp -from .NFModels import * +import normflows as nf import cosipy.response.NFWorkerState as NFWorkerState +CompileMode = Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] + +def build_cmlp_diaggaussian_base(input_dim: int, output_dim: int, + hidden_dim: int, num_hidden_layers: int) -> nf.distributions.BaseDistribution: + context_encoder = BaseMLP(input_dim, output_dim, hidden_dim, num_hidden_layers) + return ConditionalDiagGaussian(shape=output_dim//2, context_encoder=context_encoder) + +def build_c_arqs_flow(base: nf.distributions.BaseDistribution, num_layers: int, + latent_dim: int, context_dim: int, num_bins: int, + num_hidden_units: int, num_residual_blocks: int) -> nf.ConditionalNormalizingFlow: + flows = [] + for _ in range(num_layers): + flows += [nf.flows.AutoregressiveRationalQuadraticSpline(num_input_channels = latent_dim, + num_blocks = num_residual_blocks, + num_hidden_channels = num_hidden_units, + num_bins = num_bins, + num_context_channels = context_dim)] + flows += [nf.flows.LULinearPermute(latent_dim)] + return nf.ConditionalNormalizingFlow(base, flows) + +class NNDensityInferenceWrapper(nn.Module): + def __init__(self, model: nn.Module): + super().__init__() + self._model = model + + def forward(self, + source: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + n_samples: Optional[int] = None, + mode: str = "inference") -> torch.Tensor: + if mode == "inference": + if context is None: + return torch.exp(self._model.log_prob(source)) + else: + return torch.exp(self._model.log_prob(source, context)) + elif mode == "sampling": + if context is None: + return self._model.sample(num_samples=n_samples)[0] + else: + return self._model.sample(num_samples=n_samples, context=context)[0] + +class ConditionalDiagGaussian(nf.distributions.BaseDistribution): + def __init__(self, shape: Union[int, List[int], Tuple[int, ...]], context_encoder: nn.Module): + super().__init__() + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, list): + shape = tuple(shape) + self.shape = shape + self.n_dim = len(shape) + self.d = np.prod(shape) + self.context_encoder = context_encoder + + def forward(self, num_samples: int=1, context: Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, torch.Tensor]: + encoder_output = self.context_encoder(context) + split_ind = encoder_output.shape[-1] // 2 + mean = encoder_output[..., :split_ind] + log_scale = encoder_output[..., split_ind:] + eps = torch.randn( + (num_samples,) + self.shape, dtype=mean.dtype, device=mean.device + ) + z = mean + torch.exp(log_scale) * eps + log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( + log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1)) + ) + return z, log_p + + def log_prob(self, z: torch.Tensor, context: Optional[torch.Tensor]=None) -> torch.Tensor: + encoder_output = self.context_encoder(context) + split_ind = encoder_output.shape[-1] // 2 + mean = encoder_output[..., :split_ind] + log_scale = encoder_output[..., split_ind:] + log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( + log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2), + list(range(1, self.n_dim + 1)), + ) + return log_p + +class BaseMLP(nn.Module): + def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_hidden_layers: int): + super().__init__() + + layers = [] + + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(nn.ReLU()) + + for _ in range(num_hidden_layers): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(nn.ReLU()) + + layers.append(nn.Linear(hidden_dim, output_dim)) + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class BaseModel(ABC): + + def __init__(self, compile_mode: CompileMode, batch_size: int, + worker_device: Union[str, int, torch.device], input: Dict): + self._worker_device = torch.device(worker_device) + + self._base_model = self._init_model(input) + + self._compile_mode = compile_mode + self._compiled_cache = {} + + self._update_model_op() + + self._is_cuda = (self._worker_device.type == 'cuda') + self.batch_size = batch_size + + if self._is_cuda: + self._compute_stream = torch.cuda.Stream(device=self._worker_device) + self._transfer_stream = torch.cuda.Stream(device=self._worker_device) + self._transfer_ready = [torch.cuda.Event(), torch.cuda.Event()] + self._compute_ready = [torch.cuda.Event(), torch.cuda.Event()] + else: + self._compute_stream = None + self._transfer_stream = None + self._transfer_ready = None + self._compute_ready = None + + @abstractmethod + def _init_model(self, input: Dict) -> Union[nn.Module, Callable]: ... + + @property + @abstractmethod + def context_dim(self) -> int: ... + + @property + def compile_mode(self) -> CompileMode: + return self._compile_mode + + @compile_mode.setter + def compile_mode(self, value: CompileMode): + if value != self._compile_mode: + self._compile_mode = value + self._update_model_op() + + def _update_model_op(self): + if self._compile_mode is None: + self._model_op = self._base_model + else: + if self._compile_mode not in self._compiled_cache: + self._compiled_cache[self._compile_mode] = torch.compile( + self._base_model, + mode=self._compile_mode + ) + self._model_op = self._compiled_cache[self._compile_mode] + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + if not isinstance(value, int) or value <= 0: + raise ValueError(f"Batch size must be a positive integer, got {value}") + self._batch_size = value + if self._is_cuda: + self._write_gpu_tensors() + + @abstractmethod + def _write_gpu_tensors(self): ... + +class AreaModel(BaseModel): + @abstractmethod + def evaluate_effective_area(self, *args: torch.Tensor) -> torch.Tensor: ... + +class DensityModel(BaseModel): + @property + @abstractmethod + def source_dim(self) -> int: ... + + def _write_gpu_tensors(self): + self._eval_inputs = [ + tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.source_dim + self.context_dim)) + for _ in range(2) + ] + self._eval_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] + + self._sample_inputs = [ + tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.context_dim)) + for _ in range(2) + ] + + self._sample_results = [ + (torch.empty((self._batch_size, self.source_dim), device=self._worker_device), + torch.empty(self._batch_size, dtype=torch.bool, device=self._worker_device)) + for _ in range(2) + ] + + @torch.inference_mode() + def sample_density(self, *args: torch.Tensor) -> torch.Tensor: + N = args[0].shape[0] + + result = torch.empty((N, self.source_dim), dtype=torch.float32, device="cpu") + failed_mask = torch.zeros(N, dtype=torch.bool, device="cpu") + + if self._is_cuda: + result, failed_mask = result.pin_memory(), failed_mask.pin_memory() + + def enqueue_sample_transfer(slot_idx, start_idx): + end_idx = min(start_idx + self._batch_size, N) + size = end_idx - start_idx + for i in range(self.context_dim): + self._sample_inputs[slot_idx][i][:size].copy_(args[i][start_idx:end_idx], non_blocking=True) + #self._sample_inputs[slot_idx][0][:size].copy_(energy_keV[start_idx:end_idx], non_blocking=True) + #self._sample_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) + #self._sample_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) + + if self._is_cuda and N > 0: + with torch.cuda.stream(self._transfer_stream): + enqueue_sample_transfer(0, 0) + self._transfer_ready[0].record(self._transfer_stream) + + for i, start in enumerate(range(0, N, self._batch_size)): + curr_idx = i % 2 + next_idx = (i + 1) % 2 + end = min(start + self._batch_size, N) + batch_len = end - start + next_start = start + self._batch_size + + if self._is_cuda: + with torch.cuda.stream(self._compute_stream): + self._compute_stream.wait_event(self._transfer_ready[curr_idx]) + + #b_ei, b_az, b_pol = [t[:batch_len] for t in self._sample_inputs[curr_idx]] + # + #b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) + #b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) + # + #b_ctx = torch.cat([ + # (b_az_sc + 1) / 2, + # (b_pol_sc[:, 1:] + 1) / 2, + # (torch.log10(b_ei) / 2 - 1).unsqueeze(1) + #], dim=1).to(torch.float32) + + b_ctx = [t[:batch_len] for t in self._sample_inputs[curr_idx]] + n_ctx = self._transform_context(*b_ctx) + + n_latent = self._model_op(context=n_ctx, mode="sampling", n_samples=batch_len) + + self._sample_results[curr_idx][0][:batch_len] = self._inverse_transform_coordinates(*(n_latent.T), *b_ctx) + self._sample_results[curr_idx][1][:batch_len] = ~self._valid_samples(*(n_latent.T), *b_ctx) + + self._compute_ready[curr_idx].record(self._compute_stream) + + if next_start < N: + with torch.cuda.stream(self._transfer_stream): + enqueue_sample_transfer(next_idx, next_start) + self._transfer_ready[next_idx].record(self._transfer_stream) + + with torch.cuda.stream(self._transfer_stream): + self._transfer_stream.wait_event(self._compute_ready[curr_idx]) + + result[start:end].copy_(self._sample_results[curr_idx][0][:batch_len], non_blocking=True) + failed_mask[start:end].copy_(self._sample_results[curr_idx][1][:batch_len], non_blocking=True) + else: + + #b_ei = energy_keV[start:end].to(self._worker_device) + #b_az, b_pol = dir_az[start:end].to(self._worker_device), dir_polar[start:end].to(self._worker_device) + + #b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) + #b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) + #b_ctx = torch.cat([ + # (b_az_sc + 1) / 2, (b_pol_sc[:, 1:] + 1) / 2, + # (torch.log10(b_ei) / 2 - 1).unsqueeze(1) + #], dim=1).to(torch.float32) + + b_ctx = [t[start:end].to(self._worker_device) for t in args] + n_ctx = self._transform_context(*b_ctx) + + n_latent = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) + result[start:end] = self._inverse_transform_coordinates(*(n_latent.T), *b_ctx) + failed_mask[start:end] = ~self._valid_samples(*(n_latent.T), *b_ctx) + + if self._is_cuda: + torch.cuda.synchronize(self._worker_device) + + if torch.any(failed_mask): + result[failed_mask] = self.sample_density(*[t[failed_mask] for t in args]) + + return result + + @abstractmethod + def _inverse_transform_coordinates(self, *args: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def _valid_samples(self, *args: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def _transform_context(self, *args: torch.Tensor) -> torch.Tensor: ... + + @abstractmethod + def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @torch.inference_mode() + def evaluate_density(self, *args: torch.Tensor) -> torch.Tensor: + + N = args[0].shape[0] + result = torch.empty(N, dtype=torch.float32, device="cpu") + + if self._is_cuda: + result = result.pin_memory() + + def enqueue_eval_transfer(slot_idx, start_idx): + end_idx = min(start_idx + self._batch_size, N) + size = end_idx - start_idx + for i in range(self.source_dim + self.context_dim): + self._eval_inputs[slot_idx][i][:size].copy_(args[i][start_idx:end_idx], non_blocking=True) + + if self._is_cuda and N > 0: + with torch.cuda.stream(self._transfer_stream): + enqueue_eval_transfer(0, 0) + self._transfer_ready[0].record(self._transfer_stream) + + for i, start in enumerate(range(0, N, self._batch_size)): + curr_idx = i % 2 + next_idx = (i + 1) % 2 + end = min(start + self._batch_size, N) + batch_len = end - start + next_start = start + self._batch_size + + if self._is_cuda: + with torch.cuda.stream(self._compute_stream): + self._compute_stream.wait_event(self._transfer_ready[curr_idx]) + + ctx, src, jac = self._transform_coordinates(*[t[:batch_len] for t in self._eval_inputs[curr_idx]]) + + torch.mul(self._model_op(src, ctx, mode="inference"), jac, out=self._eval_results[curr_idx][:batch_len]) + + self._compute_ready[curr_idx].record(self._compute_stream) + + if next_start < N: + with torch.cuda.stream(self._transfer_stream): + enqueue_eval_transfer(next_idx, next_start) + + self._transfer_ready[next_idx].record(self._transfer_stream) + + with torch.cuda.stream(self._transfer_stream): + self._transfer_stream.wait_event(self._compute_ready[curr_idx]) + + result[start:end].copy_(self._eval_results[curr_idx][:batch_len], non_blocking=True) + else: + ctx, src, jac = self._transform_coordinates(*[t[start:end].to(self._worker_device) for t in args]) + result[start:end] = self._model_op(src, ctx, mode="inference") * jac + + if self._is_cuda: + torch.cuda.synchronize(self._worker_device) + return result + +class RateModel(ABC): + def __init__(self, rate_input: Dict): + self._unpack_rate_input(rate_input) + + @property + @abstractmethod + def context_dim(self) -> int: ... + + @abstractmethod + def _unpack_rate_input(self, rate_input: Dict): ... + + @abstractmethod + def evaluate_rate(self, *args: torch.Tensor) -> torch.Tensor: ... + + class DensityApproximation(ABC): def __init__(self, major_version: int, density_input: Dict, worker_device: Union[str, int, torch.device], batch_size: int, compile_mode: CompileMode): self._major_version = major_version @@ -29,8 +402,7 @@ def __init__(self, major_version: int, density_input: Dict, worker_device: Union self._setup_model() @abstractmethod - def _setup_model(self): - ... + def _setup_model(self): ... def evaluate_density(self, context: torch.Tensor, source: torch.Tensor) -> torch.Tensor: dim_context = context.shape[1] diff --git a/cosipy/response/NFModels.py b/cosipy/response/NFModels.py deleted file mode 100644 index 729149171..000000000 --- a/cosipy/response/NFModels.py +++ /dev/null @@ -1,713 +0,0 @@ -import numpy as np -import healpy as hp - -from typing import Protocol, Optional, Literal, List, Union, Tuple, Dict, runtime_checkable - - -from importlib.util import find_spec - -if any(find_spec(pkg) is None for pkg in ["torch", "normflows", "sphericart.torch"]): - raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") - -import sphericart.torch -from torch import nn -import normflows as nf -import torch - - -CompileMode = Optional[Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]] - -def build_cmlp_diaggaussian_base(input_dim: int, output_dim: int, - hidden_dim: int, num_hidden_layers: int) -> nf.distributions.BaseDistribution: - context_encoder = BaseMLP(input_dim, output_dim, hidden_dim, num_hidden_layers) - return ConditionalDiagGaussian(shape=output_dim//2, context_encoder=context_encoder) - -def build_c_arqs_flow(base: nf.distributions.BaseDistribution, num_layers: int, - latent_dim: int, context_dim: int, num_bins: int, - num_hidden_units: int, num_residual_blocks: int) -> nf.ConditionalNormalizingFlow: - flows = [] - for _ in range(num_layers): - flows += [nf.flows.AutoregressiveRationalQuadraticSpline(num_input_channels = latent_dim, - num_blocks = num_residual_blocks, - num_hidden_channels = num_hidden_units, - num_bins = num_bins, - num_context_channels = context_dim)] - flows += [nf.flows.LULinearPermute(latent_dim)] - return nf.ConditionalNormalizingFlow(base, flows) - -class NNDensityInferenceWrapper(nn.Module): - def __init__(self, model: nn.Module): - super().__init__() - self._model = model - - def forward(self, - source: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - n_samples: Optional[int] = None, - mode: str = "inference") -> torch.Tensor: - if mode == "inference": - if context is None: - return torch.exp(self._model.log_prob(source)) - else: - return torch.exp(self._model.log_prob(source, context)) - elif mode == "sampling": - if context is None: - return self._model.sample(num_samples=n_samples)[0] - else: - return self._model.sample(num_samples=n_samples, context=context)[0] - -class ConditionalDiagGaussian(nf.distributions.BaseDistribution): - def __init__(self, shape: Union[int, List[int], Tuple[int, ...]], context_encoder: nn.Module): - super().__init__() - if isinstance(shape, int): - shape = (shape,) - if isinstance(shape, list): - shape = tuple(shape) - self.shape = shape - self.n_dim = len(shape) - self.d = np.prod(shape) - self.context_encoder = context_encoder - - def forward(self, num_samples: int=1, context: Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, torch.Tensor]: - encoder_output = self.context_encoder(context) - split_ind = encoder_output.shape[-1] // 2 - mean = encoder_output[..., :split_ind] - log_scale = encoder_output[..., split_ind:] - eps = torch.randn( - (num_samples,) + self.shape, dtype=mean.dtype, device=mean.device - ) - z = mean + torch.exp(log_scale) * eps - log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( - log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1)) - ) - return z, log_p - - def log_prob(self, z: torch.Tensor, context: Optional[torch.Tensor]=None) -> torch.Tensor: - encoder_output = self.context_encoder(context) - split_ind = encoder_output.shape[-1] // 2 - mean = encoder_output[..., :split_ind] - log_scale = encoder_output[..., split_ind:] - log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( - log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2), - list(range(1, self.n_dim + 1)), - ) - return log_p - -class BaseMLP(nn.Module): - def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_hidden_layers: int): - super().__init__() - - layers = [] - - layers.append(nn.Linear(input_dim, hidden_dim)) - layers.append(nn.ReLU()) - - for _ in range(num_hidden_layers): - layers.append(nn.Linear(hidden_dim, hidden_dim)) - layers.append(nn.ReLU()) - - layers.append(nn.Linear(hidden_dim, output_dim)) - - self.net = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net(x) - -@runtime_checkable -class AreaModel(Protocol): - @property - def context_dim(self) -> int: ... - - @property - def compile_mode(self) -> CompileMode: ... - - @compile_mode.setter - def compile_mode(self, value: CompileMode): ... - - @property - def batch_size(self) -> int: ... - - @batch_size.setter - def batch_size(self, value: int): ... - - def evaluate_effective_area(self, *args: torch.Tensor) -> torch.Tensor: ... - -class UnpolarizedAreaSphericalHarmonicsExpansion: - def __init__(self, area_input: Dict, worker_device: Union[str, int, torch.device], - batch_size: int, compile_mode: CompileMode = "max-autotune-no-cudagraphs"): - self._worker_device = torch.device(worker_device) - - self._lmax = area_input['lmax'] - self._poly_degree = area_input['poly_degree'] - self._poly_coeffs = area_input['poly_coeffs'] - - self._conv_coeffs = self._convert_coefficients().to(self._worker_device) - self._sh_calculator = sphericart.torch.SphericalHarmonics(self._lmax) - - self._compile_mode = compile_mode - self._compiled_cache = {} - - self._update_horner_op() - - self._is_cuda = (self._worker_device.type == 'cuda') - self.batch_size = batch_size - - if self._is_cuda: - self._compute_stream = torch.cuda.Stream(device=self._worker_device) - self._transfer_stream = torch.cuda.Stream(device=self._worker_device) - self._transfer_ready = [torch.cuda.Event(), torch.cuda.Event()] - self._compute_ready = [torch.cuda.Event(), torch.cuda.Event()] - else: - self._compute_stream = None - self._transfer_stream = None - self._transfer_ready = None - self._compute_ready = None - - def _write_gpu_tensors(self): - self._gpu_inputs = [ - (torch.empty(self._batch_size, device=self._worker_device), - torch.empty(self._batch_size, device=self._worker_device), - torch.empty(self._batch_size, device=self._worker_device)) - for _ in range(2) - ] - self._gpu_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] - - @property - def context_dim(self) -> int: - return 3 - - @property - def compile_mode(self) -> CompileMode: - return self._compile_mode - - @compile_mode.setter - def compile_mode(self, value: CompileMode): - if value != self._compile_mode: - self._compile_mode = value - self._update_horner_op() - - def _update_horner_op(self): - if self._compile_mode is None: - self._horner_op = self._horner_eval - else: - if self._compile_mode not in self._compiled_cache: - self._compiled_cache[self._compile_mode] = torch.compile( - self._horner_eval, - mode=self._compile_mode - ) - self._horner_op = self._compiled_cache[self._compile_mode] - - @property - def batch_size(self) -> int: - return self._batch_size - - @batch_size.setter - def batch_size(self, value: int): - if not isinstance(value, int) or value <= 0: - raise ValueError(f"Batch size must be a positive integer, got {value}") - self._batch_size = value - if self._is_cuda: - self._write_gpu_tensors() - - def _convert_coefficients(self) -> torch.Tensor: - num_sh = (self._lmax + 1)**2 - conv_coeffs = torch.zeros((num_sh, self._poly_degree + 1), dtype=torch.float64) - - for cnt, (l, m) in enumerate((l, m) for l in range(self._lmax + 1) for m in range(-l, l + 1)): - idx = hp.Alm.getidx(self._lmax, l, abs(m)) - if m == 0: - conv_coeffs[cnt] = self._poly_coeffs[0, :, idx] - else: - fac = np.sqrt(2) * (-1)**m - val = self._poly_coeffs[0, :, idx] if m > 0 else -self._poly_coeffs[1, :, idx] - conv_coeffs[cnt] = fac * val - return conv_coeffs.T - - def _horner_eval(self, x: torch.Tensor) -> torch.Tensor: - x_64 = x.to(torch.float64).unsqueeze(1) - result = self._conv_coeffs[0].expand(x.shape[0], -1).clone() - for i in range(1, self._conv_coeffs.size(0)): - result.mul_(x_64).add_(self._conv_coeffs[i]) - return result.to(torch.float32) - - def _compute_spherical_harmonics(self, dir_az: torch.Tensor, dir_polar: torch.Tensor) -> torch.Tensor: - sin_p = torch.sin(dir_polar) - xyz = torch.stack(( - sin_p * torch.cos(dir_az), - sin_p * torch.sin(dir_az), - torch.cos(dir_polar) - ), dim=-1) - return self._sh_calculator(xyz) - - @torch.inference_mode() - def evaluate_effective_area(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor) -> torch.Tensor: - N = energy_keV.shape[0] - - ei_norm = (torch.log10(energy_keV) / 2 - 1).to(torch.float32) - result = torch.empty(N, dtype=torch.float32, device="cpu") - - def get_batch(start_idx): - end_idx = min(start_idx + self._batch_size, N) - return ( - ei_norm[start_idx:end_idx].to(self._worker_device), - dir_az[start_idx:end_idx].to(self._worker_device), - dir_polar[start_idx:end_idx].to(self._worker_device) - ) - - if self._is_cuda: - ei_norm = ei_norm.pin_memory() - result = result.pin_memory() - - def enqueue_transfer(slot_idx, start_idx): - end_idx = min(start_idx + self._batch_size, N) - size = end_idx - start_idx - self._gpu_inputs[slot_idx][0][:size].copy_(ei_norm[start_idx:end_idx], non_blocking=True) - self._gpu_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) - self._gpu_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) - - if self._is_cuda and (N > 0): - with torch.cuda.stream(self._transfer_stream): - enqueue_transfer(0, 0) - self._transfer_ready[0].record(self._transfer_stream) - - for i, start in enumerate(range(0, N, self._batch_size)): - curr_idx = i % 2 - next_idx = (i + 1) % 2 - end = min(start + self._batch_size, N) - batch_len = end - start - next_start = start + self._batch_size - - if self._is_cuda: - with torch.cuda.stream(self._compute_stream): - self._compute_stream.wait_event(self._transfer_ready[curr_idx]) - - ei_b, az_b, pol_b = [t[:batch_len] for t in self._gpu_inputs[curr_idx]] - - poly_b = self._horner_op(ei_b) - ylm_b = self._compute_spherical_harmonics(az_b, pol_b) - - torch.sum(poly_b * ylm_b, dim=1, out=self._gpu_results[curr_idx][:batch_len]) - - self._compute_ready[curr_idx].record(self._compute_stream) - - if next_start < N: - with torch.cuda.stream(self._transfer_stream): - enqueue_transfer(next_idx, next_start) - - self._transfer_ready[next_idx].record(self._transfer_stream) - - with torch.cuda.stream(self._transfer_stream): - self._transfer_stream.wait_event(self._compute_ready[curr_idx]) - result[start:end].copy_(self._gpu_results[curr_idx][:batch_len], non_blocking=True) - else: - ei_b, az_b, pol_b = get_batch(start) - - poly_b = self._horner_op(ei_b) - ylm_b = self._compute_spherical_harmonics(az_b, pol_b) - result[start:end] = torch.sum(poly_b * ylm_b, dim=1) - - if self._is_cuda: - torch.cuda.synchronize(self._worker_device) - - return torch.clamp(result, min=0) - -@runtime_checkable -class DensityModel(Protocol): - @property - def context_dim(self) -> int: ... - - @property - def source_dim(self) -> int: ... - - @property - def compile_mode(self) -> CompileMode: ... - - @compile_mode.setter - def compile_mode(self, value: CompileMode): ... - - @property - def batch_size(self) -> int: ... - - @batch_size.setter - def batch_size(self, value: int): ... - - def sample_density(self, *args: torch.Tensor) -> torch.Tensor: ... - - def evaluate_density(self, *args: torch.Tensor) -> torch.Tensor: ... - -class UnpolarizedDensityCMLPDGaussianCARQSFlow: - def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.device], - batch_size: int, compile_mode: CompileMode = "default"): - self._worker_device = torch.device(worker_device) - - self._snapshot = density_input["model_state_dict"] - self._bins = density_input["bins"] - self._hidden_units = density_input["hidden_units"] - self._residual_blocks = density_input["residual_blocks"] - self._total_layers = density_input["total_layers"] - self._context_size = density_input["context_size"] - self._latent_size = density_input["latent_size"] - self._mlp_hidden_units = density_input["mlp_hidden_units"] - self._mlp_hidden_layers = density_input["mlp_hidden_layers"] - self._menergy_cuts = density_input["menergy_cuts"] - self._phi_cuts = density_input["phi_cuts"] - - self._compile_mode = compile_mode - self._compiled_cache = {} - - self._eager_model = self._init_base_model() - self._update_model_op() - - self._is_cuda = (self._worker_device.type == 'cuda') - self.batch_size = batch_size - - if self._is_cuda: - self._compute_stream = torch.cuda.Stream(device=self._worker_device) - self._transfer_stream = torch.cuda.Stream(device=self._worker_device) - self._transfer_ready = [torch.cuda.Event(), torch.cuda.Event()] - self._compute_ready = [torch.cuda.Event(), torch.cuda.Event()] - else: - self._compute_stream = None - self._transfer_stream = None - self._transfer_ready = None - self._compute_ready = None - - def _write_gpu_tensors(self): - self._eval_inputs = [ - tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.source_dim + self.context_dim)) - for _ in range(2) - ] - self._eval_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] - - self._sample_inputs = [ - tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.context_dim)) - for _ in range(2) - ] - - self._sample_results = [ - (torch.empty((self._batch_size, self._latent_size), device=self._worker_device), - torch.empty(self._batch_size, dtype=torch.bool, device=self._worker_device)) - for _ in range(2) - ] - - @property - def context_dim(self) -> int: - return 3 - - @property - def source_dim(self) -> int: - return 4 - - @property - def compile_mode(self) -> CompileMode: - return self._compile_mode - - @compile_mode.setter - def compile_mode(self, value: CompileMode): - if value != self._compile_mode: - self._compile_mode = value - self._update_model_op() - - @property - def batch_size(self) -> int: - return self._batch_size - - @batch_size.setter - def batch_size(self, value: int): - if not isinstance(value, int) or value <= 0: - raise ValueError(f"Batch size must be a positive integer, got {value}") - self._batch_size = value - if self._is_cuda: - self._write_gpu_tensors() - - def _build_model(self) -> nf.ConditionalNormalizingFlow: - base = build_cmlp_diaggaussian_base( - self._context_size, 2 * self._latent_size, self._mlp_hidden_units, self._mlp_hidden_layers - ) - return build_c_arqs_flow( - base, self._total_layers, self._latent_size, self._context_size, self._bins, self._hidden_units, self._residual_blocks - ) - - def _init_base_model(self) -> NNDensityInferenceWrapper: - model = self._build_model() - - model.load_state_dict(self._snapshot) - model = NNDensityInferenceWrapper(model) - model.eval() - model.to(self._worker_device) - - return model - - def _update_model_op(self): - if self._compile_mode is None: - self._model_op = self._eager_model - else: - if self._compile_mode not in self._compiled_cache: - self._compiled_cache[self._compile_mode] = torch.compile( - self._eager_model, - mode=self._compile_mode - ) - self._model_op = self._compiled_cache[self._compile_mode] - - @staticmethod - def _get_vector(phi_sc: torch.Tensor, theta_sc: torch.Tensor) -> torch.Tensor: - x = theta_sc[:, 0] * phi_sc[:, 1] - y = theta_sc[:, 0] * phi_sc[:, 0] - z = theta_sc[:, 1] - return torch.stack((x, y, z), dim=-1) - - def _convert_conventions(self, dir_az_sc: torch.Tensor, dir_polar_sc: torch.Tensor, - ei: torch.Tensor, em: torch.Tensor, phi: torch.Tensor, - scatt_az_sc: torch.Tensor, scatt_polar_sc: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - eps = em / ei - 1 - - source = self._get_vector(dir_az_sc, dir_polar_sc) - scatter = self._get_vector(scatt_az_sc, scatt_polar_sc) - - dot_product = torch.sum(source * scatter, dim=1) - phi_geo = torch.acos(torch.clamp(dot_product, -1.0, 1.0)) - theta = phi_geo - phi - - xaxis = torch.tensor([1., 0., 0.], device=source.device, dtype=source.dtype) - pz = -source - - px = torch.linalg.cross(pz, xaxis.expand_as(pz)) - px = px / torch.linalg.norm(px, dim=1, keepdim=True) - - py = torch.linalg.cross(pz, px) - py = py / torch.linalg.norm(py, dim=1, keepdim=True) - - proj_x = torch.sum(scatter * px, dim=1) - proj_y = torch.sum(scatter * py, dim=1) - - zeta = torch.atan2(proj_y, proj_x) - zeta = torch.where(zeta < 0, zeta + 2 * np.pi, zeta) - - return eps, theta, zeta - - def _inverse_transform_coordinates(self, samples: torch.Tensor, ei: torch.Tensor, - dir_az_sc: torch.Tensor, dir_pol_sc: torch.Tensor) -> torch.Tensor: - eps = -samples[:, 0] - phi = samples[:, 1] * np.pi - theta = (samples[:, 2] - 0.5) * (2 * np.pi) - zeta = samples[:, 3] * (2 * np.pi) - - em = ei * (eps + 1) - - phi_geo = theta + phi - scatter_phf = self._get_vector(torch.stack((torch.sin(zeta), torch.cos(zeta)), dim=1), - torch.stack((torch.sin(np.pi - phi_geo), torch.cos(np.pi - phi_geo)), dim=1)) - source_vec = self._get_vector(dir_az_sc, dir_pol_sc) - xaxis = torch.tensor([1., 0., 0.], device=self._worker_device, dtype=source_vec.dtype) - - pz = -source_vec - px = torch.linalg.cross(pz, xaxis.expand_as(pz)) - px = px / torch.linalg.norm(px, dim=1, keepdim=True) - py = torch.linalg.cross(pz, px) - py = py / torch.linalg.norm(py, dim=1, keepdim=True) - - basis = torch.stack((px, py, pz), dim=2) - scatter_scf = torch.bmm(basis, scatter_phf.unsqueeze(-1)).squeeze(-1) - - psi_cds = torch.atan2(scatter_scf[:, 1], scatter_scf[:, 0]) - psi_cds = torch.where(psi_cds < 0, psi_cds + 2 * np.pi, psi_cds) - chi_cds = torch.acos(torch.clamp(scatter_scf[:, 2], -1.0, 1.0)) - - return torch.stack([em, phi, psi_cds, chi_cds], dim=1) - - def _transform_coordinates(self, dir_az: torch.Tensor, dir_pol: torch.Tensor, - ei: torch.Tensor, em: torch.Tensor, phi: torch.Tensor, - scatt_az: torch.Tensor, scatt_pol: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) - dir_pol_sc = torch.stack((torch.sin(dir_pol), torch.cos(dir_pol)), dim=1) - scatt_az_sc = torch.stack((torch.sin(scatt_az), torch.cos(scatt_az)), dim=1) - scatt_pol_sc = torch.stack((torch.sin(scatt_pol), torch.cos(scatt_pol)), dim=1) - - eps_raw, theta_raw, zeta_raw = self._convert_conventions( - dir_az_sc, dir_pol_sc, ei, em, phi, scatt_az_sc, scatt_pol_sc - ) - - jac = 1.0 / (ei * torch.sin(theta_raw + phi) * 4 * np.pi**3) - - ctx = torch.cat([ - (dir_az_sc + 1) / 2, - (dir_pol_sc[:, 1:] + 1) / 2, - (torch.log10(ei) / 2 - 1).unsqueeze(1) - ], dim=1) - - src = torch.cat([ - (-eps_raw).unsqueeze(1), - (phi / np.pi).unsqueeze(1), - (theta_raw / (2 * np.pi) + 0.5).unsqueeze(1), - (zeta_raw / (2 * np.pi)).unsqueeze(1) - ], dim=1) - - return ctx.to(torch.float32), src.to(torch.float32), jac.to(torch.float32) - - def _valid_samples(self, ienergy: torch.Tensor, samples: torch.Tensor) -> torch.Tensor: - phi_geo_norm = samples[:, 1] + 2 * samples[:, 2] - 1.0 - valid_mask = (samples[:, 0] < 1.0) & \ - (samples[:, 1] > 0.0) & (samples[:, 1] <= 1.0) & \ - (samples[:, 2] >= 0.0) & (samples[:, 2] <= 1.0) & \ - (samples[:, 3] >= 0.0) & (samples[:, 3] <= 1.0) & \ - (phi_geo_norm > 0.0) & (phi_geo_norm < 1.0) & \ - (samples[:, 0] <= (1 - self._menergy_cuts[0]/ienergy)) & \ - (samples[:, 0] >= (1 - self._menergy_cuts[1]/ienergy)) & \ - (samples[:, 1] >= self._phi_cuts[0]/np.pi) & \ - (samples[:, 1] <= self._phi_cuts[1]/np.pi) - - return valid_mask - - @torch.inference_mode() - def sample_density(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor) -> torch.Tensor: - N = dir_az.shape[0] - - result = torch.empty((N, self._latent_size), dtype=torch.float32, device="cpu") - failed_mask = torch.zeros(N, dtype=torch.bool, device="cpu") - - if self._is_cuda: - result, failed_mask = result.pin_memory(), failed_mask.pin_memory() - - def enqueue_sample_transfer(slot_idx, start_idx): - end_idx = min(start_idx + self._batch_size, N) - size = end_idx - start_idx - self._sample_inputs[slot_idx][0][:size].copy_(energy_keV[start_idx:end_idx], non_blocking=True) - self._sample_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) - self._sample_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) - - if self._is_cuda and N > 0: - with torch.cuda.stream(self._transfer_stream): - enqueue_sample_transfer(0, 0) - self._transfer_ready[0].record(self._transfer_stream) - - for i, start in enumerate(range(0, N, self._batch_size)): - curr_idx = i % 2 - next_idx = (i + 1) % 2 - end = min(start + self._batch_size, N) - batch_len = end - start - next_start = start + self._batch_size - - if self._is_cuda: - with torch.cuda.stream(self._compute_stream): - self._compute_stream.wait_event(self._transfer_ready[curr_idx]) - - b_ei, b_az, b_pol = [t[:batch_len] for t in self._sample_inputs[curr_idx]] - - b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) - b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) - - b_ctx = torch.cat([ - (b_az_sc + 1) / 2, - (b_pol_sc[:, 1:] + 1) / 2, - (torch.log10(b_ei) / 2 - 1).unsqueeze(1) - ], dim=1).to(torch.float32) - - b_latent = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) - - self._sample_results[curr_idx][0][:batch_len] = self._inverse_transform_coordinates( - b_latent, b_ei, b_az_sc, b_pol_sc - ) - self._sample_results[curr_idx][1][:batch_len] = ~self._valid_samples(b_ei, b_latent) - - self._compute_ready[curr_idx].record(self._compute_stream) - - if next_start < N: - with torch.cuda.stream(self._transfer_stream): - enqueue_sample_transfer(next_idx, next_start) - self._transfer_ready[next_idx].record(self._transfer_stream) - - with torch.cuda.stream(self._transfer_stream): - self._transfer_stream.wait_event(self._compute_ready[curr_idx]) - - result[start:end].copy_(self._sample_results[curr_idx][0][:batch_len], non_blocking=True) - failed_mask[start:end].copy_(self._sample_results[curr_idx][1][:batch_len], non_blocking=True) - else: - b_ei = energy_keV[start:end].to(self._worker_device) - b_az, b_pol = dir_az[start:end].to(self._worker_device), dir_polar[start:end].to(self._worker_device) - - b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) - b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) - b_ctx = torch.cat([ - (b_az_sc + 1) / 2, (b_pol_sc[:, 1:] + 1) / 2, - (torch.log10(b_ei) / 2 - 1).unsqueeze(1) - ], dim=1).to(torch.float32) - - b_samples = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) - result[start:end] = self._inverse_transform_coordinates(b_samples, b_ei, b_az_sc, b_pol_sc) - failed_mask[start:end] = ~self._valid_samples(b_ei, b_samples) - - if self._is_cuda: - torch.cuda.synchronize(self._worker_device) - - if torch.any(failed_mask): - result[failed_mask] = self.sample_density( - dir_az[failed_mask], dir_polar[failed_mask], energy_keV[failed_mask] - ) - - return result - - @torch.inference_mode() - def evaluate_density( - self, dir_az: torch.Tensor, dir_polar: torch.Tensor, - energy_keV: torch.Tensor, menergy_keV: torch.Tensor, - scatt_angle: torch.Tensor, scatt_az: torch.Tensor, - scatt_polar: torch.Tensor) -> torch.Tensor: - - N = dir_az.shape[0] - result = torch.empty(N, dtype=torch.float32, device="cpu") - - if self._is_cuda: - result = result.pin_memory() - - def enqueue_eval_transfer(slot_idx, start_idx): - end_idx = min(start_idx + self._batch_size, N) - size = end_idx - start_idx - tensors = [dir_az, dir_polar, energy_keV, menergy_keV, scatt_angle, scatt_az, scatt_polar] - for i in range(self.source_dim + self.context_dim): - self._eval_inputs[slot_idx][i][:size].copy_(tensors[i][start_idx:end_idx], non_blocking=True) - - if self._is_cuda and N > 0: - with torch.cuda.stream(self._transfer_stream): - enqueue_eval_transfer(0, 0) - self._transfer_ready[0].record(self._transfer_stream) - - for i, start in enumerate(range(0, N, self._batch_size)): - curr_idx = i % 2 - next_idx = (i + 1) % 2 - end = min(start + self._batch_size, N) - batch_len = end - start - next_start = start + self._batch_size - - if self._is_cuda: - with torch.cuda.stream(self._compute_stream): - self._compute_stream.wait_event(self._transfer_ready[curr_idx]) - - ctx, src, jac = self._transform_coordinates(*[t[:batch_len] for t in self._eval_inputs[curr_idx]]) - - torch.mul(self._model_op(src, ctx, mode="inference"), jac, out=self._eval_results[curr_idx][:batch_len]) - - self._compute_ready[curr_idx].record(self._compute_stream) - - if next_start < N: - with torch.cuda.stream(self._transfer_stream): - enqueue_eval_transfer(next_idx, next_start) - - self._transfer_ready[next_idx].record(self._transfer_stream) - - with torch.cuda.stream(self._transfer_stream): - self._transfer_stream.wait_event(self._compute_ready[curr_idx]) - - result[start:end].copy_(self._eval_results[curr_idx][:batch_len], non_blocking=True) - else: - b_az, b_pol = dir_az[start:end].to(self._worker_device), dir_polar[start:end].to(self._worker_device) - b_ei, b_em = energy_keV[start:end].to(self._worker_device), menergy_keV[start:end].to(self._worker_device) - b_phi = scatt_angle[start:end].to(self._worker_device) - b_s_az, b_s_pol = scatt_az[start:end].to(self._worker_device), scatt_polar[start:end].to(self._worker_device) - - ctx, src, jac = self._transform_coordinates(b_az, b_pol, b_ei, b_em, b_phi, b_s_az, b_s_pol) - result[start:end] = self._model_op(src, ctx, mode="inference") * jac - - if self._is_cuda: - torch.cuda.synchronize(self._worker_device) - return result \ No newline at end of file diff --git a/cosipy/response/NFResponse.py b/cosipy/response/NFResponse.py index 4e192074f..a44daa032 100644 --- a/cosipy/response/NFResponse.py +++ b/cosipy/response/NFResponse.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional +from typing import List, Union, Optional, Dict, Tuple from pathlib import Path @@ -9,7 +9,8 @@ import torch import torch.multiprocessing as mp -from .NFBase import * +from .NFResponseModels import UnpolarizedDensityCMLPDGaussianCARQSFlow, UnpolarizedAreaSphericalHarmonicsExpansion +from .NFBase import DensityApproximation, CompileMode, NFBase, init_density_worker, update_density_worker_settings, AreaModel, DensityModel import cosipy.response.NFWorkerState as NFWorkerState diff --git a/cosipy/response/NFResponseModels.py b/cosipy/response/NFResponseModels.py new file mode 100644 index 000000000..94783128b --- /dev/null +++ b/cosipy/response/NFResponseModels.py @@ -0,0 +1,317 @@ +import numpy as np +import healpy as hp + +from typing import Union, Tuple, Dict + + +from importlib.util import find_spec + +if any(find_spec(pkg) is None for pkg in ["torch", "normflows", "sphericart.torch"]): + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + + +from .NFBase import CompileMode, build_c_arqs_flow, build_cmlp_diaggaussian_base, NNDensityInferenceWrapper, AreaModel, DensityModel +import sphericart.torch +import normflows as nf +import torch + + +class UnpolarizedAreaSphericalHarmonicsExpansion(AreaModel): + def __init__(self, area_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode, batch_size, worker_device, area_input) + + def _init_model(self, input: Dict): + self._lmax = input['lmax'] + self._poly_degree = input['poly_degree'] + self._poly_coeffs = input['poly_coeffs'] + + self._conv_coeffs = self._convert_coefficients().to(self._worker_device) + self._sh_calculator = sphericart.torch.SphericalHarmonics(self._lmax) + + return self._horner_eval + + def _write_gpu_tensors(self): + self._gpu_inputs = [ + (torch.empty(self._batch_size, device=self._worker_device), + torch.empty(self._batch_size, device=self._worker_device), + torch.empty(self._batch_size, device=self._worker_device)) + for _ in range(2) + ] + self._gpu_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] + + @property + def context_dim(self) -> int: + return 3 + + def _convert_coefficients(self) -> torch.Tensor: + num_sh = (self._lmax + 1)**2 + conv_coeffs = torch.zeros((num_sh, self._poly_degree + 1), dtype=torch.float64) + + for cnt, (l, m) in enumerate((l, m) for l in range(self._lmax + 1) for m in range(-l, l + 1)): + idx = hp.Alm.getidx(self._lmax, l, abs(m)) + if m == 0: + conv_coeffs[cnt] = self._poly_coeffs[0, :, idx] + else: + fac = np.sqrt(2) * (-1)**m + val = self._poly_coeffs[0, :, idx] if m > 0 else -self._poly_coeffs[1, :, idx] + conv_coeffs[cnt] = fac * val + return conv_coeffs.T + + def _horner_eval(self, x: torch.Tensor) -> torch.Tensor: + x_64 = x.to(torch.float64).unsqueeze(1) + result = self._conv_coeffs[0].expand(x.shape[0], -1).clone() + for i in range(1, self._conv_coeffs.size(0)): + result.mul_(x_64).add_(self._conv_coeffs[i]) + return result.to(torch.float32) + + def _compute_spherical_harmonics(self, dir_az: torch.Tensor, dir_polar: torch.Tensor) -> torch.Tensor: + sin_p = torch.sin(dir_polar) + xyz = torch.stack(( + sin_p * torch.cos(dir_az), + sin_p * torch.sin(dir_az), + torch.cos(dir_polar) + ), dim=-1) + return self._sh_calculator(xyz) + + @torch.inference_mode() + def evaluate_effective_area(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor) -> torch.Tensor: + N = energy_keV.shape[0] + + ei_norm = (torch.log10(energy_keV) / 2 - 1).to(torch.float32) + result = torch.empty(N, dtype=torch.float32, device="cpu") + + def get_batch(start_idx): + end_idx = min(start_idx + self._batch_size, N) + return ( + ei_norm[start_idx:end_idx].to(self._worker_device), + dir_az[start_idx:end_idx].to(self._worker_device), + dir_polar[start_idx:end_idx].to(self._worker_device) + ) + + if self._is_cuda: + ei_norm = ei_norm.pin_memory() + result = result.pin_memory() + + def enqueue_transfer(slot_idx, start_idx): + end_idx = min(start_idx + self._batch_size, N) + size = end_idx - start_idx + self._gpu_inputs[slot_idx][0][:size].copy_(ei_norm[start_idx:end_idx], non_blocking=True) + self._gpu_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) + self._gpu_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) + + if self._is_cuda and (N > 0): + with torch.cuda.stream(self._transfer_stream): + enqueue_transfer(0, 0) + self._transfer_ready[0].record(self._transfer_stream) + + for i, start in enumerate(range(0, N, self._batch_size)): + curr_idx = i % 2 + next_idx = (i + 1) % 2 + end = min(start + self._batch_size, N) + batch_len = end - start + next_start = start + self._batch_size + + if self._is_cuda: + with torch.cuda.stream(self._compute_stream): + self._compute_stream.wait_event(self._transfer_ready[curr_idx]) + + ei_b, az_b, pol_b = [t[:batch_len] for t in self._gpu_inputs[curr_idx]] + + poly_b = self._model_op(ei_b) + ylm_b = self._compute_spherical_harmonics(az_b, pol_b) + + torch.sum(poly_b * ylm_b, dim=1, out=self._gpu_results[curr_idx][:batch_len]) + + self._compute_ready[curr_idx].record(self._compute_stream) + + if next_start < N: + with torch.cuda.stream(self._transfer_stream): + enqueue_transfer(next_idx, next_start) + + self._transfer_ready[next_idx].record(self._transfer_stream) + + with torch.cuda.stream(self._transfer_stream): + self._transfer_stream.wait_event(self._compute_ready[curr_idx]) + result[start:end].copy_(self._gpu_results[curr_idx][:batch_len], non_blocking=True) + else: + ei_b, az_b, pol_b = get_batch(start) + + poly_b = self._model_op(ei_b) + ylm_b = self._compute_spherical_harmonics(az_b, pol_b) + result[start:end] = torch.sum(poly_b * ylm_b, dim=1) + + if self._is_cuda: + torch.cuda.synchronize(self._worker_device) + + return torch.clamp(result, min=0) + +class UnpolarizedDensityCMLPDGaussianCARQSFlow(DensityModel): + def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "default"): + super().__init__(compile_mode, batch_size, worker_device, density_input) + + def _init_model(self, input: Dict): + self._snapshot = input["model_state_dict"] + self._bins = input["bins"] + self._hidden_units = input["hidden_units"] + self._residual_blocks = input["residual_blocks"] + self._total_layers = input["total_layers"] + self._context_size = input["context_size"] + self._mlp_hidden_units = input["mlp_hidden_units"] + self._mlp_hidden_layers = input["mlp_hidden_layers"] + self._menergy_cuts = input["menergy_cuts"] + self._phi_cuts = input["phi_cuts"] + + return self._load_model() + + @property + def context_dim(self) -> int: + return 3 + + @property + def source_dim(self) -> int: + return 4 + + def _build_model(self) -> nf.ConditionalNormalizingFlow: + base = build_cmlp_diaggaussian_base( + self._context_size, 2 * self.source_dim, self._mlp_hidden_units, self._mlp_hidden_layers + ) + return build_c_arqs_flow( + base, self._total_layers, self.source_dim, self._context_size, self._bins, self._hidden_units, self._residual_blocks + ) + + def _load_model(self) -> NNDensityInferenceWrapper: + model = self._build_model() + + model.load_state_dict(self._snapshot) + model = NNDensityInferenceWrapper(model) + model.eval() + model.to(self._worker_device) + + return model + + @staticmethod + def _get_vector(phi_sc: torch.Tensor, theta_sc: torch.Tensor) -> torch.Tensor: + x = theta_sc[:, 0] * phi_sc[:, 1] + y = theta_sc[:, 0] * phi_sc[:, 0] + z = theta_sc[:, 1] + return torch.stack((x, y, z), dim=-1) + + def _convert_conventions(self, dir_az_sc: torch.Tensor, dir_polar_sc: torch.Tensor, + ei: torch.Tensor, em: torch.Tensor, phi: torch.Tensor, + scatt_az_sc: torch.Tensor, scatt_polar_sc: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + eps = em / ei - 1 + + source = self._get_vector(dir_az_sc, dir_polar_sc) + scatter = self._get_vector(scatt_az_sc, scatt_polar_sc) + + dot_product = torch.sum(source * scatter, dim=1) + phi_geo = torch.acos(torch.clamp(dot_product, -1.0, 1.0)) + theta = phi_geo - phi + + xaxis = torch.tensor([1., 0., 0.], device=source.device, dtype=source.dtype) + pz = -source + + px = torch.linalg.cross(pz, xaxis.expand_as(pz)) + px = px / torch.linalg.norm(px, dim=1, keepdim=True) + + py = torch.linalg.cross(pz, px) + py = py / torch.linalg.norm(py, dim=1, keepdim=True) + + proj_x = torch.sum(scatter * px, dim=1) + proj_y = torch.sum(scatter * py, dim=1) + + zeta = torch.atan2(proj_y, proj_x) + zeta = torch.where(zeta < 0, zeta + 2 * np.pi, zeta) + + return eps, theta, zeta + + def _inverse_transform_coordinates(self, *args: torch.Tensor) -> torch.Tensor: + neps, nphi, ntheta, nzeta, dir_az, dir_pol, ei = args + + eps = -neps + phi = nphi * np.pi + theta = (ntheta - 0.5) * (2 * np.pi) + zeta = nzeta * (2 * np.pi) + + em = ei * (eps + 1) + + phi_geo = theta + phi + scatter_phf = self._get_vector(torch.stack((torch.sin(zeta), torch.cos(zeta)), dim=1), + torch.stack((torch.sin(np.pi - phi_geo), torch.cos(np.pi - phi_geo)), dim=1)) + + dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) + dir_pol_sc = torch.stack((torch.sin(dir_pol), torch.cos(dir_pol)), dim=1) + source_vec = self._get_vector(dir_az_sc, dir_pol_sc) + xaxis = torch.tensor([1., 0., 0.], device=self._worker_device, dtype=source_vec.dtype) + + pz = -source_vec + px = torch.linalg.cross(pz, xaxis.expand_as(pz)) + px = px / torch.linalg.norm(px, dim=1, keepdim=True) + py = torch.linalg.cross(pz, px) + py = py / torch.linalg.norm(py, dim=1, keepdim=True) + + basis = torch.stack((px, py, pz), dim=2) + scatter_scf = torch.bmm(basis, scatter_phf.unsqueeze(-1)).squeeze(-1) + + psi_cds = torch.atan2(scatter_scf[:, 1], scatter_scf[:, 0]) + psi_cds = torch.where(psi_cds < 0, psi_cds + 2 * np.pi, psi_cds) + chi_cds = torch.acos(torch.clamp(scatter_scf[:, 2], -1.0, 1.0)) + + return torch.stack([em, phi, psi_cds, chi_cds], dim=1) + + def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dir_az, dir_pol, ei, em, phi, scatt_az, scatt_pol = args + + dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) + dir_pol_sc = torch.stack((torch.sin(dir_pol), torch.cos(dir_pol)), dim=1) + scatt_az_sc = torch.stack((torch.sin(scatt_az), torch.cos(scatt_az)), dim=1) + scatt_pol_sc = torch.stack((torch.sin(scatt_pol), torch.cos(scatt_pol)), dim=1) + + eps_raw, theta_raw, zeta_raw = self._convert_conventions( + dir_az_sc, dir_pol_sc, ei, em, phi, scatt_az_sc, scatt_pol_sc + ) + + jac = 1.0 / (ei * torch.sin(theta_raw + phi) * 4 * np.pi**3) + + ctx = self._transform_context(dir_az, dir_pol, ei) + + src = torch.cat([ + (-eps_raw).unsqueeze(1), + (phi / np.pi).unsqueeze(1), + (theta_raw / (2 * np.pi) + 0.5).unsqueeze(1), + (zeta_raw / (2 * np.pi)).unsqueeze(1) + ], dim=1) + + return ctx.to(torch.float32), src.to(torch.float32), jac.to(torch.float32) + + def _transform_context(self, *args: torch.Tensor) -> torch.Tensor: + dir_az, dir_pol, ei = args + + dir_az_sc = torch.stack((torch.sin(dir_az), torch.cos(dir_az)), dim=1) + dir_pol_c = torch.cos(dir_pol).unsqueeze(1) + + ctx = torch.cat([ + (dir_az_sc + 1) / 2, + (dir_pol_c + 1) / 2, + (torch.log10(ei) / 2 - 1).unsqueeze(1) + ], dim=1) + + return ctx.to(torch.float32) + + def _valid_samples(self, *args: torch.Tensor) -> torch.Tensor: + neps, nphi, ntheta, nzeta, _, _, ei = args + + phi_geo_norm = nphi + 2 * ntheta - 1.0 + valid_mask = (neps < 1.0) & \ + (nphi > 0.0) & (nphi <= 1.0) & \ + (ntheta >= 0.0) & (ntheta <= 1.0) & \ + (nzeta >= 0.0) & (nzeta <= 1.0) & \ + (phi_geo_norm > 0.0) & (phi_geo_norm < 1.0) & \ + (neps <= (1 - self._menergy_cuts[0]/ei)) & \ + (neps >= (1 - self._menergy_cuts[1]/ei)) & \ + (nphi >= self._phi_cuts[0]/np.pi) & \ + (nphi <= self._phi_cuts[1]/np.pi) + + return valid_mask From 98ee5890351919e5a13df651623cebc40ff1e063 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Sun, 8 Mar 2026 14:35:17 +0100 Subject: [PATCH 09/16] Added background implementation. Since the effective area and normalizing flows models contain CPU syncs, the cuda streams provide no benefit and will be removed. This commit saves the current state. --- cosipy/background_estimation/NFBackground.py | 90 +++++++ .../NFBackgroundModels.py | 225 ++++++++++++++++++ cosipy/response/NFBase.py | 144 +++++++++-- cosipy/response/NFResponse.py | 43 +++- cosipy/response/NFResponseModels.py | 53 ++++- cosipy/response/NFWorkerState.py | 3 +- .../nf_instrument_response_function.py | 11 +- 7 files changed, 534 insertions(+), 35 deletions(-) create mode 100644 cosipy/background_estimation/NFBackground.py create mode 100644 cosipy/background_estimation/NFBackgroundModels.py diff --git a/cosipy/background_estimation/NFBackground.py b/cosipy/background_estimation/NFBackground.py new file mode 100644 index 000000000..a0477e9c4 --- /dev/null +++ b/cosipy/background_estimation/NFBackground.py @@ -0,0 +1,90 @@ +from typing import List, Union, Optional, Dict +from pathlib import Path + +from cosipy import SpacecraftHistory + + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +import torch +import torch.multiprocessing as mp +from cosipy.response.NFBase import NFBase, CompileMode, update_density_worker_settings, init_density_worker, DensityApproximation, DensityModel, RateModel +from .NFBackgroundModels import TotalBackgroundDensityCMLPDGaussianCARQSFlow, TotalDC4BackgroundRate + + +class BackgroundDensityApproximation(DensityApproximation): + + def _setup_model(self): + version_map: Dict[int, DensityModel] = { + 1: TotalBackgroundDensityCMLPDGaussianCARQSFlow(self._density_input, self._worker_device, self._batch_size, self._compile_mode), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Density Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + self._expected_source_dim = self._model.source_dim + +class BackgroundRateApproximation: + def __init__(self, major_version: int, rate_input: Dict): + self._major_version = major_version + self._rate_input = rate_input + + self._setup_model() + + def _setup_model(self): + version_map: Dict[int, RateModel] = { + 1: TotalDC4BackgroundRate(self._rate_input), + } + if self._major_version not in version_map: + raise ValueError(f"Unsupported major version {self._major_version} for Rate Approximation") + else: + self._model = version_map[self._major_version] + self._expected_context_dim = self._model.context_dim + + def evaluate_rate(self, context: torch.Tensor) -> torch.Tensor: + dim_context = context.shape[1] + + if dim_context != self._expected_context_dim: + raise ValueError( + f"Feature mismatch: {type(self._model).__name__} expects " + f"{self._expected_context_dim} features, but context has {dim_context}." + ) + + list_context = [context[:, i] for i in range(dim_context)] + + return self._model.evaluate_rate(*list_context) + +def init_background_worker(device_queue: mp.Queue, progress_queue: mp.Queue, major_version: int, + density_input: Dict, density_batch_size: int, + density_compile_mode: CompileMode): + + init_density_worker(device_queue, progress_queue, major_version, + density_input, density_batch_size, + density_compile_mode, BackgroundDensityApproximation) + +class NFBackground(NFBase): + def __init__(self, path_to_model: Union[str, Path], + devices: Optional[List[Union[str, int, torch.device]]] = None, + density_batch_size: int = 100_000, density_compile_mode: CompileMode = "default", show_progress: bool = True): + + super().__init__(path_to_model, update_density_worker_settings, init_background_worker, density_batch_size, devices, density_compile_mode, ['rate_input'], show_progress) + + self._rate_approximation = BackgroundRateApproximation(self._major_version, self._ckpt['rate_input']) + + self._update_pool_arguments() + + def _update_pool_arguments(self): + self._pool_arguments = [ + getattr(self, "_major_version", None), + getattr(self, "_density_input", None), + getattr(self, "_density_batch_size", None), + getattr(self, "_density_compile_mode", None), + ] + + def evaluate_rate(self, context: torch.Tensor) -> torch.Tensor: + return self._rate_approximation.evaluate_rate(context) + \ No newline at end of file diff --git a/cosipy/background_estimation/NFBackgroundModels.py b/cosipy/background_estimation/NFBackgroundModels.py new file mode 100644 index 000000000..0f1cd8314 --- /dev/null +++ b/cosipy/background_estimation/NFBackgroundModels.py @@ -0,0 +1,225 @@ +import numpy as np + +from typing import Union, Tuple, Dict + + +from importlib.util import find_spec + +if any(find_spec(pkg) is None for pkg in ["torch", "normflows"]): + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + + +from cosipy.response.NFBase import CompileMode, build_c_arqs_flow, build_cmlp_diaggaussian_base, NNDensityInferenceWrapper, DensityModel, RateModel +import normflows as nf +import torch + + +class TotalDC4BackgroundRate(RateModel): + @property + def context_dim(self) -> int: + return 1 + + def _unpack_rate_input(self, rate_input: Dict): + self._slew_duration = rate_input["slew_duration"] + self._obs_duration = rate_input["obs_duration"] + self._start_time = rate_input["start_time"] + + self._offset: float = rate_input["offset"] + self._slope: float = rate_input["slope"] + self._buildup_A: Tuple[float, float] = rate_input["buildup"][0] + self._buildup_T: Tuple[float, float] = rate_input["buildup"][1] + self._scale: float = rate_input["scale"] + self._cutoff_T: float = rate_input["cutoff"][0] + self._cutoff_A: Tuple[float, float, float] = rate_input["cutoff"][1] + self._cutoff_kappa: Tuple[float, float, float] = rate_input["cutoff"][2] + self._cutoff_mu: Tuple[float, float, float] = rate_input["cutoff"][3] + self._outlocs: torch.Tensor = rate_input["outlocs"] + self._saa_decay_A: Tuple[float, float] = rate_input["saa_decay"][0] + self._saa_decay_T: Tuple[float, float] = rate_input["saa_decay"][1] + + @staticmethod + def _buildup(time: torch.Tensor, A: float, T: float) -> torch.Tensor: + return A * (1 - torch.exp(-time * np.log(2) / T)) + + def _pointing_scale(self, time: torch.Tensor, scale: float, k0: float=1.0) -> torch.Tensor: + half_slew = self._slew_duration / 2.0 + full_cycle = 2 * (self._obs_duration + self._slew_duration) + rel_t = time % full_cycle + k = k0 / self._slew_duration + + t1 = self._obs_duration + half_slew + t2 = full_cycle - half_slew + + s1 = 1 / (1 + torch.exp(-k * (rel_t - t1))) + s2 = 1 / (1 + torch.exp(-k * (rel_t - t2))) + s0 = 1 / (1 + torch.exp(-k * (rel_t - (t2 - full_cycle)))) + + return scale * (s0 - s1 + s2) + + @staticmethod + def _von_mises(time: torch.Tensor, T: float, A: float, kappa: float, mu: float) -> torch.Tensor: + return A * torch.exp(kappa * torch.cos(2 * np.pi * (time - mu) / T)) + + def _base_cutoff(self, time, T: float, A: Tuple[float, float, float], + kappa: Tuple[float, float, float], mu: Tuple[float, float, float]) -> torch.Tensor: + return self._von_mises(time, T, A[0], kappa[0], mu[0]) + \ + self._von_mises(time, T, A[1], kappa[1], mu[1]) + \ + self._von_mises(time, T, A[2], kappa[2], mu[2]) + + def _orbital_period(self, time, scale: float, T: float, A: Tuple[float, float, float], + kappa: Tuple[float, float, float], mu: Tuple[float, float, float]) -> torch.Tensor: + sample_times = torch.linspace(0, T, 1000) + fmin = torch.min(self._base_cutoff(sample_times, T, A, kappa, mu)) + + fval = self._base_cutoff(time, T, A, kappa, mu) + + return fmin + (fval - fmin) * (1 + scale) + + @staticmethod + def _decay(time: torch.Tensor, A: float, T: float) -> torch.Tensor: + return A * torch.exp(-time * np.log(2) / T) + + def _saa_decay(self, time: torch.Tensor, A: Tuple[float, float], T: Tuple[float, float]) -> torch.Tensor: + exit_times = (self._outlocs - self._start_time)/60 + last_exit = exit_times[torch.searchsorted(exit_times, time, right=True) - 1] + + return self._decay(time - last_exit, A[0], T[0]) + self._decay(time - last_exit, A[1], T[1]) + + def evaluate_rate(self, *args: torch.Tensor) -> torch.Tensor: + time = (args[0] - self._start_time)/60 + rate = self._offset + self._slope * time + rate += self._buildup(time, self._buildup_A[0], self._buildup_T[0]) + rate += self._buildup(time, self._buildup_A[1], self._buildup_T[1]) + rate += self._orbital_period(time, self._pointing_scale(time, self._scale), + self._cutoff_T, self._cutoff_A, + self._cutoff_kappa, self._cutoff_mu) + rate += self._saa_decay(time, self._saa_decay_A, self._saa_decay_T) + + return rate + +class TotalBackgroundDensityCMLPDGaussianCARQSFlow(DensityModel): + def __init__(self, density_input: Dict, worker_device: Union[str, int, torch.device], + batch_size: int, compile_mode: CompileMode = "default"): + super().__init__(compile_mode, batch_size, worker_device, density_input) + + def _init_model(self, input: Dict): + self._snapshot = input["model_state_dict"] + self._bins = input["bins"] + self._hidden_units = input["hidden_units"] + self._residual_blocks = input["residual_blocks"] + self._total_layers = input["total_layers"] + self._context_size = input["context_size"] + self._mlp_hidden_units = input["mlp_hidden_units"] + self._mlp_hidden_layers = input["mlp_hidden_layers"] + self._menergy_cuts = input["menergy_cuts"] + self._phi_cuts = input["phi_cuts"] + + self._start_time: float = input["start_time"] + self._total_time: float = input["total_time"] + self._period: float = input["period"] + self._slew_duration: float = input["slew_duration"] + self._obs_duration: float = input["obs_duration"] + self._outlocs: torch.Tensor = input["outlocs"].to(self._worker_device) + + return self._load_model() + + @property + def context_dim(self) -> int: + return 1 + + @property + def source_dim(self) -> int: + return 4 + + def _build_model(self) -> nf.ConditionalNormalizingFlow: + base = build_cmlp_diaggaussian_base( + self._context_size, 2 * self.source_dim, self._mlp_hidden_units, self._mlp_hidden_layers + ) + return build_c_arqs_flow( + base, self._total_layers, self.source_dim, self._context_size, self._bins, self._hidden_units, self._residual_blocks + ) + + def _load_model(self) -> NNDensityInferenceWrapper: + model = self._build_model() + + model.load_state_dict(self._snapshot) + model = NNDensityInferenceWrapper(model) + model.eval() + model.to(self._worker_device) + + return model + + def _inverse_transform_coordinates(self, *args: torch.Tensor) -> torch.Tensor: + nem, nphi, npsi, nchi, _ = args + + em = 10 ** (2 * (nem + 1)) + phi = nphi * np.pi + az = npsi * 2 * np.pi + pol = torch.acos(2 * nchi - 1) + + return torch.stack([em, phi, az, pol], dim=1) + + def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + time, em, phi, scatt_az, scatt_pol = args + + jac = 1/(np.log(10) * em * 8*np.pi**2) + + ctx = self._transform_context(time) + + src = torch.cat([ + (torch.log10(em)/2 - 1).unsqueeze(1), + (phi / np.pi).unsqueeze(1), + (scatt_az / (2 * np.pi)).unsqueeze(1), + ((torch.cos(scatt_pol) + 1) / 2).unsqueeze(1) + ], dim=1) + + return ctx.to(torch.float32), src.to(torch.float32), jac.to(torch.float32) + + def _sigmoid_switch(self, t: torch.Tensor, k0: float=1.0) -> torch.Tensor: + half_slew = self._slew_duration / 2.0 + full_cycle = 2 * (self._obs_duration + self._slew_duration) + rel_t = (t - self._start_time) % full_cycle + k = k0 / self._slew_duration + + t1 = self._obs_duration + half_slew + t2 = full_cycle - half_slew + + s1 = 1 / (1 + torch.exp(-k * (rel_t - t1))) + s2 = 1 / (1 + torch.exp(-k * (rel_t - t2))) + s0 = 1 / (1 + torch.exp(-k * (rel_t - (t2 - full_cycle)))) + + return s0 - s1 + s2 + + def _transform_context(self, *args: torch.Tensor) -> torch.Tensor: + time = args[0] + + last_exits = self._outlocs[torch.searchsorted(self._outlocs, time, right=True) - 1] + time_since_start = (time - self._start_time)/self._total_time + pointing_phase = self._sigmoid_switch(time, k0 = 1.0) + time_since_saa = (time - last_exits)/self._period + phase_c = (torch.cos((time - self._start_time)/self._period * 2 * np.pi) + 1) / 2 + phase_s = (torch.sin((time - self._start_time)/self._period * 2 * np.pi) + 1) / 2 + + ctx = torch.hstack([ + (time_since_start).unsqueeze(1), + (pointing_phase).unsqueeze(1), + (time_since_saa).unsqueeze(1), + (phase_c).unsqueeze(1), + (phase_s).unsqueeze(1) + ]) + + return ctx.to(torch.float32) + + def _valid_samples(self, *args: torch.Tensor) -> torch.Tensor: + nem, nphi, npsi, nchi, _ = args + + valid_mask = (nem >= 0.0) & \ + (nphi > 0.0) & (nphi <= 1.0) & \ + (npsi >= 0.0) & (npsi <= 1.0) & \ + (nchi >= 0.0) & (nchi <= 1.0) & \ + (nem >= (np.log10(self._menergy_cuts[0])/2 - 1)) & \ + (nem <= (np.log10(self._menergy_cuts[1])/2 - 1)) & \ + (nphi >= self._phi_cuts[0]/np.pi) & \ + (nphi <= self._phi_cuts[1]/np.pi) + + return valid_mask diff --git a/cosipy/response/NFBase.py b/cosipy/response/NFBase.py index 143edef5f..881798e7f 100644 --- a/cosipy/response/NFBase.py +++ b/cosipy/response/NFBase.py @@ -2,6 +2,9 @@ from pathlib import Path from abc import ABC, abstractmethod import numpy as np +from tqdm.auto import tqdm +import queue +import time from importlib.util import find_spec @@ -187,7 +190,7 @@ def _write_gpu_tensors(self): ... class AreaModel(BaseModel): @abstractmethod - def evaluate_effective_area(self, *args: torch.Tensor) -> torch.Tensor: ... + def evaluate_effective_area(self, *args: torch.Tensor, progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: ... class DensityModel(BaseModel): @property @@ -213,13 +216,16 @@ def _write_gpu_tensors(self): ] @torch.inference_mode() - def sample_density(self, *args: torch.Tensor) -> torch.Tensor: + def sample_density(self, *args: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: N = args[0].shape[0] result = torch.empty((N, self.source_dim), dtype=torch.float32, device="cpu") failed_mask = torch.zeros(N, dtype=torch.bool, device="cpu") if self._is_cuda: + num_batches = (N + self._batch_size - 1) // self._batch_size + pending_progress = [torch.cuda.Event() for _ in range(num_batches)] result, failed_mask = result.pin_memory(), failed_mask.pin_memory() def enqueue_sample_transfer(slot_idx, start_idx): @@ -278,6 +284,8 @@ def enqueue_sample_transfer(slot_idx, start_idx): result[start:end].copy_(self._sample_results[curr_idx][0][:batch_len], non_blocking=True) failed_mask[start:end].copy_(self._sample_results[curr_idx][1][:batch_len], non_blocking=True) + + pending_progress[i].record(self._transfer_stream) else: #b_ei = energy_keV[start:end].to(self._worker_device) @@ -296,12 +304,31 @@ def enqueue_sample_transfer(slot_idx, start_idx): n_latent = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) result[start:end] = self._inverse_transform_coordinates(*(n_latent.T), *b_ctx) failed_mask[start:end] = ~self._valid_samples(*(n_latent.T), *b_ctx) + + if progress_callback is not None: + amount = batch_len - torch.sum(failed_mask[start:end]).item() + progress_callback(amount) if self._is_cuda: + if progress_callback is not None: + i = 0 + while i < len(pending_progress): + event = pending_progress[i] + if event.query(): + start = self._batch_size * i + end = min(start + self._batch_size, N) + num_failed = torch.sum(failed_mask[start:end]).item() + amount = (end - start) - num_failed + if amount > 0: + progress_callback(amount) + + i += 1 + else: + time.sleep(0.01) torch.cuda.synchronize(self._worker_device) if torch.any(failed_mask): - result[failed_mask] = self.sample_density(*[t[failed_mask] for t in args]) + result[failed_mask] = self.sample_density(*[t[failed_mask] for t in args], progress_callback=progress_callback) return result @@ -318,12 +345,15 @@ def _transform_context(self, *args: torch.Tensor) -> torch.Tensor: ... def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... @torch.inference_mode() - def evaluate_density(self, *args: torch.Tensor) -> torch.Tensor: + def evaluate_density(self, *args: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: N = args[0].shape[0] result = torch.empty(N, dtype=torch.float32, device="cpu") if self._is_cuda: + num_batches = (N + self._batch_size - 1) // self._batch_size + pending_progress = [torch.cuda.Event() for _ in range(num_batches)] result = result.pin_memory() def enqueue_eval_transfer(slot_idx, start_idx): @@ -364,11 +394,30 @@ def enqueue_eval_transfer(slot_idx, start_idx): self._transfer_stream.wait_event(self._compute_ready[curr_idx]) result[start:end].copy_(self._eval_results[curr_idx][:batch_len], non_blocking=True) + + pending_progress[i].record(self._transfer_stream) else: ctx, src, jac = self._transform_coordinates(*[t[start:end].to(self._worker_device) for t in args]) result[start:end] = self._model_op(src, ctx, mode="inference") * jac + + if progress_callback is not None: + progress_callback(batch_len) if self._is_cuda: + if progress_callback is not None: + i = 0 + while i < len(pending_progress): + event = pending_progress[i] + if event.query(): + start = self._batch_size * i + end = min(start + self._batch_size, N) + amount = (end - start) + if amount > 0: + progress_callback(amount) + + i += 1 + else: + time.sleep(0.01) torch.cuda.synchronize(self._worker_device) return result @@ -404,7 +453,8 @@ def __init__(self, major_version: int, density_input: Dict, worker_device: Union @abstractmethod def _setup_model(self): ... - def evaluate_density(self, context: torch.Tensor, source: torch.Tensor) -> torch.Tensor: + def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: dim_context = context.shape[1] dim_source = source.shape[1] @@ -422,9 +472,10 @@ def evaluate_density(self, context: torch.Tensor, source: torch.Tensor) -> torch list_context = [context[:, i] for i in range(dim_context)] list_source = [source[:, i] for i in range(dim_source)] - return self._model.evaluate_density(*list_context, *list_source) + return self._model.evaluate_density(*list_context, *list_source, progress_callback=progress_callback) - def sample_density(self, context: torch.Tensor) -> torch.Tensor: + def sample_density(self, context: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: dim_context = context.shape[1] if dim_context != self._expected_context_dim: @@ -435,7 +486,7 @@ def sample_density(self, context: torch.Tensor) -> torch.Tensor: list_context = [context[:, i] for i in range(dim_context)] - return self._model.sample_density(*list_context) + return self._model.sample_density(*list_context, progress_callback=progress_callback) def cuda_cleanup_task(_) -> bool: if torch.cuda.is_available(): @@ -450,11 +501,12 @@ def update_density_worker_settings(args: Tuple[str, Union[int, CompileMode]]): elif attr == 'density_compile_mode': NFWorkerState.density_module._model.compile_mode = value -def init_density_worker(device_queue: mp.Queue, major_version: int, +def init_density_worker(device_queue: mp.Queue, progress_queue: mp.Queue, major_version: int, density_input: Dict, density_batch_size: int, density_compile_mode: CompileMode, density_approximation: DensityApproximation): NFWorkerState.worker_device = torch.device(device_queue.get()) + NFWorkerState.progress_queue = progress_queue if NFWorkerState.worker_device.type == 'cuda': torch.cuda.set_device(NFWorkerState.worker_device) @@ -469,7 +521,8 @@ def evaluate_density_task(args: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) sub_context = sub_context.pin_memory() sub_source = sub_source.pin_memory() - return NFWorkerState.density_module.evaluate_density(sub_context, sub_source) + cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None + return NFWorkerState.density_module.evaluate_density(sub_context, sub_source, progress_callback=cb) def sample_density_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: context, indices = args @@ -477,13 +530,14 @@ def sample_density_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor sub_context = context[indices, :] if torch.device(NFWorkerState.worker_device).type == 'cuda': sub_context = sub_context.pin_memory() - - return NFWorkerState.density_module.sample_density(sub_context) + + cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None + return NFWorkerState.density_module.sample_density(sub_context, progress_callback=cb) class NFBase(): def __init__(self, path_to_model: Union[str, Path], update_worker, pool_worker, density_batch_size: int = 100_000, devices: Optional[List[Union[str, int, torch.device]]] = None, density_compile_mode: CompileMode = "default", - additional_required_keys: List[str] = None): + additional_required_keys: List[str] = None, show_progress: bool = True): self._ckpt = torch.load(str(path_to_model), map_location=torch.device('cpu'), weights_only=False) required_keys = ['version', 'density_input'] + (additional_required_keys or []) @@ -505,6 +559,8 @@ def __init__(self, path_to_model: Union[str, Path], update_worker, pool_worker, self._ctx = mp.get_context("spawn") self._pool_worker = pool_worker self._update_worker = update_worker + self.show_progress = show_progress + self._progress_queue = None self.density_batch_size = density_batch_size self.density_compile_mode = density_compile_mode @@ -518,6 +574,16 @@ def __init__(self, path_to_model: Union[str, Path], update_worker, pool_worker, def __del__(self): self.shutdown_compute_pool() + + @property + def show_progress(self) -> bool: + return self._show_progress + + @show_progress.setter + def show_progress(self, value: bool): + if not isinstance(value, bool): + raise ValueError("show_progress must be a boolean") + self._show_progress = value @property def devices(self) -> List[Union[str, int, torch.device]]: @@ -550,6 +616,9 @@ def density_compile_mode(self, value: CompileMode): self._update_pool_arguments() self._update_worker_config('density_compile_mode', value) + @property + def active_pool(self) -> bool: return self._pool is not None + def _update_worker_config(self, attr: str, value: Union[int, CompileMode]): if self._pool is not None: self._pool.map(self._update_worker, [(attr, value)] * self._num_workers) @@ -574,6 +643,10 @@ def shutdown_compute_pool(self): self._num_workers = 0 self._pool = None self._has_cuda = None + + self._progress_queue.close() + self._progress_queue.join_thread() + self._progress_queue = None def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device]]]=None): active_devices = devices if devices is not None else self._devices @@ -591,11 +664,12 @@ def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device] device_queue = self._ctx.Queue() for d in active_devices: device_queue.put(d) + self._progress_queue = self._ctx.Queue() self._pool = self._ctx.Pool( processes=self._num_workers, initializer=self._pool_worker, - initargs=(device_queue, *self._pool_arguments), + initargs=(device_queue, self._progress_queue,*self._pool_arguments), ) def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str, int, torch.device]]]=None) -> torch.Tensor: @@ -615,9 +689,28 @@ def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str indices = torch.tensor_split(torch.arange(n_data), self._num_workers) tasks = [(context, idx) for idx in indices] - results = self._pool.map(sample_density_task, tasks) + async_result = self._pool.map_async(sample_density_task, tasks) + with tqdm(total=n_data, disable=(not self.show_progress), desc="Sampling the density", unit="calls", leave=False) as pbar: + while not async_result.ready(): + try: + while True: + completed = self._progress_queue.get_nowait() + pbar.update(completed) + except queue.Empty: + async_result.wait(timeout=0.1) + + while not self._progress_queue.empty(): + try: + pbar.update(self._progress_queue.get_nowait()) + except queue.Empty: + break + + results = async_result.get() return torch.cat(results, dim=0) + + #results = self._pool.map(sample_density_task, tasks) + #return torch.cat(results, dim=0) finally: if temp_pool: @@ -641,9 +734,28 @@ def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, devices: indices = torch.tensor_split(torch.arange(n_data), self._num_workers) tasks = [(context, source, idx) for idx in indices] - results = self._pool.map(evaluate_density_task, tasks) + async_result = self._pool.map_async(evaluate_density_task, tasks) + with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the density", unit="calls", leave=False) as pbar: + while not async_result.ready(): + try: + while True: + completed = self._progress_queue.get_nowait() + pbar.update(completed) + except queue.Empty: + async_result.wait(timeout=0.1) + + while not self._progress_queue.empty(): + try: + pbar.update(self._progress_queue.get_nowait()) + except queue.Empty: + break + + results = async_result.get() return torch.cat(results, dim=0) + + #results = self._pool.map(evaluate_density_task, tasks) + #return torch.cat(results, dim=0) finally: if temp_pool: diff --git a/cosipy/response/NFResponse.py b/cosipy/response/NFResponse.py index a44daa032..a18363773 100644 --- a/cosipy/response/NFResponse.py +++ b/cosipy/response/NFResponse.py @@ -1,5 +1,7 @@ -from typing import List, Union, Optional, Dict, Tuple +from typing import List, Union, Optional, Dict, Tuple, Callable from pathlib import Path +from tqdm.auto import tqdm +import queue from importlib.util import find_spec @@ -47,7 +49,8 @@ def _setup_model(self): self._model = version_map[self._major_version] self._expected_context_dim = self._model.context_dim - def evaluate_effective_area(self, context: torch.Tensor) -> torch.Tensor: + def evaluate_effective_area(self, context: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: dim_context = context.shape[1] if dim_context != self._expected_context_dim: @@ -58,7 +61,7 @@ def evaluate_effective_area(self, context: torch.Tensor) -> torch.Tensor: list_context = [context[:, i] for i in range(dim_context)] - return self._model.evaluate_effective_area(*list_context) + return self._model.evaluate_effective_area(*list_context, progress_callback=progress_callback) def update_response_worker_settings(args: Tuple[str, Union[int, CompileMode]]): update_density_worker_settings(args) @@ -70,11 +73,11 @@ def update_response_worker_settings(args: Tuple[str, Union[int, CompileMode]]): elif attr == 'area_compile_mode': NFWorkerState.area_module._model.compile_mode = value -def init_response_worker(device_queue: mp.Queue, major_version: int, area_input: Dict, +def init_response_worker(device_queue: mp.Queue, progress_queue: mp.Queue, major_version: int, area_input: Dict, density_input: Dict, area_batch_size: int, density_batch_size: int, area_compile_mode: CompileMode, density_compile_mode: CompileMode): - init_density_worker(device_queue, major_version, + init_density_worker(device_queue, progress_queue, major_version, density_input, density_batch_size, density_compile_mode, ResponseDensityApproximation) @@ -88,14 +91,15 @@ def evaluate_area_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: if torch.device(NFWorkerState.worker_device).type == 'cuda': sub_context = sub_context.pin_memory() - return NFWorkerState.area_module.evaluate_effective_area(sub_context) + cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None + return NFWorkerState.area_module.evaluate_effective_area(sub_context, progress_callback=cb) class NFResponse(NFBase): def __init__(self, path_to_model: Union[str, Path], area_batch_size: int = 100_000, density_batch_size: int = 100_000, - devices: Optional[List[Union[str, int, torch.device]]] = None, - area_compile_mode: CompileMode = "max-autotune-no-cudagraphs", density_compile_mode: CompileMode = "default"): + devices: Optional[List[Union[str, int, torch.device]]] = None, area_compile_mode: CompileMode = "max-autotune-no-cudagraphs", + density_compile_mode: CompileMode = "default", show_progress: bool = True): - super().__init__(path_to_model, update_response_worker_settings, init_response_worker, density_batch_size, devices, density_compile_mode, ['is_polarized', 'area_input']) + super().__init__(path_to_model, update_response_worker_settings, init_response_worker, density_batch_size, devices, density_compile_mode, ['is_polarized', 'area_input'], show_progress) self._is_polarized = self._ckpt['is_polarized'] self._area_input = self._ckpt['area_input'] @@ -159,9 +163,28 @@ def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[ indices = torch.tensor_split(torch.arange(n_data), self._num_workers) tasks = [(context, idx) for idx in indices] - results = self._pool.map(evaluate_area_task, tasks) + async_result = self._pool.map_async(evaluate_area_task, tasks) + with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the effective area", unit="calls", leave=False) as pbar: + while not async_result.ready(): + try: + while True: + completed = self._progress_queue.get_nowait() + pbar.update(completed) + except queue.Empty: + async_result.wait(timeout=0.1) + + while not self._progress_queue.empty(): + try: + pbar.update(self._progress_queue.get_nowait()) + except queue.Empty: + break + + results = async_result.get() return torch.cat(results, dim=0) + + #results = self._pool.map(evaluate_area_task, tasks) + #return torch.cat(results, dim=0) finally: if temp_pool: diff --git a/cosipy/response/NFResponseModels.py b/cosipy/response/NFResponseModels.py index 94783128b..50bd293a5 100644 --- a/cosipy/response/NFResponseModels.py +++ b/cosipy/response/NFResponseModels.py @@ -1,7 +1,8 @@ import numpy as np import healpy as hp +import time -from typing import Union, Tuple, Dict +from typing import Union, Tuple, Dict, Optional, Callable from importlib.util import find_spec @@ -75,7 +76,8 @@ def _compute_spherical_harmonics(self, dir_az: torch.Tensor, dir_polar: torch.Te return self._sh_calculator(xyz) @torch.inference_mode() - def evaluate_effective_area(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor) -> torch.Tensor: + def evaluate_effective_area(self, dir_az: torch.Tensor, dir_polar: torch.Tensor, energy_keV: torch.Tensor, + progress_callback: Optional[Callable[[int], None]] = None) -> torch.Tensor: N = energy_keV.shape[0] ei_norm = (torch.log10(energy_keV) / 2 - 1).to(torch.float32) @@ -90,6 +92,8 @@ def get_batch(start_idx): ) if self._is_cuda: + num_batches = (N + self._batch_size - 1) // self._batch_size + pending_progress = [torch.cuda.Event() for _ in range(num_batches)] ei_norm = ei_norm.pin_memory() result = result.pin_memory() @@ -104,8 +108,10 @@ def enqueue_transfer(slot_idx, start_idx): with torch.cuda.stream(self._transfer_stream): enqueue_transfer(0, 0) self._transfer_ready[0].record(self._transfer_stream) - + + torch.cuda.set_sync_debug_mode(1) for i, start in enumerate(range(0, N, self._batch_size)): + print(f"Loop: {i}", flush=True) curr_idx = i % 2 next_idx = (i + 1) % 2 end = min(start + self._batch_size, N) @@ -117,31 +123,61 @@ def enqueue_transfer(slot_idx, start_idx): self._compute_stream.wait_event(self._transfer_ready[curr_idx]) ei_b, az_b, pol_b = [t[:batch_len] for t in self._gpu_inputs[curr_idx]] - + t0 = time.time() poly_b = self._model_op(ei_b) + t1 = time.time() ylm_b = self._compute_spherical_harmonics(az_b, pol_b) + t2 = time.time() torch.sum(poly_b * ylm_b, dim=1, out=self._gpu_results[curr_idx][:batch_len]) - + t3 = time.time() self._compute_ready[curr_idx].record(self._compute_stream) if next_start < N: with torch.cuda.stream(self._transfer_stream): + t4 = time.time() enqueue_transfer(next_idx, next_start) - + t5 = time.time() self._transfer_ready[next_idx].record(self._transfer_stream) with torch.cuda.stream(self._transfer_stream): self._transfer_stream.wait_event(self._compute_ready[curr_idx]) + t6 = time.time() result[start:end].copy_(self._gpu_results[curr_idx][:batch_len], non_blocking=True) + t7 = time.time() + pending_progress[i].record(self._transfer_stream) + if i < 5 and self._worker_device == torch.device("cuda:0"): + print(f"Run: {i} on device {self._worker_device}", flush=True) + print(f"model_op: {t1-t0:.5f}s", flush=True) + print(f"spherical_harmonics: {t2-t1:.5f}s", flush=True) + print(f"sum: {t3-t2:.5f}s", flush=True) + print(f"enqueue_transfer: {t5-t4:.5f}s", flush=True) + print(f"result copy: {t7-t6:.5f}s", flush=True) else: ei_b, az_b, pol_b = get_batch(start) poly_b = self._model_op(ei_b) ylm_b = self._compute_spherical_harmonics(az_b, pol_b) result[start:end] = torch.sum(poly_b * ylm_b, dim=1) - + + if progress_callback is not None: + progress_callback(batch_len) + print("Done", flush=True) if self._is_cuda: + if progress_callback is not None: + i = 0 + while i < len(pending_progress): + event = pending_progress[i] + if event.query(): + start = self._batch_size * i + end = min(start + self._batch_size, N) + amount = (end - start) + if amount > 0: + progress_callback(amount) + + i += 1 + else: + time.sleep(0.01) torch.cuda.synchronize(self._worker_device) return torch.clamp(result, min=0) @@ -274,6 +310,9 @@ def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, tor ) jac = 1.0 / (ei * torch.sin(theta_raw + phi) * 4 * np.pi**3) + for i in range(torch.sum(torch.isinf(jac))): + print("You found a rare event where phi_geo is 0 or 180 degrees!", flush=True) + jac[torch.isinf(jac)] = 0.0 ctx = self._transform_context(dir_az, dir_pol, ei) diff --git a/cosipy/response/NFWorkerState.py b/cosipy/response/NFWorkerState.py index 735b9f7b7..6e28c0e83 100644 --- a/cosipy/response/NFWorkerState.py +++ b/cosipy/response/NFWorkerState.py @@ -1,3 +1,4 @@ worker_device = None density_module = None -area_module = None \ No newline at end of file +area_module = None +progress_queue = None \ No newline at end of file diff --git a/cosipy/response/nf_instrument_response_function.py b/cosipy/response/nf_instrument_response_function.py index b672f70ee..01fc46aec 100644 --- a/cosipy/response/nf_instrument_response_function.py +++ b/cosipy/response/nf_instrument_response_function.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, Optional, List, Union import numpy as np @@ -28,6 +28,15 @@ def __init__(self, response: NFResponse,): raise ValueError("The provided NNResponse is polarized, but UnpolarizedNNFarFieldInstrumentResponseFunction only supports unpolarized responses.") self._response = response + def init_compute_pool(self, devices: Optional[List[Union[str, int, torch.device]]]=None): + self._response.init_compute_pool(devices) + + def shutdown_compute_pool(self): + self._response.shutdown_compute_pool() + + @property + def active_pool(self) -> bool: return self._response.active_pool + @staticmethod def _get_context(photons: PhotonListWithDirectionAndEnergyInSCFrameInterface): lon = torch.as_tensor(asarray(photons.direction_lon_rad_sc, dtype=np.float32)) From 2194e6e78b9af6ad209537ec3835e5679eddb801 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Tue, 10 Mar 2026 22:54:03 +0100 Subject: [PATCH 10/16] Changes before cosipy PR. --- .../nf_unbinned_background.py | 118 +++++++ cosipy/interfaces/expectation_interface.py | 2 +- .../interfaces/source_response_interface.py | 25 +- cosipy/response/NFBase.py | 229 +----------- cosipy/response/NFResponse.py | 10 +- cosipy/response/NFResponseModels.py | 105 +----- cosipy/threeml/optimized_unbinned_folding.py | 331 +++++++++++------- cosipy/threeml/unbinned_model_folding.py | 70 +++- pyproject.toml | 2 +- 9 files changed, 441 insertions(+), 451 deletions(-) diff --git a/cosipy/background_estimation/nf_unbinned_background.py b/cosipy/background_estimation/nf_unbinned_background.py index e69de29bb..b581c5bba 100644 --- a/cosipy/background_estimation/nf_unbinned_background.py +++ b/cosipy/background_estimation/nf_unbinned_background.py @@ -0,0 +1,118 @@ +from typing import Dict, Iterable, Type, Optional + +from astropy import units as u +import numpy as np + +from cosipy import SpacecraftHistory +from cosipy.interfaces.event import EventInterface +from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface +from cosipy.data_io.EmCDSUnbinnedData import TimeTagEmCDSEventInSCFrameInterface +from cosipy.interfaces.background_interface import BackgroundDensityInterface +from cosipy.util.iterables import asarray + +from importlib.util import find_spec + +if find_spec("torch") is None: + raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") + +from cosipy.background_estimation.NFBackground import NFBackground +import torch + +class FreeNormNFUnbinnedBackground(BackgroundDensityInterface): + + def __init__(self, + model: NFBackground, + data: TimeTagEmCDSEventDataInSCFrameInterface, + sc_history: SpacecraftHistory, + label: str = "bkg_norm"): + + self._expected_counts = None + self._expectation_density = None + self._model = model + self._data = data + self._sc_history = sc_history + + self._accum_livetime = self._sc_history.cumulative_livetime().to_value(u.s) + + self._norm = 1 + self._label = label + self._offset: Optional[float] = 1e-12 + + @property + def event_type(self) -> Type[EventInterface]: + return TimeTagEmCDSEventInSCFrameInterface + + @property + def offset(self) -> Optional[float]: + return self._offset + + @offset.setter + def offset(self, offset: Optional[float]): + if (offset is not None) and (offset < 0): + raise ValueError("The offset cannot be negative.") + self._offset = offset + + @property + def norm(self) -> u.Quantity: + self._update_cache(counts_only=True) + return u.Quantity(self._norm * self._expected_counts/self._accum_livetime, u.Hz) + + @norm.setter + def norm(self, norm: u.Quantity): + self._update_cache(counts_only=True) + self._norm = norm.to_value(u.Hz) * self._accum_livetime/self._expected_counts + + def set_parameters(self, **parameters: u.Quantity) -> None: + self.norm = parameters[self._label] + + @property + def parameters(self) -> Dict[str, u.Quantity]: + return {self._label: self.norm} + + def _integrate_rate(self) -> float: + mid_times = torch.as_tensor((self._sc_history.obstime[:-1] + (self._sc_history.obstime[1:] - self._sc_history.obstime[:-1]) / 2).utc.unix).view(-1, 1) + rate = self._model.evaluate_rate(mid_times) + return torch.sum(rate * torch.as_tensor(self._sc_history.livetime)).item() + + def _compute_density(self): + self._energy_m_keV = torch.as_tensor(asarray(self._data.energy_keV, dtype=np.float32)) + self._phi_rad = torch.as_tensor(asarray(self._data.scattering_angle_rad, dtype=np.float32)) + self._lon_scatt = torch.as_tensor(asarray(self._data.scattered_lon_rad_sc, dtype=np.float32)) + self._lat_scatt = torch.as_tensor(asarray(self._data.scattered_lat_rad_sc, dtype=np.float32)) + source = torch.stack((self._energy_m_keV, self._phi_rad, self._lon_scatt, np.pi/2 - self._lat_scatt), dim=1) + + time = torch.as_tensor(self._data.time.utc.unix).view(-1, 1) + if torch.any((time < self._sc_history.tstart.utc.unix) | (time > self._sc_history.tstop.utc.unix)): + raise ValueError("Input times are outside the spacecraft history range") + interval_ratios = torch.as_tensor(self._sc_history.livetime.to_value(u.s) / self._sc_history.intervals_duration.to_value(u.s)) + factor = torch.searchsorted(torch.as_tensor(self._sc_history.obstime.utc.unix), time.view(-1), right=True) - 1 + + return np.asarray(self._model.evaluate_density(time, source) * self._model.evaluate_rate(time) * interval_ratios[factor], dtype=np.float64) + + def _update_cache(self, counts_only=False): + if self._expected_counts is None: + self._expected_counts = self._integrate_rate() + + if (self._expectation_density is None) and (not counts_only): + active_pool = self._model.active_pool + if not active_pool: + self._model.init_compute_pool() + self._expectation_density = self._compute_density() + if not active_pool: + self._model.shutdown_compute_pool() + + def expected_counts(self) -> float: + self._update_cache() + + return self._expected_counts * self._norm + + def expectation_density(self) -> Iterable[float]: + self._update_cache() + + result = self._expectation_density * self._norm + + if self._offset is not None: + return result + self._offset + else: + return result + \ No newline at end of file diff --git a/cosipy/interfaces/expectation_interface.py b/cosipy/interfaces/expectation_interface.py index 0ee8be283..c01657923 100644 --- a/cosipy/interfaces/expectation_interface.py +++ b/cosipy/interfaces/expectation_interface.py @@ -117,7 +117,7 @@ def __init__(self, *expectations:Tuple[ExpectationDensityInterface, None], vecto Parameters ---------- expectations: Other ExpectationDensityInterface implementations - vectorize: It True (default), it will first cache all the individual expectations on numpy arrays, and then it will sum + vectorize: If True (default), it will first cache all the individual expectations on numpy arrays, and then it will sum them up using numpy's method. The output will also be a numpy. If False, it will query one element from each expectation object a time and sum them up. The output in this case is an Generator. """ diff --git a/cosipy/interfaces/source_response_interface.py b/cosipy/interfaces/source_response_interface.py index ba2e488b9..0984aea9d 100644 --- a/cosipy/interfaces/source_response_interface.py +++ b/cosipy/interfaces/source_response_interface.py @@ -1,6 +1,7 @@ -from typing import Protocol, runtime_checkable +from typing import Protocol, runtime_checkable, Union from astromodels import Model from astromodels.sources import Source +from pathlib import Path from .expectation_interface import BinnedExpectationInterface, ExpectationDensityInterface @@ -64,6 +65,28 @@ class UnbinnedThreeMLSourceResponseInterface(ThreeMLSourceResponseInterface, Exp No new methods. Just the inherited ones. """ +@runtime_checkable +class CachedUnbinnedThreeMLSourceResponseInterface(UnbinnedThreeMLSourceResponseInterface, Protocol): + """ + Guaranteeing that the source response can be cached to and loaded from a file. + """ + + def cache_to_file(self, filename: Union[str, Path]): + """ + Saves the calculated response cache to the specified HDF5 file. + The implementation has to make sure that the source is handled correctly. + """ + + def cache_from_file(self, filename: Union[str, Path]): + """Loads the response cache from the specified HDF5 file.""" + + def init_cache(self): + """ + Initialize the response cache that can be saved to file. + This way there is no need to call expected_counts() or expectation_density() to initialize the cache. + Make sure that repeated calls don't lead to unnecessary recomputations. + """ + @runtime_checkable class BinnedThreeMLSourceResponseInterface(ThreeMLSourceResponseInterface, BinnedExpectationInterface, Protocol): """ diff --git a/cosipy/response/NFBase.py b/cosipy/response/NFBase.py index 881798e7f..979437abe 100644 --- a/cosipy/response/NFBase.py +++ b/cosipy/response/NFBase.py @@ -4,7 +4,6 @@ import numpy as np from tqdm.auto import tqdm import queue -import time from importlib.util import find_spec @@ -133,17 +132,6 @@ def __init__(self, compile_mode: CompileMode, batch_size: int, self._is_cuda = (self._worker_device.type == 'cuda') self.batch_size = batch_size - - if self._is_cuda: - self._compute_stream = torch.cuda.Stream(device=self._worker_device) - self._transfer_stream = torch.cuda.Stream(device=self._worker_device) - self._transfer_ready = [torch.cuda.Event(), torch.cuda.Event()] - self._compute_ready = [torch.cuda.Event(), torch.cuda.Event()] - else: - self._compute_stream = None - self._transfer_stream = None - self._transfer_ready = None - self._compute_ready = None @abstractmethod def _init_model(self, input: Dict) -> Union[nn.Module, Callable]: ... @@ -182,11 +170,6 @@ def batch_size(self, value: int): if not isinstance(value, int) or value <= 0: raise ValueError(f"Batch size must be a positive integer, got {value}") self._batch_size = value - if self._is_cuda: - self._write_gpu_tensors() - - @abstractmethod - def _write_gpu_tensors(self): ... class AreaModel(BaseModel): @abstractmethod @@ -196,24 +179,6 @@ class DensityModel(BaseModel): @property @abstractmethod def source_dim(self) -> int: ... - - def _write_gpu_tensors(self): - self._eval_inputs = [ - tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.source_dim + self.context_dim)) - for _ in range(2) - ] - self._eval_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] - - self._sample_inputs = [ - tuple(torch.empty(self._batch_size, device=self._worker_device) for _ in range(self.context_dim)) - for _ in range(2) - ] - - self._sample_results = [ - (torch.empty((self._batch_size, self.source_dim), device=self._worker_device), - torch.empty(self._batch_size, dtype=torch.bool, device=self._worker_device)) - for _ in range(2) - ] @torch.inference_mode() def sample_density(self, *args: torch.Tensor, @@ -223,109 +188,20 @@ def sample_density(self, *args: torch.Tensor, result = torch.empty((N, self.source_dim), dtype=torch.float32, device="cpu") failed_mask = torch.zeros(N, dtype=torch.bool, device="cpu") - if self._is_cuda: - num_batches = (N + self._batch_size - 1) // self._batch_size - pending_progress = [torch.cuda.Event() for _ in range(num_batches)] - result, failed_mask = result.pin_memory(), failed_mask.pin_memory() - - def enqueue_sample_transfer(slot_idx, start_idx): - end_idx = min(start_idx + self._batch_size, N) - size = end_idx - start_idx - for i in range(self.context_dim): - self._sample_inputs[slot_idx][i][:size].copy_(args[i][start_idx:end_idx], non_blocking=True) - #self._sample_inputs[slot_idx][0][:size].copy_(energy_keV[start_idx:end_idx], non_blocking=True) - #self._sample_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) - #self._sample_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) - - if self._is_cuda and N > 0: - with torch.cuda.stream(self._transfer_stream): - enqueue_sample_transfer(0, 0) - self._transfer_ready[0].record(self._transfer_stream) - - for i, start in enumerate(range(0, N, self._batch_size)): - curr_idx = i % 2 - next_idx = (i + 1) % 2 + for start in range(0, N, self._batch_size): end = min(start + self._batch_size, N) batch_len = end - start - next_start = start + self._batch_size - - if self._is_cuda: - with torch.cuda.stream(self._compute_stream): - self._compute_stream.wait_event(self._transfer_ready[curr_idx]) - - #b_ei, b_az, b_pol = [t[:batch_len] for t in self._sample_inputs[curr_idx]] - # - #b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) - #b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) - # - #b_ctx = torch.cat([ - # (b_az_sc + 1) / 2, - # (b_pol_sc[:, 1:] + 1) / 2, - # (torch.log10(b_ei) / 2 - 1).unsqueeze(1) - #], dim=1).to(torch.float32) - - b_ctx = [t[:batch_len] for t in self._sample_inputs[curr_idx]] - n_ctx = self._transform_context(*b_ctx) - - n_latent = self._model_op(context=n_ctx, mode="sampling", n_samples=batch_len) - - self._sample_results[curr_idx][0][:batch_len] = self._inverse_transform_coordinates(*(n_latent.T), *b_ctx) - self._sample_results[curr_idx][1][:batch_len] = ~self._valid_samples(*(n_latent.T), *b_ctx) - - self._compute_ready[curr_idx].record(self._compute_stream) - - if next_start < N: - with torch.cuda.stream(self._transfer_stream): - enqueue_sample_transfer(next_idx, next_start) - self._transfer_ready[next_idx].record(self._transfer_stream) - - with torch.cuda.stream(self._transfer_stream): - self._transfer_stream.wait_event(self._compute_ready[curr_idx]) - - result[start:end].copy_(self._sample_results[curr_idx][0][:batch_len], non_blocking=True) - failed_mask[start:end].copy_(self._sample_results[curr_idx][1][:batch_len], non_blocking=True) - - pending_progress[i].record(self._transfer_stream) - else: - - #b_ei = energy_keV[start:end].to(self._worker_device) - #b_az, b_pol = dir_az[start:end].to(self._worker_device), dir_polar[start:end].to(self._worker_device) - - #b_az_sc = torch.stack((torch.sin(b_az), torch.cos(b_az)), dim=1) - #b_pol_sc = torch.stack((torch.sin(b_pol), torch.cos(b_pol)), dim=1) - #b_ctx = torch.cat([ - # (b_az_sc + 1) / 2, (b_pol_sc[:, 1:] + 1) / 2, - # (torch.log10(b_ei) / 2 - 1).unsqueeze(1) - #], dim=1).to(torch.float32) - - b_ctx = [t[start:end].to(self._worker_device) for t in args] - n_ctx = self._transform_context(*b_ctx) - n_latent = self._model_op(context=b_ctx, mode="sampling", n_samples=batch_len) - result[start:end] = self._inverse_transform_coordinates(*(n_latent.T), *b_ctx) - failed_mask[start:end] = ~self._valid_samples(*(n_latent.T), *b_ctx) - - if progress_callback is not None: - amount = batch_len - torch.sum(failed_mask[start:end]).item() - progress_callback(amount) - - if self._is_cuda: + b_ctx = [t[start:end].to(self._worker_device) for t in args] + n_ctx = self._transform_context(*b_ctx) + + n_latent = self._model_op(context=n_ctx, mode="sampling", n_samples=batch_len) + result[start:end] = self._inverse_transform_coordinates(*(n_latent.T), *b_ctx) + failed_mask[start:end] = ~self._valid_samples(*(n_latent.T), *b_ctx) + if progress_callback is not None: - i = 0 - while i < len(pending_progress): - event = pending_progress[i] - if event.query(): - start = self._batch_size * i - end = min(start + self._batch_size, N) - num_failed = torch.sum(failed_mask[start:end]).item() - amount = (end - start) - num_failed - if amount > 0: - progress_callback(amount) - - i += 1 - else: - time.sleep(0.01) - torch.cuda.synchronize(self._worker_device) + amount = batch_len - torch.sum(failed_mask[start:end]).item() + progress_callback(amount) if torch.any(failed_mask): result[failed_mask] = self.sample_density(*[t[failed_mask] for t in args], progress_callback=progress_callback) @@ -351,74 +227,16 @@ def evaluate_density(self, *args: torch.Tensor, N = args[0].shape[0] result = torch.empty(N, dtype=torch.float32, device="cpu") - if self._is_cuda: - num_batches = (N + self._batch_size - 1) // self._batch_size - pending_progress = [torch.cuda.Event() for _ in range(num_batches)] - result = result.pin_memory() - - def enqueue_eval_transfer(slot_idx, start_idx): - end_idx = min(start_idx + self._batch_size, N) - size = end_idx - start_idx - for i in range(self.source_dim + self.context_dim): - self._eval_inputs[slot_idx][i][:size].copy_(args[i][start_idx:end_idx], non_blocking=True) - - if self._is_cuda and N > 0: - with torch.cuda.stream(self._transfer_stream): - enqueue_eval_transfer(0, 0) - self._transfer_ready[0].record(self._transfer_stream) - - for i, start in enumerate(range(0, N, self._batch_size)): - curr_idx = i % 2 - next_idx = (i + 1) % 2 + for start in range(0, N, self._batch_size): end = min(start + self._batch_size, N) batch_len = end - start - next_start = start + self._batch_size - if self._is_cuda: - with torch.cuda.stream(self._compute_stream): - self._compute_stream.wait_event(self._transfer_ready[curr_idx]) - - ctx, src, jac = self._transform_coordinates(*[t[:batch_len] for t in self._eval_inputs[curr_idx]]) - - torch.mul(self._model_op(src, ctx, mode="inference"), jac, out=self._eval_results[curr_idx][:batch_len]) - - self._compute_ready[curr_idx].record(self._compute_stream) - - if next_start < N: - with torch.cuda.stream(self._transfer_stream): - enqueue_eval_transfer(next_idx, next_start) - - self._transfer_ready[next_idx].record(self._transfer_stream) - - with torch.cuda.stream(self._transfer_stream): - self._transfer_stream.wait_event(self._compute_ready[curr_idx]) - - result[start:end].copy_(self._eval_results[curr_idx][:batch_len], non_blocking=True) - - pending_progress[i].record(self._transfer_stream) - else: - ctx, src, jac = self._transform_coordinates(*[t[start:end].to(self._worker_device) for t in args]) - result[start:end] = self._model_op(src, ctx, mode="inference") * jac - - if progress_callback is not None: - progress_callback(batch_len) - - if self._is_cuda: + ctx, src, jac = self._transform_coordinates(*[t[start:end].to(self._worker_device) for t in args]) + result[start:end] = self._model_op(src, ctx, mode="inference") * jac + if progress_callback is not None: - i = 0 - while i < len(pending_progress): - event = pending_progress[i] - if event.query(): - start = self._batch_size * i - end = min(start + self._batch_size, N) - amount = (end - start) - if amount > 0: - progress_callback(amount) - - i += 1 - else: - time.sleep(0.01) - torch.cuda.synchronize(self._worker_device) + progress_callback(batch_len) + return result class RateModel(ABC): @@ -517,9 +335,6 @@ def evaluate_density_task(args: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) sub_context = context[indices, :] sub_source = source[indices, :] - if torch.device(NFWorkerState.worker_device).type == 'cuda': - sub_context = sub_context.pin_memory() - sub_source = sub_source.pin_memory() cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None return NFWorkerState.density_module.evaluate_density(sub_context, sub_source, progress_callback=cb) @@ -528,8 +343,6 @@ def sample_density_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor context, indices = args sub_context = context[indices, :] - if torch.device(NFWorkerState.worker_device).type == 'cuda': - sub_context = sub_context.pin_memory() cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None return NFWorkerState.density_module.sample_density(sub_context, progress_callback=cb) @@ -691,7 +504,7 @@ def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str tasks = [(context, idx) for idx in indices] async_result = self._pool.map_async(sample_density_task, tasks) - with tqdm(total=n_data, disable=(not self.show_progress), desc="Sampling the density", unit="calls", leave=False) as pbar: + with tqdm(total=n_data, disable=(not self.show_progress), desc="Sampling the density", unit="calls", leave=False, smoothing=0.20) as pbar: while not async_result.ready(): try: while True: @@ -708,9 +521,6 @@ def sample_density(self, context: torch.Tensor, devices: Optional[List[Union[str results = async_result.get() return torch.cat(results, dim=0) - - #results = self._pool.map(sample_density_task, tasks) - #return torch.cat(results, dim=0) finally: if temp_pool: @@ -736,7 +546,7 @@ def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, devices: tasks = [(context, source, idx) for idx in indices] async_result = self._pool.map_async(evaluate_density_task, tasks) - with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the density", unit="calls", leave=False) as pbar: + with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the density", unit="calls", leave=False, smoothing=0.20) as pbar: while not async_result.ready(): try: while True: @@ -753,9 +563,6 @@ def evaluate_density(self, context: torch.Tensor, source: torch.Tensor, devices: results = async_result.get() return torch.cat(results, dim=0) - - #results = self._pool.map(evaluate_density_task, tasks) - #return torch.cat(results, dim=0) finally: if temp_pool: diff --git a/cosipy/response/NFResponse.py b/cosipy/response/NFResponse.py index a18363773..12514b9ea 100644 --- a/cosipy/response/NFResponse.py +++ b/cosipy/response/NFResponse.py @@ -81,21 +81,18 @@ def init_response_worker(device_queue: mp.Queue, progress_queue: mp.Queue, major density_input, density_batch_size, density_compile_mode, ResponseDensityApproximation) - #NFWorkerState.density_module = ResponseDensityApproximation(major_version, density_input, NFWorkerState.worker_device, density_batch_size, density_compile_mode) NFWorkerState.area_module = AreaApproximation(major_version, area_input, NFWorkerState.worker_device, area_batch_size, area_compile_mode) def evaluate_area_task(args: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: context, indices = args sub_context = context[indices, :] - if torch.device(NFWorkerState.worker_device).type == 'cuda': - sub_context = sub_context.pin_memory() cb = lambda n: NFWorkerState.progress_queue.put(n) if hasattr(NFWorkerState, 'progress_queue') else None return NFWorkerState.area_module.evaluate_effective_area(sub_context, progress_callback=cb) class NFResponse(NFBase): - def __init__(self, path_to_model: Union[str, Path], area_batch_size: int = 100_000, density_batch_size: int = 100_000, + def __init__(self, path_to_model: Union[str, Path], area_batch_size: int = 300_000, density_batch_size: int = 100_000, devices: Optional[List[Union[str, int, torch.device]]] = None, area_compile_mode: CompileMode = "max-autotune-no-cudagraphs", density_compile_mode: CompileMode = "default", show_progress: bool = True): @@ -165,7 +162,7 @@ def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[ tasks = [(context, idx) for idx in indices] async_result = self._pool.map_async(evaluate_area_task, tasks) - with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the effective area", unit="calls", leave=False) as pbar: + with tqdm(total=n_data, disable=(not self.show_progress), desc="Evaluating the effective area", unit="calls", leave=False, smoothing=0.20) as pbar: while not async_result.ready(): try: while True: @@ -182,9 +179,6 @@ def evaluate_effective_area(self, context: torch.Tensor, devices: Optional[List[ results = async_result.get() return torch.cat(results, dim=0) - - #results = self._pool.map(evaluate_area_task, tasks) - #return torch.cat(results, dim=0) finally: if temp_pool: diff --git a/cosipy/response/NFResponseModels.py b/cosipy/response/NFResponseModels.py index 50bd293a5..5626938f3 100644 --- a/cosipy/response/NFResponseModels.py +++ b/cosipy/response/NFResponseModels.py @@ -1,6 +1,5 @@ import numpy as np import healpy as hp -import time from typing import Union, Tuple, Dict, Optional, Callable @@ -10,7 +9,6 @@ if any(find_spec(pkg) is None for pkg in ["torch", "normflows", "sphericart.torch"]): raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") - from .NFBase import CompileMode, build_c_arqs_flow, build_cmlp_diaggaussian_base, NNDensityInferenceWrapper, AreaModel, DensityModel import sphericart.torch import normflows as nf @@ -32,15 +30,6 @@ def _init_model(self, input: Dict): return self._horner_eval - def _write_gpu_tensors(self): - self._gpu_inputs = [ - (torch.empty(self._batch_size, device=self._worker_device), - torch.empty(self._batch_size, device=self._worker_device), - torch.empty(self._batch_size, device=self._worker_device)) - for _ in range(2) - ] - self._gpu_results = [torch.empty(self._batch_size, device=self._worker_device) for _ in range(2)] - @property def context_dim(self) -> int: return 3 @@ -90,95 +79,19 @@ def get_batch(start_idx): dir_az[start_idx:end_idx].to(self._worker_device), dir_polar[start_idx:end_idx].to(self._worker_device) ) - - if self._is_cuda: - num_batches = (N + self._batch_size - 1) // self._batch_size - pending_progress = [torch.cuda.Event() for _ in range(num_batches)] - ei_norm = ei_norm.pin_memory() - result = result.pin_memory() - - def enqueue_transfer(slot_idx, start_idx): - end_idx = min(start_idx + self._batch_size, N) - size = end_idx - start_idx - self._gpu_inputs[slot_idx][0][:size].copy_(ei_norm[start_idx:end_idx], non_blocking=True) - self._gpu_inputs[slot_idx][1][:size].copy_(dir_az[start_idx:end_idx], non_blocking=True) - self._gpu_inputs[slot_idx][2][:size].copy_(dir_polar[start_idx:end_idx], non_blocking=True) - if self._is_cuda and (N > 0): - with torch.cuda.stream(self._transfer_stream): - enqueue_transfer(0, 0) - self._transfer_ready[0].record(self._transfer_stream) - - torch.cuda.set_sync_debug_mode(1) - for i, start in enumerate(range(0, N, self._batch_size)): - print(f"Loop: {i}", flush=True) - curr_idx = i % 2 - next_idx = (i + 1) % 2 + for start in range(0, N, self._batch_size): end = min(start + self._batch_size, N) batch_len = end - start - next_start = start + self._batch_size - if self._is_cuda: - with torch.cuda.stream(self._compute_stream): - self._compute_stream.wait_event(self._transfer_ready[curr_idx]) - - ei_b, az_b, pol_b = [t[:batch_len] for t in self._gpu_inputs[curr_idx]] - t0 = time.time() - poly_b = self._model_op(ei_b) - t1 = time.time() - ylm_b = self._compute_spherical_harmonics(az_b, pol_b) - t2 = time.time() - - torch.sum(poly_b * ylm_b, dim=1, out=self._gpu_results[curr_idx][:batch_len]) - t3 = time.time() - self._compute_ready[curr_idx].record(self._compute_stream) - - if next_start < N: - with torch.cuda.stream(self._transfer_stream): - t4 = time.time() - enqueue_transfer(next_idx, next_start) - t5 = time.time() - self._transfer_ready[next_idx].record(self._transfer_stream) - - with torch.cuda.stream(self._transfer_stream): - self._transfer_stream.wait_event(self._compute_ready[curr_idx]) - t6 = time.time() - result[start:end].copy_(self._gpu_results[curr_idx][:batch_len], non_blocking=True) - t7 = time.time() - pending_progress[i].record(self._transfer_stream) - if i < 5 and self._worker_device == torch.device("cuda:0"): - print(f"Run: {i} on device {self._worker_device}", flush=True) - print(f"model_op: {t1-t0:.5f}s", flush=True) - print(f"spherical_harmonics: {t2-t1:.5f}s", flush=True) - print(f"sum: {t3-t2:.5f}s", flush=True) - print(f"enqueue_transfer: {t5-t4:.5f}s", flush=True) - print(f"result copy: {t7-t6:.5f}s", flush=True) - else: - ei_b, az_b, pol_b = get_batch(start) - - poly_b = self._model_op(ei_b) - ylm_b = self._compute_spherical_harmonics(az_b, pol_b) - result[start:end] = torch.sum(poly_b * ylm_b, dim=1) - - if progress_callback is not None: - progress_callback(batch_len) - print("Done", flush=True) - if self._is_cuda: + ei_b, az_b, pol_b = get_batch(start) + + poly_b = self._model_op(ei_b) + ylm_b = self._compute_spherical_harmonics(az_b, pol_b) + result[start:end] = torch.sum(poly_b * ylm_b, dim=1) + if progress_callback is not None: - i = 0 - while i < len(pending_progress): - event = pending_progress[i] - if event.query(): - start = self._batch_size * i - end = min(start + self._batch_size, N) - amount = (end - start) - if amount > 0: - progress_callback(amount) - - i += 1 - else: - time.sleep(0.01) - torch.cuda.synchronize(self._worker_device) + progress_callback(batch_len) return torch.clamp(result, min=0) @@ -310,8 +223,6 @@ def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, tor ) jac = 1.0 / (ei * torch.sin(theta_raw + phi) * 4 * np.pi**3) - for i in range(torch.sum(torch.isinf(jac))): - print("You found a rare event where phi_geo is 0 or 180 degrees!", flush=True) jac[torch.isinf(jac)] = 0.0 ctx = self._transform_context(dir_az, dir_pol, ei) diff --git a/cosipy/threeml/optimized_unbinned_folding.py b/cosipy/threeml/optimized_unbinned_folding.py index 93be29fd1..9b7872314 100644 --- a/cosipy/threeml/optimized_unbinned_folding.py +++ b/cosipy/threeml/optimized_unbinned_folding.py @@ -1,17 +1,21 @@ import copy import os import json -from typing import Optional, Iterable, Type, Tuple, List +from typing import Optional, Iterable, Type, Tuple, List, Union +from pathlib import Path +from tqdm.auto import tqdm import numpy as np import h5py from astromodels import PointSource from astropy.coordinates import CartesianRepresentation from executing import Source +from scoords import SpacecraftFrame from cosipy import SpacecraftHistory +from cosipy.interfaces.source_response_interface import CachedUnbinnedThreeMLSourceResponseInterface from cosipy.data_io.EmCDSUnbinnedData import EmCDSEventDataInSCFrameFromArrays -from cosipy.interfaces import UnbinnedThreeMLSourceResponseInterface, EventInterface +from cosipy.interfaces import EventInterface from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface from cosipy.interfaces.instrument_response_interface import FarFieldSpectralInstrumentResponseFunctionInterface @@ -30,14 +34,16 @@ raise RuntimeError("Install cosipy with [ml] optional package to use this feature.") import torch +from cosipy.response.nf_instrument_response_function import UnpolarizedNFFarFieldInstrumentResponseFunction -class UnbinnedThreeMLPointSourceResponseIRFAdaptive(UnbinnedThreeMLSourceResponseInterface): +class UnbinnedThreeMLPointSourceResponseIRFAdaptive(CachedUnbinnedThreeMLSourceResponseInterface): def __init__(self, data: TimeTagEmCDSEventDataInSCFrameInterface, irf: FarFieldSpectralInstrumentResponseFunctionInterface, - sc_history: SpacecraftHistory,): + sc_history: SpacecraftHistory, + show_progress: bool = True): """ Will fold the IRF with the point source spectrum by evaluating the IRF at Ei positions adaptively chosen based on characteristic IRF features @@ -59,6 +65,7 @@ def __init__(self, self._data = data self._irf = irf self._sc_ori = sc_history + self.show_progress = show_progress # Default parameters for irf energy node placement self._total_energy_nodes = (60, 500) @@ -66,6 +73,7 @@ def __init__(self, self._peak_widths = (0.04, 0.1) self._energy_range = (100., 10_000.) self._batch_size = 1_000_000 + self._offset: Optional[float] = 1e-12 # Placeholder for node pool - stored as Tensors self._width_tensor: Optional[torch.Tensor] = None @@ -90,23 +98,34 @@ def __init__(self, self._exp_events: Optional[float] = None self._exp_density: Optional[torch.Tensor] = None - # Precomputed spacecraft history + # Precomputed spacecraft history - Midpoint self._mid_times = self._sc_ori.obstime[:-1] + (self._sc_ori.obstime[1:] - self._sc_ori.obstime[:-1]) / 2 self._sc_ori_center = self._sc_ori.interp(self._mid_times) + # Precomputed spacecraft history - Simpson + # t_edges = self._sc_ori.obstime + # t_mids = t_edges[:-1] + (t_edges[1:] - t_edges[:-1]) / 2 + # all_t, inv_indices = np.unique(Time([t_edges, t_mids]), return_inverse=True) + # all_t = Time(all_t) + # self._sc_ori_simpson = self._sc_ori.interp(all_t) + # edge_indices = inv_indices[:len(t_edges)] + # mid_indices = inv_indices[len(t_edges):] + # livetime = self._sc_ori.livetime.to_value(u.s) + # self._unique_time_weights = np.zeros(len(all_t), dtype=np.float32) + # np.add.at(self._unique_time_weights, edge_indices[:-1], livetime / 6.0) + # np.add.at(self._unique_time_weights, edge_indices[1:], livetime / 6.0) + # np.add.at(self._unique_time_weights, mid_indices, 4.0 * livetime / 6.0) + data_times = self._data.time self._n_events = self._data.nevents - self._unique_mjds, self._inv_idx = np.unique(data_times.mjd, return_inverse=True) - unique_times_obj = Time(self._unique_mjds, format='mjd') - + self._unique_unix, self._inv_idx = np.unique(data_times.utc.unix, return_inverse=True) + unique_times_obj = Time(self._unique_unix, format='unix', scale='utc') self._sc_ori_unique = self._sc_ori.interp(unique_times_obj) interval_ratios = (self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) - bin_indices = np.searchsorted(self._sc_ori.obstime.mjd, self._unique_mjds) - 1 - + bin_indices = np.searchsorted(self._sc_ori.obstime.utc.unix, self._unique_unix, side="right") - 1 bin_indices = np.clip(bin_indices, 0, len(self._sc_ori.livetime) - 1) unique_ratio = interval_ratios[bin_indices] - self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) self._energy_m_keV = torch.as_tensor(asarray(self._data.energy_keV, dtype=np.float32)) @@ -119,74 +138,105 @@ def __init__(self, self._cos_lon_scatt = torch.cos(self._lon_scatt) self._sin_lon_scatt = torch.sin(self._lon_scatt) - #unique_ratio = np.interp(self._unique_mjds, - # self._mid_times.mjd, - # self._sc_ori.livetime.to_value(u.s) / self._sc_ori.intervals_duration.to_value(u.s)) - # - #self._livetime_ratio = unique_ratio[self._inv_idx].astype(np.float32) - - #wrong_order = np.where(((data_times[1:] - data_times[:-1]) <= 0))[0] - #data_times[wrong_order + 1] = data_times[wrong_order + 1] + 1 - #self._sc_ori_data = self._sc_ori.interp(data_times) - - #ratio = np.interp(self._data.time.mjd, - # self._mid_times.mjd, - # self._sc_ori.livetime.to_value(u.s)/self._sc_ori.intervals_duration.to_value(u.s)) - #self._livetime_ratio = ratio.astype(np.float32) - @property def event_type(self) -> Type[EventInterface]: return TimeTagEmCDSEventInSCFrameInterface + @property + def total_energy_nodes(self) -> Tuple[int, int]: return self._total_energy_nodes + @total_energy_nodes.setter + def total_energy_nodes(self, val): self.set_integration_parameters(total_energy_nodes=val) + + @property + def peak_nodes(self) -> Tuple[int, int]: return self._peak_nodes + @peak_nodes.setter + def peak_nodes(self, val): self.set_integration_parameters(peak_nodes=val) + + @property + def peak_widths(self) -> Tuple[float, float]: return self._peak_widths + @peak_widths.setter + def peak_widths(self, val): self.set_integration_parameters(peak_widths=val) + + @property + def energy_range(self) -> Tuple[float, float]: return self._energy_range + @energy_range.setter + def energy_range(self, val): self.set_integration_parameters(energy_range=val) + + @property + def batch_size(self) -> int: return self._batch_size + @batch_size.setter + def batch_size(self, val): self.set_integration_parameters(batch_size=val) + + @property + def offset(self) -> Optional[float]: return self._offset + @offset.setter + def offset(self, val): self.set_integration_parameters(offset=val) + + @property + def show_progress(self) -> bool: return self._show_progress + + @show_progress.setter + def show_progress(self, value: bool): + if not isinstance(value, bool): + raise ValueError("show_progress must be a boolean") + self._show_progress = value + def set_integration_parameters(self, - total_energy_nodes: Tuple[int, int] = (60, 500), - peak_nodes: Tuple[int, int] = (18, 12), - peak_widths: Tuple[float, float] = (0.04, 0.1), - energy_range: Tuple[float, float] = (100., 10_000.), - batch_size: int = 1_000_000,): - - # Reset caches if parameters change - if (peak_nodes != self._peak_nodes - or - peak_widths != self._peak_widths - or - total_energy_nodes[0] != self._total_energy_nodes[0]): - self._irf_cache = None - self._irf_energy_node_cache = None - self._width_tensor = None - self._nodes_primary = None - self._nodes_secondary = None - self._nodes_bkg_1 = None - self._nodes_bkg_2 = None - self._nodes_bkg_3 = None - - if (total_energy_nodes[1] != self._total_energy_nodes[1]): - self._area_cache = None - self._area_energy_node_cache = None - - if (energy_range != self._energy_range): - self._irf_cache = None - self._irf_energy_node_cache = None - self._area_cache = None - self._area_energy_node_cache = None - - if total_energy_nodes[0] < (peak_nodes[0] + 2 * peak_nodes[1] + 3): - raise ValueError("To many nodes per peak compared to the total number or peaks!") - - if (total_energy_nodes[0] < 1) or (total_energy_nodes[1] < 1): + total_energy_nodes: Optional[Tuple[int, int]] = None, + peak_nodes: Optional[Tuple[int, int]] = None, + peak_widths: Optional[Tuple[float, float]] = None, + energy_range: Optional[Tuple[float, float]] = None, + batch_size: Optional[int] = None, + offset: Optional[float] = -1.0): + + new_total = total_energy_nodes or self._total_energy_nodes + new_peak_nodes = peak_nodes or self._peak_nodes + new_peak_widths = peak_widths or self._peak_widths + new_range = energy_range or self._energy_range + new_batch = batch_size if batch_size is not None else self._batch_size + new_offset = offset if offset != -1.0 else self._offset + + irf_affected = ( + new_peak_nodes != self._peak_nodes or + new_peak_widths != self._peak_widths or + new_total[0] != self._total_energy_nodes[0] or + new_range != self._energy_range + ) + + area_affected = ( + new_total[1] != self._total_energy_nodes[1] or + new_range != self._energy_range + ) + + if irf_affected: + self._irf_cache = self._irf_energy_node_cache = self._width_tensor = None + self._nodes_primary = self._nodes_secondary = None + self._nodes_bkg_1 = self._nodes_bkg_2 = self._nodes_bkg_3 = None + + if area_affected: + self._area_cache = self._area_energy_node_cache = None + + if new_total[0] < (new_peak_nodes[0] + 2 * new_peak_nodes[1] + 3): + raise ValueError("Too many nodes per peak compared to the total number or peaks!") + + if any(n < 1 for n in new_total): raise ValueError("The number of energy nodes must be at least 1.") - if energy_range[0] >= energy_range[1]: + if new_range[0] >= new_range[1]: raise ValueError("The initial energy interval needs to be increasing!") - if (batch_size < total_energy_nodes[0]) or (batch_size < total_energy_nodes[1]): + if new_batch < max(new_total): raise ValueError("The batch size cannot be smaller than the number of integration nodes.") + + if (new_offset is not None) and (new_offset < 0): + raise ValueError("The offset cannot be negative.") - self._total_energy_nodes = total_energy_nodes - self._peak_nodes = peak_nodes - self._peak_widths = peak_widths - self._energy_range = energy_range - self._batch_size = batch_size + self._total_energy_nodes = new_total + self._peak_nodes = new_peak_nodes + self._peak_widths = new_peak_widths + self._energy_range = new_range + self._batch_size = new_batch + self._offset = new_offset @staticmethod def _build_nodes(degree: int) -> Tuple[torch.Tensor, torch.Tensor]: @@ -285,7 +335,7 @@ def set_source(self, source: Source): self._source = source - def copy(self) -> UnbinnedThreeMLSourceResponseInterface: + def copy(self) -> CachedUnbinnedThreeMLSourceResponseInterface: new_instance = copy.copy(self) new_instance.clear_cache() new_instance._source = None @@ -294,11 +344,18 @@ def copy(self) -> UnbinnedThreeMLSourceResponseInterface: @staticmethod def _earth_occ(source_coord: SkyCoord, ori: SpacecraftHistory) -> np.ndarray: - gcrs_cart = ori.location.represent_as(CartesianRepresentation) - dist_earth_center = gcrs_cart.norm() - max_angle = np.pi*u.rad - np.arcsin(c.R_earth/dist_earth_center) + dist_earth_center = ori.location.spherical.distance.km + max_angle = np.pi - np.arcsin(c.R_earth.to(u.km).value/dist_earth_center) src_angle = source_coord.separation(ori.earth_zenith) - return (src_angle < max_angle).astype(np.float32) + return (src_angle.to(u.rad).value < max_angle).astype(np.float32) + + @staticmethod + def _get_target_in_sc_frame(source_coord: SkyCoord, ori: SpacecraftHistory) -> SkyCoord: + src_in_sc_frame = SkyCoord(np.dot(ori.attitude.rot.inv().as_matrix(), source_coord.transform_to(ori.attitude.frame).cartesian.xyz.value), + representation_type = 'cartesian', frame = SpacecraftFrame()) + + src_in_sc_frame.representation_type = 'spherical' + return src_in_sc_frame def _compute_area(self): coord = self._source.position.sky_coord @@ -316,10 +373,17 @@ def _compute_area(self): e_w = (np.log(10) * self._area_energy_node_cache * (w * scale)).astype(np.float32).reshape(1, -1) e_n = self._area_energy_node_cache.astype(np.float32) - sc_coord_sph = self._sc_ori_center.get_target_in_sc_frame(coord) + # Midpoint + sc_coord_sph = self._get_target_in_sc_frame(coord, self._sc_ori_center) earth_occ_index = self._earth_occ(coord, self._sc_ori_center) - - time_weights = (self._sc_ori.livetime.to_value(u.s)).astype(np.float32) * earth_occ_index + + combined_time_weights = (self._sc_ori.livetime.to_value(u.s)).astype(np.float32) * earth_occ_index + + # Simpson + # sc_coord_sph = self._sc_ori_simpson.get_target_in_sc_frame(coord) + # earth_occ_index = self._earth_occ(coord, self._sc_ori_simpson) + + # combined_time_weights = (self._unique_time_weights * earth_occ_index).astype(np.float32) lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) @@ -333,15 +397,16 @@ def _compute_area(self): batch_lons_buffer = np.empty(max_batch_total, dtype=np.float32) batch_lats_buffer = np.empty(max_batch_total, dtype=np.float32) batch_energies_buffer = np.empty(max_batch_total, dtype=np.float32) - - for i in range(0, n_time, batch_size_time): + + for i in tqdm(range(0, n_time, batch_size_time), + disable=(not self.show_progress), + desc="Caching the effective area", + smoothing=0.2, + leave=False): start = i end = min(i + batch_size_time, n_time) current_n_time = end - start current_total = current_n_time * n_energy - - #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lons_buffer[:current_total]) - #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lats_buffer[:current_total]) batch_lons_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] batch_lats_buffer[:current_total].reshape(current_n_time, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] @@ -358,7 +423,7 @@ def _compute_area(self): total_area += np.einsum('ij,i,j->j', eff_areas_grid, - time_weights[start:end], + combined_time_weights[start:end], e_w.ravel()) self._area_cache = total_area @@ -382,15 +447,12 @@ def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, n_res, w_res = self._scale_nodes_center(E1, E2, EC, *self._nodes_primary) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E1, E2, EC, *self._nodes_primary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_bkg_1[0].shape[1] n_res, w_res = self._scale_nodes_exp(E2, Emax, *self._nodes_bkg_1) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E2, Emax, *self._nodes_bkg_1, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) elif mode == 2: center_peak = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 @@ -410,29 +472,24 @@ def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_bkg_2[0][0].shape[1] n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_2[0]) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_2[0], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_secondary[0].shape[1] n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_bkg_2[1][0].shape[1] n_res, w_res = self._scale_nodes_exp(E4, Emax, *self._nodes_bkg_2[1]) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E4, Emax, *self._nodes_bkg_2[1], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) elif mode == 3: center_peak_1 = (sorted_peaks[:, 0] + sorted_peaks[:, 1]) / 2 @@ -454,43 +511,37 @@ def _fill_nodes(self, nodes_out: torch.Tensor, weights_out: torch.Tensor, n_res, w_res = self._scale_nodes_center(E1, E2, EC1, *self._nodes_primary) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E1, E2, EC1, *self._nodes_primary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_bkg_3[0][0].shape[1] n_res, w_res = self._scale_nodes_exp(E2, E3, *self._nodes_bkg_3[0]) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E2, E3, *self._nodes_bkg_3[0], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_secondary[0].shape[1] n_res, w_res = self._scale_nodes_center(E3, E4, EC2, *self._nodes_secondary) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E3, E4, EC2, *self._nodes_secondary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_bkg_3[1][0].shape[1] n_res, w_res = self._scale_nodes_exp(E4, E5, *self._nodes_bkg_3[1]) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E4, E5, *self._nodes_bkg_3[1], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_secondary[0].shape[1] n_res, w_res = self._scale_nodes_center(E5, E6, EC3, *self._nodes_secondary) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_center_inplace(E5, E6, EC3, *self._nodes_secondary, - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + c += w w = self._nodes_bkg_3[2][0].shape[1] n_res, w_res = self._scale_nodes_exp(E6, Emax, *self._nodes_bkg_3[2]) nodes_out[indices, c:c+w] = n_res weights_out[indices, c:c+w] = w_res - #self._scale_nodes_exp_inplace(E6, Emax, *self._nodes_bkg_3[2], - # nodes_out[indices, c:c+w], weights_out[indices, c:c+w]) + def _get_nodes(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor, phi_geo_rad: torch.Tensor, phi_igeo_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -545,7 +596,7 @@ def _get_nodes(self, energy_m_keV: torch.Tensor, phi_rad: torch.Tensor, return nodes, weights - def _get_CDS_coordinates(self, lon_src_rad: torch.Tensor, lat_src_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def _get_CDS_coordinates(self, lon_src_rad: torch.Tensor, lat_src_rad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: cos_lat_src = torch.cos(lat_src_rad) sin_lat_src = torch.sin(lat_src_rad) cos_lon_src = torch.cos(lon_src_rad) @@ -584,8 +635,12 @@ def _compute_density(self): batch_phi_buffer = np.empty(buffer_size, dtype=np.float32) batch_lon_scatt_buffer = np.empty(buffer_size, dtype=np.float32) batch_lat_scatt_buffer = np.empty(buffer_size, dtype=np.float32) - - for i in range(0, self._n_events, batch_size_events): + + for i in tqdm(range(0, self._n_events, batch_size_events), + disable=(not self.show_progress), + desc="Caching the response", + smoothing=0.2, + leave=False): start = i end = min(i + batch_size_events, self._n_events) current_n = end - start @@ -600,9 +655,6 @@ def _compute_density(self): if batch_size_events >= self._n_events: self._irf_energy_node_cache = np.asarray(nodes) - - #np.repeat(lon_ph_rad[start:end], n_energy, out=batch_lon_src_buffer[:current_total]) - #np.repeat(lat_ph_rad[start:end], n_energy, out=batch_lat_src_buffer[:current_total]) batch_lon_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] batch_lat_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] @@ -612,16 +664,12 @@ def _compute_density(self): batch_lat_scatt_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._lat_scatt[start:end, np.newaxis]) batch_phi_buffer[:current_total].reshape(current_n, n_energy)[:] = np.asarray(self._phi_rad[start:end, np.newaxis]) - #np.repeat(np.asarray(self._energy_m_keV[start:end]), n_energy, out=batch_energy_buffer[:current_total]) - #np.repeat(np.asarray(self._lon_scatt[start:end]), n_energy, out=batch_lon_scatt_buffer[:current_total]) - #np.repeat(np.asarray(self._lat_scatt[start:end]), n_energy, out=batch_lat_scatt_buffer[:current_total]) - #np.repeat(np.asarray(self._phi_rad[start:end]), n_energy, out=batch_phi_buffer[:current_total]) - photons = PhotonListWithDirectionAndEnergyInSCFrame( batch_lon_src_buffer[:current_total], batch_lat_src_buffer[:current_total], np.asarray(nodes).ravel() ) + events = EmCDSEventDataInSCFrameFromArrays( batch_energy_buffer[:current_total], batch_lon_scatt_buffer[:current_total], @@ -649,8 +697,7 @@ def _update_cache(self): source_coord = self._source.position.sky_coord if (self._sc_coord_sph_cache is None) or (source_coord != self._last_convolved_source_skycoord): - #self._sc_coord_sph_cache = self._sc_ori_data.get_target_in_sc_frame(source_coord) - self._sc_coord_sph_cache = self._sc_ori_unique.get_target_in_sc_frame(source_coord)[self._inv_idx] + self._sc_coord_sph_cache = self._get_target_in_sc_frame(source_coord, self._sc_ori_unique)[self._inv_idx] no_recalculation = ((source_coord == self._last_convolved_source_skycoord) and @@ -669,6 +716,12 @@ def _update_cache(self): if no_recalculation: return else: + active_pool = True + if isinstance(self._irf, UnpolarizedNFFarFieldInstrumentResponseFunction): + active_pool = self._irf.active_pool + if not active_pool: + self._irf.init_compute_pool() + if area_recalculation: self._compute_area() @@ -676,16 +729,22 @@ def _update_cache(self): self._init_node_pool() self._compute_density() + if not active_pool: + self._irf.shutdown_compute_pool() + self._last_convolved_source_skycoord = source_coord.copy() - def cache_to_file(self, filename: str): - with h5py.File(filename, 'w') as f: + def cache_to_file(self, filename: Union[str, Path]): + with h5py.File(str(filename), 'w') as f: f.attrs['total_energy_nodes'] = self._total_energy_nodes f.attrs['peak_nodes'] = self._peak_nodes f.attrs['peak_widths'] = self._peak_widths f.attrs['energy_range'] = self._energy_range f.attrs['batch_size'] = self._batch_size + if self._offset is not None: + f.attrs['offset'] = self._offset + if self._irf_cache is not None: f.create_dataset('irf_cache', data=self._irf_cache.numpy(), compression='gzip', compression_opts=4) @@ -722,18 +781,25 @@ def cache_to_file(self, filename: str): f.attrs['last_convolved_lon_deg'] = sc.spherical.lon.deg f.attrs['last_convolved_lat_deg'] = sc.spherical.lat.deg f.attrs['last_convolved_frame'] = sc.frame.name + if hasattr(sc, 'equinox'): + f.attrs['last_convolved_equinox'] = sc.equinox.value - def cache_from_file(self, filename: str): - if not os.path.exists(filename): - raise FileNotFoundError(f"Cache file {filename} not found.") + def cache_from_file(self, filename: Union[str, Path]): + if not os.path.exists(str(filename)): + raise FileNotFoundError(f"Cache file {str(filename)} not found.") - with h5py.File(filename, 'r') as f: + with h5py.File(str(filename), 'r') as f: self._total_energy_nodes = tuple(f.attrs['total_energy_nodes']) self._peak_nodes = tuple(f.attrs['peak_nodes']) self._peak_widths = tuple(f.attrs['peak_widths']) self._energy_range = tuple(f.attrs['energy_range']) self._batch_size = int(f.attrs['batch_size']) + if 'offset' in f.attrs: + self._offset = f.attrs['offset'] + else: + self._offset = None + if 'irf_cache' in f: self._irf_cache = torch.from_numpy(f['irf_cache'][:]) else: @@ -778,7 +844,9 @@ def cache_from_file(self, filename: str): lon = f.attrs['last_convolved_lon_deg'] lat = f.attrs['last_convolved_lat_deg'] frame = f.attrs['last_convolved_frame'] - self._last_convolved_source_skycoord = SkyCoord(lon, lat, unit='deg', frame=frame) + equinox = f.attrs.get('last_convolved_equinox', None) + + self._last_convolved_source_skycoord = SkyCoord(lon, lat, unit='deg', frame=frame, equinox=equinox) else: self._last_convolved_source_skycoord = None @@ -807,7 +875,6 @@ def expectation_density(self) -> Iterable[float]: self._update_cache() source_dict = self._source.to_dict() - if (source_dict != self._last_convolved_source_dict_density) or (self._exp_density is None): self._exp_density = torch.zeros(self._n_events, dtype=self._irf_cache.dtype) @@ -843,7 +910,9 @@ def expectation_density(self) -> Iterable[float]: self._last_convolved_source_dict_density = source_dict - #print(self._data.time.unix[self._exp_density <= 0][:100]) - #print(np.sum(self._exp_density <= 0)/self._n_events * 100) - #print(self.expected_counts() - np.sum(np.log(self._exp_density+1e-12))) - return np.asarray(self._exp_density, dtype=np.float64)+1e-12 \ No newline at end of file + result = np.asarray(self._exp_density, dtype=np.float64) + + if self._offset is not None: + return result + self._offset + else: + return result \ No newline at end of file diff --git a/cosipy/threeml/unbinned_model_folding.py b/cosipy/threeml/unbinned_model_folding.py index a843565f8..912db7f46 100644 --- a/cosipy/threeml/unbinned_model_folding.py +++ b/cosipy/threeml/unbinned_model_folding.py @@ -1,11 +1,14 @@ import itertools -from typing import Optional, Iterable +from typing import Optional, Iterable, Union import numpy as np from astromodels import Model, PointSource, ExtendedSource +from pathlib import Path from cosipy.interfaces import UnbinnedThreeMLModelFoldingInterface, UnbinnedThreeMLSourceResponseInterface +from cosipy.interfaces.source_response_interface import CachedUnbinnedThreeMLSourceResponseInterface from cosipy.response.threeml_response import ThreeMLModelFoldingCacheSourceResponsesMixin +from cosipy.util.iterables import asarray class UnbinnedThreeMLModelFolding(UnbinnedThreeMLModelFoldingInterface, ThreeMLModelFoldingCacheSourceResponsesMixin): @@ -62,3 +65,68 @@ def expectation_density(self) -> Iterable[float]: self._cache_source_responses() return [sum(expectations) for expectations in zip(*(s.expectation_density() for s in self._source_responses.values()))] + + +class CachedUnbinnedThreeMLModelFolding(UnbinnedThreeMLModelFolding): + def __init__(self, + point_source_response: Optional[UnbinnedThreeMLSourceResponseInterface] = None, + extended_source_response: Optional[UnbinnedThreeMLSourceResponseInterface] = None, + vectorize:bool = True): + + super().__init__(point_source_response=point_source_response, + extended_source_response=extended_source_response) + + self._base_filename = "_source_response_cache.h5" + self._vectorize = vectorize + + def init_cache(self): + """ + Forces the creation of response objects for each source in the model. + """ + self._cache_source_responses() + + for response in self._source_responses.values(): + if isinstance(response, CachedUnbinnedThreeMLSourceResponseInterface): + response.init_cache() + + def save_caches(self, directory: Union[str, Path], cache_only: Optional[Iterable[str]] = None): + """Saves only the responses that implement the cache interface.""" + self.init_cache() + + dir_path = Path(directory) + dir_path.mkdir(parents=True, exist_ok=True) + + for name, response in self._source_responses.items(): + if (cache_only is not None) and (name not in set(cache_only)): + continue + if isinstance(response, CachedUnbinnedThreeMLSourceResponseInterface): + filepath = dir_path / f"{name}{self._base_filename}" + response.cache_to_file(filepath) + + def load_caches(self, directory: Union[str, Path], load_only: Optional[Iterable[str]] = None): + """Loads available cache files into compatible response objects.""" + self._cache_source_responses() + + dir_path = Path(directory) + for name, response in self._source_responses.items(): + if (load_only is not None) and (name not in set(load_only)): + continue + if isinstance(response, CachedUnbinnedThreeMLSourceResponseInterface): + filepath = dir_path / f"{name}{self._base_filename}" + if filepath.exists(): + response.cache_from_file(filepath) + + def _expectation_density_gen(self) -> Iterable[float]: + for exdensity in zip(*[ex.expectation_density() for ex in self._source_responses.values()]): + yield sum(exdensity) + + def expectation_density(self) -> Iterable[float]: + self._cache_source_responses() + if self._vectorize: + if not self._source_responses: + return np.array([], dtype=np.float64) + else: + densities = [asarray(ex.expectation_density(), np.float64) for ex in self._source_responses.values()] + return np.add.reduce(densities) + else: + return self._expectation_density_gen() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ce56e513c..93beb69fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ [project.optional-dependencies] # Machine learning stuff -ML = ["torch", "torch_geometric"] +ML = ["torch", "torch_geometric", "normflows", "sphericart[torch]"] [project.urls] Homepage = "https://github.com/cositools/cosipy" From 42cc1d0405fbaa6a4ae3b422859c49f8ad72cd5c Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Wed, 11 Mar 2026 14:04:34 +0100 Subject: [PATCH 11/16] Fixed small bug where due to limited precision the jacobian could be negative and not just infinite --- cosipy/response/NFResponseModels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosipy/response/NFResponseModels.py b/cosipy/response/NFResponseModels.py index 5626938f3..1162f1f5c 100644 --- a/cosipy/response/NFResponseModels.py +++ b/cosipy/response/NFResponseModels.py @@ -223,7 +223,7 @@ def _transform_coordinates(self, *args: torch.Tensor) -> Tuple[torch.Tensor, tor ) jac = 1.0 / (ei * torch.sin(theta_raw + phi) * 4 * np.pi**3) - jac[torch.isinf(jac)] = 0.0 + jac[torch.isinf(jac) | (jac < 0)] = 0.0 ctx = self._transform_context(dir_az, dir_pol, ei) From 1ad7910bcd16cc39b5d95ebe432f7288ba481af1 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Thu, 12 Mar 2026 14:46:11 +0100 Subject: [PATCH 12/16] Bug fix. Source should receive flat array. --- cosipy/threeml/optimized_unbinned_folding.py | 65 +++++++++++++++----- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/cosipy/threeml/optimized_unbinned_folding.py b/cosipy/threeml/optimized_unbinned_folding.py index 9b7872314..d92c03b3e 100644 --- a/cosipy/threeml/optimized_unbinned_folding.py +++ b/cosipy/threeml/optimized_unbinned_folding.py @@ -43,19 +43,14 @@ def __init__(self, data: TimeTagEmCDSEventDataInSCFrameInterface, irf: FarFieldSpectralInstrumentResponseFunctionInterface, sc_history: SpacecraftHistory, - show_progress: bool = True): + show_progress: bool = True, + force_energy_node_caching: bool = False): """ Will fold the IRF with the point source spectrum by evaluating the IRF at Ei positions adaptively chosen based on characteristic IRF features Note that this assumes a smooth flux spectrum All IRF queries are cached and can be saved to / loaded from a file - - Parameters - ---------- - data - irf - sc_history """ # Interface inputs @@ -66,6 +61,7 @@ def __init__(self, self._irf = irf self._sc_ori = sc_history self.show_progress = show_progress + self.force_energy_node_caching = force_energy_node_caching # Default parameters for irf energy node placement self._total_energy_nodes = (60, 500) @@ -142,6 +138,14 @@ def __init__(self, def event_type(self) -> Type[EventInterface]: return TimeTagEmCDSEventInSCFrameInterface + @property + def force_energy_node_caching(self) -> bool: return self._force_energy_node_caching + @force_energy_node_caching.setter + def force_energy_node_caching(self, val): + if not isinstance(val, bool): + raise ValueError("force_energy_node_caching must be a boolean") + self._force_energy_node_caching = val + @property def total_energy_nodes(self) -> Tuple[int, int]: return self._total_energy_nodes @total_energy_nodes.setter @@ -174,12 +178,11 @@ def offset(self, val): self.set_integration_parameters(offset=val) @property def show_progress(self) -> bool: return self._show_progress - @show_progress.setter - def show_progress(self, value: bool): - if not isinstance(value, bool): + def show_progress(self, val: bool): + if not isinstance(val, bool): raise ValueError("show_progress must be a boolean") - self._show_progress = value + self._show_progress = val def set_integration_parameters(self, total_energy_nodes: Optional[Tuple[int, int]] = None, @@ -613,6 +616,15 @@ def _get_CDS_coordinates(self, lon_src_rad: torch.Tensor, lat_src_rad: torch.Ten return phi_geo_rad, np.pi - phi_geo_rad + def _compute_nodes(self): + sc_coord_sph = self._sc_coord_sph_cache + + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) + self._irf_energy_node_cache = np.asarray(self._get_nodes(self._energy_m_keV, self._phi_rad, phi_geo_rad, phi_igeo_rad)[0]) + def _compute_density(self): coord = self._source.position.sky_coord sc_coord_sph = self._sc_coord_sph_cache @@ -636,6 +648,9 @@ def _compute_density(self): batch_lon_scatt_buffer = np.empty(buffer_size, dtype=np.float32) batch_lat_scatt_buffer = np.empty(buffer_size, dtype=np.float32) + if (batch_size_events < self._n_events) & (self._force_energy_node_caching): + self._irf_energy_node_cache = np.empty((self._n_events, n_energy), dtype=np.float32) + for i in tqdm(range(0, self._n_events, batch_size_events), disable=(not self.show_progress), desc="Caching the response", @@ -654,7 +669,9 @@ def _compute_density(self): nodes, weights = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) if batch_size_events >= self._n_events: - self._irf_energy_node_cache = np.asarray(nodes) + self._irf_energy_node_cache = np.asarray(nodes) + else: + self._irf_energy_node_cache[start:end] = np.asarray(nodes) batch_lon_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lon_ph_rad[start:end, np.newaxis] batch_lat_src_buffer[:current_total].reshape(current_n, n_energy)[:] = lat_ph_rad[start:end, np.newaxis] @@ -722,6 +739,9 @@ def _update_cache(self): if not active_pool: self._irf.init_compute_pool() + if source_coord != self._last_convolved_source_skycoord: + self._irf_energy_node_cache = None + if area_recalculation: self._compute_area() @@ -733,6 +753,13 @@ def _update_cache(self): self._irf.shutdown_compute_pool() self._last_convolved_source_skycoord = source_coord.copy() + + node_caching = (self.force_energy_node_caching + and + self._irf_energy_node_cache is None) + + if node_caching: + self._compute_nodes() def cache_to_file(self, filename: Union[str, Path]): with h5py.File(str(filename), 'w') as f: @@ -741,6 +768,8 @@ def cache_to_file(self, filename: Union[str, Path]): f.attrs['peak_widths'] = self._peak_widths f.attrs['energy_range'] = self._energy_range f.attrs['batch_size'] = self._batch_size + f.attrs['show_progress'] = self._show_progress + f.attrs['force_energy_node_caching'] = self._force_energy_node_caching if self._offset is not None: f.attrs['offset'] = self._offset @@ -794,6 +823,8 @@ def cache_from_file(self, filename: Union[str, Path]): self._peak_widths = tuple(f.attrs['peak_widths']) self._energy_range = tuple(f.attrs['energy_range']) self._batch_size = int(f.attrs['batch_size']) + self._show_progress = bool(f.attrs['show_progress']) + self._force_energy_node_caching = bool(f.attrs['force_energy_node_caching']) if 'offset' in f.attrs: self._offset = f.attrs['offset'] @@ -879,7 +910,10 @@ def expectation_density(self) -> Iterable[float]: self._exp_density = torch.zeros(self._n_events, dtype=self._irf_cache.dtype) if self._irf_energy_node_cache is not None: - flux = torch.as_tensor(self._source(self._irf_energy_node_cache), dtype=self._irf_cache.dtype) + flux = torch.as_tensor( + self._source(self._irf_energy_node_cache.ravel()), + dtype=self._irf_cache.dtype + ).view(self._irf_energy_node_cache.shape) torch.linalg.vecdot(self._irf_cache, flux, dim=1, out=self._exp_density) @@ -904,7 +938,10 @@ def expectation_density(self) -> Iterable[float]: nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) - flux_batch = torch.as_tensor(self._source(np.asarray(nodes)), dtype=self._irf_cache.dtype) + flux_batch = torch.as_tensor( + self._source(np.asarray(nodes).ravel()), + dtype=self._irf_cache.dtype + ).view(nodes.shape) torch.linalg.vecdot(self._irf_cache[i:end], flux_batch, dim=1, out=self._exp_density[i:end]) From afb088e90dc4de35830b4c8a69e8f6764ea64a4f Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Fri, 13 Mar 2026 14:15:50 +0100 Subject: [PATCH 13/16] Added crab tutorial with normalizing flows --- cosipy/background_estimation/NFBackground.py | 4 +- .../example_crab_fit_normalizing_flows.ipynb | 503 ++++++++++++++++++ 2 files changed, 505 insertions(+), 2 deletions(-) create mode 100644 docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb diff --git a/cosipy/background_estimation/NFBackground.py b/cosipy/background_estimation/NFBackground.py index a0477e9c4..5187e8de3 100644 --- a/cosipy/background_estimation/NFBackground.py +++ b/cosipy/background_estimation/NFBackground.py @@ -67,9 +67,9 @@ def init_background_worker(device_queue: mp.Queue, progress_queue: mp.Queue, maj density_compile_mode, BackgroundDensityApproximation) class NFBackground(NFBase): - def __init__(self, path_to_model: Union[str, Path], + def __init__(self, path_to_model: Union[str, Path], density_batch_size: int = 100_000, devices: Optional[List[Union[str, int, torch.device]]] = None, - density_batch_size: int = 100_000, density_compile_mode: CompileMode = "default", show_progress: bool = True): + density_compile_mode: CompileMode = "default", show_progress: bool = True): super().__init__(path_to_model, update_density_worker_settings, init_background_worker, density_batch_size, devices, density_compile_mode, ['rate_input'], show_progress) diff --git a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb new file mode 100644 index 000000000..f66265170 --- /dev/null +++ b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb @@ -0,0 +1,503 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8e42b9b4-766d-4818-b5db-f3d6c56a9563", + "metadata": { + "deletable": true, + "editable": true, + "frozen": false + }, + "source": [ + "# Fitting the Crab Spectrum with the Neural-Network Response and Background Approximation" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c02898da-7753-46b6-b291-612cb953dc36", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "This notebook provides an overview of the different new classes introduced to allow for a truly unbinned spectral analysis of continuum point sources. The user will mainly interact with:\n", + "\n", + "1. `NFResponse` and `UnpolarizedNFFarFieldInstrumentResponseFunction`: Provide the **C-A-RQS + Spherical Harmonics Expansion** approximation of the response.\n", + "2. `NFBackground` and `FreeNormNFUnbinnedBackground`: Provide the **C-A-RQS + Analytical Rate Model** approximation of the simulated background (currently total DC4).\n", + "3. `UnbinnedThreeMLPointSourceResponseIRFAdaptive`: Perform the folding of the response with the flux model in a more efficient way.\n", + "4. `CachedUnbinnedThreeMLModelFolding`: Adds the capability to save and load the cache of `UnbinnedThreeMLPointSourceResponseIRFAdaptive`.\n", + "\n", + "### Inner Workings\n", + "\n", + "For a comprehensive description of the underlying architecture or more detailed comparison plots and benchmarks, please for now refer to my (Pascal J.) thesis, \"Development of an Efficient Response Description for the COSI MeV 𝜸-Ray Telescope.\" Here, I will just provide a very basic explanation of the code structure.\n", + "\n", + "#### Flow-Based Models\n", + "\n", + "The building blocks managing the approximations are `NFResponse` and `NFBackground` (both of type `NFBase`). During setup, both require a checkpoint file (e.g., `unpolarized_nfresponse_v1-00.pt`) which contains all the necessary information, such as the neural network weights and hyperparameters, the version number to choose the correct model, or the coefficients for the effective area or background rate model.\n", + "\n", + "Another important task is the creation of compute pools. All PyTorch-related inference tasks are managed by a separate process for each chosen device. These processes need to be started or stopped, as shown in the example below.\n", + "\n", + "All queries, such as evaluating the effective area or density as well as sampling, are passed through an `Approximation` object (which chooses the correct model and prepares the input) to a `Model` object (which handles the actual inference on the PyTorch device). For the response $R = A_\\mathrm{eff}(\\nu\\lambda, E_i) \\cdot P(E_m, \\phi, \\psi\\chi|\\nu\\lambda, E_i)$, this includes:\n", + "\n", + "1. The effective area $A_\\mathrm{eff}$, modeled with a spherical harmonics expansion. The latter are evaluated using the PyTorch backend of the `sphericart` library.\n", + "2. The probability density $P(E_m, \\phi, \\psi\\chi|\\nu\\lambda, E_i)$, modeled with conditional-autoregressive-rational quadratic spline flows. The library used is called `normflows`.\n", + "\n", + "The background model $B = R(t) \\cdot P(E_m, \\phi, \\psi\\chi|t)$ follows a very similar structure and therefore uses most of the same code. While the rate $R(t)$ is always computed on the CPU, the density $P$ is also modeled using `normflows`.\n", + "\n", + "#### Folding\n", + "\n", + "For the unbinned analysis, we need to calculate the total number of events we expect, $N$, and the expectation density, $\\mathrm{d}N/\\mathrm{d}t\\mathrm{d}E_m\\mathrm{d}\\phi\\mathrm{d}\\psi\\chi$. The latter is especially difficult to compute, since it requires folding the response with the flux model and therefore computing a numerical integral with enough precision for every event in the analysis. The background model $B$, on the other hand, is already the expectation density and requires no integration.\n", + "\n", + "The folding can be performed using `UnbinnedThreeMLPointSourceResponseTrapz`. However, since $R$ has a low inference rate, `UnbinnedThreeMLPointSourceResponseIRFAdaptive` provides an optimized implementation. It uses the fact that many flux models are \"well-behaved\" (low order in a Taylor series). This means one can significantly reduce the integration error by concentrating most Gauss-Legendre integration nodes at peaks of the response and distributing the remaining ones in between. \n", + "\n", + "The exact parameters can be tuned by the user, but it probably won't be able to account for very complex spectral shapes. For each integration node, the response is evaluated once and cached, which takes most of the time. This way, during optimization, only the flux model changes and the fitting is fast.\n", + "\n", + "Using `CachedUnbinnedThreeMLModelFolding` instead of `UnbinnedThreeMLModelFolding` adds the capability to save this cache and other parameters of `UnbinnedThreeMLPointSourceResponseIRFAdaptive`. These files can be shared with other scientists and loaded during another session to skip the expensive cache initialization.\n", + "\n", + "### Hardware Requirements\n", + "\n", + "It is highly recommended to use a GPU, preferably by NVIDIA, for the inference. A CPU, even something like an AMD Threadripper, will only allow you to analyze a few orbits in a reasonable time. Also, note that the unbinned analysis requires you to have enough RAM, which increases as the number of events you include in the analysis increases. The batch size can be decreased to limit the maximum consumption." + ] + }, + { + "cell_type": "markdown", + "id": "35a07884-ca61-461f-bcde-96ec572d7f9d", + "metadata": {}, + "source": [ + "## Example" + ] + }, + { + "cell_type": "markdown", + "id": "87d50f2d-6d5a-48b7-9fe5-1c63014c80e0", + "metadata": {}, + "source": [ + "### Basic Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d2076d1-f84b-4971-b4e6-611844dcdef6", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from cosipy.util import fetch_wasabi_file\n", + "from cosipy.spacecraftfile import SpacecraftHistory\n", + "from astropy.time import Time\n", + "\n", + "import astropy.units as u\n", + "\n", + "from threeML import Band, PointSource, Model, JointLikelihood, DataList\n", + "from astromodels import Parameter\n", + "\n", + "from cosipy.threeml.unbinned_model_folding import CachedUnbinnedThreeMLModelFolding\n", + "from cosipy.statistics import UnbinnedLikelihood\n", + "from cosipy.interfaces import ThreeMLPluginInterface\n", + "from cosipy.interfaces.expectation_interface import SumExpectationDensity\n", + "\n", + "from cosipy.event_selection.time_selection import TimeSelector\n", + "from cosipy.data_io.EmCDSUnbinnedData import TimeTagEmCDSEventDataInSCFrameFromDC3Fits\n", + "\n", + "from cosipy.response.NFResponse import NFResponse\n", + "from cosipy.response.nf_instrument_response_function import UnpolarizedNFFarFieldInstrumentResponseFunction\n", + "\n", + "from cosipy.background_estimation.NFBackground import NFBackground\n", + "from cosipy.background_estimation.nf_unbinned_background import FreeNormNFUnbinnedBackground\n", + "\n", + "from cosipy.threeml.optimized_unbinned_folding import UnbinnedThreeMLPointSourceResponseIRFAdaptive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ec4a3dc-5b5e-4587-ba36-fc5e024e0682", + "metadata": {}, + "outputs": [], + "source": [ + "data_path = Path(\"./data_files/\")\n", + "\n", + "crab_data_path = data_path / \"crab_standard_3months_unbinned_data_filtered_with_SAAcut.fits.gz\"\n", + "fetch_wasabi_file('COSI-SMEX/DC3/Data/Sources/crab_standard_3months_unbinned_data_filtered_with_SAAcut.fits.gz', \n", + " output=str(crab_data_path))\n", + "\n", + "sc_orientation_path = data_path / \"DC3_final_530km_3_month_with_slew_15sbins_GalacticEarth_SAA.ori\"\n", + "fetch_wasabi_file('COSI-SMEX/DC3/Data/Orientation/DC3_final_530km_3_month_with_slew_15sbins_GalacticEarth_SAA.ori', \n", + " output=str(sc_orientation_path))\n", + "\n", + "bkg_data_path = data_path / \"Total_DC4_BG_3months_unbinned_data_filtered_with_SAAcut_withSAAbck.fits.gz\"\n", + "fetch_wasabi_file('COSI-SMEX/DC4/Data/Backgrounds/Total_DC4_BG_3months_unbinned_data_filtered_with_SAAcut_withSAAbck.fits.gz', \n", + " output=str(bkg_data_path))\n", + "\n", + "rsp_path = data_path / \"unpolarized_nfresponse_v1-00.pt\"\n", + "bkg_path = data_path / \"nfbackground_v1-01.pt\"" + ] + }, + { + "cell_type": "markdown", + "id": "68bdf301-5315-48b7-af07-d340c52e8780", + "metadata": {}, + "source": [ + "For this tutorial we only use one day of data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b9f01c9-8bd8-4e69-b964-71ef99032452", + "metadata": {}, + "outputs": [], + "source": [ + "tstart = Time(\"2028-03-15 00:00:00.000\")\n", + "tstop = Time(\"2028-03-15 02:00:00.000\")\n", + "sc_orientation = SpacecraftHistory.open(sc_orientation_path)\n", + "sc_orientation = sc_orientation.select_interval(tstart, tstop)" + ] + }, + { + "cell_type": "markdown", + "id": "cd8dbcc7-016d-411a-b6de-f0958aa6ae5c", + "metadata": {}, + "source": [ + "The time selection is very slow (see issue https://github.com/cositools/cosipy/issues/504). Consider using your own implementation for now." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a6b82a7-9c78-4315-97bd-50d83ed09f42", + "metadata": {}, + "outputs": [], + "source": [ + "data_file = [crab_data_path, bkg_data_path]\n", + "selector = TimeSelector(tstart = sc_orientation.tstart, tstop = sc_orientation.tstop)\n", + "data = TimeTagEmCDSEventDataInSCFrameFromDC3Fits(data_file, selection=selector)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b974b15-e507-4ab9-9299-2b735f7d6b7f", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"This analysis uses {data.nevents} Events\")" + ] + }, + { + "cell_type": "markdown", + "id": "e676fa4e-d0a5-4b58-b309-19bb17dc95c3", + "metadata": {}, + "source": [ + "`NFResponse` and `NFBackground`\n", + "- devices: Optional default devices. For CPU, choose `[\"cpu\"]`.\n", + " - Default devices cause the pool to be initialized and closed automatically for each inference. This can produce unnecessary overhead when not used with functions like `UnbinnedThreeMLPointSourceResponseIRFAdaptive`, which handles this management internally.\n", + " - You can also manually manage the pools using `init_compute_pool`, `clean_compute_pool`, and `shutdown_compute_pool`.\n", + "- compile_mode: Optional from a list of options. If issues arise, choose `None`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25eca2eb-c55d-4f02-a61c-db9e2034a706", + "metadata": {}, + "outputs": [], + "source": [ + "rsp = NFResponse(\n", + " path_to_model=rsp_path,\n", + " area_batch_size=300_000,\n", + " density_batch_size=100_000, \n", + " devices=[\"cuda:0\", \"cuda:1\", \"cuda:2\", \"cuda:3\"],\n", + " area_compile_mode=\"max-autotune-no-cudagraphs\",\n", + " density_compile_mode=\"default\",\n", + " show_progress=True)\n", + "\n", + "irf = UnpolarizedNFFarFieldInstrumentResponseFunction(rsp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c8d0df3-813c-4eed-8316-d604d80dcbcb", + "metadata": {}, + "outputs": [], + "source": [ + "bkg_model = NFBackground(\n", + " path_to_model=bkg_path,\n", + " density_batch_size=100_000,\n", + " devices=[\"cuda:0\", \"cuda:1\", \"cuda:2\", \"cuda:3\"],\n", + " density_compile_mode=\"default\",\n", + " show_progress=True)\n", + "\n", + "bkg = FreeNormNFUnbinnedBackground(\n", + " model=bkg_model, \n", + " data=data, \n", + " sc_history=sc_orientation, \n", + " label=\"bkg_norm\")" + ] + }, + { + "cell_type": "markdown", + "id": "4904469c-285e-4a3a-8770-3dbb2bc0857b", + "metadata": {}, + "source": [ + "The next step takes some time, as it needs to interpolate the `SpacecraftHistory` (again issue https://github.com/cositools/cosipy/issues/504)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2dc234db-30f6-44e4-b9e6-22519b177b15", + "metadata": {}, + "outputs": [], + "source": [ + "psr = UnbinnedThreeMLPointSourceResponseIRFAdaptive(\n", + " data=data, \n", + " irf=irf, \n", + " sc_history=sc_orientation, \n", + " show_progress=True, \n", + " force_energy_node_caching=True) # Saves the energy nodes even when the batch size is too small" + ] + }, + { + "cell_type": "markdown", + "id": "e614fff3-9729-4e01-b8aa-2b7f08b1c89a", + "metadata": {}, + "source": [ + "The default `Band` or `Powerlaw` are very slow (see discussion https://github.com/cositools/cosipy/discussions/492). Consider implementing your own version with numpy or torch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "390f3822-e071-4add-b08e-833566ee7514", + "metadata": {}, + "outputs": [], + "source": [ + "l = 184.56\n", + "b = -5.78\n", + "\n", + "alpha = -1.99\n", + "beta = -2.32\n", + "xp = 531. * u.keV * (alpha + 2)\n", + "piv = 500. * u.keV\n", + "K = 3.07e-5 / u.cm / u.cm / u.s / u.keV\n", + "\n", + "spectrum = Band()\n", + "\n", + "spectrum.alpha.min_value = -2.14\n", + "spectrum.alpha.max_value = 3.0\n", + "spectrum.beta.min_value = -5.0\n", + "spectrum.beta.max_value = -2.15\n", + "spectrum.xp.min_value = 1.0\n", + "spectrum.alpha.delta = 0.01\n", + "spectrum.beta.delta = 0.01\n", + "\n", + "spectrum.alpha.value = alpha\n", + "spectrum.beta.value = beta\n", + "spectrum.xp.value = xp.value\n", + "spectrum.K.value = K.value\n", + "spectrum.piv.value = piv.value\n", + "\n", + "spectrum.xp.unit = xp.unit\n", + "spectrum.K.unit = K.unit\n", + "spectrum.piv.unit = piv.unit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3de7c067-204b-491f-9bea-4e44929dcfb9", + "metadata": {}, + "outputs": [], + "source": [ + "source = PointSource(\"Crab\",\n", + " l=l,\n", + " b=b,\n", + " spectral_shape=spectrum)\n", + "model = Model(source)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e52ef27-a5e7-41d7-87fd-3a82836bc131", + "metadata": {}, + "outputs": [], + "source": [ + "response = CachedUnbinnedThreeMLModelFolding(psr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01771aa6-36b1-46fb-a331-4dffff30ac91", + "metadata": {}, + "outputs": [], + "source": [ + "expectation_density = SumExpectationDensity(response, bkg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a5fbeee-e1dc-4c7b-993b-80b65f5217f4", + "metadata": {}, + "outputs": [], + "source": [ + "like_fun = UnbinnedLikelihood(expectation_density)\n", + "cosi = ThreeMLPluginInterface('cosi', like_fun, response, bkg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13e8ad22-f5e4-41cd-88fd-a04be7ad8200", + "metadata": {}, + "outputs": [], + "source": [ + "bkg_norm = bkg.norm.to_value(u.Hz)\n", + "\n", + "cosi.bkg_parameter['bkg_norm'] = Parameter(\"bkg_norm\",\n", + " bkg_norm,\n", + " unit = u.Hz,\n", + " min_value=0,\n", + " max_value=100,\n", + " delta=0.05,\n", + " )\n", + "print(f\"The average background flux is {bkg_norm:.2f} Hz\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88ea41cb-867a-41e1-840a-a328f316d3c2", + "metadata": {}, + "outputs": [], + "source": [ + "plugins = DataList(cosi)\n", + "like = JointLikelihood(model, plugins, verbose=True) # You can disable debugging" + ] + }, + { + "cell_type": "markdown", + "id": "9bcad107-7f24-45c9-838c-bf5c672607dd", + "metadata": {}, + "source": [ + "### Initializing the Cache" + ] + }, + { + "cell_type": "markdown", + "id": "e5483878-ee8b-44c7-b2cd-49abfc07d0f5", + "metadata": {}, + "source": [ + "Here you could load the cache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c87b4579-25dd-4e6d-be92-635d131f99d5", + "metadata": {}, + "outputs": [], + "source": [ + "# response.load_caches(data_path / \"Crab_tutorial\")" + ] + }, + { + "cell_type": "markdown", + "id": "ba814a8b-ad8a-41ef-ad44-c8794fd574f7", + "metadata": {}, + "source": [ + "The cache is initialized, which takes some time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37ab7f53-e474-4699-b7a6-5e4472ee3309", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Data Events: {data.nevents}\\nExpected Events: {expectation_density.expected_counts():.2f}\\nRelative Deviation {100 * (expectation_density.expected_counts()/data.nevents - 1):.3f} %\")" + ] + }, + { + "cell_type": "markdown", + "id": "28caa6c7-8775-4f4c-83ca-b27eb2918a90", + "metadata": {}, + "source": [ + "Now you could save the cache." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f64a3895-49ca-4e47-bc40-b41106ee270f", + "metadata": {}, + "outputs": [], + "source": [ + "response.save_caches(data_path / \"Crab_tutorial\")" + ] + }, + { + "cell_type": "markdown", + "id": "859f6453-b4be-4e1c-97a4-743f851b566d", + "metadata": {}, + "source": [ + "### Fitting" + ] + }, + { + "cell_type": "markdown", + "id": "4ea6ed61-8e10-49b8-a8d4-878f34b0b53d", + "metadata": {}, + "source": [ + "If the fit fails with \"Current minimum stored after fit ... and current ... do not correspond!\" simply rerun this cell" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d0f6e15-c963-4116-911c-f6ca11fa4292", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "like.fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "132119fa-6035-4a8d-a23d-550bdee2744f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 18b5a714016a60ca4549b841071032b2c9b7ed21 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Fri, 13 Mar 2026 15:54:51 +0100 Subject: [PATCH 14/16] Added response and background checkpoint to wasabi --- .../crab/example_crab_fit_normalizing_flows.ipynb | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb index f66265170..b11562582 100644 --- a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb +++ b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb @@ -131,7 +131,12 @@ " output=str(bkg_data_path))\n", "\n", "rsp_path = data_path / \"unpolarized_nfresponse_v1-00.pt\"\n", - "bkg_path = data_path / \"nfbackground_v1-01.pt\"" + "fetch_wasabi_file('cosi-pipeline-public/COSI-SMEX/DC4/Data/Responses/unpolarized_nfresponse_v1-00.pt',\n", + " checksum = 'a4f9a7842f2a7345f604da32a155803f', output=str(rsp_path))\n", + "\n", + "bkg_path = data_path / \"nfbackground_v1-01.pt\"\n", + "fetch_wasabi_file('cosi-pipeline-public/COSI-SMEX/DC4/Data/Responses/nfbackground_v1-01.pt',\n", + " checksum = '52a4fb024930e18430bff80db882d3d5', output=str(bkg_path))" ] }, { From dcffa20215054e88e18353740b1cc3c618da2860 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Sun, 15 Mar 2026 22:48:28 +0100 Subject: [PATCH 15/16] Added option to compare the best fit and injected spectrum to tutorial notebook --- .../example_crab_fit_normalizing_flows.ipynb | 78 ++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb index b11562582..a7ebeb039 100644 --- a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb +++ b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb @@ -88,6 +88,9 @@ "from astropy.time import Time\n", "\n", "import astropy.units as u\n", + "from copy import deepcopy\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "from threeML import Band, PointSource, Model, JointLikelihood, DataList\n", "from astromodels import Parameter\n", @@ -270,7 +273,7 @@ "id": "e614fff3-9729-4e01-b8aa-2b7f08b1c89a", "metadata": {}, "source": [ - "The default `Band` or `Powerlaw` are very slow (see discussion https://github.com/cositools/cosipy/discussions/492). Consider implementing your own version with numpy or torch." + "The default `Band`, `Powerlaw` and `PointSource` are very slow (see discussion https://github.com/cositools/cosipy/discussions/492). Consider implementing your own version with torch." ] }, { @@ -310,6 +313,16 @@ "spectrum.piv.unit = piv.unit" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f3af0aa", + "metadata": {}, + "outputs": [], + "source": [ + "spectrum_inj = deepcopy(spectrum)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -460,7 +473,7 @@ "id": "4ea6ed61-8e10-49b8-a8d4-878f34b0b53d", "metadata": {}, "source": [ - "If the fit fails with \"Current minimum stored after fit ... and current ... do not correspond!\" simply rerun this cell" + "If the fit fails with \"Current minimum stored after fit ... and current ... do not correspond!\" simply rerun this cell (sometimes even that does not help)." ] }, { @@ -475,13 +488,72 @@ "like.fit()" ] }, + { + "cell_type": "markdown", + "id": "8eb8a59e", + "metadata": {}, + "source": [ + "Now we can plot the result and compare it with the injected spectrum." + ] + }, { "cell_type": "code", "execution_count": null, "id": "132119fa-6035-4a8d-a23d-550bdee2744f", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "results = like.results\n", + "\n", + "parameters = {par.name: results.get_variates(par.path)\n", + " for par in results.optimized_model[\"Crab\"].parameters.values()\n", + " if par.free}\n", + "\n", + "results_err = results.propagate(results.optimized_model[\"Crab\"].spectrum.main.shape.evaluate_at, **parameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f68e51c9", + "metadata": {}, + "outputs": [], + "source": [ + "energy = np.geomspace(100*u.keV, 10*u.MeV).to_value(u.keV)\n", + "\n", + "flux_lo = np.zeros_like(energy)\n", + "flux_median = np.zeros_like(energy)\n", + "flux_hi = np.zeros_like(energy)\n", + "flux_inj = np.zeros_like(energy)\n", + "\n", + "for i, e in enumerate(energy):\n", + " flux = results_err(e)\n", + " flux_median[i] = flux.median\n", + " flux_lo[i], flux_hi[i] = flux.equal_tail_interval(cl=0.68)\n", + " flux_inj[i] = spectrum_inj.evaluate_at(e)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7500f4f7", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize = (9, 6))\n", + "\n", + "ax.plot(energy, energy**2 * flux_median, label = \"Best fit\")\n", + "ax.fill_between(energy, energy**2 * flux_lo, energy*energy*flux_hi, alpha = .5, label = \"Best fit (errors)\")\n", + "ax.plot(energy, energy**2 * flux_inj, color = 'black', ls = \":\", label = \"Injected\")\n", + "\n", + "ax.semilogx()\n", + "ax.semilogy()\n", + "\n", + "ax.set_xlabel(\"Energy [keV]\")\n", + "ax.set_ylabel(r\"$E^2 \\frac{\\mathrm{d}N}{\\mathrm{d}E}$ [keV cm$^{-2}$ s$^{-1}$]\")\n", + "\n", + "ax.legend();" + ] } ], "metadata": { From 77265f01d58bef0dfea898a53103cee3270cb301 Mon Sep 17 00:00:00 2001 From: pjanowsk Date: Thu, 19 Mar 2026 21:40:27 +0100 Subject: [PATCH 16/16] Fixed iminuit not converging by forcing float64. Replaced vectorized with batch_size for the unbinned_model_folding. Updated the tutorial. --- cosipy/threeml/optimized_unbinned_folding.py | 140 +++++++++++++----- cosipy/threeml/unbinned_model_folding.py | 51 +++---- .../example_crab_fit_normalizing_flows.ipynb | 7 +- 3 files changed, 134 insertions(+), 64 deletions(-) diff --git a/cosipy/threeml/optimized_unbinned_folding.py b/cosipy/threeml/optimized_unbinned_folding.py index d92c03b3e..3fcf7e070 100644 --- a/cosipy/threeml/optimized_unbinned_folding.py +++ b/cosipy/threeml/optimized_unbinned_folding.py @@ -27,6 +27,10 @@ from astropy.coordinates import SkyCoord from astropy.time import Time +import logging + +logger = logging.getLogger(__name__) + from importlib.util import find_spec @@ -44,7 +48,8 @@ def __init__(self, irf: FarFieldSpectralInstrumentResponseFunctionInterface, sc_history: SpacecraftHistory, show_progress: bool = True, - force_energy_node_caching: bool = False): + force_energy_node_caching: bool = False, + reduce_memory: bool = True): """ Will fold the IRF with the point source spectrum by evaluating the IRF at Ei positions adaptively chosen based on characteristic IRF features @@ -68,7 +73,8 @@ def __init__(self, self._peak_nodes = (18, 12) self._peak_widths = (0.04, 0.1) self._energy_range = (100., 10_000.) - self._batch_size = 1_000_000 + self._cache_batch_size = 1_000_000 + self._integration_batch_size = 1_000_000 self._offset: Optional[float] = 1e-12 # Placeholder for node pool - stored as Tensors @@ -134,6 +140,9 @@ def __init__(self, self._cos_lon_scatt = torch.cos(self._lon_scatt) self._sin_lon_scatt = torch.sin(self._lon_scatt) + # Also runs _check_memory_savings + self.reduce_memory = reduce_memory + @property def event_type(self) -> Type[EventInterface]: return TimeTagEmCDSEventInSCFrameInterface @@ -167,9 +176,14 @@ def energy_range(self) -> Tuple[float, float]: return self._energy_range def energy_range(self, val): self.set_integration_parameters(energy_range=val) @property - def batch_size(self) -> int: return self._batch_size - @batch_size.setter - def batch_size(self, val): self.set_integration_parameters(batch_size=val) + def cache_batch_size(self) -> int: return self._cache_batch_size + @cache_batch_size.setter + def cache_batch_size(self, val): self.set_integration_parameters(cache_batch_size=val) + + @property + def integration_batch_size(self) -> int: return self._integration_batch_size + @integration_batch_size.setter + def integration_batch_size(self, val): self.set_integration_parameters(integration_batch_size=val) @property def offset(self) -> Optional[float]: return self._offset @@ -184,19 +198,44 @@ def show_progress(self, val: bool): raise ValueError("show_progress must be a boolean") self._show_progress = val + def _check_memory_savings(self): + inefficient = (self._integration_batch_size > (self._n_events * self._total_energy_nodes[0]/2)) + if inefficient & self._reduce_memory: + logger.warning(f"Since integration_batch_size is too large reduce_memory will increase the memory usage! Disable it if this behavior is not desired.") + @property + def reduce_memory(self) -> bool: return self._reduce_memory + @reduce_memory.setter + def reduce_memory(self, val: bool): + if not isinstance(val, bool): + raise ValueError("reduce_memory must be a boolean") + self._reduce_memory = val + self._check_memory_savings() + if not val: + if self._irf_cache is not None: + self._irf_cache = torch.as_tensor(self._irf_cache, dtype=torch.float64) + if self._irf_energy_node_cache is not None: + self._irf_energy_node_cache = np.asarray(self._irf_energy_node_cache, dtype=np.float64) + else: + if self._irf_cache is not None: + self._irf_cache = torch.as_tensor(self._irf_cache, dtype=torch.float32) + if self._irf_energy_node_cache is not None: + self._irf_energy_node_cache = np.asarray(self._irf_energy_node_cache, dtype=np.float32) + def set_integration_parameters(self, total_energy_nodes: Optional[Tuple[int, int]] = None, peak_nodes: Optional[Tuple[int, int]] = None, peak_widths: Optional[Tuple[float, float]] = None, energy_range: Optional[Tuple[float, float]] = None, - batch_size: Optional[int] = None, + cache_batch_size: Optional[int] = -1, + integration_batch_size: Optional[int] = -1, offset: Optional[float] = -1.0): new_total = total_energy_nodes or self._total_energy_nodes new_peak_nodes = peak_nodes or self._peak_nodes new_peak_widths = peak_widths or self._peak_widths new_range = energy_range or self._energy_range - new_batch = batch_size if batch_size is not None else self._batch_size + new_cache_batch = cache_batch_size if cache_batch_size != -1 else self._cache_batch_size + new_integration_batch = integration_batch_size if integration_batch_size != -1 else self._integration_batch_size new_offset = offset if offset != -1.0 else self._offset irf_affected = ( @@ -228,8 +267,11 @@ def set_integration_parameters(self, if new_range[0] >= new_range[1]: raise ValueError("The initial energy interval needs to be increasing!") - if new_batch < max(new_total): - raise ValueError("The batch size cannot be smaller than the number of integration nodes.") + if (new_cache_batch is not None) and (new_cache_batch < max(new_total)): + raise ValueError("The cache batch size cannot be smaller than the number of integration nodes.") + + if (new_integration_batch is not None) and (new_integration_batch < max(new_total)): + raise ValueError("The integration batch size cannot be smaller than the number of integration nodes.") if (new_offset is not None) and (new_offset < 0): raise ValueError("The offset cannot be negative.") @@ -238,8 +280,10 @@ def set_integration_parameters(self, self._peak_nodes = new_peak_nodes self._peak_widths = new_peak_widths self._energy_range = new_range - self._batch_size = new_batch + self._cache_batch_size = new_cache_batch if new_cache_batch is not None else (self._n_events * max(new_total)) + self._integration_batch_size = new_integration_batch if new_integration_batch is not None else (self._n_events * max(new_total)) self._offset = new_offset + self._check_memory_savings() @staticmethod def _build_nodes(degree: int) -> Tuple[torch.Tensor, torch.Tensor]: @@ -392,7 +436,7 @@ def _compute_area(self): lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) n_time = len(lon_ph_rad) - batch_size_time = self._batch_size // n_energy + batch_size_time = self._cache_batch_size // n_energy total_area = np.zeros(n_energy, dtype=np.float64) @@ -623,7 +667,8 @@ def _compute_nodes(self): lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) - self._irf_energy_node_cache = np.asarray(self._get_nodes(self._energy_m_keV, self._phi_rad, phi_geo_rad, phi_igeo_rad)[0]) + np_memory_dtype = np.float32 if self._reduce_memory else np.float64 + self._irf_energy_node_cache = np.asarray(self._get_nodes(self._energy_m_keV, self._phi_rad, phi_geo_rad, phi_igeo_rad)[0], dtype=np_memory_dtype) def _compute_density(self): coord = self._source.position.sky_coord @@ -636,9 +681,12 @@ def _compute_density(self): phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) n_energy = self._total_energy_nodes[0] - batch_size_events = self._batch_size // n_energy + batch_size_events = self._cache_batch_size // n_energy + + np_memory_dtype = np.float32 if self._reduce_memory else np.float64 + torch_memory_dtype = torch.float32 if self._reduce_memory else torch.float64 - self._irf_cache = torch.zeros((self._n_events, n_energy), dtype=torch.float32) + self._irf_cache = torch.zeros((self._n_events, n_energy), dtype=torch_memory_dtype) buffer_size = n_energy * min(batch_size_events, self._n_events) batch_lon_src_buffer = np.empty(buffer_size, dtype=np.float32) @@ -649,7 +697,7 @@ def _compute_density(self): batch_lat_scatt_buffer = np.empty(buffer_size, dtype=np.float32) if (batch_size_events < self._n_events) & (self._force_energy_node_caching): - self._irf_energy_node_cache = np.empty((self._n_events, n_energy), dtype=np.float32) + self._irf_energy_node_cache = np.zeros((self._n_events, n_energy), dtype=np_memory_dtype) for i in tqdm(range(0, self._n_events, batch_size_events), disable=(not self.show_progress), @@ -669,7 +717,7 @@ def _compute_density(self): nodes, weights = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) if batch_size_events >= self._n_events: - self._irf_energy_node_cache = np.asarray(nodes) + self._irf_energy_node_cache = np.asarray(nodes, dtype=np_memory_dtype) else: self._irf_energy_node_cache[start:end] = np.asarray(nodes) @@ -767,9 +815,11 @@ def cache_to_file(self, filename: Union[str, Path]): f.attrs['peak_nodes'] = self._peak_nodes f.attrs['peak_widths'] = self._peak_widths f.attrs['energy_range'] = self._energy_range - f.attrs['batch_size'] = self._batch_size + f.attrs['cache_batch_size'] = self._cache_batch_size + f.attrs['integration_batch_size'] = self._integration_batch_size f.attrs['show_progress'] = self._show_progress f.attrs['force_energy_node_caching'] = self._force_energy_node_caching + f.attrs['reduce_memory'] = self._reduce_memory if self._offset is not None: f.attrs['offset'] = self._offset @@ -822,9 +872,11 @@ def cache_from_file(self, filename: Union[str, Path]): self._peak_nodes = tuple(f.attrs['peak_nodes']) self._peak_widths = tuple(f.attrs['peak_widths']) self._energy_range = tuple(f.attrs['energy_range']) - self._batch_size = int(f.attrs['batch_size']) + self._cache_batch_size = int(f.attrs['cache_batch_size']) + self._integration_batch_size = int(f.attrs['integration_batch_size']) self._show_progress = bool(f.attrs['show_progress']) self._force_energy_node_caching = bool(f.attrs['force_energy_node_caching']) + self._reduce_memory = bool(f.attrs['reduce_memory']) if 'offset' in f.attrs: self._offset = f.attrs['offset'] @@ -907,43 +959,55 @@ def expectation_density(self) -> Iterable[float]: self._update_cache() source_dict = self._source.to_dict() if (source_dict != self._last_convolved_source_dict_density) or (self._exp_density is None): - self._exp_density = torch.zeros(self._n_events, dtype=self._irf_cache.dtype) + self._exp_density = torch.zeros(self._n_events, dtype=torch.float64) + + n_energy = self._total_energy_nodes[0] + batch_size = self._integration_batch_size // n_energy - if self._irf_energy_node_cache is not None: + if (self._irf_energy_node_cache is not None) & (batch_size >= self._n_events): flux = torch.as_tensor( - self._source(self._irf_energy_node_cache.ravel()), - dtype=self._irf_cache.dtype + self._source( + np.asarray(self._irf_energy_node_cache, dtype=np.float64).ravel() + ), + dtype=torch.float64 ).view(self._irf_energy_node_cache.shape) + + cache = torch.as_tensor(self._irf_cache, dtype=torch.float64) - torch.linalg.vecdot(self._irf_cache, flux, dim=1, out=self._exp_density) + torch.linalg.vecdot(cache, flux, dim=1, out=self._exp_density) else: - n_energy = self._total_energy_nodes[0] - batch_size = self._batch_size // n_energy - - sc_coord_sph = self._sc_coord_sph_cache + if self._irf_energy_node_cache is None: + sc_coord_sph = self._sc_coord_sph_cache - lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) - lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) + lon_ph_rad = asarray(sc_coord_sph.lon.rad, dtype=np.float32) + lat_ph_rad = asarray(sc_coord_sph.lat.rad, dtype=np.float32) - phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) + phi_geo_rad, phi_igeo_rad = self._get_CDS_coordinates(torch.as_tensor(lon_ph_rad), torch.as_tensor(lat_ph_rad)) for i in range(0, self._n_events, batch_size): end = min(i + batch_size, self._n_events) - e_sl = self._energy_m_keV[i:end] - p_sl = self._phi_rad[i:end] - pg_sl = phi_geo_rad[i:end] - pig_sl = phi_igeo_rad[i:end] + if self._irf_energy_node_cache is None: + e_sl = self._energy_m_keV[i:end] + p_sl = self._phi_rad[i:end] + pg_sl = phi_geo_rad[i:end] + pig_sl = phi_igeo_rad[i:end] - nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + nodes, _ = self._get_nodes(e_sl, p_sl, pg_sl, pig_sl) + else: + nodes = self._irf_energy_node_cache[i:end] + + nodes = np.asarray(nodes, dtype=np.float64) flux_batch = torch.as_tensor( - self._source(np.asarray(nodes).ravel()), - dtype=self._irf_cache.dtype + self._source(nodes.ravel()), + dtype=torch.float64 ).view(nodes.shape) - torch.linalg.vecdot(self._irf_cache[i:end], flux_batch, dim=1, out=self._exp_density[i:end]) + cache = torch.as_tensor(self._irf_cache[i:end], dtype=torch.float64) + + torch.linalg.vecdot(cache, flux_batch, dim=1, out=self._exp_density[i:end]) self._last_convolved_source_dict_density = source_dict diff --git a/cosipy/threeml/unbinned_model_folding.py b/cosipy/threeml/unbinned_model_folding.py index 912db7f46..c6fc5e62d 100644 --- a/cosipy/threeml/unbinned_model_folding.py +++ b/cosipy/threeml/unbinned_model_folding.py @@ -9,12 +9,14 @@ from cosipy.interfaces.source_response_interface import CachedUnbinnedThreeMLSourceResponseInterface from cosipy.response.threeml_response import ThreeMLModelFoldingCacheSourceResponsesMixin from cosipy.util.iterables import asarray +from cosipy.util.iterables import itertools_batched class UnbinnedThreeMLModelFolding(UnbinnedThreeMLModelFoldingInterface, ThreeMLModelFoldingCacheSourceResponsesMixin): def __init__(self, point_source_response = UnbinnedThreeMLSourceResponseInterface, - extended_source_response: UnbinnedThreeMLSourceResponseInterface = None): + extended_source_response: UnbinnedThreeMLSourceResponseInterface = None, + batch_size: Optional[int] = None): # Interface inputs self._model = None @@ -22,6 +24,7 @@ def __init__(self, # Implementation inputs self._psr = point_source_response self._esr = extended_source_response + self._batch_size = batch_size if (self._psr is not None) and (self._esr is not None) and self._psr.event_type != self._esr.event_type: raise RuntimeError("Point and Extended Source Response must handle the same event type") @@ -57,27 +60,40 @@ def expected_counts(self) -> float: return sum(s.expected_counts() for s in self._source_responses.values()) + def _expectation_density_batched_gen(self, sources: list) -> Iterable[float]: + batched_sources = [itertools_batched(s, self._batch_size) for s in sources] + + for chunks in zip(*batched_sources): + densities = [asarray(c, dtype=np.float64) for c in chunks] + + yield from np.add.reduce(densities) + def expectation_density(self) -> Iterable[float]: - """ - Sum of expectation density - """ - self._cache_source_responses() - - return [sum(expectations) for expectations in zip(*(s.expectation_density() for s in self._source_responses.values()))] + + if not self._source_responses: + return np.array([], dtype=np.float64) + + sources = [ex.expectation_density() for ex in self._source_responses.values()] + + if (self._batch_size is None) or all(hasattr(s, "__len__") for s in sources): + densities = [asarray(s, dtype=np.float64) for s in sources] + return np.add.reduce(densities) + else: + return self._expectation_density_batched_gen(sources) class CachedUnbinnedThreeMLModelFolding(UnbinnedThreeMLModelFolding): def __init__(self, point_source_response: Optional[UnbinnedThreeMLSourceResponseInterface] = None, extended_source_response: Optional[UnbinnedThreeMLSourceResponseInterface] = None, - vectorize:bool = True): + batch_size: Optional[int] = None): super().__init__(point_source_response=point_source_response, - extended_source_response=extended_source_response) + extended_source_response=extended_source_response, + batch_size=batch_size) self._base_filename = "_source_response_cache.h5" - self._vectorize = vectorize def init_cache(self): """ @@ -115,18 +131,3 @@ def load_caches(self, directory: Union[str, Path], load_only: Optional[Iterable[ filepath = dir_path / f"{name}{self._base_filename}" if filepath.exists(): response.cache_from_file(filepath) - - def _expectation_density_gen(self) -> Iterable[float]: - for exdensity in zip(*[ex.expectation_density() for ex in self._source_responses.values()]): - yield sum(exdensity) - - def expectation_density(self) -> Iterable[float]: - self._cache_source_responses() - if self._vectorize: - if not self._source_responses: - return np.array([], dtype=np.float64) - else: - densities = [asarray(ex.expectation_density(), np.float64) for ex in self._source_responses.values()] - return np.add.reduce(densities) - else: - return self._expectation_density_gen() \ No newline at end of file diff --git a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb index a7ebeb039..79a1a21dd 100644 --- a/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb +++ b/docs/tutorials/spectral_fits/continuum_fit/crab/example_crab_fit_normalizing_flows.ipynb @@ -265,7 +265,12 @@ " irf=irf, \n", " sc_history=sc_orientation, \n", " show_progress=True, \n", - " force_energy_node_caching=True) # Saves the energy nodes even when the batch size is too small" + " force_energy_node_caching=True, # Saves the energy nodes even when the batch size is too small\n", + " reduce_memory=True) # May reduce the peak memory consumption (saves the cache as float32). Look out for warnings\n", + "\n", + "# There are several other options that can be set, for example:\n", + "\n", + "psr.integration_batch_size = 100_000" ] }, {