Python* API Reference for Intel® Data Analytics Acceleration Library 2018 Update 3

dt_cls_traverse_model.py

1 # file: dt_cls_traverse_model.py
2 #===============================================================================
3 # Copyright 2014-2018 Intel Corporation.
4 #
5 # This software and the related documents are Intel copyrighted materials, and
6 # your use of them is governed by the express license under which they were
7 # provided to you (License). Unless the License provides otherwise, you may not
8 # use, modify, copy, publish, distribute, disclose or transmit this software or
9 # the related documents without Intel's prior written permission.
10 #
11 # This software and the related documents are provided as is, with no express
12 # or implied warranties, other than those that are expressly stated in the
13 # License.
14 #===============================================================================
15 
16 #
17 # ! Content:
18 # ! Python example of decision tree classification model traversal.
19 # !
20 # ! The program trains the decision tree classification model on a training
21 # ! datasetFileName and prints the trained model by its depth-first traversing.
22 # !*****************************************************************************
23 
24 #
25 
26 
27 #
28 from __future__ import print_function
29 
30 from daal.algorithms import classifier
31 from daal.algorithms import decision_tree
32 import daal.algorithms.decision_tree.classification
33 import daal.algorithms.decision_tree.classification.training
34 
35 from daal.data_management import (
36  DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, FileDataSource
37 )
38 
39 # Input data set parameters
40 trainDatasetFileName = "../data/batch/decision_tree_train.csv"
41 pruneDatasetFileName = "../data/batch/decision_tree_prune.csv"
42 
43 nFeatures = 5
44 nClasses = 5
45 
46 
47 def trainModel():
48 
49  # Initialize FileDataSource<CSVFeatureManager> to retrieve the input data from a .csv file
50  trainDataSource = FileDataSource(
51  trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
52  )
53 
54  # Create Numeric Tables for training data and labels
55  trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
56  trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
57  mergedData = MergedNumericTable(trainData, trainGroundTruth)
58 
59  # Retrieve the data from the input file
60  trainDataSource.loadDataBlock(mergedData)
61 
62  # Initialize FileDataSource<CSVFeatureManager> to retrieve the pruning input data from a .csv file
63  pruneDataSource = FileDataSource(
64  pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
65  )
66 
67  # Create Numeric Tables for pruning data and labels
68  pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
69  pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
70  pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
71 
72  # Retrieve the data from the pruning input file
73  pruneDataSource.loadDataBlock(pruneMergedData)
74 
75  # Create an algorithm object to train the Decision tree model
76  algorithm = decision_tree.classification.training.Batch(nClasses)
77 
78  # Pass the training data set, labels, and pruning dataset with labels to the algorithm
79  algorithm.input.set(classifier.training.data, trainData)
80  algorithm.input.set(classifier.training.labels, trainGroundTruth)
81  algorithm.input.set(decision_tree.classification.training.dataForPruning, pruneData)
82  algorithm.input.set(decision_tree.classification.training.labelsForPruning, pruneGroundTruth)
83 
84  # Train the Decision tree model and retrieve the results
85  return algorithm.compute()
86 
87 
88 # Visitor class implementing NodeVisitor interface, prints out tree nodes of the
89 # model when it is called back by model traversal method
90 class PrintNodeVisitor(classifier.TreeNodeVisitor):
91 
92  def __init__(self):
93  super(PrintNodeVisitor, self).__init__()
94 
95  def onLeafNode(self, level, response):
96 
97  for i in range(level):
98  print(" ", end='')
99  print("Level {}, leaf node. Response value = {}".format(level, response))
100 
101  return True
102 
103  def onSplitNode(self, level, featureIndex, featureValue):
104 
105  for i in range(level):
106  print(" ", end='')
107  print("Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
108 
109  return True
110 
111 
112 def printModel(m):
113  visitor = PrintNodeVisitor()
114  m.traverseDF(visitor)
115 
116 
117 if __name__ == "__main__":
118 
119  trainingResult = trainModel()
120  printModel(trainingResult.get(classifier.training.model))

For more complete information about compiler optimizations, see our Optimization Notice.