@@ -164,7 +164,7 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
164
164
update_centers! (X, weights, assignments, to_update, centers, wcounts)
165
165
166
166
if ! isempty (unused)
167
- repick_unused_centers (X, costs, centers, unused , distance, rng)
167
+ repick_unused_centers! (centers, unused, X, costs , distance, rng)
168
168
to_update[unused] .= true
169
169
end
170
170
@@ -211,18 +211,16 @@ function _kmeans!(X::AbstractMatrix{<:Real}, # in: data matrix (d
211
211
wcounts, objv, t, converged)
212
212
end
213
213
214
- #
215
- # Updates assignments, costs, and counts based on
216
- # an updated (squared) distance matrix
217
- #
214
+ # Update point assignments, costs, and cluster counts based on
215
+ # an updated (squared) distance matrix
218
216
function update_assignments! (dmat:: Matrix{<:Real} , # in: distance matrix (k x n)
219
217
is_init:: Bool , # in: whether it is the initial run
220
218
assignments:: Vector{Int} , # out: assignment vector (n)
221
219
costs:: Vector{<:Real} , # out: costs of the resultant assignment (n)
222
220
counts:: Vector{Int} , # out: # of points assigned to each cluster (k)
223
221
to_update:: Vector{Bool} , # out: whether a center needs update (k)
224
- unused:: Vector{Int} # out: list of centers with no points assigned
225
- )
222
+ unused:: Vector{Int} , # out: list of centers with no points assigned
223
+ )
226
224
k, n = size (dmat)
227
225
228
226
# re-initialize the counting vector
@@ -272,17 +270,15 @@ function update_assignments!(dmat::Matrix{<:Real}, # in: distance matrix (k
272
270
end
273
271
end
274
272
275
- #
276
- # Update centers based on updated assignments
277
- #
278
- # (specific to the case where points are not weighted)
279
- #
273
+ # Update cluster centers and weights to match updated assignments
274
+ # (non-weighted points case)
280
275
function update_centers! (X:: AbstractMatrix{<:Real} , # in: data matrix (d x n)
281
276
weights:: Nothing , # in: point weights
282
277
assignments:: Vector{Int} , # in: assignments (n)
283
278
to_update:: Vector{Bool} , # in: whether a center needs update (k)
284
279
centers:: AbstractMatrix{<:AbstractFloat} , # out: updated centers (d x k)
285
- wcounts:: Vector{Int} ) # out: updated cluster weights (k)
280
+ wcounts:: Vector{Int} , # out: updated cluster weights (k)
281
+ )
286
282
d, n = size (X)
287
283
k = size (centers, 2 )
288
284
@@ -318,18 +314,15 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d
318
314
end
319
315
end
320
316
321
- #
322
- # Update centers based on updated assignments
323
- #
324
- # (specific to the case where points are weighted)
325
- #
317
+ # Update cluster centers and weights to match updated assignments
318
+ # (weighted points case)
326
319
function update_centers! (X:: AbstractMatrix{<:Real} , # in: data matrix (d x n)
327
320
weights:: Vector{W} , # in: point weights (n)
328
321
assignments:: Vector{Int} , # in: assignments (n)
329
322
to_update:: Vector{Bool} , # in: whether a center needs update (k)
330
323
centers:: AbstractMatrix{<:Real} , # out: updated centers (d x k)
331
- wcounts:: Vector{W} # out: updated cluster weights (k)
332
- ) where W<: Real
324
+ wcounts:: Vector{W} , # out: updated cluster weights (k)
325
+ ) where W<: Real
333
326
d, n = size (X)
334
327
k = size (centers, 2 )
335
328
@@ -368,26 +361,25 @@ function update_centers!(X::AbstractMatrix{<:Real}, # in: data matrix (d x n)
368
361
end
369
362
370
363
371
- #
372
- # Re-picks centers that have no points assigned to them.
373
- #
374
- function repick_unused_centers (X:: AbstractMatrix{<:Real} , # in: the data matrix (d x n)
375
- costs:: Vector{<:Real} , # in: the current assignment costs (n)
376
- centers:: AbstractMatrix{<:AbstractFloat} , # out: the centers (d x k)
377
- unused:: Vector{Int} , # in: indices of centers to be updated
378
- distance:: SemiMetric , # in: function to calculate the distance with
379
- rng:: AbstractRNG ) # in: RNG object
364
+ # Re-pick centers that have no points assigned to them.
365
+ function repick_unused_centers! (centers:: AbstractMatrix{<:AbstractFloat} , # out: the centers (d x k)
366
+ unused:: Vector{Int} , # in: indices of centers to be updated (k)
367
+ X:: AbstractMatrix{<:Real} , # in: the data matrix (d x n)
368
+ costs:: Vector{<:Real} , # in: the current assignment costs (n)
369
+ distance:: SemiMetric , # in: function to calculate the distance with
370
+ rng:: AbstractRNG ,
371
+ )
380
372
# pick new centers using a scheme like kmeans++
381
373
ds = similar (costs)
382
- tcosts = copy (costs)
374
+ tcosts = copy (costs) # temporary costs used as sampling weights
383
375
n = size (X, 2 )
384
376
385
377
for i in unused
378
+ # select a random point as a new center
386
379
j = wsample (rng, 1 : n, tcosts)
387
- tcosts[j] = 0
388
- v = view (X, :, j)
389
- centers[:, i] = v
380
+ centers[:, i] = v = view (X, :, j)
390
381
colwise! (distance, ds, v, X)
391
- tcosts = min (tcosts, ds)
382
+ ds[j] = 0 # calculated distance might be not exactly zero
383
+ tcosts .= min .(tcosts, ds)
392
384
end
393
385
end
0 commit comments