54 from __future__
import print_function
56 from daal.algorithms
import classifier
57 from daal.algorithms
import decision_tree
58 import daal.algorithms.decision_tree.classification
59 import daal.algorithms.decision_tree.classification.training
61 from daal.data_management
import (
62 DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, FileDataSource
66 trainDatasetFileName =
"../data/batch/decision_tree_train.csv"
67 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv"
76 trainDataSource = FileDataSource(
77 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
81 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
82 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
83 mergedData = MergedNumericTable(trainData, trainGroundTruth)
86 trainDataSource.loadDataBlock(mergedData)
89 pruneDataSource = FileDataSource(
90 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
94 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
95 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
96 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
99 pruneDataSource.loadDataBlock(pruneMergedData)
102 algorithm = decision_tree.classification.training.Batch(nClasses)
105 algorithm.input.set(classifier.training.data, trainData)
106 algorithm.input.set(classifier.training.labels, trainGroundTruth)
107 algorithm.input.set(decision_tree.classification.training.dataForPruning, pruneData)
108 algorithm.input.set(decision_tree.classification.training.labelsForPruning, pruneGroundTruth)
111 return algorithm.compute()
116 class PrintNodeVisitor(classifier.TreeNodeVisitor):
119 super(PrintNodeVisitor, self).__init__()
121 def onLeafNode(self, level, response):
123 for i
in range(level):
125 print(
"Level {}, leaf node. Response value = {}".format(level, response))
129 def onSplitNode(self, level, featureIndex, featureValue):
131 for i
in range(level):
133 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
139 visitor = PrintNodeVisitor()
140 m.traverseDF(visitor)
143 if __name__ ==
"__main__":
145 trainingResult = trainModel()
146 printModel(trainingResult.get(classifier.training.model))