import ast
import json
import argparse
import os
import logging
from pathlib import Path

debug_output = False


def strdebug(n):
    if debug_output:
        return str(ast.dump(n))
    return str(type(n))


def parse_node(node):
    """Recursively parse AST nodes."""
    if node is None:
        return None
    elif isinstance(node, ast.Constant):
        return node.value  # e.g., for `3`, node.value would be 3
    elif isinstance(node, ast.Name):
        return node.id  # e.g., for `x`, node.id would be "x"
    elif isinstance(node, ast.BinOp):
        # Handle binary operations like `a + b`
        return f"({parse_node(node.left)} {ast.dump(node.op)} {parse_node(node.right)})"
    elif isinstance(node, ast.Tuple):
        # Handle tuple of slices (e.g., a[1:10, 20:30])
        return ",".join([parse_node(el) for el in node.elts])
    elif isinstance(node, ast.Slice):
        # Handle an ast.Slice node
        return {
            "lower": parse_node(node.lower),
            "upper": parse_node(node.upper),
            "step": parse_node(node.step),
        }
    elif isinstance(node, ast.Subscript):
        return parse_subscript(node)
    else:
        return ast.dump(node)  # Fallback for unhandled cases


def parse_subscript(subscript: ast.Subscript):
    try:
        return subscript.value.id + "[" + parse_node(subscript.slice) + "]"
    except Exception:
        logging.exception("could not parse subscript from ast.Subscript.")


def parse_annotation(annotation):
    if isinstance(annotation, ast.Name):
        return annotation.id
    if isinstance(annotation, ast.Attribute):
        return annotation.value.id
    if isinstance(annotation, ast.Subscript):
        return parse_subscript(annotation)
    if isinstance(annotation, ast.Constant):
        return annotation.s

    return None


def parse_function_args(arguments: ast.arguments):
    args = []

    # posonlyargs
    try:
        for arg in arguments.posonlyargs:
            arg_dict = {
                "name": arg.arg,
                "type": parse_annotation(arg.annotation),
            }
            args.append(arg_dict)
    except AttributeError:
        pass

    # normal args
    try:
        for arg in arguments.args:
            arg_dict = {
                "name": arg.arg,
                "type": parse_annotation(arg.annotation),
            }
            args.append(arg_dict)
    except AttributeError:
        pass

    # vararg
    try:
        if arguments.varargs is not None:
            arg_dict = {
                "name": arguments.varargs.arg,
                "type": parse_annotation(arguments.varargs.annotation),
            }
            args.append(arg_dict)
    except AttributeError:
        pass

    # kwonlyargs
    try:
        for arg in arguments.kwonlyargs:
            arg_dict = {
                "name": arg.arg,
                "type": parse_annotation(arg.annotation),
            }
            args.append(arg_dict)
    except AttributeError:
        pass

    # default values
    # To Do: handle kw_defaults & defaults

    return args


def parse_function_return_statement(node):
    if isinstance(node, ast.Constant):
        return node.value
    if isinstance(node, ast.Name):
        return node.id
    return None


def parse_decorators(decorators):
    d = []
    for decorator in decorators:
        if isinstance(decorator, ast.Attribute):
            d.append(decorator.attr)
        elif isinstance(decorator, ast.Call):
            continue
        else:
            d.append(decorator.id)
    return d


def parse_function_definition(node: ast.FunctionDef):
    """Return a dict with the function infos"""

    f = dict()
    f["name"] = node.name
    f["docstring"] = ast.get_docstring(node)

    decorators = parse_decorators(node.decorator_list)
    if decorators:
        f["decorators"] = decorators
    params = parse_function_args(node.args)
    if params:
        f["params"] = params
    returns = parse_function_return_statement(node.returns)
    if returns is not None:
        f["returns"] = parse_function_return_statement(node.returns)

    return f


