54 from __future__
import print_function
56 from daal.algorithms
import classifier
57 from daal.algorithms
import decision_forest
58 import daal.algorithms.decision_forest.classification
59 import daal.algorithms.decision_forest.classification.training
61 from daal.data_management
import (
62 FileDataSource, HomogenNumericTable, MergedNumericTable, NumericTableIface, DataSourceIface, data_feature_utils
66 trainDatasetFileName =
"../data/batch/df_classification_train.csv"
67 categoricalFeaturesIndices = [2]
72 minObservationsInLeafNode = 8
81 trainData, trainDependentVariable = loadData(trainDatasetFileName)
84 algorithm = decision_forest.classification.training.Batch(nClasses)
87 algorithm.input.set(classifier.training.data, trainData)
88 algorithm.input.set(classifier.training.labels, trainDependentVariable)
90 algorithm.parameter.nTrees = nTrees
91 algorithm.parameter.featuresPerNode = nFeatures
92 algorithm.parameter.minObservationsInLeafNode = minObservationsInLeafNode
93 algorithm.parameter.maxTreeDepth = maxTreeDepth
96 return algorithm.compute()
99 def loadData(fileName):
102 trainDataSource = FileDataSource(
103 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
107 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
108 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
109 mergedData = MergedNumericTable(data, dependentVar)
112 trainDataSource.loadDataBlock(mergedData)
114 dictionary = data.getDictionary()
115 for i
in range(len(categoricalFeaturesIndices)):
116 dictionary[categoricalFeaturesIndices[i]].featureType = data_feature_utils.DAAL_CATEGORICAL
118 return data, dependentVar
122 class PrintNodeVisitor(classifier.TreeNodeVisitor):
125 super(PrintNodeVisitor, self).__init__()
127 def onLeafNode(self, level, response):
129 for i
in range(level):
131 print(
"Level {}, leaf node. Response value = {}".format(level, response))
134 def onSplitNode(self, level, featureIndex, featureValue):
136 for i
in range(level):
138 print(
"Level {}, split node. Feature index = {}, feature value = {:.6g}".format(level, featureIndex, featureValue))
143 visitor = PrintNodeVisitor()
144 print(
"Number of trees: {}".format(m.numberOfTrees()))
145 for i
in range(m.numberOfTrees()):
146 print(
"Tree #{}".format(i))
147 m.traverseDF(i, visitor)
150 if __name__ ==
"__main__":
152 trainingResult = trainModel()
153 printModel(trainingResult.get(classifier.training.model))