31 import daal.algorithms.optimization_solver
as optimization_solver
32 import daal.algorithms.optimization_solver.mse
33 import daal.algorithms.optimization_solver.sgd
34 import daal.algorithms.optimization_solver.iterative_solver
36 from daal.data_management
import (
37 DataSourceIface, FileDataSource, HomogenNumericTable, MergedNumericTable, NumericTableIface
40 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
41 if utils_folder
not in sys.path:
42 sys.path.insert(0, utils_folder)
43 from utils
import printNumericTable
45 datasetFileName = os.path.join(
'..',
'data',
'batch',
'mse.csv')
48 accuracyThreshold = 0.0000001
52 initialPoint = np.array([[8], [2], [1], [4]], dtype=np.float64)
54 if __name__ ==
"__main__":
57 dataSource = FileDataSource(datasetFileName,
58 DataSourceIface.notAllocateNumericTable,
59 DataSourceIface.doDictionaryFromContext)
62 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
63 dependentVariables = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
64 mergedData = MergedNumericTable(data, dependentVariables)
67 dataSource.loadDataBlock(mergedData)
69 nVectors = data.getNumberOfRows()
71 mseObjectiveFunction = optimization_solver.mse.Batch(nVectors)
72 mseObjectiveFunction.input.set(optimization_solver.mse.data, data)
73 mseObjectiveFunction.input.set(optimization_solver.mse.dependentVariables, dependentVariables)
76 sgdMomentumAlgorithm = optimization_solver.sgd.Batch(mseObjectiveFunction, method=optimization_solver.sgd.momentum)
79 sgdMomentumAlgorithm.input.setInput(optimization_solver.iterative_solver.inputArgument,
80 HomogenNumericTable(initialPoint))
81 sgdMomentumAlgorithm.parameter.learningRateSequence = HomogenNumericTable(1, 1, NumericTableIface.doAllocate,
83 sgdMomentumAlgorithm.parameter.nIterations = nIterations
84 sgdMomentumAlgorithm.parameter.batchSize = batchSize
85 sgdMomentumAlgorithm.parameter.accuracyThreshold = accuracyThreshold
89 res = sgdMomentumAlgorithm.compute()
92 printNumericTable(res.getResult(optimization_solver.iterative_solver.minimum),
"Minimum")
93 printNumericTable(res.getResult(optimization_solver.iterative_solver.nIterations),
"Number of iterations performed:")