/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file PrefixSum.cpp

 HPCG routine
 */

#include "PrefixSum.hpp"
#include "EsimdHelpers.hpp"
#include "Helpers.hpp"

// turn on and off the printing for debug
//#define USE_PRINTF
#ifdef USE_PRINTF
#define DEBUG_PRINTF(fmt, ...) do { printf(fmt, ##__VA_ARGS__); fflush(0); } while (0)
#else
#define DEBUG_PRINTF(fmt, ...) do {} while(0)
#endif


namespace {

/**
 * Modification of prefix sum implementation in https://github.com/intel/llvm/blob/sycl/sycl/test-e2e/ESIMD/PrefixSum.cpp
 * to support all sizes and 64b integers
 */

/**
 * Start of Device functions
 */
// Prefix scan of elements in a simd register
template <size_t simdLen, typename int_t, typename vec>
static inline void regPrefixSum(vec &simdReg) {
  if constexpr (simdLen == 32) {
    // step 1
    simdReg.template select<1, 1, 16, 2>(0, 1) += simdReg.template select<1, 1, 16, 2>(0, 0);
    // step 2
    simdReg.template select<1, 1, 8, 4>(0, 2) += simdReg.template select<1, 1, 8, 4>(0, 1);
    simdReg.template select<1, 1, 8, 4>(0, 3) += simdReg.template select<1, 1, 8, 4>(0, 1);
    // step 3
    simdReg.template select<1, 1, 4, 1>(0, 4)  += simdReg.template replicate_vs_w_hs<1, 0, 4, 0>(0, 3);
    simdReg.template select<1, 1, 4, 1>(0, 12) += simdReg.template replicate_vs_w_hs<1, 0, 4, 0>(0, 11);
    simdReg.template select<1, 1, 4, 1>(0, 20) += simdReg.template replicate_vs_w_hs<1, 0, 4, 0>(0, 19);
    simdReg.template select<1, 1, 4, 1>(0, 28) += simdReg.template replicate_vs_w_hs<1, 0, 4, 0>(0, 27);
    // step 4
    simdReg.template select<1, 1, 8, 1>(0, 8)  += simdReg.template replicate_vs_w_hs<1, 0, 8, 0>(0, 7);
    simdReg.template select<1, 1, 8, 1>(0, 24) += simdReg.template replicate_vs_w_hs<1, 0, 8, 0>(0, 23);
    // step 5
    simdReg.template select<1, 1, 16, 1>(0, 16) += simdReg.template replicate_vs_w_hs<1, 0, 16, 0>(0, 15);
  }
  else if constexpr (simdLen == 16) {
    // Split simd16 reg to 2xsimd8 reg for int64_t
    esimd::simd<int_t, 8> S1, S2;
    int_t S1End, S2End;
    S1 = simdReg.template select<1, 1, 8, 1>(0, 0);
    S2 = simdReg.template select<1, 1, 8, 1>(0, 8);

    // Prefix Sum lower part
    S1.template select<4, 2>(1) += S1.template select<4, 2>(0);
    S1.template select<2, 4>(2) += S1.template select<2, 4>(1);
    S1.template select<2, 4>(3) += S1.template select<2, 4>(1);
    S1End = S1.template select<1, 1>(3);
    S1.template select<4, 1>(4) += S1End;

    // Prefix Sum upper part
    S2.template select<4, 2>(1) += S2.template select<4, 2>(0);
    S2.template select<2, 4>(2) += S2.template select<2, 4>(1);
    S2.template select<2, 4>(3) += S2.template select<2, 4>(1);
    S2End = S2.template select<1, 1>(3);
    S2.template select<4, 1>(4) += S2End;

    // Store back in original simdReg
    S1End = S1.template select<1, 1>(7);
    simdReg.template select<1, 1, 8, 1>(0, 0) = S1;
    simdReg.template select<1, 1, 8, 1>(0, 8) = S2 + S1End;
  }
}

// final reduction. One thread to compute prefix all remaining entries
template<size_t simdLen, bool useBlockLoadStore, typename int_t>
void prefixSum_singleThread(sycl::item<1> it, int_t *acc, uint32_t stride_elems, uint32_t remaining) {
  (void)it; // to prevent compiler warning

  // simdLen when data is interpreted as 32b
  constexpr size_t simdLen32 = sizeof(int_t) == 4 ? simdLen : 2 * simdLen;

  // 32b offset okay, working memory size will not exceed L2 cache size
  // this is handled at driver level (prefix_sum_esimd3)
  esimd::simd<uint32_t, simdLen32> voff(0,1);

  // element offsets for scattered read: [e0,e1,e2,...,e31] where e_i =
  // global_offset + # prefix_entries + prefix_entries - 1;
  if constexpr (sizeof(int_t) == 4) {
    // [0,1,2,...,31]
    voff = ((voff + 1) * stride_elems - 1) * sizeof(int_t);
  }
  else if constexpr (sizeof(int_t) == 8) {
    // gather doesn't support 64bit, so use 2xsimdLen 32b gather
    // [0,0,1,1,2,2,...,15,15]
    voff = voff >> 1;
    voff = ((voff + 1) * stride_elems - 1) * sizeof(int_t);
    voff.template select<16,2>(1) += sizeof(int32_t);
  }

  esimd::simd<int_t, simdLen> S;
  esimd::simd<std::int32_t, simdLen32> S32;
  esimd::simd<int_t, 1> prev = 0;

  uint32_t n_iter = floor_div<uint32_t>(remaining, simdLen), i;

  for (i = 0; i < n_iter; i++) {

    if constexpr (useBlockLoadStore)
      S = esimd_lsc_block_load<int_t, uint32_t, simdLen>(acc, i * simdLen);
    else {
      // Reinterpret as simd<int32_t,simdLen32> for gather
      S32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
        (const std::int32_t*)acc, voff);
      // Reinterpret back to simd<int_t,simdLen> for in-register prefix sum
      S = S32.template bit_cast_view<int_t, 1, simdLen>();
    }

    auto cnt_table = S.template bit_cast_view<int_t, 1, simdLen>();
    cnt_table.column(0) += prev;

    regPrefixSum<simdLen, int_t>(cnt_table);

    if constexpr (useBlockLoadStore)
      esimd_lsc_block_store<int_t, uint32_t, simdLen>(acc, i * simdLen, S);
    else {
      // Reinterpret as simd<int32_t,simdLen32> for scatter
      S32 = S.template bit_cast_view<int32_t, 1, simdLen32>();
      esimd_exp::lsc_scatter<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
        (std::int32_t*)acc, voff, S32);
    }

