diff --git a/src/permutations.jl b/src/permutations.jl index 93e1541..edab08a 100644 --- a/src/permutations.jl +++ b/src/permutations.jl @@ -15,44 +15,26 @@ struct Permutations{T} length::Int end -function has_repeats(state::Vector{Int}) - # This can be safely marked inbounds because of the type restriction in the signature. - # If the type restriction is ever loosened, please check safety of the `@inbounds`. - @inbounds for outer in eachindex(state) - for inner in (outer+1):lastindex(state) - if state[outer] == state[inner] - return true - end - end - end - return false -end - -function increment!(state::Vector{Int}, min::Int, max::Int) - # All array indexing can be marked inbounds because of the type restriction in the signature. - # If the type restriction is ever loosened, please check safety of the `@inbounds`. - @inbounds state[end] += 1 - i = lastindex(state) - @inbounds while i > firstindex(state) && state[i] > max - state[i] = min - state[i-1] += 1 - i -= 1 - end -end - -function next_permutation!(state::Vector{Int}, min::Int, max::Int) - while true - increment!(state, min, max) - has_repeats(state) || break - end -end - -function Base.iterate(p::Permutations, state::Vector{Int}=fill(firstindex(p.data), p.length)) - next_permutation!(state, firstindex(p.data), lastindex(p.data)) - if first(state) > lastindex(p.data) - return nothing +# The following code basically implements `permutations` in terms of `multiset_permutations` as +# +# permutations(a, t::Integer=length(a)) = Iterators.map( +# indices -> [a[i] for i in indices], +# multiset_permutations(eachindex(a), t)) +# +# with the difference that we can also define `eltype(::Permutations)`, which is used in some tests. + +function Base.iterate(p::Permutations, state=nothing) + if state === nothing + mp = multiset_permutations(collect(eachindex(p.data)), p.length) + it = iterate(mp) + if it === nothing return nothing end + else + mp, mp_state = state + it = iterate(mp, mp_state) + if it === nothing return nothing end end - [p.data[i] for i in state], state + indices, mp_state = it + return [p.data[i] for i in indices], (mp=mp, mp_state=mp_state) end function Base.length(p::Permutations)