/*******************************************************************************
* Copyright (C) 2020 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates usage of oneapi::mkl::rng::device::mcg59
*       random number generators to produce random
*       numbers using unifrom distribution on a SYCL device (CPU, GPU).
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <vector>

#include <sycl/sycl.hpp>
#include "oneapi/mkl/rng/device.hpp"

// local includes
#include "common_for_examples.hpp"

// example parameters
constexpr std::uint64_t seed = 777;
constexpr std::size_t n = 1024;
constexpr int n_print = 10;

//
// examples show usage of rng device functionality, which can be called from both
// host and device sides with scalar and vector generation
//
template <typename Type, std::int32_t VecSize>
int run_example(sycl::queue& queue) {
    if(VecSize == 1) {
        std::cout << "\tRunning scalar example" << std::endl;
    }
    else {
        std::cout << "\tRunning vector example with " << VecSize << " vector size" << std::endl;
    }
    // prepare array for random numbers
    std::vector<Type> r_dev(n), r_host(n);

    // submit a kernel to generate on device
    {
        sycl::buffer<Type> r_buf(r_dev.data(), r_dev.size());

        try {
            queue.submit([&](sycl::handler& cgh) {
                sycl::accessor r_acc(r_buf, cgh, sycl::write_only);
                cgh.parallel_for(sycl::range<1>(n / VecSize), [=](sycl::item<1> item) {
                    size_t item_id = item.get_id(0);
                    oneapi::mkl::rng::device::mcg59<VecSize> engine(seed, item_id * VecSize);
                    oneapi::mkl::rng::device::uniform<Type> distr;

                    auto res = oneapi::mkl::rng::device::generate(distr, engine);
                    if constexpr(VecSize == 1) {
                        r_acc[item_id] = res;
                    }
                    else {
                        res.store(item_id, r_acc);
                    }
                });
            });
            queue.wait_and_throw();
        }
        catch (sycl::exception const& e) {
            std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
            return 1;
        }

        std::cout << "\t\tOutput of generator:" << std::endl;

        auto r_acc = sycl::host_accessor(r_buf, sycl::read_only);
        std::cout << "first " << n_print << " numbers of " << n << ": " << std::endl;
        for (int i = 0; i < n_print; i++) {
            std::cout << r_acc[i] << " ";
        }
        std::cout << std::endl;
    } // buffer life-time ends

    // compare results with host-side generation
    oneapi::mkl::rng::device::mcg59<1> engine(seed);
    oneapi::mkl::rng::device::uniform<Type> distr;

    int err = 0;
    for (std::size_t i = 0; i < n; i++) {
        r_host[i] = oneapi::mkl::rng::device::generate(distr, engine);
        if (r_host[i] != r_dev[i]) {
            std::cout << "error in " << i << " element " << r_host[i] << " " << r_dev[i]
                      << std::endl;
            err++;
        }
    }
    return err;
}

//
// description of example setup, APIs used
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << "# Generate uniformly distributed random numbers example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using APIs:" << std::endl;
    std::cout << "# mcg59 uniform" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << std::endl;
}

//
// main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//

int main() {
    print_example_banner();

    // handler to catch asynchronous exceptions
    auto exception_handler = [](sycl::exception_list exceptions) {
        for (std::exception_ptr const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            }
            catch (sycl::exception const& e) {
                std::cout << "Caught asynchronous SYCL exception:\n" << e.what() << std::endl;
            }
        }
    };

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);
        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            sycl::queue queue(my_dev, exception_handler);

            std::cout << "\n\tRunning with single precision real data type:" << std::endl;
            if (run_example<float, 1>(queue) || run_example<float, 4>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
            if (isDoubleSupported(my_dev)) {
                std::cout << "\n\tRunning with double precision real data type:" << std::endl;
                if (run_example<double, 1>(queue) || run_example<double, 4>(queue)) {
                    std::cout << "FAILED" << std::endl;
                    return 1;
                }
            }
            else {
                std::cout << "Double precision is not supported for this device" << std::endl;
            }
            std::cout << "\n\tRunning with integer data type:" << std::endl;
            if (run_example<std::int32_t, 1>(queue) || run_example<std::int32_t, 4>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
            std::cout << "\n\tRunning with unsigned integer data type:" << std::endl;
            if (run_example<std::uint32_t, 1>(queue) || run_example<std::uint32_t, 4>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices is enabled.\n";
            std::cout << "FAILED" << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }
    std::cout << "PASSED" << std::endl;
    return 0;
}