def parse_class_definition(node: ast.ClassDef):
    """Return a dict with the Class elements"""

    class_def = dict()
    class_def["name"] = node.name
    class_def["docstring"] = ast.get_docstring(node)

    decorators = parse_decorators(node.decorator_list)
    if decorators:
        class_def["decorators"] = decorators

    bases = []
    for base in node.bases:
        if isinstance(base, ast.Name):
            bases.append(base.id)
        elif isinstance(base, ast.Attribute):
            bases.append(base.attr)
        elif isinstance(base, ast.Subscript):  # subscript in class def
            continue
        else:
            logging.debug("type not handled in Class: ", strdebug(base))
            continue
    if bases:
        class_def["bases"] = bases

    classes = []
    functions = []
    for childNode in ast.iter_child_nodes(node):
        if isinstance(childNode, ast.ClassDef):
            if childNode.name.startswith("_"):
                continue
            classes.append(parse_class_definition(childNode))
        elif isinstance(childNode, (ast.FunctionDef, ast.AsyncFunctionDef)):
            if childNode.name == "__init__":
                parsed = parse_function_definition(childNode)
                if parsed["params"]:
                    class_def["params"] = parsed["params"]
                continue
            elif childNode.name.startswith("_"):
                continue
            functions.append(parse_function_definition(childNode))
        elif isinstance(childNode, ast.Expr):  # a normal expression, skip
            continue
        elif isinstance(childNode, ast.Name):  # in the class def "object"
            continue
        elif isinstance(
            childNode, ast.Assign
        ):  # class properties TODO(jfy): we may want handle them
            continue
        elif isinstance(
            childNode, ast.Attribute
        ):  # class inheritance in another module
            continue
        elif isinstance(childNode, ast.Pass):  # pass in a class
            continue
        elif isinstance(childNode, ast.Call):  # decorator above the class, same as when parsing decorators
            continue
        elif isinstance(childNode, ast.AnnAssign):  # annotation in class
            continue
        elif isinstance(childNode, ast.Subscript):  # subscript in node
            continue
        else:
            logging.debug("type not handled in child node: ", strdebug(childNode))
            continue
    if classes:
        class_def["classes"] = classes
    if functions:
        class_def["functions"] = functions

    return class_def


def get_import_path(file_path, root_path):
    import_path = str(file_path.relative_to(root_path)).replace("/", ".")
    if import_path.endswith(".py"):
        import_path = import_path[:-3]
    return import_path

def parse_file(file_path: Path, root_path: Path, from_dir: Path) -> dict:
    logging.info("parsing " + str(file_path))
    if from_dir is None:
        from_dir = root_path
    mm = dict()
    with open(file_path, "r", encoding="utf-8") as f:
        try:
            content = f.read()
        except UnicodeDecodeError as e:
            logging.exception(f"error while reading {file_path}", e)
            return None
        module_node = ast.parse(content)
        docstring = ast.get_docstring(module_node)
        classes = []
        functions = []
        for childNode in ast.iter_child_nodes(module_node):
            if isinstance(childNode, ast.ClassDef):
                if childNode.name.startswith("_"):
                    continue
                classes.append(parse_class_definition(childNode))
            elif isinstance(childNode, (ast.FunctionDef, ast.AsyncFunctionDef)):
                if childNode.name.startswith("_"):
                    continue
                functions.append(parse_function_definition(childNode))

        # do not write module if empty
        if not classes and not functions and not docstring:
            return None

        if classes:
            mm["classes"] = classes
        if functions:
            mm["functions"] = functions
        mm["docstring"] = docstring
        mm["name"] = file_path.stem
        mm["importPath"] = get_import_path(file_path, root_path)
        mm["filePath"] = file_path.relative_to(from_dir).as_posix()

    return mm

def read_cli_args():
    parser = argparse.ArgumentParser(
        description="Extract docstrings from python libraries"
    )
    parser.add_argument("--input-paths", required=True, nargs="+", help="Paths to scan")
    parser.add_argument("--max-file-size", required=False, type=int, help="Max scanned python file size in KB", default=10000)
    parser.add_argument(
            "--from-dir",
            required=False,
            help="Starting path to consider looking for the file",
        )
    parser.add_argument(
        "--debug", required=False, action="store_true", help="Activate debug output"
    )
    parser.add_argument("output", help="Output file")
    return parser.parse_args()

