import logging
from collections import defaultdict

import torch.distributed as dist

from dataiku.doctor.deephub.deephub_context import get_deephub_context

logger = logging.getLogger(__name__)


class PredictedDataAccumulator(object):

    def __init__(self):
        self._values_to_accumulate = defaultdict(list)
        self._accumulated_values = defaultdict(list)
        self._gathered = False

    def accumulate_value(self, name, value):
        self._values_to_accumulate[name].append(value)

    def gather(self):
        if len(self._values_to_accumulate) == 0:
            self._gathered = True
            return

        training_context = get_deephub_context()
        if not training_context.distributed:
            self._accumulated_values = self._values_to_accumulate
        else:
            logger.info("Begin 'all_gather_object()' to exchange accumulated values across workers")
            acc = [None for _ in range(training_context.world_size)]
            dist.all_gather_object(acc, self._values_to_accumulate)
            for some_values in acc:
                for name in some_values:
                    self._accumulated_values[name].extend(some_values[name])
            logger.info("End 'all_gather_object()'")
        self._gathered = True

    def has_accumulated_values(self):
        if not self._gathered:
            raise Exception("Cannot get accumulated value before gathering results")
        return len(self._accumulated_values) > 0

    def get_accumulated_value(self, name, default=None):
        if not self._gathered:
            raise Exception("Cannot get accumulated value before gathering results")
        return self._accumulated_values.get(name, default)
