Python* API Reference for Intel® Data Analytics Acceleration Library 2019 Update 5

svm_multi_class_metrics_dense_batch.py

1 # file: svm_multi_class_metrics_dense_batch.py
2 #===============================================================================
3 # Copyright 2014-2019 Intel Corporation.
4 #
5 # This software and the related documents are Intel copyrighted materials, and
6 # your use of them is governed by the express license under which they were
7 # provided to you (License). Unless the License provides otherwise, you may not
8 # use, modify, copy, publish, distribute, disclose or transmit this software or
9 # the related documents without Intel's prior written permission.
10 #
11 # This software and the related documents are provided as is, with no express
12 # or implied warranties, other than those that are expressly stated in the
13 # License.
14 #===============================================================================
15 
16 #
17 # ! Content:
18 # ! Python example of multi-class support vector machine (SVM) quality metrics
19 # !
20 # !*****************************************************************************
21 
22 #
23 
24 
25 #
26 
27 import os
28 import sys
29 import numpy as np
30 
31 from daal.algorithms.classifier.quality_metric import multiclass_confusion_matrix
32 from daal.algorithms import svm
33 from daal.algorithms import kernel_function
34 from daal.algorithms import multi_class_classifier
35 from daal.algorithms import classifier
36 from daal.data_management import (
37  DataSourceIface, FileDataSource, readOnly, BlockDescriptor, HomogenNumericTable,
38  NumericTableIface, MergedNumericTable
39 )
40 
41 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
42 if utils_folder not in sys.path:
43  sys.path.insert(0, utils_folder)
44 from utils import printNumericTables, printNumericTable
45 
46 # Input data set parameters
47 DATA_PREFIX = os.path.join('..', 'data', 'batch')
48 trainDatasetFileName = os.path.join(DATA_PREFIX, 'svm_multi_class_train_dense.csv')
49 testDatasetFileName = os.path.join(DATA_PREFIX, 'svm_multi_class_test_dense.csv')
50 
51 nFeatures = 20
52 nClasses = 5
53 
54 training = svm.training.Batch(fptype=np.float64)
55 prediction = svm.prediction.Batch(fptype=np.float64)
56 
57 # Model object for the multi-class classifier algorithm
58 trainingResult = None
59 predictionResult = None
60 
61 # Parameters for the multi-class classifier kernel function
62 kernel = kernel_function.linear.Batch(fptype=np.float64)
63 
64 qualityMetricSetResult = None
65 predictedLabels = None
66 groundTruthLabels = None
67 
68 
69 def trainModel():
70  global trainingResult
71 
72  # Initialize FileDataSource to retrieve the input data from a .csv file
73  trainDataSource = FileDataSource(
74  trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
75  DataSourceIface.doDictionaryFromContext
76  )
77 
78  # Create Numeric Tables for training data and labels
79  trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
80  trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
81  mergedData = MergedNumericTable(trainData, trainGroundTruth)
82 
83  # Retrieve the data from the input file
84  trainDataSource.loadDataBlock(mergedData)
85 
86  # Create an algorithm object to train the multi-class SVM model
87  algorithm = multi_class_classifier.training.Batch(nClasses,fptype=np.float64)
88 
89  algorithm.parameter.training = training
90  algorithm.parameter.prediction = prediction
91 
92  # Pass a training data set and dependent values to the algorithm
93  algorithm.input.set(classifier.training.data, trainData)
94  algorithm.input.set(classifier.training.labels, trainGroundTruth)
95 
96  # Build the multi-class SVM model and get the algorithm results
97  trainingResult = algorithm.compute()
98 
99 
100 def testModel():
101  global predictionResult, groundTruthLabels
102 
103  # Initialize FileDataSource<CSVFeatureManager> to retrieve the test data from a .csv file
104  testDataSource = FileDataSource(
105  testDatasetFileName, DataSourceIface.doAllocateNumericTable,
106  DataSourceIface.doDictionaryFromContext
107  )
108 
109  # Create Numeric Tables for testing data and labels
110  testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
111  groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
112  mergedData = MergedNumericTable(testData, groundTruthLabels)
113 
114  # Retrieve the data from input file
115  testDataSource.loadDataBlock(mergedData)
116 
117  # Create an algorithm object to predict multi-class SVM values
118  algorithm = multi_class_classifier.prediction.Batch(nClasses,fptype=np.float64)
119 
120  algorithm.parameter.training = training
121  algorithm.parameter.prediction = prediction
122 
123  # Pass a testing data set and the trained model to the algorithm
124  algorithm.input.setTable(classifier.prediction.data, testData)
125  algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
126 
127  # Predict multi-class SVM values and get the Result class from daal.algorithms.classifier.prediction
128  predictionResult = algorithm.compute()
129 
130 
131 def testModelQuality():
132  global predictedLabels, qualityMetricSetResult
133 
134  # Retrieve predicted labels
135  predictedLabels = predictionResult.get(classifier.prediction.prediction)
136 
137  # Create a quality metric set object to compute quality metrics of the multi-class classifier algorithm
138  qualityMetricSet = multi_class_classifier.quality_metric_set.Batch(nClasses)
139  input = qualityMetricSet.getInputDataCollection().getInput(multi_class_classifier.quality_metric_set.confusionMatrix)
140 
141  input.set(multiclass_confusion_matrix.predictedLabels, predictedLabels)
142  input.set(multiclass_confusion_matrix.groundTruthLabels, groundTruthLabels)
143 
144  # Compute quality metrics and get the quality metrics
145  # returns ResultCollection class from daal.algorithms.multi_class_classifier.quality_metric_set
146  qualityMetricSetResult = qualityMetricSet.compute()
147 
148 def printResults():
149 
150  # Print the classification results
151  printNumericTables(
152  groundTruthLabels, predictedLabels,
153  "Ground truth", "Classification results",
154  "SVM classification results (first 20 observations):", 20, interval=15, flt64=False
155  )
156  # Print the quality metrics
157  qualityMetricResult = qualityMetricSetResult.getResult(multi_class_classifier.quality_metric_set.confusionMatrix)
158  printNumericTable(qualityMetricResult.get(multiclass_confusion_matrix.confusionMatrix), "Confusion matrix:")
159 
160  block = BlockDescriptor()
161  qualityMetricsTable = qualityMetricResult.get(multiclass_confusion_matrix.multiClassMetrics)
162  qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
163  qualityMetricsData = block.getArray().flatten()
164  print("Average accuracy: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.averageAccuracy]))
165  print("Error rate: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.errorRate]))
166  print("Micro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microPrecision]))
167  print("Micro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microRecall]))
168  print("Micro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microFscore]))
169  print("Macro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroPrecision]))
170  print("Macro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroRecall]))
171  print("Macro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroFscore]))
172  qualityMetricsTable.releaseBlockOfRows(block)
173 
174 if __name__ == "__main__":
175  training.parameter.cacheSize = 100000000
176  training.parameter.kernel = kernel
177  prediction.parameter.kernel = kernel
178 
179  trainModel()
180  testModel()
181  testModelQuality()
182  printResults()

For more complete information about compiler optimizations, see our Optimization Notice.