22 from daal.algorithms.svm
import training, prediction
23 from daal.algorithms
import classifier, kernel_function, multi_class_classifier
24 from daal.data_management
import DataSourceIface, FileDataSource
26 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
27 if utils_folder
not in sys.path:
28 sys.path.insert(0, utils_folder)
29 from utils
import printNumericTables, createSparseTable
32 data_dir = os.path.join(
'..',
'data',
'batch')
33 trainDatasetFileName = os.path.join(data_dir,
'svm_multi_class_train_csr.csv')
34 trainLabelsFileName = os.path.join(data_dir,
'svm_multi_class_train_labels.csv')
35 testDatasetFileName = os.path.join(data_dir,
'svm_multi_class_test_csr.csv')
36 testLabelsFileName = os.path.join(data_dir,
'svm_multi_class_test_labels.csv')
40 trainingAlg = training.Batch()
41 predictionAlg = prediction.Batch()
44 kernel = kernel_function.linear.Batch(method=kernel_function.linear.fastCSR)
47 predictionResult =
None
48 testGroundTruth =
None
55 trainLabelsDataSource = FileDataSource(
56 trainLabelsFileName, DataSourceIface.doAllocateNumericTable,
57 DataSourceIface.doDictionaryFromContext
61 trainData = createSparseTable(trainDatasetFileName)
64 trainLabelsDataSource.loadDataBlock()
67 algorithm = multi_class_classifier.training.Batch(nClasses)
69 algorithm.parameter.training = trainingAlg
70 algorithm.parameter.prediction = predictionAlg
73 algorithm.input.set(classifier.training.data, trainData)
74 algorithm.input.set(classifier.training.labels, trainLabelsDataSource.getNumericTable())
78 trainingResult = algorithm.compute()
82 global predictionResult
85 testData = createSparseTable(testDatasetFileName)
88 algorithm = multi_class_classifier.prediction.Batch(nClasses)
90 algorithm.parameter.training = trainingAlg
91 algorithm.parameter.prediction = predictionAlg
94 algorithm.input.setTable(classifier.prediction.data, testData)
95 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
99 predictionResult = algorithm.compute()
105 testLabelsDataSource = FileDataSource(
106 testLabelsFileName, DataSourceIface.doAllocateNumericTable,
107 DataSourceIface.doDictionaryFromContext
110 testLabelsDataSource.loadDataBlock()
111 testGroundTruth = testLabelsDataSource.getNumericTable()
114 testGroundTruth, predictionResult.get(classifier.prediction.prediction),
115 "Ground truth",
"Classification results",
116 "Multi-class SVM classification sample program results (first 20 observations):",
120 if __name__ ==
"__main__":
121 trainingAlg.parameter.cacheSize = 100000000
122 trainingAlg.parameter.kernel = kernel
123 predictionAlg.parameter.kernel = kernel