54 from __future__
import print_function
56 from daal
import algorithms
57 from daal.algorithms
import decision_forest
58 import daal.algorithms.decision_forest.regression
59 import daal.algorithms.decision_forest.regression.training
61 from daal.data_management
import (
62 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, data_feature_utils
66 trainDatasetFileName =
"../data/batch/df_regression_train.csv"
67 categoricalFeaturesIndices = [3]
77 trainData, trainDependentVariable = loadData(trainDatasetFileName)
80 algorithm = decision_forest.regression.training.Batch()
83 algorithm.input.set(decision_forest.regression.training.data, trainData)
84 algorithm.input.set(decision_forest.regression.training.dependentVariable, trainDependentVariable)
86 algorithm.parameter.nTrees = nTrees
89 return algorithm.compute()
92 def loadData(fileName):
95 trainDataSource = FileDataSource(
96 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
100 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
101 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
102 mergedData = MergedNumericTable(data, dependentVar)
105 trainDataSource.loadDataBlock(mergedData)
107 dictionary = data.getDictionary()
108 for i
in range(len(categoricalFeaturesIndices)):
109 dictionary[categoricalFeaturesIndices[i]].featureType = data_feature_utils.DAAL_CATEGORICAL
111 return data, dependentVar
115 class PrintNodeVisitor(algorithms.regression.TreeNodeVisitor):
118 super(PrintNodeVisitor, self).__init__()
120 def onLeafNode(self, level, response):
122 for i
in range(level):
124 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
128 def onSplitNode(self, level, featureIndex, featureValue):
130 for i
in range(level):
132 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
137 visitor = PrintNodeVisitor()
138 print(
"Number of trees: {}".format(m.numberOfTrees()))
139 for i
in range(m.numberOfTrees()):
140 print(
"Tree #{}".format(i))
141 m.traverseDF(i, visitor)
143 if __name__ ==
"__main__":
145 trainingResult = trainModel()
146 printModel(trainingResult.get(decision_forest.regression.training.model))