    voff += stride_elems * sizeof(int_t) * simdLen;
    prev = cnt_table.column(simdLen - 1);
  }

  if (n_iter * simdLen < remaining) {
    esimd::simd_mask<simdLen32> mask = voff < remaining * stride_elems * sizeof(int_t);
    // Gather loads 32b only.
    const esimd::simd<std::int32_t, simdLen32> pass_thru{0};
    S32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
      (const std::int32_t*)acc, voff, mask, pass_thru);
    S = S32.template bit_cast_view<int_t, 1, simdLen>();

    auto cnt_table = S.template bit_cast_view<int_t, 1, simdLen>();
    cnt_table.column(0) += prev;

    regPrefixSum<simdLen, int_t>(cnt_table);

    S32 = S.template bit_cast_view<int32_t, 1, simdLen32>();
    esimd_exp::lsc_scatter<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
      (std::int32_t*)acc, voff, S32, mask);

    prev = cnt_table.column(simdLen - 1);
  }
}

// Each workgroup computes reduction on its part of the array and stores the value in
// the right most element of its subarray. Other elements in subarray are not modified.
template<size_t simdLen, size_t blockSize, bool useBlockLoad, typename int_t>
void localReduction(sycl::item<1> it, int_t *acc,
                    uint32_t stride_elems, uint32_t stride_threads, uint32_t remaining)
{

  // simdLen when data is interpreted as 32b
  constexpr size_t simdLen32 = sizeof(int_t) == 4 ? simdLen : 2 * simdLen;

  auto h_pos = it.get_id(0);
  auto num_wg = it.get_range(0);

  const size_t rem = remaining % blockSize;
  if (h_pos == num_wg - 1 && rem > 0)  return; // need to do the full interior ones only,  localPrefixSum will handle last remainder itself

  esimd::simd<uint32_t, simdLen32> voff(0, 1); // 0, 1, 2, ..., 31

  // global offset for a thread, using 32b offset since working memory size
  // will not exceed L2 cache size this is handled at the driver level (prefix_sum_esimd)
  uint32_t global_offset = (h_pos * stride_threads);
  acc += global_offset;

  // element offsets for scattered read: [e0,e1,e2,...,e31] where e_i =
  // # prefix_entries + prefix_entries - 1;
  if constexpr (sizeof(int_t) == 4) {
    voff = ((voff + 1) * stride_elems - 1) * sizeof(int_t);
  }
  else if constexpr (sizeof(int_t) == 8) {
    voff = voff >> 1; // 0, 0, 1, 1 ..., 15, 15
    voff = ((voff + 1) * stride_elems - 1) * sizeof(int_t);
    voff.template select<16,2>(1) += sizeof(int32_t);
  }

  esimd::simd<int_t, simdLen> S, T;
  esimd::simd<int32_t, simdLen32> S32, T32;

  if constexpr (useBlockLoad)
    S = esimd_lsc_block_load<int_t, uint32_t, simdLen>(acc, 0);
  else {
    S32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
      (const std::int32_t*)acc, voff);
    S.template select<simdLen, 1>(0) = S32.template bit_cast_view<int_t, 1, simdLen>();
  }

  constexpr uint32_t n_iter = blockSize / simdLen;

