Skip to content

Commit 0d64da8

Browse files
authored
Additions needed for sweeping algorithms on trees
1 parent c0f0e8b commit 0d64da8

26 files changed

+2138
-30
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1212
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
13+
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
1314
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
1617
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
1718
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1819
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1920
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
21+
SparseArrayKit = "a9a3c162-d163-4c15-8926-b8794fbefed2"
2022
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
23+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2124
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
2225
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2326

src/ITensorNetworks.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Dictionaries
66
using DocStringExtensions
77
using Graphs
88
using Graphs.SimpleGraphs # AbstractSimpleGraph
9+
using IsApprox
910
using ITensors
1011
using ITensors.ContractionSequenceOptimization
1112
using ITensors.ITensorVisualizationCore
@@ -15,7 +16,9 @@ using Observers
1516
using Printf
1617
using Requires
1718
using SimpleTraits
19+
using SparseArrayKit
1820
using SplitApplyCombine
21+
using StaticArrays
1922
using Suppressor
2023
using TimerOutputs
2124

@@ -27,6 +30,7 @@ using ITensors:
2730
@timeit_debug,
2831
AbstractMPS,
2932
Algorithm,
33+
OneITensor,
3034
check_hascommoninds,
3135
commontags,
3236
orthocenter,
@@ -69,11 +73,20 @@ include("expect.jl")
6973
include("models.jl")
7074
include("tebd.jl")
7175
include("itensornetwork.jl")
76+
include("utility.jl")
7277
include("specialitensornetworks.jl")
7378
include("renameitensornetwork.jl")
7479
include("boundarymps.jl")
7580
include("beliefpropagation.jl")
76-
include(joinpath("treetensornetworks", "treetensornetwork.jl"))
81+
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
82+
# include(joinpath("treetensornetworks", "treetensornetwork.jl"))
83+
include(joinpath("treetensornetworks", "ttns.jl"))
84+
include(joinpath("treetensornetworks", "ttno.jl"))
85+
include(joinpath("treetensornetworks", "opsum_to_ttno.jl"))
86+
include(joinpath("treetensornetworks", "abstractprojttno.jl"))
87+
include(joinpath("treetensornetworks", "projttno.jl"))
88+
include(joinpath("treetensornetworks", "projttnosum.jl"))
89+
include(joinpath("treetensornetworks", "projttno_apply.jl"))
7790
# Compatibility of ITensor observer and Observers
7891
# TODO: Delete this
7992
include(joinpath("treetensornetworks", "solvers", "update_observer.jl"))

src/abstractindsnetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ edge_data_type(::Type{<:AbstractIndsNetwork{V,I}}) where {V,I} = Vector{I}
1919

2020
function uniqueinds(is::AbstractIndsNetwork, edge::AbstractEdge)
2121
inds = IndexSet(get(is, src(edge), Index[]))
22-
for ei in setdiff(incident_edges(is, src(edge)...), [edge])
22+
for ei in setdiff(incident_edges(is, src(edge)), [edge])
2323
inds = unioninds(inds, get(is, ei, Index[]))
2424
end
2525
return inds

src/abstractitensornetwork.jl

Lines changed: 174 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ end
4545
# Iteration
4646
#
4747

48+
# TODO: iteration
49+
50+
# TODO: different `map` functionalities as defined for ITensors.AbstractMPS
51+
52+
# TODO: broadcasting
53+
4854
function union(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork; kwargs...)
4955
tn = ITensorNetwork(union(data_graph(tn1), data_graph(tn2)); kwargs...)
5056
# Add any new edges that are introduced during the union
@@ -104,6 +110,36 @@ end
104110
# Convenience wrapper
105111
itensors(tn::AbstractITensorNetwork) = Vector{ITensor}(tn)
106112

