# -*- coding: utf-8 -*-

"""
This script trains a regression model using PyTorch, logs it to MLflow, and automatically deploys 
the best-performing model as a Dataiku Saved Model with evaluation on a validation dataset.
"""

import os
import shutil
import itertools
import dataiku
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import mlflow
import mlflow.pytorch
from dataikuapi.dss.ml import DSSPredictionMLTaskSettings

# ---------------------- CONFIGURATION ----------------------
CONFIG = {
    "prediction_type": "REGRESSION",
    "experiment_folder_id": "pshnkijv",
    "experiment_name": "PyTorch-Regression-Model",
    "saved_model_name": "PyTorch Regression Model",
    "deployment_metric": "test_loss",
    "eval_dataset": "validation",
    "target_name": "target",
    "auto_deploy": True,
    "parameter_grid": {
        'input_size': None,  # To be filled dynamically
        'hidden_size': [64, 128],
        'output_size': [1],
        'learning_rate': [0.01],
        'num_epochs': [30]
    },
    "artifacts": {},
}
# -----------------------------------------------------------


client = dataiku.api_client()
project = client.get_default_project()
mlflow_extension = project.get_mlflow_extension()
mlflow_extension.garbage_collect()
mf = project.get_managed_folder(CONFIG["experiment_folder_id"])


class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))


class PytorchModelWrapper(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.model = mlflow.pytorch.load_model(context.artifacts[CONFIG["saved_model_name"]])
    def predict(self, context, model_input):
        inputs = torch.FloatTensor(model_input.values.astype("float32"))
        return pd.DataFrame(self.model(inputs).detach().numpy())

    
# --- Utility Functions ---
def get_code_env_name():
    """Retrieve the current code environment name from custom variables."""
    return dataiku.get_custom_variables().get("code_env")


def load_data():
    """Load configuration and training/evaluation datasets from Dataiku."""
    df_train = dataiku.Dataset("train").get_dataframe()
    df_validation = dataiku.Dataset("validation").get_dataframe()
    X_train = df_train.drop(CONFIG["target_name"], axis=1)
    y_train = df_train[CONFIG["target_name"]]
    X_test = df_validation.drop(CONFIG["target_name"], axis=1)
    y_test = df_validation[CONFIG["target_name"]]
    CONFIG["parameter_grid"]["input_size"] = [X_train.shape[1]]
    CONFIG["artifacts"][CONFIG["saved_model_name"]] = "./pytorch_regression_model.pth"
    return X_train, y_train, X_test, y_test


def prepare_dataloaders(X_train, y_train, X_test, y_test):
    """Prepare the data loaders for the training."""
    train_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_train.values).float(), torch.from_numpy(y_train.values).view(-1, 1).float()),
        batch_size=32, shuffle=True)
    test_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_test.values).float(), torch.from_numpy(y_test.values).view(-1, 1).float()),
        batch_size=32, shuffle=False)
    return train_loader, test_loader


def train_model(input_size, hidden_size, output_size, learning_rate, num_epochs, train_loader, test_loader, experiment_id):
    """Train a regression model and log it to MLflow under the specified experiment."""
    model = SimpleNN(input_size, hidden_size, output_size)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    with mlflow.start_run(experiment_id=experiment_id) as run:
        run_id = run.info.run_id
        mlflow.log_params({
            "input_size": input_size,
            "hidden_size": hidden_size,
            "output_size": output_size,
            "learning_rate": learning_rate,
            "num_epochs": num_epochs
        })

        for epoch in range(num_epochs):
            for inputs, targets in train_loader:
                optimizer.zero_grad()
                loss = criterion(model(inputs), targets)
                loss.backward()
                optimizer.step()
            mlflow.log_metric("train_loss", loss.item(), step=epoch)

        # Evaluation
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for inputs, targets in test_loader:
                outputs = model(inputs)
                test_loss += criterion(outputs, targets).item()
        avg_test_loss = test_loss / len(test_loader)
        mlflow.log_metric("test_loss", avg_test_loss)

        # Save and log model
        model_path = CONFIG["artifacts"][CONFIG["saved_model_name"]]
        if os.path.exists(model_path): shutil.rmtree(model_path)
        mlflow.pytorch.save_model(model, model_path)

        mlflow.pyfunc.log_model(
            artifact_path=f"{type(model).__name__}-{run_id}",
            python_model=PytorchModelWrapper(),
            artifacts=CONFIG["artifacts"]
        )

        mlflow_extension.set_run_inference_info(
            run_id=run_id,
            prediction_type=CONFIG["prediction_type"],
            code_env_name=get_code_env_name()
        )

        return run_id, avg_test_loss


