/*******************************************************************************
* Copyright (C) 2022 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*  Content:
*       This example demonstrates use of oneapi::mkl::lapack::gesvda_batch
*       to perform batched calculation of truncated SVD.
*
*       The supported floating point data types for matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*******************************************************************************/

#include <oneapi/mkl.hpp>
#include "common_for_examples.hpp"

template <typename data_t, typename real_t = decltype(std::real((data_t)0)), bool is_real = std::is_same_v<data_t,real_t>>
int run_gesvda_batch_example(sycl::device &dev)
{
    const int64_t m = 5, n = 5, lda = m, stride_a = n*lda, ldu = m, stride_u=m*ldu,  batch_size = 2;
    const int64_t stride_s = m, ldvt = m, stride_vt=n*ldvt;
    int64_t iparm[16], irank[batch_size];
    real_t Residual[batch_size];
    const real_t tolerance = std::is_same_v<real_t, float> ? 1e-6 : 1e-8;

    auto v = [] (real_t arg) { if constexpr (is_real) return arg; else return data_t{0, arg}; };

    data_t A[] = {
        v( 1.0), v( 0.0), v( 0.0), v( 0.0), v( 0.0),
        v( 1.0), v( 0.2), v(-0.4), v(-0.4), v(-0.8),
        v( 1.0), v( 0.6), v(-0.2), v( 0.4), v(-1.2),
        v( 1.0), v( 1.0), v(-1.0), v( 0.6), v(-0.8),
        v( 1.0), v( 1.8), v(-0.6), v( 0.2), v(-0.6)
                                                   ,
        v( 0.2), v(-0.4), v(-0.4), v(-0.8), v( 0.0),
        v( 0.4), v( 0.2), v( 0.8), v(-0.4), v( 0.0),
        v( 0.4), v(-0.8), v( 0.2), v( 0.4), v( 0.0),
        v( 0.8), v( 0.4), v(-0.4), v( 0.2), v( 0.0),
        v( 0.0), v( 0.0), v( 0.0), v( 0.0), v( 1.0)
    };

    for (int i=0; i<batch_size; i++) {
        irank[i] =n;
    }
    for (int i=0; i<16; i++) {
        iparm[i] =0;
    }
    iparm[3] = 1; 
    iparm[0] = 0;
    sycl::queue que { dev };

    data_t *A_dev = sycl::aligned_alloc_device<data_t>(64, stride_a*batch_size, que);
    que.copy(A, A_dev, stride_a*batch_size).wait();
    data_t *U_dev = sycl::aligned_alloc_device<data_t>(64, stride_u*batch_size, que);
    data_t *Vt_dev = sycl::aligned_alloc_device<data_t>(64, stride_vt*batch_size, que);
    real_t *S_dev = sycl::aligned_alloc_device<real_t>(64, stride_s*batch_size, que);
    real_t *Residual_dev = sycl::aligned_alloc_device<real_t>(64, batch_size, que);
    int64_t *irank_dev = sycl::aligned_alloc_device<int64_t>(64, batch_size, que);
    que.copy(irank, irank_dev, batch_size).wait();


    int64_t scratchpad_size = oneapi::mkl::lapack::gesvda_batch_scratchpad_size<data_t>(que, m, n, lda, stride_a,
                                     stride_s, ldu, stride_u, ldvt, stride_vt,
                                     batch_size);


    data_t *scratchpad = sycl::aligned_alloc_device<data_t>(64, scratchpad_size, que);

    oneapi::mkl::lapack::gesvda_batch(que, iparm, irank_dev, m, n, A_dev, lda, stride_a, S_dev, stride_s,
    U_dev, ldu, stride_u, Vt_dev, ldvt, stride_vt, tolerance, Residual_dev,
    batch_size, scratchpad, scratchpad_size).wait_and_throw();

    que.copy(Residual_dev, Residual, batch_size).wait();

    const real_t threshold = std::is_same_v<real_t, float> ? 1e-5 : 1e-10;
    bool passed = true;

    for (int i=0; i<batch_size; i++) {
        real_t result = Residual[i] ;
        std::cout << " Residual entry # " << i << " Value " << result << std::endl;
        passed = passed and (result == result) and ( result < threshold);
    }

    int returnval = 0;
    if (passed) {
        std::cout << " Calculations successfully finished " << std::endl;
    } else {
        std::cout << " Computed residual exceeds the tolerance threshold  " << std::endl;
        returnval = 1;
    }

    sycl::free(scratchpad, que);
    sycl::free(irank_dev, que);
    sycl::free(Residual_dev, que);
    sycl::free(S_dev, que);
    sycl::free(Vt_dev, que);
    sycl::free(U_dev, que);
    sycl::free(A_dev, que);

    return returnval;
}

//
// Description of example setup, APIs used and supported floating point type precisions
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# Batched strided truncated SVD example:" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Computes truncated SVD of a batch of matrices." << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices
//
//  For each device selected and each data type supported, gesvda_batch example
//  runs with all supported data types
//
int main(int argc, char **argv) {

    print_example_banner();

    // Find list of devices
    std::list<my_sycl_device_types> listOfDevices;
    set_list_of_devices(listOfDevices);

    bool failed = false;

    for (auto &deviceType: listOfDevices) {
        sycl::device myDev;
        bool myDevIsFound = false;
        get_sycl_device(myDev, myDevIsFound, deviceType);

        if (myDevIsFound) {
          std::cout << std::endl << "Running gesvda_batch examples on " << sycl_device_names[deviceType] << "." << std::endl;

          std::cout << "Running with single precision real data type:" << std::endl;
          failed |= run_gesvda_batch_example<float>(myDev);

          std::cout << "Running with single precision complex data type:" << std::endl;
          failed |= run_gesvda_batch_example<std::complex<float>>(myDev);

          if (isDoubleSupported(myDev)) {
              std::cout << "Running with double precision real data type:" << std::endl;
              failed |= run_gesvda_batch_example<double>(myDev);

              std::cout << "Running with double precision complex data type:" << std::endl;
              failed |= run_gesvda_batch_example<std::complex<double>>(myDev);
          } else {
              std::cout << "Double precision not supported on this device " << std::endl;
              std::cout << std::endl;
          }

        } else {
#ifdef FAIL_ON_MISSING_DEVICES
          std::cout << "No " << sycl_device_names[deviceType] << " devices found; Fail on missing devices is enabled.\n";
          return 1;
#else
          std::cout << "No " << sycl_device_names[deviceType] << " devices found; skipping " << sycl_device_names[deviceType] << " tests.\n";
#endif
        }
    }
    return failed;
}
