Skip to content

Commit 62d30e0

Browse files
authored
AbstractBeliefPropagationCache (#217)
1 parent 309c3f6 commit 62d30e0

13 files changed

+350
-256
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.11.26"
4+
version = "0.11.27"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ITensorNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include("edge_sequences.jl")
2828
include("formnetworks/abstractformnetwork.jl")
2929
include("formnetworks/bilinearformnetwork.jl")
3030
include("formnetworks/quadraticformnetwork.jl")
31+
include("caches/abstractbeliefpropagationcache.jl")
3132
include("caches/beliefpropagationcache.jl")
3233
include("contraction_tree_to_graph.jl")
3334
include("gauging.jl")
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
using Graphs: IsDirected
2+
using SplitApplyCombine: group
3+
using LinearAlgebra: diag, dot
4+
using ITensors: dir
5+
using ITensorMPS: ITensorMPS
6+
using NamedGraphs.PartitionedGraphs:
7+
PartitionedGraphs,
8+
PartitionedGraph,
9+
PartitionVertex,
10+
boundary_partitionedges,
11+
partitionvertices,
12+
partitionedges,
13+
unpartitioned_graph
14+
using SimpleTraits: SimpleTraits, Not, @traitfn
15+
using NDTensors: NDTensors
16+
17+
abstract type AbstractBeliefPropagationCache end
18+
19+
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
20+
sequence = optimal_contraction_sequence(contract_list)
21+
updated_messages = contract(contract_list; sequence, kwargs...)
22+
message_norm = norm(updated_messages)
23+
if normalize && !iszero(message_norm)
24+
updated_messages /= message_norm
25+
end
26+
return ITensor[updated_messages]
27+
end
28+
29+
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
30+
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
31+
lhs, rhs = contract(message_a), contract(message_b)
32+
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
33+
return 1 - f
34+
end
35+
36+
default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
37+
default_messages(ptn::PartitionedGraph) = Dictionary()
38+
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
39+
@traitfn function default_bp_maxiter(g::::IsDirected)
40+
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
41+
end
42+
default_partitioned_vertices::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
43+
function default_partitioned_vertices(f::AbstractFormNetwork)
44+
return group(v -> original_state_vertex(f, v), vertices(f))
45+
end
46+
47+
partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
48+
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
49+
function default_message(
50+
bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...
51+
)
52+
return not_implemented()
53+
end
54+
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
55+
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
56+
default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented()
57+
function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
58+
return not_implemented()
59+
end
60+
function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
61+
return not_implemented()
62+
end
63+
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...)
64+
return not_implemented()
65+
end
66+
function region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...)
67+
return not_implemented()
68+
end
69+
function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...)
70+
return not_implemented()
71+
end
72+
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
73+
partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented()
74+
75+
function default_edge_sequence(
76+
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
77+
)
78+
return default_edge_sequence(Algorithm(alg), bpc)
79+
end
80+
function default_bp_maxiter(
81+
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
82+
)
83+
return default_bp_maxiter(Algorithm(alg), bpc)
84+
end
85+
function default_message_update_kwargs(
86+
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
87+
)
88+
return default_message_update_kwargs(Algorithm(alg), bpc)
89+
end
90+
91+
function tensornetwork(bpc::AbstractBeliefPropagationCache)
92+
return unpartitioned_graph(partitioned_tensornetwork(bpc))
93+
end
94+
95+
function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
96+
return ITensor[tensornetwork(bpc)[v] for v in verts]
97+
end
98+
99+
function factors(
100+
bpc::AbstractBeliefPropagationCache, partition_verts::Vector{<:PartitionVertex}
101+
)
102+
return factors(bpc, vertices(bpc, partition_verts))
103+
end
104+
105+
function factors(bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex)
106+
return factors(bpc, [partition_vertex])
107+
end
108+
109+
function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc); kwargs...)
110+
return map(pv -> region_scalar(bpc, pv; kwargs...), pvs)
111+
end
112+
113+
function edge_scalars(
114+
bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs...
115+
)
116+
return map(pe -> region_scalar(bpc, pe; kwargs...), pes)
117+
end
118+
119+
function scalar_factors_quotient(bpc::AbstractBeliefPropagationCache)
120+
return vertex_scalars(bpc), edge_scalars(bpc)
121+
end
122+
123+
function incoming_messages(
124+
bpc::AbstractBeliefPropagationCache,
125+
partition_vertices::Vector{<:PartitionVertex};
126+
ignore_edges=(),
127+
)
128+
bpes = boundary_partitionedges(bpc, partition_vertices; dir=:in)
129+
ms = messages(bpc, setdiff(bpes, ignore_edges))
130+
return reduce(vcat, ms; init=ITensor[])
131+
end
132+
133+
function incoming_messages(
134+
bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs...
135+
)
136+
return incoming_messages(bpc, [partition_vertex]; kwargs...)
137+
end
138+
139+
#Forward from partitioned graph
140+
for f in [
141+
:(PartitionedGraphs.partitioned_graph),
142+
:(PartitionedGraphs.partitionedge),
143+
:(PartitionedGraphs.partitionvertices),
144+
:(PartitionedGraphs.vertices),
145+
:(PartitionedGraphs.boundary_partitionedges),
146+
:(ITensorMPS.linkinds),
147+
]
148+
@eval begin
149+
function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
150+
return $f(partitioned_tensornetwork(bpc), args...; kwargs...)
151+
end
152+
end
153+
end
154+
155+
NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc))
156+
157+
"""
158+
Update the tensornetwork inside the cache
159+
"""
160+
function update_factors(bpc::AbstractBeliefPropagationCache, factors)
161+
bpc = copy(bpc)
162+
tn = tensornetwork(bpc)
163+
for vertex in eachindex(factors)
164+
# TODO: Add a check that this preserves the graph structure.
165+
setindex_preserve_graph!(tn, factors[vertex], vertex)
166+
end
167+
return bpc
168+
end
169+
170+
function update_factor(bpc, vertex, factor)
171+
return update_factors(bpc, Dictionary([vertex], [factor]))
172+
end
173+
174+
function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...)
175+
mts = messages(bpc)
176+
return get(() -> default_message(bpc, edge; kwargs...), mts, edge)
177+
end
178+
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...)
179+
return map(edge -> message(bpc, edge; kwargs...), edges)
180+
end
181+
function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
182+
bpc = copy(bpc)
183+
ms = messages(bpc)
184+
set!(ms, pe, message)
185+
return bpc
186+
end
187+
188+
"""
189+
Compute message tensor as product of incoming mts and local state
190+
"""
191+
function updated_message(
192+
bpc::AbstractBeliefPropagationCache,
193+
edge::PartitionEdge;
194+
message_update_function=default_message_update,
195+
message_update_function_kwargs=(;),
196+
)
197+
vertex = src(edge)
198+
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
199+
state = factors(bpc, vertex)
200+
201+
return message_update_function(
202+
ITensor[incoming_ms; state]; message_update_function_kwargs...
203+
)
204+
end
205+
206+
function update(
207+
alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...
208+
)
209+
return set_message(bpc, edge, updated_message(bpc, edge; kwargs...))
210+
end
211+
212+
"""
213+
Do a sequential update of the message tensors on `edges`
214+
"""
215+
function update(
216+
alg::Algorithm,
217+
bpc::AbstractBeliefPropagationCache,
218+
edges::Vector;
219+
(update_diff!)=nothing,
220+
kwargs...,
221+
)
222+
bpc = copy(bpc)
223+
for e in edges
224+
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
225+
bpc = update(alg, bpc, e; kwargs...)
226+
if !isnothing(update_diff!)
227+
update_diff![] += message_diff(message(bpc, e), prev_message)
228+
end
229+
end
230+
return bpc
231+
end
232+
233+
"""
234+
Do parallel updates between groups of edges of all message tensors
235+
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
236+
mts relevant to that group.
237+
"""
238+
function update(
239+
alg::Algorithm,
240+
bpc::AbstractBeliefPropagationCache,
241+
edge_groups::Vector{<:Vector{<:PartitionEdge}};
242+
kwargs...,
243+
)
244+
new_mts = copy(messages(bpc))
245+
for edges in edge_groups
246+
bpc_t = update(alg, bpc, edges; kwargs...)
247+
for e in edges
248+
new_mts[e] = message(bpc_t, e)
249+
end
250+
end
251+
return set_messages(bpc, new_mts)
252+
end
253+
254+
"""
255+
More generic interface for update, with default params
256+
"""
257+
function update(
258+
alg::Algorithm,
259+
bpc::AbstractBeliefPropagationCache;
260+
edges=default_edge_sequence(alg, bpc),
261+
maxiter=default_bp_maxiter(alg, bpc),
262+
message_update_kwargs=default_message_update_kwargs(alg, bpc),
263+
tol=nothing,
264+
verbose=false,
265+
)
266+
compute_error = !isnothing(tol)
267+
if isnothing(maxiter)
268+
error("You need to specify a number of iterations for BP!")
269+
end
270+
for i in 1:maxiter
271+
diff = compute_error ? Ref(0.0) : nothing
272+
bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...)
273+
if compute_error && (diff.x / length(edges)) <= tol
274+
if verbose
275+
println("BP converged to desired precision after $i iterations.")
276+
end
277+
break
278+
end
279+
end
280+
return bpc
281+
end
282+
283+
function update(
284+
bpc::AbstractBeliefPropagationCache;
285+
alg::String=default_message_update_alg(bpc),
286+
kwargs...,
287+
)
288+
return update(Algorithm(alg), bpc; kwargs...)
289+
end

0 commit comments

Comments
 (0)