48 from daal
import step1Local, step2Master
49 from daal.algorithms
import classifier
50 from daal.algorithms.multinomial_naive_bayes
import training, prediction
51 from daal.data_management
import FileDataSource, DataSourceIface
53 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
54 if utils_folder
not in sys.path:
55 sys.path.insert(0, utils_folder)
56 from utils
import printNumericTables, createSparseTable
58 DAAL_PREFIX = os.path.join(
'..',
'data')
61 trainDatasetFileNames = [
62 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
63 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
64 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
65 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv')
68 trainGroundTruthFileNames = [
69 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
70 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
71 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
72 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv')
75 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_csr.csv')
76 testGroundTruthFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_labels.csv')
80 nTrainVectorsInBlock = 8000
81 nTestObservations = 2000
84 predictionResult =
None 85 trainData = [0] * nBlocks
90 global trainData, trainingResult
92 masterAlgorithm = training.Distributed(step2Master, nClasses, method=training.fastCSR)
94 for i
in range(nBlocks):
96 trainData[i] = createSparseTable(trainDatasetFileNames[i])
99 trainLabelsSource = FileDataSource(
100 trainGroundTruthFileNames[i], DataSourceIface.doAllocateNumericTable,
101 DataSourceIface.doDictionaryFromContext
105 trainLabelsSource.loadDataBlock(nTrainVectorsInBlock)
108 localAlgorithm = training.Distributed(step1Local, nClasses, method=training.fastCSR)
111 localAlgorithm.input.set(classifier.training.data, trainData[i])
112 localAlgorithm.input.set(classifier.training.labels, trainLabelsSource.getNumericTable())
116 masterAlgorithm.input.add(training.partialModels, localAlgorithm.compute())
119 masterAlgorithm.compute()
120 trainingResult = masterAlgorithm.finalizeCompute()
124 global predictionResult, testData
127 testData = createSparseTable(testDatasetFileName)
130 algorithm = prediction.Batch(nClasses, method=prediction.fastCSR)
133 algorithm.input.setTable(classifier.prediction.data, testData)
134 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
137 predictionResult = algorithm.compute()
142 testGroundTruth = FileDataSource(
143 testGroundTruthFileName, DataSourceIface.doAllocateNumericTable,
144 DataSourceIface.doDictionaryFromContext
146 testGroundTruth.loadDataBlock(nTestObservations)
149 testGroundTruth.getNumericTable(),
150 predictionResult.get(classifier.prediction.prediction),
151 "Ground truth",
"Classification results",
152 "NaiveBayes classification results (first 20 observations):", 20, 15, flt64=
False 155 if __name__ ==
"__main__":