# -*- coding: utf-8 -*-
import logging

from dataiku import Dataset
from dataiku.customrecipe import get_input_names_for_role, get_output_names_for_role, get_recipe_config

from solutions.graph.dataiku.batch_generators import generate_nodes_batches
from solutions.graph.graph_builder import NODE_PRIMARY_KEY_PROP_NAME_PREFIX
from solutions.graph.models import to_node_group_definition
from solutions.graph.store.graph_metadata_snapshot_store import DataikuGraphMetadataSnapshotStore

logger = logging.getLogger(__name__)

# get recipe config
snapshot_id = get_recipe_config()["snapshot_id"]
node_group_id = get_recipe_config()["node_id"]

logger.info(f"Collecting nodes {node_group_id} of saved configuration {snapshot_id}.")

snapshots_store = DataikuGraphMetadataSnapshotStore(get_input_names_for_role("snapshots_ds")[0])

snapshot = snapshots_store.get_by_id(snapshot_id)
if not snapshot:
    raise ValueError(f"Cannot load saved configuration, no saved configuration {snapshot_id} is available.")

nodes = snapshot["nodes"]

if not nodes.get(node_group_id):
    raise ValueError("Cannot load node group, no node group named " + node_group_id + " is available")

nodes_group_meta = nodes[node_group_id]
node_group = nodes_group_meta["node_group"]

nodes_dataset_output = Dataset(get_output_names_for_role("main")[0])

group_definitions = [to_node_group_definition(node_group_id, node_group, d) for d in nodes_group_meta["definitions"]]

logger.info("Starting to process nodes...")

first_batch = True
with nodes_dataset_output.get_writer() as output_writer:
    for batch in generate_nodes_batches(group_definitions, sampling=None):
        definition = batch["definition"]
        node_group = definition["node_group"]

        df = batch["df"]
        primary_col = definition["primary_col"]
        # _dku_id, id of the the node.
        df[NODE_PRIMARY_KEY_PROP_NAME_PREFIX] = df[primary_col]
        # _dku_label, label of the node group.
        df["_dku_label"] = definition["node_group"]

        if first_batch:
            nodes_dataset_output.write_schema_from_dataframe(df, drop_and_create=True)
            first_batch = False

        output_writer.write_dataframe(df)

logger.info("Done processing nodes.")
