28 from __future__
import print_function
30 from daal
import algorithms
31 from daal.algorithms
import decision_forest
32 import daal.algorithms.decision_forest.regression
33 import daal.algorithms.decision_forest.regression.training
35 from daal.data_management
import (
36 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, features
40 trainDatasetFileName =
"../data/batch/df_regression_train.csv"
41 categoricalFeaturesIndices = [3]
51 trainData, trainDependentVariable = loadData(trainDatasetFileName)
54 algorithm = decision_forest.regression.training.Batch()
57 algorithm.input.set(decision_forest.regression.training.data, trainData)
58 algorithm.input.set(decision_forest.regression.training.dependentVariable, trainDependentVariable)
60 algorithm.parameter.nTrees = nTrees
63 return algorithm.compute()
66 def loadData(fileName):
69 trainDataSource = FileDataSource(
70 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
74 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
75 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
76 mergedData = MergedNumericTable(data, dependentVar)
79 trainDataSource.loadDataBlock(mergedData)
81 dictionary = data.getDictionary()
82 for i
in range(len(categoricalFeaturesIndices)):
83 dictionary[categoricalFeaturesIndices[i]].featureType = features.DAAL_CATEGORICAL
85 return data, dependentVar
89 class PrintNodeVisitor(algorithms.regression.TreeNodeVisitor):
92 super(PrintNodeVisitor, self).__init__()
94 def onLeafNode(self, level, response):
96 for i
in range(level):
98 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
102 def onSplitNode(self, level, featureIndex, featureValue):
104 for i
in range(level):
106 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
111 visitor = PrintNodeVisitor()
112 print(
"Number of trees: {}".format(m.getNumberOfTrees()))
113 for i
in range(m.getNumberOfTrees()):
114 print(
"Tree #{}".format(i))
115 m.traverseDF(i, visitor)
117 if __name__ ==
"__main__":
119 trainingResult = trainModel()
120 printModel(trainingResult.get(decision_forest.regression.training.model))