import dataiku
import html
import time
import urllib.robotparser
import re
import json
import requests
import urllib.parse
import requests
import os
import logging
import concurrent.futures
import functools
from langchain.document_loaders import UnstructuredURLLoader
from langchain.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import OpenAIEmbeddings
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.tools.base import BaseTool

HEADER = {
    "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36"
}
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
MODEL = "sentence-transformers/msmarco-distilbert-cos-v5"

embeddings = HuggingFaceEmbeddings(
    model_name=MODEL, cache_folder=os.getenv("SENTENCE_TRANSFORMERS_HOME")
)

from langchain.embeddings import OpenAIEmbeddings
auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "openai_key":
        embeddings = OpenAIEmbeddings(openai_api_key=secret["value"])
        break

text_splitter = CharacterTextSplitter(
    separator="\n",
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
    length_function=len,
)

timeout_webpage_scraping = 10

clean_html_regex = re.compile('<.*?>') 

def clean_html(raw_html):
    cleaned = re.sub(clean_html_regex, '', raw_html)
    return html.unescape(cleaned)

class BraveSearchWrapper(object):
    def __init__(self, num_results, safesearch="strict"):
        self.api_key = os.environ["BRAVE_API_KEY"]
        self.num_results = num_results
        self.safesearch = safesearch
    
    @functools.lru_cache()
    def results(self, query):
        headers = {"X-Subscription-Token": self.api_key}
        results = []
        for i in range((self.num_results-1)//20+1):
            params = {"q": query, "count": min(20, self.num_results), "result_filter": "web", "offset": i, "safesearch": self.safesearch}
            response = requests.get(
                "https://api.search.brave.com/res/v1/web/search",
                headers=headers,
                params=params
            )
            response.raise_for_status()
            json_result = response.json()
            if "web" in json_result:
                results += [
                    {
                        "title": f"{x['title']} ({x['age']})" if "age" in x else x["title"],
                        "link": x["url"],
                        "snippet": html.unescape(re.sub(r'<[^>]+>', '', x["description"]))
                    } for x in json_result["web"]["results"]
                ]
        return results[:self.num_results]

class YouSearchWrapper(object):
    def __init__(self, num_results, safesearch="strict"):
        self.api_key = os.environ["YDC_API_KEY"]
        self.num_results = num_results
        self.safesearch = safesearch
    
    @functools.lru_cache()
    def results(self, query):
        headers = {"X-API-Key": self.api_key}
        results = []
        for i in range((self.num_results-1)//20+1):
            params = {"query": query, "count": min(20, self.num_results), "offset": i, "safesearch": self.safesearch}
            response = requests.get(
                f"https://api.ydc-index.io/search",
                headers=headers,
                params=params
            )
            response.raise_for_status()
            json_result = response.json()
            if "hits" in json_result:
                results += [
                    {
                        "title": x['title'],
                        "link": x["url"],
                        "snippet": clean_html(x["description"]) + "\n\n" + "\n\n".join(x["snippets"][:2])
                    } for x in json_result["hits"]
                ]
        return results[:self.num_results]
    
    @functools.lru_cache()
    def results_news(self, query):
        headers = {"X-API-Key": self.api_key}
        results = []
        for i in range((self.num_results-1)//20+1):
            params = {"query": query, "count": min(20, self.num_results), "offset": i, "safesearch": self.safesearch}
            response = requests.get(
                f"https://api.ydc-index.io/news",
                headers=headers,
                params=params
            )
            response.raise_for_status()
            json_result = response.json()
            if "news" in json_result:
                results += [
                    {
                        "title": f"{x['title']} ({x['age']})",
                        "link": x["url"],
                        "snippet": x["description"]
                    } for x in json_result["news"]["results"]
                ]
        return results[:self.num_results]

class BraveSearch(BaseTool):
    """Tool that queries the BraveSearch API."""

    name: str = "brave_search"
    description: str = (
        "an internet search engine. "
        "useful for when you need to search the internet."
        " input should be a search query."
    )
    search_wrapper: BraveSearchWrapper

    @classmethod
    def create(
        cls, num_results, safesearch="strict"
    ):
        """Create a tool from the expected number of results.

        Args:
            num_results: The number of results returned by the search engine.
            safesearch: Whether to activate Safe Search. Can be `strict`, `moderate` or `off`.

        Returns:
            A tool.
        """
        wrapper = BraveSearchWrapper(num_results, safesearch)
        return cls(search_wrapper=wrapper)

    def _run(
        self,
        query: str,
        run_manager=None,
    ) -> str:
        """Use the tool."""
        return json.dumps(self.search_wrapper.results(query))

class YouSearch(BaseTool):
    """Tool that queries the You.com search API."""

    name: str = "you_search"
    description: str = (
        "an internet search engine. "
        "useful for when you need to search the internet."
        " input should be a search query."
    )
    search_wrapper: YouSearchWrapper

    @classmethod
    def create(
        cls, num_results, safesearch="strict"
    ):
        """Create a tool from the expected number of results.

        Args:
            num_results: The number of results returned by the search engine.
            safesearch: Whether to activate Safe Search. Can be `strict`, `moderate` or `off`.

        Returns:
            A tool.
        """
        wrapper = YouSearchWrapper(num_results, safesearch)
        return cls(search_wrapper=wrapper)

    def _run(
        self,
        query: str,
        run_manager=None,
    ) -> str:
        """Use the tool."""
        return json.dumps(self.search_wrapper.results(query))

class YouSearchNews(BaseTool):
    """Tool that queries the You.com search API."""

    name: str = "you_search_news"
    description: str = (
        "a search engine to retrieve recent news."
        "useful for when you need to answer questions about current events."
        " input should be a search query."
    )
    search_wrapper: YouSearchWrapper

    @classmethod
    def create(
        cls, num_results, safesearch="strict"
    ):
        """Create a tool from the expected number of results.

        Args:
            num_results: The number of results returned by the search engine.
            safesearch: Whether to activate Safe Search. Can be `strict`, `moderate` or `off`.

        Returns:
            A tool.
        """
        wrapper = YouSearchWrapper(num_results, safesearch)
        return cls(search_wrapper=wrapper)

    def _run(
        self,
        query: str,
        run_manager=None,
    ) -> str:
        """Use the tool."""
        return json.dumps(self.search_wrapper.results_news(query))

def get_base_url(url):
    """
    Derive the base URL from an URL
    """
    parsed_url = urllib.parse.urlparse(url)
    return f"{parsed_url.scheme}://{parsed_url.netloc}"

def get_robotparser(base_url):
    """
    Get the robot parser for a base URL
    """
    robots_url = urllib.parse.urljoin(base_url, "robots.txt")
    rp = urllib.robotparser.RobotFileParser(robots_url)
    rp.read()
    # A website is assumed to be crawlable if robots.txt is out of reach
    # Cf. https://developers.google.com/search/docs/crawling-indexing/robots/robots-faq#h1c
    rp.disallow_all = False
    rp.last_checked = time.time()
    return rp

def filter_urls(urls):
    """
    Filter the URLs that are disallowed in the corresponding robots.txt
    """
    urls_by_base_url = {}
    results = []
    for url in urls:
        base_url = get_base_url(url)
        if base_url in urls_by_base_url:
            urls_by_base_url[base_url].append(url)
        else:
            urls_by_base_url[base_url] = [url]
    for base_url in urls_by_base_url:
        try:
            rp = get_robotparser(base_url)
            for url in urls_by_base_url[base_url]:
                if rp.can_fetch("*", url):
                    results.append(url)
        except (ConnectionResetError, urllib.error.URLError):
            results += urls_by_base_url[base_url]
    return results

def with_timeout(timeout):
    def decorator(func):
        @functools.wraps(func)
        def wrapped(*args, **kwargs):
            # Create a ThreadPoolExecutor with a single thread
            with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
                # Submit the function to the executor
                future = executor.submit(func, *args, **kwargs)
                
                try:
                    # Wait for the future to complete with a timeout
                    result = future.result(timeout=timeout)
                except concurrent.futures.TimeoutError:
                    # If the timeout is exceeded, cancel the future
                    future.cancel()
                    raise TimeoutError(f"Function timed out after {timeout} seconds")
                
                return result
        return wrapped
    return decorator

@with_timeout(int(timeout_webpage_scraping))
def extract_content(url):
    """
    Extract the content of an URL as LangChain documents.
    """
    loader = UnstructuredURLLoader([url], mode="single", strategy="fast", headers=HEADER)
    logging.info(f"Content extracted: {url}")
    return loader.load_and_split(text_splitter=text_splitter)

def index_urls(urls):
    if len(urls) == 0:
        return None
    docs = []
    for url in urls:
        try:
            docs += extract_content(url)
        except TimeoutError:
            pass
    return FAISS.from_documents(docs, embeddings)

def format_chunks(chunks):
    return [f"{x.metadata['source']}: {x.page_content}" for x in chunks]