From 184da75d6d0424357e18eab059419bb4fc3add1b Mon Sep 17 00:00:00 2001 From: Greg Neustroev Date: Mon, 1 Jul 2024 14:52:40 +0200 Subject: [PATCH] Add support for Dirac weights --- src/weight_fitting.jl | 20 ++++++++++++++++++-- test/test-weight-fitting.jl | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/weight_fitting.jl b/src/weight_fitting.jl index 0df9b3a..6aa1879 100644 --- a/src/weight_fitting.jl +++ b/src/weight_fitting.jl @@ -62,13 +62,27 @@ end """ project_onto_nonnegative_orthant(vector) -Projects `vector` onto the nonnegative_orthant. This projection is trivial: +Projects `vector` onto the nonnegative orthant. This projection is trivial: replace negative components of the vector with zeros. """ function project_onto_nonnegative_orthant(vector::Vector{Float64}) return max.(vector, 0.0) end +""" + project_onto_standard_basis(vector) + +Projects `vector` onto the standard basis. This projection is trivial: +replace all components of the vector with zeros, except for the largest one, +which is replaced with one. +""" +function project_onto_standard_basis(vector::Vector{Float64}) + i = argmax(vector) + result = zeros(size(vector)) + result[i] = 1.0 + return result +end + """ projected_subgradient_descent!(x; gradient, projection, niters, rtol, learning_rate, adaptive_grad) @@ -166,7 +180,9 @@ function fit_rep_period_weights!( args..., ) # Determine the appropriate projection method - if weight_type == :convex + if weight_type == :dirac + projection = project_onto_standard_basis + elseif weight_type == :convex projection = project_onto_simplex elseif weight_type == :conical projection = project_onto_nonnegative_orthant diff --git a/test/test-weight-fitting.jl b/test/test-weight-fitting.jl index 17498cb..c6a6bc8 100644 --- a/test/test-weight-fitting.jl +++ b/test/test-weight-fitting.jl @@ -15,6 +15,23 @@ TulipaClustering.project_onto_simplex(x) ≈ [0.0, 1.0] end end + + @testset "Make sure that projection onto standard basis works correctly" begin + @test begin + x = [1.0, 10.0] + TulipaClustering.project_onto_standard_basis(x) == [0.0, 1.0] + end + + @test begin + x = [10.0, 1.0] + TulipaClustering.project_onto_standard_basis(x) == [1.0, 0.0] + end + + @test begin + x = [-2.0, 1.0] + TulipaClustering.project_onto_standard_basis(x) == [0.0, 1.0] + end + end end @testset "Subgradient descent" begin @@ -42,6 +59,23 @@ end return DBInterface.execute(con, "SELECT * FROM profiles") |> DataFrame end + @testset "Make sure that weight fitting works correctly for Dirac weights" begin + @test begin + clustering_data = get_data() + split_into_periods!(clustering_data; period_duration = 24 * 7) + clustering_result = find_representative_periods( + clustering_data, + 10; + drop_incomplete_last_period = false, + method = :k_means, + distance = SqEuclidean(), + init = :kmcen, + ) + TulipaClustering.fit_rep_period_weights!(clustering_result; weight_type = :dirac, niters = 5) + all(sum(clustering_result.weight_matrix[1:(end - 1), :], dims = 2) .== 1.0) + end + end + @testset "Make sure that weight fitting works correctly for convex weights" begin @test begin clustering_data = get_data()