import requests
import logging
import copy
import json
import simplejson
from datetime import datetime
from requests_ntlm import HttpNtlmAuth
from osisoft_constants import OSIsoftConstants
from osisoft_endpoints import OSIsoftEndpoints
from osisoft_plugin_common import (
    assert_server_url_ok, build_requests_params,
    is_filtered_out, is_server_throttling, escape, epoch_to_iso,
    iso_to_epoch, RecordsLimit, is_iso8601, get_next_page_url, change_key_in_dict
)
from osisoft_pagination import OffsetPagination
from safe_logger import SafeLogger


logger = SafeLogger("PI System", ["username", "password"])


class PISystemClientError(ValueError):
    pass


class OSIsoftClient(object):

    def __init__(self, server_url, auth_type, username, password, is_ssl_check_disabled=False, can_raise=True, is_debug_mode=False, network_timer=None):
        if can_raise:
            assert_server_url_ok(server_url)
        self.session = requests.Session()
        self.session.auth = self.get_auth(auth_type, username, password)
        self.session.verify = (not is_ssl_check_disabled)
        logger.info("Initialization server_url={}, is_ssl_check_disabled={}".format(server_url, is_ssl_check_disabled))
        self.endpoint = OSIsoftEndpoints(server_url)
        self.next_page = None
        self.can_raise = can_raise
        self.is_debug_mode = is_debug_mode
        self.debug_level = None
        self.network_timer = network_timer

    def get_auth(self, auth_type, username, password):
        if auth_type == "basic":
            return (username, password)
        elif auth_type == "ntlm":
            return HttpNtlmAuth(username, password)
        else:
            return None

    def recursive_get_rows_from_webid(self, webid, data_type, **kwargs):
        # Split the time range until no more HTTP 400
        kwargs["endpoint_type"] = kwargs.get("endpoint_type", "event_frames")
        kwargs["can_raise"] = kwargs.get("can_raise", True)
        start_date = kwargs["start_date"]
        end_date = kwargs["end_date"]
        max_count = kwargs.get("max_count")
        done = False
        previous_item_timestamp = False
        while not done:
            logger.info("Attempting download webids from {} to {}".format(start_date, end_date))
            rows = self.get_rows_from_webid(webid, data_type, **kwargs)
            counter = 0
            try:
                row = next(rows)
                if not previous_item_timestamp or previous_item_timestamp != row.get("Timestamp"):
                    # When the number of rows returned for a timerange is equal to maxCount,
                    # we re-send this request, this time between [last time received ; initial end time]
                    # The first element of this new request is likely to be the last element of the previous requeset
                    # If that's not the case (not the same timestamp), we keep that row.
                    logger.info("Keeping first row")
                    yield row
                counter += 1
                for row in rows:
                    yield row
                    counter += 1
            except Exception as err:
                if is_parameter_greater_than_max_allowed(err):
                    start_timestamp, end_timestamp, half_time_iso = self.halve_time_range(start_date, end_date)
                    kwargs["start_date"] = start_timestamp
                    kwargs["end_date"] = half_time_iso
                    first_half_rows = self.recursive_get_rows_from_webid(
                        webid, data_type, **kwargs
                    )
                    for row in first_half_rows:
                        yield row
                    logger.info("Successfully retrieved first half ({} to {})".format(start_timestamp, half_time_iso))
                    kwargs["start_date"] = half_time_iso
                    kwargs["end_date"] = end_timestamp
                    second_half_rows = self.recursive_get_rows_from_webid(
                        webid, data_type, **kwargs
                    )
                    for row in second_half_rows:
                        yield row
                    logger.info("Successfully retrieved second half ({} to {})".format(half_time_iso, end_timestamp))
                else:
                    logger.error("Error: {}".format(err))
                    raise Exception("Error: {}".format(err))
            logger.info("Successfully retrieved time range {} to {}".format(start_date, end_date))
            if counter == max_count:
                logger.warning("Number of replies equals maxCount. Shifting startDate and trying one more time.")
                last_received_timestamp = row.get("Timestamp")
                logger.info("Last received timestamp is {}".format(last_received_timestamp))
                start_date = last_received_timestamp
                kwargs["start_date"] = start_date
                previous_item_timestamp = last_received_timestamp
            else:
                done = True

    def recursive_get_rows_from_item(self, item, data_type, start_date=None, end_date=None,
                                     interval=None, sync_time=None, boundary_type=None, record_boundary_type=None,
                                     can_raise=True, object_id=None, endpoint_type="event_frames", search_full_hierarchy=None,
                                     max_count=None, summary_type=None, summary_duration=None):
        # item can be an pi tag, a path to an element or event frame
        # Split the time range until no more HTTP 400
        done = False
        previous_item_timestamp = False
        while not done:
            logger.info("Attempting download items from {} to {}".format(start_date, end_date))
            rows = self.get_rows_from_item(item, data_type, start_date=start_date, end_date=end_date, interval=interval,
                                           sync_time=sync_time, boundary_type=boundary_type, record_boundary_type=record_boundary_type,
                                           can_raise=True, object_id=object_id,
                                           search_full_hierarchy=search_full_hierarchy, max_count=max_count,
                                           summary_type=summary_type, summary_duration=summary_duration)
            counter = 0
            try:
                row = next(rows)
                if not previous_item_timestamp or previous_item_timestamp != row.get("Timestamp"):
                    # When the number of rows returned for a timerange is equal to maxCount,
                    # we re-send this request, this time between [last time received ; initial end time]
                    # The first element of this new request is likely to be the last element of the previous requeset
                    # If that's not the case (not the same timestamp), we keep that row.
                    logger.info("Keeping first row")
                    yield row
                counter += 1
                for row in rows:
                    yield row
                    counter += 1
            except Exception as err:
                if is_parameter_greater_than_max_allowed(err):
                    start_timestamp, end_timestamp, half_time_iso = self.halve_time_range(start_date, end_date)
                    first_half_rows = self.recursive_get_rows_from_item(
                        item, data_type, start_date=start_timestamp, end_date=half_time_iso,
                        interval=interval, sync_time=sync_time, boundary_type=boundary_type,
                        record_boundary_type=record_boundary_type, can_raise=True, object_id=object_id,
                        search_full_hierarchy=search_full_hierarchy, max_count=max_count, summary_type=summary_type, summary_duration=summary_duration
                    )
                    for row in first_half_rows:
                        yield row
                    logger.info("Successfully retrieved first half ({} to {})".format(start_timestamp, half_time_iso))
                    second_half_rows = self.recursive_get_rows_from_item(
                        item, data_type, start_date=half_time_iso, end_date=end_timestamp,
                        interval=interval, sync_time=sync_time, boundary_type=boundary_type,
                        record_boundary_type=record_boundary_type, can_raise=True, object_id=object_id,
                        search_full_hierarchy=search_full_hierarchy, max_count=max_count, summary_type=summary_type, summary_duration=summary_duration
                    )
                    for row in second_half_rows:
                        yield row
                    logger.info("Successfully retrieved second half ({} to {})".format(half_time_iso, end_timestamp))
                else:
                    logger.error("Error: {}".format(err))
                    if can_raise:
                        raise Exception("Error: {}".format(err))
                    # Only wrap and yield unhandled exceptions in the outer call
                    yield {'object_id': "{}".format(object_id), 'Errors': "{}".format(err)}
            logger.info("Successfully retrieved time range {} to {}".format(start_date, end_date))
            if counter == max_count:
                logger.warning("Number of replies equals maxCount. Shifting startDate and trying one more time.")
                last_received_timestamp = row.get("Timestamp")
                logger.info("Last received timestamp is {}".format(last_received_timestamp))
                start_date = last_received_timestamp
                previous_item_timestamp = last_received_timestamp
            else:
                done = True

    def halve_time_range(self, start_date, end_date):
        logger.warning("The time range {} -> {} is too large, splitting the job in two".format(start_date, end_date))
        start_timestamp = self.parse_pi_time(start_date, to_epoch=True)
        end_timestamp = self.parse_pi_time(end_date, to_epoch=True)
        new_time_range = (end_timestamp - start_timestamp) / 2
        half_time_iso = epoch_to_iso(start_timestamp + new_time_range)
        return epoch_to_iso(start_timestamp), epoch_to_iso(end_timestamp), half_time_iso

    def parse_pi_time(self, pi_time, to_epoch=False):
        """" Checks that pi_time is iso8601.
        If not, send it to pi-server to evaluate the Pi time expression.

        Arguments:
        pi_time -- String containing an iso8601 datetime or Pi time string format
                   https://docs.aveva.com/bundle/pi-web-api-reference/page/help/topics/time-strings.html
        to_epoch -- Select the format of the returned datetime (iso8601 / epoch)
        """

        logger.info("Parsing '{}' to_epoch={}".format(pi_time, to_epoch))
        if not pi_time:
            logger.info("No time given")
            return None
        if is_iso8601(pi_time):
            logger.info("Time is iso8601")
            if not to_epoch:
                return pi_time
        else:
            logger.info("Time is not iso8601")
        if to_epoch:
            epoch_timestamp = iso_to_epoch(pi_time)
            logger.info("'{}' converted to epoch '{}'".format(pi_time, epoch_timestamp))
            if epoch_timestamp:
                return epoch_timestamp
        logger.info("Using Pi server to resolve time string format")
        url = self.endpoint.get_calculation_time_url()
        headers = self.get_requests_headers()
        json_response = self.get(url=url, headers=headers, params={
            "expression": 'ParseTime("*")',
            "time": escape(pi_time)
        })
        items = json_response.get("Items", [{}])
        item = items[0]
        iso_timestamp = item.get("Timestamp")
        if to_epoch and iso_timestamp:
            return iso_to_epoch(iso_timestamp)
        return iso_timestamp

    def get_rows_from_webid(self, webid, data_type, **kwargs):
        endpoint_type = kwargs.get("endpoint_type", "event_frames")
        kwargs["endpoint_type"] = endpoint_type
        url = self.endpoint.get_data_from_webid_url(endpoint_type, data_type, webid)
        has_more = True
        while has_more:
            json_response, has_more = self.get_paginated(
                self.generic_get,
                url,
                **kwargs
            )
            if OSIsoftConstants.DKU_ERROR_KEY in json_response:
                json_response['object_id'] = "{}".format(webid)
                yield json_response
            else:
                items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [json_response])
                if not items:
                    items = [{}]
                for item in items:
                    yield item

    def get_rows_from_webids(self, input_rows, data_type, **kwargs):
        endpoint_type = kwargs.get("endpoint_type", "event_frames")
        batch_size = kwargs.get("batch_size", 500)

        batch_requests_parameters = []
        number_processed_webids = 0
        number_of_webids_to_process = len(input_rows)
        web_ids = []
        event_start_times = []
        event_end_times = []
        for input_row in input_rows:
            event_start_time = event_end_time = None
            if isinstance(input_row, dict):
                webid = input_row.get("WebId")
                event_start_time = input_row.get("StartTime")
                event_end_time = input_row.get("EndTime")
            else:
                webid = input_row
            url = self.endpoint.get_data_from_webid_url(endpoint_type, data_type, webid)
            requests_kwargs = self.generic_get_kwargs(**kwargs)
            requests_kwargs['url'] = build_query_string(url, requests_kwargs.get("params"))
            web_ids.append(webid)
            event_start_times.append(event_start_time)
            event_end_times.append(event_end_time)
            batch_requests_parameters.append(requests_kwargs)
            number_processed_webids += 1
            if (len(batch_requests_parameters) >= batch_size) or (number_processed_webids == number_of_webids_to_process):
                json_responses = self._batch_requests(batch_requests_parameters)
                batch_requests_parameters = []
                response_index = 0
                for json_response in json_responses:
                    response_content = json_response.get("Content", {})
                    webid = web_ids[response_index]
                    event_start_time = event_start_times[response_index]
                    event_end_time = event_end_times[response_index]
                    if OSIsoftConstants.DKU_ERROR_KEY in response_content:
                        if endpoint_type == "event_frames":
                            response_content['event_frame_webid'] = "{}".format(webid)
                        yield response_content
                    items = response_content.get(OSIsoftConstants.API_ITEM_KEY, [])
                    for item in items:
                        if event_start_time:
                            item['StartTime'] = event_start_time
                        if event_end_time:
                            item['EndTime'] = event_end_time
                        if endpoint_type == "event_frames":
                            item['event_frame_webid'] = "{}".format(webid)
                        yield item
                    response_index += 1
                web_ids = []

    def _batch_requests(self, batch_requests_parameters, method=None):
        method = method or "GET"
        batch_endpoint = self.endpoint.get_batch_endpoint()
        batch_body = {}
        index = 0
        for row_request_parameters in batch_requests_parameters:
            batch_request = {}
            batch_request["Method"] = method
            batch_request["Resource"] = "{}".format(row_request_parameters.get("url"))
            if "data" in row_request_parameters:
                batch_request["Content"] = "{}".format(row_request_parameters.get("data"))
            if "json" in row_request_parameters:
                batch_request["Content"] = "{}".format(row_request_parameters.get("json"))
            batch_body["{}".format(index)] = batch_request
            index += 1
        response = self.post_value(url=batch_endpoint, data=batch_body)
        json_response = simplejson.loads(response.content)
        for index in range(0, len(batch_requests_parameters)):
            batch_section = json_response.get("{}".format(index), {})
            yield batch_section

    def generic_get_kwargs(self, **kwargs):
        headers = self.get_requests_headers()
        params = build_requests_params(**kwargs)
        return {
            "headers": headers,
            "params": params
        }

    def generic_get(self, url, **kwargs):
        headers = self.get_requests_headers()
        params = build_requests_params(**kwargs)
        can_raise = kwargs.get("can_raise", True)
        json_response = self.get(
            url=url,
            headers=headers,
            params=params,
            can_raise=can_raise
        )
        return json_response

    def get_rows_from_item(self, item, data_type, **kwargs):
        # item can be an pi tag, a path to an element or event frame
        has_more = True
        object_id = kwargs.get("object_id")
        while has_more:
            json_response, has_more = self.get_paginated(
                self.get_link_from_item,
                item,
                data_type,
                **kwargs
            )
            if OSIsoftConstants.DKU_ERROR_KEY in json_response:
                json_response['object_id'] = "{}".format(object_id)
                yield json_response
            items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [json_response])
            for item in items:
                yield self.loop_sub_items(item)

    def get_link_from_item(self, item, data_type, **kwargs):
        can_raise = kwargs.get("can_raise", True)
        url = self.extract_link_with_key(item, data_type)
        if not url:
            error_message = "This object does not have {} data type".format(data_type)
            if can_raise:
                raise PISystemClientError(error_message)
            return {OSIsoftConstants.DKU_ERROR_KEY: error_message}
        headers = self.get_requests_headers()
        kwargs = change_key_in_dict(kwargs, "boundary_type", "sync_time_boundary_type")
        params = build_requests_params(
            **kwargs
        )
        json_response = self.get(
            url=url,
            headers=headers,
            params=params,
            can_raise=can_raise
        )
        return json_response

    def get_rows_from_url(self, url=None, start_date=None, end_date=None, interval=None, sync_time=None, boundary_type=None,  max_count=None):
        pagination = OffsetPagination()
        has_more = True
        while has_more:
            json_response, has_more = pagination.get_offset_paginated(
                self.get_link_from_url,
                url, start_date=start_date, end_date=end_date, interval=interval, sync_time=sync_time, boundary_type=boundary_type, max_count=max_count
            )
            items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [json_response])
            for item in items:
                if OSIsoftConstants.API_ITEM_KEY in item:
                    rets = self.loop_sub_items(item)
                    for ret in rets:
                        yield ret
                else:
                    yield item

    def get_rows_from_urls(self, links=None, data_type=None, start_date=None, end_date=None, interval=None, sync_time=None, boundary_type=None, max_count=None):
        links = links or []
        for link in links:
            url = link
            rows = self.get_rows_from_url(
                url, start_date=start_date, end_date=end_date, interval=interval,
                sync_time=sync_time, boundary_type=boundary_type, max_count=max_count
            )
            for row in rows:
                yield row

    def get_link_from_url(self, url, **kwargs):
        if not url:
            url = self.endpoint.get_base_url()
        headers = self.get_requests_headers()
        params = build_requests_params(
            **kwargs
        )
        json_response = self.get(
            url=url,
            headers=headers,
            params=params
        )
        return json_response

    def get_paginated(self, calling_function, *args, **kwargs):
        if self.next_page:
            json_response = self.get(self.next_page, headers=self.get_requests_headers(), params={})
            self.next_page = None
        else:
            json_response = calling_function(*args, **kwargs)
        self.next_page = json_response.get("Links", {}).get("Next", None)
        if self.next_page:
            has_more = True
            logging.info("Next page is {}".format(self.next_page))
        else:
            has_more = False
        if OSIsoftConstants.API_ITEM_KEY in json_response:
            items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [])
            if not items:
                has_more = False
        return json_response, has_more

    def is_resource_path(self, reference):
        if isinstance(reference, str):
            return reference.startswith("\\")
        else:
            return False

    def get_web_id(self, resource_path):
        url = self.endpoint.get_resource_path_url()
        headers = self.get_requests_headers()
        params = self.get_resource_path_params(resource_path)
        json_response = self.get(
            url=url,
            headers=headers,
            params=params,
            can_raise=False,
            error_source="get_web_id"
        )
        if OSIsoftConstants.DKU_ERROR_KEY in json_response:
            logging.warning("Path {} not found by resource path search, trying by traversing")
            json_response = self.traverse_path(resource_path)
        return json_response.get("WebId")

    def get_item_from_path(self, item_path):
        url = self.endpoint.get_resource_path_url()
        headers = self.get_requests_headers()
        params = self.get_resource_path_params(item_path)
        json_response = self.get(
            url=url,
            headers=headers,
            params=params,
            can_raise=False,
            error_source="get_item_from_path"
        )
        if OSIsoftConstants.DKU_ERROR_KEY in json_response:
            try:
                json_response = self.traverse_path(item_path)
            except Exception as err:
                logger.warning("Error while traversing path {}:{}".format(item_path, err))
        return json_response

    def get_item_from_url(self, url):
        headers = self.get_requests_headers()
        params = {}
        json_response = self.get(
            url=url,
            headers=headers,
            params=params,
            can_raise=False,
            error_source="get_item_from_url"
        )
        return json_response

    def get(self, url, headers, params, can_raise=True, error_source=None):
        error_message = None
        url = build_query_string(url, params)
        logger.info("Trying to connect to {}".format(url))
        limit = RecordsLimit(OSIsoftConstants.MAXIMUM_RETRIES_ON_THROTTLING)
        try:
            response = None
            while is_server_throttling(response):
                if self.network_timer:
                    self.network_timer.start(url)
                response = self.session.get(
                    url=url,
                    headers=headers
                )
                if self.network_timer:
                    self.network_timer.stop()
                if self.is_debug_mode:
                    logger.info("get response.content={}".format(response.content)[:1000])
                    logger.info("get response.status={}".format(response.status_code))
                if limit.is_reached():
                    error_message = "The maximum number of retries has been reached."
                    break
        except Exception as err:
            error_message = "Could not connect. Error: {}{}".format(formatted_error_source(error_source), err)
            logger.error(error_message)
            if can_raise:
                raise PISystemClientError(error_message)
        if not error_message:
            error_message = self.assert_valid_response(response, can_raise=can_raise, error_source=error_source)
        if error_message:
            return {OSIsoftConstants.DKU_ERROR_KEY: error_message}
        json_response = simplejson.loads(response.content)
        return json_response

    def post_stream_value(self, webid, data):
        url = self.endpoint.get_stream_value_url(webid)
        headers = OSIsoftConstants.WRITE_HEADERS
        params = {}
        response = self.post(
            url=url,
            headers=headers,
            params=params,
            data=data
        )
        return response

    def post_value(self, url, data):
        headers = self.get_requests_headers()
        headers.update(OSIsoftConstants.WRITE_HEADERS)
        params = {}
        response = self.post(
            url=url,
            headers=headers,
            params=params,
            data=data
        )
        return response

    def prepare_post_all_values(self, webid, buffer):
        url = self.endpoint.get_stream_record_url(webid)
        headers = OSIsoftConstants.WRITE_HEADERS
        params = {}
        requests_kwargs = self.generic_get_kwargs(url=url, headers=headers, params=params, data=buffer)
        requests_kwargs['url'] = url
        requests_kwargs['json'] = buffer
        return requests_kwargs

    def post(self, url, headers, params, data, can_raise=True, error_source=None):
        url = build_query_string(url, params)
        logger.info("Trying to post to {}".format(url))
        if self.network_timer:
            self.network_timer.start(url)
        response = self.session.post(
            url=url,
            headers=headers,
            json=data
        )
        if self.network_timer:
            self.network_timer.stop()
        if self.is_debug_mode:
            logger.info("post response.content={}".format(response.content)[:self.get_debug_level()])
            logger.info("post response.status={}".format(response.status_code))
        self.assert_valid_response(response, can_raise=can_raise, error_source=error_source)
        return response

    def get_debug_level(self):
        if self.debug_level == 5000:
            self.debug_level = 1000
        if not self.debug_level:
            self.debug_level = 5000
        return self.debug_level

    def get_resource_path_params(self, resource_path):
        return {
            "path": escape(resource_path)
        }

    def get_requests_headers(self):
        return {
            "Content-Type": "application/json",
            "Accept": "application/json",
            "Accept-Encoding": "gzip, deflate, br"
        }

    def assert_valid_response(self, response, can_raise=True, error_source=None):
        if response.status_code >= 400:
            error_message = "Error {}{}".format(formatted_error_source(error_source), response.status_code)
            try:
                json_response = simplejson.loads(response.content)
                if OSIsoftConstants.DKU_ERROR_KEY in json_response:
                    error_message = error_message + " {}".format(json_response.get(OSIsoftConstants.DKU_ERROR_KEY))
                if "Message" in json_response:
                    error_message = error_message + " {}".format(json_response.get("Message"))
            except Exception as err:
                logger.error("{}".format(err))
            logger.error(error_message)
            logger.error("response.content={}".format(response.content))
            if can_raise:
                raise PISystemClientError(error_message)
            return error_message

    def loop_sub_items(self, base_row):
        base_row.pop("Links", None)
        if OSIsoftConstants.API_ITEM_KEY in base_row:
            sub_items = base_row.pop(OSIsoftConstants.API_ITEM_KEY, [])
            new_rows = []
            for sub_item in sub_items:
                new_row = copy.deepcopy(base_row)
                new_row.update(sub_item)
                new_rows.append(new_row)
            return new_rows
        else:
            return unnest(base_row)

    def get_asset_servers(self, can_raise=True):
        asset_servers = []
        asset_servers_url = self.endpoint.get_asset_servers_url()
        headers = self.get_requests_headers()
        json_response = self.get(url=asset_servers_url, headers=headers, params={}, error_source="get_asset_servers", can_raise=can_raise)
        if OSIsoftConstants.DKU_ERROR_KEY in json_response:
            return [{
                "label": json_response.get(OSIsoftConstants.DKU_ERROR_KEY)
            }]
        items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [])
        for item in items:
            asset_servers.append({
                "label": item.get("Name"),
                "value": item.get("Links").get("Databases")
            })
        return asset_servers

    def get_data_servers(self, can_raise=True):
        data_servers = []
        data_servers_url = self.endpoint.get_data_servers_url()
        headers = self.get_requests_headers()
        json_response = self.get(url=data_servers_url, headers=headers, params={}, error_source="get_data_servers", can_raise=can_raise)
        if OSIsoftConstants.DKU_ERROR_KEY in json_response:
            return [{
                "label": json_response.get(OSIsoftConstants.DKU_ERROR_KEY)
            }]
        items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [])
        for item in items:
            data_servers.append({
                "label": item.get("Name"),
                "value": item.get("Links", {}).get("Points")
            })
        return data_servers

    def get_next_choices(self, next_url, next_key, params=None, use_name_as_link=False, filter=None):
        params = params or {}
        next_choices = []
        headers = self.get_requests_headers()
        json_response = self.get(url=next_url, headers=headers, params=params, error_source="get_next_choices")
        if OSIsoftConstants.DKU_ERROR_KEY in json_response:
            return [{
                "label": json_response.get(OSIsoftConstants.DKU_ERROR_KEY)
            }]
        items = json_response.get(OSIsoftConstants.API_ITEM_KEY)
        for item in items:
            if not is_filtered_out(item, filter):
                next_choices.append({
                    "label": item.get("Name"),
                    "value": item.get("Name") if use_name_as_link else item.get("Links").get(next_key)
                })
        return next_choices

    def get_next_choices_as_json(self, next_url, next_key, params=None, use_name_as_link=False):
        params = params or {}
        next_choices = []
        headers = self.get_requests_headers()
        json_response = self.get(url=next_url, headers=headers, params=params, error_source="get_next_choices")
        if OSIsoftConstants.DKU_ERROR_KEY in json_response:
            return [{
                "label": json_response.get(OSIsoftConstants.DKU_ERROR_KEY)
            }]
        items = json_response.get(OSIsoftConstants.API_ITEM_KEY)
        for item in items:
            next_choices.append({
                "label": item.get("Name"),
                "value": json.dumps({"url": item.get("Links").get(next_key), "label": item.get("Name")})
            })
        return next_choices

    def search_attributes(self, database_webid, **kwargs):
        search_attributes_base_url = self.endpoint.get_attribute_url()
        query = "Element:{{{}}} {}".format(
            self.build_element_query(**kwargs),
            self.build_attribute_query(**kwargs)
        )
        headers = self.get_requests_headers()
        params = {
            "query": query,
            "databaseWebId": database_webid
        }
        json_response = self.get(url=search_attributes_base_url, headers=headers, params=params)
        if OSIsoftConstants.DKU_ERROR_KEY in json_response:
            yield json_response
        while json_response:
            next_page_url = get_next_page_url(json_response)
            items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [])
            for item in items:
                yield item
            if next_page_url:
                json_response = self.get(url=next_page_url, headers={}, params={})
            else:
                json_response = None

    def build_element_query(self, **kwargs):
        element_query_keys = {
            "element_name": "Name:'{}'",
            "search_root_path": "Root:'{}'",
            "element_template": "Template:'{}'",
            "element_type": "Type:'{}'",
            "element_category": "CategoryName:'{}'"
        }
        output_tokens = []
        kwargs = apply_manual_inputs(kwargs)
        for argument in kwargs:
            value = kwargs.get(argument)
            if value and argument in element_query_keys:
                template = element_query_keys.get(argument)
                output_tokens.append(template.format(value))
        return " ".join(output_tokens)

    def build_attribute_query(self, **kwargs):
        attribute_query_keys = {
            "attribute_name": "Name:'{}'",
            "attribute_category": "CategoryName:'{}'",
            "attribute_value_type": "Type:'{}'"
        }
        output_tokens = []
        for argument in kwargs:
            value = kwargs.get(argument)
            if value and argument in attribute_query_keys:
                template = attribute_query_keys.get(argument)
                output_tokens.append(template.format(value))
        return " ".join(output_tokens)

    def traverse(self, path_elements):
        # traversing:
        # piwebapi AssetServers Databases Items[].name="Well" Elements Items[].name=Assets Elements Items[].name=TX511 Attributes

        # Loading piwebapi initial page
        # next_url = self.get_web_api_base_url()
        next_url = self.endpoint.get_base_url()
        headers = self.get_requests_headers()
        json_response = self.get(url=next_url, headers=headers, params={}, error_source="traverse")

        # Asset server page
        next_url = self.extract_link_with_key(json_response, "AssetServers")
        json_response = self.get(url=next_url, headers=headers, params={}, error_source="traverse")

        item = self.extract_item_with_name(json_response, path_elements.pop(0))
        next_url = self.extract_link_with_key(item, "Databases")
        json_response = self.get(url=next_url, headers=headers, params={}, error_source="traverse")

        # get the database
        item = self.extract_item_with_name(json_response, path_elements.pop(0))
        next_url = self.extract_link_with_key(item, "Elements")
        json_response = self.get(url=next_url, headers=headers, params={}, error_source="traverse")

        # Looping through elements
        for path_element in path_elements:
            element, attribute = self.split_element_attribute(path_element)
            item = self.extract_item_with_name(json_response, element)
            if attribute:
                next_url = self.extract_link_with_key(item, "Attributes")
            else:
                next_url = self.extract_link_with_key(item, "Elements")
            json_response = self.get(url=next_url, headers=headers, params={}, error_source="traverse")
            if attribute:
                item = self.extract_item_with_name(json_response, attribute)

        return item

    def split_element_attribute(self, path_element):
        attribute = None
        path_elements = path_element.split("|")
        if len(path_elements) > 1:
            attribute = path_elements[1]
        return path_elements[0], attribute

    def extract_item_with_name(self, json_response, name):
        items = json_response.get(OSIsoftConstants.API_ITEM_KEY, [])
        for item in items:
            item_name = item.get("Name", "")
            if item_name == name:
                return item
        return {}

    def extract_link_with_key(self, item, key):
        links = item.get("Links", {})
        return links.get(key, "")

    def traverse_path(self, path):
        elements = path.split("\\")
        elements.pop(0)  # Server name
        elements.pop(0)  # Database name
        json_response = self.traverse(elements)
        return json_response

    def unnest_row(self, row):
        rows_to_append = [row]
        if OSIsoftConstants.API_ITEM_KEY in row:
            items = row.pop(OSIsoftConstants.API_ITEM_KEY, [])
            for item in items:
                base_row = copy.deepcopy(row)
                base_row.update(item)
                rows_to_append.append(base_row)
            return rows_to_append
        else:
            if OSIsoftConstants.API_VALUE_KEY in row:
                base_row = copy.deepcopy(row)
                value = base_row.pop(OSIsoftConstants.API_VALUE_KEY)
                if isinstance(value, dict):
                    base_row.update(value)
                    return [base_row]
            return rows_to_append


