22 import daal.algorithms.implicit_als.prediction.ratings
as ratings
23 import daal.algorithms.implicit_als.training
as training
24 import daal.algorithms.implicit_als.training.init
as init
26 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
27 if utils_folder
not in sys.path:
28 sys.path.insert(0, utils_folder)
29 from utils
import printNumericTable, createSparseTable
31 DAAL_PREFIX = os.path.join(
'..',
'data')
34 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'implicit_als_csr.csv')
44 def initializeModel():
45 global initialModel, dataTable
48 dataTable = createSparseTable(trainDatasetFileName)
51 initAlgorithm = init.Batch(method=init.fastCSR)
52 initAlgorithm.parameter.nFactors = nFactors
55 initAlgorithm.input.set(init.data, dataTable)
58 res = initAlgorithm.compute()
60 initialModel = res.get(init.model)
67 algorithm = training.Batch(method=training.fastCSR)
70 algorithm.input.setTable(training.data, dataTable)
71 algorithm.input.setModel(training.inputModel, initialModel)
73 algorithm.parameter.nFactors = nFactors
77 trainingResult = algorithm.compute()
83 algorithm = ratings.Batch()
84 algorithm.parameter.nFactors = nFactors
86 algorithm.input.set(ratings.model, trainingResult.get(training.model))
88 res = algorithm.compute()
90 predictedRatings = res.get(ratings.prediction)
92 printNumericTable(predictedRatings,
"Predicted ratings:")
94 if __name__ ==
"__main__":