import io
from xml.etree.ElementTree import ElementTree, Element, fromstring, register_namespace
from typing import Optional, Any, Dict, List, cast, Union
from .model_parameter import ModelParameter
from ..parser_exception import ParserException
from .improve_target import ImproveTarget
[docs]
class ModelParser:
    """
        Performs various operations with XML representation of KW models
    """
    __ns = {'AA': 'KWKA_XML_Export', 'KWKAModel': 'KWKA_XML_Model', 'KWKA_XML_NLR': "KWKA_XML_NLR", "KWKANUMWELLPAR": "KWKA_NUM_WELL_PAR"}
[docs]
    def __init__(self, xml_string: str) -> None:
        register_namespace('AA', 'KWKA_XML_Export')
        register_namespace('KWKAModel', 'KWKA_XML_Model')
        register_namespace('AB', 'http://www.w3.org/2001/XMLSchema-instance')
        register_namespace('KWKA_XML_NLR', "KWKA_XML_NLR")
        register_namespace("KWKANUMWELLPAR", "KWKA_NUM_WELL_PAR")
        register_namespace('ns', 'KWKA_XML_Model')
        self.__tree = ElementTree(fromstring(xml_string))
        self.__root = self.__tree.getroot() 
    def __find_parameter_in_element(self, element: Element, namespace_prefix: str, parameter_name: str) -> Element:
        parameter = element.find(f"{namespace_prefix}:{parameter_name}", self.__ns)
        if parameter is None:
            raise ParserException(f"Cannot find {self.__ns[namespace_prefix]}:{parameter_name} in the xml document")
        return parameter
    def __get_model_element(self, analysis_id: str) -> Element:
        models_container_element = self.__find_parameter_in_element(self.__root, 'AA', 'Models')
        for model_element in models_container_element.findall('AA:ModelPerProject', self.__ns):
            model_element_find = self.__find_parameter_in_element(model_element, 'KWKAModel', 'AnalysisID')
            if model_element_find.text == analysis_id:
                try:
                    model_element = self.__find_parameter_in_element(model_element, 'KWKAModel', "NumericalModel")
                    self.__is_numerical = True
                except ParserException:
                    self.__is_numerical = False
                return model_element
        else:
            raise ParserException('Document does not contain model for analysis <{}>'.format(analysis_id))
    def __get_model_definition_element(self, analysis_id: str, definition_category_name: str, parameter_name: str, sub_parameter_name: Optional[str] = None) -> \
            
Optional[Element]:
        model_element = self.__get_model_element(analysis_id)
        for parameter_element in model_element.findall('KWKAModel:{}'.format(definition_category_name), self.__ns):
            if sub_parameter_name is None:
                return parameter_element.findall('KWKAModel:{}'.format(parameter_name), self.__ns)[0]
            else:
                for subparameter_element in parameter_element.findall('KWKAModel:{}'.format(parameter_name), self.__ns):
                    try:
                        return subparameter_element.findall('KWKAModel:{}'.format(sub_parameter_name), self.__ns)[0]
                    except IndexError:
                        return subparameter_element.findall('KWKA_XML_NLR:{}'.format(sub_parameter_name), self.__ns)[0]
        else:
            raise ParserException("Cannot find the model definition in the XML representation of the KW model")
    def __get_parameter_element(self, model_element: Element, parameter_name: str, conditions: Optional[Dict[Any, Any]]) -> \
            
Optional[Element]:
        matching_parameter_elements = list()
        parameter_elements = model_element.findall('KWKAModel:NumModelParameters', self.__ns) if self.__is_numerical \
            
else model_element.findall('KWKAModel:ModelParameter', self.__ns)
        for parameter_element in parameter_elements:
            type_attr = parameter_element.attrib['{http://www.w3.org/2001/XMLSchema-instance}type']
            if type_attr == parameter_name or (':' in type_attr and type_attr.split(':')[1] == parameter_name):
                matching_parameter_elements.append(parameter_element)
        if len(matching_parameter_elements) == 0:
            return None
        if conditions is None:
            return matching_parameter_elements[0].find('KWKAModel:Value', self.__ns)
        else:
            for parameter_element in matching_parameter_elements:
                matching_conditions = False
                for attribute_name, attribute_value in conditions.items():
                    sub_element = parameter_element.find('KWKAModel:{}'.format(attribute_name), self.__ns)
                    if sub_element is not None:
                        if sub_element.text == attribute_value:
                            matching_conditions = True
                        else:
                            matching_conditions = False
                            break
                if matching_conditions:
                    return parameter_element
        return None
    def __get_parameter_value_element(self, model_element: Element, parameter_name: str, conditions: Optional[Dict[Any, Any]]) -> Element:
        parameter_element = self.__get_parameter_element(model_element, parameter_name, conditions)
        if parameter_element is not None:
            return self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'Value')
        raise ParserException('Model does not contain parameter <{}> that matches given conditions'.format(parameter_name))
