# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: MIT

import logging
import re
import types
from typing import Dict, Callable, List, Tuple, Union

import numpy as np
import pandas as pd

from mpp.core.internals.code_generator import generate_numpy_vectorized_code
from mpp.core.internals.code_generator import is_number
from mpp.core.types import MetricDefinition, RawDataFrameColumns as rdc, ThresholdDefinition

INDICES_RE = re.compile(r"\[(\d+:\d+(?::\d+)?|\d+|:\d+)\]")
EVENT_RE = re.compile(r"df\['(\w|\.)+']")

EVENT_WITH_INDICES_RE = re.compile(r"df\['(\w|\.)+'](\[\d+\]|\[\d+\:\d+(?::\d+)?\]|\[:\d+\]|\[\d+:\])")
EVENT_NO_INDICES_RE = re.compile(r"df\['(\w|\.|\:)+'](?!\[+)")
CONSTANT_NAME_RE = re.compile(r"\'([a-z]|\.|[0-9]|_)+\'")


class _MetricCompiler:
    """
    Generate executable Python functions from `MetricDefinition` objects
    """
    __vectorized_compute_metric_code_prefix = (
        'import numpy as np\n'
        'import pandas as pd\n'
        'from warnings import simplefilter\n\n\n'
        'def get_metrics_mapping():\n'
        '    np.seterr(all=\'ignore\')\n'
        '    simplefilter(action=\'ignore\', category=pd.errors.PerformanceWarning)\n'
        '    metric_mapping = {\n'
        '        ')

    __vectorized_compute_metric_code_suffix = (
        '\n'
        '        }\n'
        '    return metric_mapping\n'
    )

    name_attr = "name"

    def compile(self, metrics: List[MetricDefinition], symbol_table=None, event_info: Tuple[pd.DataFrame, int] = None) \
            -> str:
        """
        Create a new module and write the vectorized code for computation
        of each metric in the input metric list in a file 'computer_metric.py' inside the module.
        The module is named 'generated_code_x' where x represents the number of MetricComputer instances created.

        :param metrics: List[MetricDefinition] List of input metrics for which vectorized code will be generated
        :param symbol_table: Dictionary that maps constant names, referenced in metric formulas,
                             to their values

        :return module_name.compute_metric: name of the module that needs to imported in order
              to use the generated code
        """
        if symbol_table is None:
            symbol_table = {}
        inlined_system_constants = self.__filter_symbols_to_inline(symbol_table)

        metric_functions = []
        for metric in metrics:
            metric_function_body = self._create_metric_function_body(metric, inlined_system_constants, event_info)
            if not metric_function_body:
                continue
            metric_source_code = self._create_metric_function(metric_function_body)
            metric_name = getattr(metric, self.name_attr)
            metric_functions.append(f'"{metric_name}": '
                                    f'({metric_source_code}, "{metric_source_code}"),')
        vectorized_compute_metric_code_body = '\n        '.join(metric_functions)
        vectorized_compute_metric_code = f'{self.__vectorized_compute_metric_code_prefix}' \
                                         f'{vectorized_compute_metric_code_body}' \
                                         f'{self.__vectorized_compute_metric_code_suffix}'

        return vectorized_compute_metric_code

    @staticmethod
    def _resolve_metric_namespace(metric_namespace: Dict, system_constants: Dict) -> Dict:
        """
        Resolves constant references in metric namespace

        :param metric_namespace: a dict representing the metric definition's constants and aliases
        :param system_constants: a dict that maps system constants (e.g., number of sockets) to their values

        :return: a combined dictionary that maps all names referenced by the metric definition to their values
        """

        def adjust_type(v):
            return int(v) if v.isnumeric() else str(v)
        retire_latencies = metric_namespace.get('retire_latencies', {})
        if not system_constants and not retire_latencies:
            return metric_namespace

        resolved_metric_namespace = _MetricCompiler.__resolve_constant_namespace(adjust_type, metric_namespace,
                                                                                 retire_latencies, system_constants)
        resolved_metric_namespace.update(system_constants)
        return resolved_metric_namespace

    @staticmethod
    def __resolve_constant_namespace(adjust_type, metric_namespace, retire_latencies,
                                     system_constants):
        resolved_metric_namespace = {}
        constant_aliases = {key: value for key, value in metric_namespace.items() if key != 'metric_name'}
        for key, value in constant_aliases.items():
            if key in retire_latencies and value in system_constants.keys():
                resolved_metric_namespace[key] = system_constants[retire_latencies[key]]
            elif value in system_constants.keys():
                resolved_metric_namespace[key] = system_constants[value]
            elif key in system_constants.keys():
                resolved_metric_namespace[key] = system_constants[key]
            else:
                resolved_metric_namespace[key] = adjust_type(value)
        return resolved_metric_namespace

    def _create_metric_function_body(self,
                                     metric: MetricDefinition,
                                     system_constants: Dict,
                                     event_info: Tuple[pd.DataFrame, int] = None) -> str:
        """
        Generate the vectorized metric function for input metric.

        :param metric: metric definition
        :param system_constants: a dict that maps system constants (e.g., number of sockets) to their values.
                                 The values of all constants in this dict will be inlined into the generated code.

        :return: a string corresponding to vectorized metric computation with system constant alias
                 replaced with corresponding values
        """
        # Create function namespace
        namespace = self._create_metric_namespace(metric)
        resolved_namespace = self._resolve_metric_namespace(namespace, system_constants)
        function_body = generate_numpy_vectorized_code(metric.formula, resolved_namespace, getattr(metric,
                                                                                                   self.name_attr))
        return function_body

    def _create_metric_function(self, function_body: str):
        return f'lambda df: {function_body}'

    def _create_metric_namespace(self, metric_definition: MetricDefinition) -> dict:
        namespace = {'metric_name': metric_definition.name}
        namespace.update(metric_definition.event_aliases)
        namespace.update(metric_definition.constants)
        namespace.update(metric_definition.retire_latencies)
        return namespace

    def __filter_symbols_to_inline(self, symbol_table):
        # These symbols should not be inlined because their values may change per view
        symbols_that_should_not_be_inlined = ['system.socket_count', 'SOCKET_COUNT']
        for symbol in symbol_table:
            if self.__is_system_per_socket_symbol(symbol) or self.__is_retire_latency_symbol(symbol):
                symbols_that_should_not_be_inlined.append(symbol)
        return dict(filter(lambda item: item[0] not in symbols_that_should_not_be_inlined, symbol_table.items()))

    @staticmethod
    def __is_retire_latency_symbol(symbol):
        return 'retire_latency' in symbol.lower()

    @staticmethod
    def __is_system_per_socket_symbol(symbol):
        return 'per_socket' in symbol.lower() and 'system.' in symbol.lower()


