Skip to content

Commit

Permalink
MutableArithmetics for IPM/HSD
Browse files Browse the repository at this point in the history
Helps with the performance of the BigFloat arithmetic. The change
shouldn't affect other arithmetics, but it's coded so it'd be easy to
extend it to another arithmetic like BigFloat, should one appear in
the ecosystem (for example a pure Julia big float type).

A script for measuring allocation, uses my packages that are on
Gitlab:
```
setprecision(BigFloat, 2^11)

using
  PolynomialPassingThroughIntervals,
  PolynomialApproximation

const max_error = eps(Float64) / 2

const monomials_even = UInt8[i for i in 0:2:16]

degree_trig_domain(offset::Int, n::Int) =
  Rational{BigInt}[offset + (x//n)*45 for x in big"0":n]

perturbed_dtd(offset::Int, n::Int) =
  perturbed(degree_trig_domain(0, n), 1//8)

perturbed_dtd_pow2(offset::Int, exp::Int) =
  perturbed_dtd(offset, 2^exp)

function min_pol_cosd(exp::Int)
  dom = perturbed_dtd_pow2(0, exp)
  image = [(big"0.0", big"1.0") for i in 1:length(dom)]
  values = map(cosd ∘ BigFloat, dom)
  @Timev "exp = $exp" minimal_polynomial_passing_through_intervals_easy(
    monomials_even, dom, image, values, max_error)
end

min_pol_cosd(6)
println()
min_pol_cosd(9)
```

On master (a0032b5):
```
exp = 6: 28.425369 seconds (177.57 M allocations: 27.151 GiB, 13.17% gc time, 22.09% compilation time)
elapsed time (ns):  28425368965
gc time (ns):       3744847370
bytes allocated:    29152935010
pool allocs:        177349665
non-pool GC allocs: 9273
malloc() calls:     117773
realloc() calls:    88512
free() calls:       116745
minor collections:  634
full collections:   0

exp = 9: 309.997183 seconds (1.39 G allocations: 217.327 GiB, 32.12% gc time)
elapsed time (ns):  309997182893
gc time (ns):       99577583042
bytes allocated:    233353489632
pool allocs:        1385706165
non-pool GC allocs: 59648
malloc() calls:     729278
realloc() calls:    553593
free() calls:       734173
minor collections:  5232
full collections:   3
```

After this commit:
```
exp = 6: 28.219849 seconds (159.69 M allocations: 24.350 GiB, 12.86% gc time, 22.40% compilation time)
elapsed time (ns):  28219848824
gc time (ns):       3628115620
bytes allocated:    26145903664
pool allocs:        159476551
non-pool GC allocs: 9290
malloc() calls:     118019
realloc() calls:    88752
free() calls:       116991
minor collections:  568
full collections:   0

exp = 9: 311.400227 seconds (1.26 G allocations: 197.588 GiB, 32.10% gc time)
elapsed time (ns):  311400226728
gc time (ns):       99971520894
bytes allocated:    212158223648
pool allocs:        1259460435
non-pool GC allocs: 60794
malloc() calls:     728027
realloc() calls:    553972
free() calls:       724639
minor collections:  4736
full collections:   3
```

There's no performance improvement for a single call above, but there's
less allocation, which should translate to better real-world
performance, I guess.
  • Loading branch information
nsajko committed Jan 10, 2023
1 parent a0032b5 commit 4273f5c
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 33 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QPSReader = "10f199a5-22af-520b-b891-7ce84a7b1bd0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
140 changes: 127 additions & 13 deletions src/IPM/HSD/HSD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,99 @@ mutable struct HSD{T, Tv, Tb, Ta, Tk} <: AbstractIPMOptimizer{T}

end

buffer_for_dot_product(::Type{V}) where {V <: AbstractVector{<:Real}} =
buffer_for(LinearAlgebra.dot, V, V)

buffer_for_dot_product(::Type{F}) where {F <: Real} =
buffer_for_dot_product(Vector{F})

buffered_dot_product_to!(buf::B, result::F, x::V, y::V) where
{B <: Any, F <: BigFloat, V <: AbstractVector{F}} =
buffered_operate_to!(buf, result, LinearAlgebra.dot, x, y)

function buffered_dot_product!!(buf::B, x::V, y::V) where
{B <: Any, F <: BigFloat, V <: AbstractVector{F}}
ret = zero(F)
ret = buffered_dot_product_to!(buf, ret, x, y)
ret
end

buffered_dot_product!!(::Nothing, x::V, y::V) where
{F <: Real, V <: AbstractVector{F}} =
dot(x, y)

struct DotWeightedSumBuffer{F <: Real, DotBuffer <: Any}
tmp::F
dot::DotBuffer

function DotWeightedSumBuffer{F}() where {F <: Real}
dot_buffer = buffer_for_dot_product(F)
new{F, typeof(dot_buffer)}(zero(F), dot_buffer)
end
end

struct DotWeightedSumBufferDummy
dot::Nothing

