47 from daal.algorithms.decision_tree.classification
import prediction, training
48 from daal.algorithms
import classifier
49 from daal.data_management
import (
50 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
52 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
53 if utils_folder
not in sys.path:
54 sys.path.insert(0, utils_folder)
55 from utils
import printNumericTables
57 DAAL_PREFIX = os.path.join(
'..',
'data')
60 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_train.csv')
61 pruneDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_prune.csv')
62 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_test.csv')
69 predictionResult =
None
70 testGroundTruth =
None
77 trainDataSource = FileDataSource(
79 DataSourceIface.notAllocateNumericTable,
80 DataSourceIface.doDictionaryFromContext
84 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
85 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
86 mergedData = MergedNumericTable(trainData, trainGroundTruth)
89 trainDataSource.loadDataBlock(mergedData)
92 pruneDataSource = FileDataSource(
94 DataSourceIface.notAllocateNumericTable,
95 DataSourceIface.doDictionaryFromContext
99 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
100 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
101 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
104 pruneDataSource.loadDataBlock(pruneMergedData)
107 algorithm = training.Batch(nClasses)
110 algorithm.input.set(classifier.training.data, trainData)
111 algorithm.input.set(classifier.training.labels, trainGroundTruth)
112 algorithm.input.setTable(training.dataForPruning, pruneData)
113 algorithm.input.setTable(training.labelsForPruning, pruneGroundTruth)
116 trainingResult = algorithm.compute()
117 model = trainingResult.get(classifier.training.model)
120 global testGroundTruth, predictionResult
123 testDataSource = FileDataSource(
125 DataSourceIface.notAllocateNumericTable,
126 DataSourceIface.doDictionaryFromContext
130 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
131 testGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
132 mergedData = MergedNumericTable(testData, testGroundTruth)
135 testDataSource.loadDataBlock(mergedData)
138 algorithm = prediction.Batch()
141 print(
"Number of columns: {}".format(testData.getNumberOfColumns()))
142 algorithm.input.setTable(classifier.prediction.data, testData)
143 algorithm.input.setModel(classifier.prediction.model, model)
147 predictionResult = algorithm.compute()
154 predictionResult.get(classifier.prediction.prediction),
155 "Ground truth",
"Classification results",
156 "Decision tree classification results (first 20 observations):",
160 if __name__ ==
"__main__":