from clv_forecast.constants import (
    FLOW_INPUT_DATASETS,
    FLOW_INPUT_DATASETS_TO_PRESERVE,
    FALLBACK_CONNECTIONS,
    FALLBACK_DATASETS,
)
from dku_utils.projects.connections.connection_commons import FlowConnectionsHandler
from dku_utils.projects.datasets.dataset_commons import (
    get_dataset_in_connection_settings,
)
from dku_utils.projects.project_commons import get_current_project_and_variables


def get_list_of_datasets_to_load(app_variables):
    datasets_to_load = ["transactions_history"]
    leverage_user_metadata = app_variables["leverage_customer_metadata_app"]
    if leverage_user_metadata:
        datasets_to_load.append("customer_metadata")
    leverage_item_metadata = app_variables["leverage_rfm_segmentation_app"]
    if leverage_item_metadata:
        datasets_to_load.append("customer_rfm_segments")
    return datasets_to_load


def update_connections():
    project, variables = get_current_project_and_variables()

    main_connection_name = variables["standard"]["main_connection_name_app"]

    main_connection_settings = get_dataset_in_connection_settings(
        project, main_connection_name
    )
    main_connection_type = main_connection_settings["type"]
    managed_dataset_format = variables["standard"]["managed_dataset_format_app"]

    if main_connection_type in ["Redshift", "Synapse", "BigQuery"]:
        fallback_connection_name = variables["standard"]["fallback_connection_name_app"]
        fallback_connection_datasets_downstream_recipes = FALLBACK_CONNECTIONS
        fallback_connection_datasets = list(FALLBACK_DATASETS)
    else:
        fallback_connection_name = None
        fallback_connection_datasets_downstream_recipes = None
        fallback_connection_datasets = None

    folder_connection_name = variables["standard"]["folder_connection_name_app"]

    datasets_to_tables_or_paths_mapping = {}
    list_of_datasets_to_load = get_list_of_datasets_to_load(variables["standard"])
    for dataset_name in list_of_datasets_to_load:
        variable_containing_dataset_table_name = "{}_table_name_app".format(
            dataset_name
        )
        dataset_table_name = variables["standard"][
            variable_containing_dataset_table_name
        ]
        datasets_to_tables_or_paths_mapping[dataset_name] = dataset_table_name

    flow_connection_handler = FlowConnectionsHandler(
        project=project,
        main_connection_name=main_connection_name,
        fallback_connection_name=fallback_connection_name,
        input_datasets=list_of_datasets_to_load,
        input_datasets_to_preserve=FLOW_INPUT_DATASETS_TO_PRESERVE,
        fallback_connection_datasets=fallback_connection_datasets,
        fallback_connection_datasets_downstream_recipes=fallback_connection_datasets_downstream_recipes,
        input_folders=["cluster_model"],
        bool_change_computed_folders_connections=False,
        folders_connection_name=folder_connection_name,
        project_folders_to_preserve=[],
    )
    flow_connection_handler.switch_flow_datasets_connections(
        managed_datasets_write_file_format=managed_dataset_format
    )
    flow_connection_handler.switch_input_datasets_to_not_managed_sate()
    flow_connection_handler.switch_flow_folders_connections()

    flow_connection_handler.connect_flow_input_datasets(
        datasets_to_tables_or_paths_mapping,
        input_datasets_read_file_format=managed_dataset_format,
    )

    if main_connection_type in ["Redshift", "Synapse", "BigQuery"]:
        flow_connection_handler.adapt_flow_to_fast_path()
