Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom modifiers for pure scalar functions with jax and sympy #2579

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
162 changes: 162 additions & 0 deletions src/pyhf/contrib/extended_modifiers/purefunc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import sympy.parsing.sympy_parser as parser
import sympy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll review this properly later today, but we'll need to add sympy to the library dependencies given that it is only bring brought in by PyTorch at the time being.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be useful to have a look at https://github.com/patrick-kidger/sympy2jax aswell, as it automatically allows to create transformable (jit, grad, ...) JAX expressions from sympy expressions.

from pyhf.parameters import ParamViewer
import jax.numpy as jnp
import jax


def create_modifiers():

class PureFunctionModifierBuilder:
is_shared = True

def __init__(self, pdfconfig):
self.config = pdfconfig
self.required_parsets = {}
self.builder_data = {'local': {}, 'global': {'symbols': set()}}
self.encountered_expressions = {}

def collect(self, thismod, nom):
maskval = True if thismod else False
mask = [maskval] * len(nom)
return {'mask': mask}

def require_symbols_as_scalars(self, symbols):
param_spec = {
p: [
{
'paramset_type': 'unconstrained',
'n_parameters': 1,
'is_shared': True,
'inits': (1.0,),
'bounds': ((0, 10),),
'is_scalar': True,
'fixed': False,
}
]
for p in symbols
}
return param_spec

def append(self, key, channel, sample, thismod, defined_samp):
self.builder_data['local'].setdefault(key, {}).setdefault(
sample, {}
).setdefault('data', {'mask': []})

nom = (
defined_samp['data']
if defined_samp
else [0.0] * self.config.channel_nbins[channel]
)
moddata = self.collect(thismod, nom)
self.builder_data['local'][key][sample]['data']['mask'] += moddata['mask']

if thismod is not None:
formula = thismod['data']['formula']
parsed = parser.parse_expr(formula)
free_symbols = parsed.free_symbols
for x in free_symbols:
if x not in self.encountered_expressions:
self.builder_data['global'].setdefault('symbols', set()).add(x)
else:
parsed = None
self.builder_data['local'].setdefault(key, {}).setdefault(
sample, {}
).setdefault('channels', {}).setdefault(channel, {})['parsed'] = parsed

def finalize(self):
list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']]

self.required_parsets = self.require_symbols_as_scalars(list_of_symbols)

self.builder_data['global']['symbol_names'] = list_of_symbols
for modname, modspec in self.builder_data['local'].items():
for sample, samplespec in modspec.items():
for channel, channelspec in samplespec['channels'].items():
if channelspec['parsed'] is not None:
channelspec['jaxfunc'] = sympy.lambdify(
list_of_symbols, channelspec['parsed'], 'jax'
)
else:
channelspec['jaxfunc'] = lambda *args: 1.0
return self.builder_data

class PureFunctionModifierApplicator:
op_code = 'multiplication'
name = 'purefunc'

def __init__(
self, modifiers=None, pdfconfig=None, builder_data=None, batch_size=None
):
self.builder_data = builder_data
self.batch_size = batch_size
self.pdfconfig = pdfconfig
self.inputs = [str(x) for x in builder_data['global']['symbols']]

self.keys = [f'{mtype}/{m}' for m, mtype in modifiers]
self.modifiers = [m for m, _ in modifiers]

parfield_shape = (
(self.batch_size, pdfconfig.npars)
if self.batch_size
else (pdfconfig.npars,)
)

self.param_viewer = ParamViewer(
parfield_shape, pdfconfig.par_map, self.inputs
)
self.create_jax_eval()

def create_jax_eval(self):
def eval_func(pars):
return jnp.array(
[
[
jnp.concatenate(
[
self.builder_data['local'][m][s]['channels'][c][
'jaxfunc'
](*pars)
* jnp.ones(self.pdfconfig.channel_nbins[c])
for c in self.pdfconfig.channels
]
)
for s in self.pdfconfig.samples
]
for m in self.keys
]
)

self.jaxeval = eval_func

def apply_nonbatched(self, pars):
return jnp.expand_dims(self.jaxeval(pars), 2)

def apply_batched(self, pars):
return jax.vmap(self.jaxeval, in_axes=(1,), out_axes=2)(pars)

def apply(self, pars):
if not self.param_viewer.index_selection:
return
if self.batch_size is None:
par_selection = self.param_viewer.get(pars)
results_purefunc = self.apply_nonbatched(par_selection)
else:
par_selection = self.param_viewer.get(pars)
results_purefunc = self.apply_batched(par_selection)
return results_purefunc

return PureFunctionModifierBuilder, PureFunctionModifierApplicator


from pyhf.modifiers import histfactory_set


def enable():
modifier_set = {}
modifier_set.update(**histfactory_set)

builder, applicator = create_modifiers()

modifier_set.update(**{applicator.name: (builder, applicator)})
return modifier_set
Loading