31 from daal.algorithms.classifier.quality_metric
import multiclass_confusion_matrix
32 from daal.algorithms
import svm
33 from daal.algorithms
import kernel_function
34 from daal.algorithms
import multi_class_classifier
35 from daal.algorithms
import classifier
36 from daal.data_management
import (
37 DataSourceIface, FileDataSource, readOnly, BlockDescriptor, HomogenNumericTable,
38 NumericTableIface, MergedNumericTable
41 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
42 if utils_folder
not in sys.path:
43 sys.path.insert(0, utils_folder)
44 from utils
import printNumericTables, printNumericTable
47 DATA_PREFIX = os.path.join(
'..',
'data',
'batch')
48 trainDatasetFileName = os.path.join(DATA_PREFIX,
'svm_multi_class_train_dense.csv')
49 testDatasetFileName = os.path.join(DATA_PREFIX,
'svm_multi_class_test_dense.csv')
54 training = svm.training.Batch(fptype=np.float64)
55 prediction = svm.prediction.Batch(fptype=np.float64)
59 predictionResult =
None
62 kernel = kernel_function.linear.Batch(fptype=np.float64)
64 qualityMetricSetResult =
None
65 predictedLabels =
None
66 groundTruthLabels =
None
73 trainDataSource = FileDataSource(
74 trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
75 DataSourceIface.doDictionaryFromContext
79 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
80 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
81 mergedData = MergedNumericTable(trainData, trainGroundTruth)
84 trainDataSource.loadDataBlock(mergedData)
87 algorithm = multi_class_classifier.training.Batch(nClasses,fptype=np.float64)
89 algorithm.parameter.training = training
90 algorithm.parameter.prediction = prediction
93 algorithm.input.set(classifier.training.data, trainData)
94 algorithm.input.set(classifier.training.labels, trainGroundTruth)
97 trainingResult = algorithm.compute()
101 global predictionResult, groundTruthLabels
104 testDataSource = FileDataSource(
105 testDatasetFileName, DataSourceIface.doAllocateNumericTable,
106 DataSourceIface.doDictionaryFromContext
110 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
111 groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
112 mergedData = MergedNumericTable(testData, groundTruthLabels)
115 testDataSource.loadDataBlock(mergedData)
118 algorithm = multi_class_classifier.prediction.Batch(nClasses,fptype=np.float64)
120 algorithm.parameter.training = training
121 algorithm.parameter.prediction = prediction
124 algorithm.input.setTable(classifier.prediction.data, testData)
125 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
128 predictionResult = algorithm.compute()
131 def testModelQuality():
132 global predictedLabels, qualityMetricSetResult
135 predictedLabels = predictionResult.get(classifier.prediction.prediction)
138 qualityMetricSet = multi_class_classifier.quality_metric_set.Batch(nClasses)
139 input = qualityMetricSet.getInputDataCollection().getInput(multi_class_classifier.quality_metric_set.confusionMatrix)
141 input.set(multiclass_confusion_matrix.predictedLabels, predictedLabels)
142 input.set(multiclass_confusion_matrix.groundTruthLabels, groundTruthLabels)
146 qualityMetricSetResult = qualityMetricSet.compute()
152 groundTruthLabels, predictedLabels,
153 "Ground truth",
"Classification results",
154 "SVM classification results (first 20 observations):", 20, interval=15, flt64=
False
157 qualityMetricResult = qualityMetricSetResult.getResult(multi_class_classifier.quality_metric_set.confusionMatrix)
158 printNumericTable(qualityMetricResult.get(multiclass_confusion_matrix.confusionMatrix),
"Confusion matrix:")
160 block = BlockDescriptor()
161 qualityMetricsTable = qualityMetricResult.get(multiclass_confusion_matrix.multiClassMetrics)
162 qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
163 qualityMetricsData = block.getArray().flatten()
164 print(
"Average accuracy: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.averageAccuracy]))
165 print(
"Error rate: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.errorRate]))
166 print(
"Micro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microPrecision]))
167 print(
"Micro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microRecall]))
168 print(
"Micro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microFscore]))
169 print(
"Macro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroPrecision]))
170 print(
"Macro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroRecall]))
171 print(
"Macro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroFscore]))
172 qualityMetricsTable.releaseBlockOfRows(block)
174 if __name__ ==
"__main__":
175 training.parameter.cacheSize = 100000000
176 training.parameter.kernel = kernel
177 prediction.parameter.kernel = kernel