def format_output_row(row):
    # Duplicates the row for each element of the Items key
    # Do that recursively (streamsets contains items in items)
    # Unnest the Value key if it is an object
    if OSIsoftConstants.API_ITEM_KEY in row:
        items = row.get(OSIsoftConstants.API_ITEM_KEY, [])
        for item in items:
            initial_row = copy.deepcopy(row)
            initial_row.pop(OSIsoftConstants.API_ITEM_KEY, None)
            initial_row.update(item)
            initial_row.pop("Links", None)
            new_rows = format_output_row(initial_row)
            for new_row in new_rows:
                yield new_row
    elif "Value" in row and isinstance(row.get("Value"), dict):
        initial_row = copy.deepcopy(row)
        value = initial_row.pop("Value", None)
        initial_row.update(value)
        yield initial_row
    else:
        yield row


class OSIsoftWriter(object):
    def __init__(self, client, path, column_names, value_url=False):
        self.client = client
        if value_url:
            self.webid = path
        else:
            self.webid = self.client.get_web_id(path)
        self.free_timing = "Timestamp" in column_names
        self.timestamp_rank, self.value_rank = self.get_column_rank(column_names)
        self.value_url = value_url
        self.path = path

    def get_column_rank(self, column_names):
        if "Timestamp" in column_names:
            logger.info("'Timestamp' column found")
            timestamp_rank = column_names.index("Timestamp")
        else:
            logger.info("No 'Timestamp' column found. Using current time")
            timestamp_rank = None
        if "Value" in column_names:
            value_rank = column_names.index("Value")
        else:
            raise PISystemClientError("The 'Value' column cannot be found in the input dataset")
        return timestamp_rank, value_rank

    def write_row(self, row):
        """
        Row is a tuple with N + 1 elements matching the schema passed to get_writer.
        The last element is a dict of columns not found in the schema
        """
        if self.timestamp_rank is not None:
            timestamp = self.timestamp_conversion(row[self.timestamp_rank])
        else:
            timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
        if not row[self.value_rank]:
            logger.warning("Empty value at timestamp {}".format(timestamp))
            return
        data = {
            "Timestamp": timestamp,
            "Value": row[self.value_rank]
        }
        if self.value_url:
            return self.client.post_value(self.path, data)
        else:
            return self.client.post_stream_value(self.webid, data)

    def timestamp_conversion(self, timestamp):
        return timestamp

    def close(self):
        pass

    # https://eme/piwebapi/streams/{{webid}}}}/value
    # body:
    # {
    #   "Timestamp": "2015-04-03T18:46:10.39135 -7",
    #   "Value": 42.0,
    # }