#pragma unroll
  for (uint32_t j = 1; j < n_iter; j++) {
    if constexpr (useBlockLoad)
      T = esimd_lsc_block_load<int_t, uint32_t, simdLen>(acc, j * simdLen);
    else {
      voff += (stride_elems * simdLen) * sizeof(int_t);
      T32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
        (const std::int32_t*)acc, voff);
      T.template select<simdLen, 1>(0) = T32.template bit_cast_view<int_t, 1, simdLen>();
    }

    S += T;
  }

  auto cnt_table = S.template bit_cast_view<int_t, simdLen, 1>();

  // sum reduction for each bin
  if constexpr (sizeof(int_t) == 4)
    cnt_table.template select<16, 1, 1, 1>(0, 0) +=  cnt_table.template select<16, 1, 1, 1>(16, 0);
  cnt_table.template select< 8, 1, 1, 1>(0, 0) +=  cnt_table.template select< 8, 1, 1, 1>(8, 0);
  cnt_table.template select< 4, 1, 1, 1>(0, 0) +=  cnt_table.template select< 4, 1, 1, 1>(4, 0);
  cnt_table.template select< 2, 1, 1, 1>(0, 0) +=  cnt_table.template select< 2, 1, 1, 1>(2, 0);
  cnt_table.template select< 1, 1, 1, 1>(0, 0) +=  cnt_table.template select< 1, 1, 1, 1>(1, 0);

  esimd::simd<uint32_t, 8> voff8(0, 1); // 0, 1, 2, 3, ..., 7
  if constexpr (sizeof(int_t) == 8) voff8 = voff8 >> 1; // 0, 0, 1, 1, ..., 3, 3 for int64_t case

  esimd::simd_mask<8> mask = voff8 < 1; // mask gather/scatter to a single element

  if constexpr (sizeof(int_t) == 4)
    voff8 = (voff8 + (stride_threads - 1)) * sizeof(int_t);
  else if constexpr (sizeof(int_t) == 8) {
    voff8 = (voff8 + (stride_threads - 1)) * sizeof(int_t);
    voff8.template select<4,2>(1) += sizeof(int32_t);
  }

  S32 = S.template bit_cast_view<int32_t, 1, simdLen32>();
  esimd_exp::lsc_scatter<std::int32_t, 1, ds::default_size, nc, nc, 8>(
    (std::int32_t*)acc, voff8, S32.template select<simdLen32 / 4, 1>(0), mask);
}