[docs]
    def get_parameter_value(self, analysis_id: str, xml_parameter_locator: ModelParameter) -> str:
        model_element = self.__get_model_element(analysis_id)
        parameter_element = self.__get_parameter_value_element(model_element, xml_parameter_locator.parameter_type, xml_parameter_locator.conditions)
        return str(parameter_element.text) 
[docs]
    def set_parameter_value(self, analysis_id: str, xml_parameter_locator: ModelParameter, value: str) -> None:
        model_element = self.__get_model_element(analysis_id)
        parameter_element = self.__get_parameter_value_element(model_element, xml_parameter_locator.parameter_type, xml_parameter_locator.conditions)
        parameter_element.text = str(value) 
[docs]
    def add_model_description_element(self, analysis_id: str, parameter_name: str, sub_parameter_name: str, parameter_value: str, position: int = 0) -> None:
        parameter_element = self.__get_model_definition_element(analysis_id, 'ModelDescription', parameter_name)
        element = Element(sub_parameter_name)
        element.text = parameter_value
        if parameter_element is None:
            raise ParserException(
                'Model does not contain "ModelDescription" with parameter <{}> with sub-parameter <{}>'.format(parameter_name, sub_parameter_name))
        parameter_element.insert(int(position), element) 
[docs]
    def add_improve_parameter(self, analysis_id: str, parameter_name: str, parameter_value: str, position: int = 0) -> None:
        model_element = self.__get_model_element(analysis_id)
        parameter_element = model_element.findall('KWKAModel:ImproveParameters', self.__ns)[0]
        element = Element(parameter_name)
        if element is None:
            raise ParserException('Model does not contain "KWKAModel:ImproveParameters" with parameter <{}>'.format(parameter_name))
        element.text = parameter_value
        parameter_element.insert(int(position), element) 
[docs]
    def get_model_description_value(self, analysis_id: str, parameter_name: str, sub_parameter_name: Optional[str] = None) -> str:
        parameter_element = self.__get_model_definition_element(analysis_id, 'ModelDescription', parameter_name, sub_parameter_name)
        if parameter_element is None:
            raise ParserException('Model does not contain "ModelDescription" with parameter <{}>'.format(parameter_name))
        return str(parameter_element.text) 
[docs]
    def set_model_description_value(self, analysis_id: str, parameter_name: str, sub_parameter_name: Optional[str], value: str) -> None:
        parameter_element = self.__get_model_definition_element(analysis_id, 'ModelDescription', parameter_name, sub_parameter_name)
        if parameter_element is None:
            raise ParserException(
                'Model does not contain "ModelDescription" with parameter <{}> with sub-parameter <{}>'.format(parameter_name, sub_parameter_name))
        parameter_element.text = str(value) 
[docs]
    def get_improve_parameter_value(self, analysis_id: str, parameter_name: str, sub_parameter_name: Optional[str] = None) -> Optional[str]:
        parameter_element = self.__get_model_definition_element(analysis_id, 'ImproveParameters', parameter_name, sub_parameter_name)
        if parameter_element is None:
            raise ParserException(
                'Model does not contain "ImproveParameters" with parameter <{}> with sub-parameter <{}>'.format(parameter_name, sub_parameter_name))
        return str(parameter_element.text) 
[docs]
    def set_improve_parameter_value(self, analysis_id: str, parameter_name: str, sub_parameter_name: str, value: Union[bool, str]) -> None:
        parameter_element = self.__get_model_definition_element(analysis_id, 'ImproveParameters', parameter_name, sub_parameter_name)
        if parameter_element is None:
            raise ParserException(
                'Model does not contain "ImproveParameters" with parameter <{}> with sub-parameter <{}>'.format(parameter_name, sub_parameter_name))
        parameter_element.text = str(value).lower() 
