// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <algorithm>

#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/matrix/coo.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "common/cuda_hip/base/config.hpp"
#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/runtime.hpp"
#include "common/cuda_hip/base/types.hpp"
#include "common/cuda_hip/components/atomic.hpp"
#include "common/cuda_hip/components/cooperative_groups.hpp"
#include "common/cuda_hip/components/intrinsics.hpp"
#include "common/cuda_hip/components/prefix_sum.hpp"
#include "common/cuda_hip/components/sorting.hpp"
#include "common/cuda_hip/components/thread_ids.hpp"
#include "common/cuda_hip/factorization/par_ilut_filter_kernels.hpp"
#include "common/cuda_hip/factorization/par_ilut_select_common.hpp"
#include "common/cuda_hip/factorization/par_ilut_select_kernels.hpp"
#include "core/components/prefix_sum_kernels.hpp"
#include "core/factorization/par_ilut_kernels.hpp"
#include "core/matrix/coo_builder.hpp"
#include "core/matrix/csr_builder.hpp"
#include "core/matrix/csr_kernels.hpp"
#include "core/synthesizer/implementation_selection.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
/**
 * @brief The parallel ILUT factorization namespace.
 *
 * @ingroup factor
 */
namespace par_ilut_factorization {


// subwarp sizes for filter kernels
using compiled_kernels =
    syn::value_list<int, 1, 2, 4, 8, 16, 32, config::warp_size>;


template <int subwarp_size, typename ValueType, typename IndexType>
void threshold_filter_approx(syn::value_list<int, subwarp_size>,
                             std::shared_ptr<const DefaultExecutor> exec,
                             const matrix::Csr<ValueType, IndexType>* m,
                             IndexType rank, array<ValueType>* tmp,
                             remove_complex<ValueType>* threshold,
                             matrix::Csr<ValueType, IndexType>* m_out,
                             matrix::Coo<ValueType, IndexType>* m_out_coo)
{
    auto values = m->get_const_values();
    IndexType size = m->get_num_stored_elements();
    using AbsType = remove_complex<ValueType>;
    constexpr auto bucket_count = kernel::searchtree_width;
    auto max_num_threads = ceildiv(size, items_per_thread);
    auto max_num_blocks = ceildiv(max_num_threads, default_block_size);

    size_type tmp_size_totals =
        ceildiv((bucket_count + 1) * sizeof(IndexType), sizeof(ValueType));
    size_type tmp_size_partials = ceildiv(
        bucket_count * max_num_blocks * sizeof(IndexType), sizeof(ValueType));
    size_type tmp_size_oracles =
        ceildiv(size * sizeof(unsigned char), sizeof(ValueType));
    size_type tmp_size_tree =
        ceildiv(kernel::searchtree_size * sizeof(AbsType), sizeof(ValueType));
    size_type tmp_size =
        tmp_size_totals + tmp_size_partials + tmp_size_oracles + tmp_size_tree;
    tmp->resize_and_reset(tmp_size);

    auto total_counts = reinterpret_cast<IndexType*>(tmp->get_data());
    auto partial_counts =
        reinterpret_cast<IndexType*>(tmp->get_data() + tmp_size_totals);
    auto oracles = reinterpret_cast<unsigned char*>(
        tmp->get_data() + tmp_size_totals + tmp_size_partials);
    auto tree =
        reinterpret_cast<AbsType*>(tmp->get_data() + tmp_size_totals +
                                   tmp_size_partials + tmp_size_oracles);

    sampleselect_count(exec, values, size, tree, oracles, partial_counts,
                       total_counts);

    // determine bucket with correct rank
    auto bucket = static_cast<unsigned char>(
        sampleselect_find_bucket(exec, total_counts, rank).idx);
    *threshold =
        exec->copy_val_to_host(tree + kernel::searchtree_inner_size + bucket);
    // we implicitly set the first splitter to -inf, but 0 works as well
    if (bucket == 0) {
        *threshold = zero<AbsType>();
    }

    // filter the elements
    auto old_row_ptrs = m->get_const_row_ptrs();
    auto old_col_idxs = m->get_const_col_idxs();
    auto old_vals = m->get_const_values();
    // compute nnz for each row
    auto num_rows = static_cast<IndexType>(m->get_size()[0]);
    auto block_size = default_block_size / subwarp_size;
    auto num_blocks = ceildiv(num_rows, block_size);
    auto new_row_ptrs = m_out->get_row_ptrs();
    if (num_blocks > 0) {
        kernel::bucket_filter_nnz<subwarp_size>
            <<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
                old_row_ptrs, oracles, num_rows, bucket, new_row_ptrs);
    }

    // build row pointers
    components::prefix_sum_nonnegative(exec, new_row_ptrs, num_rows + 1);

    // build matrix
    auto new_nnz = exec->copy_val_to_host(new_row_ptrs + num_rows);
    // resize arrays and update aliases
    matrix::CsrBuilder<ValueType, IndexType> builder{m_out};
    builder.get_col_idx_array().resize_and_reset(new_nnz);
    builder.get_value_array().resize_and_reset(new_nnz);
    auto new_col_idxs = m_out->get_col_idxs();
    auto new_vals = m_out->get_values();
    IndexType* new_row_idxs{};
    if (m_out_coo) {
        matrix::CooBuilder<ValueType, IndexType> coo_builder{m_out_coo};
        coo_builder.get_row_idx_array().resize_and_reset(new_nnz);
        coo_builder.get_col_idx_array() =
            make_array_view(exec, new_nnz, new_col_idxs);
        coo_builder.get_value_array() =
            make_array_view(exec, new_nnz, new_vals);
        new_row_idxs = m_out_coo->get_row_idxs();
    }
    if (num_blocks > 0) {
        kernel::bucket_filter<subwarp_size>
            <<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
                old_row_ptrs, old_col_idxs, as_device_type(old_vals), oracles,
                num_rows, bucket, new_row_ptrs, new_row_idxs, new_col_idxs,
                as_device_type(new_vals));
    }
}


GKO_ENABLE_IMPLEMENTATION_SELECTION(select_threshold_filter_approx,
                                    threshold_filter_approx);


template <typename ValueType, typename IndexType>
void threshold_filter_approx(std::shared_ptr<const DefaultExecutor> exec,
                             const matrix::Csr<ValueType, IndexType>* m,
                             IndexType rank, array<ValueType>& tmp,
                             remove_complex<ValueType>& threshold,
                             matrix::Csr<ValueType, IndexType>* m_out,
                             matrix::Coo<ValueType, IndexType>* m_out_coo)
{
    auto num_rows = m->get_size()[0];
    auto total_nnz = m->get_num_stored_elements();
    auto total_nnz_per_row = total_nnz / num_rows;
    select_threshold_filter_approx(
        compiled_kernels(),
        [&](int compiled_subwarp_size) {
            return total_nnz_per_row <= compiled_subwarp_size ||
                   compiled_subwarp_size == config::warp_size;
        },
        syn::value_list<int>(), syn::type_list<>(), exec, m, rank, &tmp,
        &threshold, m_out, m_out_coo);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
    GKO_DECLARE_PAR_ILUT_THRESHOLD_FILTER_APPROX_KERNEL);


}  // namespace par_ilut_factorization
}  // namespace GKO_DEVICE_NAMESPACE
}  // namespace kernels
}  // namespace gko