// Each workgroup computes local prefix_sum on its part of the array
template<size_t simdLen, size_t blockSize, bool useBlockLoadStore, typename int_t>
void localPrefixSum(sycl::item<1> it, int_t *acc,
                    uint32_t stride_elems, uint32_t stride_thread,
                    uint32_t remaining) {

  // simdLen when data is interpreted as 32b
  constexpr size_t simdLen32 = sizeof(int_t) == 4 ? simdLen : 2 * simdLen;

  auto h_pos = it.get_id(0);
  auto num_wg = it.get_range(0);

  esimd::simd<uint32_t, simdLen32> voff(0,1);

  // global offset for a thread, using 32b offset since working memory size
  // will not exceed L2 cache size this is handled at driver level (prefix_sum_esimd3)
  uint32_t global_offset = (h_pos * stride_thread);
  acc += global_offset;

  // element offsets for scattered read: [e0,e1,e2,...,e31] where e_i =
  // global_offset + # prefix_entries + prefix_entries - 1;
  if constexpr (sizeof(int_t) == 4)
    voff = ((voff + 1) * stride_elems - 1) * sizeof(int_t);
  else if constexpr (sizeof(int_t) == 8) {
    voff = voff >> 1;
    voff = ((voff + 1) * stride_elems - 1) * sizeof(int_t);
    voff.template select<16,2>(1) += sizeof(int32_t);
  }

  // read the accumulated sum from its previous chunk
  esimd::simd<int_t, 1> prev = 0;
  if (h_pos == 0)
    prev = 0;
  else {
    // WA gather does not take less than 8
    esimd::simd<uint32_t, 8> voff8(0, 1);
    if constexpr (sizeof(int_t) == 8) {
      voff8 = voff8 >> 1;
      voff8 *= sizeof(int_t);
      voff8.template select<4,2>(1) += sizeof(int32_t);
    }
    else {
      voff8 *= sizeof(int_t);
    }
    auto temp32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, 8>(
      (const std::int32_t*)(acc - 1), voff8);
    auto temp = temp32.template bit_cast_view<int_t, 1, simdLen / 4>();
    prev = temp.template select<1, 1, 1, 1>(0);
  }

  esimd::simd<int_t, simdLen> S;
  esimd::simd<int32_t, simdLen32> S32;

  const size_t rem = remaining % blockSize;

  if (h_pos < num_wg - 1 || rem == 0) {

    constexpr uint32_t n_iter = blockSize / simdLen;
#pragma unroll
    for (uint32_t i = 0; i < n_iter; i++) {

      if constexpr (useBlockLoadStore) {
         S = esimd_lsc_block_load<int_t, uint32_t, simdLen>(acc, i * simdLen);
      }
      else {
        S32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
          (const std::int32_t*)acc, voff);
        S = S32.template bit_cast_view<int_t, 1, simdLen>();
      }

      auto cnt_table = S.template bit_cast_view<int_t, 1, simdLen>();
      cnt_table.column(0) += prev;

      regPrefixSum<simdLen, int_t>(cnt_table);

      // during reduction phase, we've already computed prefix sum and saved in
      // the last entry. Here we avoid double counting the last entry
      if (i == n_iter - 1) {
        cnt_table.column(simdLen - 1) -= cnt_table.column(simdLen - 2);
      }

      if constexpr (useBlockLoadStore) {
        esimd_lsc_block_store<int_t, uint32_t, simdLen>(acc, i * simdLen, S);
      }
      else {
        S32 = S.template bit_cast_view<int32_t, 1, simdLen32>();
        esimd_exp::lsc_scatter<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
          (std::int32_t*)acc, voff, S32);
      }

      voff += stride_elems * sizeof(int_t) * simdLen;
      prev = cnt_table.column(simdLen - 1);
    }
  }
  else {
    // Handle remainder part of array
    for (unsigned i = 0; i < rem; i += simdLen) {

      esimd::simd_mask<simdLen32> mask = voff < rem * stride_elems * sizeof(int_t);
      const esimd::simd<std::int32_t, simdLen32> pass_thru{0};
      S32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
        (const std::int32_t*)acc, voff, mask, pass_thru);
      S = S32.template bit_cast_view<int_t, 1, simdLen>();

      auto cnt_table = S.template bit_cast_view<int_t, 1, simdLen>();
      cnt_table.column(0) += prev;

      regPrefixSum<simdLen, int_t>(cnt_table);

      S32 = S.template bit_cast_view<int32_t, 1, simdLen32>();
      esimd_exp::lsc_scatter<std::int32_t, 1, ds::default_size, nc, nc, simdLen32>(
        (std::int32_t*)acc, voff, S32, mask);

      voff += stride_elems * sizeof(int_t) * simdLen;
      prev = cnt_table.column(simdLen - 1);
    }
  }
}


