#!/usr/bin/env python3

# ===============================================================================
# Copyright (c) 2023 PTC Inc., Its Subsidiary Companies, and /or its Partners.
# All Rights Reserved.
#
# Vuforia is a trademark of PTC Inc., registered in the United States and other
# countries.
# ===============================================================================

import logging
import argparse
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import threading
import itertools
import time
import pprint
from typing import Optional, Dict, List

import requests
import requests.auth
import tqdm
from requests import Response

logger = logging.getLogger(__name__)


class VuforiaWebServiceClient:
    headers: Dict[str, str] = {"User-Agent": "VuforiaWebServiceClientPython/1.0"}

    def __init__(self, api_base: str, client_id: str, client_secret: str):
        self.api_base: str = api_base
        self.client_id: str = client_id
        self.client_secret: str = client_secret

    def login(self) -> None:
        logger.debug("Fetching JWT...")
        jwt_response = requests.post(
            f"{self.api_base}/oauth2/token",
            auth=requests.auth.HTTPBasicAuth(self.client_id, self.client_secret),
            data={"grant_type": "client_credentials"},
        )
        logger.debug(jwt_response)
        if jwt_response.status_code == 401:
            print(f"Invalid credentials: {jwt_response.text}")
        jwt_response.raise_for_status()

        jwt = jwt_response.json()["access_token"]
        self.headers.update({"Authorization": f"Bearer {jwt}"})

    def request(self, method: str, route: str, **kwargs) -> requests.Response:
        if "Authorization" not in self.headers:
            self.login()

        url = f"{self.api_base}/{route}"
        response = requests.request(method=method, url=url, headers=self.headers, **kwargs)

        # 401 response means the JWT expired
        if response.status_code == 401:
            self.login()
            response = requests.request(method=method, url=url, headers=self.headers, **kwargs)

        return response

    def post(self, route: str, **kwargs) -> requests.Response:
        return self.request("post", route, **kwargs)

    def put(self, route: str, **kwargs) -> requests.Response:
        return self.request("put", route, **kwargs)

    def delete(self, route: str, **kwargs) -> requests.Response:
        return self.request("delete", route, **kwargs)

    def get(self, route: str, **kwargs) -> requests.Response:
        return self.request("get", route, **kwargs)


