Skip to content

Commit

Permalink
feat: SHAP values for binary classification [target: r2025.0] (#2635)
Browse files Browse the repository at this point in the history
* Add result container for shap predictions in df clsf

* Continue work on dispatching SHAP options

* Fix result shape checks

* SHAP for binary clsf works

* Fix result allocation for all resultsToEvaluate

* finalize binary classification SHAP values

* Update example

* Revert "Update example"

This reverts commit 375dacd.

* typo in comment

* Revert "Revert "Update example""

This reverts commit f5072c5.

* Align checking for configuration errors

---------

Co-authored-by: icfaust <[email protected]>
  • Loading branch information
ahuber21 and icfaust authored Sep 5, 2024
1 parent a637ba2 commit f13313a
Show file tree
Hide file tree
Showing 13 changed files with 345 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch

typedef algorithms::gbt::classification::prediction::Input InputType;
typedef algorithms::gbt::classification::prediction::Parameter ParameterType;
typedef typename super::ResultType ResultType;
typedef algorithms::gbt::classification::prediction::Result ResultType;

InputType input; /*!< %Input objects of the algorithm */

Expand Down Expand Up @@ -152,6 +152,12 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch
*/
virtual int getMethod() const DAAL_C11_OVERRIDE { return (int)method; }

/**
* Returns the structure that contains the result of model-based prediction
* \return Structure that contains the result of the model-based prediction
*/
ResultPtr getResult() { return ResultType::cast(_result); }

/**
* Returns a pointer to the newly allocated gradient boosted trees prediction algorithm with a copy of input objects
* and parameters of this gradient boosted trees prediction algorithm
Expand All @@ -164,7 +170,7 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch

services::Status allocateResult() DAAL_C11_OVERRIDE
{
services::Status s = _result->allocate<algorithmFPType>(&input, _par, 0);
services::Status s = getResult()->template allocate<algorithmFPType>(&input, _par, 0);
_res = _result.get();
return s;
}
Expand All @@ -173,6 +179,7 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch
{
_in = &input;
_ac = new __DAAL_ALGORITHM_CONTAINER(batch, BatchContainer, algorithmFPType, method)(&_env);
_result.reset(new ResultType());
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@ enum Method
defaultDense = 0 /*!< Default method */
};

/**
* <a name="DAAL-ENUM-ALGORITHMS__GBT__CLASSIFICATION__PREDICTION__MODELINPUTID"></a>
* \brief Available identifiers of input models for making model-based prediction
*/
enum ModelInputId
{
model = algorithms::classifier::prediction::model, /*!< Trained gradient boosted trees model */
lastModelInputId = model
};

/**
* <a name="DAAL-ENUM-ALGORITHMS__GBT__CLASSIFICATION__PREDICTION__RESULTID"></a>
* \brief Available identifiers of the result for making model-based prediction
*/
enum ResultId
{
prediction = algorithms::classifier::prediction::prediction,
probabilities = algorithms::classifier::prediction::probabilities,
logProbabilities = algorithms::classifier::prediction::logProbabilities,
lastResultId = logProbabilities
};

/**
* <a name="DAAL-ENUM-ALGORITHMS__GBT__CLASSIFICATION__PREDICTION__RESULTTOCOMPUTEID"></a>
* Available identifiers to specify the result to compute - results are mutually exclusive
Expand Down Expand Up @@ -89,6 +111,62 @@ struct DAAL_EXPORT Parameter : public daal::algorithms::classifier::Parameter
DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */
};
/* [Parameter source code] */

/**
* <a name="DAAL-CLASS-ALGORITHMS__GBT__CLASSIFICATION__RESULT"></a>
* \brief Provides interface for the result of model-based prediction
*/
class DAAL_EXPORT Result : public algorithms::classifier::prediction::Result
{
public:
DECLARE_SERIALIZABLE_CAST(Result)
Result();

/**
* Returns the result of model-based prediction
* \param[in] id Identifier of the result
* \return Result that corresponds to the given identifier
*/
data_management::NumericTablePtr get(ResultId id) const;

/**
* Sets the result of model-based prediction
* \param[in] id Identifier of the input object
* \param[in] value %Input object
*/
void set(ResultId id, const data_management::NumericTablePtr & value);

/**
* Allocates memory to store a partial result of model-based prediction
* \param[in] input %Input object
* \param[in] par %Parameter of the algorithm
* \param[in] method Algorithm method
* \return Status of allocation
*/
template <typename algorithmFPType>
DAAL_EXPORT services::Status allocate(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, const int method);

/**
* Checks the result of model-based prediction
* \param[in] input %Input object
* \param[in] par %Parameter of the algorithm
* \param[in] method Computation method
* \return Status of checking
*/
services::Status check(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, int method) const DAAL_C11_OVERRIDE;

protected:
using daal::algorithms::Result::check;

/** \private */
template <typename Archive, bool onDeserialize>
services::Status serialImpl(Archive * arch)
{
return daal::algorithms::Result::serialImpl<Archive, onDeserialize>(arch);
}
};

