Skip to content

Commit 05fca7c

Browse files
Add random_walk_pe (#273)
* Add randomwalk positional encoding * Add test randomwalkPE * export randomwalkPE * Fix degree Co-authored-by: Carlo Lucibello <[email protected]> * Fix initialization matrix Co-authored-by: Carlo Lucibello <[email protected]> * Change return * Add clearer adjacency matrix parameters Co-authored-by: Carlo Lucibello <[email protected]> * Fix dense_zeros_like * Rename function * Export correct function * Add new test compared with PyTorch * Add docstring --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 5b23e9c commit 05fca7c

File tree

3 files changed

+50
-11
lines changed

3 files changed

+50
-11
lines changed

src/GNNGraphs/GNNGraphs.jl

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ export add_nodes,
6262
set_edge_weight,
6363
to_bidirected,
6464
to_unidirected,
65+
random_walk_pe,
6566
# from Flux
6667
batch,
6768
unbatch,

src/GNNGraphs/transform.jl

+27
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,33 @@ function rand_edge_split(g::GNNGraph, frac; bidirected = is_bidirected(g))
694694
return g1, g2
695695
end
696696

697+
"""
698+
random_walk_pe(g, walk_length)
699+
700+
Return the random walk positional encoding from the paper [Graph Neural Networks with Learnable Structural and Positional Representations](https://arxiv.org/abs/2110.07875) of the given graph `g` and the length of the walk `walk_length` as a matrix of size `(walk_length, g.num_nodes)`.
701+
"""
702+
function random_walk_pe(g::GNNGraph, walk_length::Int)
703+
matrix = zeros(walk_length, g.num_nodes)
704+
adj = adjacency_matrix(g, Float32; dir = :out)
705+
matrix = dense_zeros_like(adj, Float32, (walk_length, g.num_nodes))
706+
deg = sum(adj, dims = 2) |> vec
707+
deg_inv = inv.(deg)
708+
deg_inv[isinf.(deg_inv)] .= 0
709+
RW = adj * Diagonal(deg_inv)
710+
out = RW
711+
matrix[1, :] .= diag(RW)
712+
for i in 2:walk_length
713+
out = out * RW
714+
matrix[i, :] .= diag(out)
715+
end
716+
return matrix
717+
end
718+
719+
dense_zeros_like(a::SparseMatrixCSC, T::Type, sz = size(a)) = zeros(T, sz)
720+
dense_zeros_like(a::AbstractArray, T::Type, sz = size(a)) = fill!(similar(a, T, sz), 0)
721+
dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
722+
dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz)
723+
697724
# """
698725
# Transform vector of cartesian indexes into a tuple of vectors containing integers.
699726
# """

test/GNNGraphs/transform.jl

+22-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
@testset "add self-loops" begin
22
A = [1 1 0 0
3-
0 0 1 0
4-
0 0 0 1
5-
1 0 0 0]
3+
0 0 1 0
4+
0 0 0 1
5+
1 0 0 0]
66
A2 = [2 1 0 0
7-
0 1 1 0
8-
0 0 1 1
9-
1 0 0 1]
7+
0 1 1 0
8+
0 0 1 1
9+
1 0 0 1]
1010

1111
g = GNNGraph(A; graph_type = GRAPH_T)
1212
fg2 = add_self_loops(g)
@@ -18,7 +18,7 @@ end
1818

1919
@testset "batch" begin
2020
g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10),
21-
graph_type = GRAPH_T)
21+
graph_type = GRAPH_T)
2222
g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T)
2323
g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T)
2424

@@ -44,7 +44,7 @@ end
4444
# Batch of batches
4545
g123123 = Flux.batch([g123, g123])
4646
@test g123123.graph_indicator ==
47-
[fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)]
47+
[fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)]
4848
@test g123123.num_graphs == 6
4949
end
5050

@@ -67,8 +67,8 @@ end
6767
c = 3
6868
ngraphs = 10
6969
gs = [rand_graph(n, c * n, ndata = rand(2, n), edata = rand(3, c * n),
70-
graph_type = GRAPH_T)
71-
for _ in 1:ngraphs]
70+
graph_type = GRAPH_T)
71+
for _ in 1:ngraphs]
7272
gall = Flux.batch(gs)
7373
gs2 = Flux.unbatch(gall)
7474
@test gs2[1] == gs[1]
@@ -77,7 +77,7 @@ end
7777

7878
@testset "getgraph" begin
7979
g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10),
80-
graph_type = GRAPH_T)
80+
graph_type = GRAPH_T)
8181
g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T)
8282
g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T)
8383
g = Flux.batch([g1, g2, g3])
@@ -268,3 +268,14 @@ end end
268268
@test nv(DG) == g.num_nodes
269269
@test ne(DG) == g.num_edges
270270
end
271+
272+
@testset "random_walk_pe" begin
273+
s = [1, 2, 2, 3]
274+
t = [2, 1, 3, 2]
275+
ndata = [-1, 0, 1]
276+
g = GNNGraph(s, t, graph_type = GRAPH_T, ndata = ndata)
277+
output = random_walk_pe(g, 3)
278+
@test output == [0.0 0.0 0.0
279+
0.5 1.0 0.5
280+
0.0 0.0 0.0]
281+
end

0 commit comments

Comments
 (0)