#include <immintrin.h>
#include <cblas.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <omp.h>

#ifndef min
#define min(a,b) (a)<(b)?(a):(b)
#endif

// AVX
const size_t simd_width_bytes = 256 / 8;

static double flush_cache()
{
  const size_t n = 10 * 1024 * 1024;
  static double* junk = NULL;
  if (!junk)
    junk = malloc(n*sizeof(double));

  volatile double tmp = 0;
  for (size_t i = 0; i < n; ++i)
    junk[i] = i;
  for (size_t i = 0; i < n; ++i)
    tmp += junk[i];
  return tmp;
}

void initialize(int m, int n, int lda, double* A, double* x)
{
  memset(A, 0, lda * n * sizeof(double));
  for (int i = 0; i < m; ++i) {
    x[i] = 2;
    for (int j = 0; j < n; ++j)
      A[i * lda + j] = 1;
  }
}

double check(int n, const double* y)
{
  double sum = 0.f;
  for (int i = 0; i < n; ++i)
    sum += y[i];
  volatile double res = sum / n;
  return res;
}

void dgemv_naive(int m, int n, int lda, const double* A, const double* x, double* y)
{
  for (int i = 0; i < m; ++i)
    y[i] = 0;

  for (int i = 0; i < m; ++i)
    for (int j = 0; j < n; ++j)
      y[i] += A[i * lda + j] * x[j];
}

void dgemv_omp1(int m, int n, int lda, const double* A, const double* x, double* y)
{
  #pragma omp parallel for
  for (int i = 0; i < m; ++i) {
    y[i] = 0;
    for (int j = 0; j < n; ++j)
      y[i] += A[i * lda + j] * x[j];
  }
}

void dgemv_omp2(int m, int n, int lda, const double* A, const double* x, double* y)
{
  for (int i = 0; i < m; ++i) {
    double tmp = 0.0;
    #pragma omp parallel for reduction(+:tmp)
    for (int j = 0; j < n; ++j)
      tmp += A[i * lda + j] * x[j];
    y[i] = tmp;
  }
}

void dgemv_omp3(int m, int n, int lda, const double* A, const double* x, double* y)
{
  memset(y, 0, m*sizeof(double));
  #pragma omp parallel
  {
    double z[m];
    memset(z, 0, m*sizeof(double));
    #pragma omp for
    for (int j = 0; j < n; ++j)
      for (int i = 0; i < m; ++i)
        z[i] += A[i * lda + j] * x[j];
    #pragma omp critical
    for (int i = 0; i < m; ++i) y[i] += z[i];
  }
}

void dgemv_omp4(int m, int n, int lda, const double* A, const double* x, double* y)
{
  for (int i = 0; i < m; ++i) y[i] = 0;
  for (int j = 0; j < n; ++j)
    #pragma omp parallel for
    for (int i = 0; i < m; ++i) {
      y[i] += A[i * lda + j] * x[j];
    }
}

void dgemv_omp5(int m, int n, int lda, const double* A, const double* x, double* y)
{
  openblas_set_num_threads(1);
  #pragma omp parallel
  {
    int i = omp_get_thread_num();
    int p = omp_get_num_threads();
    int chunk = (m + p - 1)/p;
    int m0 = chunk;
    int m1 = min(m0, m-i*m0);
    //int lda = m;

    cblas_dgemv(CblasRowMajor, CblasNoTrans, m1, n, 1.0, &A[i*m0],
                lda, x, 1.0, 0.0, &y[i*m0], 1);

    //#pragma omp for
    //for (int i = 0; i < p; i++) {
    //  int m1 = min(m0, m-i*m0);
    //  cblas_dgemv(CblasRowMajor, CblasNoTrans, m1, n, 1.0, &A[i*m0],
    //            lda, x, 1.0, 0.0, &y[i*m0], 1);
    //}
  }
}

void dgemv_blas(int m, int n, int lda, const double* A, const double* x, double* y)
{
  openblas_set_num_threads(4);
  cblas_dgemv(CblasRowMajor, CblasNoTrans, m, n, 1.0, A, lda, x, 1.0, 0.0, y, 1);
}


typedef void (*gemv_function_t)(int m, int n, int lda, const double* A, const double* x, double* y);

static void time_function(gemv_function_t gemv, char *name)
{
  char filename[256];
  sprintf(filename, "%s.dat", name);
  FILE* f = fopen(filename, "w");

  fprintf(stderr, "profiling %s...\n", name);

  for (int i = 1024; i <= 1024; i += 100) {
    int m = i, n = i, lda;
    double *A = NULL, *x = NULL, *y = NULL;
    int simd_width = simd_width_bytes / sizeof(double);
    lda = (m + simd_width - 1) / simd_width;
    lda *= simd_width;

    posix_memalign((void**)&A, simd_width_bytes, lda * n * sizeof(double));
    posix_memalign((void**)&x, simd_width_bytes, m * sizeof(double));
    posix_memalign((void**)&y, simd_width_bytes, n * sizeof(double));

    initialize(m, n, lda, A, x);

    const int nsteps = i < 1000 ? 100 : 20;

    double t = 0;
    double t0, t1;
    for (int step = 0; step < nsteps; ++step) {
      flush_cache();
      t0 = omp_get_wtime();
      gemv(m, n, lda, A, x, y);
      t1 = omp_get_wtime();
      t += (t1-t0);
      check(n, y);
    }

    double time = t / nsteps;
    long FLOP = n * m * 2;
    double FLOPS = FLOP / time;

    fprintf(f, "%d\t%g\t%g\n", i, time, FLOPS);
    fprintf(stdout, "%s [n=%d] time=%.3lf ms -> %.3lf GFLOPs\n", name, i, time*1e3, FLOPS/1e9);

    free(A);
    free(x);
    free(y);
  }
  fclose(f);
}

int main()
{
  time_function(dgemv_naive, "dgemv_naive");
  time_function(dgemv_omp1, "dgemv_omp1");
  time_function(dgemv_omp2, "dgemv_omp2");
  time_function(dgemv_omp3, "dgemv_omp3");
  time_function(dgemv_omp4, "dgemv_omp4");
  time_function(dgemv_omp5, "dgemv_omp5");
  time_function(dgemv_omp5, "dgemv_blas");

  return 0;
}