typedef services::SharedPtr<Result> ResultPtr;
} // namespace interface2

/**
Expand Down Expand Up @@ -124,7 +202,7 @@ class DAAL_EXPORT Input : public classifier::prediction::Input
* \param[in] id Identifier of the input Model object
* \return %Input object that corresponds to the given identifier
*/
gbt::classification::ModelPtr get(classifier::prediction::ModelInputId id) const;
gbt::classification::ModelPtr get(ModelInputId id) const;

/**
* Sets the input NumericTable object in the prediction stage of the classification algorithm
Expand All @@ -138,7 +216,7 @@ class DAAL_EXPORT Input : public classifier::prediction::Input
* \param[in] id Identifier of the input object
* \param[in] ptr Pointer to the input object
*/
void set(classifier::prediction::ModelInputId id, const gbt::classification::ModelPtr & ptr);
void set(ModelInputId id, const gbt::classification::ModelPtr & ptr);

/**
* Checks the correctness of the input object
Expand All @@ -151,6 +229,8 @@ class DAAL_EXPORT Input : public classifier::prediction::Input

} // namespace interface1
using interface2::Parameter;
using interface2::Result;
using interface2::ResultPtr;
using interface1::Input;
} // namespace prediction
/** @} */
Expand Down
3 changes: 2 additions & 1 deletion cpp/daal/include/services/error_indexes.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ enum ErrorID
// GBT error: -30000..-30099
ErrorGbtIncorrectNumberOfTrees = -30000, /*!< Number of trees in the model is not consistent with the number of classes */
ErrorGbtPredictIncorrectNumberOfIterations = -30001, /*!< Number of iterations value in GBT parameter is not consistent with the model */
ErrorGbtPredictShapOptions = -30002, /*< For SHAP values, calculate either contributions or interactions, not both */
ErrorGbtPredictShapOptions = -30002, /*!< For SHAP values, calculate either contributions or interactions, not both */
ErrorGbtPredictShapMulticlassNotSupported = -30003, /*!< For classification, SHAP values currently only support binary classification */

// Data management errors: -80001..
ErrorUserAllocatedMemory = -80001, /*!< Couldn't free memory allocated by user */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ services::Status BatchContainer<algorithmFPType, method, cpu>::compute()
result->get(classifier::prediction::probabilities).get() :
nullptr);

const bool predShapContributions = par->resultsToCompute & shapContributions;
const bool predShapInteractions = par->resultsToCompute & shapInteractions;
__DAAL_CALL_KERNEL(env, internal::PredictKernel, __DAAL_KERNEL_ARGUMENTS(algorithmFPType, method), compute,
daal::services::internal::hostApp(*input), a, m, r, prob, par->nClasses, par->nIterations);
daal::services::internal::hostApp(*input), a, m, r, prob, par->nClasses, par->nIterations, predShapContributions,
predShapInteractions);
}

} // namespace interface2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ public:
* \param m The model for which to run prediction
* \param nIterations Number of iterations
* \param pHostApp HostAppInterface
* \param predShapContributions Predict SHAP contributions
* \param predShapInteractions Predict SHAP interactions
* \return services::Status
*/
services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp);
services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp,
bool predShapContributions, bool predShapInteractions);

protected:
/**
Expand Down Expand Up @@ -126,9 +129,12 @@ public:
* \param nClasses Number of data classes
* \param nIterations Number of iterations
* \param pHostApp HostAppInterface
* \param predShapContributions Predict SHAP contributions
* \param predShapInteractions Predict SHAP interactions
* \return services::Status
*/
services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nClasses, size_t nIterations, services::HostAppIface * pHostApp);
services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nClasses, size_t nIterations, services::HostAppIface * pHostApp,
bool predShapContributions, bool predShapInteractions);

