# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: MIT

import logging
import re
from abc import abstractmethod, ABC
from dataclasses import fields
from typing import Dict

from mpp.core.types import MetricDefinition, ThresholdDefinition


class _MetricFields(ABC):

    def __init__(self, metric_def):
        self.metric_def = metric_def
        self.legacy_constant_map = {
            # key = legacy constant name, value = current constant name
              'system.cha_count/system.socket_count': 'CHAS_PER_SOCKET',
        }

    @property
    def fields(self):
        return tuple(getattr(self, f.name) for f in fields(MetricDefinition))

    @staticmethod
    def _adjust_formula_if_sampling_time(name: str, formula: str):
        # Adjusts the formula of sampling time to get the average sampling time per sample in the aggregate
        if name == "metric_EDP EMON Sampling time (seconds)":
            formula += ' / samples'
        to_python_converter = _ToPythonConverter(formula)
        python_formula = to_python_converter.convert()
        return python_formula

    @staticmethod
    def _adjust_constants_if_sampling_time(name: str, constants: Dict):
        # Adds $processed_samples to the constants dictionary for sampling time
        if name == "metric_EDP EMON Sampling time (seconds)":
            constants['samples'] = '$processed_samples'
        return constants

    @property
    @abstractmethod
    def name(self):
        pass

    @property
    @abstractmethod
    def throughput_metric_name(self):
        pass

    @property
    @abstractmethod
    def description(self):
        pass

    @property
    @abstractmethod
    def formula(self):
        pass

    @property
    @abstractmethod
    def event_aliases(self):
        pass

    @property
    @abstractmethod
    def constants(self):
        pass

    @property
    @abstractmethod
    def retire_latencies(self):
        pass

    @property
    def canonical_name(self):
        return MetricDefinition.canonical_name

    @property
    @abstractmethod
    def level(self):
        pass

    @property
    @abstractmethod
    def unit_of_measure(self):
        pass

    @property
    @abstractmethod
    def category(self):
        pass

    @property
    @abstractmethod
    def threshold(self):
        pass

    @property
    @abstractmethod
    def resolution_levels(self):
        pass

    @property
    @abstractmethod
    def metric_group(self):
        pass

    @property
    def human_readable_expression(self):
        return MetricDefinition.human_readable_expression


class _XmlMetricFields(_MetricFields):

    def __init__(self, metric_def):
        super().__init__(metric_def)
        self._event_alias_map = {}
        self._latencies = {}
        self._event_alias_map, self._latencies = self._get_dicts_for_events()

    @property
    def name(self):
        return self.metric_def.get('name')

    @property
    def throughput_metric_name(self):
        return self._get_throughput_name(self.metric_def)

    @property
    def description(self):
        return self._get_description(self.metric_def)

    @property
    def formula(self):
        formula_tag = self.metric_def.find('formula')
        formula = formula_tag.text if formula_tag is not None else ''
        return self._adjust_formula_if_sampling_time(self.name, formula)

    @property
    def event_aliases(self):
        return self._event_alias_map

    @property
    def constants(self):
        constants = self._get_constants(self.metric_def)
        return self._adjust_constants_if_sampling_time(self.name, constants)

    @property
    def retire_latencies(self):
        return self._latencies

    @property
    def level(self):
        return MetricDefinition.level

    @property
    def unit_of_measure(self):
        return MetricDefinition.unit_of_measure

    @property
    def category(self):
        category_tag = self.metric_def.find('category')
        if category_tag is not None:
            return category_tag.text
        return MetricDefinition.category

    @property
    def threshold(self):
        return self._get_threshold()

    @property
    def resolution_levels(self):
        return MetricDefinition.resolution_levels

    @property
    def metric_group(self):
        return MetricDefinition.metric_group

    def _get_threshold(self):
        threshold_tags = self.metric_def.findall('threshold')
        if not threshold_tags:
            return MetricDefinition.threshold
        # for now only supporting one threshold, will increase to multiple later
        threshold_tag = threshold_tags[0]
        threshold_metric_tags = threshold_tag.findall('metric')
        threshold_metric_aliases = {tag.get('alias'): tag.text for tag in threshold_metric_tags}

        threshold_formula_tag = threshold_tag.find('formula')
        threshold_formula = threshold_formula_tag.text if threshold_formula_tag is not None else logging.debug(
            'Threshold is missing a formula')
        if not threshold_formula:
            return MetricDefinition.threshold
        threshold_formula_raw_tag = threshold_tag.find('base_formula')
        threshold_formula_raw = threshold_formula_raw_tag.text if threshold_formula_raw_tag is not None else None

        return ThresholdDefinition(name=self.name,
                                   metric_aliases=threshold_metric_aliases,
                                   formula=threshold_formula,
                                   formula_raw=threshold_formula_raw)

    def _get_constants(self, metric_def) -> Dict[str, str]:
        constants = {}
        for const in metric_def.findall('constant'):
            # Hack - replace the "system.cha_count/system.socket_count" expression with a "chas_per_socket" constant
            constants[const.get('alias')] = self.legacy_constant_map.get(const.text.strip(), const.text)
        return constants

    @staticmethod
    def _get_throughput_name(metric_def) -> str:
        for tmn in metric_def.findall('throughput-metric-name'):
            # assume only one <throughput-metric-name> elements
            return tmn.text
        return ''

    @staticmethod
    def _get_description(metric_def) -> str:
        for d in metric_def.findall('description'):
            # assume only one <description> element
            return d.text if d.text is not None else ''
        return ''

    def _get_dicts_for_events(self):
        from mpp.parsers.metrics import RETIRE_LATENCY_STRING
        events = {}
        latencies = {}
        for md in self.metric_def.findall('event'):
            if RETIRE_LATENCY_STRING not in md.text:
                events[md.get('alias')] = md.text
            else:
                latencies[md.get('alias')] = md.text
        times = self._get_dict_for_tag('time', self.metric_def)
        events.update(times)
        return events, latencies

    @staticmethod
    def _get_dict_for_tag(tag: str, metric_def):
        d = {}
        for md in metric_def.findall(tag):
            d[md.get('alias')] = md.text
        return d


