import traceback
from azure.storage.blob import BlobServiceClient
from scripts.common.config_parser import *
from scripts.common.constants import GetDataFromStoreConstants, ComponentExceptions
from scripts.common.logsetup import logger


class AzureDownload:
    def __init__(self,query):
        self.query = query
        self.container = self.query[GetDataFromStoreConstants.CONTAINER_NAME]
        self.connection_string = self.query['connection_string']
        self.blob_service_client = BlobServiceClient.from_connection_string(self.connection_string)
        self.container_client = self.blob_service_client.get_container_client(self.container)

    def download_from_blob(self):
        try:
            logger.info("Creating Connection........")
            blob_client = self.blob_service_client.get_blob_client(container=self.container,
                                                                   blob=self.query[
                                                                       GetDataFromStoreConstants.AZURE_FILE_PATH])
            with open(self.query[GetDataFromStoreConstants.LOCAL_FILE_PATH], "wb") as download_file:
                download_file.write(blob_client.download_blob().readall())
            logger.info(
                "Started downloading file to path " + self.query[GetDataFromStoreConstants.LOCAL_FILE_PATH] + " from path " +
                self.query[GetDataFromStoreConstants.AZURE_FILE_PATH] + " on azure")
            return True
        except Exception as e:
            raise Exception(e)


if __name__ == '__main__':
    try:
        if GetDataFromStoreConstants.ARTIFACT_BASE_PATH in config.keys():
            azure_file_path = config[GetDataFromStoreConstants.ARTIFACT_BASE_PATH]
        else:
            raise Exception(ComponentExceptions.INVALID_ARTIFACT_BASE_PATH_EXCEPTION)
        if GetDataFromStoreConstants.ARTIFACT_NAME in config.keys():
            azure_file_name = config[GetDataFromStoreConstants.ARTIFACT_NAME]
        else:
            raise Exception(ComponentExceptions.INVALID_AZURE_FILE_NAME_EXCEPTION)
        if GetDataFromStoreConstants.CONTAINER_NAME in config.keys():
            azure_container_name = config[GetDataFromStoreConstants.CONTAINER_NAME]
        else:
            raise Exception(ComponentExceptions.INVALID_CONTAINER_NAME)
        logger.info(config)
        mydict = {
            "azure_file_path": os.path.join(azure_file_path, azure_file_name),
            "local_file_path": os.path.join(config['component_output_dir'], azure_file_name),
            "container_name": config['container_name'],
            "connection_string": config['connection_string']
        }
        obj = AzureDownload(mydict)
        val = obj.download_from_blob()
        if val:
            logger.info("File Downloaded Successfully")
    except:
        logger.info("Azure File Download Component Failed")
        logger.info(traceback.format_exc())
