57 import daal.algorithms.optimization_solver
as optimization_solver
58 import daal.algorithms.optimization_solver.mse
59 import daal.algorithms.optimization_solver.sgd
60 import daal.algorithms.optimization_solver.iterative_solver
62 from daal.data_management
import (
63 DataSourceIface, FileDataSource, HomogenNumericTable, MergedNumericTable, NumericTableIface
66 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
67 if utils_folder
not in sys.path:
68 sys.path.insert(0, utils_folder)
69 from utils
import printNumericTable
71 datasetFileName = os.path.join(
'..',
'data',
'batch',
'mse.csv')
74 accuracyThreshold = 0.0000001
76 nIterations = halfNIterations * 2
80 startPoint = np.array([[8], [2], [1], [4]], dtype=np.float64)
82 if __name__ ==
"__main__":
85 dataSource = FileDataSource(datasetFileName,
86 DataSourceIface.notAllocateNumericTable,
87 DataSourceIface.doDictionaryFromContext)
90 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
91 dependentVariables = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
92 mergedData = MergedNumericTable(data, dependentVariables)
95 dataSource.loadDataBlock(mergedData)
97 nVectors = data.getNumberOfRows()
99 mseObjectiveFunction = optimization_solver.mse.Batch(nVectors)
100 mseObjectiveFunction.input.set(optimization_solver.mse.data, data)
101 mseObjectiveFunction.input.set(optimization_solver.mse.dependentVariables, dependentVariables)
104 sgdAlgorithm = optimization_solver.sgd.Batch(mseObjectiveFunction, method=optimization_solver.sgd.momentum)
107 sgdAlgorithm.input.setInput(optimization_solver.iterative_solver.inputArgument, HomogenNumericTable(startPoint))
108 sgdAlgorithm.parameter.learningRate = HomogenNumericTable(1, 1, NumericTableIface.doAllocate, learningRate)
109 sgdAlgorithm.parameter.nIterations = halfNIterations
110 sgdAlgorithm.parameter.accuracyThreshold = accuracyThreshold
111 sgdAlgorithm.parameter.batchSize = batchSize
112 sgdAlgorithm.parameter.optionalResultRequired =
True 116 res = sgdAlgorithm.compute()
119 printNumericTable(res.getResult(optimization_solver.iterative_solver.minimum),
"Minimum after first compute():")
120 printNumericTable(res.getResult(optimization_solver.iterative_solver.nIterations),
"Number of iterations performed:")
122 sgdAlgorithm.input.setInput(optimization_solver.iterative_solver.inputArgument, res.getResult(optimization_solver.iterative_solver.minimum))
123 sgdAlgorithm.input.setInput(optimization_solver.iterative_solver.optionalArgument, res.getResult(optimization_solver.iterative_solver.optionalResult))
125 res = sgdAlgorithm.compute()
127 printNumericTable(res.getResult(optimization_solver.iterative_solver.minimum),
"Minimum after second compute():")
128 printNumericTable(res.getResult(optimization_solver.iterative_solver.nIterations),
"Number of iterations performed:")