from flask import Blueprint, request
from pydantic import BaseModel

from solutions.backend.utils import return_ok

from ..utils.webapp_config import webapp_config

datasets = Blueprint("datasets", __name__, url_prefix="/datasets")


@datasets.route("/get_datasets", methods=["GET"])
def get_datasets():
    return return_ok({"nodes_datasets": webapp_config.nodes_datasets, "edges_datasets": webapp_config.edges_datasets})


class GetDatasetColumnsParams(BaseModel):
    dataset_name: str


@datasets.route("/get_dataset_columns", methods=["POST"])
def get_dataset_columns():
    params = GetDatasetColumnsParams(**request.get_json())

    all_datasets = webapp_config.nodes_datasets + webapp_config.edges_datasets

    dataset_name = params.dataset_name
    if dataset_name in all_datasets:
        return return_ok({"columns": webapp_config.get_columns(dataset_name)})
    else:
        raise ValueError("invalid dataset name {} ".format(str(dataset_name)))