// For L2 cache-blocking we update the first element of the next sub-buffer with the
// last value of the current sub-buffer
template <typename int_t>
void updateNextSubBuffer(int_t *acc, uint32_t n_entries) {

  esimd::simd<uint32_t, 8> voff(0, 1); // 0, 1, 2, 3, ..., 7

  if constexpr (sizeof(int_t) == 4) {
    esimd::simd_mask<8> mask = voff < 1; // mask gather/scatter to a single element

    const esimd::simd<std::int32_t, 8> pass_thru{0};
    auto prev = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, 8>(
      (const std::int32_t*)(acc + (n_entries - 1)), voff, mask, pass_thru);
    auto curr = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, 8>(
      (const std::int32_t*)(acc + n_entries), voff, mask, pass_thru);

    prev += curr;
    esimd_exp::lsc_scatter<std::int32_t, 1, ds::default_size, nc, nc, 8>(
      (std::int32_t*)(acc + n_entries), voff, prev.template select<8, 1>(0), mask);
  }
  else {
    voff = voff >> 1; // 0, 0, 1, 1, ..., 3, 3
    esimd::simd_mask<8> mask = voff < 1; // mask gather/scatter to a single element
    voff = voff * sizeof(int_t);
    voff.template select<4,2>(1) += sizeof(int32_t);

    const esimd::simd<std::int32_t, 8> pass_thru{0};
    auto prev32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, 8>(
      (const std::int32_t*)(acc + (n_entries - 1)), voff, mask, pass_thru);
    auto curr32 = esimd_exp::lsc_gather<std::int32_t, 1, ds::default_size, nc, nc, 8>(
      (const std::int32_t*)(acc + n_entries), voff, mask, pass_thru);

    auto prev = prev32.template bit_cast_view<int_t, 1, 4>();
    auto curr = curr32.template bit_cast_view<int_t, 1, 4>();

    prev += curr;
    prev32 = prev.template bit_cast_view<std::int32_t, 1, 8>();
    esimd_exp::lsc_scatter<std::int32_t, 1, ds::default_size, nc, nc, 8>(
      (std::int32_t*)(acc + n_entries), voff, prev32.template select<8, 1>(0), mask);
  }
}

// End of Device functions


/**
 * Start of Host-side drivers
 */
