from typing import List, Optional, Union, Dict, Any
from xml.etree.ElementTree import ElementTree, Element, fromstring
from .datetime_utils import str_to_datetime
from datetime import datetime
ns_v1 = {'AC': 'KW_XML_Export'}
ns_v2 = {'AA': 'KWKA_XML_Export', 'KWXmlResults': 'KW_XML_Results'}
[docs]
class NumericValue:
[docs]
    def __init__(self, measure: str, value: float):
        self.__measure = measure
        self.__value = value 
    @property
    def measure(self) -> str:
        """ Gets a measure as a string.
        """
        return self.__measure
    @property
    def value(self) -> float:
        """ Gets a value as a float.
        """
        return self.__value 
[docs]
class ResultsXMLParsingException(Exception):
[docs]
    def __init__(self, message: str) -> None:
        self.message: str = message 
 
[docs]
class AnalysisResults:
    """
    Wrapper for the results.xml from KW document.
    Presents analysis results XML in convenient way, supporting both v1 and v2 schemas.
    .. note:: :py:obj:`Document.analysis_results` property is populated on-demand and is cached for the duration of the :class:`Connection`.
    If you need to get actual values, use the :py:meth:`Document.get_results_xml` method and instantiate this class with it.
    """
[docs]
    def __init__(self, xml_string: str):
        self.__tree: ElementTree = ElementTree(fromstring(xml_string))
        self.__root: Element = self.__tree.getroot() 
    def __get_results_element(self, analysis_name: str) -> Element:
        results_element = self.__root.findall('AC:Results', ns_v1)
        for elem in results_element:
            if elem.get('name') == analysis_name:
                return elem
        results_container_element = self.__root.find('AA:Results', ns_v2)
        if results_container_element is not None:
            for elem in results_container_element.findall('KWXmlResults:Results', ns_v2):
                if elem.get('name') == analysis_name:
                    return elem
        raise ResultsXMLParsingException('Saphir document does not contain results for <{}> analysis'.format(analysis_name))
    def __find_element(self, element: Element, name: str) -> Optional[Element]:
        sub_element = element.find('AC:{}'.format(name), ns_v1)
        if sub_element is not None:
            return sub_element
        else:
            return element.find('KWXmlResults:{}'.format(name), ns_v2)
    def __find_elements(self, element: Element, name: str) -> List[Element]:
        sub_elements = element.findall('AC:{}'.format(name), ns_v1)
        if sub_elements:
            return sub_elements
        else:
            return element.findall('KWXmlResults:{}'.format(name), ns_v2)
[docs]
    def get_analysis_names(self) -> List[str]:
        """ Returns a list of analysis names.
        """
        name_list = list()
        results_element = self.__root.findall('AC:Results', ns_v1)
        for elem in results_element:
            name_list.append(str(elem.get('name')))
        if len(name_list) > 0:
            return name_list
        results_container_element = self.__root.find('AA:Results', ns_v2)
        if results_container_element is not None:
            for elem in results_container_element.findall('KWXmlResults:Results', ns_v2):
                name_list.append(str(elem.get('name')))
        return name_list 
[docs]
    def get_productivity_index_table_elements(self, analysis_name: Optional[str] = None) -> Union[Dict[str, Dict[str, Any]], Dict[str, Any]]:
        """Returns a dict with the productivity index table elements for each analysis name.
         If an analysis name is provided, it will return only the productivity index table elements of this analysis"""
        results_element = self.__root.find('AA:Results', ns_v2)
        if results_element is not None:
            productivity_index_results = results_element.find('KWXmlResults:ProductivityIndexResults', ns_v2)
            if productivity_index_results is not None:
                productivity_index_table_elements = productivity_index_results.findall('KWXmlResults:PITableElement', ns_v2)
                table_elements_dict: Dict[str, Dict[str, Any]] = dict()
                for element in productivity_index_table_elements:
                    period_name = str(element.find("KWXmlResults:PeriodName", ns_v2).text)  # type:ignore[union-attr]
                    if analysis_name is not None and period_name != analysis_name:
                        continue
                    table_elements_dict[period_name] = {}
                    for x in list(element):
                        if x.text != period_name:
                            table_elements_dict[period_name][str(x.tag.split("}")[1])] = x.text
                return table_elements_dict if analysis_name is None else table_elements_dict[analysis_name]
            else:
                raise ValueError("There is no productivity index table in this document")
        else:
            raise ValueError("There is no Results section in this document") 
