#include "daal.h"
#include "service.h"
using namespace std;
using namespace daal;
using namespace daal::algorithms;
using namespace daal::data_management;
const string datasetFileName = "../data/batch/XM_100.csv";
const string groundTruthFileName = "../data/batch/saga_solution_100_features.csv";
const size_t nFeatures = 100;
const size_t nIterations = 100000;
const float tol = 0.00000001;
int main(int argc, char *argv[])
{
checkArguments(argc, argv, 1, &datasetFileName);
FileDataSource<CSVFeatureManager> dataSource(datasetFileName,
DataSource::notAllocateNumericTable,
DataSource::doDictionaryFromContext);
NumericTablePtr data(new HomogenNumericTable<float>(nFeatures, 0, NumericTable::doNotAllocate));
NumericTablePtr dependentVariables(new HomogenNumericTable<float>(1, 0, NumericTable::doNotAllocate));
NumericTablePtr mergedData(new MergedNumericTable(data, dependentVariables));
dataSource.loadDataBlock(mergedData.get());
services::SharedPtr<optimization_solver::logistic_loss::Batch<float> > func(new optimization_solver::logistic_loss::Batch<float>(data->getNumberOfRows()));
func->input.set(optimization_solver::logistic_loss::data, data);
func->input.set(optimization_solver::logistic_loss::dependentVariables, dependentVariables);
const size_t nParameters = (nFeatures + 1);
float argument[nParameters];
for(int i = 0; i < nParameters; i++)
argument[i] = 0;
argument[0] = 0;
argument[1] = 0;
func->parameter().penaltyL1 = 0.06;
func->parameter().penaltyL2 = 0;
func->parameter().interceptFlag = false;
func->parameter().resultsToCompute = optimization_solver::objective_function::gradient;
daal::algorithms::optimization_solver::saga::Batch<float> sagaAlgorithm(func);
sagaAlgorithm.input.set(optimization_solver::iterative_solver::inputArgument,
HomogenNumericTable<float>::create(argument, 1, nParameters));
sagaAlgorithm.parameter().nIterations = nIterations;
sagaAlgorithm.parameter().accuracyThreshold = tol;
sagaAlgorithm.parameter().batchSize = 1;
sagaAlgorithm.compute();
NumericTablePtr munimum = sagaAlgorithm.getResult()->get(optimization_solver::iterative_solver::minimum);
printNumericTable(munimum, "Minimum:");
printNumericTable(sagaAlgorithm.getResult()->get(optimization_solver::iterative_solver::nIterations), "nIterations:");
services::SharedPtr<optimization_solver::logistic_loss::Batch<float> > func_check(new optimization_solver::logistic_loss::Batch<float>(data->getNumberOfRows()));
func_check->input.set(optimization_solver::logistic_loss::data, data);
func_check->input.set(optimization_solver::logistic_loss::dependentVariables, dependentVariables);
func_check->parameter().penaltyL1 = 0.06;
func_check->parameter().penaltyL2 = 0;
func_check->parameter().interceptFlag = false;
func_check->parameter().resultsToCompute = optimization_solver::objective_function::value;
FileDataSource<CSVFeatureManager> groundTruthDS(groundTruthFileName,
DataSource::notAllocateNumericTable,
DataSource::doDictionaryFromContext);
NumericTablePtr groundTruthNT(new HomogenNumericTable<float>(1, 0, NumericTable::doNotAllocate));
groundTruthDS.loadDataBlock(groundTruthNT.get());
func_check->input.set(optimization_solver::logistic_loss::argument, groundTruthNT);
func_check->compute();
printNumericTable(func_check->getResult()->get(optimization_solver::objective_function::valueIdx),"groundTruth:");
func_check->input.set(optimization_solver::logistic_loss::argument, munimum);
func_check->compute();
printNumericTable(func_check->getResult()->get(optimization_solver::objective_function::valueIdx),"value DAAL:");
return 0;
}