diff --git a/src/atoms/IndexAtom.jl b/src/atoms/IndexAtom.jl index 9196c7a93..49162e000 100644 --- a/src/atoms/IndexAtom.jl +++ b/src/atoms/IndexAtom.jl @@ -109,11 +109,17 @@ function Base.getindex(x::AbstractExpr, rows::AbstractVector{<:Real}, col::Real) return getindex(x, rows, col:col) end -function Base.getindex(x::AbstractExpr, I::AbstractMatrix{Bool}) +function Base.getindex( + x::AbstractExpr, + I::Union{AbstractMatrix{Bool},<:BitMatrix}, +) return [xi for (xi, ii) in zip(x, I) if ii] end -function Base.getindex(x::AbstractExpr, I::AbstractVector{Bool}) +function Base.getindex( + x::AbstractExpr, + I::Union{<:AbstractVector{Bool},<:BitVector}, +) return [xi for (xi, ii) in zip(x, I) if ii] end diff --git a/test/test_atoms.jl b/test/test_atoms.jl index 2fa320be3..6e38b06a0 100644 --- a/test/test_atoms.jl +++ b/test/test_atoms.jl @@ -517,9 +517,40 @@ function test_IndexAtom() _test_atom(target) do context return Variable(2, 2)[:, 2] end + # Base.getindex(x::AbstractExpr, I::AbstractVector{Bool}) y = [true, false, true] x = Variable(3) - @test string(x[y]) == string([x[1], x[3]]) + z = x[y] + @test string(z) == string([x[1], x[3]]) + @test z isa Vector{Convex.IndexAtom} + @test length(z) == 2 + Convex.set_value!(x, [1, 2, 3]) + @test Convex.evaluate.(z) == [1, 3] + # Base.getindex(x::AbstractExpr, I::AbstractMatrix{Bool}) + y = [true false; true true] + x = Variable(2, 2) + z = x[y] + @test z isa Vector{Convex.IndexAtom} + @test length(z) == 3 + Convex.set_value!(x, [1 3; 2 4]) + @test Convex.evaluate.(z) == [1, 2, 4] + # Base.getindex(x::AbstractExpr, I::BitVector) + y = BitVector([true, false, true]) + x = Variable(3) + z = x[y] + @test string(z) == string([x[1], x[3]]) + @test z isa Vector{Convex.IndexAtom} + @test length(z) == 2 + Convex.set_value!(x, [1, 2, 3]) + @test Convex.evaluate.(z) == [1, 3] + # Base.getindex(x::AbstractExpr, I::BitMatrix) + y = BitMatrix([true false; true true]) + x = Variable(2, 2) + z = x[y] + @test z isa Vector{Convex.IndexAtom} + @test length(z) == 3 + Convex.set_value!(x, [1 3; 2 4]) + @test Convex.evaluate.(z) == [1, 2, 4] return end