"""
Author: Owaiz Mustafa Khan
Email: owaiz.mustafakhan@rockwellautomation.com
"""



from pymongo import MongoClient
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.database import Database

from scripts.config.app_configurations import DBConf

# Retrieval

mongo_client = MongoClient(DBConf.MONGO_URI)

def _get_database(name: str) -> Database:
    """
    This function is used to get a specific Database from the MongoDB
    :param name: Name of the database in which this collection is present
    :return: database object of Database class
    """
    database = mongo_client.get_database(name)
    return database


def get_collection(collection_name: str, db_name: str) -> Collection:
    """
    This function is used to get a specific collection from the MongoDB
    :param collection_name: The name of collection you want to access
    :param db_name: Name of the database in which this collection is present
    :return: collection object of Collection class
    """
    collection = _get_database(db_name).get_collection(collection_name)
    return collection

def execute_aggregate(collection: Collection, query: list):
    data = collection.aggregate(query)
    data = list(data) if data else []
    return data


# Operations
def find_all(collection: Collection):
    """
        This function is used to get all the records present in a collection
        :param collection:  Object of the collection you want to access
        :return: A list of all records present in the collection or [] if there are no records in collections
        """
    data = collection.aggregate([{'$project': {'_id': 0}}])
    find_result = list(data) if data else []
    return find_result

def find_index(collection: Collection,
               include_database_name: bool = True,
               include_schema_name: bool = True,
               include_table_name: bool = True,
               full_projection: bool = False,
               index_name: str = None,
               table_name: str = None,
               schema_name: str = None,
               database_name: str = None):
    """
        This function is used to get all the records present in a collection
        :param database_name: Name of the database in which you want to find the index
        :param schema_name: Name of the schema in which you want to find the index
        :param table_name: Name of the table in which you want to find the index
        :param full_projection: True if you want the whole record else False [default: False]
        :param index_name: Name of the index you want to find
        :param include_table_name: True if you want to include table name in result else False [default: True]
        :param include_schema_name: True if you want to include schema name in result else False [default: True]
        :param include_database_name: True if you want to include databse name in result else False [default: True]
        :param collection:  Object of the collection you want to access
        :return: A list of all records present in the collection or [] if there are no records in collections
        """

    query = list()
    projection = {'_id': 0}

    if include_database_name:
        projection.update({'database_name': 1})

    if include_schema_name:
        projection.update({"schema_name": "$schemas.schema_name"})

    if include_table_name:
        projection.update({"table_name": "$schemas.tables.table_name"})

    query.append({"$unwind": "$schemas"})
    query.append({"$unwind": "$schemas.tables"})
    query.append({"$unwind": "$schemas.tables.indexes"})
    projection.update({"index": "$schemas.tables.indexes"})

    match = dict()

    if database_name:
        match.update({"database_name": database_name})

    if schema_name:
        match.update({"schemas.schema_name": schema_name})

    if table_name:
        match.update({"schemas.tables.table_name": table_name})

    if index_name:
        match.update({"schemas.tables.indexes.name": index_name})

    if database_name or schema_name or table_name or index_name:
        query.append({'$match': match})

    if full_projection:
        query.append({'$project': {'_id': 0}})
    else:
        query.append({'$project': projection})

    result = execute_aggregate(collection, query)
    return result

def add_index(collection: Collection, data: dict):
    query = {
        "database_name": data.get('database_name'),
        "schemas.schema_name": data.get('schema_name', 'public'),
        "schemas.tables.table_name": data.get('table_name')
    }

    update = {
        "$push": {
            "schemas.$[s].tables.$[t].indexes": {
                "name": data.get('index').get('name'),
                "columns": data.get('index').get('columns'),
                "unique": data.get('index').get('unique'),
                "type": data.get('index').get('type')
            }
        }
    }

    array_filters = [
        {"s.schema_name": data.get('schema_name', 'public')},
        {"t.table_name": data.get('table_name')}
    ]

    # Perform the update
    result = collection.update_one(query, update, array_filters=array_filters)
    if result.modified_count > 0:
        return True
    else:
        return False

def update_index(collection: Collection, data: dict):
    query = {
        "database_name": data.get('database_name'),
        "schemas.schema_name": data.get('schema_name', 'public'),
        "schemas.tables.table_name": data.get('table_name'),
        "schemas.tables.indexes.name": data.get('index').get('name')
    }

    update = {
        "$set": {
            "schemas.$[s].tables.$[t].indexes.$[i].name": data.get('new_index').get('name'),
            "schemas.$[s].tables.$[t].indexes.$[i].columns": data.get('new_index').get('columns'),
            # New columns for the index
            "schemas.$[s].tables.$[t].indexes.$[i].unique": data.get('new_index').get('unique'),  # Update uniqueness of the index
            "schemas.$[s].tables.$[t].indexes.$[i].type": data.get('new_index').get('type')  # Update index type
        }
    }

    array_filters = [
        {"s.schema_name": data.get('schema_name', 'public')},
        {"t.table_name": data.get('table_name')},
        {"i.name": data.get('index').get('name')}
    ]

    # Perform the update
    result = collection.update_one(query, update, array_filters=array_filters)

    # Print the result
    if result.modified_count > 0:
        print("Index Updated successfully.")
        return True
    else:
        print("No matching index found or index was not updated.")
        return False

def delete_index(collection: Collection, data: dict):
    filter_criteria = {
        "database_name": data.get('database_name'),
        "schemas.schema_name": data.get('schema_name', 'public'),
        "schemas.tables.table_name": data.get('table_name'),
        "schemas.tables.indexes.name": data.get('index').get('name')
    }

    update_criteria = {
        "$pull": {
            "schemas.$[s].tables.$[t].indexes": {
                "name": data.get('index').get('name')
            }
        }
    }

    array_filters = [
        {"s.schema_name": data.get('schema_name', 'public')},
        {"t.table_name": data.get('table_name')}
    ]

    # Perform the update operation
    result = collection.update_one(
        filter_criteria,
        update_criteria,
        array_filters=array_filters
    )

    # Output the result
    if result.modified_count > 0:
        print("Index deleted successfully.")
        return True
    else:
        print("No matching index found or index was not deleted.")
        return False
