48 from daal.algorithms.decision_tree.regression
import prediction, training
49 from daal.data_management
import (
50 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
53 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
54 if utils_folder
not in sys.path:
55 sys.path.insert(0, utils_folder)
56 from utils
import printNumericTables
58 DAAL_PREFIX = os.path.join(
'..',
'data')
61 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_train.csv')
62 pruneDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_prune.csv')
63 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_test.csv')
69 predictionResult =
None
70 testGroundTruth =
None
77 trainDataSource = FileDataSource(
79 DataSourceIface.notAllocateNumericTable,
80 DataSourceIface.doDictionaryFromContext
84 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
85 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
86 mergedData = MergedNumericTable(trainData, trainGroundTruth)
89 trainDataSource.loadDataBlock(mergedData)
92 pruneDataSource = FileDataSource(
94 DataSourceIface.notAllocateNumericTable,
95 DataSourceIface.doDictionaryFromContext
99 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
100 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
101 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
104 pruneDataSource.loadDataBlock(pruneMergedData)
107 algorithm = training.Batch()
110 algorithm.input.set(training.data, trainData)
111 algorithm.input.set(training.dependentVariables, trainGroundTruth)
112 algorithm.input.set(training.dataForPruning, pruneData)
113 algorithm.input.set(training.dependentVariablesForPruning, pruneGroundTruth)
116 trainingResult = algorithm.compute()
117 model = trainingResult.get(training.model)
120 global testGroundTruth, predictionResult
123 testDataSource = FileDataSource(
125 DataSourceIface.notAllocateNumericTable,
126 DataSourceIface.doDictionaryFromContext
130 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
131 testGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
132 mergedData = MergedNumericTable(testData, testGroundTruth)
135 testDataSource.loadDataBlock(mergedData)
138 algorithm = prediction.Batch()
142 algorithm.input.setTable(prediction.data, testData)
143 algorithm.input.setModel(prediction.model, model)
146 predictionResult = algorithm.compute()
151 printNumericTables(testGroundTruth, predictionResult.get(prediction.prediction),
152 "Ground truth",
"Regression results",
153 "Decision tree regression results (first 20 observations):",
156 if __name__ ==
"__main__":