import copy
from datetime import datetime, timezone
import uuid
from typing import List, Dict, Optional
from ..rest_api import RestAPI
from .._private._data_api import DataAPI
from ..metadata import Metadata
from .data_enumerator_settings import DataEnumeratorSettings
from .interpolation_method_enum import InterpolationMethodEnum
from .extrapolation_method_enum import ExtrapolationMethodEnum
from .time_range_enum import TimeRangeEnum
from .data_point import DataPoint
from .data_enumeration import DataEnumeration
from ..vector import Vector
[docs]
class DataEnumerator:
    """Python wrapper for a Dataset Enumerator.
    .. note:: Should not be instantiated directly.
    """
[docs]
    def __init__(self, settings: DataEnumeratorSettings, rest_api: Optional[RestAPI] = None):
        self.__settings: DataEnumeratorSettings = settings
        self.__data_api: Optional[DataAPI] = DataAPI(rest_api) if rest_api is not None else None 
[docs]
    def get_default_settings(self) -> DataEnumeratorSettings:
        return copy.deepcopy(self.__settings) 
    def __get_enumeration(self, x_inputs: Dict[str, List[datetime]], y_inputs: Dict[str, List[float]], vector_metadata: Dict[str, Metadata],
                          last_points: List[DataPoint]) -> DataEnumeration:
        x_outputs = self.get_time_range(x_inputs, vector_metadata)
        y_outputs = self.get_outputs_values(x_inputs, y_inputs, x_outputs, vector_metadata)
        vectors = [Vector(x_outputs, outputs, vector_id=key) for key, outputs in zip(y_inputs, y_outputs)]
        return DataEnumeration(vectors, last_points)
[docs]
    def to_enumerable_from_vectors(self, params: List[Vector], settings: Optional[DataEnumeratorSettings] = None,
                                   metadata_params: Optional[List[Metadata]] = None) -> DataEnumeration:
        """Gets a data enumerator for multiple datasets from class:`Vector`.
        Generates an enumerator that allows to iterate through the list of
        points of given KAPPA Automate datasets using a common time index.
        Parameters
        ----------
        params:
            The list of class:`Vector`.
        settings:
            The data enumeration service.
        metadata_params:
            List of metadata for each class:`Vector`
        """
        if settings is not None:
            self.__settings = settings
        if len(params) < 1:
            return DataEnumeration([Vector([], [])], [])
        ids = list(set([p.id for p in params]))
        if len(ids) < len(params):
            raise ValueError("Duplicate vectors specified in input parameters")
        if not all(isinstance(element, Vector) for element in params):
            raise ValueError("Parameters input are not Vector objects")
        x_inputs: Dict[str, List[datetime]] = dict()
        y_inputs: Dict[str, List[float]] = dict()
        vector_metadata: Dict[str, Metadata] = dict()
        last_points: List[DataPoint] = []
        for i, vector in enumerate(params):
            index = str(vector.id)
            default_metadata = Metadata("", False, vector.first_x, vector.values[0], vector.values[-1], False, "", len(vector.dates))
            vector_metadata[index] = default_metadata if metadata_params is None else metadata_params[i] if i < len(metadata_params) else default_metadata
            x_inputs[index] = vector.dates
            y_inputs[index] = vector.values
            if len(vector.dates) > 0:
                last_points.append(DataPoint(vector.dates[-1], vector.values[-1]))
        return self.__get_enumeration(x_inputs, y_inputs, vector_metadata, last_points) 