class OSIsoftBatchWriter(object):
    # Each row of data added (write_row) first goes in a streak buffer
    # streak buffer is meant to be used to push a flow of values to one AF path / webid
    # If the new row concerns another webid, the current streak buffer is flushed into the request buffer
    # When request buffer is full, it is flushed into the batch endpoint

    # Each write_row call is adds an entry to the responses list
    # The pointer into that list is past as a _dku_counter parameter, extracted from the requests just before sending the batch
    # Uppon receiving the batch response, each individual reponse is reordered
    #   and the status code / error messages are stored into the responses list at the right pointer

    # Possible improvement: flush the responses list (and result writing in the recipe.py) at each _flush_requests to keep memory from going up.

    # My Regards to the unsung Hero who had to review this code

    def __init__(self, client, max_streak_buffer_size=500, max_requests_buffer_size=500):
        logger.info("Initializing OSIsoftBatchWriter, msbs={}, mrbs={}".format(max_streak_buffer_size, max_requests_buffer_size))
        self.client = client
        self.streak_buffer = []     # list of points from a single webid -> can be all sent in one request
        self.requests_buffer = []   # list of independant requests -> sent one batch at a time
        self.responses = []         # building a list of reponses in same order as write_row calls
        self.current_webid = None
        self.max_buffer_size = max_streak_buffer_size
        self.max_requests_buffer_size = max_requests_buffer_size
        self.current_streak = 0
        self.row_number = 0

    def write_row(self, webid, timestamp, value):
        response = None
        if not validate_timestamp(timestamp):
            error_message = "Timestamp '{}' has an invalid format".format(
                    timestamp
            )
            logger.error(error_message)
            # No valid timestamp so we skip this row
            return {
                "Error": error_message
            }
        # mark in self.responses that status of this write depends on result of streak X
        if self.current_webid is None:
            logger.info("webid : {}".format(webid))
            self.current_webid = webid
        if webid != self.current_webid or len(self.streak_buffer) >= self.max_buffer_size:
            logger.info("webid: {} / {}".format(webid, self.current_webid))
            self._flush_streak()
            self.current_webid = webid
        self.streak_buffer.append(
            {
                "Timestamp": "{}".format(timestamp),
                "Value": "{}".format(value),
                "_dku_counter": self.row_number
            }
        )
        self.row_number += 1
        logger.info("Pushed in streak buffer. Size is now {}".format(len(self.streak_buffer)))
        self.responses.append({"streak": self.current_streak})
        return response

    def _flush_streak(self):
        # streak buffer must be flushed every time the webid changes or before closing the session
        logger.info("flushing streak")
        kwargs = self.client.prepare_post_all_values(self.current_webid, self.streak_buffer)
        self.requests_buffer.append(kwargs)
        logger.info("pushed to requests buffer {}".format(len(self.requests_buffer)))
        if len(self.requests_buffer) >= self.max_requests_buffer_size:
            logger.info("Buffer is full, flushing current requests")
            self._flush_requests()
        self.streak_buffer = []
        self.current_streak += 1

    def _flush_requests(self):
        logger.info("flushing current requests")
        request_buffer, streaks_pointers = prepare_request_buffer(self.requests_buffer)
        json_responses = self.client._batch_requests(request_buffer, method='POST')
        for json_response, streak_pointers in zip(json_responses, streaks_pointers):
            for streak_pointer in streak_pointers:
                self.responses[streak_pointer] = json_response
        self.requests_buffer = []

    def close(self):
        logger.info("closing")
        self._flush_streak()
        self._flush_requests()
        return self.responses