class CompiledMetric:
    """
    A compiled metric that can be executed as a regular python function
    """

    def __init__(self, metric_def: Union[MetricDefinition, ThresholdDefinition],
                 compiled_metric_func: Callable,
                 metric_func_source_code: str = None):
        self.__metric_def = metric_def
        self.__metric_function: Callable = compiled_metric_func
        self.__metric_source_code = metric_func_source_code
        self._log_error: bool = True

    def __call__(self, *args, **kwargs):
        """
        Execute the compiled metrics
        :param args: additional arguments
        :param kwargs: additional optional/keyword arguments
        :return: computed metric
        """
        try:
            return self.__metric_function(*args, **kwargs)
        except (NameError, TypeError, KeyError, pd.errors.IntCastingNaNError) as e:
            return self._handle_exception(e)

    @property
    def definition(self) -> MetricDefinition:
        """
        :return: metric definition
        """
        return self.__metric_def

    @property
    def source_code(self) -> str:
        """
        :return: metric function source code
        """
        return self.__metric_source_code

    def _handle_exception(self, e):
        if self._log_error:
            logging.warning(f'Cannot calculate \'{self.definition.name}\'. {e}')
            self._log_error = False
        return np.nan


class MetricComputer:
    """
    Compute metrics on a given dataframe
    """
    __module_index = 0

    GROUP_INDEX_NAME = rdc.GROUP
    TIMESTAMP_INDEX_NAME = rdc.TIMESTAMP
    metric_compiler = _MetricCompiler()

    def __init__(self,
                 metric_definitions: List[MetricDefinition],
                 symbol_table: Dict = {},
                 event_info: Tuple[pd.DataFrame, int] = None):
        """
        Initialize the metric computer

        :param metric_definitions: List of metrics to compute
        :param symbol_table: Dictionary that maps constant names, referenced in metric formulas,
                             to their values
        """
        self.__group_index_name = None
        self.__timestamp_index_name = None
        self._metric_definitions = metric_definitions
        self.__symbol_table = symbol_table
        self.__update_symbol_table(metric_definitions)
        self.__block_level_metrics_requested = False
        MetricComputer.__module_index += 1
        self._compiled_metrics = self._generate_compiled_metrics(metric_definitions, self.__symbol_table,
                                                                 event_info)

    @property
    def symbol_table(self):
        """
        :return: a copy of the symbol table dictionary with which the `MetricComputer` was initialized
        """
        return self.__symbol_table.copy()

    @property
    def compiled_metrics(self) -> List[CompiledMetric]:
        """
        :return: a list of compiled metrics
        """
        return self._compiled_metrics.copy()

    @property
    def metric_definitions(self) -> List[MetricDefinition]:
        """
        :return: a list of metric definitions
        """
        return self._metric_definitions.copy()

    def update_compiled_metrics(self, compiled_metrics: List[CompiledMetric]):
        self._compiled_metrics = compiled_metrics

    def compute_metric(self, df: pd.DataFrame,
                       constant_values: Dict = None,
                       calculate_block_level: bool = False,
                       group_index_name: str = GROUP_INDEX_NAME,
                       timestamp_index_name: str = TIMESTAMP_INDEX_NAME) -> pd.DataFrame:
        """
        Compute metrics on the input dataframe

        :param df: input dataframe containing event counts.
                   `df` is expected to have the following structure:
                   - Columns: event names, where each column represents a single event
                   - Index: If `calculate_block_level` is True, `df` must have a multi-index where the first two levels
                            represent timestamp and event group id.
                            `df` can have additional levels, e.g. for socket, core, thread...
                   - Rows: event counts for each event
        :param constant_values: an optional dictionary that maps constant expressions, used in metric formulas,
                                to their values (e.g. 'system.socket_count')
        :param calculate_block_level: whether to calculate block level metrics in addition to sample level metrics
        :param group_index_name: index level name in `df` to use for determining event group id.
                                 Used only when `calculate_block_level` is True.
        :param timestamp_index_name: index level name in `df` to use for determining the timestamp.
                                     Used only when `calculate_block_level` is True.

        :return: a new dataframe where each column is a metric. The index of the result dataframe is identical to `df`
        """

        def verify_preconditions():
            if calculate_block_level:
                expected_index_names = df.index.names[:2]
                if group_index_name not in expected_index_names or timestamp_index_name not in expected_index_names:
                    raise KeyError(f'Unable to calculate block-level metrics: '
                                   f'"{group_index_name}" and "{timestamp_index_name}" must be in the input '
                                   f'dataframe index')

        verify_preconditions()
        self.__group_index_name = group_index_name
        self.__timestamp_index_name = timestamp_index_name
        self.__block_level_metrics_requested = calculate_block_level
        df_with_constants = self.__add_constant_values_to_input_dataframe(df, constant_values)
        sample_level_metrics_df = self.__calculate_sample_level_metrics(df_with_constants, self.compiled_metrics)
        block_level_metrics_df = self.__calculate_block_level_metrics(
            df_with_constants, sample_level_metrics_df, self.compiled_metrics)
        all_metrics_df = self.__merge_sample_and_block_results(sample_level_metrics_df, block_level_metrics_df)
        all_metrics_df = all_metrics_df.astype(float)
        return all_metrics_df

    @staticmethod
    def __merge_sample_and_block_results(sample_level_metrics_df, block_level_metrics_df):
        if block_level_metrics_df.empty:
            return sample_level_metrics_df

        sample_level_metrics_df.update(block_level_metrics_df)
        return sample_level_metrics_df

    @staticmethod
    def _import_code(code, module_name):
        module = types.ModuleType(module_name)
        exec(code, module.__dict__)
        return module

    def _generate_compiled_metrics(self, metric_definition_list, symbol_table, event_info):
        vectorized_compute_metric_code = self.metric_compiler.compile(metric_definition_list, symbol_table, event_info)
        generated_module_name = f'generated_code_{self.__module_index}'
        generated_module = self._import_code(vectorized_compute_metric_code, generated_module_name)
        metric_mapping = generated_module.get_metrics_mapping()

        compiled_metrics = self._get_compiled_metrics(metric_mapping)

        return compiled_metrics

    def _get_compiled_metrics(self, metric_mapping):
        compiled_metrics = []
        for metric in self._metric_definitions:
            metric_name = getattr(metric, self.metric_compiler.name_attr)
            if metric_name in metric_mapping:
                func, source = metric_mapping[metric_name]
                compiled_metrics.append(CompiledMetric(metric, func, source))
        return compiled_metrics

    @staticmethod
    def __all_metric_references_are_available(metric_def: MetricDefinition,
                                              df: pd.DataFrame,
                                              system_symbols: Dict) -> bool:
        for event in metric_def.event_aliases:
            symbol_name = metric_def.event_aliases[event]
            if symbol_name not in df.columns:
                logging.debug(f'Excluding \'{metric_def.name}\' from reports because {symbol_name} is unavailable.')
                return False

        for constant in metric_def.constants:
            symbol_name = metric_def.constants[constant]
            if not is_number(symbol_name) and \
                    symbol_name not in system_symbols and \
                    symbol_name not in df.columns:
                logging.debug(f'Excluding \'{metric_def.name}\' from reports because {symbol_name} is missing or '
                              f'invalid.')
                return False

        for retire_latency in metric_def.retire_latencies:
            symbol_name = metric_def.retire_latencies[retire_latency]
            if symbol_name not in df.columns:
                logging.debug(f'Excluding \'{metric_def.name}\' from reports because {symbol_name} is unavailable.')
                return False

        return True

    def __calculate_sample_level_metrics(self,
                                         df: pd.DataFrame,
                                         metrics_to_compute: List[CompiledMetric]) -> pd.DataFrame:
        """
        Compute sample-level metrics for the input dataframe.
        Sample level metrics are only calculated for rows in the input dataframe that contain all values
        references by the metric formula.

        :param df: input dataframe
        :param metrics_to_compute: list of metrics to compute

        :return: a dataframe with calculated sample-level metrics, and an index identical to `df`.
        """
        result_df = pd.DataFrame()
        for compiled_metric in metrics_to_compute:
            if self.__all_metric_references_are_available(compiled_metric.definition, df, self.__symbol_table):
                result_df[compiled_metric.definition.name] = compiled_metric(df)
        return result_df

    def __calculate_block_level_metrics(self,
                                        df: pd.DataFrame,
                                        sample_level_result: pd.DataFrame,
                                        metrics_to_compute: List[CompiledMetric]) -> pd.DataFrame:
        """
        Compute block-level metrics for the input dataframe.

        Block-level metrics are metrics that require events across more than one sample within a block.
        For these metrics, we calculate the block average for all of the events, then use these averages to calculate
        the metric values.
        We then apply the block average metric values to all metrics that weren't calculated per sample,
        and store this value in the last timestamp of the block. This allows the metric to be averaged with all
        the other events/metrics into larger intervals without distorting the data.

        :param df: input dataframe
        :param sample_level_result: dataframe containing the computed sample-level metrics, indexed by timestamp
        :param metrics_to_compute: list of metrics to compute

        :return: a dataframe with calculated block-level metrics, indexed by the same levels as the index of `df`.
        """
        if not self.__block_level_metrics_requested:
            return pd.DataFrame()

        block_avg_df = self.__compute_event_avg_per_group(df)

        # Compute metrics
        block_result_df = pd.DataFrame()
        block_result = {compiled_metric.definition.name: np.array(compiled_metric(block_avg_df)).flatten()
                        for compiled_metric in metrics_to_compute
                        if self.__all_metric_references_are_available(compiled_metric.definition,
                                                                      df, self.__symbol_table)
                        and sample_level_result[compiled_metric.definition.name].isnull().values.all()}

        if block_result:
            try:
                block_index = self._get_block_df_index(block_avg_df)
                block_result_df = pd.DataFrame(block_result, block_index)
            except ValueError as e:
                if 'Length of values' not in str(e):
                    raise e
                logging.warning(f'Unable to calculate block-level metrics: {e}')
                block_result_df = pd.DataFrame()
        return block_result_df

    def __compute_event_avg_per_group(self, df):
        # Compute event average values for each device (socket, core, thread) in each group
        index_names = list(df.index.names)
        index_names.remove(self.__timestamp_index_name)
        block_avg_df = df.groupby(index_names).mean(numeric_only=True)

        # Associate the event averages with the group's max timestamp.
        # This is done by recreating the index, adding each group's max timestamp to the index,
        # and reusing the original index levels order
        block_avg_df.index = pd.MultiIndex.from_frame(
            df.index.to_frame()[self.__timestamp_index_name].groupby(
                index_names).max().reset_index()[df.index.names])
        return block_avg_df

    def _get_block_df_index(self, df):
        # Used for _SliceMetricComputer to get proper block df index
        return df.index

    def __update_symbol_table(self, metric_definition_list: List[MetricDefinition]):
        self.__update_symbol_table_constants()
        self.__update_symbol_table_metric_constants(metric_definition_list)

    def __update_symbol_table_constants(self):
        updated_symbol_table_constants = self.__update_system_constants(self.__symbol_table)
        self.__symbol_table.update(updated_symbol_table_constants)

    def __update_symbol_table_metric_constants(self, metric_definition_list: List[MetricDefinition]):
        for metric in metric_definition_list:
            constants = self.__update_system_constants(metric.constants)
            self.__symbol_table.update(constants)

    @staticmethod
    def __update_system_constants(constants: Dict[str, str]):
        """
        Special logic to extract the value out of 'per_socket' metric constants so these can be set to 1 for uncore
        views. Any special logic to change the symbol_table that requires the metric_constants should go here.

        This only needs to be done if there is a hardcoded metric constant in the metric file that needs to be updated
        later for a specific view.
        i.e. {'channels_populated_per_socket': 8} becomes {'channels_populated_per_socket':
        'system.channels_populated_per_socket', 'system.channels_populated_per_socket': 8}
        @param constants: a dictionary of a metric constant name and corresponding value
        @return: An updated metric constant dictionary for (previously) hardcoded metric constants
        """
        updated_system_constants = {}
        for name, value in constants.items():
            if str(value).isnumeric() and name.lower().endswith('per_socket') and 'system.' not in name.lower():
                system_constant = f'system.{name}'
                updated_system_constants[name] = system_constant
                updated_system_constants[system_constant] = float(value)
        return updated_system_constants

    def __add_constant_values_to_input_dataframe(self, df: pd.DataFrame, constant_values: dict):
        if not constant_values:
            return df
        try:
            df_with_constants = df.copy()
        except Exception as e:
            self.__handle_exceptions(e)

        for constant_name, constant_value in constant_values.items():
            # If the symbol already exists in the dataframe, don't replace it with a constant
            if constant_name not in df.columns:
                df_with_constants[constant_name] = constant_value
        return df_with_constants

    def __handle_exceptions(self, e: Exception):
        error_type = type(e).__name__
        default_exception_msg = f" ({error_type}, message: {str(e)})"
        if isinstance(e, MemoryError) and 'Unable to allocate' in str(e):
            raise DataFrameMemoryError(original_exception=e) from e
        else:
            raise Exception(default_exception_msg)