protected:
/** Dispatcher type for template dispatching */
Expand Down Expand Up @@ -408,8 +414,12 @@ protected:
//////////////////////////////////////////////////////////////////////////////////////////
template <typename algorithmFPType, CpuType cpu>
services::Status PredictBinaryClassificationTask<algorithmFPType, cpu>::run(const gbt::classification::internal::ModelImpl * m, size_t nIterations,
services::HostAppIface * pHostApp)
services::HostAppIface * pHostApp, bool predShapContributions,
bool predShapInteractions)
{
// assert we're not requesting both contributions and interactions
DAAL_ASSERT(!(predShapContributions && predShapInteractions));

DAAL_ASSERT(!nIterations || nIterations <= m->size());
DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data));
const auto nTreesTotal = (nIterations ? nIterations : m->size());
Expand Down Expand Up @@ -442,7 +452,7 @@ services::Status PredictBinaryClassificationTask<algorithmFPType, cpu>::run(cons
TArray<algorithmFPType, cpu> expValPtr(nRows);
algorithmFPType * expVal = expValPtr.get();
DAAL_CHECK_MALLOC(expVal);
s = super::runInternal(pHostApp, this->_res, margin, false, false);
s = super::runInternal(pHostApp, this->_res, margin, predShapContributions, predShapInteractions);
if (!s) return s;

auto nBlocks = daal::threader_get_threads_number();
Expand Down Expand Up @@ -474,7 +484,7 @@ services::Status PredictBinaryClassificationTask<algorithmFPType, cpu>::run(cons
algorithmFPType * expVal = expValPtr.get();
NumericTablePtr expNT = HomogenNumericTableCPU<algorithmFPType, cpu>::create(expVal, 1, nRows, &s);
DAAL_CHECK_MALLOC(expVal);
s = super::runInternal(pHostApp, expNT.get(), margin, false, false);
s = super::runInternal(pHostApp, expNT.get(), margin, predShapContributions, predShapInteractions);
if (!s) return s;

auto nBlocks = daal::threader_get_threads_number();
Expand All @@ -497,18 +507,22 @@ services::Status PredictBinaryClassificationTask<algorithmFPType, cpu>::run(cons
DAAL_CHECK_BLOCK_STATUS(resBD);
const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) };
algorithmFPType * res = resBD.get();
s = super::runInternal(pHostApp, this->_res, margin, false, false);
s = super::runInternal(pHostApp, this->_res, margin, predShapContributions, predShapInteractions);
if (!s) return s;

typedef services::internal::SignBit<algorithmFPType, cpu> SignBit;

PRAGMA_IVDEP
for (size_t iRow = 0; iRow < nRows; ++iRow)
// for SHAP values, the score from runInternal is what we need
if (!(predShapContributions || predShapInteractions))
{
// probability is a sigmoid(f) hence sign(f) can be checked
const algorithmFPType initial = res[iRow];
const int sign = SignBit::get(initial);
res[iRow] = label[sign];
// convert the score to a class label
typedef services::internal::SignBit<algorithmFPType, cpu> SignBit;
PRAGMA_IVDEP
for (size_t iRow = 0; iRow < nRows; ++iRow)
{
// probability is a sigmoid(f) hence sign(f) can be checked
const algorithmFPType initial = res[iRow];
const int sign = SignBit::get(initial);
res[iRow] = label[sign];
}
}
}
return s;
Expand Down Expand Up @@ -714,23 +728,28 @@ size_t PredictMulticlassTask<algorithmFPType, cpu>::getMaxClass(const algorithmF
template <typename algorithmFPType, prediction::Method method, CpuType cpu>
services::Status PredictKernel<algorithmFPType, method, cpu>::compute(services::HostAppIface * pHostApp, const NumericTable * x,
const classification::Model * m, NumericTable * r, NumericTable * prob,
size_t nClasses, size_t nIterations)
size_t nClasses, size_t nIterations, bool predShapContributions,
bool predShapInteractions)
{
const daal::algorithms::gbt::classification::internal::ModelImpl * pModel =
static_cast<const daal::algorithms::gbt::classification::internal::ModelImpl *>(m);
if (nClasses == 2)
{
PredictBinaryClassificationTask<algorithmFPType, cpu> task(x, r, prob);
return task.run(pModel, nIterations, pHostApp);
return task.run(pModel, nIterations, pHostApp, predShapContributions, predShapInteractions);
}
PredictMulticlassTask<algorithmFPType, cpu> task(x, r, prob);
return task.run(pModel, nClasses, nIterations, pHostApp);
return task.run(pModel, nClasses, nIterations, pHostApp, predShapContributions, predShapInteractions);
}

template <typename algorithmFPType, CpuType cpu>
services::Status PredictMulticlassTask<algorithmFPType, cpu>::run(const gbt::classification::internal::ModelImpl * m, size_t nClasses,
size_t nIterations, services::HostAppIface * pHostApp)
size_t nIterations, services::HostAppIface * pHostApp, bool predShapContributions,
bool predShapInteractions)
{
// assert we're not requesting both contributions and interactions
DAAL_ASSERT(!(predShapContributions && predShapInteractions));

DAAL_ASSERT(!nIterations || nClasses * nIterations <= m->size());
const auto nTreesTotal = (nIterations ? nIterations * nClasses : m->size());
DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ class PredictKernel : public daal::algorithms::Kernel
* \param prob[out] Prediction class probabilities
* \param nClasses[in] Number of classes in gradient boosted trees algorithm parameter
* \param nIterations[in] Number of iterations to predict in gradient boosted trees algorithm parameter
* \param predShapContributions[in] Predict SHAP contributions
* \param predShapInteractions[in] Predict SHAP interactions
*/
services::Status compute(services::HostAppIface * pHostApp, const NumericTable * a, const classification::Model * m, NumericTable * r,
NumericTable * prob, size_t nClasses, size_t nIterations);
NumericTable * prob, size_t nClasses, size_t nIterations, bool predShapContributions, bool predShapInteractions);
};

} // namespace internal
Expand Down
Loading

0 comments on commit f13313a

Please sign in to comment.