Skip to content

Commit eec3a46

Browse files
committed
Fix batch case and reorder
1 parent e10c4a9 commit eec3a46

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

src/utils.jl

+19-22
Original file line numberDiff line numberDiff line change
@@ -105,43 +105,34 @@ function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
105105
return matrix[:, index]
106106
end
107107

108-
function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
108+
function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
109109
if sortby === nothing
110110
return sort(matrix, dims = 2; rev)[:, 1:k]
111111
else
112112
return _sort_col(matrix; rev, sortby)[:, 1:k]
113113
end
114114
end
115115

116-
function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing)
117-
return map(x -> _sort_matrix(x, k; rev, sortby), matrices)
118-
end
119-
120-
function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
116+
function _topk_batch(matrices::AbstractArray, k::Int; rev::Bool = true,
121117
sortby = nothing)
122-
tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs,
123-
number_graphs)
124-
sorted_matrix = _sort_batch(eachslice(tensor_matrix, dims = 3), k; rev, sortby)
118+
sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby), matrices)
125119
return reduce(hcat, sorted_matrix)
126120
end
127121

128-
function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
129-
sortby = nothing)
130-
if number_graphs == 1
131-
return _sort_matrix(matrix, k; rev, sortby)
132-
else
133-
return _topk_batch(matrix, number_graphs, k; rev, sortby)
134-
end
135-
end
136-
137122
"""
138123
topk_nodes(g, feat, k; rev = true, sortby = nothing)
139124
140125
Graph-wise top-k on node features `feat` according to the `sortby` feature index.
141126
"""
142127
function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
143-
matrix = getproperty(g.ndata, feat)
144-
return _topk(matrix, g.num_graphs, k; rev, sortby)
128+
if g.num_graphs == 1
129+
matrix = getproperty(g.ndata, feat)
130+
return _topk_matrix(matrix, k; rev, sortby)
131+
else
132+
graphs = [getgraph(g, i) for i in 1:(g.num_graphs)]
133+
matrices = map(graph -> getproperty(graph.ndata, feat), graphs)
134+
return _topk_batch(matrices, k; rev, sortby)
135+
end
145136
end
146137

147138
"""
@@ -150,6 +141,12 @@ end
150141
Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
151142
"""
152143
function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
153-
matrix = getproperty(g.edata, feat)
154-
return _topk(matrix, g.num_graphs, k; rev, sortby)
144+
if g.num_graphs == 1
145+
matrix = getproperty(g.edata, feat)
146+
return _topk_matrix(matrix, k; rev, sortby)
147+
else
148+
graphs = [getgraph(g, i) for i in 1:(g.num_graphs)]
149+
matrices = map(graph -> getproperty(graph.edata, feat), graphs)
150+
return _topk_batch(matrices, k; rev, sortby)
151+
end
155152
end

0 commit comments

Comments
 (0)