# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import os
import json
import dataiku
import cv2
import torch
import pandas as pd

import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)

df = dataiku.Dataset("train").get_dataframe()
input_path = dataiku.Folder("PRCGY0s7").get_path()
output_path = dataiku.Folder("vRa5vtTP").get_path()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog

from detectron2.structures import BoxMode

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
all_categories = set()
cat2id = {}
for i in range(len(df)):
    annotations = json.loads(df["label"].iloc[i])
    for annotation in annotations:
        all_categories.add(annotation["category"])
all_categories = sorted(list(all_categories))
for index, value in enumerate(all_categories):
    cat2id[value] = index

categories_df = pd.DataFrame.from_dict(
    {"value": all_categories, "id": range(len(all_categories))}
)
dataiku.Dataset("classes").write_with_schema(categories_df)


def to_dict(df):
    result = []
    for i in range(len(df)):
        image_path = os.path.join(input_path, df.iloc[i].record_id)
        img = cv2.imread(image_path)
        record = {
            "file_name": image_path,
            "image_id": i,
            "height": img.shape[0],
            "width": img.shape[1],
        }
        annotations = []
        for annotation in json.loads(df.iloc[i].label):
            annotations.append(
                {
                    "bbox": annotation["bbox"],
                    "bbox_mode": BoxMode.XYWH_ABS,
                    "category_id": cat2id[annotation["category"]],
                }
            )
        record["annotations"] = annotations
        result.append(record)
    return result


train_dict = to_dict(df)
DatasetCatalog.register("train", lambda d="train": train_dict)
MetadataCatalog.get("train").set(thing_classes=all_categories)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.merge_from_file(
    model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
)
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
cfg.OUTPUT_DIR = output_path
cfg.DATASETS.TRAIN = ("train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
    "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
)
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 4000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(all_categories)

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
