Skip to content

Commit

Permalink
Don't use ccalls where wrappers already exist (fmpz edition) (#1913)
Browse files Browse the repository at this point in the history
Also add some tests.

Co-authored-by: Max Horn <[email protected]>
  • Loading branch information
lgoettgens and fingolfin authored Oct 25, 2024
1 parent 2a54dbe commit 246a850
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 98 deletions.
20 changes: 14 additions & 6 deletions src/flint/fmpq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,31 +661,39 @@ end
#
################################################################################

remove(a::QQFieldElem, b::Integer) = remove(a, ZZRingElem(b))
function remove(a::QQFieldElem, b::IntegerUnion)
b <= 1 && error("Factor <= 1")
a == 0 && error("Not yet implemented")
remove!(deepcopy(a), ZZ(b))
end

valuation(a::QQFieldElem, b::Integer) = valuation(a, ZZRingElem(b))
function valuation(a::QQFieldElem, b::IntegerUnion)
b <= 1 && error("Factor <= 1")
a == 0 && error("Not yet implemented")
valuation!(deepcopy(a), ZZ(b))
end

function remove!(a::QQFieldElem, b::ZZRingElem)
nr = _num_ptr(a)
vn = ccall((:fmpz_remove, libflint), Clong, (Ptr{ZZRingElem}, Ptr{ZZRingElem}, Ref{ZZRingElem}), nr, nr, b)
vn, nr = remove!(nr, b)
#QQFieldElem's are simplified: either num OR den will be non-trivial
if !is_zero(vn)
return vn, a
end
nr = _den_ptr(a)
vn = ccall((:fmpz_remove, libflint), Clong, (Ptr{ZZRingElem}, Ptr{ZZRingElem}, Ref{ZZRingElem}), nr, nr, b)
vn, nr = remove!(nr, b)
return -vn, a
end

function valuation!(a::QQFieldElem, b::ZZRingElem)
nr = _num_ptr(a)
vn = ccall((:fmpz_remove, libflint), Clong, (Ptr{ZZRingElem}, Ptr{ZZRingElem}, Ref{ZZRingElem}), nr, nr, b)
vn, nr = remove!(nr, b)
#QQFieldElem's are simplified: either num OR den will be non-trivial
if !is_zero(vn)
return vn
end
nr = _den_ptr(a)
vn = ccall((:fmpz_remove, libflint), Clong, (Ptr{ZZRingElem}, Ptr{ZZRingElem}, Ref{ZZRingElem}), nr, nr, b)
vn, nr = remove!(nr, b)
return -vn
end

Expand Down
106 changes: 34 additions & 72 deletions src/flint/fmpz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ zero(::Type{ZZRingElem}) = ZZRingElem(0)
Return the sign of $a$, i.e. $+1$, $0$ or $-1$.
"""
sign(a::ZZRingElem) = ZZRingElem(ccall((:fmpz_sgn, libflint), Cint, (Ref{ZZRingElem},), a))
sign(a::ZZRingElemOrPtr) = ZZRingElem(@ccall libflint.fmpz_sgn(a::Ref{ZZRingElem})::Cint)

sign(::Type{Int}, a::ZZRingElemOrPtr) = Int(ccall((:fmpz_sgn, libflint), Cint, (Ref{ZZRingElem},), a))
sign(::Type{Int}, a::ZZRingElemOrPtr) = Int(@ccall libflint.fmpz_sgn(a::Ref{ZZRingElem})::Cint)

Base.signbit(a::ZZRingElemOrPtr) = signbit(sign(Int, a))

Expand Down Expand Up @@ -416,15 +416,11 @@ end

function divexact(x::ZZRingElem, y::ZZRingElem; check::Bool=true)
iszero(y) && throw(DivideError())
z = ZZRingElem()
if check
r = ZZRingElem()
ccall((:fmpz_tdiv_qr, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, r, x, y)
r != 0 && throw(ArgumentError("Not an exact division"))
z, r = tdivrem(x, y)
is_zero(r) || throw(ArgumentError("Not an exact division"))
else
ccall((:fmpz_divexact, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, x, y)
z = divexact!(ZZRingElem(), x, y)
end
return z
end
Expand Down Expand Up @@ -483,11 +479,7 @@ function is_divisible_by(x::Integer, y::ZZRingElem)
end

function rem(x::ZZRingElem, c::ZZRingElem)
iszero(c) && throw(DivideError())
q = ZZRingElem()
r = ZZRingElem()
ccall((:fmpz_tdiv_qr, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), q, r, x, c)
q, r = Base.divrem(x, c)
return r
end

Expand Down Expand Up @@ -676,22 +668,12 @@ Base.divrem(x::ZZRingElem, y::Int) = (Base.div(x, y), Base.rem(x, y))
###############################################################################

function divrem(x::ZZRingElem, y::ZZRingElem)
iszero(y) && throw(DivideError())
z1 = ZZRingElem()
z2 = ZZRingElem()
ccall((:fmpz_fdiv_qr, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z1, z2, x, y)
z1, z2
return fdivrem(x, y)
end

# N.B. Base.divrem differs from Nemo.divrem
function Base.divrem(x::ZZRingElem, y::ZZRingElem)
iszero(y) && throw(DivideError())
z1 = ZZRingElem()
z2 = ZZRingElem()
ccall((:fmpz_tdiv_qr, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z1, z2, x, y)
z1, z2
return tdivrem(x, y)
end

function tdivrem(x::ZZRingElem, y::ZZRingElem)
Expand Down Expand Up @@ -722,12 +704,8 @@ function cdivrem(x::ZZRingElem, y::ZZRingElem)
end

function ntdivrem(x::ZZRingElem, y::ZZRingElem)
iszero(y) && throw(DivideError())
z1 = ZZRingElem()
z2 = ZZRingElem()
ccall((:fmpz_ndiv_qr, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z1, z2, x, y)
return z1, z2
# ties are the only possible remainders
return ndivrem(x, y)
end

function nfdivrem(a::ZZRingElem, b::ZZRingElem)
Expand Down Expand Up @@ -1317,13 +1295,11 @@ always be non-negative and will be zero iff all inputs are zero.
"""
function gcd(x::ZZRingElem, y::ZZRingElem, z::ZZRingElem...)
d = ZZRingElem()
ccall((:fmpz_gcd, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), d, x, y)
d = gcd!(d, x, y)
length(z) == 0 && return d

