Skip to content

Commit a590c47

Browse files
committed
Refactor eft and resonant limit tasks.
1 parent 98c7f56 commit a590c47

18 files changed

Lines changed: 128 additions & 215 deletions

dhi/tasks/combine.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -779,48 +779,6 @@ def store_parts(self):
779779
return parts
780780

781781

782-
class MultiDatacardTransposedTask(MultiDatacardTask):
783-
784-
exclude_params_index = {"datacard_names", "datacard_order"}
785-
786-
datacard_names = None
787-
datacard_order = None
788-
group_duplicate_cards = False
789-
790-
@classmethod
791-
def extract_info_from_datacard_path(cls, datacard):
792-
return os.path.splitext(os.path.basename(datacard).rsplit("_", 1)[-1])[0]
793-
794-
def __init__(self, *args, **kwargs):
795-
super(MultiDatacardTransposedTask, self).__init__(*args, **kwargs)
796-
797-
# create a map of datacard info strings to lists of cards that contain it
798-
self.multi_datacards_transposed = OrderedDict()
799-
seen = set()
800-
for datacards in self.multi_datacards:
801-
contains_duplicate = any(datacard in seen for datacard in datacards)
802-
groups = OrderedDict()
803-
for datacard in datacards:
804-
# extract the info string from the basename
805-
info = self.extract_info_from_datacard_path(datacard)
806-
807-
if not self.group_duplicate_cards:
808-
# when not grouping, just add the card
809-
self.multi_datacards_transposed.setdefault(info, [[]])[0].append(datacard)
810-
elif not contains_duplicate:
811-
# when the sequence contains only unseen cards, just add the card
812-
self.multi_datacards_transposed.setdefault(info, []).append([datacard])
813-
else:
814-
# add it to groups for the current sequence
815-
groups.setdefault(info, []).append(datacard)
816-
seen.add(datacard)
817-
818-
# add groups if any
819-
if groups:
820-
for info, group in groups.items():
821-
self.multi_datacards_transposed.setdefault(info, []).append(group)
822-
823-
824782
class ParameterValuesTask(AnalysisTask):
825783

