Python* API Reference for Intel® Data Analytics Acceleration Library 2018 Update 1

dt_cls_traverse_model.py

1 # file: dt_cls_traverse_model.py
2 #===============================================================================
3 # Copyright 2014-2017 Intel Corporation All Rights Reserved.
4 #
5 # The source code, information and material ("Material") contained herein is
6 # owned by Intel Corporation or its suppliers or licensors, and title to such
7 # Material remains with Intel Corporation or its suppliers or licensors. The
8 # Material contains proprietary information of Intel or its suppliers and
9 # licensors. The Material is protected by worldwide copyright laws and treaty
10 # provisions. No part of the Material may be used, copied, reproduced,
11 # modified, published, uploaded, posted, transmitted, distributed or disclosed
12 # in any way without Intel's prior express written permission. No license under
13 # any patent, copyright or other intellectual property rights in the Material
14 # is granted to or conferred upon you, either expressly, by implication,
15 # inducement, estoppel or otherwise. Any license under such intellectual
16 # property rights must be express and approved by Intel in writing.
17 #
18 # Unless otherwise agreed by Intel in writing, you may not remove or alter this
19 # notice or any other notice embedded in Materials by Intel or Intel's
20 # suppliers or licensors in any way.
21 #===============================================================================
22 
23 #
24 # ! Content:
25 # ! Python example of decision tree classification model traversal.
26 # !
27 # ! The program trains the decision tree classification model on a training
28 # ! datasetFileName and prints the trained model by its depth-first traversing.
29 # !*****************************************************************************
30 
31 #
32 
33 
34 #
35 from __future__ import print_function
36 
37 from daal.algorithms import classifier
38 from daal.algorithms import decision_tree
39 import daal.algorithms.decision_tree.classification
40 import daal.algorithms.decision_tree.classification.training
41 
42 from daal.data_management import (
43  DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, FileDataSource
44 )
45 
46 # Input data set parameters
47 trainDatasetFileName = "../data/batch/decision_tree_train.csv"
48 pruneDatasetFileName = "../data/batch/decision_tree_prune.csv"
49 
50 nFeatures = 5
51 nClasses = 5
52 
53 
54 def trainModel():
55 
56  # Initialize FileDataSource<CSVFeatureManager> to retrieve the input data from a .csv file
57  trainDataSource = FileDataSource(
58  trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
59  )
60 
61  # Create Numeric Tables for training data and labels
62  trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
63  trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
64  mergedData = MergedNumericTable(trainData, trainGroundTruth)
65 
66  # Retrieve the data from the input file
67  trainDataSource.loadDataBlock(mergedData)
68 
69  # Initialize FileDataSource<CSVFeatureManager> to retrieve the pruning input data from a .csv file
70  pruneDataSource = FileDataSource(
71  pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
72  )
73 
74  # Create Numeric Tables for pruning data and labels
75  pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
76  pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
77  pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
78 
79  # Retrieve the data from the pruning input file
80  pruneDataSource.loadDataBlock(pruneMergedData)
81 
82  # Create an algorithm object to train the Decision tree model
83  algorithm = decision_tree.classification.training.Batch(nClasses)
84 
85  # Pass the training data set, labels, and pruning dataset with labels to the algorithm
86  algorithm.input.set(classifier.training.data, trainData)
87  algorithm.input.set(classifier.training.labels, trainGroundTruth)
88  algorithm.input.set(decision_tree.classification.training.dataForPruning, pruneData)
89  algorithm.input.set(decision_tree.classification.training.labelsForPruning, pruneGroundTruth)
90 
91  # Train the Decision tree model and retrieve the results
92  return algorithm.compute()
93 
94 
95 # Visitor class implementing NodeVisitor interface, prints out tree nodes of the
96 # model when it is called back by model traversal method
97 class PrintNodeVisitor(classifier.TreeNodeVisitor):
98 
99  def __init__(self):
100  super(PrintNodeVisitor, self).__init__()
101 
102  def onLeafNode(self, level, response):
103 
104  for i in range(level):
105  print(" ", end='')
106  print("Level {}, leaf node. Response value = {}".format(level, response))
107 
108  return True
109 
110  def onSplitNode(self, level, featureIndex, featureValue):
111 
112  for i in range(level):
113  print(" ", end='')
114  print("Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
115 
116  return True
117 
118 
119 def printModel(m):
120  visitor = PrintNodeVisitor()
121  m.traverseDF(visitor)
122 
123 
124 if __name__ == "__main__":
125 
126  trainingResult = trainModel()
127  printModel(trainingResult.get(classifier.training.model))

For more complete information about compiler optimizations, see our Optimization Notice.