def get_best_run(experiment_id):
    """Select the best MLflow run based on the test loss from a given experiment."""
    best_run = None
    for _, run_info in mlflow.search_runs(experiment_id).iterrows():
        run = mlflow.get_run(run_info["run_id"])
        if best_run is None or run.data.metrics.get(CONFIG["deployment_metric"], float('inf')) < \
                               best_run.data.metrics.get(CONFIG["deployment_metric"], float('inf')):
            best_run = run
    return best_run


def deploy_best_model(best_run, experiment_id):
    """Deploy the best MLflow run as a Dataiku saved model version."""
    sm_id = None
    for sm in project.list_saved_models():
        if sm["name"] == CONFIG["saved_model_name"]:
            sm_id = sm["id"]
            break
    saved_model = project.get_saved_model(sm_id) if sm_id else project.create_mlflow_pyfunc_model(
        CONFIG["saved_model_name"], CONFIG["prediction_type"])

    version_id = f"v{len(saved_model.list_versions()) + 1}"
    model_path = f"{experiment_id}/{best_run.info.run_id}/artifacts/{SimpleNN.__name__}-{best_run.info.run_id}"

    mlflow_version = saved_model.import_mlflow_version_from_managed_folder(
        version_id=version_id,
        managed_folder=CONFIG["experiment_folder_id"],
        path=model_path,
        code_env_name=get_code_env_name()
    )

    saved_model.set_active_version(mlflow_version.version_id)
    mlflow_version.set_core_metadata(target_column_name=CONFIG["target_name"],
                                     get_features_from_dataset=CONFIG["eval_dataset"])
    mlflow_version.evaluate(CONFIG["eval_dataset"])
    mlflow_extension.deploy_run_model(
        run_id=best_run.info.run_id,
        sm_id=saved_model.id,
        code_env_name=get_code_env_name(),
        evaluation_dataset=CONFIG["eval_dataset"],
        version_id=version_id,
        target_column_name=CONFIG["target_name"]
    )


def main():
    """Main execution function: train a regression model using PyTorch and track experiments using MLflow."""
    X_train, y_train, X_test, y_test = load_data()
    train_loader, test_loader = prepare_dataloaders(X_train, y_train, X_test, y_test)

    with project.setup_mlflow(mf) as mlflow_context:
        experiment = mlflow_context.get_experiment_by_name(CONFIG["experiment_name"])
        if experiment is None:
            mlflow_context.create_experiment(CONFIG["experiment_name"])
            experiment = mlflow_context.get_experiment_by_name(CONFIG["experiment_name"])
        experiment_id = experiment.experiment_id

        for params in itertools.product(*CONFIG["parameter_grid"].values()):
            input_size, hidden_size, output_size, learning_rate, num_epochs = params
            run_id, test_loss = train_model(input_size, hidden_size, output_size, learning_rate, num_epochs, train_loader, test_loader, experiment_id)
            print(f"Finished run {run_id} with test_loss={test_loss}")

        if CONFIG["auto_deploy"]:
            best_run = get_best_run(experiment_id)
            print(f"Deploying best run {best_run.info.run_id} with {CONFIG['deployment_metric']}={best_run.data.metrics[CONFIG['deployment_metric']]}")
            deploy_best_model(best_run, experiment_id)


if __name__ == "__main__":
    main()