35 from __future__
import print_function
37 from daal
import algorithms
38 from daal.algorithms
import decision_forest
39 import daal.algorithms.decision_forest.regression
40 import daal.algorithms.decision_forest.regression.training
42 from daal.data_management
import (
43 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, data_feature_utils
47 trainDatasetFileName =
"../data/batch/df_regression_train.csv"
48 categoricalFeaturesIndices = [3]
58 trainData, trainDependentVariable = loadData(trainDatasetFileName)
61 algorithm = decision_forest.regression.training.Batch()
64 algorithm.input.set(decision_forest.regression.training.data, trainData)
65 algorithm.input.set(decision_forest.regression.training.dependentVariable, trainDependentVariable)
67 algorithm.parameter.nTrees = nTrees
70 return algorithm.compute()
73 def loadData(fileName):
76 trainDataSource = FileDataSource(
77 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
81 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
82 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
83 mergedData = MergedNumericTable(data, dependentVar)
86 trainDataSource.loadDataBlock(mergedData)
88 dictionary = data.getDictionary()
89 for i
in range(len(categoricalFeaturesIndices)):
90 dictionary[categoricalFeaturesIndices[i]].featureType = data_feature_utils.DAAL_CATEGORICAL
92 return data, dependentVar
96 class PrintNodeVisitor(algorithms.regression.TreeNodeVisitor):
99 super(PrintNodeVisitor, self).__init__()
101 def onLeafNode(self, level, response):
103 for i
in range(level):
105 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
109 def onSplitNode(self, level, featureIndex, featureValue):
111 for i
in range(level):
113 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
118 visitor = PrintNodeVisitor()
119 print(
"Number of trees: {}".format(m.numberOfTrees()))
120 for i
in range(m.numberOfTrees()):
121 print(
"Tree #{}".format(i))
122 m.traverseDF(i, visitor)
124 if __name__ ==
"__main__":
126 trainingResult = trainModel()
127 printModel(trainingResult.get(decision_forest.regression.training.model))