[docs]
    def to_enumerable(self, params: List[uuid.UUID], settings: Optional[DataEnumeratorSettings] = None) -> DataEnumeration:
        """Gets a data enumerator for multiple datasets.
        Generates an enumerator that allows to iterate through the list of
        points of given KAPPA Automate datasets using a common time index.
        Parameters
        ----------
        params:
            The list of KAPPA Automate vector identifiers.
        settings:
            The data enumeration service.
        """
        if settings is not None:
            self.__settings = settings
        if len(params) < 1:
            return DataEnumeration([Vector([], [])], [])
        if self.__data_api is None:
            raise ValueError("You must provide a rest API object when you initialize the data enumerator")
        x_inputs: Dict[str, List[datetime]] = dict()
        y_inputs: Dict[str, List[float]] = dict()
        vector_metadata: Dict[str, Metadata] = dict()
        last_points: List[DataPoint] = []
        ids = list(set([str(p) for p in params]))
        if len(ids) < len(params):
            raise ValueError("Duplicate vectors specified in input parameters")
        # Python sets are not keeping order
        ids = [str(p) for p in params]
        max_count_dict = dict()
        for p in ids:
            if isinstance(self.__settings.max_count, dict):
                max_count = self.__settings.max_count[p] if p in self.__settings.max_count.keys() else 100000
            elif isinstance(self.__settings.max_count, int):
                max_count = self.__settings.max_count
            else:
                max_count = 100000
            if max_count > 100000:
                raise ValueError("Max count is 100000 points")
            max_count_dict[p] = max_count
        read_vectors = self.__data_api.read_vectors(ids, max_count_dict, self.__settings.exclusive_start, self.__settings.inclusive_end)
        for i in range(0, len(params)):
            x_inputs[ids[i]] = read_vectors[ids[i]][0]
            y_inputs[ids[i]] = read_vectors[ids[i]][1]
            if len(read_vectors[ids[i]][0]) > 0:
                last_points.append(DataPoint(read_vectors[ids[i]][0][-1], read_vectors[ids[i]][1][-1]))
            vector_metadata[ids[i]] = self.__data_api.get_metadata(ids[i])
        return self.__get_enumeration(x_inputs, y_inputs, vector_metadata, last_points) 
[docs]
    def get_time_range(self, x_inputs: Dict[str, List[datetime]], vector_metadata: Dict[str, Metadata]) -> List[datetime]:
        """
        Calculate the time_range of the interpolation
        Parameters
        ----------
        x_inputs:
            list of the input times
        vector_metadata:
            metadata of the inputs
        Returns
        -------
            list of the output times with in the specified time range
        """
        x_outputs = list()
        x_input_values = list(x_inputs.values())
        if self.__settings.reference_vector_id is not None:
            ref_id = str(self.__settings.reference_vector_id)
            x_outputs = x_inputs[ref_id].copy()
            vector_metadata_ref_id = vector_metadata[ref_id]
            if vector_metadata_ref_id.is_by_step and vector_metadata_ref_id.first_x is not None and not vector_metadata_ref_id.is_step_at_start:
                first_x = vector_metadata_ref_id.first_x
                if first_x is not None:
                    x_outputs.insert(0, first_x)
        else:
            for key, x_value in x_inputs.items():
                res_set = set(x_outputs + x_value)
                x_outputs = list(res_set)
                vector_metadata_key = vector_metadata[key]
                if vector_metadata_key.is_by_step and vector_metadata_key.first_x is not None and not vector_metadata_key.is_step_at_start:
                    x_outputs.insert(0, vector_metadata_key.first_x)
            x_outputs = sorted(set(x_outputs))
        if self.__settings.time_range == TimeRangeEnum.common:
            minimum = datetime.min.replace(tzinfo=timezone.utc)
            maximum = datetime.max.replace(tzinfo=timezone.utc)
            for inputs in x_input_values:
                if len(inputs) == 0:
                    return []
                min_input = min(inputs)
                max_input = max(inputs)
                if min_input > minimum:
                    minimum = min_input
                if max_input < maximum:
                    maximum = max_input
            res = list()
            for output in x_outputs:
                if minimum <= output <= maximum:
                    res.append(output)
            x_outputs = res
        try:
            if self.__settings.exclusive_start is not None:
                while x_outputs[0] <= self.__settings.exclusive_start:
                    x_outputs.pop(0)
            if self.__settings.inclusive_end is not None:
                while x_outputs[len(x_outputs) - 1] > self.__settings.inclusive_end:
                    x_outputs.pop()
        except IndexError:
            return x_outputs
        return x_outputs 
