/*******************************************************************************
* Copyright 1999-2018 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
!  Content:
!      C B L A S _ Z G E M M  Example Program Text ( C Interface )
!******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include "mkl.h"

#include "mkl_example.h"

int main(int argc, char *argv[])
{
      FILE *in_file;
      char *in_file_name;

      MKL_INT         m, n, k;
      MKL_INT         lda, ldb, ldc;
      MKL_INT         rmaxa, cmaxa, rmaxb, cmaxb, rmaxc, cmaxc;
      MKL_Complex16   alpha, beta;
      MKL_Complex16  *a, *b, *c;
      CBLAS_LAYOUT    layout;
      CBLAS_TRANSPOSE transA, transB;
      MKL_INT         ma, na, mb, nb;

      printf("\n     C B L A S _ Z G E M M  EXAMPLE PROGRAM\n");

/*       Get input parameters                                  */

      if( argc == 1 ) {
          printf("\n You must specify in_file data file as 1-st parameter");
          return 1;
      }
      in_file_name = argv[1];

/*       Get input data                                        */

      if( (in_file = fopen( in_file_name, "r" )) == NULL ) {
         printf("\n ERROR on OPEN '%s' with mode=\"r\"\n", in_file_name);
         return 1;
      }
      if( GetIntegerParameters(in_file, &m, &n, &k) != 3 ) {
          printf("\n ERROR of m, n, k reading\n");
          return 1;
      }
      if( GetScalarsZ(in_file, &alpha, &beta) != 2 ) {
          printf("\n ERROR of alpha, beta reading\n");
          return 1;
      }
      if( GetCblasCharParameters(in_file, &transA, &transB, &layout) != 3 ) {
          printf("\n ERROR of transA, transB, layout reading\n");
          return 1;
      }

      if( transA == CblasNoTrans ) {
          rmaxa = m + 1;
          cmaxa = k;
          ma    = m;
          na    = k;
      } else {
          rmaxa = k + 1;
          cmaxa = m;
          ma    = k;
          na    = m;
      }
      if( transB == CblasNoTrans ) {
          rmaxb = k + 1;
          cmaxb = n;
          mb    = k;
          nb    = n;
      } else {
          rmaxb = n + 1;
          cmaxb = k;
          mb    = n;
          nb    = k;
      }
      rmaxc = m + 1;
      cmaxc = n;
      a = (MKL_Complex16 *)mkl_calloc(rmaxa*cmaxa, sizeof(MKL_Complex16), 64);
      b = (MKL_Complex16 *)mkl_calloc(rmaxb*cmaxb, sizeof(MKL_Complex16), 64);
      c = (MKL_Complex16 *)mkl_calloc(rmaxc*cmaxc, sizeof(MKL_Complex16), 64);
      if ( a == NULL || b == NULL || c == NULL ) {
           printf("\n Can't allocate memory arrays");
           return 1;
      }
      if( layout == CblasRowMajor ) {
          lda=cmaxa;
          ldb=cmaxb;
          ldc=cmaxc;
      } else {
          lda=rmaxa;
          ldb=rmaxb;
          ldc=rmaxc;
      }
      if( GetArrayZ(in_file, &layout, GENERAL_MATRIX, &ma, &na, a, &lda) != 0 ) {
        printf("\n ERROR of array A reading\n");
        return 1;
      }
      if( GetArrayZ(in_file, &layout, GENERAL_MATRIX, &mb, &nb, b, &ldb) != 0 ) {
        printf("\n ERROR of array B reading\n");
        return 1;
      }
      if( GetArrayZ(in_file, &layout, GENERAL_MATRIX, &m,  &n,  c, &ldc) != 0 ) {
        printf("\n ERROR of array C reading\n");
        return 1;
      }
      fclose(in_file);

/*       Print input data                                      */

      printf("\n     INPUT DATA");
      printf("\n       M="INT_FORMAT"  N="INT_FORMAT"  K="INT_FORMAT, m, n, k);
      printf("\n       ALPHA =(%5.1f,%5.1f )  BETA =(%5.1f,%5.1f )",
              alpha.real, alpha.imag, beta.real, beta.imag);
      PrintParameters("TRANSA, TRANSB", transA, transB);
      PrintParameters("LAYOUT", layout);
      PrintArrayZ(&layout, FULLPRINT, GENERAL_MATRIX, &ma, &na, a, &lda, "A");
      PrintArrayZ(&layout, FULLPRINT, GENERAL_MATRIX, &mb, &nb, b, &ldb, "B");
      PrintArrayZ(&layout, FULLPRINT, GENERAL_MATRIX, &m, &n, c, &ldc, "C");

/*      Call ZGEMM subroutine ( C Interface )                  */

      cblas_zgemm(layout, transA, transB, m, n, k, &alpha,
                  a, lda, b, ldb, &beta, c, ldc);

/*       Print output data                                     */

      printf("\n\n     OUTPUT DATA");
      PrintArrayZ(&layout, FULLPRINT, GENERAL_MATRIX, &m, &n, c, &ldc, "C");

      mkl_free(a);
      mkl_free(b);
      mkl_free(c);

      return 0;
}