class AreaTargetWebAPIClient:
    def __init__(self, vws_client: VuforiaWebServiceClient, max_parallel_connections: int = 2, max_tries: int = 3):
        self.vws_client = vws_client
        self.max_parallel_connections = max_parallel_connections
        self.max_tries = max_tries

    class AreaTarget:
        id: str
        name: str
        status: str
        generation_progress: Optional[int] = None
        error: Optional[dict] = None
        warning: Optional[dict] = None

        def __init__(self, d: Dict[str, str]):
            self.__dict__.update(d)

        def __repr__(self):
            return pprint.pformat(self.__dict__)

    def _upload_part(self, source_file: Path, upload_url: str, part_offset: int, part_size: int,
                     progress: tqdm.tqdm, progress_lock: threading.Lock) -> None:
        class ReadPartWithProgress:
            remainder: int = part_size

            def read(self, size=-1):
                # avoid reading more than this part size
                if size == -1:
                    size = self.remainder
                size = min(size, self.remainder)
                data = f.read(size)
                with progress_lock:
                    progress.update(len(data))
                self.remainder -= len(data)
                assert self.remainder >= 0
                return data

            def __len__(self):
                return part_size

        with open(source_file, "rb") as f:
            success = False
            num_tries = 0
            sleep_seconds = 5
            while not success and num_tries < self.max_tries:
                f.seek(part_offset)
                num_tries += 1
                part_reader = ReadPartWithProgress()
                try:
                    upload_response = requests.put(upload_url, data=part_reader)
                    success = (upload_response.status_code == 200)
                    if not success and num_tries >= self.max_tries:
                        raise Exception(f"Failed to upload part with with status code: {upload_response.status_code},"
                                        f"response: {upload_response.text}")
                except requests.exceptions.RequestException:
                    if num_tries >= self.max_tries:
                        raise
                if not success:
                    # reset back the progress as we need to upload the part again
                    bytes_uploaded = part_size - part_reader.remainder
                    with progress_lock:
                        progress.n -= bytes_uploaded
                        progress.refresh()
                    logger.warning(f"Part upload failed. Will try again in {sleep_seconds} seconds."
                                   f"Attempts: {num_tries}. {bytes_uploaded} bytes need to be uploaded again.")
                    time.sleep(sleep_seconds)
                    sleep_seconds *= 2

    def create_target(self, source_file: Path, target_name: str = "", target_sdk: str = "10.11") -> AreaTarget:
        # derive the target name from the source filename if not provided
        if not target_name:
            target_name = source_file.stem.replace("_authoring", "").replace(".", "")

        logger.info(f"Creating Area Target for {source_file}...")
        source_size = source_file.stat().st_size
        create_at_response_json = self.vws_client.post(
            "areatargets/datasets",
            json={"name": target_name,
                  "source_format": source_file.suffix.strip("."),
                  "source_size": source_size,
                  "target_sdk": target_sdk}
        ).json()
        logger.debug(create_at_response_json)
        at_id = create_at_response_json["id"]
        upload_urls = create_at_response_json["upload_urls"]
        logger.info(f"Area Target {at_id} created.")

        logger.info(f"Uploading source file {source_file}...")
        progress_lock = threading.Lock()
        with tqdm.tqdm(total=source_size, unit="B", unit_scale=True) as progress:
            # Equally split the file between parts and ensure that the last part contains the whole remainder
            quotient, remainder = divmod(source_size, len(upload_urls))
            part_sizes = [quotient] * len(upload_urls)
            part_sizes[-1] += remainder
            with ThreadPoolExecutor(max_workers=self.max_parallel_connections) as thread_pool:
                futures = []
                for upload_url, part_offset, part_size in zip(upload_urls, itertools.count(0, quotient), part_sizes):
                    futures.append(thread_pool.submit(self._upload_part, source_file, upload_url,
                                                      part_offset, part_size, progress, progress_lock))
                # raise exceptions from the part uploads if any
                for future in futures:
                    future.result()

        logger.info("Completing upload...")
        complete_response = self.vws_client.put(f"areatargets/datasets/{at_id}").json(object_hook=self.AreaTarget)
        pprint.pprint(complete_response)
        return complete_response

    def delete_target(self, target_id: str) -> bool:
        response = self.vws_client.delete(f"areatargets/datasets/{target_id}")
        pprint.pprint(response)
        return response.status_code == 204

    def list_targets(self, targets: Optional[List[AreaTarget]] = None, cursor: Optional[str] = None) \
            -> List[AreaTarget]:
        targets = targets or []
        response = self.vws_client.get(f"areatargets/datasets?cursor={cursor}") if cursor \
            else self.vws_client.get("areatargets/datasets")

        def targets_from_paginated_response(response: Response) -> List[AreaTargetWebAPIClient.AreaTarget]:
            return [AreaTargetWebAPIClient.AreaTarget(target) for target in response.json()["data"]]

        if isinstance(response.json(), list):
            # endpoint is not paginated
            targets += response.json(object_hook=self.AreaTarget)
            pprint.pprint(targets)
            return targets

        if "cursor" in response.json().keys():
            # more pages are available
            return self.list_targets(targets + targets_from_paginated_response(response), response.json()["cursor"])

        targets += targets_from_paginated_response(response)
        pprint.pprint(targets)
        return targets

    def target_status(self, target_id: str) -> AreaTarget:
        response = self.vws_client.get(f"areatargets/datasets/{target_id}/status").json(object_hook=self.AreaTarget)
        pprint.pprint(response)
        return response

    def download_target(self, target_id: str, output_directory: Path,
                        download_authoring: bool, download_delivery: bool) -> None:
        assets: List[Dict[str, str]] = []
        for asset_type, download in [("authoring", download_authoring), ("delivery", download_delivery)]:
            if download:
                assets.extend(self.vws_client.get(f"areatargets/datasets/{target_id}/{asset_type}/assets").json())
        logger.info(assets)

        output_directory.mkdir(parents=True, exist_ok=True)
        for asset in assets:
            filename = asset["filename"]
            response = requests.get(asset["url"], stream=True)
            file_size = int(response.headers.get("Content-Length", 0))
            with open(output_directory / filename, "wb") as f, \
                 tqdm.tqdm(desc=filename, total=file_size, unit='B', unit_scale=True) as progress:
                for data in response.iter_content(chunk_size=10*1024):
                    progress.update(f.write(data))


