35 from __future__
import print_function
37 from daal.algorithms
import classifier
38 from daal.algorithms
import decision_tree
39 import daal.algorithms.decision_tree.classification
40 import daal.algorithms.decision_tree.classification.training
42 from daal.data_management
import (
43 DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, FileDataSource
47 trainDatasetFileName =
"../data/batch/decision_tree_train.csv"
48 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv"
57 trainDataSource = FileDataSource(
58 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
62 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
63 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
64 mergedData = MergedNumericTable(trainData, trainGroundTruth)
67 trainDataSource.loadDataBlock(mergedData)
70 pruneDataSource = FileDataSource(
71 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
75 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
76 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
77 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
80 pruneDataSource.loadDataBlock(pruneMergedData)
83 algorithm = decision_tree.classification.training.Batch(nClasses)
86 algorithm.input.set(classifier.training.data, trainData)
87 algorithm.input.set(classifier.training.labels, trainGroundTruth)
88 algorithm.input.set(decision_tree.classification.training.dataForPruning, pruneData)
89 algorithm.input.set(decision_tree.classification.training.labelsForPruning, pruneGroundTruth)
92 return algorithm.compute()
97 class PrintNodeVisitor(classifier.TreeNodeVisitor):
100 super(PrintNodeVisitor, self).__init__()
102 def onLeafNode(self, level, response):
104 for i
in range(level):
106 print(
"Level {}, leaf node. Response value = {}".format(level, response))
110 def onSplitNode(self, level, featureIndex, featureValue):
112 for i
in range(level):
114 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
120 visitor = PrintNodeVisitor()
121 m.traverseDF(visitor)
124 if __name__ ==
"__main__":
126 trainingResult = trainModel()
127 printModel(trainingResult.get(classifier.training.model))