#include "daal.h"
#include "service.h"
using namespace std;
using namespace daal;
using namespace daal::data_management;
using namespace daal::algorithms::implicit_als;
string trainDatasetFileName = "../data/batch/implicit_als_csr.csv";
typedef float algorithmFPType;
const size_t nFactors = 2;
NumericTablePtr dataTable;
ModelPtr initialModel;
training::ResultPtr trainingResult;
void initializeModel();
void trainModel();
void testModel();
int main(int argc, char *argv[])
{
checkArguments(argc, argv, 1, &trainDatasetFileName);
initializeModel();
trainModel();
testModel();
return 0;
}
void initializeModel()
{
dataTable = NumericTablePtr(createSparseTable<float>(trainDatasetFileName));
training::init::Batch<algorithmFPType, training::init::fastCSR> initAlgorithm;
initAlgorithm.parameter.nFactors = nFactors;
initAlgorithm.input.set(training::init::data, dataTable);
initAlgorithm.compute();
initialModel = initAlgorithm.getResult()->get(training::init::model);
}
void trainModel()
{
training::Batch<algorithmFPType, training::fastCSR> algorithm;
algorithm.input.set(training::data, dataTable);
algorithm.input.set(training::inputModel, initialModel);
algorithm.parameter.nFactors = nFactors;
algorithm.compute();
trainingResult = algorithm.getResult();
}
void testModel()
{
prediction::ratings::Batch<> algorithm;
algorithm.parameter.nFactors = nFactors;
algorithm.input.set(prediction::ratings::model, trainingResult->get(training::model));
algorithm.compute();
NumericTablePtr predictedRatings = algorithm.getResult()->get(prediction::ratings::prediction);
printNumericTable(predictedRatings, "Predicted ratings:");
}