# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT

from abc import ABC, abstractmethod
import atexit
import os
from pathlib import Path
from typing import List
import natsort
import numpy as np
from tqdm import tqdm

import multiprocess as mp

from mpp.core.unit_filter import UnitFilter
from mpp.core.types import VerboseLevel
from mpp.console_output import ConsoleOutput
from mpp.core.views import ViewType, ViewAggregationLevel


class _DataProcessor(ABC):

    def __init__(self,
                        data_accumulator,
                        view_generator,
                        view_writer,
                        event_reader,
                        unit_range_map=None,
                        verbose: VerboseLevel = VerboseLevel.INFO):
        self.data_accumulator = data_accumulator
        self.view_generator = view_generator
        self.view_writer = view_writer
        self.event_reader = event_reader
        self.unit_range_map = unit_range_map
        self._verbose = verbose

        self.detail_views = None

    @abstractmethod
    def process_partitions(self, partitions: List['Partition'], parallel_cores: int = 1, no_detail_views: bool = False):
        pass

    @abstractmethod
    def process_partition(self, partition: 'Partition'):
        pass

    def _get_dataframe_from_partition(self, partition: 'Partition'):
        event_reader = self.event_reader(partition=partition, chunk_size=0)
        event_df = next(event_reader)
        if self.unit_range_map:
            event_df = self._filter_units_from_dataframe(event_df)
        return event_df

    def _filter_units_from_dataframe(self, event_df):
        valid_core_types = self.__get_valid_core_types()
        unit_filter = UnitFilter(self.unit_range_map, valid_core_types)
        event_df = unit_filter.filter_units(event_df)
        return event_df

    def __get_valid_core_types(self):
        core_views = list(filter(lambda view: view.attributes.aggregation_level == ViewAggregationLevel.CORE,
                                 self.view_generator.views))
        return list(set([view.attributes.device.type_name for view in core_views]))


class _SerialDataProcessor(_DataProcessor):

    def process_partitions(self, partitions: List['Partition'], parallel_cores: int = 1, no_detail_views: bool = False):
        if ConsoleOutput.is_regular_verbosity(self._verbose):
            progress_bar = tqdm(partitions, leave=(ConsoleOutput.is_regular_verbosity(self._verbose)))
        else:
            progress_bar = partitions
        for idx, partition in enumerate(progress_bar):
            if ConsoleOutput.is_regular_verbosity(self._verbose):
                progress_bar.set_description(f'Processing partition {idx + 1} out of {len(partitions)}')
            self.process_partition(partition)

    def process_partition(self, partition):
        event_df = self._get_dataframe_from_partition(partition)
        first_sample, last_sample = partition.first_sample, partition.last_sample
        self.data_accumulator = self.process_dataframe(event_df, first_sample, last_sample)

    def process_dataframe(self, event_df, first_sample, last_sample):
        # Generate partial detail views for the partition and write to storage
        event_aggregates = self.view_generator.compute_aggregates(event_df)
        self.data_accumulator.update_aggregates(event_aggregates)
        self.handle_detail_views(event_df, first_sample, last_sample)
        if not self.detail_views:
            self.data_accumulator.update_statistics(df=event_df)
        return self.data_accumulator

    def handle_detail_views(self, event_df, first_sample, last_sample):
        self.detail_views = self.view_generator.generate_detail_views(event_df)
        if self.detail_views:
            self.view_writer.write(list(self.detail_views.values()), first_sample, last_sample)
            self.data_accumulator.update_statistics(self.detail_views)


