Skip to content

Commit

Permalink
fixup MKLSparseMatrix: struct that need manual destroy() call
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Stukalov committed Sep 13, 2024
1 parent 0ea0b85 commit 4918149
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 84 deletions.
87 changes: 45 additions & 42 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ function mv!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_d
) where T
check_trans(transA)
check_mat_op_sizes(y, A, transA, x, 'N')
hA = create_handle(A)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_mvI}(), typeof(A),
transA, alpha, hA, descr, x, beta, y)
destroy_handle(typeof(A), hA)
destroy(hA)
check_status(res)
return y
end
Expand All @@ -48,10 +48,10 @@ function mm!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_d
columns = size(C, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldB = stride(B, 2)
ldC = stride(C, 2)
hA = create_handle(A)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_mmI}(), typeof(A),
transA, alpha, hA, descr, dense_layout, B, columns, ldB, beta, C, ldC)
destroy_handle(typeof(A), hA)
destroy(hA)
check_status(res)
return C
end
Expand All @@ -61,12 +61,12 @@ function spmm(transA::Char, A::AbstractSparseMatrix{T}, B::AbstractSparseMatrix{
check_trans(transA)
check_mat_op_sizes(nothing, A, transA, B, 'N')
Cout = Ref{sparse_matrix_t}()
hA = create_handle(A)
hB = create_handle(B)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_spmmI}(), typeof(A),
transA, hA, hB, Cout)
destroy_handle(typeof(A), hA)
destroy_handle(typeof(B), hB)
destroy(hA)
destroy(hB)
check_status(res)
return MKLSparseMatrix(Cout[])
end
Expand All @@ -79,12 +79,12 @@ function spmmd!(transa::Char, A::AbstractSparseMatrix{T}, B::AbstractSparseMatri
check_trans(transa)
check_mat_op_sizes(C, A, transa, B, 'N')
ldC = stride(C, 2)
hA = create_handle(A)
hB = create_handle(B)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_T_spmmdI}(), typeof(A),
transa, hA, hB, dense_layout, C, ldC)
destroy_handle(typeof(A), hA)
destroy_handle(typeof(B), hB)
destroy(hA)
destroy(hB)
check_status(res)
return C
end
Expand All @@ -96,15 +96,16 @@ function sp2m(transA::Char, A::AbstractSparseMatrix{T}, descrA::matrix_descr,
check_trans(transB)
check_mat_op_sizes(nothing, A, transA, B, transB)
Cout = Ref{sparse_matrix_t}()
hA = create_handle(A)
hB = create_handle(B)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
SPARSE_STAGE_FULL_MULT, Cout)
destroy_handle(typeof(A), hA)
destroy_handle(typeof(B), hB)
destroy(hA)
destroy(hB)
check_status(res)
return MKLSparseMatrix(Cout[])
# NOTE: we are guessing what is the storage format of C
return MKLSparseMatrix{typeof(A)}(Cout[])
end

