48 from daal.algorithms.svm
import training, prediction
49 from daal.algorithms
import kernel_function, classifier
50 from daal.data_management
import DataSourceIface, FileDataSource
52 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
53 if utils_folder
not in sys.path:
54 sys.path.insert(0, utils_folder)
55 from utils
import printNumericTables, createSparseTable
58 DATA_PREFIX = os.path.join(
'..',
'data',
'batch')
60 trainDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_train_csr.csv')
61 trainLabelsFileName = os.path.join(DATA_PREFIX,
'svm_two_class_train_labels.csv')
62 testDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_test_csr.csv')
63 testLabelsFileName = os.path.join(DATA_PREFIX,
'svm_two_class_test_labels.csv')
66 kernel = kernel_function.linear.Batch(method=kernel_function.linear.fastCSR)
70 predictionResult =
None
77 trainLabelsDataSource = FileDataSource(
78 trainLabelsFileName, DataSourceIface.doAllocateNumericTable,
79 DataSourceIface.doDictionaryFromContext
83 trainData = createSparseTable(trainDatasetFileName)
86 trainLabelsDataSource.loadDataBlock()
89 algorithm = training.Batch()
91 algorithm.parameter.kernel = kernel
92 algorithm.parameter.cacheSize = 40000000
95 algorithm.input.set(classifier.training.data, trainData)
96 algorithm.input.set(classifier.training.labels, trainLabelsDataSource.getNumericTable())
99 trainingResult = algorithm.compute()
103 global predictionResult
106 testData = createSparseTable(testDatasetFileName)
109 algorithm = prediction.Batch()
111 algorithm.parameter.kernel = kernel
114 algorithm.input.setTable(classifier.prediction.data, testData)
116 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
122 predictionResult = algorithm.getResult()
128 testLabelsDataSource = FileDataSource(
129 testLabelsFileName, DataSourceIface.doAllocateNumericTable,
130 DataSourceIface.doDictionaryFromContext
133 testLabelsDataSource.loadDataBlock()
134 testGroundTruth = testLabelsDataSource.getNumericTable()
137 testGroundTruth, predictionResult.get(classifier.prediction.prediction),
138 "Ground truth\t",
"Classification results",
139 "SVM classification results (first 20 observations):", 20, flt64=
False
142 if __name__ ==
"__main__":