class _ParallelDataProcessor(_DataProcessor):

    def __init__(self,
                        data_accumulator,
                        view_generator,
                        view_writer,
                        event_reader,
                        unit_range_map=None,
                        verbose: VerboseLevel = VerboseLevel.INFO):
        from cli.serdes import DataViewsSerializer
        super().__init__(data_accumulator, view_generator, view_writer, event_reader, unit_range_map, verbose)
        self.serializer = DataViewsSerializer()

    def process_partitions(self, partitions: List['Partition'], parallel_cores: int = 1, no_detail_views: bool = False):
        from cli.serdes import DataViewsDeserializer
        MAX_NUMBER_OF_CORES = 60
        parallel = parallel_cores
        with self.serializer.parent_dir as tmp_dir:
            if parallel_cores and parallel_cores > MAX_NUMBER_OF_CORES:
                print(f'Warning: parallel processing on greater than {MAX_NUMBER_OF_CORES} cores is not currently '
                    f'supported')
            if not parallel_cores:
                parallel = None
                parallel_cores = mp.cpu_count()
            number_of_cores = np.min([mp.cpu_count(), parallel_cores, len(partitions), MAX_NUMBER_OF_CORES])
            partition_str = f' due to {len(partitions)} partitions in the input file' if number_of_cores == len(
                partitions) and len(partitions) != parallel_cores else ''
            parallel_option_str = ' (use -p to specify number of processes)' if not parallel else ''
            if ConsoleOutput.is_regular_verbosity(self._verbose):
                print(f'Processing in parallel with {number_of_cores} out of {mp.cpu_count()} processes' + partition_str
                    + parallel_option_str + '...', end='', flush=True)
            atexit.register(self.serializer.cleanup)
            with mp.Pool(processes=number_of_cores) as pool:
                pool.map(self.process_partition, partitions)
                pool.close()
                pool.join()
            # Initialize a deserializer for views
            deserializer = DataViewsDeserializer()
            self.data_accumulator = self._handle_temp_detail_view_files(no_detail_views, deserializer, tmp_dir)
            self.data_accumulator = self._handle_temp_summary_view_files(deserializer, tmp_dir)

    def process_partition(self, partition: 'Partition'):
        event_df = self._get_dataframe_from_partition(partition)
        if ConsoleOutput.is_regular_verbosity(self._verbose):
            print('.', end='', flush=True)
        # Generate partial detail views for the partition and write to storage
        summary_computations = self.view_generator.compute_aggregates(event_df)
        self.serializer.write_views(summary_computations, partition=partition)
        detail_views = self.view_generator.generate_detail_views(event_df)
        if detail_views:
            self.serializer.write_views(list(detail_views.values()), partition=partition)

    def _handle_temp_detail_view_files(self, no_detail_views, deserializer, tmp_dir):
        from cli.writers.views import DataMerger
        # Deserialize and combine all partial detail views
        detail_view_files = list(filter(lambda x: '__' + ViewType.DETAILS.name + '__' in x, os.listdir(tmp_dir)))
        data_merger = DataMerger(self.view_writer)
        include_details = not no_detail_views
        if ConsoleOutput.is_regular_verbosity(self._verbose):
            progress_bar = tqdm(natsort.natsorted(detail_view_files))
        else:
            progress_bar = natsort.natsorted(detail_view_files)
        for idx, filename in enumerate(progress_bar):
            if ConsoleOutput.is_regular_verbosity(self._verbose):
                progress_bar.set_description(f'\nWriting partition {idx + 1} out of {len(progress_bar)} to CSV...')
            detail_views, partition = deserializer.read_views(Path(tmp_dir) / filename)
            if include_details:
                data_merger.write_to_detail_views(detail_views, partition)
            self.data_accumulator.update_statistics(detail_views)
        return self.data_accumulator

    def _handle_temp_summary_view_files(self, deserializer, tmp_dir):
        summary_view_files = list(filter(lambda x: '__' + ViewType.SUMMARY.name + '__' in x, os.listdir(tmp_dir)))
        for filename in natsort.natsorted(summary_view_files):
            summary_views, partition = deserializer.read_views(Path(tmp_dir) / filename)
            self.data_accumulator.update_aggregates(list(summary_views.values()))
        return self.data_accumulator


class DataProcessorIds:
    PARALLEL = 'parallel'
    SERIAL = 'serial'

class DataProcessorFactory:

    data_processors = {
        DataProcessorIds.PARALLEL: _ParallelDataProcessor,
        DataProcessorIds.SERIAL: _SerialDataProcessor
    }

    def create(self, is_parallel, data_accumulator, view_generator, view_writer, event_reader, unit_filters, verbose=0):
        if is_parallel:
            return self.data_processors[DataProcessorIds.PARALLEL](data_accumulator, view_generator, view_writer,
                                                                   event_reader, unit_filters, verbose)
        else:
            return self.data_processors[DataProcessorIds.SERIAL](data_accumulator, view_generator, view_writer,
                                                                 event_reader, unit_filters, verbose)
