56 from daal.algorithms.classifier.quality_metric
import multiclass_confusion_matrix
57 from daal.algorithms
import svm
58 from daal.algorithms
import kernel_function
59 from daal.algorithms
import multi_class_classifier
60 from daal.algorithms
import classifier
61 from daal.data_management
import (
62 DataSourceIface, FileDataSource, readOnly, BlockDescriptor, HomogenNumericTable,
63 NumericTableIface, MergedNumericTable
66 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
67 if utils_folder
not in sys.path:
68 sys.path.insert(0, utils_folder)
69 from utils
import printNumericTables, printNumericTable
72 DATA_PREFIX = os.path.join(
'..',
'data',
'batch')
73 trainDatasetFileName = os.path.join(DATA_PREFIX,
'svm_multi_class_train_dense.csv')
74 testDatasetFileName = os.path.join(DATA_PREFIX,
'svm_multi_class_test_dense.csv')
79 training = svm.training.Batch()
80 prediction = svm.prediction.Batch()
84 predictionResult =
None
87 kernel = kernel_function.linear.Batch()
89 qualityMetricSetResult =
None
90 predictedLabels =
None
91 groundTruthLabels =
None
98 trainDataSource = FileDataSource(
99 trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
100 DataSourceIface.doDictionaryFromContext
104 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
105 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
106 mergedData = MergedNumericTable(trainData, trainGroundTruth)
109 trainDataSource.loadDataBlock(mergedData)
112 algorithm = multi_class_classifier.training.Batch(nClasses)
114 algorithm.parameter.training = training
115 algorithm.parameter.prediction = prediction
118 algorithm.input.set(classifier.training.data, trainData)
119 algorithm.input.set(classifier.training.labels, trainGroundTruth)
122 trainingResult = algorithm.compute()
126 global predictionResult, groundTruthLabels
129 testDataSource = FileDataSource(
130 testDatasetFileName, DataSourceIface.doAllocateNumericTable,
131 DataSourceIface.doDictionaryFromContext
135 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
136 groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
137 mergedData = MergedNumericTable(testData, groundTruthLabels)
140 testDataSource.loadDataBlock(mergedData)
143 algorithm = multi_class_classifier.prediction.Batch(nClasses)
145 algorithm.parameter.training = training
146 algorithm.parameter.prediction = prediction
149 algorithm.input.setTable(classifier.prediction.data, testData)
150 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
153 predictionResult = algorithm.compute()
156 def testModelQuality():
157 global predictedLabels, qualityMetricSetResult
160 predictedLabels = predictionResult.get(classifier.prediction.prediction)
163 qualityMetricSet = multi_class_classifier.quality_metric_set.Batch(nClasses)
164 input = qualityMetricSet.getInputDataCollection().getInput(multi_class_classifier.quality_metric_set.confusionMatrix)
166 input.set(multiclass_confusion_matrix.predictedLabels, predictedLabels)
167 input.set(multiclass_confusion_matrix.groundTruthLabels, groundTruthLabels)
171 qualityMetricSetResult = qualityMetricSet.compute()
177 groundTruthLabels, predictedLabels,
178 "Ground truth",
"Classification results",
179 "SVM classification results (first 20 observations):", 20, interval=15, flt64=
False
182 qualityMetricResult = qualityMetricSetResult.getResult(multi_class_classifier.quality_metric_set.confusionMatrix)
183 printNumericTable(qualityMetricResult.get(multiclass_confusion_matrix.confusionMatrix),
"Confusion matrix:")
185 block = BlockDescriptor()
186 qualityMetricsTable = qualityMetricResult.get(multiclass_confusion_matrix.multiClassMetrics)
187 qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
188 qualityMetricsData = block.getArray().flatten()
189 print(
"Average accuracy: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.averageAccuracy]))
190 print(
"Error rate: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.errorRate]))
191 print(
"Micro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microPrecision]))
192 print(
"Micro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microRecall]))
193 print(
"Micro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microFscore]))
194 print(
"Macro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroPrecision]))
195 print(
"Macro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroRecall]))
196 print(
"Macro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroFscore]))
197 qualityMetricsTable.releaseBlockOfRows(block)
199 if __name__ ==
"__main__":
200 training.parameter.cacheSize = 100000000
201 training.parameter.kernel = kernel
202 prediction.parameter.kernel = kernel