From f13313a87a2abc551104d77a28c3a701fe345101 Mon Sep 17 00:00:00 2001 From: Andreas Huber <9201869+ahuber21@users.noreply.github.com> Date: Thu, 5 Sep 2024 11:56:21 +0200 Subject: [PATCH] feat: SHAP values for binary classification [target: r2025.0] (#2635) * Add result container for shap predictions in df clsf * Continue work on dispatching SHAP options * Fix result shape checks * SHAP for binary clsf works * Fix result allocation for all resultsToEvaluate * finalize binary classification SHAP values * Update example * Revert "Update example" This reverts commit 375dacd4aff2d9c4c6ad8bff9cfd88066cf74fe8. * typo in comment * Revert "Revert "Update example"" This reverts commit f5072c5f3c4869d81f9772f8b547b1ab768645aa. * Align checking for configuration errors --------- Co-authored-by: icfaust --- .../gbt_classification_predict.h | 11 +- .../gbt_classification_predict_types.h | 84 +++++++++++++- cpp/daal/include/services/error_indexes.h | 3 +- .../gbt_classification_predict_container.h | 5 +- ...ication_predict_dense_default_batch_impl.i | 55 ++++++--- .../gbt_classification_predict_kernel.h | 4 +- .../gbt_classification_predict_result_fpt.cpp | 96 ++++++++++++++++ .../gbt_classification_predict_types.cpp | 104 ++++++++++++++++-- ...ression_predict_dense_default_batch_impl.i | 2 +- .../gbt_regression_predict_types.cpp | 6 +- cpp/daal/src/algorithms/dtrees/gbt/treeshap.h | 25 +++-- cpp/daal/src/services/error_handling.cpp | 1 + .../gbt_cls_traversed_model_builder.cpp | 3 +- 13 files changed, 345 insertions(+), 54 deletions(-) create mode 100644 cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_result_fpt.cpp diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict.h index 42c6e1118d4..fa3e2041ae2 100644 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict.h @@ -107,7 +107,7 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch typedef algorithms::gbt::classification::prediction::Input InputType; typedef algorithms::gbt::classification::prediction::Parameter ParameterType; - typedef typename super::ResultType ResultType; + typedef algorithms::gbt::classification::prediction::Result ResultType; InputType input; /*!< %Input objects of the algorithm */ @@ -152,6 +152,12 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch */ virtual int getMethod() const DAAL_C11_OVERRIDE { return (int)method; } + /** + * Returns the structure that contains the result of model-based prediction + * \return Structure that contains the result of the model-based prediction + */ + ResultPtr getResult() { return ResultType::cast(_result); } + /** * Returns a pointer to the newly allocated gradient boosted trees prediction algorithm with a copy of input objects * and parameters of this gradient boosted trees prediction algorithm @@ -164,7 +170,7 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch services::Status allocateResult() DAAL_C11_OVERRIDE { - services::Status s = _result->allocate(&input, _par, 0); + services::Status s = getResult()->template allocate(&input, _par, 0); _res = _result.get(); return s; } @@ -173,6 +179,7 @@ class DAAL_EXPORT Batch : public classifier::prediction::Batch { _in = &input; _ac = new __DAAL_ALGORITHM_CONTAINER(batch, BatchContainer, algorithmFPType, method)(&_env); + _result.reset(new ResultType()); } private: diff --git a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h index 58aeea58a3a..d10d48211d2 100755 --- a/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h +++ b/cpp/daal/include/algorithms/gradient_boosted_trees/gbt_classification_predict_types.h @@ -56,6 +56,28 @@ enum Method defaultDense = 0 /*!< Default method */ }; +/** + * + * \brief Available identifiers of input models for making model-based prediction + */ +enum ModelInputId +{ + model = algorithms::classifier::prediction::model, /*!< Trained gradient boosted trees model */ + lastModelInputId = model +}; + +/** + * + * \brief Available identifiers of the result for making model-based prediction + */ +enum ResultId +{ + prediction = algorithms::classifier::prediction::prediction, + probabilities = algorithms::classifier::prediction::probabilities, + logProbabilities = algorithms::classifier::prediction::logProbabilities, + lastResultId = logProbabilities +}; + /** * * Available identifiers to specify the result to compute - results are mutually exclusive @@ -89,6 +111,62 @@ struct DAAL_EXPORT Parameter : public daal::algorithms::classifier::Parameter DAAL_UINT64 resultsToCompute; /*!< 64 bit integer flag that indicates the results to compute */ }; /* [Parameter source code] */ + +/** + * + * \brief Provides interface for the result of model-based prediction + */ +class DAAL_EXPORT Result : public algorithms::classifier::prediction::Result +{ +public: + DECLARE_SERIALIZABLE_CAST(Result) + Result(); + + /** + * Returns the result of model-based prediction + * \param[in] id Identifier of the result + * \return Result that corresponds to the given identifier + */ + data_management::NumericTablePtr get(ResultId id) const; + + /** + * Sets the result of model-based prediction + * \param[in] id Identifier of the input object + * \param[in] value %Input object + */ + void set(ResultId id, const data_management::NumericTablePtr & value); + + /** + * Allocates memory to store a partial result of model-based prediction + * \param[in] input %Input object + * \param[in] par %Parameter of the algorithm + * \param[in] method Algorithm method + * \return Status of allocation + */ + template + DAAL_EXPORT services::Status allocate(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, const int method); + + /** + * Checks the result of model-based prediction + * \param[in] input %Input object + * \param[in] par %Parameter of the algorithm + * \param[in] method Computation method + * \return Status of checking + */ + services::Status check(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, int method) const DAAL_C11_OVERRIDE; + +protected: + using daal::algorithms::Result::check; + + /** \private */ + template + services::Status serialImpl(Archive * arch) + { + return daal::algorithms::Result::serialImpl(arch); + } +}; + +typedef services::SharedPtr ResultPtr; } // namespace interface2 /** @@ -124,7 +202,7 @@ class DAAL_EXPORT Input : public classifier::prediction::Input * \param[in] id Identifier of the input Model object * \return %Input object that corresponds to the given identifier */ - gbt::classification::ModelPtr get(classifier::prediction::ModelInputId id) const; + gbt::classification::ModelPtr get(ModelInputId id) const; /** * Sets the input NumericTable object in the prediction stage of the classification algorithm @@ -138,7 +216,7 @@ class DAAL_EXPORT Input : public classifier::prediction::Input * \param[in] id Identifier of the input object * \param[in] ptr Pointer to the input object */ - void set(classifier::prediction::ModelInputId id, const gbt::classification::ModelPtr & ptr); + void set(ModelInputId id, const gbt::classification::ModelPtr & ptr); /** * Checks the correctness of the input object @@ -151,6 +229,8 @@ class DAAL_EXPORT Input : public classifier::prediction::Input } // namespace interface1 using interface2::Parameter; +using interface2::Result; +using interface2::ResultPtr; using interface1::Input; } // namespace prediction /** @} */ diff --git a/cpp/daal/include/services/error_indexes.h b/cpp/daal/include/services/error_indexes.h index 6833ca50c97..c386f518fb3 100644 --- a/cpp/daal/include/services/error_indexes.h +++ b/cpp/daal/include/services/error_indexes.h @@ -381,7 +381,8 @@ enum ErrorID // GBT error: -30000..-30099 ErrorGbtIncorrectNumberOfTrees = -30000, /*!< Number of trees in the model is not consistent with the number of classes */ ErrorGbtPredictIncorrectNumberOfIterations = -30001, /*!< Number of iterations value in GBT parameter is not consistent with the model */ - ErrorGbtPredictShapOptions = -30002, /*< For SHAP values, calculate either contributions or interactions, not both */ + ErrorGbtPredictShapOptions = -30002, /*!< For SHAP values, calculate either contributions or interactions, not both */ + ErrorGbtPredictShapMulticlassNotSupported = -30003, /*!< For classification, SHAP values currently only support binary classification */ // Data management errors: -80001.. ErrorUserAllocatedMemory = -80001, /*!< Couldn't free memory allocated by user */ diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_container.h b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_container.h index fbb0bf5a62d..5ba78699c18 100755 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_container.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_container.h @@ -70,8 +70,11 @@ services::Status BatchContainer::compute() result->get(classifier::prediction::probabilities).get() : nullptr); + const bool predShapContributions = par->resultsToCompute & shapContributions; + const bool predShapInteractions = par->resultsToCompute & shapInteractions; __DAAL_CALL_KERNEL(env, internal::PredictKernel, __DAAL_KERNEL_ARGUMENTS(algorithmFPType, method), compute, - daal::services::internal::hostApp(*input), a, m, r, prob, par->nClasses, par->nIterations); + daal::services::internal::hostApp(*input), a, m, r, prob, par->nClasses, par->nIterations, predShapContributions, + predShapInteractions); } } // namespace interface2 diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i index 00a9ee5884d..4d5d4829d74 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_dense_default_batch_impl.i @@ -81,9 +81,12 @@ public: * \param m The model for which to run prediction * \param nIterations Number of iterations * \param pHostApp HostAppInterface + * \param predShapContributions Predict SHAP contributions + * \param predShapInteractions Predict SHAP interactions * \return services::Status */ - services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp); + services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, services::HostAppIface * pHostApp, + bool predShapContributions, bool predShapInteractions); protected: /** @@ -126,9 +129,12 @@ public: * \param nClasses Number of data classes * \param nIterations Number of iterations * \param pHostApp HostAppInterface + * \param predShapContributions Predict SHAP contributions + * \param predShapInteractions Predict SHAP interactions * \return services::Status */ - services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nClasses, size_t nIterations, services::HostAppIface * pHostApp); + services::Status run(const gbt::classification::internal::ModelImpl * m, size_t nClasses, size_t nIterations, services::HostAppIface * pHostApp, + bool predShapContributions, bool predShapInteractions); protected: /** Dispatcher type for template dispatching */ @@ -408,8 +414,12 @@ protected: ////////////////////////////////////////////////////////////////////////////////////////// template services::Status PredictBinaryClassificationTask::run(const gbt::classification::internal::ModelImpl * m, size_t nIterations, - services::HostAppIface * pHostApp) + services::HostAppIface * pHostApp, bool predShapContributions, + bool predShapInteractions) { + // assert we're not requesting both contributions and interactions + DAAL_ASSERT(!(predShapContributions && predShapInteractions)); + DAAL_ASSERT(!nIterations || nIterations <= m->size()); DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data)); const auto nTreesTotal = (nIterations ? nIterations : m->size()); @@ -442,7 +452,7 @@ services::Status PredictBinaryClassificationTask::run(cons TArray expValPtr(nRows); algorithmFPType * expVal = expValPtr.get(); DAAL_CHECK_MALLOC(expVal); - s = super::runInternal(pHostApp, this->_res, margin, false, false); + s = super::runInternal(pHostApp, this->_res, margin, predShapContributions, predShapInteractions); if (!s) return s; auto nBlocks = daal::threader_get_threads_number(); @@ -474,7 +484,7 @@ services::Status PredictBinaryClassificationTask::run(cons algorithmFPType * expVal = expValPtr.get(); NumericTablePtr expNT = HomogenNumericTableCPU::create(expVal, 1, nRows, &s); DAAL_CHECK_MALLOC(expVal); - s = super::runInternal(pHostApp, expNT.get(), margin, false, false); + s = super::runInternal(pHostApp, expNT.get(), margin, predShapContributions, predShapInteractions); if (!s) return s; auto nBlocks = daal::threader_get_threads_number(); @@ -497,18 +507,22 @@ services::Status PredictBinaryClassificationTask::run(cons DAAL_CHECK_BLOCK_STATUS(resBD); const algorithmFPType label[2] = { algorithmFPType(1.), algorithmFPType(0.) }; algorithmFPType * res = resBD.get(); - s = super::runInternal(pHostApp, this->_res, margin, false, false); + s = super::runInternal(pHostApp, this->_res, margin, predShapContributions, predShapInteractions); if (!s) return s; - typedef services::internal::SignBit SignBit; - - PRAGMA_IVDEP - for (size_t iRow = 0; iRow < nRows; ++iRow) + // for SHAP values, the score from runInternal is what we need + if (!(predShapContributions || predShapInteractions)) { - // probability is a sigmoid(f) hence sign(f) can be checked - const algorithmFPType initial = res[iRow]; - const int sign = SignBit::get(initial); - res[iRow] = label[sign]; + // convert the score to a class label + typedef services::internal::SignBit SignBit; + PRAGMA_IVDEP + for (size_t iRow = 0; iRow < nRows; ++iRow) + { + // probability is a sigmoid(f) hence sign(f) can be checked + const algorithmFPType initial = res[iRow]; + const int sign = SignBit::get(initial); + res[iRow] = label[sign]; + } } } return s; @@ -714,23 +728,28 @@ size_t PredictMulticlassTask::getMaxClass(const algorithmF template services::Status PredictKernel::compute(services::HostAppIface * pHostApp, const NumericTable * x, const classification::Model * m, NumericTable * r, NumericTable * prob, - size_t nClasses, size_t nIterations) + size_t nClasses, size_t nIterations, bool predShapContributions, + bool predShapInteractions) { const daal::algorithms::gbt::classification::internal::ModelImpl * pModel = static_cast(m); if (nClasses == 2) { PredictBinaryClassificationTask task(x, r, prob); - return task.run(pModel, nIterations, pHostApp); + return task.run(pModel, nIterations, pHostApp, predShapContributions, predShapInteractions); } PredictMulticlassTask task(x, r, prob); - return task.run(pModel, nClasses, nIterations, pHostApp); + return task.run(pModel, nClasses, nIterations, pHostApp, predShapContributions, predShapInteractions); } template services::Status PredictMulticlassTask::run(const gbt::classification::internal::ModelImpl * m, size_t nClasses, - size_t nIterations, services::HostAppIface * pHostApp) + size_t nIterations, services::HostAppIface * pHostApp, bool predShapContributions, + bool predShapInteractions) { + // assert we're not requesting both contributions and interactions + DAAL_ASSERT(!(predShapContributions && predShapInteractions)); + DAAL_ASSERT(!nIterations || nClasses * nIterations <= m->size()); const auto nTreesTotal = (nIterations ? nIterations * nClasses : m->size()); DAAL_CHECK_MALLOC(this->_featHelper.init(*this->_data)); diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h index 0efa3707206..d024b6aeb1a 100755 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_kernel.h @@ -57,9 +57,11 @@ class PredictKernel : public daal::algorithms::Kernel * \param prob[out] Prediction class probabilities * \param nClasses[in] Number of classes in gradient boosted trees algorithm parameter * \param nIterations[in] Number of iterations to predict in gradient boosted trees algorithm parameter + * \param predShapContributions[in] Predict SHAP contributions + * \param predShapInteractions[in] Predict SHAP interactions */ services::Status compute(services::HostAppIface * pHostApp, const NumericTable * a, const classification::Model * m, NumericTable * r, - NumericTable * prob, size_t nClasses, size_t nIterations); + NumericTable * prob, size_t nClasses, size_t nIterations, bool predShapContributions, bool predShapInteractions); }; } // namespace internal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_result_fpt.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_result_fpt.cpp new file mode 100644 index 00000000000..09408dc96f1 --- /dev/null +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_result_fpt.cpp @@ -0,0 +1,96 @@ +/* file: gbt_classification_predict_result_fpt.cpp */ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* +//++ +// Implementation of the gradient boosted trees classification algorithm interface +//-- +*/ + +#include "algorithms/gradient_boosted_trees/gbt_classification_predict_types.h" +#include "data_management/data/homogen_numeric_table.h" +#include "src/services/daal_strings.h" + +namespace daal +{ +namespace algorithms +{ +namespace gbt +{ +namespace classification +{ +namespace prediction +{ +using namespace daal::services; + +template +DAAL_EXPORT services::Status Result::allocate(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, const int method) +{ + using algorithms::classifier::computeClassLabels; + using algorithms::classifier::computeClassProbabilities; + using algorithms::classifier::computeClassLogProbabilities; + using algorithms::classifier::prediction::data; + + const Input * algInput = (static_cast(input)); + data_management::NumericTablePtr dataPtr = algInput->get(data); + DAAL_CHECK_EX(dataPtr.get(), ErrorNullInputNumericTable, ArgumentName, dataStr()); + services::Status s; + const size_t nVectors = dataPtr->getNumberOfRows(); + + const Parameter * classificationParameter = static_cast(par); + + if (classificationParameter->resultsToEvaluate & computeClassLabels) + { + size_t nColumnsToAllocate = 1; + if (classificationParameter->resultsToCompute & shapContributions) + { + const size_t nColumns = dataPtr->getNumberOfColumns(); + nColumnsToAllocate = nColumns + 1; + } + else if (classificationParameter->resultsToCompute & shapInteractions) + { + const size_t nColumns = dataPtr->getNumberOfColumns(); + nColumnsToAllocate = (nColumns + 1) * (nColumns + 1); + } + + Argument::set(prediction, data_management::HomogenNumericTable::create(nColumnsToAllocate, nVectors, + data_management::NumericTableIface::doAllocate, &s)); + } + + if (classificationParameter->resultsToEvaluate & computeClassProbabilities) + { + Argument::set(probabilities, data_management::HomogenNumericTable::create( + classificationParameter->nClasses, nVectors, data_management::NumericTableIface::doAllocate, &s)); + } + + if (classificationParameter->resultsToEvaluate & computeClassLogProbabilities) + { + Argument::set(logProbabilities, data_management::HomogenNumericTable::create( + classificationParameter->nClasses, nVectors, data_management::NumericTableIface::doAllocate, &s)); + } + + return s; +} + +template DAAL_EXPORT services::Status Result::allocate(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, + const int method); + +} // namespace prediction +} // namespace classification +} // namespace gbt +} // namespace algorithms +} // namespace daal diff --git a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp index 58e3da1a52b..0ae71f30db4 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/classification/gbt_classification_predict_types.cpp @@ -22,6 +22,7 @@ */ #include "algorithms/gradient_boosted_trees/gbt_classification_predict_types.h" +#include "algorithms/classifier/classifier_predict_types.h" #include "src/services/serialization_utils.h" #include "src/services/daal_strings.h" #include "src/algorithms/dtrees/gbt/classification/gbt_classification_model_impl.h" @@ -56,7 +57,7 @@ NumericTablePtr Input::get(classifier::prediction::NumericTableInputId id) const * \param[in] id Identifier of the input object * \return %Input object that corresponds to the given identifier */ -gbt::classification::ModelPtr Input::get(classifier::prediction::ModelInputId id) const +gbt::classification::ModelPtr Input::get(ModelInputId id) const { return staticPointerCast(Argument::get(id)); } @@ -76,9 +77,9 @@ void Input::set(classifier::prediction::NumericTableInputId id, const NumericTab * \param[in] id Identifier of the input object * \param[in] value %Input object */ -void Input::set(classifier::prediction::ModelInputId id, const gbt::classification::ModelPtr & value) +void Input::set(ModelInputId id, const gbt::classification::ModelPtr & value) { - algorithms::classifier::prediction::Input::set(id, value); + algorithms::classifier::prediction::Input::set(algorithms::classifier::prediction::ModelInputId(id), value); } /** @@ -88,16 +89,14 @@ services::Status Input::check(const daal::algorithms::Parameter * parameter, int { Status s; DAAL_CHECK_STATUS(s, algorithms::classifier::prediction::Input::check(parameter, method)); - ModelPtr m = get(classifier::prediction::model); - const daal::algorithms::gbt::classification::internal::ModelImpl * pModel = - static_cast(m.get()); + classifier::ModelPtr m = get(prediction::model); + const auto * pModel = static_cast(m.get()); DAAL_ASSERT(pModel); DAAL_CHECK(pModel->getNumberOfTrees(), services::ErrorNullModel); size_t nClasses = 0, nIterations = 0; - const gbt::classification::prediction::interface2::Parameter * pPrm = - dynamic_cast(parameter); + const auto * pPrm = dynamic_cast(parameter); if (pPrm) { nClasses = pPrm->nClasses; @@ -108,19 +107,100 @@ services::Status Input::check(const daal::algorithms::Parameter * parameter, int return services::ErrorNullParameterNotSupported; } + const bool predictContribs = pPrm->resultsToCompute & shapContributions; + const bool predictInteractions = pPrm->resultsToCompute & shapInteractions; + DAAL_CHECK(!(predictContribs && predictInteractions), services::ErrorGbtPredictShapOptions); + DAAL_CHECK(!(nClasses > 2 && (predictContribs || predictInteractions)), services::ErrorGbtPredictShapMulticlassNotSupported); + auto maxNIterations = pModel->getNumberOfTrees(); if (nClasses > 2) maxNIterations /= nClasses; DAAL_CHECK((nClasses < 3) || (pModel->getNumberOfTrees() % nClasses == 0), services::ErrorGbtIncorrectNumberOfTrees); DAAL_CHECK((nIterations == 0) || (nIterations <= maxNIterations), services::ErrorGbtPredictIncorrectNumberOfIterations); - const bool predictContribs = pPrm->resultsToCompute & shapContributions; - const bool predictInteractions = pPrm->resultsToCompute & shapInteractions; - DAAL_CHECK(!(predictContribs || predictInteractions), services::ErrorMethodNotImplemented); - return s; } } // namespace interface1 + +namespace interface2 +{ +__DAAL_REGISTER_SERIALIZATION_CLASS(Result, SERIALIZATION_DECISION_FOREST_CLASSIFICATION_PREDICTION_RESULT_ID); + +Result::Result() : algorithms::classifier::prediction::Result(lastResultId + 1) {}; + +/** + * Returns the result of gradient boosted trees model-based prediction + * \param[in] id Identifier of the result + * \return Result that corresponds to the given identifier + */ +NumericTablePtr Result::get(ResultId id) const +{ + return algorithms::classifier::prediction::Result::get(algorithms::classifier::prediction::ResultId(id)); +} + +/** + * Sets the result of gradient boosted trees model-based prediction + * \param[in] id Identifier of the input object + * \param[in] value %Input object + */ +void Result::set(ResultId id, const NumericTablePtr & value) +{ + algorithms::classifier::prediction::Result::set(algorithms::classifier::prediction::ResultId(id), value); +} + +/** + * Checks the result of gradient boosted trees model-based prediction + * \param[in] input %Input object + * \param[in] par %Parameter of the algorithm + * \param[in] method Computation method + */ +services::Status Result::check(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, int method) const +{ + using algorithms::classifier::computeClassLabels; + using algorithms::classifier::computeClassProbabilities; + using algorithms::classifier::computeClassLogProbabilities; + using algorithms::classifier::prediction::data; + using algorithms::classifier::prediction::Result; + + const prediction::Parameter * classificationParameter = static_cast(par); + Status s; + + const auto predictionInput = static_cast(input); + DAAL_CHECK(predictionInput->get(data).get(), services::ErrorNullInputNumericTable); + const size_t nRows = predictionInput->get(data)->getNumberOfRows(); + + classifier::ModelPtr m = predictionInput->get(prediction::model); + DAAL_CHECK(m.get(), services::ErrorNullModel); + + if (classificationParameter->resultsToEvaluate & computeClassLabels) + { + size_t expectedNColumns = 1; + if (classificationParameter->resultsToCompute & shapContributions) + { + const size_t nColumns = predictionInput->get(data)->getNumberOfColumns(); + expectedNColumns = nColumns + 1; + } + else if (classificationParameter->resultsToCompute & shapInteractions) + { + const size_t nColumns = predictionInput->get(data)->getNumberOfColumns(); + expectedNColumns = (nColumns + 1) * (nColumns + 1); + } + DAAL_CHECK_EX(get(prediction)->getNumberOfColumns() == expectedNColumns, ErrorIncorrectNumberOfColumns, ArgumentName, predictionStr()); + DAAL_CHECK_STATUS( + s, data_management::checkNumericTable(get(prediction).get(), predictionStr(), data_management::packed_mask, 0, expectedNColumns, nRows)); + } + if (classificationParameter->resultsToEvaluate & computeClassProbabilities) + DAAL_CHECK_STATUS(s, data_management::checkNumericTable(get(probabilities).get(), probabilitiesStr(), data_management::packed_mask, 0, + classificationParameter->nClasses, nRows)); + if (classificationParameter->resultsToEvaluate & computeClassLogProbabilities) + DAAL_CHECK_STATUS(s, data_management::checkNumericTable(get(logProbabilities).get(), logProbabilitiesStr(), data_management::packed_mask, 0, + classificationParameter->nClasses, nRows)); + + return s; +} + +} // namespace interface2 + } // namespace prediction } // namespace classification } // namespace gbt diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i index 7b506fe36b5..ec3c3969dfd 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_dense_default_batch_impl.i @@ -339,7 +339,7 @@ services::Status PredictRegressionTask::predictContributio for (size_t currentTreeIndex = iTree; currentTreeIndex < iTree + nTrees; ++currentTreeIndex) { const gbt::internal::GbtDecisionTree * currentTree = _aTree[currentTreeIndex]; - st |= gbt::treeshap::treeShap(currentTree, currentX, phi, &_featHelper, + st |= gbt::treeshap::treeShap(currentTree, currentX, phi, 1, &_featHelper, condition, conditionFeature); } diff --git a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp index 82e6492e309..f05e5b7fa22 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp +++ b/cpp/daal/src/algorithms/dtrees/gbt/regression/gbt_regression_predict_types.cpp @@ -143,9 +143,11 @@ void Result::set(ResultId id, const NumericTablePtr & value) */ services::Status Result::check(const daal::algorithms::Input * input, const daal::algorithms::Parameter * par, int method) const { + using algorithms::regression::prediction::Result; + Status s; - DAAL_CHECK_STATUS(s, algorithms::regression::prediction::Result::check(input, par, method)); - const auto inputCast = static_cast(input); + DAAL_CHECK_STATUS(s, Result::check(input, par, method)); + const auto inputCast = static_cast(input); const prediction::Parameter * regressionParameter = static_cast(par); size_t expectedNColumns = 1; if (regressionParameter->resultsToCompute & shapContributions) diff --git a/cpp/daal/src/algorithms/dtrees/gbt/treeshap.h b/cpp/daal/src/algorithms/dtrees/gbt/treeshap.h index be5b1cb739e..3702ecc65e3 100644 --- a/cpp/daal/src/algorithms/dtrees/gbt/treeshap.h +++ b/cpp/daal/src/algorithms/dtrees/gbt/treeshap.h @@ -237,7 +237,7 @@ float unwoundPathSumZero(const float * pWeights, uint32_t uniqueDepth, uint32_t * Important: nodeIndex is counted from 0 here! */ template -inline void treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, +inline void treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, size_t numOutputs, const FeatureTypes * featureHelper, size_t nodeIndex, size_t depth, size_t uniqueDepth, size_t uniqueDepthPWeights, PathElement * parentUniquePath, float * parentPWeights, algorithmFPType pWeightsResidual, float parentZeroFraction, float parentOneFraction, int parentFeatureIndex, int condition, FeatureIndexType conditionFeature, float conditionFraction) @@ -245,7 +245,6 @@ inline void treeShap(const gbt::internal::GbtDecisionTree * tree, const algorith // stop if we have no weight coming down to us if (conditionFraction < FLT_EPSILON) return; - const size_t numOutputs = 1; // currently only support single-output models const ModelFPType * const splitValues = tree->getSplitPoints() - 1; const int * const defaultLeft = tree->getDefaultLeftForSplit() - 1; const FeatureIndexType * const fIndexes = tree->getFeatureIndexesForSplit() - 1; @@ -377,12 +376,12 @@ inline void treeShap(const gbt::internal::GbtDecisionTree * tree, const algorith } treeShap( - tree, x, phi, featureHelper, hotIndex, depth + 1, uniqueDepth + 1, uniqueDepthPWeights + 1, uniquePath, pWeights, pWeightsResidual, - hotZeroFraction * incomingZeroFraction, incomingOneFraction, splitIndex, condition, conditionFeature, hotConditionFraction); + tree, x, phi, numOutputs, featureHelper, hotIndex, depth + 1, uniqueDepth + 1, uniqueDepthPWeights + 1, uniquePath, pWeights, + pWeightsResidual, hotZeroFraction * incomingZeroFraction, incomingOneFraction, splitIndex, condition, conditionFeature, hotConditionFraction); treeShap( - tree, x, phi, featureHelper, coldIndex, depth + 1, uniqueDepth + 1, uniqueDepthPWeights + 1, uniquePath, pWeights, pWeightsResidual, - coldZeroFraction * incomingZeroFraction, 0, splitIndex, condition, conditionFeature, coldConditionFraction); + tree, x, phi, numOutputs, featureHelper, coldIndex, depth + 1, uniqueDepth + 1, uniqueDepthPWeights + 1, uniquePath, pWeights, + pWeightsResidual, coldZeroFraction * incomingZeroFraction, 0, splitIndex, condition, conditionFeature, coldConditionFraction); } /** @@ -395,7 +394,7 @@ inline void treeShap(const gbt::internal::GbtDecisionTree * tree, const algorith * \param conditionFeature the index of the feature to fix */ template -inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, +inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, size_t numClasses, const FeatureTypes * featureHelper, int condition, FeatureIndexType conditionFeature) { services::Status st; @@ -410,8 +409,8 @@ inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, co TArray pWeights(nElements); DAAL_CHECK_MALLOC(pWeights.get()); - treeShap(tree, x, phi, featureHelper, 1, 0, 0, 0, uniquePathData.get(), pWeights.get(), 1, - 1, 1, -1, condition, conditionFeature, 1); + treeShap(tree, x, phi, numClasses, featureHelper, 1, 0, 0, 0, uniquePathData.get(), + pWeights.get(), 1, 1, 1, -1, condition, conditionFeature, 1); return st; } @@ -430,12 +429,13 @@ enum TreeShapVersion * \param tree current tree * \param x dense data matrix * \param phi dense output matrix of feature attributions + * \param numClasses number of classes in input data * \param featureHelper pointer to a FeatureTypes object (required to traverse tree) * \param condition fix one feature to either off (-1) on (1) or not fixed (0 default) * \param conditionFeature the index of the feature to fix */ template -inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, +inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, const algorithmFPType * x, algorithmFPType * phi, size_t numClasses, const FeatureTypes * featureHelper, int condition, FeatureIndexType conditionFeature, TreeShapVersion shapVersion = fast_v1) { @@ -446,11 +446,12 @@ inline services::Status treeShap(const gbt::internal::GbtDecisionTree * tree, co switch (shapVersion) { case lundberg: + DAAL_CHECK(numClasses == 1, services::ErrorIncorrectParameter); // our lundberg version only supports single-class/regression return treeshap::internal::v0::treeShap(tree, x, phi, featureHelper, condition, conditionFeature); case fast_v1: - return treeshap::internal::v1::treeShap(tree, x, phi, featureHelper, condition, - conditionFeature); + return treeshap::internal::v1::treeShap(tree, x, phi, numClasses, featureHelper, + condition, conditionFeature); default: return services::Status(ErrorMethodNotImplemented); } } diff --git a/cpp/daal/src/services/error_handling.cpp b/cpp/daal/src/services/error_handling.cpp index 4e20a556953..1059aca4425 100644 --- a/cpp/daal/src/services/error_handling.cpp +++ b/cpp/daal/src/services/error_handling.cpp @@ -932,6 +932,7 @@ void ErrorMessageCollection::parseResourceFile() add(ErrorGbtIncorrectNumberOfTrees, "Number of trees in the model is not consistent with the number of classes"); add(ErrorGbtPredictIncorrectNumberOfIterations, "Number of iterations value in GBT parameter is not consistent with the model"); add(ErrorGbtPredictShapOptions, "Incompatible SHAP options. Can calculate either contributions or interactions, not both"); + add(ErrorGbtPredictShapMulticlassNotSupported, "Multiclass classification SHAP values not supported."); //Math errors: -90000..-90099 add(ErrorDataSourseNotAvailable, "ErrorDataSourseNotAvailable"); diff --git a/examples/daal/cpp/source/gradient_boosted_trees/gbt_cls_traversed_model_builder.cpp b/examples/daal/cpp/source/gradient_boosted_trees/gbt_cls_traversed_model_builder.cpp index 71d1c22d6d4..8170728cd10 100644 --- a/examples/daal/cpp/source/gradient_boosted_trees/gbt_cls_traversed_model_builder.cpp +++ b/examples/daal/cpp/source/gradient_boosted_trees/gbt_cls_traversed_model_builder.cpp @@ -272,8 +272,7 @@ size_t testModel(daal::algorithms::gbt::classification::ModelPtr modelPtr) { algorithm.compute(); /* Retrieve the algorithm results */ - NumericTablePtr prediction = - algorithm.getResult()->get(daal::algorithms::classifier::prediction::prediction); + NumericTablePtr prediction = algorithm.getResult()->get(prediction::prediction); printNumericTable(prediction, "Gradient boosted trees prediction results (first 10 rows):", 10); printNumericTable(testGroundTruth, "Ground truth (first 10 rows):", 10); size_t nRows = 0;