import unittest
from unittest.mock import MagicMock, patch

from common.backend.models.base import LLMStep
from common.solutions.chains.image_generation.image_generation_chain import ImageGenerationChain


# @patch('common.solutions.chains.image_generation.image_generation_chain.dataiku_api')
class TestImageGenerationChain(unittest.TestCase):
    WEBAPP_CONFIG = {"max_images_per_user_per_week": 5}
    @patch('common.solutions.chains.image_generation.image_generation_chain.get_num_images_user_can_generate')
    @patch('common.solutions.chains.image_generation.image_generation_chain.get_nbr_images_to_generate')
    @patch('common.solutions.chains.image_generation.image_generation_chain.NoRetrievalChain')
    @patch('common.solutions.chains.image_generation.image_generation_chain.DKULLM')
    @patch('common.solutions.chains.image_generation.image_generation_chain.DKULLM')
    @patch('common.solutions.chains.image_generation.image_generation_chain.get_auth_user')
    def test_run_image_generation_query_limit_reached(
        self,
        mock_get_user,
        mock_llm,
        mock_image_generation_llm,
        mock_NoRetrievalChain,
        mock_get_nbr_images_to_generate,
        mock_get_num_images_user_can_generate,
    ):
        # Mock configurations and behaviors
        mock_get_num_images_user_can_generate.return_value = 0
        mock_get_nbr_images_to_generate.return_value = 3
        mock_no_retrieval_instance = mock_NoRetrievalChain.return_value
        mock_no_retrieval_instance.get_as_json.return_value = {"answer": "max_images_per_user_per_week_reached"}
        mock_get_user.return_value = "user"
        # Prepare test inputs
        user_query = "Generate an image"
        referred_image = None
        user_profile = {"media": {"image": {"nbr_images_to_generate": {"value": 3}}}}
        user_profile_sql_manager_mock = MagicMock()
        user_profile_sql_manager_mock.get_user_profile.return_value = user_profile

        mock_include_user_profile_in_prompt = MagicMock()
        
        # Initialize the ImageGenerationChain
        llm =  mock_llm.return_value
        image_generation_llm = mock_image_generation_llm.return_value
        image_gen_chain = ImageGenerationChain(llm=llm, image_generation_llm=image_generation_llm, user_query=user_query, referred_image=referred_image, user_profile=user_profile, trace=MagicMock(), user_profile_sql_manager=user_profile_sql_manager_mock,
                                               include_user_profile_in_prompt=mock_include_user_profile_in_prompt)

        # Execute the method
        results = list(image_gen_chain.run_image_generation_query(max_images_to_generate=self.WEBAPP_CONFIG["max_images_per_user_per_week"]))

        # Validate results
        self.assertEqual(results[0], {"step": LLMStep.STREAMING_END})
        self.assertEqual(results[1], {"answer": "max_images_per_user_per_week_reached"})
        mock_get_num_images_user_can_generate.assert_called_with(user_profile)
        mock_NoRetrievalChain.assert_called_with(llm, include_user_profile_in_prompt=mock_include_user_profile_in_prompt)
        mock_no_retrieval_instance.get_as_json.assert_called_once()

    @patch('common.solutions.chains.image_generation.image_generation_chain.get_num_images_user_can_generate')
    @patch('common.solutions.chains.image_generation.image_generation_chain.set_nbr_images_to_generate')
    @patch('common.solutions.chains.image_generation.image_generation_chain.DKULLM')
    @patch('common.solutions.chains.image_generation.image_generation_chain.DKULLM')
    @patch('common.solutions.chains.image_generation.image_generation_chain.get_auth_user')
    def test_run_image_generation_requested_images_more_than_quota_left(
        self,
        mock_get_user,
        mock_llm,
        mock_image_generation_llm,
        mock_set_nbr_images_to_generate,
        mock_get_num_images_user_can_generate
    ):
        # Mock configurations and behaviors
        mock_get_num_images_user_can_generate.return_value = 3
        mock_get_user.return_value = "user"
        user_profile_sql_manager_mock = MagicMock()
        # Prepare test inputs
        user_query = "Generate an image"
        referred_image = None
        user_profile = {"media": {"image": {"nbr_images_to_generate": {"value": 4}}}}
        user_profile_sql_manager_mock.get_user_profile.return_value = user_profile
        # Initialize the ImageGenerationChain
        llm =  mock_llm.return_value
        image_generation_llm = mock_image_generation_llm.return_value
        image_gen_chain = ImageGenerationChain(llm=llm, image_generation_llm=image_generation_llm, user_query=user_query, referred_image=referred_image, user_profile=user_profile, trace=MagicMock(), user_profile_sql_manager=user_profile_sql_manager_mock)

        # Execute the method
        results = list(image_gen_chain.run_image_generation_query(max_images_to_generate=self.WEBAPP_CONFIG["max_images_per_user_per_week"]))

        # Validate results
        new_user_profile = {"media": {"image": {"nbr_images_to_generate": {"value": 3}}}}
        mock_set_nbr_images_to_generate.return_value = new_user_profile
        mock_get_num_images_user_can_generate.assert_called_with(user_profile)
        mock_set_nbr_images_to_generate.assert_called_with(user_profile, 3)
        self.assertEqual(results[0], {"step": LLMStep.GENERATING_IMAGE})
