#!/usr/bin/env python

import argparse
import os
import sys
import json
import logging
from pkg_resources import packaging
import platform
import tarfile
import shutil
from tempfile import NamedTemporaryFile, TemporaryDirectory
import subprocess
import urllib.request
import hashlib
import copy
from urllib.parse import urlsplit, urlunsplit

def download_file(url, target):
    with urllib.request.urlopen(url) as stream, open(target, "wb") as fd:
        fd.write(stream.read())


def format_tags_for_azure_cli(tags):
    return [f"{k}={v}" if v else k for k,v in tags.items()]


def execute_command(command, check=True, redacted_command=None, hide_stdout=False, hide_stderr=False, extra_env={}):
    logging.info("Execute command: {}".format(" ".join(redacted_command if redacted_command else command)))
    process_env = os.environ.copy()
    process_env.update(extra_env)
    completed_process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=process_env)
    if check and completed_process.returncode != 0 :
        logging.critical(f"Command failed with return code {completed_process.returncode}")
        logging.critical("stdout: <REDACTED>" if hide_stdout else f"stdout: {completed_process.stdout}")
        logging.critical("stderr: <REDACTED>" if hide_stderr else f"stderr: {completed_process.stderr}")
        sys.exit(1)
    if not check:
        logging.info(f"returncode: {completed_process.returncode}")
    logging.info("stdout: <REDACTED>" if hide_stdout else f"stdout: {completed_process.stdout}")
    logging.info("stderr: <REDACTED>" if hide_stderr else f"stderr: {completed_process.stderr}")
    return completed_process


