30 from daal.algorithms
import kernel_function
31 from daal.algorithms.classifier.quality_metric
import binary_confusion_matrix
32 from daal.algorithms
import svm
33 from daal.algorithms
import classifier
34 from daal.data_management
import (
35 DataSourceIface, FileDataSource, readOnly, BlockDescriptor,
36 HomogenNumericTable, NumericTableIface, MergedNumericTable
39 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
40 if utils_folder
not in sys.path:
41 sys.path.insert(0, utils_folder)
42 from utils
import printNumericTables, printNumericTable
45 DATA_PREFIX = os.path.join(
'..',
'data',
'batch')
46 trainDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_train_dense.csv')
47 testDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_test_dense.csv')
52 kernel = kernel_function.linear.Batch()
56 predictionResult =
None
57 qualityMetricSetResult =
None
59 predictedLabels =
None
60 groundTruthLabels =
None
67 trainDataSource = FileDataSource(
68 trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
69 DataSourceIface.doDictionaryFromContext
73 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
74 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
75 mergedData = MergedNumericTable(trainData, trainGroundTruth)
78 trainDataSource.loadDataBlock(mergedData)
81 algorithm = svm.training.Batch()
83 algorithm.parameter.kernel = kernel
84 algorithm.parameter.cacheSize = 600000000
87 algorithm.input.set(classifier.training.data, trainData)
88 algorithm.input.set(classifier.training.labels, trainGroundTruth)
91 trainingResult = algorithm.compute()
94 global predictionResult, groundTruthLabels
97 testDataSource = FileDataSource(
98 testDatasetFileName, DataSourceIface.doAllocateNumericTable,
99 DataSourceIface.doDictionaryFromContext
103 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
104 groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
105 mergedData = MergedNumericTable(testData, groundTruthLabels)
108 testDataSource.loadDataBlock(mergedData)
111 algorithm = svm.prediction.Batch()
113 algorithm.parameter.kernel = kernel
116 algorithm.input.setTable(classifier.prediction.data, testData)
117 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
121 predictionResult = algorithm.compute()
124 def testModelQuality():
125 global predictedLabels, qualityMetricSetResult, groundTruthLabels
128 predictedLabels = predictionResult.get(classifier.prediction.prediction)
131 qualityMetricSet = svm.quality_metric_set.Batch()
133 input = qualityMetricSet.getInputDataCollection().getInput(svm.quality_metric_set.confusionMatrix)
135 input.set(binary_confusion_matrix.predictedLabels, predictedLabels)
136 input.set(binary_confusion_matrix.groundTruthLabels, groundTruthLabels)
140 qualityMetricSetResult = qualityMetricSet.compute()
147 groundTruthLabels, predictedLabels,
148 "Ground truth",
"Classification results",
149 "SVM classification results (first 20 observations):", 20, interval=15, flt64=
False
153 qualityMetricResult = qualityMetricSetResult.getResult(svm.quality_metric_set.confusionMatrix)
154 printNumericTable(qualityMetricResult.get(binary_confusion_matrix.confusionMatrix),
"Confusion matrix:")
156 block = BlockDescriptor()
157 qualityMetricsTable = qualityMetricResult.get(binary_confusion_matrix.binaryMetrics)
158 qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
159 qualityMetricsData = block.getArray().flatten()
160 print(
"Accuracy: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.accuracy]))
161 print(
"Precision: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.precision]))
162 print(
"Recall: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.recall]))
163 print(
"F-score: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.fscore]))
164 print(
"Specificity: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.specificity]))
165 print(
"AUC: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.AUC]))
166 qualityMetricsTable.releaseBlockOfRows(block)
168 if __name__ ==
"__main__":