48 from daal.algorithms.decision_tree.classification
import prediction, training
49 from daal.algorithms
import classifier
50 from daal.data_management
import (
51 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
53 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
54 if utils_folder
not in sys.path:
55 sys.path.insert(0, utils_folder)
56 from utils
import printNumericTables
58 DAAL_PREFIX = os.path.join(
'..',
'data')
61 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_train.csv')
62 pruneDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_prune.csv')
63 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_test.csv')
70 predictionResult =
None 71 testGroundTruth =
None 78 trainDataSource = FileDataSource(
80 DataSourceIface.notAllocateNumericTable,
81 DataSourceIface.doDictionaryFromContext
85 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
86 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
87 mergedData = MergedNumericTable(trainData, trainGroundTruth)
90 trainDataSource.loadDataBlock(mergedData)
93 pruneDataSource = FileDataSource(
95 DataSourceIface.notAllocateNumericTable,
96 DataSourceIface.doDictionaryFromContext
100 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
101 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
102 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
105 pruneDataSource.loadDataBlock(pruneMergedData)
108 algorithm = training.Batch(nClasses)
111 algorithm.input.set(classifier.training.data, trainData)
112 algorithm.input.set(classifier.training.labels, trainGroundTruth)
113 algorithm.input.setTable(training.dataForPruning, pruneData)
114 algorithm.input.setTable(training.labelsForPruning, pruneGroundTruth)
117 trainingResult = algorithm.compute()
118 model = trainingResult.get(classifier.training.model)
121 global testGroundTruth, predictionResult
124 testDataSource = FileDataSource(
126 DataSourceIface.notAllocateNumericTable,
127 DataSourceIface.doDictionaryFromContext
131 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
132 testGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
133 mergedData = MergedNumericTable(testData, testGroundTruth)
136 testDataSource.loadDataBlock(mergedData)
139 algorithm = prediction.Batch()
143 algorithm.input.setTable(classifier.prediction.data, testData)
144 algorithm.input.setModel(classifier.prediction.model, model)
148 predictionResult = algorithm.compute()
155 predictionResult.get(classifier.prediction.prediction),
156 "Ground truth",
"Classification results",
157 "Decision tree classification results (first 20 observations):",
161 if __name__ ==
"__main__":