template <size_t blockSize, bool prevUpdate, typename int_t>
class prefixSumSingleWGKernel;

// Launches a single work-group kernel which computes the prefix sum
//  -- Handles small arrays
template <size_t blockSize, bool prevUpdate, typename int_t>
static inline sycl::event prefixSumSingleWGDriver(sycl::queue &q, int rank, int_t *buf,
                                           uint32_t elem_stride, uint32_t n_entries,
                                           const std::vector<sycl::event>& dependencies) {
  (void)rank;

  constexpr size_t simdLen = sizeof(int_t) == 4 ? 32 : 16;

  auto last = q.submit([&](sycl::handler &cgh) {
    cgh.depends_on(dependencies);

    if (n_entries > 0) {
      auto kernel = [=](sycl::item<1> it) SYCL_ESIMD_KERNEL {
        if (elem_stride == 1) {
          prefixSum_singleThread<simdLen, true, int_t>(it, buf, elem_stride, n_entries);
          if constexpr (prevUpdate) updateNextSubBuffer<int_t>(buf, n_entries);
        }
        else {
          prefixSum_singleThread<simdLen, false, int_t>(it, buf, elem_stride, n_entries);
        }
      };
      cgh.parallel_for<prefixSumSingleWGKernel<blockSize, prevUpdate, int_t>>(sycl::range<1>{1}, kernel);
    }
  });

  return last;
}

template <size_t blockSize, typename int_t>
class localReductionKernel;

// Launches a kernel where each workgroup computes the reduction of its partial array
template <size_t blockSize, typename int_t>
static inline sycl::event localReductionDriver(sycl::queue &q, int rank, int_t *buf, uint32_t elem_stride,
                                        uint32_t thread_stride, uint32_t n_entries,
                                        const std::vector<sycl::event>& dependencies) {
  (void)rank; // to prevent compiler warning
  constexpr size_t simdLen = sizeof(int_t) == 4 ? 32 : 16;

  auto last = q.submit([&](sycl::handler &cgh) {
    cgh.depends_on(dependencies);

    size_t num_wg = ceil_div<size_t>(n_entries, blockSize);
    auto kernel = [=](sycl::item<1> it) SYCL_ESIMD_KERNEL {
        if (elem_stride == 1) {
          localReduction<simdLen, blockSize, true, int_t>(it, buf, elem_stride, thread_stride, n_entries);
        }
        else {
          localReduction<simdLen, blockSize, false, int_t>(it, buf, elem_stride, thread_stride, n_entries);
        }
    };

    cgh.parallel_for<localReductionKernel<blockSize, int_t>>(sycl::range<1>{num_wg}, kernel);
  });

  return last;
}



template <size_t blockSize, bool prevUpdate, typename int_t>
class localPrefixSumKernel;

// Launches a kernel where each workgroup computes the prefix sum of its partial array
template <size_t blockSize, bool prevUpdate, typename int_t>
static inline sycl::event localPrefixSumDriver(sycl::queue &q, int rank, int_t *buf, uint32_t elem_stride,
                                        uint32_t thread_stride, uint32_t n_entries,
                                        const std::vector<sycl::event>& dependencies) {
  (void)rank;

  constexpr size_t simdLen = sizeof(int_t) == 4 ? 32 : 16;

  auto last = q.submit([&](sycl::handler &cgh) {
    cgh.depends_on(dependencies);

    const size_t num_wg = ceil_div<size_t>(n_entries, blockSize);

    auto kernel = [=](sycl::item<1> it) SYCL_ESIMD_KERNEL {
      if (elem_stride == 1) {
        localPrefixSum<simdLen, blockSize, true, int_t>(it, buf, elem_stride, thread_stride, n_entries);
        if (it.get_id(0) == num_wg - 1) {
          if constexpr (prevUpdate) updateNextSubBuffer<int_t>(buf, n_entries);
        }
      }
      else {
        localPrefixSum<simdLen, blockSize, false, int_t>(it, buf, elem_stride, thread_stride, n_entries);
      }
    };
    cgh.parallel_for<localPrefixSumKernel<blockSize, prevUpdate, int_t>>(sycl::range<1>{num_wg}, kernel);
  });

  return last;
}



