@@ -105,43 +105,34 @@ function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
105
105
return matrix[:, index]
106
106
end
107
107
108
- function _sort_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
108
+ function _topk_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
109
109
if sortby === nothing
110
110
return sort (matrix, dims = 2 ; rev)[:, 1 : k]
111
111
else
112
112
return _sort_col (matrix; rev, sortby)[:, 1 : k]
113
113
end
114
114
end
115
115
116
- function _sort_batch (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
117
- return map (x -> _sort_matrix (x, k; rev, sortby), matrices)
118
- end
119
-
120
- function _topk_batch (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
116
+ function _topk_batch (matrices:: AbstractArray , k:: Int ; rev:: Bool = true ,
121
117
sortby = nothing )
122
- tensor_matrix = reshape (matrix, size (matrix, 1 ), size (matrix, 2 ) ÷ number_graphs,
123
- number_graphs)
124
- sorted_matrix = _sort_batch (eachslice (tensor_matrix, dims = 3 ), k; rev, sortby)
118
+ sorted_matrix = map (x -> _topk_matrix (x, k; rev, sortby), matrices)
125
119
return reduce (hcat, sorted_matrix)
126
120
end
127
121
128
- function _topk (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
129
- sortby = nothing )
130
- if number_graphs == 1
131
- return _sort_matrix (matrix, k; rev, sortby)
132
- else
133
- return _topk_batch (matrix, number_graphs, k; rev, sortby)
134
- end
135
- end
136
-
137
122
"""
138
123
topk_nodes(g, feat, k; rev = true, sortby = nothing)
139
124
140
125
Graph-wise top-k on node features `feat` according to the `sortby` feature index.
141
126
"""
142
127
function topk_nodes (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
143
- matrix = getproperty (g. ndata, feat)
144
- return _topk (matrix, g. num_graphs, k; rev, sortby)
128
+ if g. num_graphs == 1
129
+ matrix = getproperty (g. ndata, feat)
130
+ return _topk_matrix (matrix, k; rev, sortby)
131
+ else
132
+ graphs = [getgraph (g, i) for i in 1 : (g. num_graphs)]
133
+ matrices = map (graph -> getproperty (graph. ndata, feat), graphs)
134
+ return _topk_batch (matrices, k; rev, sortby)
135
+ end
145
136
end
146
137
147
138
"""
150
141
Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
151
142
"""
152
143
function topk_edges (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
153
- matrix = getproperty (g. edata, feat)
154
- return _topk (matrix, g. num_graphs, k; rev, sortby)
144
+ if g. num_graphs == 1
145
+ matrix = getproperty (g. edata, feat)
146
+ return _topk_matrix (matrix, k; rev, sortby)
147
+ else
148
+ graphs = [getgraph (g, i) for i in 1 : (g. num_graphs)]
149
+ matrices = map (graph -> getproperty (graph. edata, feat), graphs)
150
+ return _topk_batch (matrices, k; rev, sortby)
151
+ end
155
152
end
0 commit comments