#include "daal.h"
#include "service.h"
using namespace std;
using namespace daal;
using namespace daal::algorithms::implicit_als;
string trainDatasetFileName = "../data/batch/implicit_als_dense.csv";
const size_t nFactors = 2;
NumericTablePtr dataTable;
ModelPtr initialModel;
training::ResultPtr trainingResult;
void initializeModel();
void trainModel();
void testModel();
int main(int argc, char *argv[])
{
checkArguments(argc, argv, 1, &trainDatasetFileName);
initializeModel();
trainModel();
testModel();
return 0;
}
void initializeModel()
{
FileDataSource<CSVFeatureManager> dataSource(trainDatasetFileName, DataSource::doAllocateNumericTable,
DataSource::doDictionaryFromContext);
dataSource.loadDataBlock();
dataTable = dataSource.getNumericTable();
training::init::Batch<> initAlgorithm;
initAlgorithm.parameter.nFactors = nFactors;
initAlgorithm.input.set(training::init::data, dataTable);
initAlgorithm.compute();
initialModel = initAlgorithm.getResult()->get(training::init::model);
}
void trainModel()
{
training::Batch<> algorithm;
algorithm.input.set(training::data, dataTable);
algorithm.input.set(training::inputModel, initialModel);
algorithm.parameter.nFactors = nFactors;
algorithm.compute();
trainingResult = algorithm.getResult();
}
void testModel()
{
prediction::ratings::Batch<> algorithm;
algorithm.parameter.nFactors = nFactors;
algorithm.input.set(prediction::ratings::model, trainingResult->get(training::model));
algorithm.compute();
NumericTablePtr predictedRatings = algorithm.getResult()->get(prediction::ratings::prediction);
printNumericTable(predictedRatings, "Predicted ratings:");
}