22 #ifndef __SGD_TYPES_H__
23 #define __SGD_TYPES_H__
25 #include "algorithms/algorithm.h"
26 #include "data_management/data/numeric_table.h"
27 #include "data_management/data/homogen_numeric_table.h"
28 #include "services/daal_defines.h"
29 #include "algorithms/optimization_solver/iterative_solver/iterative_solver_types.h"
30 #include "algorithms/engines/mt19937/mt19937.h"
36 namespace optimization_solver
67 pastUpdateVector = iterative_solver::lastOptionalData + 1,
68 pastWorkValue = pastUpdateVector + 1 ,
69 lastOptionalData = pastWorkValue
85 struct DAAL_EXPORT BaseParameter :
public optimization_solver::iterative_solver::interface1::Parameter
99 const sum_of_functions::interface1::BatchPtr &
function,
100 size_t nIterations = 100,
101 double accuracyThreshold = 1.0e-05,
102 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
103 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
104 new data_management::HomogenNumericTable<double>(
105 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
106 size_t batchSize = 1,
109 virtual ~BaseParameter() {}
116 virtual services::Status check()
const;
118 data_management::NumericTablePtr batchIndices;
121 data_management::NumericTablePtr learningRateSequence;
124 engines::EnginePtr engine;
129 template<Method method>
130 struct Parameter :
public BaseParameter {};
140 struct DAAL_EXPORT Parameter<defaultDense> :
public BaseParameter
152 const sum_of_functions::interface1::BatchPtr &
function,
153 size_t nIterations = 100,
154 double accuracyThreshold = 1.0e-05,
155 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
156 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
157 new data_management::HomogenNumericTable<double>(
158 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
166 virtual services::Status check()
const;
168 virtual ~Parameter() {}
180 struct DAAL_EXPORT Parameter<miniBatch> :
public BaseParameter
198 const sum_of_functions::interface1::BatchPtr &
function,
199 size_t nIterations = 100,
200 double accuracyThreshold = 1.0e-05,
201 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
202 size_t batchSize = 128,
203 data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
204 new data_management::HomogenNumericTable<double>(
205 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
206 size_t innerNIterations = 5,
207 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
208 new data_management::HomogenNumericTable<double>(
209 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
217 virtual services::Status check()
const;
219 virtual ~Parameter() {}
221 data_management::NumericTablePtr conservativeSequence;
222 size_t innerNIterations;
235 struct DAAL_EXPORT Parameter<momentum> :
public BaseParameter
252 const sum_of_functions::interface1::BatchPtr&
function,
253 double momentum = 0.9,
254 size_t nIterations = 100,
255 double accuracyThreshold = 1.0e-05,
256 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
257 size_t batchSize = 128,
258 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
259 new data_management::HomogenNumericTable<double>(
260 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
268 virtual services::Status check()
const;
270 virtual ~Parameter() {}
284 class DAAL_EXPORT Input :
public optimization_solver::iterative_solver::interface1::Input
287 typedef optimization_solver::iterative_solver::interface1::Input super;
289 Input(
const Input& other);
298 data_management::NumericTablePtr
get(OptionalDataId id)
const;
305 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
314 virtual services::Status check(
const daal::algorithms::Parameter *par,
int method)
const DAAL_C11_OVERRIDE;
324 class DAAL_EXPORT Result :
public optimization_solver::iterative_solver::interface1::Result
327 DECLARE_SERIALIZABLE_CAST(Result);
328 typedef optimization_solver::iterative_solver::interface1::Result super;
342 template <
typename algorithmFPType>
343 DAAL_EXPORT services::Status allocate(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
const int method);
350 data_management::NumericTablePtr
get(OptionalDataId id)
const;
357 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
367 virtual services::Status check(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
368 int method)
const DAAL_C11_OVERRIDE;
371 typedef services::SharedPtr<Result> ResultPtr;
391 struct DAAL_EXPORT BaseParameter :
public optimization_solver::iterative_solver::Parameter
405 const sum_of_functions::BatchPtr &
function,
406 size_t nIterations = 100,
407 double accuracyThreshold = 1.0e-05,
408 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
409 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
410 new data_management::HomogenNumericTable<double>(
411 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
412 size_t batchSize = 1,
415 virtual ~BaseParameter() {}
422 virtual services::Status check()
const;
424 data_management::NumericTablePtr batchIndices;
427 data_management::NumericTablePtr learningRateSequence;
430 engines::EnginePtr engine;
435 template<Method method>
436 struct Parameter :
public BaseParameter {};
446 struct DAAL_EXPORT Parameter<defaultDense> :
public BaseParameter
458 const sum_of_functions::BatchPtr &
function,
459 size_t nIterations = 100,
460 double accuracyThreshold = 1.0e-05,
461 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
462 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
463 new data_management::HomogenNumericTable<double>(
464 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
472 virtual services::Status check()
const;
474 virtual ~Parameter() {}
486 struct DAAL_EXPORT Parameter<miniBatch> :
public BaseParameter
504 const sum_of_functions::BatchPtr &
function,
505 size_t nIterations = 100,
506 double accuracyThreshold = 1.0e-05,
507 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
508 size_t batchSize = 128,
509 data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
510 new data_management::HomogenNumericTable<double>(
511 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
512 size_t innerNIterations = 5,
513 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
514 new data_management::HomogenNumericTable<double>(
515 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
523 virtual services::Status check()
const;
525 virtual ~Parameter() {}
527 data_management::NumericTablePtr conservativeSequence;
528 size_t innerNIterations;
541 struct DAAL_EXPORT Parameter<momentum> :
public BaseParameter
558 const sum_of_functions::BatchPtr&
function,
559 double momentum = 0.9,
560 size_t nIterations = 100,
561 double accuracyThreshold = 1.0e-05,
562 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
563 size_t batchSize = 128,
564 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
565 new data_management::HomogenNumericTable<double>(
566 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
574 virtual services::Status check()
const;
576 virtual ~Parameter() {}
590 class DAAL_EXPORT Input :
public optimization_solver::iterative_solver::Input
593 typedef optimization_solver::iterative_solver::Input super;
595 Input(
const Input& other);
604 data_management::NumericTablePtr
get(OptionalDataId id)
const;
611 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
620 virtual services::Status check(
const daal::algorithms::Parameter *par,
int method)
const DAAL_C11_OVERRIDE;
630 class DAAL_EXPORT Result :
public optimization_solver::iterative_solver::Result
633 DECLARE_SERIALIZABLE_CAST(Result);
634 typedef optimization_solver::iterative_solver::Result super;
648 template <
typename algorithmFPType>
649 DAAL_EXPORT services::Status allocate(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
const int method);
656 data_management::NumericTablePtr
get(OptionalDataId id)
const;
663 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
673 virtual services::Status check(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
674 int method)
const DAAL_C11_OVERRIDE;
677 typedef services::SharedPtr<Result> ResultPtr;
682 using interface2::BaseParameter;
683 using interface2::Parameter;
684 using interface2::Input;
685 using interface2::Result;
686 using interface2::ResultPtr;
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:430
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:124
daal::algorithms::optimization_solver::sgd::interface1::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:221
daal
Definition: algorithm_base_common.h:31
daal::algorithms::optimization_solver::iterative_solver::interface1::Result
Results obtained with the compute() method of the iterative solver algorithm in the batch processing ...
Definition: iterative_solver_types.h:221
daal::algorithms::optimization_solver::sgd::interface1::Parameter
Definition: sgd_types.h:130
daal::algorithms::optimization_solver::sgd::miniBatch
Definition: sgd_types.h:57
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::seed
size_t seed
Definition: sgd_types.h:428
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::seed
size_t seed
Definition: sgd_types.h:122
daal::algorithms::optimization_solver::iterative_solver::interface1::Parameter
Parameter base class for the iterative solver algorithm
Definition: iterative_solver_types.h:113
daal::algorithms::em_gmm::nIterations
Definition: em_gmm_types.h:97
daal::algorithms::optimization_solver::sgd::interface2::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:527
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:391
daal::algorithms::optimization_solver::sgd::OptionalDataId
OptionalDataId
Definition: sgd_types.h:65
daal::algorithms::optimization_solver::sgd::interface1::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:324
daal::algorithms::optimization_solver::sgd::defaultDense
Definition: sgd_types.h:56
daal::algorithms::optimization_solver::sgd::interface2::Parameter
Definition: sgd_types.h:436
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:424
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:118
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:427
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:121
daal::algorithms::optimization_solver::sgd::interface1::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:272
daal::algorithms::optimization_solver::sgd::interface2::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:630
daal::algorithms::optimization_solver::sgd::momentum
Definition: sgd_types.h:58
daal::algorithms::optimization_solver::sgd::pastUpdateVector
Definition: sgd_types.h:67
daal::algorithms::optimization_solver::sgd::Method
Method
Definition: sgd_types.h:54
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:85
daal::algorithms::optimization_solver::sgd::pastWorkValue
Definition: sgd_types.h:68
daal::algorithms::optimization_solver::sgd::interface2::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:578