class _SliceMetricCompiler(_MetricCompiler):

    def __init__(self, aggregation_level_columns: List[str]):
        self.aggregation_level_columns = aggregation_level_columns
        self.multiindex_map = {column: idx for idx, column in enumerate(
            self.aggregation_level_columns)}

    def _create_metric_function_body(self,
                                     metric: MetricDefinition,
                                     system_constants: Dict,
                                     event_info: Tuple[pd.DataFrame, int] = None) -> str:
        """
        Generate the vectorized metric function for input metric.

        :param metric: metric definition
        :param system_constants: a dict that maps system constants (e.g., number of sockets) to their values.
                                 The values of all constants in this dict will be inlined into the generated code.

        :return: a string corresponding to vectorized metric computation with system constant alias
                 replaced with corresponding values
        """
        metric_function_body = super()._create_metric_function_body(metric, system_constants, event_info)
        updated_metric_function_body = self._resolve_metric_function(metric_function_body)
        return updated_metric_function_body

    def _resolve_metric_function(self, formula: str):
        updated_formula = formula
        if re.search(INDICES_RE, formula):  # TODO: add exception handling
            updated_formula = re.sub(EVENT_NO_INDICES_RE,
                                     lambda x: self.__handle_non_sliced_events(x.group()),
                                     formula)
            updated_formula = re.sub(EVENT_WITH_INDICES_RE,
                                     lambda x: self.__handle_sliced_events(x.group()), updated_formula)
        return updated_formula

    def __handle_sliced_events(self, event_with_indices):
        event_name = re.search(EVENT_RE, event_with_indices).group()
        index = self.__get_indices(event_with_indices)
        groupby_str = self.__get_groupby_str()
        agg_str = self.__get_agg_str(event_name)
        updated_event = f"{event_name}.where({event_name}.index.get_level_values('unit').isin({index})){groupby_str}" \
                        f"{agg_str}"
        return updated_event

    def __get_indices(self, event_with_indices):
        indices = re.search(INDICES_RE, event_with_indices).group()
        indices = [int(x) for x in indices.strip('[').strip(']').split(':')]
        if self.__start_and_end_indices_exist(indices):
            indices = self.__get_range_of_indices(indices)
        return indices

    def __handle_non_sliced_events(self, event_without_indices):
        groupby_str = self.__get_groupby_str()
        agg_str = self.__get_agg_str(event_without_indices)
        updated_event = f"({event_without_indices}" \
                        f"{groupby_str}{agg_str})"
        return updated_event

    def __get_groupby_str(self):
        if self.aggregation_level_columns:
            groupby_str = f".groupby({self.aggregation_level_columns})"
        else:
            groupby_str = f".to_frame('{rdc.VALUE}')"
        return groupby_str

    @staticmethod
    def __get_agg_str(event_name):
        constant_name_search = re.search(CONSTANT_NAME_RE, event_name)
        if constant_name_search:  # TODO: make applicable to all constants not 'inlined' and retire latency
            return '.mean()'
        return ".sum(min_count=1)"

    @staticmethod
    def __start_and_end_indices_exist(indices):
        return 2 <= len(indices) <= 3

    @staticmethod
    def __get_range_of_indices(indices):
        start_index = indices[0]
        end_index_inclusive = indices[1] + 1
        step_size = indices[2] if len(indices) > 2 else 1
        indices = range(start_index, end_index_inclusive, step_size)
        return indices


class _SliceMetricComputer(MetricComputer):

    def __init__(self, metric_definition_list, symbol_table, aggregation_level_columns: List[str]):
        slice_metric_definition_list = self.__get_slice_metrics(metric_definition_list)
        self.metric_compiler = _SliceMetricCompiler(aggregation_level_columns)

        super().__init__(slice_metric_definition_list, symbol_table)

    def _get_block_df_index(self, df):
        return df.groupby(self.metric_compiler.aggregation_level_columns).first().index

    @staticmethod
    def __get_slice_metrics(metric_definition_list):
        return list(filter(lambda x: re.search(INDICES_RE, x.formula) is not
                                     None, metric_definition_list))


class DataFrameMemoryError(Exception):
    def __init__(self, original_exception=None, message=" A memory allocation error occurred while "
        "processing a DataFrame. To resolve this issue, you can try one of the following: "
        "1. Adjust the chunk size using the '--chunk-size' command-line option, or "
        "2. Close any other memory-intensive applications running in the background."):
        if original_exception:
            error_type = type(original_exception).__name__
            message += f" (Original error type: {error_type}, message: {str(original_exception)})"
        super().__init__(message)
