36 from __future__
import print_function
38 from daal.algorithms
import regression
39 from daal.algorithms
import decision_tree
40 import daal.algorithms.decision_tree.regression
41 import daal.algorithms.decision_tree.regression.training
43 from daal.data_management
import FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
46 trainDatasetFileName =
"../data/batch/decision_tree_train.csv"
47 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv"
55 trainDataSource = FileDataSource(
56 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
60 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
61 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
62 mergedData = MergedNumericTable(trainData, trainGroundTruth)
65 trainDataSource.loadDataBlock(mergedData)
68 pruneDataSource = FileDataSource(
69 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
73 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
74 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
75 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
78 pruneDataSource.loadDataBlock(pruneMergedData)
81 algorithm = decision_tree.regression.training.Batch()
84 algorithm.input.set(decision_tree.regression.training.data, trainData)
85 algorithm.input.set(decision_tree.regression.training.dependentVariables, trainGroundTruth)
86 algorithm.input.set(decision_tree.regression.training.dataForPruning, pruneData)
87 algorithm.input.set(decision_tree.regression.training.dependentVariablesForPruning, pruneGroundTruth)
90 return algorithm.compute()
94 class PrintNodeVisitor(regression.TreeNodeVisitor):
97 super(PrintNodeVisitor, self).__init__()
99 def onLeafNode(self, level, response):
101 for i
in range(level):
103 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
107 def onSplitNode(self, level, featureIndex, featureValue):
109 for i
in range(level):
111 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
116 visitor = PrintNodeVisitor()
117 m.traverseDF(visitor)
119 if __name__ ==
"__main__":
121 trainingResult = trainModel()
122 printModel(trainingResult.get(decision_tree.regression.training.model))