for ix in 1:length(z)
ccall((:fmpz_gcd, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), d, d, z[ix])
for zi in z
d = gcd!(d, zi)
end
return d
end
Expand All @@ -1343,12 +1319,10 @@ function gcd(x::Vector{ZZRingElem})
end

z = ZZRingElem()
ccall((:fmpz_gcd, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, x[1], x[2])
z = gcd!(z, x[1], x[2])

for i in 3:length(x)
ccall((:fmpz_gcd, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, z, x[i])
z = gcd!(z, x[i])
if isone(z)
return z
end
Expand All @@ -1365,13 +1339,11 @@ always be non-negative and will be zero if any input is zero.
"""
function lcm(x::ZZRingElem, y::ZZRingElem, z::ZZRingElem...)
m = ZZRingElem()
ccall((:fmpz_lcm, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), m, x, y)
m = lcm!(m, x, y)
length(z) == 0 && return m

for ix in 1:length(z)
ccall((:fmpz_lcm, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), m, m, z[ix])
for zi in z
m = lcm!(m, zi)
end
return m
end
Expand All @@ -1390,12 +1362,10 @@ function lcm(x::Vector{ZZRingElem})
end

z = ZZRingElem()
ccall((:fmpz_lcm, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, x[1], x[2])
z = lcm!(z, x[1], x[2])

for i in 3:length(x)
ccall((:fmpz_lcm, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, z, x[i])
z = lcm!(z, x[i])
end

return z
Expand Down Expand Up @@ -1477,7 +1447,7 @@ julia> isqrt(ZZ(13))
```
"""
function isqrt(x::ZZRingElem)
x < 0 && throw(DomainError(x, "Argument must be non-negative"))
is_negative(x) && throw(DomainError(x, "Argument must be non-negative"))
z = ZZRingElem()
ccall((:fmpz_sqrt, libflint), Nothing, (Ref{ZZRingElem}, Ref{ZZRingElem}), z, x)
return z
Expand All @@ -1498,7 +1468,7 @@ julia> isqrtrem(ZZ(13))
```
"""
function isqrtrem(x::ZZRingElem)
x < 0 && throw(DomainError(x, "Argument must be non-negative"))
is_negative(x) && throw(DomainError(x, "Argument must be non-negative"))
s = ZZRingElem()
r = ZZRingElem()
ccall((:fmpz_sqrtrem, libflint), Nothing,
Expand All @@ -1507,20 +1477,16 @@ function isqrtrem(x::ZZRingElem)
end

function Base.sqrt(x::ZZRingElem; check=true)
x < 0 && throw(DomainError(x, "Argument must be non-negative"))
is_negative(x) && throw(DomainError(x, "Argument must be non-negative"))
if check
for i = 1:length(sqrt_moduli)
res = mod(x, sqrt_moduli[i])
!(res in sqrt_residues[i]) && error("Not a square")
end
s = ZZRingElem()
r = ZZRingElem()
ccall((:fmpz_sqrtrem, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), s, r, x)
s, r = isqrtrem(x)
!iszero(r) && error("Not a square")
else
s = ZZRingElem()
ccall((:fmpz_sqrt, libflint), Nothing, (Ref{ZZRingElem}, Ref{ZZRingElem}), s, x)
s = isqrt(x)
end
return s
end
Expand All @@ -1529,7 +1495,7 @@ is_square(x::ZZRingElem) = Bool(ccall((:fmpz_is_square, libflint), Cint,
(Ref{ZZRingElem},), x))

function is_square_with_sqrt(x::ZZRingElem)
if x < 0
if is_negative(x)
return false, zero(ZZRingElem)
end
for i = 1:length(sqrt_moduli)
Expand All @@ -1538,10 +1504,7 @@ function is_square_with_sqrt(x::ZZRingElem)
return false, zero(ZZRingElem)
end
end
s = ZZRingElem()
r = ZZRingElem()
ccall((:fmpz_sqrtrem, libflint), Nothing,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), s, r, x)
s, r = isqrtrem(x)
if !iszero(r)
return false, zero(ZZRingElem)
end
Expand Down Expand Up @@ -1745,18 +1708,18 @@ function next_prime(x::Int, proved::Bool = true)
return x < 2 ? 2 : Int(next_prime(x % UInt, proved))
end

function remove!(a::ZZRingElem, b::ZZRingElem)
v = ccall((:fmpz_remove, libflint), Clong, (Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), a, a, b)
return v, a
function remove!(z::ZZRingElemOrPtr, a::ZZRingElemOrPtr, b::ZZRingElemOrPtr)
v = ccall((:fmpz_remove, libflint), Clong, (Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, a, b)
return Int(v), z
end

remove!(a::ZZRingElemOrPtr, b::ZZRingElemOrPtr) = remove!(a, a, b)

function remove(x::ZZRingElem, y::ZZRingElem)
iszero(y) && throw(DivideError())
y <= 1 && error("Factor <= 1")
z = ZZRingElem()
num = ccall((:fmpz_remove, libflint), Int,
(Ref{ZZRingElem}, Ref{ZZRingElem}, Ref{ZZRingElem}), z, x, y)
return num, z
return remove!(z, x, y)
end

remove(x::ZZRingElem, y::Integer) = remove(x, ZZRingElem(y))
Expand Down Expand Up @@ -2344,8 +2307,7 @@ julia> nbits(ZZ(12))
4
```
"""
nbits(x::ZZRingElem) = iszero(x) ? 0 : Int(ccall((:fmpz_bits, libflint), Clong,
(Ref{ZZRingElem},), x))
nbits(x::ZZRingElemOrPtr) = Int(@ccall libflint.fmpz_bits(x::Ref{ZZRingElem})::Culong)

@doc raw"""
nbits(a::Integer) -> Int
Expand Down
30 changes: 13 additions & 17 deletions src/flint/fmpz_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ end
@boundscheck _checkbounds(A, i, j)
GC.@preserve A begin
m = mat_entry_ptr(A, i, j)
fl = ccall((:fmpz_sgn, libflint), Int, (Ptr{ZZRingElem},), m)
return isone(fl)
return is_positive(m)
end
end

Expand Down Expand Up @@ -769,7 +768,7 @@ function hadamard_bound2(M::ZZMatrix)
zero!(r)
M_ptr = mat_entry_ptr(M, i, 1)
for j in 1:n
ccall((:fmpz_addmul, libflint), Cvoid, (Ref{ZZRingElem}, Ptr{ZZRingElem}, Ptr{ZZRingElem}), r, M_ptr, M_ptr)
addmul!(r, M_ptr, M_ptr)
M_ptr += sizeof(ZZRingElem)
end
if iszero(r)
Expand All @@ -794,7 +793,7 @@ function maximum(::typeof(nbits), M::ZZMatrix)
#this is not going through the "correct" order of the rows, but
#for this is does not matter
if !iszero(unsafe_load(reinterpret(Ptr{Int}, M_ptr)))
mx = max(mx, ccall((:fmpz_bits, libflint), Culong, (Ptr{ZZRingElem},), M_ptr))
mx = max(mx, nbits(M_ptr))
end
M_ptr += sizeof(ZZRingElem)
end
Expand Down Expand Up @@ -1395,7 +1394,7 @@ function AbstractAlgebra.add_row!(A::ZZMatrix, s::ZZRingElem, i::Int, j::Int)
i_ptr = mat_entry_ptr(A, i, 1)
j_ptr = mat_entry_ptr(A, j, 1)
for k = 1:ncols(A)
ccall((:fmpz_addmul, libflint), Cvoid, (Ptr{ZZRingElem}, Ref{ZZRingElem}, Ptr{ZZRingElem}), i_ptr, s, j_ptr)
addmul!(i_ptr, s, j_ptr)
i_ptr += sizeof(ZZRingElem)
j_ptr += sizeof(ZZRingElem)
end
Expand All @@ -1409,7 +1408,7 @@ function AbstractAlgebra.add_column!(A::ZZMatrix, s::ZZRingElem, i::Int, j::Int)
for k = 1:nrows(A)
i_ptr = mat_entry_ptr(A, k, i)
j_ptr = mat_entry_ptr(A, k, j)
ccall((:fmpz_addmul, libflint), Cvoid, (Ptr{ZZRingElem}, Ref{ZZRingElem}, Ptr{ZZRingElem}), i_ptr, s, j_ptr)
addmul!(i_ptr, s, j_ptr)
end
end
end
Expand Down Expand Up @@ -1628,18 +1627,16 @@ function _solve_triu_left(U::ZZMatrix, b::ZZMatrix)
tmp_p += sizeof(ZZRingElem)
end
for j = 1:n
ccall((:fmpz_zero, libflint), Cvoid, (Ref{ZZRingElem}, ), s)
zero!(s)

tmp_p = mat_entry_ptr(tmp, 1, 1)
for k = 1:j-1
U_p = mat_entry_ptr(U, k, j)
ccall((:fmpz_addmul, libflint), Cvoid, (Ref{ZZRingElem}, Ptr{ZZRingElem}, Ptr{ZZRingElem}), s, U_p, tmp_p)
addmul!(s, U_p, tmp_p)
tmp_p += sizeof(ZZRingElem)
end
ccall((:fmpz_sub, libflint), Cvoid,
(Ref{ZZRingElem}, Ptr{ZZRingElem}, Ref{ZZRingElem}), s, mat_entry_ptr(b, i, j), s)
ccall((:fmpz_divexact, libflint), Cvoid,
(Ptr{ZZRingElem}, Ref{ZZRingElem}, Ptr{ZZRingElem}), mat_entry_ptr(tmp, 1, j), s, mat_entry_ptr(U, j, j))
sub!(s, mat_entry_ptr(b, i, j), s)
divexact!(mat_entry_ptr(tmp, 1, j), s, mat_entry_ptr(U, j, j))
end
tmp_p = mat_entry_ptr(tmp, 1, 1)
X_p = mat_entry_ptr(X, i, 1)
Expand Down Expand Up @@ -1669,21 +1666,20 @@ function _solve_triu(U::ZZMatrix, b::ZZMatrix)
tmp_ptr += sizeof(ZZRingElem)
end
for j = n:-1:1
ccall((:fmpz_zero, libflint), Cvoid, (Ref{ZZRingElem}, ), s)
zero!(s)
tmp_ptr = mat_entry_ptr(tmp, 1, j+1)
for k = j + 1:n
U_ptr = mat_entry_ptr(U, j, k)
ccall((:fmpz_addmul, libflint), Cvoid, (Ref{ZZRingElem}, Ptr{ZZRingElem}, Ptr{ZZRingElem}), s, U_ptr, tmp_ptr)
mul!(s, U_ptr, tmp_ptr)
tmp_ptr += sizeof(ZZRingElem)
# s = addmul!(s, U[j, k], tmp[k])
end
b_ptr = mat_entry_ptr(b, j, i)
ccall((:fmpz_sub, libflint), Cvoid, (Ref{ZZRingElem}, Ptr{ZZRingElem}, Ref{ZZRingElem}), s, b_ptr, s)
sub!(s, b_ptr, s)
# s = b[j, i] - s
tmp_ptr = mat_entry_ptr(tmp, 1, j)
U_ptr = mat_entry_ptr(U, j, j)
ccall((:fmpz_divexact, libflint), Cvoid, (Ptr{ZZRingElem}, Ref{ZZRingElem}, Ptr{ZZRingElem}), tmp_ptr, s, U_ptr)

divexact!(tmp_ptr, s, U_ptr)
# tmp[j] = divexact(s, U[j,j])
end
tmp_ptr = mat_entry_ptr(tmp, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion src/flint/fmpz_mpoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ function (::Type{Fac{ZZMPolyRingElem}})(fac::fmpz_mpoly_factor, preserve_input::
ccall((:fmpz_mpoly_factor_get_constant_fmpz, libflint), Nothing,
(Ref{ZZRingElem}, Ref{fmpz_mpoly_factor}),
c, fac)
sgnc = ccall((:fmpz_sgn, libflint), Cint, (Ref{ZZRingElem},), c)
sgnc = sign(Int, c)
if sgnc != 0
G = fmpz_factor()
ccall((:fmpz_factor, libflint), Nothing, (Ref{fmpz_factor}, Ref{ZZRingElem}), G, c)
Expand Down
Loading

0 comments on commit 246a850

Please sign in to comment.