47 from daal.algorithms.multinomial_naive_bayes
import prediction, training
48 from daal.algorithms
import classifier
49 from daal.data_management
import FileDataSource, DataSourceIface
51 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
52 if utils_folder
not in sys.path:
53 sys.path.insert(0, utils_folder)
54 from utils
import printNumericTables, createSparseTable
56 DAAL_PREFIX = os.path.join(
'..',
'data')
59 trainDatasetFileNames = [
60 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
61 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
62 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
63 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv')
66 trainGroundTruthFileNames = [
67 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
68 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
69 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
70 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv')
73 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_csr.csv')
74 testGroundTruthFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_labels.csv')
76 nTrainVectorsInBlock = 8000
77 nTestObservations = 2000
82 predictionResult =
None
83 trainData = [0] * nBlocks
88 global trainData, trainingResult
91 algorithm = training.Online(nClasses, method=training.fastCSR)
93 for i
in range(nBlocks):
95 trainData[i] = createSparseTable(trainDatasetFileNames[i])
96 trainLabelsSource = FileDataSource(
97 trainGroundTruthFileNames[i], DataSourceIface.doAllocateNumericTable,
98 DataSourceIface.doDictionaryFromContext
101 trainLabelsSource.loadDataBlock(nTrainVectorsInBlock)
104 algorithm.input.set(classifier.training.data, trainData[i])
105 algorithm.input.set(classifier.training.labels, trainLabelsSource.getNumericTable())
111 trainingResult = algorithm.finalizeCompute()
115 global predictionResult, testData
118 testData = createSparseTable(testDatasetFileName)
121 algorithm = prediction.Batch(nClasses, method=prediction.fastCSR)
124 algorithm.input.setTable(classifier.prediction.data, testData)
125 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
128 predictionResult = algorithm.compute()
133 testGroundTruth = FileDataSource(
134 testGroundTruthFileName, DataSourceIface.doAllocateNumericTable,
135 DataSourceIface.doDictionaryFromContext
137 testGroundTruth.loadDataBlock(nTestObservations)
140 testGroundTruth.getNumericTable(),
141 predictionResult.get(classifier.prediction.prediction),
142 "Ground truth",
"Classification results",
143 "NaiveBayes classification results (first 20 observations):", 20, 15, flt64=
False
146 if __name__ ==
"__main__":