Skip to content

Commit 18bad9f

Browse files
authored
Extend MPS solvers to trees (#44)
1 parent 0d64da8 commit 18bad9f

20 files changed

+1288
-783
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1212
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
1313
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
14+
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1415
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"

src/ITensorNetworks.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using IsApprox
1010
using ITensors
1111
using ITensors.ContractionSequenceOptimization
1212
using ITensors.ITensorVisualizationCore
13+
using IterTools
1314
using KrylovKit: KrylovKit
1415
using NamedGraphs
1516
using Observers
@@ -87,6 +88,8 @@ include(joinpath("treetensornetworks", "abstractprojttno.jl"))
8788
include(joinpath("treetensornetworks", "projttno.jl"))
8889
include(joinpath("treetensornetworks", "projttnosum.jl"))
8990
include(joinpath("treetensornetworks", "projttno_apply.jl"))
91+
# Compatibility of ITensors.MPS/MPO with tree sweeping routines
92+
include(joinpath("treetensornetworks", "solvers", "tree_patch.jl"))
9093
# Compatibility of ITensor observer and Observers
9194
# TODO: Delete this
9295
include(joinpath("treetensornetworks", "solvers", "update_observer.jl"))
@@ -103,10 +106,11 @@ include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
103106
include(joinpath("treetensornetworks", "solvers", "dmrg.jl"))
104107
include(joinpath("treetensornetworks", "solvers", "dmrg_x.jl"))
105108
include(joinpath("treetensornetworks", "solvers", "projmpo_apply.jl"))
106-
include(joinpath("treetensornetworks", "solvers", "contract_mpo_mps.jl"))
109+
include(joinpath("treetensornetworks", "solvers", "contract_operator_state.jl"))
107110
include(joinpath("treetensornetworks", "solvers", "projmps2.jl"))
108111
include(joinpath("treetensornetworks", "solvers", "projmpo_mps2.jl"))
109112
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
113+
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
110114

111115
include("exports.jl")
112116

src/treetensornetworks/solvers/contract_mpo_mps.jl

Lines changed: 0 additions & 52 deletions
This file was deleted.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
function contract_solver(; kwargs...)
2+
function solver(PH, t, psi; kws...)
3+
v = ITensor(1.0)
4+
for j in sites(PH)
5+
v *= PH.psi0[j]
6+
end
7+
Hpsi0 = contract(PH, v)
8+
return Hpsi0, nothing
9+
end
10+
return solver
11+
end
12+
13+
function ITensors.contract(
14+
::ITensors.Algorithm"fit",
15+
A::IsTreeOperator,
16+
psi0::ST;
17+
init_state=psi0,
18+
nsweeps=1,
19+
kwargs...,
20+
)::ST where {ST<:IsTreeState}
21+
n = nv(A)
22+
n != nv(psi0) && throw(
23+
DimensionMismatch("Number of sites operator ($n) and state ($(nv(psi0))) do not match"),
24+
)
25+
if n == 1
26+
v = only(vertices(psi0))
27+
return ST([A[v] * psi0[v]])
28+
end
29+
30+
check_hascommoninds(siteinds, A, psi0)
31+
32+
# In case A and psi0 have the same link indices
33+
A = sim(linkinds, A)
34+
35+
# Fix site and link inds of init_state
36+
init_state = deepcopy(init_state)
37+
init_state = sim(linkinds, init_state)
38+
for v in vertices(psi0)
39+
replaceinds!(
40+
init_state[v], siteinds(init_state, v), uniqueinds(siteinds(A, v), siteinds(psi0, v))
41+
)
42+
end
43+
44+
t = Inf
45+
reverse_step = false
46+
PH = proj_operator_apply(psi0, A)
47+
psi = tdvp(
48+
contract_solver(; kwargs...), PH, t, init_state; nsweeps, reverse_step, kwargs...
49+
)
50+
51+
return psi
52+
end
53+
54+
# extra ITensors overloads for tree tensor networks
55+
function ITensors.contract(A::TTNO, ψ::TTNS; alg="fit", kwargs...)
56+
return contract(ITensors.Algorithm(alg), A, ψ; kwargs...)
57+
end
58+
59+
function ITensors.apply(A::TTNO, ψ::TTNS; kwargs...)
60+
= contract(A, ψ; kwargs...)
61+
return replaceprime(Aψ, 1 => 0)
62+
end

src/treetensornetworks/solvers/dmrg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ function eigsolve_solver(; kwargs...)
1616
return solver
1717
end
1818

19-
function dmrg(H, psi0::MPS; kwargs...)
19+
function dmrg(H, psi0::IsTreeState; kwargs...)
2020
t = Inf # DMRG is TDVP with an infinite timestep and no reverse step
2121
reverse_step = false
2222
psi = tdvp(eigsolve_solver(; kwargs...), H, t, psi0; reverse_step, kwargs...)
2323
return psi
2424
end
2525

2626
# Alias for DMRG
27-
function eigsolve(H, psi0::MPS; kwargs...)
27+
function eigsolve(H, psi0::IsTreeState; kwargs...)
2828
return dmrg(H, psi0; kwargs...)
2929
end

src/treetensornetworks/solvers/dmrg_x.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function dmrg_x_solver(PH, t, psi0; kwargs...)
77
return U_max, nothing
88
end
99

10-
function dmrg_x(PH, psi0::MPS; reverse_step=false, kwargs...)
10+
function dmrg_x(PH, psi0::IsTreeState; reverse_step=false, kwargs...)
1111
t = Inf
1212
psi = tdvp(dmrg_x_solver, PH, t, psi0; reverse_step, kwargs...)
1313
return psi

src/treetensornetworks/solvers/projmpo_mps2.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,5 @@ end
4545
contract(P::ProjMPO_MPS2, v::ITensor) = contract(P.PH, v)
4646

4747
proj_mps(P::ProjMPO_MPS2) = [proj_mps(m) for m in P.Ms]
48+
49+
underlying_graph(P::ProjMPO_MPS2) = chain_lattice_graph(length(P.PH.H)) # tree patch

src/treetensornetworks/solvers/solver_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct TimeDependentSum{S,T}
2828
f::Vector{S}
2929
H0::T
3030
end
31-
TimeDependentSum(f::Vector, H0::ProjMPOSum) = TimeDependentSum(f, H0.pm)
31+
TimeDependentSum(f::Vector, H0::IsTreeProjOperatorSum) = TimeDependentSum(f, H0.pm)
3232
Base.length(H::TimeDependentSum) = length(H.f)
3333

3434
function Base.:*(c::Number, H::TimeDependentSum)

src/treetensornetworks/solvers/tdvp.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ function tdvp_solver(; kwargs...)
4343
end
4444
end
4545

46-
function tdvp(H, t::Number, psi0::MPS; kwargs...)
46+
function tdvp(H, t::Number, psi0::IsTreeState; kwargs...)
4747
return tdvp(tdvp_solver(; kwargs...), H, t, psi0; kwargs...)
4848
end
4949

50-
function tdvp(t::Number, H, psi0::MPS; kwargs...)
50+
function tdvp(t::Number, H, psi0::IsTreeState; kwargs...)
5151
return tdvp(H, t, psi0; kwargs...)
5252
end
5353

54-
function tdvp(H, psi0::MPS, t::Number; kwargs...)
54+
function tdvp(H, psi0::IsTreeState, t::Number; kwargs...)
5555
return tdvp(H, t, psi0; kwargs...)
5656
end

src/treetensornetworks/solvers/tdvp_generic.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function process_sweeps(; kwargs...)
4242
return (; maxdim, mindim, cutoff, noise)
4343
end
4444

45-
function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)
45+
function tdvp(solver, PH, t::Number, psi0::IsTreeState; kwargs...)
4646
reverse_step = get(kwargs, :reverse_step, true)
4747

