Skip to content

Commit

Permalink
Issue JuliaStats#64: added n_init to kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
lbollar authored and lollar committed Sep 30, 2016
1 parent ea03689 commit 775eefc
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 74 deletions.
165 changes: 94 additions & 71 deletions src/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ const _kmeans_default_init = :kmpp
const _kmeans_default_maxiter = 100
const _kmeans_default_tol = 1.0e-6
const _kmeans_default_display = :none
const _kmeans_default_n_init = 10

function kmeans!{T<:AbstractFloat}(X::Matrix{T}, centers::Matrix{T};
weights=nothing,
maxiter::Integer=_kmeans_default_maxiter,
tol::Real=_kmeans_default_tol,
display::Symbol=_kmeans_default_display)


m, n = size(X)
m2, k = size(centers)
Expand All @@ -43,18 +45,37 @@ function kmeans(X::Matrix, k::Int;
weights=nothing,
init=_kmeans_default_init,
maxiter::Integer=_kmeans_default_maxiter,
n_init::Integer=_kmeans_default_n_init,
tol::Real=_kmeans_default_tol,
display::Symbol=_kmeans_default_display)


m, n = size(X)
(2 <= k < n) || error("k must have 2 <= k < n.")
iseeds = initseeds(init, X, k)
centers = copyseeds(X, iseeds)
kmeans!(X, centers;
weights=weights,
maxiter=maxiter,
tol=tol,
display=display)
n_init > 0 || error("n_init must be greater than 0")

lowestcost::Float64 = Inf
local bestresult::KmeansResult

for i = 1:n_init

iseeds = initseeds(init, X, k)
centers = copyseeds(X, iseeds)
result = kmeans!(X, centers;
weights=weights,
maxiter=maxiter,
tol=tol,
display=display)

if result.totalcost < lowestcost
lowestcost = result.totalcost
bestresult = result
end

end

return bestresult

end

#### Core implementation
Expand All @@ -72,86 +93,88 @@ function _kmeans!{T<:AbstractFloat}(
tol::Real, # in: tolerance of change at convergence
displevel::Int) # in: the level of display

# initialize

k = size(centers, 2)
to_update = Array(Bool, k) # indicators of whether a center needs to be updated
unused = Int[]
num_affected::Int = k # number of centers, to which the distances need to be recomputed

dmat = pairwise(SqEuclidean(), centers, x)
dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
update_assignments!(dmat, true, assignments, costs, counts, to_update, unused)
objv = w == nothing ? sum(costs) : dot(w, costs)

# main loop
t = 0
converged = false
if displevel >= 2
@printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected"
println("-------------------------------------------------------------")
@printf("%7d %18.6e\n", t, objv)
end

while !converged && t < maxiter
t = t + 1
# initialize

# update (affected) centers
k = size(centers, 2)
to_update = Array(Bool, k) # indicators of whether a center needs to be updated
unused = Int[]
num_affected::Int = k # number of centers, to which the distances need to be recomputed

update_centers!(x, w, assignments, to_update, centers, cweights)
dmat = pairwise(SqEuclidean(), centers, x)
dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
update_assignments!(dmat, true, assignments, costs, counts, to_update, unused)
objv = w == nothing ? sum(costs) : dot(w, costs)

if !isempty(unused)
repick_unused_centers(x, costs, centers, unused)
end
# main loop
t = 0
converged = false
if displevel >= 2
@printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected"
println("-------------------------------------------------------------")
@printf("%7d %18.6e\n", t, objv)
end

# update pairwise distance matrix
while !converged && t < maxiter
t = t + 1

if !isempty(unused)
to_update[unused] = true
end
# update (affected) centers

if t == 1 || num_affected > 0.75 * k
pairwise!(dmat, SqEuclidean(), centers, x)
else
# if only a small subset is affected, only compute for that subset
affected_inds = find(to_update)
dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x)
dmat[affected_inds, :] = dmat_p
end
update_centers!(x, w, assignments, to_update, centers, cweights)

# update assignments
if !isempty(unused)
repick_unused_centers(x, costs, centers, unused)
end

update_assignments!(dmat, false, assignments, costs, counts, to_update, unused)
num_affected = sum(to_update) + length(unused)
# update pairwise distance matrix

# compute change of objective and determine convergence
if !isempty(unused)
to_update[unused] = true
end

prev_objv = objv
objv = w == nothing ? sum(costs) : dot(w, costs)
objv_change = objv - prev_objv
if t == 1 || num_affected > 0.75 * k
pairwise!(dmat, SqEuclidean(), centers, x)
else
# if only a small subset is affected, only compute for that subset
affected_inds = find(to_update)
dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x)
dmat[affected_inds, :] = dmat_p
end

if objv_change > tol
warn("The objective value changes towards an opposite direction")
end
# update assignments

if abs(objv_change) < tol
converged = true
end
update_assignments!(dmat, false, assignments, costs, counts, to_update, unused)
num_affected = sum(to_update) + length(unused)

# display iteration information (if asked)
# compute change of objective and determine convergence

if displevel >= 2
@printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected)
end
end
prev_objv = objv
objv = w == nothing ? sum(costs) : dot(w, costs)
objv_change = objv - prev_objv

if displevel >= 1
if converged
println("K-means converged with $t iterations (objv = $objv)")
else
println("K-means terminated without convergence after $t iterations (objv = $objv)")
end
end
if objv_change > tol
warn("The objective value changes towards an opposite direction")
end

if abs(objv_change) < tol
converged = true
end

# display iteration information (if asked)

if displevel >= 2
@printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected)
end
end

if displevel >= 1
if converged
println("K-means converged with $t iterations (objv = $objv)")
else
println("K-means terminated without convergence after $t iterations (objv = $objv)")
end
end

return KmeansResult(centers, assignments, costs, counts, cweights,
@compat(Float64(objv)), t, converged)
Expand Down
6 changes: 3 additions & 3 deletions test/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ k = 10
x = rand(m, n)

# non-weighted
r = kmeans(x, k; maxiter=50)
r = kmeans(x, k; maxiter=50, n_init=2)
@test isa(r, KmeansResult{Float64})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand All @@ -24,7 +24,7 @@ r = kmeans(x, k; maxiter=50)
@test_approx_eq sum(r.costs) r.totalcost

# non-weighted (float32)
r = kmeans(@compat(map(Float32, x)), k; maxiter=50)
r = kmeans(@compat(map(Float32, x)), k; maxiter=50, n_init=2)
@test isa(r, KmeansResult{Float32})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand All @@ -37,7 +37,7 @@ r = kmeans(@compat(map(Float32, x)), k; maxiter=50)

# weighted
w = rand(n)
r = kmeans(x, k; maxiter=50, weights=w)
r = kmeans(x, k; maxiter=50, weights=w, n_init=2)
@test isa(r, KmeansResult{Float64})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand Down

0 comments on commit 775eefc

Please sign in to comment.