[docs]
    def set_model_parameter_improve_value(self, analysis_id: str, parameter: ModelParameter, improve_parameter_name: str, value: str) -> None:
        model_element = self.__get_model_element(analysis_id)
        parameter_element = self.__get_parameter_element(model_element, parameter.parameter_type, parameter.conditions)
        if parameter_element is None:
            raise ParserException(f"Cannot find the parameter {parameter.parameter_type} with conditions {parameter.conditions} in the model")
        improve_element = self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'Improve')
        improve_parameter_element = self.__find_parameter_in_element(improve_element, "KWKAModel", f"{improve_parameter_name}")
        improve_parameter_element.text = str(value).lower() 
[docs]
    def get_improve_parameter_targets(self, analysis_id: str) -> List[ImproveTarget]:
        improve_targets_elements = self.__get_improve_parameter_target_elements(analysis_id)
        improve_target_list = list()
        for parameter_element in improve_targets_elements:
            target_type_value = cast(str, self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'Type').text)
            is_selected_value = bool(self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'IsSelected').text)
            global_weight_value = cast(float, self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'GlobalWeight').text)
            improve_target_list.append(ImproveTarget(target_type_value, global_weight_value, is_selected_value))
        return improve_target_list 
[docs]
    def set_improve_parameter_targets(self, analysis_id: str, improve_parameter_targets: List[ImproveTarget]) -> None:
        improve_targets_elements = self.__get_improve_parameter_target_elements(analysis_id)
        for target in improve_parameter_targets:
            for parameter_element in improve_targets_elements:
                type_element = self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'Type')
                if target.target_type == type_element.text:
                    self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'IsSelected').text = str(target.is_selected).lower()
                    self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'GlobalWeight').text = str(target.global_weight)
                    break
            else:
                raise ValueError(f"ImproveTargets Type {target.target_type} not found in the xml document") 
[docs]
    def set_improve_dt_start(self, analysis_id: str, dt_start: float) -> None:
        model_element = self.__get_model_element(analysis_id)
        improve_parameters_element = self.__find_parameter_in_element(model_element, 'KWKAModel', 'ImproveParameters')
        improve_parameter_element = self.__find_parameter_in_element(improve_parameters_element, "KWKAModel", "ElapsedDtStart")
        improve_parameter_element.text = str(dt_start).lower() 
    def __get_improve_parameter_target_elements(self, analysis_id: str) -> List[Element]:
        model_element = self.__get_model_element(analysis_id)
        improve_parameter_element = self.__find_parameter_in_element(model_element, 'KWKAModel', 'ImproveParameters')
        return improve_parameter_element.findall('KWKAModel:{}'.format('ImproveTargets'), self.__ns)
[docs]
    def export(self) -> str:
        stream = io.StringIO()
        self.__tree.write(stream, encoding='unicode', method='xml', xml_declaration=True)
        return stream.getvalue() 
[docs]
    def include_or_exclude_parameter_for_improve(self, analysis_id: str, parameter_name: str, is_included: bool,
                                                 conditions: Optional[Dict[Any, Any]]) -> None:
        model_element = self.__get_model_element(analysis_id)
        parameter_element = self.__get_parameter_element(model_element, parameter_name, conditions)
        if parameter_element is not None:
            improve_el = self.__find_parameter_in_element(parameter_element, 'KWKAModel', 'Improve')
            include_el = self.__find_parameter_in_element(improve_el, 'KWKAModel', 'Include')
            include_el.text = "true" if is_included else "false" 
[docs]
    def remove_analysis(self, analysis_id: str) -> None:
        models_container_element = self.__find_parameter_in_element(self.__root, 'AA', 'Models')
        for model_element in models_container_element.findall('AA:ModelPerProject', self.__ns):
            model_element_find = self.__find_parameter_in_element(model_element, 'KWKAModel', 'AnalysisID')
            if model_element_find.text == analysis_id:
                models_container_element.remove(model_element)
                break 
[docs]
    def remove_model_description(self, analyses_id: List[str]) -> None:
        for analysis_id in analyses_id:
            model_element = self.__get_model_element(analysis_id)
            model_description = "NumModelDescription" if self.__is_numerical else 'ModelDescription'
            model_description_element = self.__find_parameter_in_element(model_element, 'KWKAModel', model_description)
            model_element.remove(model_description_element)