diff --git a/src/operators.jl b/src/operators.jl index e5d4e0c..dd27226 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -92,6 +92,74 @@ function MA.operate_to!( return output end +""" + _lowest_term_idx(p::Polynomial) + +Return the index of the lowest term in `p` according to its monomial ordering. +""" +_lowest_term_idx(p::Polynomial{V, M, T}) where {V, M <: Reverse, T} = lastindex(p.x) +_lowest_term_idx(p::Polynomial) = firstindex(p.x) + +""" + _insert_constant_term!(p::Polynomial) + +Insert a constant (degree 0) term into polynomial `p` at the appropriate position for the +monomial ordering of `p`. Assume that a constant term does not already exists. +""" +function _insert_constant_term!(p::Polynomial{V, M, T}) where {V, M <: Reverse, T} + push!(MP.coefficients(p), zero(T)) + push!(MP.monomials(p).Z, zeros(Int, length(MP.variables(p)))) + return p +end + +function _insert_constant_term!(p::Polynomial{V, M, T}) where {V, M, T} + insert!(MP.coefficients(p), 1, zero(T)) + insert!(MP.monomials(p).Z, 1, zeros(Int, length(MP.variables(p)))) + return p +end + +function MA.operate!(op::Union{typeof(+), typeof(-)}, p::Polynomial{V, M, T}, x::T) where {V, M, T} + c_idx = _lowest_term_idx(p) + if MP.nterms(p) == 0 || !MP.isconstant(MP.terms(p)[c_idx]) + _insert_constant_term!(p) + c_idx = _lowest_term_idx(p) + end + coeffs = MP.coefficients(p) + coeffs[c_idx] = op(coeffs[c_idx], x) + if iszero(coeffs[c_idx]) + deleteat!(coeffs, c_idx) + deleteat!(MP.monomials(p), c_idx) + end + return p +end + +function MA.operate!(op::Union{typeof(+), typeof(-)}, p::Polynomial{V, M, T}, x::Variable{V, M}) where {V, M, T} + vars = MP.variables(p) + idx = searchsortedfirst(vars, x; rev = true) + monos = MP.monomials(p) + if idx > length(vars) || !isequal(vars[idx], x) + for mono in monos + insert!(MP.exponents(mono), idx, 0) + end + insert!(vars, idx, x) + end + mono = Monomial{V, M}(vars, zeros(Int, length(vars))) + mono.z[idx] = 1 + idx = searchsortedfirst(monos, mono) + coeffs = MP.coefficients(p) + N = MP.nterms(p) + if idx > N || !isequal(monos[idx], mono) + insert!(monos.Z, idx, MP.exponents(mono)) + insert!(coeffs, idx, zero(T)) + end + coeffs[idx] = op(coeffs[idx], one(T)) + if iszero(coeffs[idx]) + deleteat!(coeffs, idx) + deleteat!(monos, idx) + end + return p +end + function MA.operate!( op::Union{typeof(+),typeof(-)}, p::Polynomial, diff --git a/test/mutable_arithmetics.jl b/test/mutable_arithmetics.jl index d2a908c..5529013 100644 --- a/test/mutable_arithmetics.jl +++ b/test/mutable_arithmetics.jl @@ -35,3 +35,68 @@ using DynamicPolynomials @test q == x + y + 1 end end + +@testset "Fast path cases: $ord" for ord in [ + InverseLexOrder, + LexOrder, + Graded{InverseLexOrder}, + Graded{LexOrder}, + Reverse{InverseLexOrder}, + Reverse{LexOrder}, + Graded{Reverse{InverseLexOrder}}, + Graded{Reverse{LexOrder}}, + Reverse{Graded{InverseLexOrder}}, + Reverse{Graded{LexOrder}}, + ] + @polyvar x y z monomial_order=ord + + # allocation tests vary between Julia versions, so they're upper bounds + @testset "Polynomial + constant" begin + poly = 2 * x^2 + 3 * x * y + z * y^2 + poly2 = copy(poly) + result = poly + 2 + MA.operate!(+, poly, 2) + @test isequal(poly, result) + # down from 576 using the generic method + @test (@allocated MA.operate!(+, poly2, 2)) <= 272 + # subsequent additions don't allocate + @test (@allocated MA.operate!(+, poly2, 2)) == 0 + + # also test `-` + poly = 2 * x^2 + 3 * x * y + z * y^2 + poly2 = copy(poly) + result = poly - 2 + MA.operate!(-, poly, 2) + @test isequal(poly, result) + # down from 576 using the generic method + @test (@allocated MA.operate!(-, poly2, 2)) <= 272 + # subsequent additions don't allocate + @test (@allocated MA.operate!(-, poly2, 2)) == 0 + end + + @testset "Polynomial + Variable" begin + poly = 2 * x^2 + 3 * x * y + z * y^2 + poly2 = copy(poly) + result = poly + x + MA.operate!(+, poly, x) + @test isequal(poly, result) + # down from 18752 using the generic method + # 368 or 304 depending on ordering, more for different version + # pre is especially bad + @test (@allocated MA.operate!(+, poly2, x)) <= 400 + # down from 1904 using the generic method + @test (@allocated MA.operate!(+, poly2, x)) <= 144 + + # also test `-` + poly = 2 * x^2 + 3 * x * y + z * y^2 + poly2 = copy(poly) + result = poly - x + MA.operate!(-, poly, x) + @test isequal(poly, result) + # down from 18752 using the generic method + # 368 or 304 depending on ordering + @test (@allocated MA.operate!(-, poly2, x)) <= 400 + # down from 1904 using the generic method + @test (@allocated MA.operate!(-, poly2, x)) <= 144 + end +end