22 from daal.algorithms.decision_tree.regression
import prediction, training
23 from daal.data_management
import (
24 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
27 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
28 if utils_folder
not in sys.path:
29 sys.path.insert(0, utils_folder)
30 from utils
import printNumericTables
32 DAAL_PREFIX = os.path.join(
'..',
'data')
35 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_train.csv')
36 pruneDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_prune.csv')
37 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_test.csv')
43 predictionResult =
None
44 testGroundTruth =
None
51 trainDataSource = FileDataSource(
53 DataSourceIface.notAllocateNumericTable,
54 DataSourceIface.doDictionaryFromContext
58 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
59 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
60 mergedData = MergedNumericTable(trainData, trainGroundTruth)
63 trainDataSource.loadDataBlock(mergedData)
66 pruneDataSource = FileDataSource(
68 DataSourceIface.notAllocateNumericTable,
69 DataSourceIface.doDictionaryFromContext
73 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
74 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
75 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
78 pruneDataSource.loadDataBlock(pruneMergedData)
81 algorithm = training.Batch()
84 algorithm.input.set(training.data, trainData)
85 algorithm.input.set(training.dependentVariables, trainGroundTruth)
86 algorithm.input.set(training.dataForPruning, pruneData)
87 algorithm.input.set(training.dependentVariablesForPruning, pruneGroundTruth)
90 trainingResult = algorithm.compute()
91 model = trainingResult.get(training.model)
94 global testGroundTruth, predictionResult
97 testDataSource = FileDataSource(
99 DataSourceIface.notAllocateNumericTable,
100 DataSourceIface.doDictionaryFromContext
104 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
105 testGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
106 mergedData = MergedNumericTable(testData, testGroundTruth)
109 testDataSource.loadDataBlock(mergedData)
112 algorithm = prediction.Batch()
116 algorithm.input.setTable(prediction.data, testData)
117 algorithm.input.setModel(prediction.model, model)
120 predictionResult = algorithm.compute()
125 printNumericTables(testGroundTruth, predictionResult.get(prediction.prediction),
126 "Ground truth",
"Regression results",
127 "Decision tree regression results (first 20 observations):",
130 if __name__ ==
"__main__":