-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: custom modifiers for pure scalar functions with jax and sympy #2579
base: main
Are you sure you want to change the base?
Conversation
3900dbc
to
1eedbe4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I've seen so far this works well, very nice feature!
I think the discovery of parameters to be created does not handle nesting at the moment where new custom functions are functions of existing custom functions. With the example from the PR body, consider a setup like this:
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu**2'}},
{'name': 'interpscale2', 'type': 'purefunc', 'data': {'formula': 'interpscale1**2'}},
which shows via print(m.config.par_order)
that interpscale1
becomes a parameter of the model (but it should not be). I think this could either be a design decision and users would have to fully resolve all intermediate expressions (that should result in the same functionality I think) or should somehow work automatically.
I was trying this out while thinking about cyclic dependencies but then noticed they do not cause problems because each function creates its own parameters instead of carrying through the dependency.
0b2b712
to
83645d6
Compare
for more information, see https://pre-commit.ci
@@ -0,0 +1,162 @@ | |||
import sympy.parsing.sympy_parser as parser | |||
import sympy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll review this properly later today, but we'll need to add sympy
to the library dependencies given that it is only bring brought in by PyTorch at the time being.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be useful to have a look at https://github.com/patrick-kidger/sympy2jax aswell, as it automatically allows to create transformable (jit, grad, ...) JAX expressions from sympy expressions.
This shouldn't really only depend on Looks good, but to address @alexander-held's point about the nested dependencies requires building a tree and resolving dependencies in order (which takes a bit more logic). That might need to be held off into a separate PR. In the meantime, you can always just expand out the function calls yourself, or do a "pre-parsing" that expands out all nested parameter definitions first, to normalize/flatten things. |
In #1991 we had a conversation about |
Description
This is a PR to add a "pure function" modifier as an external / optional contribution under the
contrib
directory.This should support modifiers or the form
etc
This is build on top of
sympy
andjax
and thus requirespyhf.set_backend('jax')
as well as possibly a sympy extra (or we assume it's externally installed for the parsing of the expressionsThe supported formulas involve either the pre-existing parameters (e.g. a
mu
coming from anormfactor
or new parameters that can be added.Basic Usage
Note that this will heavily rely on
jax.jit
so care should be taking (by us/the user) that the model is properly jitted before fittingTagging @mswiatlo @alexander-held @matthewfeickert @kratsg @nhartman94 @malin-horstmann