22 from daal.algorithms.decision_tree.classification
import prediction, training
23 from daal.algorithms
import classifier
24 from daal.data_management
import (
25 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
27 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
28 if utils_folder
not in sys.path:
29 sys.path.insert(0, utils_folder)
30 from utils
import printNumericTables
32 DAAL_PREFIX = os.path.join(
'..',
'data')
35 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_train.csv')
36 pruneDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_prune.csv')
37 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_test.csv')
44 predictionResult =
None
45 testGroundTruth =
None
52 trainDataSource = FileDataSource(
54 DataSourceIface.notAllocateNumericTable,
55 DataSourceIface.doDictionaryFromContext
59 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
60 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
61 mergedData = MergedNumericTable(trainData, trainGroundTruth)
64 trainDataSource.loadDataBlock(mergedData)
67 pruneDataSource = FileDataSource(
69 DataSourceIface.notAllocateNumericTable,
70 DataSourceIface.doDictionaryFromContext
74 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
75 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
76 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
79 pruneDataSource.loadDataBlock(pruneMergedData)
82 algorithm = training.Batch(nClasses)
85 algorithm.input.set(classifier.training.data, trainData)
86 algorithm.input.set(classifier.training.labels, trainGroundTruth)
87 algorithm.input.setTable(training.dataForPruning, pruneData)
88 algorithm.input.setTable(training.labelsForPruning, pruneGroundTruth)
91 trainingResult = algorithm.compute()
92 model = trainingResult.get(classifier.training.model)
95 global testGroundTruth, predictionResult
98 testDataSource = FileDataSource(
100 DataSourceIface.notAllocateNumericTable,
101 DataSourceIface.doDictionaryFromContext
105 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
106 testGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
107 mergedData = MergedNumericTable(testData, testGroundTruth)
110 testDataSource.loadDataBlock(mergedData)
113 algorithm = prediction.Batch()
117 algorithm.input.setTable(classifier.prediction.data, testData)
118 algorithm.input.setModel(classifier.prediction.model, model)
122 predictionResult = algorithm.compute()
129 predictionResult.get(classifier.prediction.prediction),
130 "Ground truth",
"Classification results",
131 "Decision tree classification results (first 20 observations):",
135 if __name__ ==
"__main__":