[docs]
    def get_outputs_values(self,
                           x_inputs: Dict[str, List[datetime]],
                           y_inputs: Dict[str, List[float]],
                           x_outputs: List[datetime],
                           vector_metadata: Dict[str, Metadata]) -> List[List[float]]:
        """
        Get the values after interpolation
        Parameters
        ----------
        x_inputs:
            list of the input times
        y_inputs:
            list of the input values
        x_outputs:
            list of the output times that will be use for the interpolation and extrapolation
        vector_metadata:
            List with the info about the vectors
        Returns
        -------
            Lists of the interpolated values
        """
        y_outputs: List[List[float]] = [[] for _ in range(len(x_inputs))]
        for k in range(0, len(x_inputs)):
            xid = list(x_inputs.keys())[k]
            x_input = x_inputs[xid]
            y_input = y_inputs[xid]
            len_vect = len(x_input)
            if len_vect > 0:
                vector_metadata_xid = vector_metadata[xid]
                by_step = vector_metadata_xid.is_by_step
                first_x = vector_metadata_xid.first_x if by_step and vector_metadata_xid.first_x is not None else \
                    
x_input[0]
                step_at_start = vector_metadata_xid.is_step_at_start
                extrapolation_method = self.__settings.extrapolation_method \
                    
if not (by_step and self.__settings.extrapolation_method == ExtrapolationMethodEnum.use_slope) \
                    
else ExtrapolationMethodEnum.boundary_value
                j = 0
                for ref_date in x_outputs:
                    if ref_date < first_x or ref_date > x_input[len_vect - 1]:
                        if by_step and step_at_start and ref_date > x_input[len_vect - 1]:
                            y = self.extrapolation(x_input, y_input, first_x, ref_date,
                                                   ExtrapolationMethodEnum.boundary_value)
                        else:
                            y = self.extrapolation(x_input, y_input, first_x, ref_date, extrapolation_method)
                    else:
                        while ref_date > x_input[j] and j < len_vect - 1:
                            j += 1
                        if j == 0:
                            if by_step or ref_date == x_input[0]:
                                y = y_input[0]
                            else:
                                y = self.interpolation(1, by_step, step_at_start, x_input, y_input, ref_date)
                        else:
                            y = self.interpolation(j, by_step, step_at_start, x_input, y_input, ref_date)
                    y_outputs[k].append(float(y))
            else:
                if self.__settings.extrapolation_method == ExtrapolationMethodEnum.zero:
                    for i in range(0, len(x_outputs)):
                        y_outputs[k].append(0)
                else:
                    for i in range(0, len(x_outputs)):
                        y_outputs[k].append(float("NaN"))
        return y_outputs 
[docs]
    def interpolation(self,
                      next_index: int,
                      by_step: bool,
                      step_at_start: bool,
                      ref_vect_x: List[datetime],
                      ref_vect_y: List[float],
                      ref_date: datetime) -> float:
        next_x = ref_vect_x[next_index]
        next_y = ref_vect_y[next_index]
        previous_x = ref_vect_x[next_index - 1]
        previous_y = ref_vect_y[next_index - 1]
        if ref_date == next_x:
            return next_y
        if self.__settings.interpolation_method == InterpolationMethodEnum.linear_interpolation or by_step:
            if len(ref_vect_x) <= 1:
                return ref_vect_y[0]
            else:
                if by_step:
                    if step_at_start:
                        return previous_y
                    else:
                        return next_y
                else:
                    return self.linear_interpolation(ref_date, previous_x, next_x, previous_y, next_y)
        elif self.__settings.interpolation_method == InterpolationMethodEnum.absent_value:
            return float("NaN")
        elif self.__settings.interpolation_method == InterpolationMethodEnum.zero:
            return 0
        elif self.__settings.interpolation_method == InterpolationMethodEnum.previous_value:
            return previous_y
        elif self.__settings.interpolation_method == InterpolationMethodEnum.next_value:
            return next_y
        else:
            return round((next_y + previous_y) / 2.0, 5) 
[docs]
    @staticmethod
    def linear_interpolation(reference_x: datetime, previous_x: datetime, next_x: datetime, previous_y: float, next_y: float) -> float:
        if next_x == previous_x:
            return 0.5 * (previous_y + next_y)
        return previous_y + ((next_y - previous_y) / (next_x - previous_x).total_seconds()) * (reference_x - previous_x).total_seconds()