|
1 | 1 | @testset "add self-loops" begin
|
2 | 2 | A = [1 1 0 0
|
3 |
| - 0 0 1 0 |
4 |
| - 0 0 0 1 |
5 |
| - 1 0 0 0] |
| 3 | + 0 0 1 0 |
| 4 | + 0 0 0 1 |
| 5 | + 1 0 0 0] |
6 | 6 | A2 = [2 1 0 0
|
7 |
| - 0 1 1 0 |
8 |
| - 0 0 1 1 |
9 |
| - 1 0 0 1] |
| 7 | + 0 1 1 0 |
| 8 | + 0 0 1 1 |
| 9 | + 1 0 0 1] |
10 | 10 |
|
11 | 11 | g = GNNGraph(A; graph_type = GRAPH_T)
|
12 | 12 | fg2 = add_self_loops(g)
|
|
18 | 18 |
|
19 | 19 | @testset "batch" begin
|
20 | 20 | g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10),
|
21 |
| - graph_type = GRAPH_T) |
| 21 | + graph_type = GRAPH_T) |
22 | 22 | g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T)
|
23 | 23 | g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T)
|
24 | 24 |
|
|
44 | 44 | # Batch of batches
|
45 | 45 | g123123 = Flux.batch([g123, g123])
|
46 | 46 | @test g123123.graph_indicator ==
|
47 |
| - [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)] |
| 47 | + [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)] |
48 | 48 | @test g123123.num_graphs == 6
|
49 | 49 | end
|
50 | 50 |
|
|
67 | 67 | c = 3
|
68 | 68 | ngraphs = 10
|
69 | 69 | gs = [rand_graph(n, c * n, ndata = rand(2, n), edata = rand(3, c * n),
|
70 |
| - graph_type = GRAPH_T) |
71 |
| - for _ in 1:ngraphs] |
| 70 | + graph_type = GRAPH_T) |
| 71 | + for _ in 1:ngraphs] |
72 | 72 | gall = Flux.batch(gs)
|
73 | 73 | gs2 = Flux.unbatch(gall)
|
74 | 74 | @test gs2[1] == gs[1]
|
|
77 | 77 |
|
78 | 78 | @testset "getgraph" begin
|
79 | 79 | g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10),
|
80 |
| - graph_type = GRAPH_T) |
| 80 | + graph_type = GRAPH_T) |
81 | 81 | g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T)
|
82 | 82 | g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T)
|
83 | 83 | g = Flux.batch([g1, g2, g3])
|
@@ -268,3 +268,14 @@ end end
|
268 | 268 | @test nv(DG) == g.num_nodes
|
269 | 269 | @test ne(DG) == g.num_edges
|
270 | 270 | end
|
| 271 | + |
| 272 | +@testset "random_walk_pe" begin |
| 273 | + s = [1, 2, 2, 3] |
| 274 | + t = [2, 1, 3, 2] |
| 275 | + ndata = [-1, 0, 1] |
| 276 | + g = GNNGraph(s, t, graph_type = GRAPH_T, ndata = ndata) |
| 277 | + output = random_walk_pe(g, 3) |
| 278 | + @test output == [0.0 0.0 0.0 |
| 279 | + 0.5 1.0 0.5 |
| 280 | + 0.0 0.0 0.0] |
| 281 | +end |
0 commit comments