Skip to content

Commit

Permalink
Merge pull request #237 from xtalax/recursiveunwrap
Browse files Browse the repository at this point in the history
Recursively unwrap equations
  • Loading branch information
xtalax authored Jan 28, 2023
2 parents afe88f3 + 47fdca7 commit 0508657
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 32 deletions.
4 changes: 0 additions & 4 deletions src/MOL_discretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::Method
############################
alleqs = []
bceqs = []
# * We wamt to do this in 2 passes
# * First parse the system and BCs, replacing with DiscreteVariables and DiscreteDerivatives
# * periodic parameters get type info on whether they are periodic or not, and if they join up to any other parameters
# * Then we can do the actual discretization by recursively indexing in to the DiscreteVariables

# Create discretized space and variables, this is called `s` throughout
s = DiscreteSpace(v, discretization)
Expand Down
23 changes: 23 additions & 0 deletions src/MOL_symbolic_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ function split_terms(eq::Equation)
return vcat(lhs, rhs)
end

function _split_terms(term)
S = Symbolics
SU = SymbolicUtils
# TODO: Update this to be exclusive of derivatives and depvars rather than inclusive of +-/*
if S.istree(term) && ((operation(term) == +) | (operation(term) == -) | (operation(term) == *) | (operation(term) == /))
return mapreduce(_split_terms, vcat, SU.arguments(term))
else
return [term]
end
end

# Additional handling to get around limitations in rules
# Splits out derivatives from containing math expressions for ingestion by the rules
function _split_terms(term, x̄)
Expand Down Expand Up @@ -244,3 +255,15 @@ function ex2term(term, v)
name = Symbol("" * string(term) * "")
return setname(similarterm(exdv, rename(operation(exdv), name), arguments(exdv)), name)
end

safe_unwrap(x) = x isa Num ? unwrap(x) : x

function recursive_unwrap(ex)
if !istree(ex)
return safe_unwrap(ex)
end

