# -*- coding: utf-8 -*-

"""
Train, log, and deploy a Keras multi-output regression model using MLflow and Dataiku.
Supports experiment tracking, hyperparameter tuning, and automatic deployment of the best model.
"""

import os
import shutil
import itertools
import numpy as np
import pandas as pd
import dataiku
from dataiku import pandasutils as pdu
from dataikuapi.dss.ml import DSSPredictionMLTaskSettings
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import mlflow
import mlflow.pytorch


# ---------------------- CONFIGURATION ----------------------
CONFIG = {
    "prediction_type": "OTHER",
    "experiment_folder_id": "eXga6z7K",
    "experiment_name": "Keras-MO-TS-Model",
    "saved_model_name": "Keras Multi-Output Time Series Model",
    "deployment_metric": "test_loss",
    "eval_dataset_name": "validation_ts",
    "train_dataset_name": "train_ts",
    "targets": ["output1", "output2"],
    "auto_deploy": True,
    "parameter_grid": {
        "batch_size": [28, 32],
        "learning_rate": [0.001],
        "num_epochs": [30]
    },
    "artifacts": {
        "model_path": "./keras_model_multioutput_ts.pth"
    }
}
# -----------------------------------------------------------


def get_code_env_name():
    """Retrieve the current code environment name from custom variables."""
    return dataiku.get_custom_variables().get("code_env")


def get_project_and_folder():
    """Return Dataiku project and experiment folder."""
    project = dataiku.api_client().get_default_project()
    mf = project.get_managed_folder(CONFIG["experiment_folder_id"])
    return project, mf


def load_data():
    """Load and split train and validation datasets."""
    df_train = dataiku.Dataset(CONFIG["train_dataset_name"]).get_dataframe()
    df_validation = dataiku.Dataset(CONFIG["eval_dataset_name"]).get_dataframe()
    X_train = df_train.drop(CONFIG["targets"] + ["datetime"], axis=1)
    y_train = df_train[CONFIG["targets"]]
    X_test = df_validation.drop(CONFIG["targets"] + ["datetime"], axis=1)
    y_test = df_validation[CONFIG["targets"]]
    return X_train, y_train, X_test, y_test


def create_model(input_dim, output_dim, learning_rate):
    """Create and compile a Keras multi-output regression model."""
    model = Sequential([
        Dense(64, input_shape=(input_dim,), activation='relu'),
        Dense(32, activation='relu'),
        Dense(output_dim, activation='linear')
    ])
    model.compile(optimizer=Adam(learning_rate=learning_rate), loss='mean_squared_error')
    return model


class KerasModelWrapper(mlflow.pyfunc.PythonModel):
    """MLflow wrapper for Keras multi-output regression model."""
    def load_context(self, context):
        self.model = tf.keras.models.load_model(context.artifacts["model_path"])

    def predict(self, context, model_input):
        model_input = model_input.drop(CONFIG["targets"] + ["datetime"], axis=1)
        return self.model.predict(model_input)

    
def setup_experiment(project, mlflow_ext, mf, mlflow):
    """Create or retrieve the MLflow experiment."""
    with project.setup_mlflow(mf) as mlflow:
        experiment_names = [exp['name'] for exp in mlflow_ext.list_experiments().get("experiments", [])]
        if CONFIG["experiment_name"] not in experiment_names:
            mlflow.create_experiment(CONFIG["experiment_name"])
        experiment = mlflow.get_experiment_by_name(CONFIG["experiment_name"])
    return experiment.experiment_id
            

def main():
    """Main execution function: trains multiple models and logs them using MLflow."""
    project, mf = get_project_and_folder()
    mlflow_extension = project.get_mlflow_extension()
    mlflow_extension.garbage_collect()
    
    X_train, y_train, X_test, y_test = load_data()
    
    # MLflow Experiment Setup
    with project.setup_mlflow(mf) as mlflow:
        experiment_id = setup_experiment(project, mlflow_extension, mf, mlflow)
        print(f"Experiment name is: {CONFIG['experiment_name']} and ID is: {experiment_id}")

        # Run experiments
        for params in itertools.product(*CONFIG["parameter_grid"].values()):
            batch_size, learning_rate, num_epochs = params
            with mlflow.start_run(experiment_id=experiment_id) as run:
                run_id = run.info.run_id
                print(f"Starting run {run_id} with params: {params}")

                mlflow.log_param("batch_size", batch_size)
                mlflow.log_param("learning_rate", learning_rate)
                mlflow.log_param("num_epochs", num_epochs)

                model = create_model(input_dim=X_train.shape[1], output_dim=len(CONFIG["targets"]), learning_rate=learning_rate)
                model.fit(X_train, y_train, epochs=num_epochs, batch_size=batch_size, validation_split=0.1)
                loss = model.evaluate(X_test, y_test)
                print(f"Test Loss: {loss}")
                mlflow.log_metric("test_loss", loss)

                if os.path.exists(CONFIG["artifacts"]["model_path"]):
                    shutil.rmtree(CONFIG["artifacts"]["model_path"])
                model.save(CONFIG["artifacts"]["model_path"])

                mlflow.pyfunc.log_model(
                    artifact_path=f"{type(model).__name__}-{run_id}",
                    python_model=KerasModelWrapper(),
                    artifacts={"model_path": CONFIG["artifacts"]["model_path"]}
                )

                mlflow_extension.set_run_inference_info(
                    run_id=run_id,
                    prediction_type=CONFIG["prediction_type"],
                    code_env_name=get_code_env_name()
                )
                print(f"Run {run_id} complete\n{'-'*40}")
    

    # Auto deploy best model
    if CONFIG["auto_deploy"]:
        with project.setup_mlflow(mf) as mlflow:
            experiment = mlflow.get_experiment_by_name(CONFIG["experiment_name"])
            best_run = None
            for _, run_info in mlflow.search_runs(experiment.experiment_id).iterrows():
                run = mlflow.get_run(run_info["run_id"])
                if not best_run or run.data.metrics.get(CONFIG["deployment_metric"], float("inf")) < \
                                   best_run.data.metrics.get(CONFIG["deployment_metric"], float("inf")):
                    best_run = run

            run_id = best_run.info.run_id
            print(f"Best run ID: {run_id} with {CONFIG['deployment_metric']} = "
                  f"{best_run.data.metrics.get(CONFIG['deployment_metric'])}")

            sm_id = next((sm["id"] for sm in project.list_saved_models()
                          if sm["name"] == CONFIG["saved_model_name"]), None)

            if sm_id:
                print(f"Found Saved Model: {CONFIG['saved_model_name']} with ID: {sm_id}")
                saved_model = project.get_saved_model(sm_id)
            else:
                saved_model = project.create_mlflow_pyfunc_model(CONFIG["saved_model_name"], CONFIG["prediction_type"])
                sm_id = saved_model.id
                print(f"Created new Saved Model with ID: {sm_id}")

            existing_versions = [model['id'] for model in saved_model.list_versions()]
            version_id = next(f"v{i}" for i in range(1, 1000) if f"v{i}" not in existing_versions)

            print(f"Deploying version {version_id} of model {CONFIG['saved_model_name']}")
            mlflow_extension.deploy_run_model(
                run_id,
                sm_id,
                code_env_name=get_code_env_name(),
                version_id=version_id,
            )
            

if __name__ == "__main__":
    main()