from common.backend.schemas.common_schemas import (
    ChatSettingsSchema,
    GeneratedMediaSchema,
    ProcessedFileSchema,
    QueryContextSchema,
    RetrievalSchema,
    UserPreferencesSchema,
)
from marshmallow import Schema, fields


class AskRequestSchema(Schema):
    user = fields.String(required=True)
    query = fields.String(required=True)

    # TODO Context should become required # type: ignore
    # Type ignore here because context conflict with marshmallow.Schema.context
    # we should rename it query_context to avoid other conflicts
    context = fields.Nested(QueryContextSchema, required=True)  # type: ignore
    conversationId = fields.String(required=False, allow_none=True)
    # messageIndex = fields.Int() # TODO should we keep it?
    selectedRetrieval = fields.Nested(RetrievalSchema, required=False)
    files = fields.List(fields.Nested(ProcessedFileSchema))
    chatSettings = fields.Nested(ChatSettingsSchema, required=False)
    userPreferences = fields.Nested(UserPreferencesSchema, required=False)


class AskResponseDataSchema(Schema):
    id = fields.String(required=True)
    messageIndex = fields.Int(required=True)
    answer = fields.String(required=True)
    query = fields.String(required=True)
    timestamp = fields.Float(required=True)

    # TODO Context should become required
    # Type ignore here because context conflict with marshmallow.Schema.context
    # we should rename it query_context to avoid other conflicts
    context = fields.Nested(QueryContextSchema, required=True)  # type: ignore

    trace = fields.Dict()
    usedRetrieval = fields.Nested(RetrievalSchema)
    conversationInfo = fields.Dict()
    generatedMedia = fields.List(fields.Nested(GeneratedMediaSchema))
    files = fields.List(fields.Nested(ProcessedFileSchema))
    llmContext = fields.Dict()


class AskResponseSchema(Schema):
    status = fields.String(required=True)
    data = fields.Nested(AskResponseDataSchema, required=True)
