from marshmallow import INCLUDE, Schema, fields, validate


class GeneratedMediaSchema(Schema):
    data = fields.String(required=True)
    format = fields.String(required=True)
    path = fields.String(required=True)
    referredFilePath = fields.String(required=False)


class FileSchema(Schema):
    name = fields.String(required=True)
    format = fields.String(required=True)
    path = fields.String()
    # type = fields.String(validate=validate.OneOf(["image", "audio", "video"]))


class ProcessedFileSchema(Schema):
    name = fields.String(required=True)
    format = fields.String(required=True)
    path = fields.String(required=True)
    # type = fields.String(validate=validate.OneOf(["image", "audio", "video"]))
    thumbnail = fields.String(allow_none=True)
    chainType = fields.String()
    jsonFilePath = fields.String()


class FilterSchema(Schema):
    key = fields.String(required=True)
    value = fields.Raw()


class SourceSchema(Schema): 
    type = fields.String()  
    metadata = fields.Dict(keys=fields.String(), values=fields.Raw())  # TODO check values type
    generatedSqlQuery = fields.String() 
    records = fields.Dict(keys=fields.String(), values=fields.Raw())  # TODO check values type

    sample = fields.Dict(keys=fields.String(), values=fields.Raw())  # TODO check values type
    images = fields.List(fields.Dict(keys=fields.String(), values=fields.Raw()))
    usedTables = fields.String()
    title = fields.String()
    url = fields.String()

    textSnippet = fields.String()
    markdownSnippet = fields.String()
    htmlSnippet = fields.String()

class AggregatedToolSourcesSchema(Schema):
    toolCallDescription = fields.String()
    items = fields.List(fields.Nested(SourceSchema))

class RetrievalSchema(Schema):
    name = fields.String(required=True)
    type = fields.String(required=True)
    alias = fields.String()
    filters = fields.Dict(keys=fields.String(), values=fields.List(fields.Raw()))  # TODO check values type
    sources = fields.List(fields.Nested(AggregatedToolSourcesSchema))  # TODO check values type
    generatedSqlQuery = fields.String()
    usedTables = fields.List(fields.String())

class HistorySchema(Schema):
    query = fields.String(required=True)
    answer = fields.String(required=True)
    timestamp = fields.Float(required=True)
    generatedMedia = fields.List(fields.Nested(GeneratedMediaSchema))
    files = fields.List(fields.Nested(ProcessedFileSchema))


class QueryContextSchema(Schema):
    applicationType = fields.String(required=True)
    applicationId = fields.String(required=True)
    botId = fields.String()
    device = fields.String()
    team = fields.String()
    timestamp = fields.Float(required=True)
    history = fields.List(fields.Nested(HistorySchema))
    localization = fields.Dict()  # TODO should we keep it?


class ChatSettingsSchema(Schema):
    createConversation = fields.Boolean(required=True)
    withTitle = fields.Boolean(required=True)
    requestedResponseFormat = fields.String(required=True, validate=validate.OneOf(["markdown", "text", "html"]))
    # forceJsonOutput = fields.Boolean()
    requestedResponseLengthLimit = fields.Integer()


class UserPreferencesSchema(Schema):
    class Meta:
        unknown = INCLUDE

    language = fields.Dict(keys=fields.String(), values=fields.String())
    timezone = fields.Dict(keys=fields.String(), values=fields.String())


# class MeshTraceSchema(Schema):
#     type = fields.String()
#     begin = fields.String()
#     end = fields.String()
#     duration = fields.Integer()
#     name = fields.String()
#     children = fields.List(fields.Nested('self'))
#     attributes = fields.Dict()
#     inputs = fields.Dict()
#     outputs = fields.Dict()
#     usageMetadata = fields.Dict()
#     inputs = fields.Dict()
#     outputs = fields.Dict()

class MessageFeedbackSchema(Schema):
    value = fields.String(required=True)
    choice = fields.String(required=True)
    message = fields.String(required=True)