//
// Hierarchical prefix sum currently does prefix sum with recursion in palce
//  -- Has no extra memory requirement
//  -- n_entries will not exceed size of cache-blocking determined in prefix_sum_esimd3
template <size_t blockSize, size_t blockSizeLow, bool prevUpdate, typename int_t>
static inline sycl::event hierarchical_prefix(sycl::queue &q, int rank, int_t *buf, uint32_t elem_stride,
                                 uint32_t thread_stride, uint32_t n_entries, uint32_t entry_per_th,
                                 const std::vector<sycl::event>& dependencies = {})
{
  constexpr size_t remainder_threshold = 2 * blockSize;

  // If array is small enough just launch a single work-group kernel to compute prefix sum
  if (n_entries <= remainder_threshold) {
      DEBUG_PRINTF("[rank %ld]: hierarchical_prefix: base_case singleWG, elem_stride = %ld\n", rank, elem_stride);
    // one single thread
    return prefixSumSingleWGDriver<blockSize, prevUpdate, int_t>(
            q, rank, buf, elem_stride, n_entries, dependencies);
  }

  // Partition the array into chunks of size blockSize (or blockSizeLow) and compute the reduction for each chunk.
  // The reductions are stored in the right-most entry of each chunk, other elements in the chunk are not modified
  sycl::event ev_iterative;
  if (entry_per_th == blockSize) {
    DEBUG_PRINTF("[rank %ld]: hierarchical_prefix: localReductionDriver<blockSize>, elem_stride = %ld, thread_strid = %ld\n", rank, elem_stride, thread_stride);
    ev_iterative = localReductionDriver<blockSize, int_t>(
            q, rank, buf, elem_stride, thread_stride, n_entries, dependencies);
  }
  else {
    DEBUG_PRINTF("[rank %ld]: hierarchical_prefix: localReductionDriver<blockSizeLow>, elem_stride = %ld, thread_stride = %ld\n", rank, elem_stride, thread_stride);
    ev_iterative = localReductionDriver<blockSizeLow, int_t>(
            q, rank, buf, elem_stride, thread_stride, n_entries, dependencies);
  }

  // Recursively call hierarchical_prefix2 to compute the prefix sum of all the reductions computed above
  sycl::event ev_recurse;
  if (n_entries / entry_per_th > 4096) {
    DEBUG_PRINTF("[rank %ld]: hierarchical_prefix: recurse, n_entries / entry_per_th = %u > 4096\n", rank, n_entries / entry_per_th);
    ev_recurse = hierarchical_prefix<blockSize, blockSizeLow, prevUpdate, int_t>(
      q, rank, buf, thread_stride, thread_stride * blockSize, n_entries / entry_per_th, blockSize, {ev_iterative});
  }
  else if (n_entries / entry_per_th > 0) {
    DEBUG_PRINTF("[rank %ld]: hierarchical_prefix: recurse, 0 < n_entries / entry_per_th = %u <= 4096\n", rank, n_entries / entry_per_th);
    // if number of remaining entries <= 4K , each thread  accumulates smaller
    // number of entries to keep EUs saturated
    ev_recurse = hierarchical_prefix<blockSize, blockSizeLow, prevUpdate, int_t>(
      q, rank, buf, thread_stride, thread_stride * blockSizeLow, n_entries / entry_per_th, blockSizeLow, {ev_iterative});
  }

  // At this point the right most value of each chunk should contain the correct value in the global prefix sum
  // Compute the local prefix sum of each chunk (in parallel) and use the right-most element
  // to adjust the local prefix sum to a global prefix sum.
  sycl::event ev_last;
  if (entry_per_th == blockSize) {
    DEBUG_PRINTF("[rank %ld]: hierarchical_prefix: localPrefixSumDriver<blockSize>, entry_per_th = %u, elem_stride = %ld, thread_stride = %ld\n", rank, entry_per_th, elem_stride, thread_stride);
    ev_last = localPrefixSumDriver<blockSize, prevUpdate, int_t>(
            q, rank, buf, elem_stride, thread_stride, n_entries, {ev_recurse});
  }
  else {
    DEBUG_PRINTF("[rank %ld]: hierarchical_prefix: localPrefixSumDriver<blockSizeLow>, entry_per_th = %u, elem_stride = %ld, thread_stride = %ld\n", rank, entry_per_th, elem_stride, thread_stride);
    ev_last = localPrefixSumDriver<blockSizeLow, prevUpdate, int_t>(
            q, rank, buf, elem_stride, thread_stride, n_entries, {ev_recurse});
  }

  return ev_last;
}