class _JsonMetricFields(_MetricFields):

    @property
    def name(self):
        return self.metric_def['LegacyName']

    @property
    def throughput_metric_name(self):
        return self.__get_optional_key(self.metric_def, 'ThroughputName')

    @property
    def description(self):
        return self.__get_optional_key(self.metric_def, 'BriefDescription')

    @property
    def formula(self):
        formula = self.metric_def['Formula']
        return self._adjust_formula_if_sampling_time(self.name, formula)

    @property
    def event_aliases(self):
        return self._set_alias_map(self.metric_def['Events'])

    @property
    def constants(self):
        constants = self._set_alias_map(self.metric_def['Constants'], self.legacy_constant_map)
        return self._adjust_constants_if_sampling_time(self.name, constants)

    @property
    def retire_latencies(self):
        return {}

    @property
    def level(self):
        return self.metric_def['Level']

    @property
    def unit_of_measure(self):
        return self.__get_optional_key(self.metric_def, 'UnitOfMeasure')

    @property
    def category(self):
        return self.__get_optional_key(self.metric_def, 'Category')

    @property
    def threshold(self):
        return None # self.__get_optional_key(self.metric_def, 'Threshold', None)

    @property
    def resolution_levels(self):
        return self.__get_optional_key(self.metric_def, 'ResolutionLevels').split(', ')

    @property
    def metric_group(self):
        return self.__get_optional_key(self.metric_def, 'MetricGroup')

    @staticmethod
    def __get_optional_key(data, key, default=''):
        return data.get(key, default)

    @staticmethod
    def _set_alias_map(definitions, alternatives=None):
        if alternatives is None:
            alternatives = {}
        alias_map = {}
        for definition in definitions:
            alias_map.update({definition['Alias']: alternatives.get(definition['Name'].strip(), definition['Name'])})
        return alias_map


class _ToPythonConverter:
    """
    Converts non-Python expressions into Python equivalents
    """
    def __init__(self, expression):
        self.expression = expression
        self.regex_patterns = {'ternary': r'(.*)\?(.*):(.*)'}

    def convert(self):
        expression_type = self.__determine_expression_type()
        if expression_type == 'ternary':
            self.expression = self.__convert_ternary_expression()
        return self.expression

    def __determine_expression_type(self):
        for key, pattern in self.regex_patterns.items():
            if re.match(pattern, self.expression):
                return key

    def __convert_ternary_expression(self):
        out_formula = self.expression
        # Is there a C-style ternary operator?
        pattern = r'(.*)\?(.*):(.*)'
        ternary_pattern_match = re.match(self.regex_patterns['ternary'], out_formula)
        while ternary_pattern_match:
            # Find each subexpression in the formula
            stack = []
            subexpression_index_pairs = []
            for index, char in enumerate(out_formula):
                if char == '(':
                    stack.append(index + 1)
                elif char == ')':
                    subexpression_index_pairs.append((stack.pop(), index - 1))
            subexpression_index_pairs.append((0, len(out_formula) - 1))
            # Find the innermost subexpression containing the ternary expression and transform it
            for subexpression_index_pair in subexpression_index_pairs:
                subexpression = out_formula[subexpression_index_pair[0]: subexpression_index_pair[1] + 1]
                ternary_match = re.match(pattern, subexpression)
                if ternary_match:
                    [cond, val1, val2] = ternary_match.groups()
                    out_formula = out_formula.replace(subexpression, '{0} if {1} else {2}'.format(val1, cond, val2))
                    break
            ternary_pattern_match = re.match(pattern, out_formula)

        return out_formula