op = operation(ex)
args = arguments(ex)
return safe_unwrap(op(recursive_unwrap.(args)))
end
13 changes: 0 additions & 13 deletions src/MOL_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@ Get a unit `CartesianIndex` in dimension `j` of length `N`.
"""
unitindex(N, j) = CartesianIndex(ntuple(i -> i == j, N))

function _split_terms(term)
S = Symbolics
SU = SymbolicUtils
# TODO: Update this to be exclusive of derivatives and depvars rather than inclusive of +-/*
if S.istree(term) && ((operation(term) == +) | (operation(term) == -) | (operation(term) == *) | (operation(term) == /))
return mapreduce(_split_terms, vcat, SU.arguments(term))
else
return [term]
end
end

@inline clip(II::CartesianIndex{M}, j, N) where {M} = II[j] > N ? II - unitindices(M)[j] : II

remove(args, t) = filter(x -> t === nothing || !isequal(safe_unwrap(x), safe_unwrap(t)), args)
Expand Down Expand Up @@ -63,8 +52,6 @@ function generate_coordinates(i::Int, stencil_x, dummy_x,
return stencil_x
end

safe_unwrap(x) = x isa Num ? x.val : x

function _get_gridloc(s, ut, is...)
u = Sym{SymbolicUtils.FnType{Tuple, Real}}(nameof(operation(ut)))
u = operation(s.ū[findfirst(isequal(u), operation.(s.ū))])
Expand Down
1 change: 0 additions & 1 deletion src/discretization/discretize_vars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ function DiscreteSpace(vars, discretization::MOLFiniteDifference{G}) where {G}
end
end


return DiscreteSpace{nspace,length(depvars),G}(vars, Dict(depvarsdisc), axies, grid, Dict(dxs), Dict(Iaxies), Dict(Igrid))
end

Expand Down
2 changes: 1 addition & 1 deletion src/discretization/schemes/WENO/WENO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function weno(II::CartesianIndex, s::DiscreteSpace, wenoscheme::WENOScheme, bs,
hp = wp1 * hp1 + wp2 * hp2 + wp3 * hp3
hm = wm1 * hm1 + wm2 * hm2 + wm3 * hm3

return (hp - hm) / dx
return recursive_unwrap((hp - hm) / dx)
end

function weno(II::CartesianIndex, s::DiscreteSpace, b, jx, u, dx::AbstractVector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ufunc is a function that returns the correct discretization indexed at Itap, it
"""
function central_difference(D::DerivativeOperator{T,N,Wind,DX}, II, s, bs, jx, u, ufunc) where {T,N,Wind,DX<:Number}
j, x = jx
ndims(u, s) == 0 && return Num(0)
ndims(u, s) == 0 && return 0
# unit index in direction of the derivative
I1 = unitindex(ndims(u, s), j)
# offset is important due to boundary proximity
Expand All @@ -23,13 +23,13 @@ function central_difference(D::DerivativeOperator{T,N,Wind,DX}, II, s, bs, jx, u
Itap = [bwrap(II + i * I1, bs, s, jx) for i in half_range(D.stencil_length)]
end
# Tap points of the stencil, this uses boundary_point_count as this is equal to half the stencil size, which is what we want.
return dot(weights, ufunc(u, Itap, x))
return recursive_unwrap(dot(weights, ufunc(u, Itap, x)))
end

function central_difference(D::DerivativeOperator{T,N,Wind,DX}, II, s, bs, jx, u, ufunc) where {T,N,Wind,DX<:AbstractVector}
j, x = jx
@assert length(bs) == 0 "Interface boundary conditions are not yet supported for nonuniform dx dimensions, such as $x, please post an issue to https://github.com/SciML/MethodOfLines.jl if you need this functionality."
ndims(u, s) == 0 && return Num(0)
ndims(u, s) == 0 && return 0
# unit index in direction of the derivative
I1 = unitindex(ndims(u, s), j)
# offset is important due to boundary proximity
Expand All @@ -48,7 +48,7 @@ function central_difference(D::DerivativeOperator{T,N,Wind,DX}, II, s, bs, jx, u
end
# Tap points of the stencil, this uses boundary_point_count as this is equal to half the stencil size, which is what we want.

return dot(weights, ufunc(u, Itap, x))
return recursive_unwrap(dot(weights, ufunc(u, Itap, x)))
end

"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ function half_offset_centered_difference(D, II, s, bs, jx, u, ufunc)
ndims(u, s) == 0 && return Num(0)
j, x = jx
weights, Itap = get_half_offset_weights_and_stencil(D, II, s, bs, u, jx)
return dot(weights, ufunc(u, Itap, x))
return recursive_unwrap(dot(weights, ufunc(u, Itap, x)))
end
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function cartesian_nonlinear_laplacian(expr, II, derivweights, s::DiscreteSpace,
# Based on the paper https://web.mit.edu/braatzgroup/analysis_of_finite_difference_discretization_schemes_for_diffusion_in_spheres_with_variable_diffusivity.pdf
# See scheme 1, namely the term without the 1/r dependence. See also #354 and #371 in DiffEqOperators, the previous home of this package.
N = ndims(u, s)
N == 0 && return Num(0)
N == 0 && return 0
jx = j, x = (x2i(s, u, x), x)

D_inner = derivweights.halfoffsetmap[1][Differential(x)]
Expand Down Expand Up @@ -64,7 +64,7 @@ function cartesian_nonlinear_laplacian(expr, II, derivweights, s::DiscreteSpace,


interpolated_expr = map(interp_weights_and_stencil) do (weights, stencil)
Num(substitute(substitute(expr, map_vars_to_interpolated(stencil, weights)), map_params_to_interpolated(stencil, weights)))
substitute(substitute(expr, map_vars_to_interpolated(stencil, weights)), map_params_to_interpolated(stencil, weights))
end

# multiply the inner finite difference by the interpolated expression, and finally take the outer finite difference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function spherical_diffusion(innerexpr, II, derivweights, s, bs, depvars, r, u)
D_1_u = central_difference(D_1, II, s, bs, (s.x2i[r], r), u, ufunc_u)
# See scheme 1 in appendix A of the paper

return exprhere*(D_1_u/Num(substitute(r, _rsubs(r, II))) + cartesian_nonlinear_laplacian(innerexpr, II, derivweights, s, bs, depvars, r, u))
return exprhere*(D_1_u/substitute(r, _rsubs(r, II)) + cartesian_nonlinear_laplacian(innerexpr, II, derivweights, s, bs, depvars, r, u))
end

@inline function generate_spherical_diffusion_rules(II::CartesianIndex, s::DiscreteSpace, depvars, derivweights::DifferentialDiscretizer, bcmap, indexmap, terms)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ in the direction of `x`
function upwind_difference(d::Int, II::CartesianIndex, s::DiscreteSpace, bs, derivweights, jx, u, ufunc, ispositive)
j, x = jx
# return if this is an ODE
ndims(u, s) == 0 && return Num(0)
ndims(u, s) == 0 && return 0
D = !ispositive ? derivweights.windmap[1][Differential(x)^d] : derivweights.windmap[2][Differential(x)^d]
#@show D.stencil_coefs, D.stencil_length, D.boundary_stencil_length, D.boundary_point_count
# unit index in direction of the derivative
weights, Itap = _upwind_difference(D, II, s, bs, ispositive, u, jx)
return dot(weights, ufunc(u, Itap, x))
return recursive_unwrap(dot(weights, ufunc(u, Itap, x)))
end

function upwind_difference(expr, d::Int, II::CartesianIndex, s::DiscreteSpace, bs, depvars, derivweights, (j, x), u, central_ufunc, indexmap)
Expand Down
1 change: 0 additions & 1 deletion src/system_parsing/bcs/parse_boundaries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ function parse_bcs(bcs, v::VariableMap, orders)
depvar_ops = v.depvar_ops

# Create some rules to match which bundary/variable a bc concerns
# * Assume that the term of the condition is applied additively and has no multiplier/divisor/power etc.
u0 = []
bceqs = []
## BC matching rules, returns the variable and parameter the bc concerns
Expand Down
4 changes: 2 additions & 2 deletions test/pde_systems/burgers_eq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using DomainSets
using StableRNGs
#using Plots

@test_broken begin #@testset "Inviscid Burgers equation, 1D, upwind, u(0, x) ~ x" begin
@testset "Inviscid Burgers equation, 1D, upwind, u(0, x) ~ x" begin
@parameters x t
@variables u(..)
Dx = Differential(x)
Expand Down Expand Up @@ -95,7 +95,7 @@ end
end
end

@test_broken begin #@testset "Inviscid Burgers equation, 1D, u(0, x) ~ x, Non-Uniform" begin
@testset "Inviscid Burgers equation, 1D, u(0, x) ~ x, Non-Uniform" begin
@parameters x t
@variables u(..)
Dx = Differential(x)
Expand Down

0 comments on commit 0508657

Please sign in to comment.