Skip to content

Extend verify_grad to complex gradient #1367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

educhesne
Copy link
Contributor

@educhesne educhesne commented Apr 14, 2025

Description

Extend verify_grad to complex gradient following the holomorphic gradient convention (as in JAX).
The decision on which convention to follow (JAX-like or torch-like) has not been taken yet; see issue #1366.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1367.org.readthedocs.build/en/1367/

Copy link

codecov bot commented Apr 14, 2025

Codecov Report

Attention: Patch coverage is 80.00000% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.05%. Comparing base (f1514eb) to head (806fad1).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/gradient.py 80.00% 1 Missing and 3 partials ⚠️

❌ Your patch check has failed because the patch coverage (80.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1367   +/-   ##
=======================================
  Coverage   82.05%   82.05%           
=======================================
  Files         203      203           
  Lines       48863    48875   +12     
  Branches     8695     8696    +1     
=======================================
+ Hits        40093    40103   +10     
- Misses       6619     6620    +1     
- Partials     2151     2152    +1     
Files with missing lines Coverage Δ
pytensor/gradient.py 78.62% <80.00%> (+0.07%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

The link in the original issues argues very much for the non-JAX approach. Why did JAX go this way? Do the arguments they make hold for higher order auto-diff?

Also want to modify one Op in this PR to show it working for complex gradients?

@educhesne
Copy link
Contributor Author

I'm afraid I've raised an issue which is beyond my knowledge... I'll try to summarize what I've understood so far though.

In the case of a real-valued function, jax and torch are equivalent (up to a conjugate).
However when the function is complex-valued torch cannot compute the derivative. The gradient of torch assumes in its chain rule that $\frac{\partial L}{\partial z^\star} = \left(\frac{\partial L}{\partial z}\right)^\star$ which is true when $L$ is real-valued but not in general.

The internals of jax go way over my head, but the doc suggests that it computes internally the full derivative of $f(x+iy) = u(x,y) + i v(x,y)$ as if it was a $\mathbb R^2 \to \mathbb R^2$ function, that is to says the 2x2 $\mathbb R$-matrix of partial derivatives of $u$ and $v$, and returns $\frac{\partial u}{\partial x} - i\frac{\partial u}{\partial y}$ as the gradient.
Whenever $f$ is real-valued ($v=0$), this expression is equal to $2\frac{\partial f}{\partial z}$ (ie the conjugate of what torch returns), and whenever $f$ is holomorphic it is equal to $\frac{\partial f}{\partial z}$. So clearly jax has a broader scope than torch.

In order to deal with higher order auto-diff I reckon we need both differentiation wrt $z$ and wrt $z^\star$. Even if the function is real-valued, its gradient is complex-valued and not necessarily holomorphic (I don't know how to do that in jax...)

Regarding pytensor I think if in L_op the grads arguments contained the derivative wrt to $z$ and $z^\star$ (and L_op returned them as well), then it would be possible to compute the chain rule in the general complex case.
I realize that it is much more complicated than I thought initially...

@ricardoV94
Copy link
Member

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants