35 from __future__
import print_function
37 from daal.algorithms
import classifier
38 from daal.algorithms
import decision_forest
39 import daal.algorithms.decision_forest.classification
40 import daal.algorithms.decision_forest.classification.training
42 from daal.data_management
import (
43 FileDataSource, HomogenNumericTable, MergedNumericTable, NumericTableIface, DataSourceIface, data_feature_utils
47 trainDatasetFileName =
"../data/batch/df_classification_train.csv"
48 categoricalFeaturesIndices = [2]
53 minObservationsInLeafNode = 8
62 trainData, trainDependentVariable = loadData(trainDatasetFileName)
65 algorithm = decision_forest.classification.training.Batch(nClasses)
68 algorithm.input.set(classifier.training.data, trainData)
69 algorithm.input.set(classifier.training.labels, trainDependentVariable)
71 algorithm.parameter.nTrees = nTrees
72 algorithm.parameter.featuresPerNode = nFeatures
73 algorithm.parameter.minObservationsInLeafNode = minObservationsInLeafNode
74 algorithm.parameter.maxTreeDepth = maxTreeDepth
77 return algorithm.compute()
80 def loadData(fileName):
83 trainDataSource = FileDataSource(
84 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
88 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
89 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
90 mergedData = MergedNumericTable(data, dependentVar)
93 trainDataSource.loadDataBlock(mergedData)
95 dictionary = data.getDictionary()
96 for i
in range(len(categoricalFeaturesIndices)):
97 dictionary[categoricalFeaturesIndices[i]].featureType = data_feature_utils.DAAL_CATEGORICAL
99 return data, dependentVar
103 class PrintNodeVisitor(classifier.TreeNodeVisitor):
106 super(PrintNodeVisitor, self).__init__()
108 def onLeafNode(self, level, response):
110 for i
in range(level):
112 print(
"Level {}, leaf node. Response value = {}".format(level, response))
115 def onSplitNode(self, level, featureIndex, featureValue):
117 for i
in range(level):
119 print(
"Level {}, split node. Feature index = {}, feature value = {:.6g}".format(level, featureIndex, featureValue))
124 visitor = PrintNodeVisitor()
125 print(
"Number of trees: {}".format(m.numberOfTrees()))
126 for i
in range(m.numberOfTrees()):
127 print(
"Tree #{}".format(i))
128 m.traverseDF(i, visitor)
131 if __name__ ==
"__main__":
133 trainingResult = trainModel()
134 printModel(trainingResult.get(classifier.training.model))