4848
nsweeps = _tdvp_compute_nsweeps(t; kwargs...)
@@ -124,37 +124,37 @@ function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)
124124
end
125125

126126
"""
127-
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
128-
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
127+
tdvp(H::MPS,psi0::MPO,t::Number; kwargs...)
128+
tdvp(H::TTNS,psi0::TTNO,t::Number; kwargs...)
129129
130130
Use the time dependent variational principle (TDVP) algorithm
131131
to compute `exp(t*H)*psi0` using an efficient algorithm based
132-
on alternating optimization of the MPS tensors and local Krylov
132+
on alternating optimization of the state tensors and local Krylov
133133
exponentiation of H.
134134
135135
Returns:
136-
* `psi::MPS` - time-evolved MPS
136+
* `psi` - time-evolved state
137137
138138
Optional keyword arguments:
139139
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
140140
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
141141
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
142142
"""
143-
function tdvp(solver, H::MPO, t::Number, psi0::MPS; kwargs...)
143+
function tdvp(solver, H::IsTreeOperator, t::Number, psi0::IsTreeState; kwargs...)
144144
check_hascommoninds(siteinds, H, psi0)
145145
check_hascommoninds(siteinds, H, psi0')
146146
# Permute the indices to have a better memory layout
147147
# and minimize permutations
148148
H = ITensors.permute(H, (linkind, siteinds, linkind))
149-
PH = ProjMPO(H)
149+
PH = proj_operator(H)
150150
return tdvp(solver, PH, t, psi0; kwargs...)
151151
end
152152

153-
function tdvp(solver, t::Number, H, psi0::MPS; kwargs...)
153+
function tdvp(solver, t::Number, H, psi0::IsTreeState; kwargs...)
154154
return tdvp(solver, H, t, psi0; kwargs...)
155155
end
156156

157-
function tdvp(solver, H, psi0::MPS, t::Number; kwargs...)
157+
function tdvp(solver, H, psi0::IsTreeState, t::Number; kwargs...)
158158
return tdvp(solver, H, t, psi0; kwargs...)
159159
end
160160

@@ -177,12 +177,14 @@ each step of the algorithm when optimizing the MPS.
177177
Returns:
178178
* `psi::MPS` - time-evolved MPS
179179
"""
180-
function tdvp(solver, Hs::Vector{MPO}, t::Number, psi0::MPS; kwargs...)
180+
function tdvp(
181+
solver, Hs::Vector{<:IsTreeOperator}, t::Number, psi0::IsTreeState; kwargs...
182+
)
181183
for H in Hs
182184
check_hascommoninds(siteinds, H, psi0)
183185
check_hascommoninds(siteinds, H, psi0')
184186
end
185187
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
186-
PHs = ProjMPOSum(Hs)
188+
PHs = proj_operator_sum(Hs)
187189
return tdvp(solver, PHs, t, psi0; kwargs...)
188190
end

0 commit comments

Comments
 (0)