29 from __future__
import print_function
31 from daal.algorithms
import regression
32 from daal.algorithms
import decision_tree
33 import daal.algorithms.decision_tree.regression
34 import daal.algorithms.decision_tree.regression.training
36 from daal.data_management
import FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
39 trainDatasetFileName =
"../data/batch/decision_tree_train.csv"
40 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv"
48 trainDataSource = FileDataSource(
49 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
53 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
54 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
55 mergedData = MergedNumericTable(trainData, trainGroundTruth)
58 trainDataSource.loadDataBlock(mergedData)
61 pruneDataSource = FileDataSource(
62 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
66 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
67 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
68 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
71 pruneDataSource.loadDataBlock(pruneMergedData)
74 algorithm = decision_tree.regression.training.Batch()
77 algorithm.input.set(decision_tree.regression.training.data, trainData)
78 algorithm.input.set(decision_tree.regression.training.dependentVariables, trainGroundTruth)
79 algorithm.input.set(decision_tree.regression.training.dataForPruning, pruneData)
80 algorithm.input.set(decision_tree.regression.training.dependentVariablesForPruning, pruneGroundTruth)
83 return algorithm.compute()
87 class PrintNodeVisitor(regression.TreeNodeVisitor):
90 super(PrintNodeVisitor, self).__init__()
92 def onLeafNode(self, level, response):
94 for i
in range(level):
96 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
100 def onSplitNode(self, level, featureIndex, featureValue):
102 for i
in range(level):
104 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
109 visitor = PrintNodeVisitor()
110 m.traverseDF(visitor)
112 if __name__ ==
"__main__":
114 trainingResult = trainModel()
115 printModel(trainingResult.get(decision_tree.regression.training.model))