def main() -> None:
    parser = argparse.ArgumentParser(description="Example client for Area Target Web API")
    parser.add_argument("--api_url", type=str, default="https://vws.vuforia.com",
                        help="URL of the Area Target Web API")
    parser.add_argument("--client_id", type=str, help="OAuth2 client id", required=True)
    parser.add_argument("--client_secret", type=str, help="OAuth2 client secret", required=True)
    parser.add_argument("--max_parallel_connections", type=int, default=2,
                        help="Maximum number of parallel connections the client can make to speed up uploads of large "
                             "files.")
    parser.add_argument("--max_tries", type=int, default=3,
                        help="Maximum number of attempts the client can make when re-trying upload of data.")
    parser.add_argument("--verbose", "-v", action="count", default=0,
                        help="Enable detailed logging. Repeat the option for more verbose logging.")

    subparsers = parser.add_subparsers(dest="command", help="Action to execute on the API")
    create_command = subparsers.add_parser("create", help="Create a Cloud Area Target")
    create_command.add_argument("source_file", type=Path, help="Path to the source 3DT or E57 file.")
    create_command.add_argument("--target_name", type=str, default="", help="Name of the created target")
    create_command.add_argument("--target_sdk", type=str, default="10.11",
                                help="Target SDK value for the created target")

    delete_command = subparsers.add_parser("delete", help="Deletes Cloud Area Target with all the associated artifacts")
    delete_command.add_argument("target_id", type=str, help="Id of the target to delete")

    subparsers.add_parser("list", help="List all stored Cloud Area Targets")

    status_command = subparsers.add_parser("status", help="Shows status of a given Cloud Area Target")
    status_command.add_argument("target_id", type=str, help="Id of the target to show")

    download_command = subparsers.add_parser("download", help="Download assets for a given Cloud Area Target")
    download_command.add_argument("target_id", type=str, help="Id of the target to download")
    download_command.add_argument("output_directory", type=Path,
                                  help="Directory to which the downloaded assets will stored")
    download_command.add_argument("--download_authoring", type=bool, default=True,
                                  help="Whether to download authoring assets (authoring.3dt, unitypackage, "
                                       "navmesh.glb)")
    download_command.add_argument("--download_delivery", type=bool, default=True,
                                  help="Whether to download delivery assets (dat, xml, occlusion.3dt)")

    args = parser.parse_args()

    if not args.command:
        parser.print_help()
        return

    if args.verbose > 0:
        if args.verbose == 1:
            logging.basicConfig(level=logging.INFO)
        else:
            logging.basicConfig(level=logging.DEBUG)

    vws_client = VuforiaWebServiceClient(args.api_url, args.client_id, args.client_secret)
    client = AreaTargetWebAPIClient(vws_client, args.max_parallel_connections, args.max_tries)

    if args.command == "create":
        client.create_target(args.source_file, args.target_name, args.target_sdk)
    elif args.command == "delete":
        client.delete_target(args.target_id)
    elif args.command == "list":
        client.list_targets()
    elif args.command == "status":
        client.target_status(args.target_id)
    elif args.command == "download":
        client.download_target(args.target_id, args.output_directory,
                               args.download_authoring, args.download_delivery)


if __name__ == "__main__":
    main()