[docs]
    def get_parameter_value(self, analysis_name: str, category_name: str, parameter_name: str) -> Union[NumericValue, str, datetime]:
        """ Returns a value of a given parameter.
        Parameters
        ----------
        analysis_name:
            The name of analysis.
        category_name:
            The name of parameter category.
        parameter_name:
            The name of the parameter.
        """
        analysis_element = self.__get_results_element(analysis_name)
        parameters_element = None
        for category_element in self.__find_elements(analysis_element, 'Category'):
            if category_element.get('name') == category_name:
                parameters_element = category_element
                break
        if parameters_element is None:
            raise ResultsXMLParsingException('Saphir document results does not contain <{}> category'.format(category_name))
        for parameter_element in [x for x in list(parameters_element) if x.get('name') == parameter_name]:
            value_element = self.__find_element(parameter_element, 'doublevalue')
            if value_element is not None:
                unit = value_element.get("unit")
                if unit is not None and value_element.text is not None:
                    return NumericValue(unit, float(value_element.text))
            value_element = self.__find_element(parameter_element, 'stringvalue')
            if value_element is not None:
                return str(value_element.text)
            value_element = self.__find_element(parameter_element, 'datetimevalue')
            if value_element is not None:
                date_value = str_to_datetime(value_element.text)
                if date_value is None:
                    raise ValueError(f'Error while parsing {value_element} as datetime')
                return date_value
        raise ResultsXMLParsingException('Saphir document results does not contain <{}> parameter in <{}> category'.format(parameter_name, category_name)) 
[docs]
    def get_period_names(self, analysis_name: str) -> List[str]:
        """ Returns a list of extracted flow period names for a given analysis.
        Parameters
        ----------
        analysis_name:
            The name of analysis.
        """
        analysis_element = self.__get_results_element(analysis_name)
        name_list = list()
        if analysis_element is not None:
            periods_element = self.__find_element(analysis_element, 'Extracted_Flow_Periods')
            if periods_element is not None:
                for period_element in self.__find_elements(periods_element, 'Extracted_Period'):
                    if period_element is not None:
                        element = self.__find_element(period_element, 'ProductionEventName')
                        if element is not None and element.text is not None:
                            name_list.append(element.text)
        return name_list 
[docs]
    def get_period_property_value(self, analysis_name: str, period_name: str, value_name: str) -> str:
        """ Returns a value for a property of a given period.
        Parameters
        ----------
        analysis_name:
            The name of analysis.
        period_name:
            The name of the extracted flow period.
        value_name:
            The name of the period property.
        """
        analysis_element = self.__get_results_element(analysis_name)
        if analysis_element is not None:
            periods_element = self.__find_element(analysis_element, 'Extracted_Flow_Periods')
            if periods_element is not None:
                for period_element in self.__find_elements(periods_element, 'Extracted_Period'):
                    element = self.__find_element(period_element, 'ProductionEventName')
                    if element is not None and element.text == period_name:
                        value_element = self.__find_element(period_element, value_name)
                        if value_element is None:
                            raise ResultsXMLParsingException('Saphir document results does not contain <{}> value in <{}> flow period'.format(value_name, period_name))
                        value_element_text = value_element.text
                        if value_element_text is None:
                            raise ResultsXMLParsingException('Saphir document results does not contain <{}> value in <{}> flow period'.format(value_name, period_name))
                        return value_element_text
        raise ResultsXMLParsingException('Saphir document results does not contain <{}> value in <{}> flow period'.format(value_name, period_name))