from pydantic import BaseModel

from scripts.db.redis_connection import destination_space_db
from scripts.utils.mongo_utils import MongoCollectionBaseClass

from scripts.constants.db_constants import (
    CollectionNames,
    DatabaseNames,
)
from scripts.logging import logger


class ArtifactMetaSchema(BaseModel):
    artifact_id: str
    name: str
    artifact_type: str
    ver: float
    image: str
    status: str
    meta: dict
    source_details: dict
    comments: str
    space_id: str
    source_id: str


class ArtifactsMeta(MongoCollectionBaseClass):
    def __init__(self, mongo_client, space_id=None):
        super().__init__(
            mongo_client, database=DatabaseNames.catalog, collection=CollectionNames.artifact_meta, space_db=destination_space_db
        )
        self.space_id = space_id

    @property
    def key_space_id(self):
        return "space_id"

    def fetch_artifacts_count(self):
        try:
            results = self.aggregate(
                [
                    {"$match": {"status": "approved"}},
                    {
                        "$group": {
                            "_id": "$artifact_type",
                            "count": {"$sum": 1},
                        }
                    },
                    {
                        "$project": {
                            "name": "$_id",
                            "value": "$count",
                            "_id": 0,
                        }
                    },
                ]
            )

            if not results:
                return None
            return list(results)
        except Exception as e:
            logger.error(f"Error occurred in fetching artifacts due to {str(e)}")

    def get_artifact_meta_by_aggregate(self, query: list):
        return list(self.aggregate(pipelines=query))

    def fetch_artifact_by_id(self, artifact_id):
        return self.find_one({"artifact_id": artifact_id}, filter_dict={"_id": 0})

    def get_artifact_latest_version(self, artifact_name, artifact_type):
        artifact_older_record = self.find(
            query={
                "name": artifact_name,
                "artifact_type": artifact_type,
            },
            filter_dict={"_id": 0, "ver": 1},
            sort={"ver": -1},
            limit=1,
        )
        artifact_older_record = list(artifact_older_record)
        if artifact_older_record:
            return f"{int(float(artifact_older_record[0]['ver']))+ 1}.0"
        else:
            return "1.0"
