Skip to content

Commit a3768b6

Browse files
authored
Merge pull request #283 from AzamatB/ab-kmeans-broadcasting-bugfix
Fix broadcasting bug in `repick_unused_centers()` for K-means
2 parents 24a30ae + 5f280eb commit a3768b6

File tree

1 file changed

+26
-34
lines changed

1 file changed

+26
-34
lines changed

src/kmeans.jl

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
164164
update_centers!(X, weights, assignments, to_update, centers, wcounts)
165165

166166
if !isempty(unused)
167-
repick_unused_centers(X, costs, centers, unused, distance, rng)
167+
repick_unused_centers!(centers, unused, X, costs, distance, rng)
168168
to_update[unused] .= true
169169
end
170170

@@ -211,18 +211,16 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
211211
wcounts, objv, t, converged)
212212
end
213213

214-
#
215-
# Updates assignments, costs, and counts based on
216-
# an updated (squared) distance matrix
217-
#
214+
# Update point assignments, costs, and cluster counts based on
215+
# an updated (squared) distance matrix
218216
function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k x n)
219217
is_init::Bool, # in: whether it is the initial run
220218
assignments::Vector{Int}, # out: assignment vector (n)
221219
costs::Vector{<:Real}, # out: costs of the resultant assignment (n)
222220
counts::Vector{Int}, # out: # of points assigned to each cluster (k)
223221
to_update::Vector{Bool}, # out: whether a center needs update (k)
224-
unused::Vector{Int} # out: list of centers with no points assigned
225-
)
222+
unused::Vector{Int}, # out: list of centers with no points assigned
223+
)
226224
k, n = size(dmat)
227225

228226
# re-initialize the counting vector
@@ -272,17 +270,15 @@ function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k
272270
end
273271
end
274272

275-
#
276-
# Update centers based on updated assignments
277-
#
278-
# (specific to the case where points are not weighted)
279-
#
273+
# Update cluster centers and weights to match updated assignments
274+
# (non-weighted points case)
280275
function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
281276
weights::Nothing, # in: point weights
282277
assignments::Vector{Int}, # in: assignments (n)
283278
to_update::Vector{Bool}, # in: whether a center needs update (k)
284279
centers::AbstractMatrix{<:AbstractFloat}, # out: updated centers (d x k)
285-
wcounts::Vector{Int}) # out: updated cluster weights (k)
280+
wcounts::Vector{Int}, # out: updated cluster weights (k)
281+
)
286282
d, n = size(X)
287283
k = size(centers, 2)
288284

@@ -318,18 +314,15 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d
318314
end
319315
end
320316

321-
#
322-
# Update centers based on updated assignments
323-
#
324-
# (specific to the case where points are weighted)
325-
#
317+
# Update cluster centers and weights to match updated assignments
318+
# (weighted points case)
326319
function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
327320
weights::Vector{W}, # in: point weights (n)
328321
assignments::Vector{Int}, # in: assignments (n)
329322
to_update::Vector{Bool}, # in: whether a center needs update (k)
330323
centers::AbstractMatrix{<:Real}, # out: updated centers (d x k)
331-
wcounts::Vector{W} # out: updated cluster weights (k)
332-
) where W<:Real
324+
wcounts::Vector{W}, # out: updated cluster weights (k)
325+
) where W<:Real
333326
d, n = size(X)
334327
k = size(centers, 2)
335328

@@ -368,26 +361,25 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
368361
end
369362

370363

371-
#
372-
# Re-picks centers that have no points assigned to them.
373-
#
374-
function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix (d x n)
375-
costs::Vector{<:Real}, # in: the current assignment costs (n)
376-
centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k)
377-
unused::Vector{Int}, # in: indices of centers to be updated
378-
distance::SemiMetric, # in: function to calculate the distance with
379-
rng::AbstractRNG) # in: RNG object
364+
# Re-pick centers that have no points assigned to them.
365+
function repick_unused_centers!(centers::AbstractMatrix{<:AbstractFloat}, # out: the centers (d x k)
366+
unused::Vector{Int}, # in: indices of centers to be updated (k)
367+
X::AbstractMatrix{<:Real}, # in: the data matrix (d x n)
368+
costs::Vector{<:Real}, # in: the current assignment costs (n)
369+
distance::SemiMetric, # in: function to calculate the distance with
370+
rng::AbstractRNG,
371+
)
380372
# pick new centers using a scheme like kmeans++
381373
ds = similar(costs)
382-
tcosts = copy(costs)
374+
tcosts = copy(costs) # temporary costs used as sampling weights
383375
n = size(X, 2)
384376

385377
for i in unused
378+
# select a random point as a new center
386379
j = wsample(rng, 1:n, tcosts)
387-
tcosts[j] = 0
388-
v = view(X, :, j)
389-
centers[:, i] = v
380+
centers[:, i] = v = view(X, :, j)
390381
colwise!(distance, ds, v, X)
391-
tcosts = min(tcosts, ds)
382+
ds[j] = 0 # calculated distance might be not exactly zero
383+
tcosts .= min.(tcosts, ds)
392384
end
393385
end

0 commit comments

Comments
 (0)