# -*- coding: utf-8 -*-
import logging
import os
from pathlib import Path

from dataiku import Dataset, Folder
from dataiku.core.base import is_container_exec
from dataiku.customrecipe import get_input_names_for_role, get_output_names_for_role, get_recipe_config

from solutions.graph.graph_db_instance_manager import (
    SUPPORTED_STORAGES,
    AbstractDbInstance,
    LocalDbInstance,
    LocalReplicaDbInstance,
)
from solutions.graph.kuzu.algorithms.page_rank import PageRankParams, ProjectedGraph, compute_page_rank

logger = logging.getLogger(__name__)

db_folder_name = get_input_names_for_role("graph_db_folder")[0]
db_folder = Folder(db_folder_name)

db_instance: AbstractDbInstance

recipe_config = get_recipe_config()
relative_path_to_db = recipe_config["path_to_db"]

db_folder_type = db_folder.get_info(sensitive_info=True)["type"]
if db_folder_type not in SUPPORTED_STORAGES:
    raise Exception(
        f"This recipe does not support output folder {db_folder_type}. It should be {', '.join(SUPPORTED_STORAGES)}."
    )

if db_folder_type == "Filesystem" and not is_container_exec():
    path_to_db = os.path.join(db_folder.get_path(), relative_path_to_db)
    db_instance = LocalDbInstance(path_to_db, readonly=False)
else:
    # Connection needs to be opened in write mode to allow creating the projected graph.
    # Write mode cannot be used with S3 or GCS remote instances, so we download the files locally.

    db_instance = LocalReplicaDbInstance(Path(relative_path_to_db), db_folder, readonly=False)


with db_instance:
    with db_instance.get_new_conn() as conn_context_manager:
        projected_graph = ProjectedGraph("page_rank_graph", recipe_config["node_groups"], recipe_config["edge_groups"])
        params = PageRankParams(
            damping_factor=float(recipe_config.get("damping_factor", 0.85)),
            max_iterations=int(recipe_config.get("max_iterations", 100)),
            tolerance=float(recipe_config.get("tolerance", 1e-7)),
            normalizeInitial=bool(recipe_config.get("normalize_initial", True)),
        )

        logger.info(f"Computing PageRank on Kuzu db in folder {db_folder_name} at {relative_path_to_db}.")
        output_ds = Dataset(get_output_names_for_role("output_ds")[0])

        first_batch = True
        with output_ds.get_writer() as output_writer:
            for df in compute_page_rank(
                conn_context_manager.connection, projected_graph, params, recipe_config.get("batch_size", 10000)
            ):
                if first_batch:
                    output_ds.write_schema_from_dataframe(df, drop_and_create=True)
                    first_batch = False

                result_length = len(df)

                logger.info(f"Writing batch of {result_length} rows to output dataset {output_ds.name}.")
                output_writer.write_dataframe(df)
                logger.debug(f"Done writing batch of {result_length} rows to output dataset {output_ds.name}.")

logger.info("Done computing PageRank.")
