/*******************************************************************************
* Copyright (C) 2021 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of DPCPP API oneapi::mkl::blas::gemm
*       using unified shared memory to perform General
*       Matrix-Matrix Multiplication on single SYCL device or on multiple SYCL
*       devices using explicit scaling.
*
*       C = alpha * op(A) * op(B) + beta * C
*
*       where op() is defined by one of oneapi::mkl::transpose::{nontrans,trans,conjtrans}
*
*
*       The supported floating point data types for gemm matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <cstdlib>
#include <limits>
#include <vector>
#include <algorithm>
#include <cstring>
#include <list>
#include <iterator>

#include <sycl/sycl.hpp>
#include "oneapi/mkl/blas.hpp"
#include "mkl.h"

// local includes
#include "common_for_examples.hpp"


//
// Main example for Gemm consisting of
// initialization of A, B and C matrices as well as
// scalars alpha and beta.  Then the product
//
// C = alpha * op(A) * op(B) + beta * C
//
// is performed and finally the results are post processed.
//
template <typename fp>
void run_gemm_example(const sycl::device &device_in) {

    //
    // Initialize data for Gemm
    //
    // C = alpha * op(A) * op(B)  + beta * C
    //

    oneapi::mkl::transpose transA = oneapi::mkl::transpose::trans;
    oneapi::mkl::transpose transB = oneapi::mkl::transpose::nontrans;

    // Matrix data sizes. This example assumes n is even for simplicity.
    int m = 45;
    int n = 98;
    int k = 67;

    // Leading dimensions of data
    int ldA = 103;
    int ldB = 105;
    int ldC = 106;

    // Choose scalar values.
    fp alpha = set_fp_value(fp(2.0), fp(-0.5));
    fp beta  = set_fp_value(fp(3.0), fp(-1.5));

    // We start with a single device (device_in). Find other devices on that platform.
    auto devices = device_in.get_platform().get_devices();

    // For this example, limit to 2 devices.
    if (devices.size() >= 2)
        devices.resize(2);

    int device_count = devices.size();

    std::cout << "\t\tUsing " << device_count << " device(s).\n";

    // Create context and execution queue(s).
    sycl::context ctx(devices);

    std::vector<sycl::queue> queues;
    for (auto &dev: devices)
        queues.emplace_back(ctx, dev);

    // Prepare data buffers on host. Use host USM memory for efficient transfer to device(s).
    int sizea = (transA == oneapi::mkl::transpose::nontrans) ? ldA * k : ldA * m;
    int sizeb = ldB * n;
    int sizec = ldC * n;

    auto A_host = sycl::malloc_host<fp>(sizea, ctx);
    auto B_host = sycl::malloc_host<fp>(sizeb, ctx);
    auto C_host = sycl::malloc_host<fp>(sizec, ctx);

    if (!A_host || !B_host || !C_host)
        throw std::runtime_error("Failed to allocate USM memory.");

    rand_matrix(A_host, transA, m, k, ldA);
    rand_matrix(B_host, transB, k, n, ldB);
    rand_matrix(C_host, oneapi::mkl::transpose::nontrans, m, n, ldC);

    //
    // Copy A/B/C from host to device(s).
    // When multiple devices are detected, GEMM operation is split between devices in n direction.
    // The full A matrix is copied to both devices.
    // B and C matrices are split between devices, so only half of B and C are copied to each device.
    //

    fp *A_device[2], *B_device[2], *C_device[2];

    if (device_count > 1) {
        A_device[0] = sycl::malloc_device<fp>(sizea, queues[0]);
        A_device[1] = sycl::malloc_device<fp>(sizea, queues[1]);
        B_device[0] = sycl::malloc_device<fp>(sizeb/2, queues[0]);
        B_device[1] = sycl::malloc_device<fp>(sizeb/2, queues[1]);
        C_device[0] = sycl::malloc_device<fp>(sizec/2, queues[0]);
        C_device[1] = sycl::malloc_device<fp>(sizec/2, queues[1]);
        queues[0].copy(A_host,             A_device[0], sizea);
        queues[1].copy(A_host,             A_device[1], sizea);
        queues[0].copy(B_host,             B_device[0], (sizeb/2));
        queues[1].copy(B_host + ldB * n/2, B_device[1], (sizeb/2));
        queues[0].copy(C_host,             C_device[0], (sizec/2));
        queues[1].copy(C_host + ldC * n/2, C_device[1], (sizec/2));
    }
    else {
        A_device[0] = sycl::malloc_device<fp>(sizea, queues[0]);
        B_device[0] = sycl::malloc_device<fp>(sizeb, queues[0]);
        C_device[0] = sycl::malloc_device<fp>(sizec, queues[0]);
        queues[0].copy(A_host, A_device[0], sizea);
        queues[0].copy(B_host, B_device[0], sizeb);
        queues[0].copy(C_host, C_device[0], sizec);
    }

    // Wait for copies to complete.
    for (auto &q: queues) q.wait();

    //
    // Execute GEMM
    //

    try {
        if (device_count > 1) {
            // Split B and C for multiple devices
            oneapi::mkl::blas::gemm(queues[0], transA, transB, m, n/2, k, alpha, A_device[0], ldA, B_device[0], ldB, beta, C_device[0], ldC);
            oneapi::mkl::blas::gemm(queues[1], transA, transB, m, n/2, k, alpha, A_device[1], ldA, B_device[1], ldB, beta, C_device[1], ldC);
        } else {
            oneapi::mkl::blas::gemm(queues[0], transA, transB, m, n, k, alpha, A_device[0], ldA, B_device[0], ldB, beta, C_device[0], ldC);
        }
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tCaught synchronous SYCL exception during GEMM:\n"
                  << e.what() << std::endl;
    }

    // Wait for GEMM call(s) to complete.
    for (auto &q: queues) q.wait();

    //
    // Copy C from devices to host
    //

    if (device_count > 1) {
        queues[0].copy(C_device[0], C_host,             sizec/2);
        queues[1].copy(C_device[1], C_host + ldC * n/2, sizec/2);
    } else {
        queues[0].copy(C_device[0], C_host, sizec);
    }

    // Wait for copies to complete.
    for (auto &q: queues) q.wait();

    //
    // Post Processing
    //

    std::cout << "\n\t\tGEMM parameters:\n";
    std::cout << "\t\t\ttransA = " << transA << ", transB = " << transB << std::endl;
    std::cout << "\t\t\tm = " << m << ", n = " << n << ", k = " << k << std::endl;
    std::cout << "\t\t\tlda = " << ldA << ", ldB = " << ldB << ", ldC = " << ldC << std::endl;
    std::cout << "\t\t\talpha = " << alpha << ", beta = " << beta << std::endl;


    std::cout << "\n\t\tUpper left 2x2 blocks of A, B, C matrices:" << std::endl;

    // Output the top 2x2 block of A/B/C.
    print_2x2_matrix_values(A_host, ldA, "A");
    print_2x2_matrix_values(B_host, ldB, "B");
    print_2x2_matrix_values(C_host, ldC, "C");

    // Free memory.
    sycl::free(A_host, ctx);
    sycl::free(B_host, ctx);
    sycl::free(C_host, ctx);

    for (int i = 0; i < device_count; i++) {
        sycl::free(A_device[i], queues[i]);
        sycl::free(B_device[i], queues[i]);
        sycl::free(C_device[i], queues[i]);
    }
}

//
// Description of example setup, apis used and supported floating point type precisions
//
void print_example_banner() {

    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# General Matrix-Matrix Multiplication using Multiple Devices " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# C = alpha * A * B + beta * C" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A, B and C are general dense matrices and alpha, beta are" << std::endl;
    std::cout << "# floating point type precision scalars." << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   gemm" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;

}


//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices
//
//  For each device selected and each data type supported, Gemm Example
//  runs with all supported data types
//
int main (int argc, char ** argv) {


    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto dev_type : list_of_devices) {

        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, dev_type);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[dev_type] << ".\n";

            std::cout << "\tRunning with single precision real data type:" << std::endl;
            run_gemm_example<float>(my_dev);

            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                std::cout << "\tRunning with double precision real data type:" << std::endl;
                run_gemm_example<double>(my_dev);
            }

            std::cout << "\tRunning with single precision complex data type:" << std::endl;
            run_gemm_example<std::complex<float>>(my_dev);

            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                std::cout << "\tRunning with double precision complex data type:" << std::endl;
                run_gemm_example<std::complex<double>>(my_dev);
            }
        } else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[dev_type] << " devices found; Fail on missing devices is enabled.\n";
                return 1;
#else
            std::cout << "No " << sycl_device_names[dev_type] << " devices found; skipping " << sycl_device_names[dev_type] << " tests.\n";
#endif
        }


    }

    mkl_free_buffers();
    return 0;

}
