Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/rail/core/common_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,6 @@ def set_param_defaults(**kwargs: Any) -> None: # pragma: no cover
set_param_default = SharedParams.set_param_default

set_param_defaults = SharedParams.set_param_defaults

TOMOGRAPHY_NONE = -1
TOMOGRAPHY_ALL = -2
13 changes: 8 additions & 5 deletions src/rail/estimation/algos/naive_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ceci.config import StageParameter as Param

from rail.core.data import QPHandle, TableHandle, TableLike
from rail.core.common_params import SharedParams
from rail.core.common_params import SharedParams, TOMOGRAPHY_ALL, TOMOGRAPHY_NONE
from rail.estimation.informer import PzInformer
from rail.estimation.summarizer import PZSummarizer

Expand Down Expand Up @@ -151,17 +151,17 @@ class NaiveStackMaskedSummarizer(NaiveStackSummarizer):
interactive_function = "naive_stack_masked_summarizer"
config_options = NaiveStackSummarizer.config_options.copy()
config_options.update(
selected_bin=Param(int, -1, msg="bin to use"),
selected_bin=Param(int, TOMOGRAPHY_NONE, msg=f"bin to use, or {TOMOGRAPHY_ALL} for all bins >=0 or {TOMOGRAPHY_NONE} for no masking"),
)
inputs = [("input", QPHandle), ("tomography_bins", TableHandle)]
outputs = [("output", QPHandle), ("single_NZ", QPHandle)]

def _setup_iterator(self) -> Generator:
selected_bin = self.config.selected_bin
if self.config.tomography_bins in ["none", None]:
selected_bin = -1
selected_bin = TOMOGRAPHY_NONE

if selected_bin == -1:
if selected_bin == TOMOGRAPHY_NONE:
itrs = [self.input_iterator("input")]
else:
itrs = [
Expand All @@ -179,7 +179,10 @@ def _setup_iterator(self) -> Generator:
pz_data = d
first = False
else:
mask = d["class_id"] == self.config.selected_bin
if selected_bin == TOMOGRAPHY_ALL:
mask = d["class_id"] >= 0
else:
mask = d["class_id"] == selected_bin
if mask is None:
mask = np.ones(
pz_data.npdf, # pylint: disable=possibly-used-before-assignment
Expand Down
13 changes: 8 additions & 5 deletions src/rail/estimation/algos/point_est_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ceci.config import StageParameter as Param

from rail.core.data import QPHandle, TableHandle, TableLike
from rail.core.common_params import SharedParams
from rail.core.common_params import SharedParams, TOMOGRAPHY_ALL, TOMOGRAPHY_NONE
from rail.estimation.informer import PzInformer
from rail.estimation.summarizer import PZSummarizer

Expand Down Expand Up @@ -121,17 +121,17 @@ class PointEstHistMaskedSummarizer(PointEstHistSummarizer):
interactive_function = "point_est_hist_masked_summarizer"
config_options = PointEstHistSummarizer.config_options.copy()
config_options.update(
selected_bin=Param(int, -1, msg="bin to use"),
selected_bin=Param(int, TOMOGRAPHY_NONE, msg=f"bin to use, or {TOMOGRAPHY_ALL} for all bins >=0 or {TOMOGRAPHY_NONE} for no masking")
)
inputs = [("input", QPHandle), ("tomography_bins", TableHandle)]
outputs = [("output", QPHandle), ("single_NZ", QPHandle)]

def _setup_iterator(self) -> Generator:
selected_bin = self.config.selected_bin
if self.config.tomography_bins in ["none", None]:
selected_bin = -1
selected_bin = TOMOGRAPHY_NONE

if selected_bin == -1:
if selected_bin == TOMOGRAPHY_NONE:
itrs = [self.input_iterator("input")]
else:
itrs = [
Expand All @@ -149,7 +149,10 @@ def _setup_iterator(self) -> Generator:
pz_data = d
first = False
else:
mask = d["class_id"] == self.config.selected_bin
if selected_bin == TOMOGRAPHY_ALL:
mask = d["class_id"] >= 0
else:
mask = d["class_id"] == selected_bin
if mask is None:
mask = np.ones(
pz_data.npdf, # pylint: disable=possibly-used-before-assignment
Expand Down
9 changes: 8 additions & 1 deletion tests/estimation/test_summarizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rail.core.stage import RailStage
from rail.estimation.algos import naive_stack, point_est_hist, var_inf
from rail.utils.path_utils import RAILDIR
from rail.core.common_params import TOMOGRAPHY_ALL, TOMOGRAPHY_NONE

testdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.hdf5")
tomobins = os.path.join(RAILDIR, "rail/examples_data/testdata/output_tomo.hdf5")
Expand Down Expand Up @@ -48,8 +49,10 @@ def one_mask_algo(
# tomo_bins = DS.read_file("tomo_bins", TableHandle, tomobins)
test_data = QPHandle("test_data", path=testdata)
tomo_bins = TableHandle("tomo_bins", path=tomobins)
summary_kwargs = summary_kwargs.copy()
selected_bin = summary_kwargs.pop("selected_bin", 1)

summarizer = summarizer_class.make_stage(name=key, selected_bin=1, **summary_kwargs)
summarizer = summarizer_class.make_stage(name=key, selected_bin=selected_bin, **summary_kwargs)
summary_ens = summarizer.summarize(test_data, tomo_bins)
os.remove(
summarizer.get_output(summarizer.get_aliased_tag("output"), final_name=True)
Expand Down Expand Up @@ -126,6 +129,8 @@ def test_naive_stack_masked() -> None:
)
summarizer_class = naive_stack.NaiveStackMaskedSummarizer
_ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict)
_ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict | {"selected_bin": TOMOGRAPHY_ALL})
_ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict | {"selected_bin": TOMOGRAPHY_NONE})
_ = one_algo("NaiveStack", summarizer_class, summary_config_dict)


Expand All @@ -138,4 +143,6 @@ def test_point_estimate_hist_masked() -> None:
)
summarizer_class = point_est_hist.PointEstHistMaskedSummarizer
_ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict)
_ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict | {"selected_bin": TOMOGRAPHY_ALL})
_ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict | {"selected_bin": TOMOGRAPHY_NONE})
_ = one_algo("PointEstimateHist", summarizer_class, summary_config_dict)
Loading