-
-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated the closures with @closure to avoid boxing #924
base: master
Are you sure you want to change the base?
Conversation
@@ -57,7 +57,7 @@ multioutput = chain isa AbstractArray | |||
strategy = NeuralPDE.GridTraining(dx) | |||
integral = NeuralPDE.get_numeric_integral(strategy, indvars, multioutput, chain, derivative) | |||
|
|||
_pde_loss_function = NeuralPDE.build_loss_function(eq, indvars, depvars, phi, derivative, | |||
_pde_loss_function = @closure NeuralPDE.build_loss_function(eq, indvars, depvars, phi, derivative, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't a closure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added @closure here under the assumption that it would be necessary for capturing variables, but I see that this isn't forming a closure.
@@ -105,14 +114,21 @@ res = solve(prob, BFGS(); maxiters = 100, callback) | |||
phi = discretization.phi | |||
|
|||
# Analysis | |||
# Analysis with closure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why mention this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you show a concrete improvement from this? What led to it? Can you show what functions had issues with inference and a flamegraph? |
@@ -16,7 +16,7 @@ steps: | |||
# Don't run Buildkite if the commit message includes the text [skip tests] | |||
if: build.message !~ /\[skip tests\]/ | |||
|
|||
- label: "Documentation" | |||
- label: "Documentation"p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- label: "Documentation"p | |
- label: "Documentation" |
@@ -17,6 +17,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | |||
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" | |||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | |||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | |||
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was for searching a keyword across all files, automating the search with a script.
@@ -38,6 +39,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" | |||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" | |||
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" | |||
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" | |||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was throwing KeyError: key "SparseArrays" not found
error continuously, this helped.
@@ -168,7 +168,7 @@ end | |||
|
|||
MiniMaxAdaptiveLoss(args...; kwargs...) = MiniMaxAdaptiveLoss{Float64}(args...; kwargs...) | |||
|
|||
function generate_adaptive_loss_function(pinnrep::PINNRepresentation, | |||
@closure function generate_adaptive_loss_function(pinnrep::PINNRepresentation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the wrong spot
@@ -28,7 +28,7 @@ end | |||
|
|||
NonAdaptiveLoss(; kwargs...) = NonAdaptiveLoss{Float64}(; kwargs...) | |||
|
|||
function generate_adaptive_loss_function(::PINNRepresentation, ::NonAdaptiveLoss, _, __) | |||
@closure function generate_adaptive_loss_function(::PINNRepresentation, ::NonAdaptiveLoss, _, __) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrong spot
This PR refactors the codebase to replace traditional closures with the @closures annotation. This change improves readability, reduces redundant closure definitions, and enhances performance by leveraging Julia's built-in annotation for automatic closure handling.