# C := opA(A) * opB(B), where C is sparse, in-place version
Expand All @@ -117,37 +118,38 @@ function sp2m!(transA::Char, A::AbstractSparseMatrix{T}, descrA::matrix_descr,
check_trans(transA)
check_trans(transB)
check_mat_op_sizes(C, A, transA, B, transB)
hA = create_handle(A)
hB = create_handle(B)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
if check_nzpattern
# pre-multiply A * B to get the number of nonzeros per column in the result
CptnOut = Ref{sparse_matrix_t}()
res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
SPARSE_STAGE_NNZ_COUNT, CptnOut)
check_status(res)
hCptn = MKLSparseMatrix{typeof(A)}(CptnOut[])
try
# check if C has the same per-column nonzeros as the result
_C = extract_data(typeof(C), CptnOut[])
_C = extract_data(hCptn)
_Cnnz = _C.major_starts[end] - 1
nnz(C) == _Cnnz || error(lazy"Number of nonzeros in the destination matrix ($(nnz(C))) does not match the result ($(_Cnnz))")
C.colptr == _C.major_starts || error("Nonzeros structure of the destination matrix does not match the result")
catch e
# destroy handles to A and B if the pattern check fails,
# otherwise reuse them at the actual multiplication
destroy_handle(typeof(A), hA)
destroy_handle(typeof(B), hB)
destroy(hA)
destroy(hB)
rethrow(e)
finally
(CptnOut[] != C_NULL) && destroy_handle(typeof(A), CptnOut[])
destroy(hCptn)
end
# FIXME rowval not checked
end
# FIXME the optimal way would be to create the MKLSparse handle to C reusing its arrays
# and do SPARSE_STAGE_FINALIZE_MULT to directly write to the C.nzval
# but that causes segfaults when the handle is destroyed
# (also the partial mkl_sparse_copy(C) workaround to reuse the nz structure segfaults)
#hC = create_handle(C)
#hC = MKLSparseMatrix(C)
#hC_ref = Ref(hC)
#res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
# transA, descrA, hA, transB, descrB, hB,
Expand All @@ -158,12 +160,13 @@ function sp2m!(transA::Char, A::AbstractSparseMatrix{T}, descrA::matrix_descr,
res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
SPARSE_STAGE_FULL_MULT, hCopy_ref)
destroy_handle(typeof(A), hA)
destroy_handle(typeof(B), hB)
destroy(hA)
destroy(hB)
check_status(res)
if hCopy_ref[] != C_NULL
copy!(C, hCopy_ref[]; check_nzpattern)
destroy_handle(typeof(C), hCopy_ref[])
hCopy = MKLSparseMatrix{typeof(A)}(hCopy_ref[])
copy!(C, hCopy; check_nzpattern)
destroy(hCopy)
end
return C
end
Expand All @@ -178,17 +181,17 @@ function sp2md!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descrA::matr
check_trans(transB)
check_mat_op_sizes(C, A, transA, B, transB)
ldC = stride(C, 2)
hA = create_handle(A)
hB = create_handle(B)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_T_sp2mdI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
alpha, beta,
C, dense_layout, ldC)
if res != SPARSE_STATUS_SUCCESS
@show transA descrA transB descrB
end
destroy_handle(typeof(A), hA)
destroy_handle(typeof(B), hB)
destroy(hA)
destroy(hB)
check_status(res)
return C
end
Expand All @@ -199,10 +202,10 @@ end
function syrk(transA::Char, A::AbstractSparseMatrix{T}) where T
check_trans(transA)
Cout = Ref{sparse_matrix_t}()
hA = create_handle(A)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_syrkI}(), typeof(A),
transA, hA, Cout)
destroy_handle(typeof(A), hA)
destroy(hA)
check_status(res)
return MKLSparseMatrix(Cout[])
end
Expand All @@ -217,10 +220,10 @@ function syrkd!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, beta::T,
check_trans(transA)
check_mat_op_sizes(C, A, transA, A, transA == 'N' ? 'T' : 'N'; dense_layout)
ldC = stride(C, 2)
hA = create_handle(A)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_syrkdI}(), typeof(A),
transA, hA, alpha, beta, C, dense_layout, ldC)
destroy_handle(typeof(A), hA)
destroy(hA)
check_status(res)
return C
end
Expand All @@ -241,11 +244,11 @@ function syprd!(transA::Char, alpha::T, A::AbstractSparseMatrix{T},
check_result_rows = false, dense_layout = dense_layout_C)
ldB = stride(B, 2)
ldC = stride(C, 2)
hA = create_handle(A)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_syprdI}(), typeof(A),
transA, hA, B, dense_layout_B, ldB,
alpha, beta, C, dense_layout_C, ldC)
destroy_handle(typeof(A), hA)
destroy(hA)
check_status(res)
return C
end
Expand All @@ -257,10 +260,10 @@ function trsv!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix
checksquare(A)
check_trans(transA)
check_mat_op_sizes(y, A, transA, x, 'N')
hA = create_handle(A)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_trsvI}(), typeof(A),
transA, alpha, hA, descr, x, y)
destroy_handle(typeof(A), hA)
destroy(hA)
check_status(res)
return y
end
Expand All @@ -276,10 +279,10 @@ function trsm!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix
columns = size(Y, dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR ? 2 : 1)
ldX = stride(X, 2)
ldY = stride(Y, 2)
hA = create_handle(A)
hA = MKLSparseMatrix(A)
res = mkl_call(Val{:mkl_sparse_T_trsmI}(), typeof(A),
transA, alpha, hA, descr, dense_layout, X, columns, ldX, Y, ldY)
destroy_handle(typeof(A), hA)
destroy(hA)
check_status(res)
return Y
end
83 changes: 41 additions & 42 deletions src/mklsparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,47 +89,63 @@ lazypermutedims(descr::matrix_descr) = matrix_descr(
descr.mode == SPARSE_FILL_MODE_LOWER ? SPARSE_FILL_MODE_UPPER : descr.mode,
descr.diag)

