#include <immintrin.h>
#include <cblas.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <omp.h>

// AVX
const size_t simd_width_bytes = 256 / 8;

static double flush_cache()
{
    const size_t n = 10 * 1024 * 1024;
    static double* junk = NULL;
    if (!junk)
        junk = malloc(n*sizeof(double));

    volatile double tmp = 0;
    for (size_t i = 0; i < n; ++i)
        junk[i] = i;
    for (size_t i = 0; i < n; ++i)
        tmp += junk[i];
    return tmp;
}

void initialize(int m, int n, int lda, double* A, double* x)
{
    memset(A, 0, lda * n * sizeof(double));
    for (int i = 0; i < m; ++i) {
        x[i] = 2;
        for (int j = 0; j < n; ++j)
            A[i * lda + j] = 1;
    }
}

double check(int n, const double* y)
{
    double sum = 0.f;
    for (int i = 0; i < n; ++i)
        sum += y[i];
    volatile double res = sum / n;
    return res;
}

void dgemv_naive(int m, int n, int lda, const double* A, const double* x, double* y)
{
    for (int i = 0; i < m; ++i)
        y[i] = 0;

    for (int i = 0; i < m; ++i)
        for (int j = 0; j < n; ++j)
            y[i] += A[i * lda + j] * x[j];
}

void dgemv_block2(int m, int n, int lda, const double* A, const double* x, double* y)
{
    double y0, y1, xj;

    for (int i = 0; i < m; i += 2) {
        y0 = y1 = 0;
        for (int j = 0; j < n; ++j) {
            xj = x[j];
            y0 += A[(i + 0) * lda + j] * xj;
            y1 += A[(i + 1) * lda + j] * xj;
        }
        y[i + 0] = y0;
        y[i + 1] = y1;
    }
}

void dgemv_block4(int m, int n, int lda, const double* A, const double* x, double* y)
{
    double y0, y1, y2, y3, xj, a0, a1, a2, a3;

    for (int i = 0; i < m; i += 4) {
        y0 = y1 = y2 = y3 = 0;
        for (int j = 0; j < n; ++j) {
            xj = x[j];
            a0 = A[(i + 0) * lda + j];
            a1 = A[(i + 1) * lda + j];
            a2 = A[(i + 2) * lda + j];
            a3 = A[(i + 3) * lda + j];

            y0 += a0 * xj;
            y1 += a1 * xj;
            y2 += a2 * xj;
            y3 += a3 * xj;
        }
        y[i + 0] = y0;
        y[i + 1] = y1;
        y[i + 2] = y2;
        y[i + 3] = y3;
    }
}

void dgemv_block_omp(int m, int n, int lda, const double* A, const double* x, double* y)
{
    double y0, y1, xj;

#pragma omp parallel for private(y0, y1, xj)
    for (int i = 0; i < m; i += 2) {
        y0 = y1 = 0;
        for (int j = 0; j < n; ++j) {
            xj = x[j];
            y0 += A[(i + 0) * lda + j] * xj;
            y1 += A[(i + 1) * lda + j] * xj;
        }
        y[i + 0] = y0;
        y[i + 1] = y1;
    }
}

void dgemv_simd(int m, int n, int lda, const double* A, const double* x, double* y)
{
    int i, j;
    __m256d AA, xx, yy;

    memset(y, 0, m * sizeof(double));
    const int simd_width = simd_width_bytes / sizeof(double);

    for (i = 0; i < m; ++i) {
        yy = _mm256_set_pd(0., 0., 0., 0.);

        for (j = 0; j < n; j += simd_width) {
            AA = _mm256_load_pd(A + i * lda + j);
            xx = _mm256_load_pd(x + j);
            yy = _mm256_fmadd_pd(AA, xx, yy);
        }

        y[i] += (yy[0] + yy[1]) + (yy[2] + yy[3]);
    }
}

void dgemv_simd_block2(int m, int n, int lda, const double* A, const double* x, double* y)
{
    int i, j;
    __m256d AA0, AA1, xx, yy0, yy1;

    memset(y, 0, m * sizeof(double));
    const int simd_width = simd_width_bytes / sizeof(double);

    for (i = 0; i < m; i += 2) {
        yy0 = yy1 = _mm256_set_pd(0., 0., 0., 0.);

        for (j = 0; j < n; j += simd_width) {
            AA0 = _mm256_load_pd(A + (i + 0) * lda + j);
            AA1 = _mm256_load_pd(A + (i + 1) * lda + j);
            xx = _mm256_load_pd(x + j);

            yy0 = _mm256_fmadd_pd(AA0, xx, yy0);
            yy1 = _mm256_fmadd_pd(AA1, xx, yy1);
        }

        y[i + 0] += (yy0[0] + yy0[1]) + (yy0[2] + yy0[3]);
        y[i + 1] += (yy1[0] + yy1[1]) + (yy1[2] + yy1[3]);
    }
}

void dgemv_blas(int m, int n, int lda, const double* A, const double* x, double* y)
{
    cblas_dgemv(CblasRowMajor, CblasNoTrans, m, n, 1.0, A, lda, x, 1.0, 0.0, y, 1);
}


typedef void (*gemv_function_t)(int m, int n, int lda, const double* A, const double* x, double* y);

static void time_function(gemv_function_t gemv, char *name)
{
    char filename[256];
    sprintf(filename, "%s.dat", name);
    FILE* f = fopen(filename, "w");

    // fprintf(stderr, "profiling %s...\n", name);

    // for (int i = 100; i < 5000; i += 100) {
    // for (int i = 100; i <= 5000; i += 100) {
    for (int i = 1024; i <= 1024; i += 100) {
        int m = i, n = i, lda;
        double *A = NULL, *x = NULL, *y = NULL;
        int simd_width = simd_width_bytes / sizeof(double);
        lda = (m + simd_width - 1) / simd_width;
        lda *= simd_width;

        posix_memalign((void**)&A, simd_width_bytes, lda * n * sizeof(double));
        posix_memalign((void**)&x, simd_width_bytes, m * sizeof(double));
        posix_memalign((void**)&y, simd_width_bytes, n * sizeof(double));

        initialize(m, n, lda, A, x);

        const int nsteps = i < 1000 ? 100 : 20;

        double t = 0;
        double t0, t1;
        for (int step = 0; step < nsteps; ++step) {
            flush_cache();
            t0 = omp_get_wtime();
            gemv(m, n, lda, A, x, y);
            t1 = omp_get_wtime();
            t += (t1-t0);
            check(n, y);
        }

        double time = t / nsteps;
        long FLOP = n * m * 2;
        double FLOPS = FLOP / time;

        fprintf(f, "%d\t%g\t%g\n", i, time, FLOPS);
        fprintf(stdout, "%s [n=%d] time=%.3lf ms -> %.3lf GFLOPs\n", name, i, time*1e3, FLOPS/1e9);

        free(A);
        free(x);
        free(y);
    }
    fclose(f);
}

int main()
{
    time_function(dgemv_naive, "naive_double");
#if 1
    time_function(dgemv_block2, "block2_double");
    time_function(dgemv_block4, "block4_double");
    time_function(dgemv_block_omp, "block2+omp_double");
    time_function(dgemv_simd, "simd_double");
    time_function(dgemv_simd_block2, "simd+block2_double");
    time_function(dgemv_blas, "blas_double");
#endif

    return 0;
}
