28 from __future__
import print_function
30 from daal.algorithms
import classifier
31 from daal.algorithms
import decision_forest
32 import daal.algorithms.decision_forest.classification
33 import daal.algorithms.decision_forest.classification.training
35 from daal.data_management
import (
36 FileDataSource, HomogenNumericTable, MergedNumericTable, NumericTableIface, DataSourceIface, features
40 trainDatasetFileName =
"../data/batch/df_classification_train.csv"
41 categoricalFeaturesIndices = [2]
46 minObservationsInLeafNode = 8
55 trainData, trainDependentVariable = loadData(trainDatasetFileName)
58 algorithm = decision_forest.classification.training.Batch(nClasses)
61 algorithm.input.set(classifier.training.data, trainData)
62 algorithm.input.set(classifier.training.labels, trainDependentVariable)
64 algorithm.parameter.nTrees = nTrees
65 algorithm.parameter.featuresPerNode = nFeatures
66 algorithm.parameter.minObservationsInLeafNode = minObservationsInLeafNode
67 algorithm.parameter.maxTreeDepth = maxTreeDepth
70 return algorithm.compute()
73 def loadData(fileName):
76 trainDataSource = FileDataSource(
77 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
81 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
82 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
83 mergedData = MergedNumericTable(data, dependentVar)
86 trainDataSource.loadDataBlock(mergedData)
88 dictionary = data.getDictionary()
89 for i
in range(len(categoricalFeaturesIndices)):
90 dictionary[categoricalFeaturesIndices[i]].featureType = features.DAAL_CATEGORICAL
92 return data, dependentVar
96 class PrintNodeVisitor(classifier.TreeNodeVisitor):
99 super(PrintNodeVisitor, self).__init__()
101 def onLeafNode(self, level, response):
103 for i
in range(level):
105 print(
"Level {}, leaf node. Response value = {}".format(level, response))
108 def onSplitNode(self, level, featureIndex, featureValue):
110 for i
in range(level):
112 print(
"Level {}, split node. Feature index = {}, feature value = {:.6g}".format(level, featureIndex, featureValue))
117 visitor = PrintNodeVisitor()
118 print(
"Number of trees: {}".format(m.getNumberOfTrees()))
119 for i
in range(m.getNumberOfTrees()):
120 print(
"Tree #{}".format(i))
121 m.traverseDF(i, visitor)
124 if __name__ ==
"__main__":
126 trainingResult = trainModel()
127 printModel(trainingResult.get(classifier.training.model))