Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDE support for GaussAdjoint #945

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

acoh64
Copy link
Contributor

@acoh64 acoh64 commented Nov 30, 2023

Not complete yet. Currently works for for gradients with respect to initial condition but not parameters.

@frankschae
Copy link
Member

You will need to add up the vjp contributions from the drift and the diffusion in the integration.

@acoh64
Copy link
Contributor Author

acoh64 commented Dec 14, 2023

Just added a new commit where I include the diffusion vjp (I think?), but still doesn't seem to work

@ChrisRackauckas
Copy link
Member

What's left here?

@acoh64
Copy link
Contributor Author

acoh64 commented Jan 4, 2024

What's left here?

@frankschae Could you take a quick look? I think the basics are in there, but still debugging why results are off

if sensealg.autojacvec isa ZygoteVJP
if W === nothing
_dy, back = Zygote.pullback(y, p) do u, p
vec(g(u, p, t))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to compute the following:
vec(g(u, p, t)*dW)

(sorry for the slow responses. I'm catching up rn -- got a bad cold.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this too (see most recent commit) but don't seem to get the right results

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check if it's the correct dW that's extracted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I am a but confused because $dW$ should be a function of $t$, right? The SDE solution type only stores $W$ as a function of time. Should I use $dW_i = W_{t_{i+1}} - W_{t_i}$, or am I misunderstanding something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in your current version, you assume dW to be constant. If you use several Gauss points for the integral (how many do you use?), you should probably compute dW appropriately for those times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they do have an order -- EM is 0.5 strong/ 1 weak 😅

$dW_i = W_{t_{i+1}} - W_{t_i}$

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks! Should these be the $t_{i}$'s from the forward pass? Then in the backwards pass, should the $t$'s that the solver stops at be identical to the forward pass?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now the best thing to do to check the implementation is to make sure it only uses $t$ values that were encountered in the forward pass -- in general I think it's not strictly necessary because we can draw from a Brownian bridge... but this likely increases the variance of the estimate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense, thanks. I will try to implement this now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can force the solver stop at the $t$ values from the forward pass but I don't think it is possible to ensure that the Gauss points for computing the integral are always at these points. Should I just interpolate or will I have to use a Brownian bridge?

@acoh64
Copy link
Contributor Author

acoh64 commented Feb 11, 2024

@ChrisRackauckas, is there a way to access all the dWs from the forward solve to build an interpolant for evaluating it at the Gauss points during the integral calculation? I think currently only the final dW is stored in the solution struct. Thanks!

@ChrisRackauckas
Copy link
Member

It should be there if you did save_noise and save_everystep

@acoh64
Copy link
Contributor Author

acoh64 commented Feb 14, 2024

It should be there if you did save_noise and save_everystep

Ah I see, this is what is stored in sol.W.u. I will try to build a linear interpolant with this

@ChrisRackauckas
Copy link
Member

IIRC dW already has an interpolation on it.

@acoh64
Copy link
Contributor Author

acoh64 commented Feb 14, 2024

There only seems to be interpolation for W, dW is just a float

@ChrisRackauckas
Copy link
Member

dW is the difference of the W's . You wouldn't ever want to interpolate that.

@acoh64
Copy link
Contributor Author

acoh64 commented Feb 14, 2024

Should I use the bridge function in W then? Although I am not exactly sure what all the arguments to it are since it is different than what I have seen before.

@ChrisRackauckas
Copy link
Member

@frankschae do you think you can find time for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants