48 #ifndef __SGD_TYPES_H__
49 #define __SGD_TYPES_H__
51 #include "algorithms/algorithm.h"
52 #include "data_management/data/numeric_table.h"
53 #include "data_management/data/homogen_numeric_table.h"
54 #include "services/daal_defines.h"
55 #include "algorithms/optimization_solver/iterative_solver/iterative_solver_types.h"
56 #include "algorithms/engines/mt19937/mt19937.h"
62 namespace optimization_solver
93 pastUpdateVector = iterative_solver::lastOptionalData + 1,
94 pastWorkValue = pastUpdateVector + 1 ,
95 lastOptionalData = pastWorkValue
111 struct DAAL_EXPORT BaseParameter :
public optimization_solver::iterative_solver::Parameter
125 const sum_of_functions::BatchPtr &
function,
126 size_t nIterations = 100,
127 double accuracyThreshold = 1.0e-05,
128 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
129 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
130 new data_management::HomogenNumericTable<double>(
131 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
132 size_t batchSize = 1,
135 virtual ~BaseParameter() {}
142 virtual services::Status check()
const;
144 data_management::NumericTablePtr batchIndices;
147 data_management::NumericTablePtr learningRateSequence;
150 engines::EnginePtr engine;
155 template<Method method>
156 struct Parameter :
public BaseParameter {};
166 struct DAAL_EXPORT Parameter<defaultDense> :
public BaseParameter
178 const sum_of_functions::BatchPtr &
function,
179 size_t nIterations = 100,
180 double accuracyThreshold = 1.0e-05,
181 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
182 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
183 new data_management::HomogenNumericTable<double>(
184 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
192 virtual services::Status check()
const;
194 virtual ~Parameter() {}
206 struct DAAL_EXPORT Parameter<miniBatch> :
public BaseParameter
224 const sum_of_functions::BatchPtr &
function,
225 size_t nIterations = 100,
226 double accuracyThreshold = 1.0e-05,
227 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
228 size_t batchSize = 128,
229 data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
230 new data_management::HomogenNumericTable<double>(
231 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
232 size_t innerNIterations = 5,
233 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
234 new data_management::HomogenNumericTable<double>(
235 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
243 virtual services::Status check()
const;
245 virtual ~Parameter() {}
247 data_management::NumericTablePtr conservativeSequence;
248 size_t innerNIterations;
261 struct DAAL_EXPORT Parameter<momentum> :
public BaseParameter
278 const sum_of_functions::BatchPtr&
function,
279 double momentum = 0.9,
280 size_t nIterations = 100,
281 double accuracyThreshold = 1.0e-05,
282 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
283 size_t batchSize = 128,
284 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
285 new data_management::HomogenNumericTable<double>(
286 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
294 virtual services::Status check()
const;
296 virtual ~Parameter() {}
310 class DAAL_EXPORT Input :
public optimization_solver::iterative_solver::Input
313 typedef optimization_solver::iterative_solver::Input super;
315 Input(
const Input& other);
324 data_management::NumericTablePtr
get(OptionalDataId id)
const;
331 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
340 virtual services::Status check(
const daal::algorithms::Parameter *par,
int method)
const DAAL_C11_OVERRIDE;
350 class DAAL_EXPORT Result :
public optimization_solver::iterative_solver::Result
353 DECLARE_SERIALIZABLE_CAST(Result);
354 typedef optimization_solver::iterative_solver::Result super;
368 template <
typename algorithmFPType>
369 DAAL_EXPORT services::Status allocate(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
const int method);
376 data_management::NumericTablePtr
get(OptionalDataId id)
const;
383 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
393 virtual services::Status check(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
394 int method)
const DAAL_C11_OVERRIDE;
397 typedef services::SharedPtr<Result> ResultPtr;
402 using interface1::BaseParameter;
403 using interface1::Parameter;
404 using interface1::Input;
405 using interface1::Result;
406 using interface1::ResultPtr;
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:150
daal::algorithms::optimization_solver::sgd::interface1::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:247
daal
Definition: algorithm_base_common.h:57
daal::algorithms::optimization_solver::sgd::interface1::Parameter
Definition: sgd_types.h:156
daal::algorithms::optimization_solver::sgd::miniBatch
Definition: sgd_types.h:83
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::seed
size_t seed
Definition: sgd_types.h:148
daal::algorithms::em_gmm::nIterations
Definition: em_gmm_types.h:123
daal::algorithms::optimization_solver::sgd::OptionalDataId
OptionalDataId
Definition: sgd_types.h:91
daal::algorithms::optimization_solver::sgd::interface1::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:350
daal::algorithms::optimization_solver::sgd::defaultDense
Definition: sgd_types.h:82
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:144
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:147
daal::algorithms::optimization_solver::sgd::interface1::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:298
daal::algorithms::optimization_solver::sgd::momentum
Definition: sgd_types.h:84
daal::algorithms::optimization_solver::sgd::pastUpdateVector
Definition: sgd_types.h:93
daal::algorithms::optimization_solver::sgd::Method
Method
Definition: sgd_types.h:80
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:111
daal::algorithms::optimization_solver::sgd::pastWorkValue
Definition: sgd_types.h:94