48 #ifndef __GBT_CLASSIFICATION_PREDICT_H__
49 #define __GBT_CLASSIFICATION_PREDICT_H__
51 #include "algorithms/algorithm.h"
52 #include "algorithms/classifier/classifier_predict.h"
53 #include "algorithms/gradient_boosted_trees/gbt_classification_model.h"
54 #include "algorithms/gradient_boosted_trees/gbt_classification_predict_types.h"
62 namespace classification
89 template<
typename algorithmFPType, Method method, CpuType cpu>
90 class DAAL_EXPORT BatchContainer :
public PredictionContainerIface
97 BatchContainer(daal::services::Environment::env *daalEnv);
104 services::Status compute() DAAL_C11_OVERRIDE;
127 template<
typename algorithmFPType = DAAL_ALGORITHM_FP_TYPE, Method method = defaultDense>
128 class Batch :
public classifier::prediction::Batch
137 Batch(
size_t nClasses)
139 _par =
new Parameter(nClasses);
149 Batch(
const Batch<algorithmFPType, method> &other) : classifier::prediction::Batch(other), input(other.input)
151 _par =
new Parameter(other.parameter());
165 Parameter& parameter() {
return *
static_cast<Parameter*
>(_par); }
171 const Parameter& parameter()
const {
return *
static_cast<const Parameter*
>(_par); }
177 Input * getInput() DAAL_C11_OVERRIDE {
return &input; }
183 virtual int getMethod() const DAAL_C11_OVERRIDE {
return(
int)method; }
190 services::SharedPtr<Batch<algorithmFPType, method> > clone()
const
192 return services::SharedPtr<Batch<algorithmFPType, method> >(cloneImpl());
196 virtual Batch<algorithmFPType, method> * cloneImpl() const DAAL_C11_OVERRIDE
198 return new Batch<algorithmFPType, method>(*this);
201 services::Status allocateResult() DAAL_C11_OVERRIDE
203 services::Status s = _result->allocate<algorithmFPType>(&input, 0, 0);
204 _res = _result.get();
211 _ac =
new __DAAL_ALGORITHM_CONTAINER(batch, BatchContainer, algorithmFPType, method)(&_env);
216 using interface1::BatchContainer;
217 using interface1::Batch;
daal
Definition: algorithm_base_common.h:57
daal::algorithms::gbt::classification::prediction::interface1::Batch::parameter
Parameter & parameter()
Definition: gbt_classification_predict.h:165
daal::algorithms::gbt::classification::prediction::interface1::Batch::parameter
const Parameter & parameter() const
Definition: gbt_classification_predict.h:171
daal::algorithms::PredictionContainerIface
Abstract interface class that provides virtual methods to access and run implementations of the algor...
Definition: prediction.h:66
daal::algorithms::gbt::classification::prediction::interface1::Batch::~Batch
~Batch()
Definition: gbt_classification_predict.h:156
daal::algorithms::gbt::classification::prediction::interface1::Batch::Batch
Batch(const Batch< algorithmFPType, method > &other)
Definition: gbt_classification_predict.h:149
daal::algorithms::gbt::classification::prediction::interface1::Parameter
Parameters of the prediction algorithm.
Definition: gbt_classification_predict_types.h:97
daal::batch
Definition: daal_defines.h:131
daal::algorithms::gbt::classification::prediction::interface1::Batch::input
Input input
Definition: gbt_classification_predict.h:131
daal::algorithms::gbt::classification::prediction::interface1::BatchContainer
Provides methods to run implementations of the gradient boosted trees algorithm. This class is associ...
Definition: gbt_classification_predict.h:90
daal::algorithms::gbt::classification::prediction::interface1::Batch::clone
services::SharedPtr< Batch< algorithmFPType, method > > clone() const
Definition: gbt_classification_predict.h:190
daal::algorithms::classifier::prediction::prediction
Definition: classifier_predict_types.h:102
daal::algorithms::gbt::classification::prediction::interface1::Batch::getMethod
virtual int getMethod() const DAAL_C11_OVERRIDE
Definition: gbt_classification_predict.h:183
daal::algorithms::gbt::classification::prediction::interface1::Batch::getInput
Input * getInput() DAAL_C11_OVERRIDE
Definition: gbt_classification_predict.h:177
daal::algorithms::gbt::classification::prediction::interface1::Batch
Predicts gradient boosted trees classification results.
Definition: gbt_classification_predict.h:128
daal::algorithms::gbt::classification::prediction::interface1::Batch::Batch
Batch(size_t nClasses)
Definition: gbt_classification_predict.h:137