Skip to content

Commit

Permalink
add STIR's CudaRelativeDifferencePrior
Browse files Browse the repository at this point in the history
  • Loading branch information
KrisThielemans committed Jul 8, 2024
1 parent 6c10a0d commit 0c8e283
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 19 deletions.
4 changes: 3 additions & 1 deletion cmake/config.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ SIRF_HAS_NiftyPET = @NiftyPET_BOOL_STR@

SIRF_HAS_Parallelproj = @Parallelproj_BOOL_STR@

SIRF_HAS_SPM = @SPM_BOOL_STR@
STIR_WITH_CUDA = @STIR_WITH_CUDA_BOOL_STR@

SIRF_HAS_SPM = @SPM_BOOL_STR@
10 changes: 9 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ if (DISABLE_STIR)
message(STATUS "STIR support disabled.")
set(NiftyPET_BOOL_STR "0" PARENT_SCOPE)
set(Parallelproj_BOOL_STR "0" PARENT_SCOPE)
set(STIR_WITH_CUDA_BOOL_STR "0" PARENT_SCOPE)
else()
find_package(STIR REQUIRED)
message(STATUS "STIR version found: ${STIR_VERSION}")
Expand All @@ -80,13 +81,20 @@ else()
MESSAGE(STATUS "STIR not built with Parallelproj.")
set(Parallelproj_BOOL_STR "0")
endif()
if (STIR_WITH_CUDA)
set(STIR_WITH_CUDA_BOOL_STR "1")
message(STATUS "STIR was built with CUDA, corresponding functionality will be enabled.")
else()
MESSAGE(STATUS "STIR not built with CUDA.")
set(STIR_WITH_CUDA_BOOL_STR "0")
endif()
ADD_SUBDIRECTORY(xSTIR)
set(SIRF_BUILT_WITH_STIR TRUE PARENT_SCOPE)
set(STIR_VERSION ${STIR_VERSION} PARENT_SCOPE)
endif()
set(NiftyPET_BOOL_STR ${NiftyPET_BOOL_STR} PARENT_SCOPE)
set(Parallelproj_BOOL_STR ${Parallelproj_BOOL_STR} PARENT_SCOPE)

set (STIR_WITH_CUDA_BOOL_STR ${STIR_WITH_CUDA_BOOL_STR} PARENT_SCOPE)