"""
MKLSparseMatrix{S}
A wrapper around the handle of a MKLSparse matrix
created from the Julia sparse matrix of type `S`.
"""
struct MKLSparseMatrix{S <: AbstractSparseMatrix}
handle::sparse_matrix_t
end

Base.unsafe_convert(::Type{sparse_matrix_t}, A::MKLSparseMatrix) = A.handle

# create sparse_matrix_t handle for the SparseMKL representation of a given sparse matrix
# the created SparseMKL matrix handle has to be disposed by calling destroy_handle()
function create_handle(A::SparseMatrixCOO; index_base = SPARSE_INDEX_BASE_ONE)
function MKLSparseMatrix(A::SparseMatrixCOO; index_base = SPARSE_INDEX_BASE_ONE)
ref = Ref{sparse_matrix_t}()
res = mkl_call(Val{:mkl_sparse_T_create_SI}(), typeof(A),
ref, index_base, A.m, A.n, nnz(A), A.rows, A.cols, A.vals,
log=Val{false}())
check_status(res)
return ref[]
return MKLSparseMatrix{typeof(A)}(ref[])
end

function create_handle(A::SparseMatrixCSR; index_base = SPARSE_INDEX_BASE_ONE)
function MKLSparseMatrix(A::SparseMatrixCSR; index_base = SPARSE_INDEX_BASE_ONE)
ref = Ref{sparse_matrix_t}()
res = mkl_call(Val{:mkl_sparse_T_create_SI}(), typeof(A),
ref, index_base, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval,
log=Val{false}())
check_status(res)
return ref[]
return MKLSparseMatrix{typeof(A)}(ref[])
end

function create_handle(A::SparseMatrixCSC; index_base = SPARSE_INDEX_BASE_ONE)
function MKLSparseMatrix(A::SparseMatrixCSC; index_base = SPARSE_INDEX_BASE_ONE)
ref = Ref{sparse_matrix_t}()
# SparseMatrixCSC is fixed to 1-based indexing, passing SPARSE_INDEX_BASE_ZERO is most likely an error
res = mkl_call(Val{:mkl_sparse_T_create_SI}(), typeof(A),
ref, index_base, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval,
log=Val{false}())
check_status(res)
return ref[]
return MKLSparseMatrix{typeof(A)}(ref[])
end

function destroy_handle(::Type{S}, handle::sparse_matrix_t) where S <: AbstractSparseMatrix
res = mkl_call(Val{:mkl_sparse_destroyI}(), S, handle)
check_status(res)
return res
function destroy(A::MKLSparseMatrix{S}) where S
if A.handle != C_NULL
res = mkl_call(Val{:mkl_sparse_destroyI}(), S, A.handle)
check_status(res)
return res
else
return SPARSE_STATUS_NOT_INITIALIZED
end
end

# extract the Intel MKL's sparse matrix A information assuming its storage type is S
# the returned arrays are internal to MKL representation of A, their lifetime is limited by A
# "major_" refers to the major axis (rows for CSR, columns for CSC)
# "minor_" refers to the minor axis (columns for CSR, rows for CSC)
function extract_data(::Type{S}, ref::sparse_matrix_t) where {S <: AbstractSparseMatrix{Tv, Ti}} where {Tv, Ti}
function extract_data(ref::MKLSparseMatrix{S}) where {S <: AbstractSparseMatrix{Tv, Ti}} where {Tv, Ti}
IT = ifelse(BlasInt === Int64 && Ti === Int32, BlasInt, Ti)
index_base = Ref{sparse_index_base_t}()
nrows = Ref{IT}(0)
Expand Down Expand Up @@ -170,58 +186,41 @@ end