113+
#
114+
# Promotion and conversion
115+
#
116+
117+
function LinearAlgebra.promote_leaf_eltypes(tn::AbstractITensorNetwork)
118+
return LinearAlgebra.promote_leaf_eltypes(itensors(tn))
119+
end
120+
121+
function ITensors.promote_itensor_eltype(tn::AbstractITensorNetwork)
122+
return LinearAlgebra.promote_leaf_eltypes(tn)
123+
end
124+
125+
ITensors.scalartype(tn::AbstractITensorNetwork) = LinearAlgebra.promote_leaf_eltypes(tn)
126+
127+
# TODO: eltype(::AbstractITensorNetwork) (cannot behave the same as eltype(::ITensors.AbstractMPS))
128+
129+
# TODO: mimic ITensors.AbstractMPS implementation using map
130+
function ITensors.convert_leaf_eltype(eltype::Type, tn::AbstractITensorNetwork)
131+
tn = copy(tn)
132+
vertex_data(tn) .= convert_eltype.(Ref(eltype), vertex_data(tn))
133+
return tn
134+
end
135+
136+
# TODO: mimic ITensors.AbstractMPS implementation using map
137+
function NDTensors.convert_scalartype(eltype::Type{<:Number}, tn::AbstractITensorNetwork)
138+
tn = copy(tn)
139+
vertex_data(tn) .= ITensors.adapt.(Ref(eltype), vertex_data(tn))
140+
return tn
141+
end
142+
107143
#
108144
# Conversion to Graphs
109145
#
@@ -185,11 +221,13 @@ end
185221
function replaceinds(tn::AbstractITensorNetwork, is_is′::Pair{<:IndsNetwork,<:IndsNetwork})
186222
tn = copy(tn)
187223
is, is′ = is_is′
188-
# TODO: Check that `is` and `is′` have the same vertices and edges.
224+
@assert underlying_graph(is) == underlying_graph(is′)
189225
for v in vertices(is)
226+
isassigned(is, v) || continue
190227
setindex_preserve_graph!(tn, replaceinds(tn[v], is[v] => is′[v]), v)
191228
end
192229
for e in edges(is)
230+
isassigned(is, e) || continue
193231
for v in (src(e), dst(e))
194232
setindex_preserve_graph!(tn, replaceinds(tn[v], is[e] => is′[e]), v)
195233
end
@@ -208,7 +246,7 @@ const map_inds_label_functions = [
208246
:setprime,
209247
:noprime,
210248
:replaceprime,
211-
:swapprime,
249+
# :swapprime, # TODO: add @test_broken as a reminder
212250
:addtags,
213251
:removetags,
214252
:replacetags,
@@ -227,6 +265,24 @@ for f in map_inds_label_functions
227265
function $f(n::Union{IndsNetwork,AbstractITensorNetwork}, args...; kwargs...)
228266
return map_inds($f, n, args...; kwargs...)
229267
end
268+
269+
function $f(
270+
ffilter::typeof(linkinds),
271+
n::Union{IndsNetwork,AbstractITensorNetwork},
272+
args...;
273+
kwargs...,
274+
)
275+
return map_inds($f, n, args...; sites=[], kwargs...)
276+
end
277+
278+
function $f(
279+
ffilter::typeof(siteinds),
280+
n::Union{IndsNetwork,AbstractITensorNetwork},
281+
args...;
282+
kwargs...,
283+
)
284+
return map_inds($f, n, args...; links=[], kwargs...)
285+
end
230286
end
231287
end
232288

@@ -402,12 +458,19 @@ function factorize(
402458
return factorize(tn, edgetype(tn)(edge); kwargs...)
403459
end
404460

405-
# For ambiguity error
461+
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
406462
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
407-
tn = factorize(tn, edge; kwargs...)
408-
# TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
409-
new_vertex = only(neighbors(tn, src(edge)) neighbors(tn, dst(edge)))
410-
return contract(tn, new_vertex => dst(edge))
463+
# tn = factorize(tn, edge; kwargs...)
464+
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
465+
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
466+
# return contract(tn, new_vertex => dst(edge))
467+
tn = copy(tn)
468+
left_inds = uniqueinds(tn, edge)
469+
ltags = tags(tn, edge)
470+
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
471+
tn[src(edge)] = X
472+
tn[dst(edge)] *= Y
473+
return tn
411474
end
412475

413476
function orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
@@ -429,6 +492,25 @@ function orthogonalize(ψ::AbstractITensorNetwork, source_vertex)
429492
return ψ
430493
end
431494

495+
# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
496+
function _truncate_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
497+
tn = copy(tn)
498+
left_inds = uniqueinds(tn, edge)
499+
ltags = tags(tn, edge)
500+
U, S, V = svd(tn[src(edge)], left_inds; lefttags=ltags, ortho="left", kwargs...)
501+
tn[src(edge)] = U
502+
tn[dst(edge)] *= (S * V)
503+
return tn
504+
end
505+
506+
function truncate(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
507+
return _truncate_edge(tn, edge; kwargs...)
508+
end
509+
510+
function truncate(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
511+
return truncate(tn, edgetype(tn)(edge); kwargs...)
512+
end
513+
432514
function Base.:*(c::Number, ψ::AbstractITensorNetwork)
433515
v₁ = first(vertices(ψ))
434516
= copy(ψ)
@@ -572,6 +654,91 @@ function visualize(
572654
return visualize(Vector{ITensor}(tn), args...; vertex_labels, kwargs...)
573655
end
574656

657+
#
658+
# Link dimensions
659+
#
660+
661+
function maxlinkdim(tn::AbstractITensorNetwork)
662+
md = 1
663+
for e in edges(tn)
664+
md = max(md, linkdim(tn, e))
665+
end
666+
return md
667+
end
668+
669+
function linkdim(tn::AbstractITensorNetwork, edge::Pair)
670+
return linkdim(tn, edgetype(tn)(edge))
671+
end
672+
673+
function linkdim(tn::AbstractITensorNetwork{V}, edge::AbstractEdge{V}) where {V}
674+
ls = linkinds(tn, edge)
675+
return prod([isnothing(l) ? 1 : dim(l) for l in ls])
676+
end
677+
678+
function linkdims(tn::AbstractITensorNetwork{V}) where {V}
679+
ld = DataGraph{V,Any,Int}(copy(underlying_graph(tn)))
680+
for e in edges(ld)
681+
ld[e] = linkdim(tn, e)
682+
end
683+
return ld
684+
end
685+
686+
#
687+
# Common index checking
688+
#
689+
690+
function hascommoninds(
691+
::typeof(siteinds), A::AbstractITensorNetwork{V}, B::AbstractITensorNetwork{V}
692+
) where {V}
693+
for v in vertices(A)
694+
!hascommoninds(siteinds(A, v), siteinds(B, v)) && return false
695+
end
696+
return true
697+
end
698+
699+
function check_hascommoninds(
700+
::typeof(siteinds), A::AbstractITensorNetwork{V}, B::AbstractITensorNetwork{V}
701+
) where {V}
702+
N = nv(A)
703+
if nv(B) N
704+
throw(
705+
DimensionMismatch(
706+
"$(typeof(A)) and $(typeof(B)) have mismatched number of vertices $N and $(nv(B))."
707+
),
708+
)
709+
end
710+
for v in vertices(A)
711+
!hascommoninds(siteinds(A, v), siteinds(B, v)) && error(
712+
"$(typeof(A)) A and $(typeof(B)) B must share site indices. On vertex $v, A has site indices $(siteinds(A, v)) while B has site indices $(siteinds(B, v)).",
713+
)
714+
end
715+
return nothing
716+
end
717+
718+
function hassameinds(
719+
::typeof(siteinds), A::AbstractITensorNetwork{V}, B::AbstractITensorNetwork{V}
720+
) where {V}
721+
nv(A) nv(B) && return false
722+
for v in vertices(A)
723+
!ITensors.hassameinds(siteinds(A, v), siteinds(B, v)) && return false
724+
end
725+
return true
726+
end
727+
728+
#
729+
# Site combiners
730+
#
731+
732+
# TODO: will be broken, fix this
733+
function site_combiners(tn::AbstractITensorNetwork{V}) where {V}
734+
Cs = DataGraph{V,ITensor}(copy(underlying_graph(tn)))
735+
for v in vertices(tn)
736+
s = siteinds(tn, v)
737+
Cs[v] = combiner(s; tags=commontags(s))
738+
end
739+
return Cs
740+
end
741+
575742
## # TODO: should this make sure that internal indices
576743
## # don't clash?
577744
## function hvncat(

src/expect.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@ function expect(
55
maxdim=nothing,
66
ortho=false,
77
sequence=nothing,
8+
sites=vertices(ψ),
89
)
910
s = siteinds(ψ)
10-
res = Dictionary(vertices(ψ), Vector{Float64}(undef, nv(ψ)))
11+
ElT = promote_itensor_eltype(ψ)
12+
# ElT = ishermitian(ITensors.op(op, s[sites[1]])) ? real(ElT) : ElT
13+
res = Dictionary(sites, Vector{ElT}(undef, length(sites)))
1114
if isnothing(sequence)
1215
sequence = contraction_sequence(inner_network(ψ, ψ; flatten=true))
1316
end
1417
normψ² = norm_sqr(ψ; sequence)
15-
for v in vertices(ψ)
18+
for v in sites
1619
O = ITensor(Op(op, v), s)
1720
= apply(O, ψ; cutoff, maxdim, ortho)
1821
res[v] = contract_inner(ψ, Oψ; sequence) / normψ²

0 commit comments

Comments
 (0)