Skip to content

Commit

Permalink
Add fallback for MOI.AbstractSymmetricMatrixSet{Triangle,Square} (#3424)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jun 27, 2023
1 parent 0b3ee89 commit 0f123b8
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 51 deletions.
7 changes: 2 additions & 5 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ end

function build_constraint(
_error::Function,
x::Matrix,
x::AbstractMatrix,
set::MOI.AbstractVectorSet,
)
return _error(
Expand All @@ -793,10 +793,7 @@ end
function build_constraint(
_error::Function,
::Matrix,
T::Union{
MOI.PositiveSemidefiniteConeSquare,
MOI.PositiveSemidefiniteConeTriangle,
},
T::MOI.PositiveSemidefiniteConeTriangle,
)
return _error("instead of `$(T)`, use `JuMP.PSDCone()`.")
end
Expand Down
100 changes: 63 additions & 37 deletions src/sd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,6 @@ triangular part of the matrix is constrained to belong to the
"""
struct PSDCone end

function build_constraint(
_error::Function,
f::AbstractMatrix{<:AbstractJuMPScalar},
::Nonnegatives,
extra::PSDCone,
)
return build_constraint(_error, f, extra)
end

function build_constraint(
_error::Function,
f::AbstractMatrix{<:AbstractJuMPScalar},
::Nonpositives,
extra::PSDCone,
)
new_f = _MA.operate!!(*, -1, f)
return build_constraint(_error, new_f, extra)
end

"""
SymmetricMatrixShape
Expand All @@ -167,6 +148,7 @@ lower-left triangular part given row by row).
struct SymmetricMatrixShape <: AbstractShape
side_dimension::Int
end

function reshape_vector(
vectorized_form::Vector{T},
shape::SymmetricMatrixShape,
Expand All @@ -181,12 +163,14 @@ function reshape_vector(
end
return LinearAlgebra.Symmetric(matrix)
end

function reshape_set(
::MOI.PositiveSemidefiniteConeTriangle,
::SymmetricMatrixShape,
)
return PSDCone()
end

function vectorize(matrix, ::SymmetricMatrixShape)
n = LinearAlgebra.checksquare(matrix)
return [matrix[i, j] for j in 1:n for i in 1:j]
Expand Down Expand Up @@ -414,12 +398,7 @@ function build_variable(
)
n = _square_side(_error, variables)
set = MOI.PositiveSemidefiniteConeTriangle(n)
shape = SymmetricMatrixShape(n)
return VariablesConstrainedOnCreation(
_vectorize_variables(_error, variables),
set,
shape,
)
return build_variable(_error, variables, set)
end

function value(
Expand Down Expand Up @@ -476,12 +455,7 @@ function build_constraint(
::PSDCone,
) where {V<:AbstractJuMPScalar,M<:AbstractMatrix{V}}
n = LinearAlgebra.checksquare(Q)
shape = SymmetricMatrixShape(n)
return VectorConstraint(
vectorize(Q, shape),
MOI.PositiveSemidefiniteConeTriangle(n),
shape,
)
return build_constraint(_error, Q, MOI.PositiveSemidefiniteConeTriangle(n))
end

"""
Expand Down Expand Up @@ -511,12 +485,7 @@ function build_constraint(
::PSDCone,
)
n = LinearAlgebra.checksquare(Q)
shape = SquareMatrixShape(n)
return VectorConstraint(
vectorize(Q, shape),
MOI.PositiveSemidefiniteConeSquare(n),
shape,
)
return build_constraint(_error, Q, MOI.PositiveSemidefiniteConeSquare(n))
end

"""
Expand Down Expand Up @@ -742,3 +711,60 @@ function build_constraint(_error::Function, ::AbstractMatrix, ::Zeros)
"`LinearAlgebra.Symmetric` or `LinearAlgebra.Hermitian`.",
)
end

function build_constraint(
_error::Function,
Q::LinearAlgebra.Symmetric{V,M},
set::MOI.AbstractSymmetricMatrixSetTriangle,
) where {V<:AbstractJuMPScalar,M<:AbstractMatrix{V}}
n = LinearAlgebra.checksquare(Q)
shape = SymmetricMatrixShape(n)
return VectorConstraint(vectorize(Q, shape), set, shape)
end

function build_constraint(
_error::Function,
Q::AbstractMatrix{<:AbstractJuMPScalar},
set::MOI.AbstractSymmetricMatrixSetSquare,
)
n = LinearAlgebra.checksquare(Q)
shape = SquareMatrixShape(n)
return VectorConstraint(vectorize(Q, shape), set, shape)
end

function build_constraint(
_error::Function,
f::AbstractMatrix{<:AbstractJuMPScalar},
::Nonnegatives,
extra::Union{
MOI.AbstractSymmetricMatrixSetTriangle,
MOI.AbstractSymmetricMatrixSetSquare,
PSDCone,
},
)
return build_constraint(_error, f, extra)
end

function build_constraint(
_error::Function,
f::AbstractMatrix{<:AbstractJuMPScalar},
::Nonpositives,
extra::Union{
MOI.AbstractSymmetricMatrixSetTriangle,
MOI.AbstractSymmetricMatrixSetSquare,
PSDCone,
},
)
new_f = _MA.operate!!(*, -1, f)
return build_constraint(_error, new_f, extra)
end

function build_variable(
_error::Function,
variables::Matrix{<:AbstractVariable},
set::MOI.AbstractSymmetricMatrixSetTriangle,
)
n = _square_side(_error, variables)
x = _vectorize_variables(_error, variables)
return VariablesConstrainedOnCreation(x, set, SymmetricMatrixShape(n))
end
9 changes: 0 additions & 9 deletions test/test_constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -652,15 +652,6 @@ function test_extension_PSD_constraint_errors(
)
model = ModelType()
@variable(model, X[1:2, 1:2])
err = ErrorException(
"In `@constraint(model, X in MOI.PositiveSemidefiniteConeSquare(2))`:" *
" instead of `MathOptInterface.PositiveSemidefiniteConeSquare(2)`," *
" use `JuMP.PSDCone()`.",
)
@test_throws_strip(
err,
@constraint(model, X in MOI.PositiveSemidefiniteConeSquare(2))
)
err = ErrorException(
"In `@constraint(model, X in MOI.PositiveSemidefiniteConeTriangle(2))`:" *
" instead of `MathOptInterface.PositiveSemidefiniteConeTriangle(2)`," *
Expand Down

0 comments on commit 0f123b8

Please sign in to comment.