# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import pandas as pd
import numpy as np

num_train = 100
num_train_small = 8
num_test = 100

df = dataiku.Dataset("data").get_dataframe()
grouped = df.groupby('label_text')
df2, df3, df4 = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

i = 0
for name, group in grouped:
    sample = group.sample(n=num_train + num_test)
    df2 = pd.concat([df2, sample.iloc[:num_train]])
    df3 = pd.concat([df3, sample.iloc[num_train:]])
    df4 = pd.concat([df4, sample.iloc[:num_train_small]])
    i += 1

dataiku.Dataset("train").write_with_schema(df2)
dataiku.Dataset("test").write_with_schema(df3)
dataiku.Dataset("train_small").write_with_schema(df4)