def main():
    logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', level=logging.INFO)

    # Parse args
    parser = argparse.ArgumentParser(
            description="Copies a public VHD to a Shared Image Gallery image version",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
            )
    parser.add_argument("vhd", help="Source VHD Url")
    parser.add_argument("target_image", help="Full resource id of the image: /subscriptions/SUBSCRIPTION_ID/resourceGroups/RESOURCE_GROUP/providers/Microsoft.Compute/galleries/GALLERY_NAME/images/IMAGE_DEFINITION/versions/VERSION or /subscriptions/SUBSCRIPTION_ID/resourceGroups/RESOURCE_GROUP/providers/Microsoft.Compute/images/IMAGE_NAME")
    parser.add_argument("--locations", required=True, help="Comme separated list of target locations for SIG replicas. First on is the primary. At least one required")
    parser.add_argument("--auth-from-fm-config", default=None, help="Reads a FM config file ")
    parser.add_argument("--auth-client-id", default=None, help="Client ID to auth with a service account")
    parser.add_argument("--auth-client-secret", default=None, help="Client password. We recommend using a file and give the path prefixed by @.")
    parser.add_argument("--auth-client-certificate-path", default=None, help="Client certificate file for az cli auth")
    parser.add_argument("--subscription", default=None, help="Azure subscription")
    parser.add_argument("--auth-tenant", default=None, help="Tenant for az cli auth")
    parser.add_argument("--auth-managed-identity", default=None, help="Use the managed identity for az login. Value mandatory like it is for FM.")
    parser.add_argument("--azcopy-bin", default="azcopy", help="Path to the azcopy binary")
    parser.add_argument("--azcopy-skip-check", action='store_true', default=False, help="Ignore azcopy checks before exec.")
    parser.add_argument("--azure-publisher", default="metadata:AzureSIGPublisher", help="Value of publisher for image definition. Read from blob metadata value NAME if formated 'metadata:NAME'.")
    parser.add_argument("--azure-offer", default="metadata:AzureSIGOffer", help="Value of offer for image definition. Read from blob metadata value NAME if formated 'metadata:NAME'.")
    parser.add_argument("--azure-sku", default="metadata:AzureSIGSKU", help="Value of SKU for image definition. Read from blob metadata value NAME if formated 'metadata:NAME'.")
    parser.add_argument("--azure-conf-dir", default=None, help="Force where to store temporary creds. Otherwise use a temp directory.")
    parser.add_argument("--tags", default="", help="Comma separated list of key[=value] of tags to add on created resources.")
    parser.add_argument("--temporary-disk-name", default=None, help="Name to give the temporary managed disk.")
    parser.add_argument("--temporary-resource-group", default=None, help="Resource group where to create temporary resources. Defaults to same than target gallery.")
    parser.add_argument("--disk-grant-duration", default=86400, help="Duration in seconds for the disk SAS token grant.")
    parser.add_argument("--keep", action='store_true', default=False, help="Keep temporary resources instead of cleaning them up after exec.")
    parser.add_argument("--force", action='store_true', default=False, help="Delete and recreates the image if already exists.")
    parser.add_argument("--proxy", default=None, help="Proxy to use for HTTP requests.")
    parser.add_argument("--verbosity", type=int, default=4, choices=range(0,6), help="Verbosity level.")
    args = parser.parse_args()
    logging.getLogger().level = 10*(6-args.verbosity)
    azure_locations = args.locations.split(",")
    tags = {}
    for kv in args.tags.split(","):
        if "=" in kv:
            k,v = kv.split("=")
            tags[k] = v
        else:
            tags[kv] = ""

    with TemporaryDirectory() as azure_conf_temp_dir:
        azure_conf_dir = args.azure_conf_dir or azure_conf_temp_dir
        common_extra_env = {}

        if args.proxy:
            common_extra_env.update({
                "HTTP_PROXY": args.proxy,
                "HTTPS_PROXY": args.proxy
            })

        # Manage login
        managed_identity = args.auth_managed_identity
        client_id = args.auth_client_id
        client_secret = args.auth_client_secret
        client_certificate_path = args.auth_client_certificate_path
        tenant = args.auth_tenant
        if args.auth_from_fm_config:
            logging.info("Extract Azure auth info from FM config")
            fm_config = json.load(open(args.auth_from_fm_config,"r"))
            client_id = fm_config["azureSettings"].get("clientId", None)
            client_secret = fm_config["azureSettings"].get("secret", None)
            client_certificate_path = fm_config["azureSettings"].get("certificatePath", None)
            tenant = fm_config["azureSettings"].get("tenantId", None)
            managed_identity = fm_config["azureSettings"].get("managedIdentityId", None)
        if managed_identity:
            logging.info("Run azure login with managed identity.")
            common_extra_env.update({"AZURE_CONFIG_DIR": azure_conf_dir})
            logging.debug("Using %s for AZURE_CONFIG_DIR", azure_conf_dir)
            execute_command(["az", "login", "--identity", "--resource-id", managed_identity], extra_env=common_extra_env)
        elif client_id and tenant:
            if client_secret:
                with NamedTemporaryFile() as secret_file:
                    if client_secret[0] != '@':
                        secret_file.write(client_secret.encode('ascii'))
                        secret_file.flush()
                        client_secret = f"@{secret_file.name}"
                    logging.info("Run azure login with client secret.")
                    common_extra_env.update({"AZURE_CONFIG_DIR": azure_conf_dir})
                    logging.debug("Using %s for AZURE_CONFIG_DIR", azure_conf_dir)
                    execute_command(["az","login","--service-principal", "--username", client_id, "--password", client_secret, "--tenant", tenant], extra_env=common_extra_env)
            elif client_certificate_path:
                logging.info("Run azure login with client certificate.")
                common_extra_env.update({"AZURE_CONFIG_DIR": azure_conf_dir})
                logging.debug("Using %s for AZURE_CONFIG_DIR", azure_conf_dir)
                execute_command(["az","login","--service-principal", "--username", client_id, "--password", client_certificate_path, "--tenant", tenant], extra_env=common_extra_env)
            else:
                logging.critical("Authentication with App Id requires at least a secret or a certificate.")
                sys.exit(1)

        # Prepare the resources to cache the image
        target_image_args = args.target_image.split("/") 
        deployment_mode = target_image_args[7]
        if deployment_mode == "galleries":
            _,_,subscription_id,_,resource_group,_,_,_,gallery_name,_,image_definition,_,image_version = target_image_args
        elif deployment_mode == "images":
            _,_,subscription_id,_,resource_group,_,_,_,image_name = target_image_args
        else:
            logger.critical("Deployment mode %s unknown", deployment_mode)
            sys.exit(1)
        
        # Get the blob size and metas
        req = urllib.request.Request(args.vhd, method="HEAD")
        blob_length = 0
        blob_metas = {}
        with urllib.request.urlopen(req) as response:
            blob_length = response.headers["Content-Length"] 
        req = urllib.request.Request(f"{args.vhd}?comp=metadata", method="GET")
        with urllib.request.urlopen(req) as response:
            blob_metas = response.headers

        subscription = []
        if args.subscription is not None:
            subscription = ["--subscription", args.subscription]

        if deployment_mode == "galleries":
            # Check target image and handle previous existence
            completed_process = execute_command(["az", "sig", "image-version", "show", 
                "--resource-group", resource_group,
                "--gallery-name",  gallery_name,
                "--gallery-image-definition", image_definition,
                "--gallery-image-version", image_version,
                "--output", "json"
                ] + subscription, check=False, extra_env=common_extra_env)
            if completed_process.returncode == 0:
                logging.info("Image already exists, check provisioning state.")
                image_infos = json.loads(completed_process.stdout)
                if args.force or image_infos["provisioningState"] == "Failed":
                    logging.warning("Image already exists in Failed state, or --force have been passed. Delete it.")
                    completed_process = execute_command(["az", "sig", "image-version", "delete", 
                        "--resource-group", resource_group,
                        "--gallery-name",  gallery_name,
                        "--gallery-image-definition", image_definition,
                        "--gallery-image-version", image_version,
                        "--output", "json",
                    ] + subscription, extra_env=common_extra_env)
                elif image_infos["provisioningState"] != "Succeeded":
                    logging.warning("Image already is creating, wait for it.")
                    completed_process = execute_command(["az", "sig", "image-version", "wait", 
                        "--resource-group", resource_group,
                        "--gallery-name",  gallery_name,
                        "--gallery-image-definition", image_definition,
                        "--gallery-image-version", image_version,
                        "--output", "json",
                        "--created",
                    ] + subscription, extra_env=common_extra_env)
                    logging.warning("Image is ready.")
                    sys.exit(0)
                else:
                    logging.info("Image already exists, do nothing.")
                    sys.exit(0)
            else:
                # TODO: Make real check with structured error if possible
                #if not completed_process.stderr.decode('ascii').startswith("ERROR: (ResourceNotFound)"):
                    #logging.critical("Image existence check failed, abort")
                    #sys.exit(1)
                pass

            # Check image gallery
            completed_process = execute_command(["az", "sig", "show", 
                "--resource-group", resource_group,
                "--gallery-name",  gallery_name,
                "--output", "json",
                ] + subscription, check=False, extra_env=common_extra_env)
            if 0 != completed_process.returncode:
                # TODO: Make real check with structured error if possible
                logging.info("Create image gallery %s", gallery_name)
                completed_process = execute_command(["az", "sig", "create", 
                    "--resource-group", resource_group,
                    "--gallery-name",  gallery_name,
                    "--location", azure_locations[0],
                    "--output", "json",
                    ] + subscription + ["--tags"] + format_tags_for_azure_cli(tags),
                    extra_env=common_extra_env
                )
            else:
                logging.info("Image gallery %s exists", gallery_name)

            # Check image definition
            completed_process = execute_command(["az", "sig", "image-definition", "show", 
                "--resource-group", resource_group,
                "--gallery-name",  gallery_name,
                "--gallery-image-definition",  image_definition,
                "--output", "json",
                ] + subscription, check=False, extra_env=common_extra_env)
            if 0 != completed_process.returncode:
                # TODO: Make real check with structured error if possible
                logging.info("Create image definition %s", image_definition)
                publisher = args.azure_publisher
                if publisher.startswith("metadata:"):
                    publisher = blob_metas["x-ms-meta-{}".format(publisher[9:])]
                offer = args.azure_offer
                if offer.startswith("metadata:"):
                    offer = blob_metas["x-ms-meta-{}".format(offer[9:])]
                sku = args.azure_sku
                if sku.startswith("metadata:"):
                    sku = blob_metas["x-ms-meta-{}".format(sku[9:])]
                completed_process = execute_command(["az", "sig", "image-definition", "create", 
                    "--resource-group", resource_group,
                    "--gallery-name",  gallery_name,
                    "--gallery-image-definition",  image_definition,
                    "--os-type","Linux",
                    "--publisher", publisher,
                    "--hyper-v-generation", "V2",
                    "--offer", offer,
                    "--sku", sku,
                    "--output", "json"
                    ] + subscription + ["--tags"] + format_tags_for_azure_cli(tags),
                    extra_env=common_extra_env,
                )
            else:
                logging.info("Image definition %s exists", image_definition)

        if deployment_mode == "images":
            # Check target image and handle previous existence
            completed_process = execute_command(["az", "image", "show",
                "--resource-group", resource_group,
                "--name", image_name,
                "--output", "json"
                ] + subscription, check=False, extra_env=common_extra_env)
            if completed_process.returncode == 0:
                logging.info("Image already exists, check provisioning state.")
                image_infos = json.loads(completed_process.stdout)
                if args.force or image_infos["provisioningState"] == "Failed":
                    logging.warning("Image already exists in Failed state, or --force have been passed. Delete it.")
                    completed_process = execute_command(["az", "image", "delete",
                        "--resource-group", resource_group,
                        "--name", image_name,
                        "--output", "json",
                    ] + subscription, extra_env=common_extra_env)
                else:
                    logging.info("Image already exists, do nothing.")
                    sys.exit(0)
            else:
                # TODO: Make real check with structured error if possible
                #if not completed_process.stderr.decode('ascii').startswith("ERROR: (ResourceNotFound)"):
                    #logging.critical("Image existence check failed, abort")
                    #sys.exit(1)
                pass

        # Check temporary disk and handle previous existence
        disk_resource_group = args.temporary_resource_group or resource_group
        disk_name = args.temporary_disk_name or "vhd_cache_{}".format(hashlib.sha1(args.vhd.encode('ascii')).hexdigest())
        disk_tags = copy.deepcopy(tags)
        # Sanitize the VHD url to remove SAS token from tag if any are present
        disk_tags["SourceVhd"] = urlunsplit(urlsplit(args.vhd)._replace(query="", fragment=""))
        logging.info("Check if disk %s exists", disk_name)
        completed_process = execute_command(["az", "disk", "show",
            "--resource-group", resource_group,
            "--name",  disk_name,
            "--output", "json"
            ] + subscription, check=False, extra_env=common_extra_env)
        if completed_process.returncode == 0:
            logging.warning("Disk already exists, delete it before we proceed.")
            completed_process = execute_command(["az", "disk", "delete",
                "--resource-group", resource_group,
                "--name",  disk_name,
                "--output", "json",
                "--yes",
            ] + subscription, extra_env=common_extra_env)
        else:
            # TODO: Make real check with structured error if possible
            #if not completed_process.stderr.decode('ascii').startswith("ERROR: (ResourceNotFound)"):
                #logging.critical("Disk existence check failed, abort")
                #sys.exit(1)
            pass

        # Create new disk for upload
        completed_process = execute_command(["az", "disk", "create",
            "--resource-group", disk_resource_group,
            "--name",  disk_name,
            "--location", azure_locations[0],
            "--for-upload",
            "--upload-size-bytes", blob_length,
            "--sku", "standard_lrs",
            "--output", "json",
            "--hyper-v-generation", "V2",
            ] + subscription + ["--tags"] + format_tags_for_azure_cli(disk_tags),
            extra_env=common_extra_env,
        )
        disk_infos = json.loads(completed_process.stdout)

        # Grant access to the disk
        completed_process = execute_command(["az", "disk", "grant-access",
            "--resource-group", disk_resource_group,
            "--name",  disk_name,
            "--access-level", "Write",
            "--duration-in-seconds", str(args.disk_grant_duration),
            "--output", "json",
            ] + subscription,
            hide_stdout=True,
            extra_env=common_extra_env,
        )
        access_sas = json.loads(completed_process.stdout)["accessSAS"]

        # Run azcopy
        azcopy_command = [args.azcopy_bin, "copy", args.vhd, access_sas, "--blob-type", "PageBlob"]
        azcopy_command_redacted = copy.deepcopy(azcopy_command)
        azcopy_command_redacted[2] = urlunsplit(urlsplit(args.vhd)._replace(query="", fragment=""))
        azcopy_command_redacted[3] = urlunsplit(urlsplit(access_sas)._replace(query="", fragment=""))
        execute_command(azcopy_command, redacted_command=azcopy_command_redacted, extra_env=common_extra_env)

        # Revoke disk upload access. This is REQUIRED or the image creation will fail.
        completed_process = execute_command(["az", "disk", "revoke-access",
            "--resource-group", disk_resource_group,
            "--name",  disk_name,
            "--output", "json",
        ] + subscription, extra_env=common_extra_env)
        
        if deployment_mode == "galleries":
            # Create image 
            # TODO: A lot of refinement is possible here on image replication and encryption. Ignored for now.
            completed_process = execute_command(["az", "sig", "image-version", "create", 
                "--resource-group", resource_group,
                "--gallery-name",  gallery_name,
                "--gallery-image-definition",  image_definition,
                "--gallery-image-version", image_version,
                "--os-snapshot", disk_infos["id"],
                "--location", azure_locations[0],
                "--output", "json",
                ] + subscription
                + ["--tags"] + format_tags_for_azure_cli(disk_tags) 
                + ["--target-regions"] + azure_locations,
                extra_env=common_extra_env,
            )

        if deployment_mode == "images":
            completed_process = execute_command(["az", "image", "create",
                "--resource-group", resource_group,
                "--source", disk_infos["id"],
                "--location", azure_locations[0],
                "--name", image_name,
                "--os-type", "Linux",
                "--hyper-v-generation", "V2",
                "--output", "json",
                ] + subscription
                + ["--tags"] + format_tags_for_azure_cli(disk_tags),
                extra_env=common_extra_env,
            )

        # Cleanup
        if not args.keep:
            logging.warning("Cleanup temporary disk")
            completed_process = execute_command(["az", "disk", "delete",
                "--resource-group", resource_group,
                "--name",  disk_name,
                "--output", "json",
                "--yes",
            ] + subscription, extra_env=common_extra_env)

    logging.info("Done.")

if __name__ == '__main__':
    main()

