|
| 1 | +using Graphs: IsDirected |
| 2 | +using SplitApplyCombine: group |
| 3 | +using LinearAlgebra: diag, dot |
| 4 | +using ITensors: dir |
| 5 | +using ITensorMPS: ITensorMPS |
| 6 | +using NamedGraphs.PartitionedGraphs: |
| 7 | + PartitionedGraphs, |
| 8 | + PartitionedGraph, |
| 9 | + PartitionVertex, |
| 10 | + boundary_partitionedges, |
| 11 | + partitionvertices, |
| 12 | + partitionedges, |
| 13 | + unpartitioned_graph |
| 14 | +using SimpleTraits: SimpleTraits, Not, @traitfn |
| 15 | +using NDTensors: NDTensors |
| 16 | + |
| 17 | +abstract type AbstractBeliefPropagationCache end |
| 18 | + |
| 19 | +function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) |
| 20 | + sequence = optimal_contraction_sequence(contract_list) |
| 21 | + updated_messages = contract(contract_list; sequence, kwargs...) |
| 22 | + message_norm = norm(updated_messages) |
| 23 | + if normalize && !iszero(message_norm) |
| 24 | + updated_messages /= message_norm |
| 25 | + end |
| 26 | + return ITensor[updated_messages] |
| 27 | +end |
| 28 | + |
| 29 | +#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages |
| 30 | +function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) |
| 31 | + lhs, rhs = contract(message_a), contract(message_b) |
| 32 | + f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs))) |
| 33 | + return 1 - f |
| 34 | +end |
| 35 | + |
| 36 | +default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e] |
| 37 | +default_messages(ptn::PartitionedGraph) = Dictionary() |
| 38 | +@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing |
| 39 | +@traitfn function default_bp_maxiter(g::::IsDirected) |
| 40 | + return default_bp_maxiter(undirected_graph(underlying_graph(g))) |
| 41 | +end |
| 42 | +default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) |
| 43 | +function default_partitioned_vertices(f::AbstractFormNetwork) |
| 44 | + return group(v -> original_state_vertex(f, v), vertices(f)) |
| 45 | +end |
| 46 | + |
| 47 | +partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented() |
| 48 | +messages(bpc::AbstractBeliefPropagationCache) = not_implemented() |
| 49 | +function default_message( |
| 50 | + bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs... |
| 51 | +) |
| 52 | + return not_implemented() |
| 53 | +end |
| 54 | +default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented() |
| 55 | +Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented() |
| 56 | +default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented() |
| 57 | +function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache) |
| 58 | + return not_implemented() |
| 59 | +end |
| 60 | +function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache) |
| 61 | + return not_implemented() |
| 62 | +end |
| 63 | +function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...) |
| 64 | + return not_implemented() |
| 65 | +end |
| 66 | +function region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...) |
| 67 | + return not_implemented() |
| 68 | +end |
| 69 | +function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...) |
| 70 | + return not_implemented() |
| 71 | +end |
| 72 | +partitions(bpc::AbstractBeliefPropagationCache) = not_implemented() |
| 73 | +partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented() |
| 74 | + |
| 75 | +function default_edge_sequence( |
| 76 | + bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) |
| 77 | +) |
| 78 | + return default_edge_sequence(Algorithm(alg), bpc) |
| 79 | +end |
| 80 | +function default_bp_maxiter( |
| 81 | + bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) |
| 82 | +) |
| 83 | + return default_bp_maxiter(Algorithm(alg), bpc) |
| 84 | +end |
| 85 | +function default_message_update_kwargs( |
| 86 | + bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) |
| 87 | +) |
| 88 | + return default_message_update_kwargs(Algorithm(alg), bpc) |
| 89 | +end |
| 90 | + |
| 91 | +function tensornetwork(bpc::AbstractBeliefPropagationCache) |
| 92 | + return unpartitioned_graph(partitioned_tensornetwork(bpc)) |
| 93 | +end |
| 94 | + |
| 95 | +function factors(bpc::AbstractBeliefPropagationCache, verts::Vector) |
| 96 | + return ITensor[tensornetwork(bpc)[v] for v in verts] |
| 97 | +end |
| 98 | + |
| 99 | +function factors( |
| 100 | + bpc::AbstractBeliefPropagationCache, partition_verts::Vector{<:PartitionVertex} |
| 101 | +) |
| 102 | + return factors(bpc, vertices(bpc, partition_verts)) |
| 103 | +end |
| 104 | + |
| 105 | +function factors(bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex) |
| 106 | + return factors(bpc, [partition_vertex]) |
| 107 | +end |
| 108 | + |
| 109 | +function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc); kwargs...) |
| 110 | + return map(pv -> region_scalar(bpc, pv; kwargs...), pvs) |
| 111 | +end |
| 112 | + |
| 113 | +function edge_scalars( |
| 114 | + bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs... |
| 115 | +) |
| 116 | + return map(pe -> region_scalar(bpc, pe; kwargs...), pes) |
| 117 | +end |
| 118 | + |
| 119 | +function scalar_factors_quotient(bpc::AbstractBeliefPropagationCache) |
| 120 | + return vertex_scalars(bpc), edge_scalars(bpc) |
| 121 | +end |
| 122 | + |
| 123 | +function incoming_messages( |
| 124 | + bpc::AbstractBeliefPropagationCache, |
| 125 | + partition_vertices::Vector{<:PartitionVertex}; |
| 126 | + ignore_edges=(), |
| 127 | +) |
| 128 | + bpes = boundary_partitionedges(bpc, partition_vertices; dir=:in) |
| 129 | + ms = messages(bpc, setdiff(bpes, ignore_edges)) |
| 130 | + return reduce(vcat, ms; init=ITensor[]) |
| 131 | +end |
| 132 | + |
| 133 | +function incoming_messages( |
| 134 | + bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... |
| 135 | +) |
| 136 | + return incoming_messages(bpc, [partition_vertex]; kwargs...) |
| 137 | +end |
| 138 | + |
| 139 | +#Forward from partitioned graph |
| 140 | +for f in [ |
| 141 | + :(PartitionedGraphs.partitioned_graph), |
| 142 | + :(PartitionedGraphs.partitionedge), |
| 143 | + :(PartitionedGraphs.partitionvertices), |
| 144 | + :(PartitionedGraphs.vertices), |
| 145 | + :(PartitionedGraphs.boundary_partitionedges), |
| 146 | + :(ITensorMPS.linkinds), |
| 147 | +] |
| 148 | + @eval begin |
| 149 | + function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...) |
| 150 | + return $f(partitioned_tensornetwork(bpc), args...; kwargs...) |
| 151 | + end |
| 152 | + end |
| 153 | +end |
| 154 | + |
| 155 | +NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc)) |
| 156 | + |
| 157 | +""" |
| 158 | +Update the tensornetwork inside the cache |
| 159 | +""" |
| 160 | +function update_factors(bpc::AbstractBeliefPropagationCache, factors) |
| 161 | + bpc = copy(bpc) |
| 162 | + tn = tensornetwork(bpc) |
| 163 | + for vertex in eachindex(factors) |
| 164 | + # TODO: Add a check that this preserves the graph structure. |
| 165 | + setindex_preserve_graph!(tn, factors[vertex], vertex) |
| 166 | + end |
| 167 | + return bpc |
| 168 | +end |
| 169 | + |
| 170 | +function update_factor(bpc, vertex, factor) |
| 171 | + return update_factors(bpc, Dictionary([vertex], [factor])) |
| 172 | +end |
| 173 | + |
| 174 | +function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) |
| 175 | + mts = messages(bpc) |
| 176 | + return get(() -> default_message(bpc, edge; kwargs...), mts, edge) |
| 177 | +end |
| 178 | +function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...) |
| 179 | + return map(edge -> message(bpc, edge; kwargs...), edges) |
| 180 | +end |
| 181 | +function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message) |
| 182 | + bpc = copy(bpc) |
| 183 | + ms = messages(bpc) |
| 184 | + set!(ms, pe, message) |
| 185 | + return bpc |
| 186 | +end |
| 187 | + |
| 188 | +""" |
| 189 | +Compute message tensor as product of incoming mts and local state |
| 190 | +""" |
| 191 | +function updated_message( |
| 192 | + bpc::AbstractBeliefPropagationCache, |
| 193 | + edge::PartitionEdge; |
| 194 | + message_update_function=default_message_update, |
| 195 | + message_update_function_kwargs=(;), |
| 196 | +) |
| 197 | + vertex = src(edge) |
| 198 | + incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)]) |
| 199 | + state = factors(bpc, vertex) |
| 200 | + |
| 201 | + return message_update_function( |
| 202 | + ITensor[incoming_ms; state]; message_update_function_kwargs... |
| 203 | + ) |
| 204 | +end |
| 205 | + |
| 206 | +function update( |
| 207 | + alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs... |
| 208 | +) |
| 209 | + return set_message(bpc, edge, updated_message(bpc, edge; kwargs...)) |
| 210 | +end |
| 211 | + |
| 212 | +""" |
| 213 | +Do a sequential update of the message tensors on `edges` |
| 214 | +""" |
| 215 | +function update( |
| 216 | + alg::Algorithm, |
| 217 | + bpc::AbstractBeliefPropagationCache, |
| 218 | + edges::Vector; |
| 219 | + (update_diff!)=nothing, |
| 220 | + kwargs..., |
| 221 | +) |
| 222 | + bpc = copy(bpc) |
| 223 | + for e in edges |
| 224 | + prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing |
| 225 | + bpc = update(alg, bpc, e; kwargs...) |
| 226 | + if !isnothing(update_diff!) |
| 227 | + update_diff![] += message_diff(message(bpc, e), prev_message) |
| 228 | + end |
| 229 | + end |
| 230 | + return bpc |
| 231 | +end |
| 232 | + |
| 233 | +""" |
| 234 | +Do parallel updates between groups of edges of all message tensors |
| 235 | +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the |
| 236 | +mts relevant to that group. |
| 237 | +""" |
| 238 | +function update( |
| 239 | + alg::Algorithm, |
| 240 | + bpc::AbstractBeliefPropagationCache, |
| 241 | + edge_groups::Vector{<:Vector{<:PartitionEdge}}; |
| 242 | + kwargs..., |
| 243 | +) |
| 244 | + new_mts = copy(messages(bpc)) |
| 245 | + for edges in edge_groups |
| 246 | + bpc_t = update(alg, bpc, edges; kwargs...) |
| 247 | + for e in edges |
| 248 | + new_mts[e] = message(bpc_t, e) |
| 249 | + end |
| 250 | + end |
| 251 | + return set_messages(bpc, new_mts) |
| 252 | +end |
| 253 | + |
| 254 | +""" |
| 255 | +More generic interface for update, with default params |
| 256 | +""" |
| 257 | +function update( |
| 258 | + alg::Algorithm, |
| 259 | + bpc::AbstractBeliefPropagationCache; |
| 260 | + edges=default_edge_sequence(alg, bpc), |
| 261 | + maxiter=default_bp_maxiter(alg, bpc), |
| 262 | + message_update_kwargs=default_message_update_kwargs(alg, bpc), |
| 263 | + tol=nothing, |
| 264 | + verbose=false, |
| 265 | +) |
| 266 | + compute_error = !isnothing(tol) |
| 267 | + if isnothing(maxiter) |
| 268 | + error("You need to specify a number of iterations for BP!") |
| 269 | + end |
| 270 | + for i in 1:maxiter |
| 271 | + diff = compute_error ? Ref(0.0) : nothing |
| 272 | + bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...) |
| 273 | + if compute_error && (diff.x / length(edges)) <= tol |
| 274 | + if verbose |
| 275 | + println("BP converged to desired precision after $i iterations.") |
| 276 | + end |
| 277 | + break |
| 278 | + end |
| 279 | + end |
| 280 | + return bpc |
| 281 | +end |
| 282 | + |
| 283 | +function update( |
| 284 | + bpc::AbstractBeliefPropagationCache; |
| 285 | + alg::String=default_message_update_alg(bpc), |
| 286 | + kwargs..., |
| 287 | +) |
| 288 | + return update(Algorithm(alg), bpc; kwargs...) |
| 289 | +end |
0 commit comments