// This is the main Driver calls hierarchical_prefix2 with L2 (192MB) cache blocking
// Compute prefix scans in blocks serials, with each block taking up all of L2.
// No problem size limitations
template <size_t blockSize, size_t blockSizeLow, typename int_t>
static inline sycl::event prefix_sum_esimd(sycl::queue &queue,
                                            int rank,
                                            const size_t length,
                                            int_t *array,
                                            const std::vector<sycl::event> &dependencies)
{

  const size_t L2CacheSize = 1024 * 1024 * 192; // 192 MB for PVC

  // Partition the array into sub-arrays of size cacheBK, prefix sums are then
  // computed for each sub-array serially. The size of the arrays provided to
  // hierarchical_prefix2 will not exceed cacheBK
  const size_t cacheBK = L2CacheSize / sizeof(int_t); // Cache-blocking factor
  sycl::event evt;

  DEBUG_PRINTF("[rank %ld]: prefix_sum_esimd: length = %li, cacheBK = %li\n", rank, length, cacheBK);
  for (size_t i = 0; i < length; i += cacheBK) {

    auto subbuf = array + i;

    // First call depends on dependencies, subsequent calls depend on evt
    std::vector<sycl::event> dep = i == 0 ? dependencies : std::vector<sycl::event>({evt});

    if (i + cacheBK < length) {
      DEBUG_PRINTF("[rank %ld]: prefix_sum_esimd: for loop updatePrev = true : i = %li\n", rank, i);
      // Remaining array to process is more than cacheBK so we update first element
      // of next sub-buffer with last element of current sub-buffer.
      evt = hierarchical_prefix<blockSize, blockSizeLow, true, int_t>(
        queue, rank, subbuf, 1, blockSize, std::min(cacheBK, length - i), blockSize, dep);
    }
    else {
      DEBUG_PRINTF("[rank %ld]: prefix_sum_esimd: for loop updatePrev = false: i = %li\n", rank, i);
      // Last call to hierarchical_prefix2, update lastElementBuf at the end.
      evt = hierarchical_prefix<blockSize, blockSizeLow, false, int_t>(
        queue, rank, subbuf, 1, blockSize, std::min(cacheBK, length - i), blockSize, dep);
    }


  }

  return evt;
}

} // anonymous namespace



//
// prefix sum of integer arrays
//
sycl::event prefix_sum(sycl::queue &queue,
                       const size_t length,
                       local_int_t *array,
                       const std::vector<sycl::event> &dependencies)
{
  return prefix_sum_esimd<512, 32, local_int_t>(queue, -1, length, array, dependencies);
}

// pass in rank for printing :)
sycl::event prefix_sum(sycl::queue &queue,
                       const int rank,
                       const size_t length,
                       local_int_t *array,
                       const std::vector<sycl::event> &dependencies)
{
  return prefix_sum_esimd<512, 32, local_int_t>(queue, rank, length, array, dependencies);
}
