22 from daal
import step1Local, step2Master
23 from daal.algorithms
import classifier
24 from daal.algorithms.multinomial_naive_bayes
import training, prediction
25 from daal.data_management
import FileDataSource, DataSourceIface
27 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
28 if utils_folder
not in sys.path:
29 sys.path.insert(0, utils_folder)
30 from utils
import printNumericTables, createSparseTable
32 DAAL_PREFIX = os.path.join(
'..',
'data')
35 trainDatasetFileNames = [
36 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
37 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
38 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
39 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv')
42 trainGroundTruthFileNames = [
43 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
44 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
45 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
46 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv')
49 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_csr.csv')
50 testGroundTruthFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_labels.csv')
54 nTrainVectorsInBlock = 8000
55 nTestObservations = 2000
58 predictionResult =
None
59 trainData = [0] * nBlocks
64 global trainData, trainingResult
66 masterAlgorithm = training.Distributed(step2Master, nClasses, method=training.fastCSR)
68 for i
in range(nBlocks):
70 trainData[i] = createSparseTable(trainDatasetFileNames[i])
73 trainLabelsSource = FileDataSource(
74 trainGroundTruthFileNames[i], DataSourceIface.doAllocateNumericTable,
75 DataSourceIface.doDictionaryFromContext
79 trainLabelsSource.loadDataBlock(nTrainVectorsInBlock)
82 localAlgorithm = training.Distributed(step1Local, nClasses, method=training.fastCSR)
85 localAlgorithm.input.set(classifier.training.data, trainData[i])
86 localAlgorithm.input.set(classifier.training.labels, trainLabelsSource.getNumericTable())
90 masterAlgorithm.input.add(training.partialModels, localAlgorithm.compute())
93 masterAlgorithm.compute()
94 trainingResult = masterAlgorithm.finalizeCompute()
98 global predictionResult, testData
101 testData = createSparseTable(testDatasetFileName)
104 algorithm = prediction.Batch(nClasses, method=prediction.fastCSR)
107 algorithm.input.setTable(classifier.prediction.data, testData)
108 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
111 predictionResult = algorithm.compute()
116 testGroundTruth = FileDataSource(
117 testGroundTruthFileName, DataSourceIface.doAllocateNumericTable,
118 DataSourceIface.doDictionaryFromContext
120 testGroundTruth.loadDataBlock(nTestObservations)
123 testGroundTruth.getNumericTable(),
124 predictionResult.get(classifier.prediction.prediction),
125 "Ground truth",
"Classification results",
126 "NaiveBayes classification results (first 20 observations):", 20, 15, flt64=
False
129 if __name__ ==
"__main__":