56 from daal.algorithms
import kernel_function
57 from daal.algorithms.classifier.quality_metric
import binary_confusion_matrix
58 from daal.algorithms
import svm
59 from daal.algorithms
import classifier
60 from daal.data_management
import (
61 DataSourceIface, FileDataSource, readOnly, BlockDescriptor,
62 HomogenNumericTable, NumericTableIface, MergedNumericTable
65 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
66 if utils_folder
not in sys.path:
67 sys.path.insert(0, utils_folder)
68 from utils
import printNumericTables, printNumericTable
71 DATA_PREFIX = os.path.join(
'..',
'data',
'batch')
72 trainDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_train_dense.csv')
73 testDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_test_dense.csv')
78 kernel = kernel_function.linear.Batch()
82 predictionResult =
None
83 qualityMetricSetResult =
None
85 predictedLabels =
None
86 groundTruthLabels =
None
93 trainDataSource = FileDataSource(
94 trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
95 DataSourceIface.doDictionaryFromContext
99 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
100 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
101 mergedData = MergedNumericTable(trainData, trainGroundTruth)
104 trainDataSource.loadDataBlock(mergedData)
107 algorithm = svm.training.Batch()
109 algorithm.parameter.kernel = kernel
110 algorithm.parameter.cacheSize = 600000000
113 algorithm.input.set(classifier.training.data, trainData)
114 algorithm.input.set(classifier.training.labels, trainGroundTruth)
117 trainingResult = algorithm.compute()
120 global predictionResult, groundTruthLabels
123 testDataSource = FileDataSource(
124 testDatasetFileName, DataSourceIface.doAllocateNumericTable,
125 DataSourceIface.doDictionaryFromContext
129 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
130 groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
131 mergedData = MergedNumericTable(testData, groundTruthLabels)
134 testDataSource.loadDataBlock(mergedData)
137 algorithm = svm.prediction.Batch()
139 algorithm.parameter.kernel = kernel
142 algorithm.input.setTable(classifier.prediction.data, testData)
143 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
147 predictionResult = algorithm.compute()
150 def testModelQuality():
151 global predictedLabels, qualityMetricSetResult, groundTruthLabels
154 predictedLabels = predictionResult.get(classifier.prediction.prediction)
157 qualityMetricSet = svm.quality_metric_set.Batch()
159 input = qualityMetricSet.getInputDataCollection().getInput(svm.quality_metric_set.confusionMatrix)
161 input.set(binary_confusion_matrix.predictedLabels, predictedLabels)
162 input.set(binary_confusion_matrix.groundTruthLabels, groundTruthLabels)
166 qualityMetricSetResult = qualityMetricSet.compute()
173 groundTruthLabels, predictedLabels,
174 "Ground truth",
"Classification results",
175 "SVM classification results (first 20 observations):", 20, interval=15, flt64=
False
179 qualityMetricResult = qualityMetricSetResult.getResult(svm.quality_metric_set.confusionMatrix)
180 printNumericTable(qualityMetricResult.get(binary_confusion_matrix.confusionMatrix),
"Confusion matrix:")
182 block = BlockDescriptor()
183 qualityMetricsTable = qualityMetricResult.get(binary_confusion_matrix.binaryMetrics)
184 qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
185 qualityMetricsData = block.getArray().flatten()
186 print(
"Accuracy: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.accuracy]))
187 print(
"Precision: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.precision]))
188 print(
"Recall: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.recall]))
189 print(
"F-score: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.fscore]))
190 print(
"Specificity: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.specificity]))
191 print(
"AUC: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.AUC]))
192 qualityMetricsTable.releaseBlockOfRows(block)
194 if __name__ ==
"__main__":