# copy the non-zero values from the MKL Sparse matrix A into the sparse matrix B
# A and B should have the same non-zero pattern
function Base.copy!(B::SparseMatrixCSC{Tv, Ti}, A::sparse_matrix_t;
check_nzpattern::Bool = true) where {Tv, Ti}
_A = extract_data(typeof(B), A)
function Base.copy!(B::S, A::MKLSparseMatrix{S};
check_nzpattern::Bool = true) where {S <: SparseMatrixCSC}
_A = extract_data(A)
Ti = eltype(B.rowval)
length(_A.nzval) == nnz(B) || error(lazy"Number of nonzeros in the source ($(length(_A.nzval))) does not match the destination matrix ($(nnz(B)))")
size(B) == _A.size || throw(DimensionMismatch(lazy"Size of the source $(_A.size) does not match the destination $(size(B))"))
if check_nzpattern
B.colptr == _A.major_starts || error("Source and destination colptr do not match")
rowval_match = _A.index_base == SPARSE_INDEX_BASE_ZERO ?
all((a, b) -> a + 1 == b, zip(_A.minor_val, B.rowval)) : # convert to 1-based
all((a, b) -> a + one(Ti) == b, zip(_A.minor_val, B.rowval)) : # convert to 1-based
_A.minor_val == B.rowval
rowval_match || error("Source and destination rowval do not match")
end
(pointer(B.nzval) != pointer(_A.nzval)) && copy!(B.nzval, _A.nzval)
return B
end

"""
MKLSparseMatrix
A wrapper around a MKLSparse matrix handle.
"""
mutable struct MKLSparseMatrix
handle::sparse_matrix_t
end

function MKLSparseMatrix(A::AbstractSparseMatrix; index_base = SPARSE_INDEX_BASE_ONE)
obj = MKLSparseMatrix(create_handle(A; index_base))
finalizer(mkl_function(Val{:mkl_sparse_destroyI}(), typeof(A)), obj)
return obj
end

Base.unsafe_convert(::Type{sparse_matrix_t}, desc::MKLSparseMatrix) = desc.handle

extract_data(::Type{S}, A::MKLSparseMatrix) where {S <: AbstractSparseMatrix} = extract_data(S, A.handle)

function Base.convert(::Type{S}, A::MKLSparseMatrix) where {S <: SparseMatrixCSC{Tv, Ti}} where {Tv, Ti}
_A = extract_data(S, A)
function Base.convert(::Type{S}, A::MKLSparseMatrix{S}) where {S <: SparseMatrixCSC}
_A = extract_data(A)
Ti = eltype(_A.minor_val)
rowval = _A.index_base == SPARSE_INDEX_BASE_ZERO ?
_A.minor_val .+ one(Ti) : # convert to 1-based (rowval is copied)
copy(_A.minor_val)
return S(_A.size..., copy(_A.major_starts), rowval, copy(_A.nzval))
end

# converter for the default SparseMatrixCSC storage type
Base.convert(::Type{SparseMatrixCSC}, A::MKLSparseMatrix) =
convert(SparseMatrixCSC{Float64, BlasInt}, A)
Base.convert(::Type{SparseMatrixCSC}, A::MKLSparseMatrix{SparseMatrixCSC{Tv, Ti}}) where {Tv, Ti} =
convert(SparseMatrixCSC{Tv, Ti}, A)

function Base.convert(::Type{S}, A::MKLSparseMatrix) where {S <: SparseMatrixCSR{Tv, Ti}} where {Tv, Ti}
_A = extract_data(S, A)
function Base.convert(::Type{S}, A::MKLSparseMatrix{S}) where {S <: SparseMatrixCSR}
_A = extract_data(A)
# not converting the col indices depending on index_base
return S(_A.size..., copy(_A.major_starts), copy(_A.minor_val), copy(_A.nzval))
end

Base.copy!(B::SparseMatrixCSC, A::MKLSparseMatrix; check_nzpattern::Bool = true) =
copy!(B, A.handle, check_nzpattern=check_nzpattern)
Base.convert(::Type{SparseMatrixCSR}, A::MKLSparseMatrix{SparseMatrixCSR{Tv, Ti}}) where {Tv, Ti} =
convert(SparseMatrixCSR{Tv, Ti}, A)

0 comments on commit 4918149

Please sign in to comment.