diff --git a/src/PETSc.jl b/src/PETSc.jl index d8782d40..0a463bb1 100644 --- a/src/PETSc.jl +++ b/src/PETSc.jl @@ -24,7 +24,7 @@ include("viewer.jl") include("options.jl") include("vec.jl") include("mat.jl") -# include("matshell.jl") +include("matshell.jl") # include("ksp.jl") # include("ref.jl") # include("pc.jl") diff --git a/src/matshell.jl b/src/matshell.jl index 98571bbd..08f77784 100644 --- a/src/matshell.jl +++ b/src/matshell.jl @@ -3,58 +3,87 @@ Create a `m×n` PETSc shell matrix object wrapping `obj`. -If `obj` is a `Function`, then the multiply action `obj(y,x)`; otherwise it calls `mul!(y, obj, x)`. +If `obj` is a `Function`, then the multiply action `obj(y,x)`; otherwise it +calls `mul!(y, obj, x)`. + This can be changed by defining `PETSc._mul!`. +# External Links +$(_doc_external("Mat/MATSHELL")) """ -mutable struct MatShell{T,A} <: AbstractMat{T} +mutable struct MatShell{PetscLib, PetscScalar, OType} <: + AbstractMat{PetscLib, PetscScalar} ptr::CMat - obj::A + obj::OType end +struct MatOp{PetscLib, PetscInt, Op} end + +function (::MatOp{PetscLib, PetscInt, LibPETSc.MATOP_MULT})( + M::CMat, + cx::CVec, + cy::CVec, +)::PetscInt where {PetscLib, PetscInt} + r_ctx = Ref{Ptr{Cvoid}}() + LibPETSc.MatShellGetContext(PetscLib, M, r_ctx) + ptr = r_ctx[] + mat = unsafe_pointer_to_objref(ptr) + + PetscScalar = getlib(PetscLib).PetscScalar + x = unsafe_localarray(WrapVec{PetscLib, PetscScalar}(cx); write = false) + y = unsafe_localarray(WrapVec{PetscLib, PetscScalar}(cy); read = false) -struct MatOp{T,Op} end + _mul!(y, mat, x) + Base.finalize(y) + Base.finalize(x) + return PetscInt(0) +end -function _mul!(y,mat::MatShell{T,F},x) where {T, F<:Function} +function _mul!( + y, + mat::MatShell{PetscLib, PetscScalar, F}, + x, +) where {PetscLib, PetscScalar, F <: Function} mat.obj(y, x) end -function _mul!(y,mat::MatShell{T},x) where {T} +function _mul!(y, mat::MatShell, x) where {T} LinearAlgebra.mul!(y, mat.obj, x) end -MatShell{T}(obj, m, n) where {T} = MatShell{T}(obj, MPI.COMM_SELF, m, n, m, n) - - -@for_libpetsc begin - function MatShell{$PetscScalar}(obj::A, comm::MPI.Comm, m, n, M, N) where {A} - mat = MatShell{$PetscScalar,A}(C_NULL, obj) - # we use the MatShell object itsel - ctx = pointer_from_objref(mat) - @chk ccall((:MatCreateShell, $libpetsc), PetscErrorCode, - (MPI.MPI_Comm,$PetscInt,$PetscInt,$PetscInt,$PetscInt,Ptr{Cvoid},Ptr{CMat}), - comm, m, n, M, N, ctx, mat) - - mulptr = @cfunction(MatOp{$PetscScalar, MATOP_MULT}(), $PetscInt, (CMat, CVec, CVec)) - @chk ccall((:MatShellSetOperation, $libpetsc), PetscErrorCode, (CMat, MatOperation, Ptr{Cvoid}), mat, MATOP_MULT, mulptr) - return mat - end - - function (::MatOp{$PetscScalar, MATOP_MULT})(M::CMat,cx::CVec,cy::CVec)::$PetscInt - r_ctx = Ref{Ptr{Cvoid}}() - @chk ccall((:MatShellGetContext, $libpetsc), PetscErrorCode, (CMat, Ptr{Ptr{Cvoid}}), M, r_ctx) - ptr = r_ctx[] - mat = unsafe_pointer_to_objref(ptr) - - x = unsafe_localarray($PetscScalar, cx; write=false) - y = unsafe_localarray($PetscScalar, cy; read=false) - - _mul!(y,mat,x) - - Base.finalize(y) - Base.finalize(x) - return $PetscInt(0) - end - +# We have to use the macro here because of the @cfunction +LibPETSc.@for_petsc function MatShell( + petsclib::$PetscLib, + obj::OType, + comm::MPI.Comm, + local_rows, + local_cols, + global_rows = LibPETSc.PETSC_DECIDE, + global_cols = LibPETSc.PETSC_DECIDE, +) where {OType} + mat = MatShell{$PetscLib, $PetscScalar, OType}(C_NULL, obj) + + # we use the MatShell object itself + ctx = pointer_from_objref(mat) + + LibPETSc.MatCreateShell( + petsclib, + comm, + local_rows, + local_cols, + global_rows, + global_cols, + pointer_from_objref(mat), + mat, + ) + + mulptr = @cfunction( + MatOp{$PetscLib, $PetscInt, LibPETSc.MATOP_MULT}(), + $PetscInt, + (CMat, CVec, CVec) + ) + LibPETSc.MatShellSetOperation(petsclib, mat, LibPETSc.MATOP_MULT, mulptr) + + return mat end diff --git a/src/vec.jl b/src/vec.jl index 6f58df70..b50e88b8 100644 --- a/src/vec.jl +++ b/src/vec.jl @@ -18,6 +18,11 @@ Base.eltype( ) where {PetscLib, PetscScalar} = PetscScalar Base.size(v::AbstractVec) = (length(v),) +mutable struct WrapVec{PetscLib, PetscScalar} <: + AbstractVec{PetscLib, PetscScalar} + ptr::CVec +end + """ VecSeq(petsclib, v::Vector) diff --git a/test/matshell.jl b/test/matshell.jl new file mode 100644 index 00000000..eb2c14ad --- /dev/null +++ b/test/matshell.jl @@ -0,0 +1,22 @@ +using Test +using PETSc +using MPI + +@testset "MatShell" begin + for petsclib in PETSc.petsclibs + PETSc.initialize(petsclib) + PetscScalar = petsclib.PetscScalar + + local_rows = 10 + local_cols = 5 + f!(x, y) = x .= [2y; 3y] + x_jl = collect + + matshell = + PETSc.MatShell(petsclib, f!, MPI.COMM_SELF, local_rows, local_cols) + x = PetscScalar.(collect(1:5)) + @test matshell * x == [2x; 3x] + + PETSc.finalize(petsclib) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a02d2207..28477ea8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,3 +2,4 @@ include("init.jl") include("options.jl") include("vec.jl") include("mat.jl") +include("matshell.jl")