def parse_package(input_path: Path, output_path: Path, rootdir: Path, from_dir: Path, max_file_size_in_kb: int):
    packages = {
        "name": output_path.name,
        "importPath": str(input_path.relative_to(rootdir)).replace("/", ".")
    }
    subpackages = []
    modules = []
    for p in input_path.iterdir():
        if p.name.startswith("."):
            # ignore dotfiles
            continue

        if p.is_dir():
            # fast & kinda efficient way to check if there are any python files, as `rglob` is a generator
            contains_python_files = any(p.rglob("*.py"))
            if contains_python_files:
                subpackages.append(
                    {"name": p.name, "path": str(input_path.relative_to(rootdir) / p.name)}
                )
        elif p.is_file() and p.suffix == ".py":
            file_size_kb = p.stat().st_size / 1000
            if file_size_kb > max_file_size_in_kb:
                modules.append({"importPath": get_import_path(p, rootdir),
                                "documentationError": {
                                    "title": "File size exceeded",
                                    "message": f"File {p.name} exceeded the limit of {max_file_size_in_kb}KB. Current size is {file_size_kb}KB"
                                },
                                "name": p.stem})
            else:
                try:
                    parsed_file = parse_file(p, rootdir, from_dir)
                    if parsed_file is not None:
                        modules.append(parsed_file)
                except Exception as e:
                    modules.append({"importPath": get_import_path(p, rootdir),
                                    "documentationError": {
                                        "title": f"Issue parsing file {p.name}",
                                        "message": str(e)
                                    },
                                    "name": p.stem})
                    logging.error(f"issue parsing file {p.name}")

    if subpackages:
        packages["packages"] = sort_modules_packages(subpackages)
    if modules:
        packages["modules"] = sort_modules_packages(modules)
    return packages


def sort_modules_packages(modules):
    # Sort by name
    modules = sorted(modules, key=lambda m: m['name'])
    # Keep files starting by __ at the beginning
    return [m for m in modules if m["name"].startswith("__")] + [m for m in modules if not m["name"].startswith("__")]


def does_root_package_has_subpackages(parsed):
    return "packages" in parsed and parsed["importPath"] == "."


def scan_path(input_path: Path, output_path: Path, rootdir: Path, from_dir: Path, max_file_size: int):
    parsed = parse_package(input_path, output_path, rootdir, from_dir, max_file_size)

    packages_path = output / "packages"
    # do not write on disk if no content to display
    if "modules" in parsed or does_root_package_has_subpackages(parsed):
        import_path = parsed["importPath"]
        if import_path == ".":
            output.mkdir(parents=True, exist_ok=True)
            of_path = output / f"{output.name}.json"
            can_merge = os.path.exists(of_path)
        else:
            packages_path.mkdir(parents=True, exist_ok=True)
            of_path = packages_path / f"{import_path}.json"
            # Detect if a package with a similar name still exists
            # It's a bit tricky:
            # if there is a __init__.py, it will "win"
            # if there is no __init__.py and a duplicate, packages will be "merged"
            # So... let's just check if package exists for now...
            if os.path.exists(of_path):
                logging.info(f"file {of_path} already exists, skipping")
                return
            can_merge = False

        if can_merge: # Another module in the root dir exists, merge it with the current one
            with open(of_path, "r", encoding="utf-8") as f:
                input_parsed = json.load(f)
            modules = input_parsed.get("modules", []) or []
            already_added_module = set(m["name"] for m in modules)
            for m in parsed.get("modules", []):
                if m["name"] not in already_added_module:
                    modules.append(m)
                    already_added_module.add(m["name"])
            if modules:
                parsed["modules"] = sort_modules_packages(modules)

        with open(of_path, "w", encoding="utf-8") as of:
            json.dump(parsed, of)
            logging.info(f"wrote {of_path}")

    # now parse sub packages
    for subpackage in parsed.get("packages", []):
        scan_path(
            input_path / subpackage["name"],
            output_path / subpackage["name"],
            rootdir,
            from_dir,
            max_file_size
            )


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    args = read_cli_args()

    debug_output = args.debug

    from_dir = Path(args.from_dir) if args.from_dir else None # to handle multiple project pythonPaths
    output = Path(args.output)
    for path_to_scan in args.input_paths:
        path_to_scan = Path(path_to_scan)
        scan_path(path_to_scan, output, path_to_scan, from_dir, args.max_file_size)
