# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import json
import numpy as np
from collections import Counter, defaultdict
import pandas as pd

N_TRAIN = 10

df = dataiku.Dataset("object_detection_data").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
np.random.seed(seed=0)

train = []
counts = defaultdict(int)
for i in np.random.permutation(len(df)):
    if df.iloc[i]["train"] == 1:
        labels = Counter([x["category"] for x in json.loads(df.iloc[i]["label"])])
        for label in labels:
            if labels[label] + counts[label] > N_TRAIN:
                break
        else:
            train.append(i)
            for label in labels:
                counts[label] += labels[label]

train_df = df.iloc[train]
test_df = df[df["train"] == 0]

del train_df["train"]
del test_df["train"]

dataiku.Dataset("train").write_with_schema(train_df)
dataiku.Dataset("test").write_with_schema(test_df)