Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom modifiers for pure scalar functions with jax and sympy #2579

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

lukasheinrich
Copy link
Contributor

@lukasheinrich lukasheinrich commented Mar 17, 2025

Description

This is a PR to add a "pure function" modifier as an external / optional contribution under the contrib directory.

This should support modifiers or the form

{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'theta_0 + theta_1'}},
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'sqrt(theta_0)'}},
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu + theta_0'}},

etc

This is build on top of sympy and jax and thus requires pyhf.set_backend('jax') as well as possibly a sympy extra (or we assume it's externally installed for the parsing of the expressions

The supported formulas involve either the pre-existing parameters (e.g. a mu coming from a normfactor or new parameters that can be added.

Basic Usage

spec = {
        'channels': [
            {
                'name': 'channel1',
                'samples': [
                    {
                        'name': 'signal1',
                        'data': [10,10,10],
                        'modifiers': [
                            {'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu**2'}},
                        ],
                    },
                ],
            }
        ]
}

import jax.numpy as jnp
import numpy as np
import pyhf
from pyhf.contrib.extended_modifiers import purefunc
pyhf.set_backend('jax')
modifier_set = purefunc.enable()
m = pyhf.Model(
    spec,
    modifier_set=modifier_set,
    poi_name='mu',
    validate=False,
)
pars = np.array(m.config.suggested_init())
pars[m.config.par_slice('mu')] = 2.0
pars = jnp.array(pars)
m.expected_actualdata(pars)

Note that this will heavily rely on jax.jit so care should be taking (by us/the user) that the model is properly jitted before fitting

Tagging @mswiatlo @alexander-held @matthewfeickert @kratsg @nhartman94 @malin-horstmann

@lukasheinrich lukasheinrich force-pushed the custom_mods_sympy_jax branch from 3900dbc to 1eedbe4 Compare March 17, 2025 10:50
Copy link
Member

@alexander-held alexander-held left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I've seen so far this works well, very nice feature!

I think the discovery of parameters to be created does not handle nesting at the moment where new custom functions are functions of existing custom functions. With the example from the PR body, consider a setup like this:

{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu**2'}},
{'name': 'interpscale2', 'type': 'purefunc', 'data': {'formula': 'interpscale1**2'}},

which shows via print(m.config.par_order) that interpscale1 becomes a parameter of the model (but it should not be). I think this could either be a design decision and users would have to fully resolve all intermediate expressions (that should result in the same functionality I think) or should somehow work automatically.

I was trying this out while thinking about cyclic dependencies but then noticed they do not cause problems because each function creates its own parameters instead of carrying through the dependency.

@lukasheinrich lukasheinrich force-pushed the custom_mods_sympy_jax branch from 0b2b712 to 83645d6 Compare March 17, 2025 14:04
@@ -0,0 +1,162 @@
import sympy.parsing.sympy_parser as parser
import sympy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll review this properly later today, but we'll need to add sympy to the library dependencies given that it is only bring brought in by PyTorch at the time being.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be useful to have a look at https://github.com/patrick-kidger/sympy2jax aswell, as it automatically allows to create transformable (jit, grad, ...) JAX expressions from sympy expressions.

@kratsg
Copy link
Contributor

kratsg commented Mar 17, 2025

This shouldn't really only depend on jax to make it work sympy allows you to parse out to different "backends" as needed although I think for the numpy case, one needs to swap to numexpr just for that part (which is really not able to symbolically parse, just symbolically evaluate). For example, if you convert to pytensor, then it's a quick step over to switch to pytorch/jax/tensorflow: https://github.com/scipp-atlas/pyhs3/blob/207fffae43b69528c27b26684955787e29206120/src/pyhs3/parse.py#L42-L88

Looks good, but to address @alexander-held's point about the nested dependencies requires building a tree and resolving dependencies in order (which takes a bit more logic). That might need to be held off into a separate PR. In the meantime, you can always just expand out the function calls yourself, or do a "pre-parsing" that expands out all nested parameter definitions first, to normalize/flatten things.

@alexander-held
Copy link
Member

In #1991 we had a conversation about pyhf.experimental vs pyhf.contrib, perhaps this one should go into the former to clearly signal its current status?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feat/enhancement New feature or request
Projects
Status: Review in progress
Development

Successfully merging this pull request may close these issues.

5 participants