22 #ifndef __SGD_TYPES_H__
23 #define __SGD_TYPES_H__
25 #include "algorithms/algorithm.h"
26 #include "data_management/data/numeric_table.h"
27 #include "data_management/data/homogen_numeric_table.h"
28 #include "services/daal_defines.h"
29 #include "algorithms/optimization_solver/iterative_solver/iterative_solver_types.h"
30 #include "algorithms/engines/mt19937/mt19937.h"
36 namespace optimization_solver
67 pastUpdateVector = iterative_solver::lastOptionalData + 1,
68 pastWorkValue = pastUpdateVector + 1 ,
69 lastOptionalData = pastWorkValue
85 struct DAAL_EXPORT BaseParameter :
public optimization_solver::iterative_solver::Parameter
99 const sum_of_functions::BatchPtr &
function,
100 size_t nIterations = 100,
101 double accuracyThreshold = 1.0e-05,
102 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
103 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
104 new data_management::HomogenNumericTable<double>(
105 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
106 size_t batchSize = 1,
109 virtual ~BaseParameter() {}
116 virtual services::Status check()
const;
118 data_management::NumericTablePtr batchIndices;
121 data_management::NumericTablePtr learningRateSequence;
124 engines::EnginePtr engine;
129 template<Method method>
130 struct Parameter :
public BaseParameter {};
140 struct DAAL_EXPORT Parameter<defaultDense> :
public BaseParameter
152 const sum_of_functions::BatchPtr &
function,
153 size_t nIterations = 100,
154 double accuracyThreshold = 1.0e-05,
155 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
156 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
157 new data_management::HomogenNumericTable<double>(
158 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
166 virtual services::Status check()
const;
168 virtual ~Parameter() {}
180 struct DAAL_EXPORT Parameter<miniBatch> :
public BaseParameter
198 const sum_of_functions::BatchPtr &
function,
199 size_t nIterations = 100,
200 double accuracyThreshold = 1.0e-05,
201 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
202 size_t batchSize = 128,
203 data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
204 new data_management::HomogenNumericTable<double>(
205 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
206 size_t innerNIterations = 5,
207 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
208 new data_management::HomogenNumericTable<double>(
209 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
217 virtual services::Status check()
const;
219 virtual ~Parameter() {}
221 data_management::NumericTablePtr conservativeSequence;
222 size_t innerNIterations;
235 struct DAAL_EXPORT Parameter<momentum> :
public BaseParameter
252 const sum_of_functions::BatchPtr&
function,
253 double momentum = 0.9,
254 size_t nIterations = 100,
255 double accuracyThreshold = 1.0e-05,
256 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
257 size_t batchSize = 128,
258 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
259 new data_management::HomogenNumericTable<double>(
260 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
268 virtual services::Status check()
const;
270 virtual ~Parameter() {}
284 class DAAL_EXPORT Input :
public optimization_solver::iterative_solver::Input
287 typedef optimization_solver::iterative_solver::Input super;
289 Input(
const Input& other);
298 data_management::NumericTablePtr
get(OptionalDataId id)
const;
305 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
314 virtual services::Status check(
const daal::algorithms::Parameter *par,
int method)
const DAAL_C11_OVERRIDE;
324 class DAAL_EXPORT Result :
public optimization_solver::iterative_solver::Result
327 DECLARE_SERIALIZABLE_CAST(Result);
328 typedef optimization_solver::iterative_solver::Result super;
342 template <
typename algorithmFPType>
343 DAAL_EXPORT services::Status allocate(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
const int method);
350 data_management::NumericTablePtr
get(OptionalDataId id)
const;
357 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
367 virtual services::Status check(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
368 int method)
const DAAL_C11_OVERRIDE;
371 typedef services::SharedPtr<Result> ResultPtr;
376 using interface1::BaseParameter;
377 using interface1::Parameter;
378 using interface1::Input;
379 using interface1::Result;
380 using interface1::ResultPtr;
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:124
daal::algorithms::optimization_solver::sgd::interface1::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:221
daal
Definition: algorithm_base_common.h:31
daal::algorithms::optimization_solver::sgd::interface1::Parameter
Definition: sgd_types.h:130
daal::algorithms::optimization_solver::sgd::miniBatch
Definition: sgd_types.h:57
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::seed
size_t seed
Definition: sgd_types.h:122
daal::algorithms::em_gmm::nIterations
Definition: em_gmm_types.h:97
daal::algorithms::optimization_solver::sgd::OptionalDataId
OptionalDataId
Definition: sgd_types.h:65
daal::algorithms::optimization_solver::sgd::interface1::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:324
daal::algorithms::optimization_solver::sgd::defaultDense
Definition: sgd_types.h:56
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:118
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:121
daal::algorithms::optimization_solver::sgd::interface1::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:272
daal::algorithms::optimization_solver::sgd::momentum
Definition: sgd_types.h:58
daal::algorithms::optimization_solver::sgd::pastUpdateVector
Definition: sgd_types.h:67
daal::algorithms::optimization_solver::sgd::Method
Method
Definition: sgd_types.h:54
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:85
daal::algorithms::optimization_solver::sgd::pastWorkValue
Definition: sgd_types.h:68