28 from __future__
import print_function
30 from daal.algorithms
import classifier
31 from daal.algorithms
import decision_tree
32 import daal.algorithms.decision_tree.classification
33 import daal.algorithms.decision_tree.classification.training
35 from daal.data_management
import (
36 DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, FileDataSource
40 trainDatasetFileName =
"../data/batch/decision_tree_train.csv"
41 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv"
50 trainDataSource = FileDataSource(
51 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
55 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
56 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
57 mergedData = MergedNumericTable(trainData, trainGroundTruth)
60 trainDataSource.loadDataBlock(mergedData)
63 pruneDataSource = FileDataSource(
64 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
68 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
69 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
70 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
73 pruneDataSource.loadDataBlock(pruneMergedData)
76 algorithm = decision_tree.classification.training.Batch(nClasses)
79 algorithm.input.set(classifier.training.data, trainData)
80 algorithm.input.set(classifier.training.labels, trainGroundTruth)
81 algorithm.input.set(decision_tree.classification.training.dataForPruning, pruneData)
82 algorithm.input.set(decision_tree.classification.training.labelsForPruning, pruneGroundTruth)
85 return algorithm.compute()
90 class PrintNodeVisitor(classifier.TreeNodeVisitor):
93 super(PrintNodeVisitor, self).__init__()
95 def onLeafNode(self, level, response):
97 for i
in range(level):
99 print(
"Level {}, leaf node. Response value = {}".format(level, response))
103 def onSplitNode(self, level, featureIndex, featureValue):
105 for i
in range(level):
107 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
113 visitor = PrintNodeVisitor()
114 m.traverseDF(visitor)
117 if __name__ ==
"__main__":
119 trainingResult = trainModel()
120 printModel(trainingResult.get(classifier.training.model))