DotWeightedSumBufferDummy() = new(nothing)
end

buffer_for_dot_weighted_sum(::Type{F}) where {F <: BigFloat} =
DotWeightedSumBuffer{F}()

buffer_for_dot_weighted_sum(::Type{F}) where {F <: Real} =
DotWeightedSumBufferDummy()

function buffered_dot_weighted_sum_to_inner!(
buf::DotWeightedSumBuffer{F},
sum::F,
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, <:Real}) where {n, F <: BigFloat}

sum = zero!!(sum)

for i in 1:n
weight = weights[i]
(x, y) = vecs[i]

buffered_dot_product_to!(buf.dot, buf.tmp, x, y)
mul!!(buf.tmp, weight)

sum = add!!(sum, buf.tmp)
end

sum
end

buffered_dot_weighted_sum_to!(
buf::DotWeightedSumBuffer{F},
sum::F,
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, Int}) where {n, F <: BigFloat} =
# It seems like the specialization
# *(x::BigFloat, c::Int8)
# could be more efficient than
# *(x::BigFloat, c::Int)
# MPFR has separate functions for those, and Julia uses them,
# there must be a good (performance) reason for that.
buffered_dot_weighted_sum_to_inner!(buf, sum, vecs, map(Int8, weights))

function buffered_dot_weighted_sum!!(
buf::DotWeightedSumBuffer{F},
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, Int}) where {n, F <: BigFloat}

ret = zero(F)
ret = buffered_dot_weighted_sum_to!(buf, ret, vecs, weights)
ret
end

buffered_dot_weighted_sum!!(
buf::DotWeightedSumBufferDummy,
vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}},
weights::NTuple{n, Int}) where {n, F <: Real} =
mapreduce((vec2, weight) -> weight*dot(vec2...), +, vecs, weights, init = zero(F))

include("step.jl")


Expand Down Expand Up @@ -101,13 +194,20 @@ function compute_residuals!(hsd::HSD{T}
mul!(res.rd, transpose(dat.A), pt.y, -one(T), one(T))
@. res.rd += pt.zu .* dat.uflag - pt.zl .* dat.lflag

dot_buf = buffer_for_dot_weighted_sum(T)

# Gap residual
# rg = c'x - (b'y + l'zl - u'zu) + k
res.rg = pt.κ + (dot(dat.c, pt.x) - (
dot(dat.b, pt.y)
+ dot(dat.l .* dat.lflag, pt.zl)
- dot(dat.u .* dat.uflag, pt.zu)
))
res.rg = pt.κ + (buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.c, pt.x),
(dat.b, pt.y),
(dat.l .* dat.lflag, pt.zl),
(dat.u .* dat.uflag, pt.zu)),
(
1, -1, -1, 1))
)

# Residuals norm
res.rp_nrm = norm(res.rp, Inf)
Expand All @@ -117,12 +217,15 @@ function compute_residuals!(hsd::HSD{T}
res.rg_nrm = norm(res.rg, Inf)

# Compute primal and dual bounds
hsd.primal_objective = dot(dat.c, pt.x) / pt.τ + dat.c0
hsd.dual_objective = (
dot(dat.b, pt.y)
+ dot(dat.l .* dat.lflag, pt.zl)
- dot(dat.u .* dat.uflag, pt.zu)
) / pt.τ + dat.c0
hsd.primal_objective = buffered_dot_product!!(dot_buf.dot, dat.c, pt.x) / pt.τ + dat.c0
hsd.dual_objective = buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.b, pt.y),
(dat.l .* dat.lflag, pt.zl),
(dat.u .* dat.uflag, pt.zu)),
(
1, 1, -1)) / pt.τ + dat.c0

return nothing
end
Expand Down Expand Up @@ -168,12 +271,15 @@ function update_solver_status!(hsd::HSD{T}, ϵp::T, ϵd::T, ϵg::T, ϵi::T) wher
return nothing
end

dot_buf = buffer_for_dot_weighted_sum(T)

# Check for infeasibility certificates
if max(
norm(dat.A * pt.x, Inf),
norm((pt.x .- pt.xl) .* dat.lflag, Inf),
norm((pt.x .+ pt.xu) .* dat.uflag, Inf)
) * (norm(dat.c, Inf) / max(1, norm(dat.b, Inf))) < - ϵi * dot(dat.c, pt.x)
) * (norm(dat.c, Inf) / max(1, norm(dat.b, Inf))) <
-ϵi * buffered_dot_product!!(dot_buf.dot, dat.c, pt.x)
# Dual infeasible, i.e., primal unbounded
hsd.primal_status = Sln_InfeasibilityCertificate
hsd.solver_status = Trm_DualInfeasible
Expand All @@ -185,7 +291,15 @@ function update_solver_status!(hsd::HSD{T}, ϵp::T, ϵd::T, ϵg::T, ϵi::T) wher
norm(dat.l .* dat.lflag, Inf),
norm(dat.u .* dat.uflag, Inf),
norm(dat.b, Inf)
) / (max(one(T), norm(dat.c, Inf))) < (dot(dat.b, pt.y) + dot(dat.l .* dat.lflag, pt.zl)- dot(dat.u .* dat.uflag, pt.zu)) * ϵi
) / (max(one(T), norm(dat.c, Inf))) < buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.b, pt.y),
(dat.l .* dat.lflag, pt.zl),
(dat.u .* dat.uflag, pt.zu)),
(
1, 1, -1)) * ϵi

