55 from __future__
import print_function
57 from daal.algorithms
import regression
58 from daal.algorithms
import decision_tree
59 import daal.algorithms.decision_tree.regression
60 import daal.algorithms.decision_tree.regression.training
62 from daal.data_management
import FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
65 trainDatasetFileName =
"../data/batch/decision_tree_train.csv" 66 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv" 74 trainDataSource = FileDataSource(
75 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
79 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
80 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
81 mergedData = MergedNumericTable(trainData, trainGroundTruth)
84 trainDataSource.loadDataBlock(mergedData)
87 pruneDataSource = FileDataSource(
88 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
92 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
93 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
94 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
97 pruneDataSource.loadDataBlock(pruneMergedData)
100 algorithm = decision_tree.regression.training.Batch()
103 algorithm.input.set(decision_tree.regression.training.data, trainData)
104 algorithm.input.set(decision_tree.regression.training.dependentVariables, trainGroundTruth)
105 algorithm.input.set(decision_tree.regression.training.dataForPruning, pruneData)
106 algorithm.input.set(decision_tree.regression.training.dependentVariablesForPruning, pruneGroundTruth)
109 return algorithm.compute()
113 class PrintNodeVisitor(regression.TreeNodeVisitor):
116 super(PrintNodeVisitor, self).__init__()
118 def onLeafNode(self, level, response):
120 for i
in range(level):
122 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
126 def onSplitNode(self, level, featureIndex, featureValue):
128 for i
in range(level):
130 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
135 visitor = PrintNodeVisitor()
136 m.traverseDF(visitor)
138 if __name__ ==
"__main__":
140 trainingResult = trainModel()
141 printModel(trainingResult.get(decision_tree.regression.training.model))