826784
parameter_values = ModelParameters(

dhi/tasks/eft.py

Lines changed: 42 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Tasks related to EFT benchmarks and scans.
55
"""
66

7-
from collections import OrderedDict
7+
import re
8+
from collections import OrderedDict, defaultdict
89

910
import law
1011
import luigi
@@ -13,7 +14,6 @@
1314
from dhi.tasks.remote import HTCondorWorkflow
1415
from dhi.tasks.combine import (
1516
MultiDatacardTask,
16-
MultiDatacardTransposedTask,
1717
POITask,
1818
POIPlotTask,
1919
CombineCommandTask,
@@ -24,28 +24,44 @@
2424
from dhi.config import br_hh
2525

2626

27-
class EFTBase(POITask, MultiDatacardTransposedTask):
27+
class EFTBase(POITask):
2828

29+
datacard_pattern = luigi.Parameter(
30+
default=r"^.*_([^_]+)\.txt$",
31+
description="a regular expression with a single match group that is supposed to point to "
32+
"the benchmark name in the datacard path; default: ^.*_([^_]+)\\.txt$",
33+
)
2934
hh_model = law.NO_STR
3035
allow_empty_hh_model = True
3136

3237
poi = "r_gghh"
3338

3439
@classmethod
35-
def modify_param_values(cls, params):
36-
params = POITask.modify_param_values.__func__.__get__(cls)(params)
37-
params = MultiDatacardTransposedTask.modify_param_values.__func__.__get__(cls)(params)
38-
return params
40+
def _group_datacards(cls, datacards, cre):
41+
groups = defaultdict(list)
42+
for datacard in datacards:
43+
m = cre.match(datacard)
44+
if not m:
45+
raise Exception(
46+
f"no benchmark value could be extracted from datacard '{datacard}' "
47+
f"with pattern '{cre.pattern}'",
48+
)
49+
groups[m.group(1)].append(datacard)
50+
51+
return OrderedDict([
52+
(bm, sorted(groups[bm]))
53+
for bm in sort_eft_benchmark_names(groups.keys())
54+
])
3955

4056
def __init__(self, *args, **kwargs):
4157
super(EFTBase, self).__init__(*args, **kwargs)
4258

43-
# sort EFT datacards according to benchmark names
44-
names = sort_eft_benchmark_names(self.multi_datacards_transposed.keys())
45-
self.benchmark_datacards = OrderedDict(
46-
(name, self.multi_datacards_transposed[name])
47-
for name in names
48-
)
59+
# group datacards into a dictionary benchmark -> [cards]
60+
self.benchmark_datacards = self.group_datacards()
61+
62+
def group_datacards(self):
63+
cre = re.compile(self.datacard_pattern)
64+
return self._group_datacards(self.datacards, cre)
4965

5066
@property
5167
def other_pois(self):
@@ -62,11 +78,10 @@ class EFTBenchmarkLimits(EFTBase, CombineCommandTask, law.LocalWorkflow, HTCondo
6278
run_command_in_tmp = True
6379

6480
def create_branch_map(self):
65-
branch_map = []
66-
for name, cards in self.benchmark_datacards.items():
67-
for _cards in cards:
68-
branch_map.append({"benchmark": name, "cards": _cards})
69-
return branch_map
81+
return [
82+
{"benchmark": benchmark, "cards": cards}
83+
for benchmark, cards in self.benchmark_datacards.items()
84+
]
7085

7186
def workflow_requires(self):
7287
reqs = super(EFTBenchmarkLimits, self).workflow_requires()
@@ -273,45 +288,21 @@ def run(self):
273288
)
274289

275290

276-
class PlotMultipleEFTBenchmarkLimits(PlotEFTBenchmarkLimits):
277-
278-
datacard_names = MultiDatacardTask.datacard_names
279-
datacard_order = MultiDatacardTask.datacard_order
280-
group_duplicate_cards = True
291+
class PlotMultipleEFTBenchmarkLimits(PlotEFTBenchmarkLimits, MultiDatacardTask):
281292

282293
default_plot_function = "dhi.plots.eft.plot_multi_benchmark_limits"
283294

284-
def __init__(self, *args, **kwargs):
285-
super(PlotMultipleEFTBenchmarkLimits, self).__init__(*args, **kwargs)
286-
287-
# check that each mass point has the same amount of cards
288-
n_entries = {len(cards) for cards in self.benchmark_datacards.values()}
289-
if len(n_entries) != 1:
290-
raise Exception("founds different amount of entries in input datacards: {}".format(
291-
",".join(map(str, n_entries)),
292-
))
293-
self.n_entries = list(n_entries)[0]
294-
295-
# the lengths of names and order indices must match multi_datacards when given
296-
if self.datacard_names and len(self.datacard_names) != self.n_entries:
297-
raise Exception("found {} entries in datacard_names whereas {} are expected".format(
298-
len(self.datacard_names), self.n_entries,
299-
))
300-
if self.datacard_order and len(self.datacard_order) != self.n_entries:
301-
raise Exception("found {} entries in datacard_order whereas {} are expected".format(
302-
len(self.datacard_order), self.n_entries,
303-
))
295+
def group_datacards(self):
296+
cre = re.compile(self.datacard_pattern)
297+
return [
298+
self._group_datacards(datacards, cre)
299+
for datacards in self.multi_datacards
300+
]
304301

305302
def requires(self):
306303
return [
307-
MergeEFTBenchmarkLimits.req(
308-
self,
309-
multi_datacards=tuple(
310-
tuple(cards[i])
311-
for cards in self.benchmark_datacards.values()
312-
),
313-
)
314-
for i in range(self.n_entries)
304+
MergeEFTBenchmarkLimits.req(self, datacards=tuple(sum(groups.values(), [])))
305+
for groups in self.benchmark_datacards
315306
]
316307

317308
def output(self):

dhi/tasks/resonant.py

Lines changed: 42 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Tasks related to upper limits on resonant scenarios.
55
"""
66

7-
from collections import OrderedDict
7+
import re
8+
from collections import OrderedDict, defaultdict
89

910
import law
1011
import luigi
@@ -13,7 +14,6 @@
1314
from dhi.tasks.remote import HTCondorWorkflow
1415
from dhi.tasks.combine import (
1516
MultiDatacardTask,
16-
MultiDatacardTransposedTask,
1717
POITask,
1818
POIPlotTask,
1919
CombineCommandTask,
@@ -23,35 +23,45 @@
2323
from dhi.config import br_hh
2424

2525

26-
class ResonantBase(POITask, MultiDatacardTransposedTask):
26+
class ResonantBase(POITask):
2727

28+
datacard_pattern = luigi.Parameter(
29+
default=r"^.*_(\d+)\.txt$",
30+
description="a regular expression with a single match group that is supposed to point to "
31+
"the resonance mass value in the datacard path; default: ^.*_(\\d+)\\.txt$",
32+
)
2833
hh_model = law.NO_STR
2934
allow_empty_hh_model = True
3035

3136
poi = "r_xhh"
3237
scan_parameter = "mhh"
3338

3439
@classmethod
35-
def modify_param_values(cls, params):
36-
params = POITask.modify_param_values.__func__.__get__(cls)(params)
37-
params = MultiDatacardTransposedTask.modify_param_values.__func__.__get__(cls)(params)
38-
return params
40+
def _group_datacards(cls, datacards, cre):
41+
groups = defaultdict(list)
42+
for datacard in datacards:
43+
m = cre.match(datacard)
44+
if not m:
45+
raise Exception(
46+
f"no resonance mass could be extracted from datacard '{datacard}' "
47+
f"with pattern '{cre.pattern}'",
48+
)
49+
groups[int(m.group(1))].append(datacard)
50+
51+
return OrderedDict([
52+
(mass, sorted(groups[mass]))
53+
for mass in sorted(groups)
54+
])
3955

4056
def __init__(self, *args, **kwargs):
4157
super(ResonantBase, self).__init__(*args, **kwargs)
4258

43-
# convert keys in multi_datacards_transposed to integers and store them as resonant cards
44-
pairs = []
45-
for info, datacards in self.multi_datacards_transposed.items():
46-
try:
47-
mass = int(info)
48-
except:
49-
raise Exception(
50-
"datacards contain a mass point '{}' which cannot be interpreted as an "
51-
"integer".format(info),
52-
)
53-
pairs.append((mass, datacards))
54-
self.resonant_datacards = OrderedDict(sorted(pairs, key=lambda pair: pair[0]))
59+
# group datacards into a dictionary mass -> [cards]
60+
self.resonant_datacards = self.group_datacards()
61+
62+
def group_datacards(self):
63+
cre = re.compile(self.datacard_pattern)
64+
return self._group_datacards(self.datacards, cre)
5565

5666
@property
5767
def other_pois(self):
@@ -68,11 +78,10 @@ class ResonantLimits(ResonantBase, CombineCommandTask, law.LocalWorkflow, HTCond
6878
run_command_in_tmp = True
6979

7080
def create_branch_map(self):
71-
branch_map = []
72-
for mass, cards in self.resonant_datacards.items():
73-
for _cards in cards:
74-
branch_map.append({"mass": mass, "cards": _cards})
75-
return branch_map
81+
return [
82+
{"mass": mass, "cards": cards}
83+
for mass, cards in self.resonant_datacards.items()
84+
]
7685

7786
def workflow_requires(self):
7887
reqs = super(ResonantLimits, self).workflow_requires()
@@ -288,45 +297,21 @@ def run(self):
288297
)
289298

290299

291-
class PlotMultipleResonantLimits(PlotResonantLimits):
292-
293-
datacard_names = MultiDatacardTask.datacard_names
294-
datacard_order = MultiDatacardTask.datacard_order
295-
group_duplicate_cards = True
300+
class PlotMultipleResonantLimits(PlotResonantLimits, MultiDatacardTask):
296301

297302
default_plot_function = "dhi.plots.limits.plot_limit_scans"
298303

299-
def __init__(self, *args, **kwargs):
300-
super(PlotMultipleResonantLimits, self).__init__(*args, **kwargs)
301-
302-
# check that each mass point has the same amount of cards
303-
n_entries = {len(cards) for cards in self.resonant_datacards.values()}
304-
if len(n_entries) != 1:
305-
raise Exception("founds different amount of entries in input datacards: {}".format(
306-
",".join(map(str, n_entries)),
307-
))
308-
self.n_entries = list(n_entries)[0]
309-
310-
# the lengths of names and order indices must match multi_datacards when given
311-
if self.datacard_names and len(self.datacard_names) != self.n_entries:
312-
raise Exception("found {} entries in datacard_names whereas {} are expected".format(
313-
len(self.datacard_names), self.n_entries,
314-
))
315-
if self.datacard_order and len(self.datacard_order) != self.n_entries:
316-
raise Exception("found {} entries in datacard_order whereas {} are expected".format(
317-
len(self.datacard_order), self.n_entries,
318-
))
304+
def group_datacards(self):
305+
cre = re.compile(self.datacard_pattern)
306+
return [
307+
self._group_datacards(datacards, cre)
308+
for datacards in self.multi_datacards
309+
]
319310

320311
def requires(self):
321312
return [
322-
MergeResonantLimits.req(
323-
self,
324-
multi_datacards=tuple(
325-
tuple(cards[i])
326-
for cards in self.resonant_datacards.values()
327-
),
328-
)
329-
for i in range(self.n_entries)
313+
MergeResonantLimits.req(self, datacards=tuple(sum(groups.values(), [])))
314+
for groups in self.resonant_datacards
330315
]
331316

332317
def output(self):

dhi/tasks/snapshot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import luigi
1111

1212
from dhi.tasks.remote import HTCondorWorkflow
13+
from dhi.tasks.base import AnalysisTask
1314
from dhi.tasks.combine import (
1415
CombineCommandTask,
1516
POITask,
@@ -82,7 +83,7 @@ def build_command(self, fallback_level):
8283
return cmd
8384

8485

85-
class SnapshotUser(object):
86+
class SnapshotUser(AnalysisTask):
8687

8788
use_snapshot = luigi.BoolParameter(
8889
default=False,

dhi/tasks/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def requires(self):
413413
if self.check_enabled("eft_benchmark_limits"):
414414
reqs["eft_benchmark_limits"] = PlotEFTBenchmarkLimits.req(
415415
self,
416-
multi_datacards=(eft_bm_cards,),
416+
datacards=eft_bm_cards,
417417
unblinded=True,
418418
xsec="fb",
419419
y_log=True,
@@ -459,7 +459,7 @@ def requires(self):
459459
if self.check_enabled("resonant_limits"):
460460
reqs["resonant_limits"] = PlotResonantLimits.req(
461461
self,
462-
multi_datacards=(res_cards,),
462+
datacards=res_cards,
463463
unblinded=False,
464464
xsec="fb",
465465
y_log=True,

docs/content/snippets/createworkspace_param_tab.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ The `CreateWorkspace` task takes the combined datacard and the PhysicsModel as i
22

33
<div class="dhi_parameter_table">
44

5-
--8<-- "content/snippets/parameters.md@-2,20,19,34"
5+
--8<-- "content/snippets/parameters.md@-2,20,19,34,98,99"
66

77
</div>

docs/content/snippets/eftbenchmarklimits_param_tab.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ The `EFTBenchmarkLimits` task computes the limits of each benchmark datacard.
22

33
<div class="dhi_parameter_table">
44

5-
--8<-- "content/snippets/parameters.md@-2,21,48,34,17,18,47"
5+
--8<-- "content/snippets/parameters.md@-2,20,101,48,34,17,18,47"
66

77
</div>

0 commit comments

Comments
 (0)