# Primal infeasible
hsd.dual_status = Sln_InfeasibilityCertificate
hsd.solver_status = Trm_PrimalInfeasible
Expand Down
56 changes: 36 additions & 20 deletions src/IPM/HSD/step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,22 @@ function compute_step!(hsd::HSD{T, Tv}, params::IPMOptions{T}) where{T, Tv<:Abst
ξ_ = @. (dat.c - ((pt.zl / pt.xl) * dat.l) * dat.lflag - ((pt.zu / pt.xu) * dat.u) * dat.uflag)
KKT.solve!(hx, hy, hsd.kkt, dat.b, ξ_)

dot_buf = buffer_for_dot_weighted_sum(T)

# Recover h0 = ρg + κ / τ - c'hx + b'hy - u'hz
# Some of the summands may take large values,
# so care must be taken for numerical stability
h0 = (
dot(dat.l .* dat.lflag, (dat.l .* θl) .* dat.lflag)
+ dot(dat.u .* dat.uflag, (dat.u .* θu) .* dat.uflag)
- dot((@. (c + (θl * dat.l) * dat.lflag + (θu * dat.u) * dat.uflag)), hx)
+ dot(b, hy)
+ pt.κ / pt.τ
+ hsd.regG
)
h0 = buffered_dot_weighted_sum!!(
dot_buf,
(
(dat.l .* dat.lflag, (dat.l .* θl) .* dat.lflag),
(dat.u .* dat.uflag, (dat.u .* θu) .* dat.uflag),
((@. (c + (θl * dat.l) * dat.lflag + (θu * dat.u) * dat.uflag)), hx),
(b, hy)),
(
1, 1, -1, 1)) +
pt.κ / pt.τ +
hsd.regG

# Affine-scaling direction
@timeit hsd.timer "Newton" solve_newton_system!(Δ, hsd, hx, hy, h0,
Expand Down Expand Up @@ -211,22 +216,33 @@ function solve_newton_system!(Δ::Point{T, Tv},
end
@timeit hsd.timer "KKT" KKT.solve!.x, Δ.y, hsd.kkt, ξp, ξd_)

dot_buf = buffer_for_dot_weighted_sum(T)

# II. Recover Δτ, Δx, Δy
# Compute Δτ
@timeit hsd.timer "ξg_" ξg_ = (ξg + ξtk / pt.τ
- dot((ξxzl ./ pt.xl) .* dat.lflag, dat.l .* dat.lflag) # l'(Xl)^-1 * ξxzl
+ dot((ξxzu ./ pt.xu) .* dat.uflag, dat.u .* dat.uflag)
- dot(((pt.zl ./ pt.xl) .* ξl) .* dat.lflag, dat.l .* dat.lflag)
- dot(((pt.zu ./ pt.xu) .* ξu) .* dat.uflag, dat.u .* dat.uflag) #
)
@timeit hsd.timer "ξg_" ξg_ = ξg + ξtk / pt.τ +
buffered_dot_weighted_sum!!(
dot_buf,
(
((ξxzl ./ pt.xl) .* dat.lflag, dat.l .* dat.lflag), # l'(Xl)^-1 * ξxzl
((ξxzu ./ pt.xu) .* dat.uflag, dat.u .* dat.uflag),
(((pt.zl ./ pt.xl) .* ξl) .* dat.lflag, dat.l .* dat.lflag),
(((pt.zu ./ pt.xu) .* ξu) .* dat.uflag, dat.u .* dat.uflag)),
(
-1, 1, -1, -1))

@timeit hsd.timer "Δτ" Δ.τ = (
ξg_
+ dot((@. (dat.c
+ ((pt.zl / pt.xl) * dat.l) * dat.lflag
+ ((pt.zu / pt.xu) * dat.u) * dat.uflag))
, Δ.x)
- dot(dat.b, Δ.y)
ξg_ +
buffered_dot_weighted_sum!!(
dot_buf,
(
((@. (dat.c
+ ((pt.zl / pt.xl) * dat.l) * dat.lflag
+ ((pt.zu / pt.xu) * dat.u) * dat.uflag))
, Δ.x),
(dat.b, Δ.y)),
(
1, -1))
) / h0


Expand Down
1 change: 1 addition & 0 deletions src/Tulip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Tulip

using LinearAlgebra
using Logging
using MutableArithmetics
using Printf
using SparseArrays
using TOML
Expand Down

0 comments on commit 4273f5c

Please sign in to comment.