Skip to content

Commit 49c5439

Browse files
committed
Add inv for SequentialTransform
1 parent 81f50c2 commit 49c5439

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/sequential.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,12 @@ struct SequentialTransform <: Transform
1111
transforms::Vector{Transform}
1212
end
1313

14-
# AbstractTrees interface
15-
AbstractTrees.nodevalue(::SequentialTransform) = SequentialTransform
16-
AbstractTrees.children(s::SequentialTransform) = s.transforms
17-
18-
Base.show(io::IO, s::SequentialTransform) =
19-
print(io, join(s.transforms, ""))
20-
21-
function Base.show(io::IO, ::MIME"text/plain", s::SequentialTransform)
22-
tree = AbstractTrees.repr_tree(s, context=io)
23-
print(io, tree[begin:end-1]) # remove \n at end
24-
end
25-
2614
isrevertible(s::SequentialTransform) = all(isrevertible, s.transforms)
2715

2816
isinvertible(s::SequentialTransform) = all(isinvertible, s.transforms)
2917

18+
Base.inv(s::SequentialTransform) = SequentialTransform([inv(t) for t in reverse(s.transforms)])
19+
3020
function apply(s::SequentialTransform, table)
3121
allcache = []
3222
current = table
@@ -80,3 +70,15 @@ Create a [`SequentialTransform`](@ref) transform with
8070
(t1::Identity, t2::Identity) = Identity()
8171
(t1::Transform, t2::Identity) = t1
8272
(t1::Identity, t2::Transform) = t2
73+
74+
# AbstractTrees interface
75+
AbstractTrees.nodevalue(::SequentialTransform) = SequentialTransform
76+
AbstractTrees.children(s::SequentialTransform) = s.transforms
77+
78+
Base.show(io::IO, s::SequentialTransform) =
79+
print(io, join(s.transforms, ""))
80+
81+
function Base.show(io::IO, ::MIME"text/plain", s::SequentialTransform)
82+
tree = AbstractTrees.repr_tree(s, context=io)
83+
print(io, tree[begin:end-1]) # remove \n at end
84+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ using Test
55
@test TransformsBase.isrevertible(Identity())
66
@test TransformsBase.isinvertible(Identity())
77
@test inv(Identity()) == Identity()
8+
@test inv(Identity() Identity()) == Identity()
89
@test (Identity() Identity()) == Identity()
910
end

0 commit comments

Comments
 (0)