22 #ifndef __SAGA_TYPES_H__
23 #define __SAGA_TYPES_H__
25 #include "data_management/data/numeric_table.h"
26 #include "data_management/data/homogen_numeric_table.h"
27 #include "services/daal_defines.h"
28 #include "algorithms/optimization_solver/iterative_solver/iterative_solver_types.h"
29 #include "algorithms/engines/mt19937/mt19937.h"
30 #include "algorithms/optimization_solver/objective_function/logistic_loss_batch.h"
36 namespace optimization_solver
65 gradientsTable = iterative_solver::lastOptionalData + 1,
68 lastOptionalData = gradientsTable
84 struct DAAL_EXPORT Parameter :
public optimization_solver::iterative_solver::Parameter
100 const sum_of_functions::BatchPtr &
function,
101 size_t nIterations = 100,
102 double accuracyThreshold = 1.0e-05,
103 const data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
104 const size_t batchSize = 128,
105 const data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(),
109 virtual ~Parameter(){}
116 virtual services::Status check()
const DAAL_C11_OVERRIDE;
118 data_management::NumericTablePtr batchIndices;
121 data_management::NumericTablePtr learningRateSequence;
124 engines::EnginePtr engine;
136 class DAAL_EXPORT Input :
public optimization_solver::iterative_solver::Input
139 typedef optimization_solver::iterative_solver::Input super;
142 Input(
const Input& other);
152 data_management::NumericTablePtr
get(OptionalDataId id)
const;
159 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
168 virtual services::Status check(
const daal::algorithms::Parameter *par,
int method)
const DAAL_C11_OVERRIDE;
176 class DAAL_EXPORT Result :
public optimization_solver::iterative_solver::Result
179 DECLARE_SERIALIZABLE_CAST(Result);
180 typedef optimization_solver::iterative_solver::Result super;
194 template <
typename algorithmFPType>
195 DAAL_EXPORT services::Status allocate(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
const int method);
202 data_management::NumericTablePtr
get(OptionalDataId id)
const;
209 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
219 virtual services::Status check(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
220 int method)
const DAAL_C11_OVERRIDE;
222 typedef services::SharedPtr<Result> ResultPtr;
227 using interface1::Parameter;
228 using interface1::Input;
229 using interface1::Result;
230 using interface1::ResultPtr;
daal
Definition: algorithm_base_common.h:31
daal::algorithms::em_gmm::nIterations
Definition: em_gmm_types.h:97
daal::algorithms::optimization_solver::saga::OptionalDataId
OptionalDataId
Definition: saga_types.h:63
daal::algorithms::optimization_solver::saga::defaultDense
Definition: saga_types.h:56
daal::algorithms::optimization_solver::saga::interface1::Parameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: saga_types.h:121
daal::algorithms::optimization_solver::saga::gradientsTable
Definition: saga_types.h:65
daal::algorithms::optimization_solver::saga::Method
Method
Definition: saga_types.h:54
daal::algorithms::optimization_solver::saga::interface1::Parameter::engine
engines::EnginePtr engine
Definition: saga_types.h:124
daal::algorithms::optimization_solver::saga::interface1::Parameter::seed
size_t seed
Definition: saga_types.h:122
daal::algorithms::optimization_solver::saga::interface1::Parameter
Parameter base class for the Stochastic average gradient descent algorithm
Definition: saga_types.h:84
daal::algorithms::optimization_solver::saga::interface1::Parameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: saga_types.h:118
daal::algorithms::optimization_solver::saga::interface1::Result
Results obtained with the compute() method of the saga algorithm in the batch processing mode...
Definition: saga_types.h:176