def validate_timestamp(timestamp):
    valid_formats=["%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ"]
    for valid_format in valid_formats:
        try:
            datetime.strptime(timestamp, valid_format)
            return True
        except Exception:
            pass
    return False


def formatted_error_source(error_source):
    return "({}) ".format(error_source) if error_source else ""


def build_query_string(url, params):
    # requests doesn't handle backslash in params, so we build the query string by hand
    # Todo: extract existing query params from url
    params = params or {}
    tokens = []
    for key in params:
        value = params.get(key)
        if isinstance(value, list):
            for element in value:
                tokens.append(key+"="+str(element))
        else:
            tokens.append(key+"="+str(value))
    if len(tokens) > 0:
        return url + "?" + "&".join(tokens)
    else:
        return url


def unnest(row):
    if "Value" in row and isinstance(row.get("Value"), dict):
        value_object = row.pop("Value", {})
        if isinstance(value_object, dict):
            for key in value_object:
                row["{}".format(key)] = value_object.get(key)
    return row


def apply_manual_inputs(kwargs):
    new_kwargs = {}
    for kwarg in kwargs:
        value = kwargs.get(kwarg)
        if value == "_DKU_manual_input":
            new_value = kwargs.get("{}_manual_input".format(kwarg))
            new_kwargs[kwarg] = new_value
        elif value == "_DKU_variable_select":
            import dataiku
            variable_name = kwargs.get("{}_variable_select".format(kwarg))
            variables = dataiku.get_custom_variables()
            if not variable_name:
                raise Exception("No variable was selected for {}".format(kwarg))
            if variable_name not in variables:
                raise Exception("Variable '{}' used in {} does not exists".format(variable_name, kwarg))
            new_value = "{}".format(variables.get(variable_name))
            new_kwargs[kwarg] = new_value
        elif not kwarg.endswith("_manual_input") and not kwarg.endswith("_variable_select"):
            new_kwargs[kwarg] = kwargs.get(kwarg)
    return new_kwargs


def is_parameter_greater_than_max_allowed(error_message):
    return "Error 400" in "{}".format(error_message) and "is greater than the maximum allowed" in "{}".format(error_message)


def prepare_request_buffer(request_buffer):
    # remove _dku_counter from the request,
    # produce a list of write_row call number per streak
    #   later used to determine which response status code / error applies to which write_row call
    row_counter = []
    for streak in request_buffer:
        json = streak.get("json", [])
        dku_counter = []
        for row in json:
            _dku_counter = row.pop("_dku_counter", None)
            dku_counter.append(_dku_counter)
        row_counter.append(dku_counter)
    return request_buffer, row_counter