##########################################################################
Expand Down
4 changes: 4 additions & 0 deletions src/xSTIR/cSTIR/cstir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ void* cSTIR_newObject(const char* name)
return NEW_OBJECT_HANDLE(LogPrior3DF);
if (sirf::iequals(name, "RelativeDifferencePrior"))
return NEW_OBJECT_HANDLE(RDPrior3DF);
#ifdef STIR_WITH_CUDA
if (sirf::iequals(name, "CudaRelativeDifferencePrior"))
return NEW_OBJECT_HANDLE(CudaRDPrior3DF);
#endif
if (sirf::iequals(name, "PLSPrior"))
return NEW_OBJECT_HANDLE(PLSPrior3DF);
if (sirf::iequals(name, "TruncateToCylindricalFOVImageProcessor"))
Expand Down
10 changes: 6 additions & 4 deletions src/xSTIR/cSTIR/cstir_p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,11 @@ void*
sirf::cSTIR_setRelativeDifferencePriorParameter
(DataHandle* hp, const char* name, const DataHandle* hv)
{
auto& prior = objectFromHandle<xSTIR_RelativeDifferencePrior3DF>(hp);
if (sirf::iequals(name, "only_2D"))
prior.only2D(dataFromHandle<int>((void*)hv));
auto& prior = objectFromHandle<stir::RelativeDifferencePrior<float>>(hp);
if (sirf::iequals(name, "only_2D")) {
auto& xrdp = objectFromHandle<xSTIR_RelativeDifferencePrior3DF>(hp);
xrdp.only2D(dataFromHandle<int>((void*)hv));
}
else if (sirf::iequals(name, "kappa")) {
auto& id = objectFromHandle<STIRImageData>(hv);
prior.set_kappa_sptr(id.data_sptr());
Expand All @@ -723,7 +725,7 @@ void*
sirf::cSTIR_RelativeDifferencePriorParameter
(DataHandle* hp, const char* name)
{
auto& prior = objectFromHandle<xSTIR_RelativeDifferencePrior3DF >(hp);
auto& prior = objectFromHandle<stir::RelativeDifferencePrior<float>>(hp);
if (sirf::iequals(name, "kappa")) {
auto sptr_im = std::make_shared<STIRImageData>(*prior.get_kappa_sptr());
return newObjectHandle(sptr_im);
Expand Down
12 changes: 9 additions & 3 deletions src/xSTIR/cSTIR/include/sirf/STIR/stir_types.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
SyneRBI Synergistic Image Reconstruction Framework (SIRF)
Copyright 2015 - 2019 Rutherford Appleton Laboratory STFC
Copyright 2019 - 2020 University College London
Copyright 2015 - 2024 Rutherford Appleton Laboratory STFC
Copyright 2019 - 2024 University College London
This is software developed for the Collaborative Computational
Project in Synergistic Reconstruction for Biomedical Imaging (formerly CCP PETMR)
Expand Down Expand Up @@ -64,6 +64,9 @@ limitations under the License.
#include "stir/recon_buildblock/QuadraticPrior.h"
#include "stir/recon_buildblock/LogcoshPrior.h"
#include "stir/recon_buildblock/RelativeDifferencePrior.h"
#ifdef STIR_WITH_CUDA
#include "stir/recon_buildblock/CUDA/CudaRelativeDifferencePrior.h"
#endif
#include "stir/SegmentBySinogram.h"
#include "stir/Shape/Box3D.h"
#include "stir/Shape/Ellipsoid.h"
Expand Down Expand Up @@ -116,7 +119,10 @@ namespace sirf {
typedef stir::QuadraticPrior<float> QuadPrior3DF;
typedef stir::LogcoshPrior<float> LogPrior3DF;
typedef stir::RelativeDifferencePrior<float> RDPrior3DF;
typedef stir::PLSPrior<float> PLSPrior3DF;
#ifdef STIR_WITH_CUDA
typedef stir::CudaRelativeDifferencePrior<float> CudaRDPrior3DF;
#endif
typedef stir::PLSPrior<float> PLSPrior3DF;
typedef stir::DataProcessor<Image3DF> DataProcessor3DF;
typedef stir::TruncateToCylindricalFOVImageProcessor<float> CylindricFilter3DF;

Expand Down
19 changes: 19 additions & 0 deletions src/xSTIR/pSTIR/STIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import sirf.STIR_params as parms
from sirf.config import SIRF_HAS_NiftyPET
from sirf.config import SIRF_HAS_Parallelproj
from sirf.config import STIR_WITH_CUDA

if sys.version_info[0] >= 3 and sys.version_info[1] >= 4:
ABC = abc.ABC
Expand Down Expand Up @@ -2529,6 +2530,24 @@ def get_kappa(self):
check_status(image.handle)
return image

if STIR_WITH_CUDA:
class CudaRelativeDifferencePrior(RelativeDifferencePrior):
r"""Class for Relative Difference Prior using CUDA computations
Identical to RelativeDifferencePrior, but using STIR's CUDA implementation.
"""

def __init__(self):
"""init."""
self.name = 'CudaRelativeDifferencePrior'
self.handle = pystir.cSTIR_newObject(self.name)
check_status(self.handle)

def __del__(self):
"""del."""
if self.handle is not None:
pyiutil.deleteDataHandle(self.handle)


class PLSPrior(Prior):
r"""Class for Parallel Level Sets prior.
Expand Down
24 changes: 14 additions & 10 deletions src/xSTIR/pSTIR/tests/tests_qp_lc_rdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
{licence}
"""
import sirf.STIR
import sirf.config
from sirf.Utilities import runner, RE_PYEXT, __license__, examples_data_path, pTest
__version__ = "2.0.0"
__version__ = "2.1.0"
__author__ = "Imraj Singh, Evgueni Ovtchinnikov, Kris Thielemans"


def Hessian_test(test, prior, x, eps=1e-3):
"""Checks that grad(x + dx) - grad(x) is close to H(x)*dx
Expand Down Expand Up @@ -52,18 +53,21 @@ def test_main(rec=False, verb=False, throw=True):

_ = sirf.STIR.MessageRedirector(warn=None)

im_thorax = sirf.STIR.ImageData(examples_data_path('PET') + '/thorax_single_slice/emission.hv')
im_1 = im_thorax.get_uniform_copy(1)
im_2 = im_thorax.get_uniform_copy(2)
im_0 = sirf.STIR.ImageData(examples_data_path('PET') + '/brain/emission.hv')
im_1 = im_0.get_uniform_copy(1)
im_2 = im_0.get_uniform_copy(2)

for im in [im_thorax, im_1, im_2]:
for im in [im_0, im_1, im_2]:
for penalisation_factor in [0,1,4]:
for kappa in [True, False]:
for prior in [sirf.STIR.QuadraticPrior(), sirf.STIR.LogcoshPrior(), sirf.STIR.RelativeDifferencePrior()]:
priors = [sirf.STIR.QuadraticPrior(), sirf.STIR.LogcoshPrior(), sirf.STIR.RelativeDifferencePrior()]
if sirf.config.STIR_WITH_CUDA:
priors.append(sirf.STIR.CudaRelativeDifferencePrior())
for prior in priors:
if kappa:
prior.set_kappa(im_thorax)
prior.set_kappa(im_0)
# Check if the kappa is stored/returned correctly
test.check_if_equal_within_tolerance(im_thorax.norm(),prior.get_kappa().norm())
test.check_if_equal_within_tolerance(im_0.norm(),prior.get_kappa().norm())

prior.set_penalisation_factor(penalisation_factor)
prior.set_up(im)
Expand All @@ -78,7 +82,7 @@ def test_main(rec=False, verb=False, throw=True):
if kappa:
# check if multiplying kappa and dividing penalisation factor gives same result
prior.set_penalisation_factor(penalisation_factor/4)
prior.set_kappa(im_thorax*2)
prior.set_kappa(im_0*2)
prior.set_up(im)
test.check_if_equal_within_tolerance(prior.get_gradient(im).norm(), grad_norm)

Expand Down

0 comments on commit 0c8e283

Please sign in to comment.