from edge_engine.common.logsetup import logger
from scripts.utils.infocenter import MongoLogger
import time
from collections import deque
import cv2
import numpy as np
#import torch
#import torch.nn.functional as F
#from action_recognition.model import create_model
#from action_recognition.spatial_transforms import (CenterCrop, Compose,
#                                                   Normalize, Scale, ToTensor, MEAN_STATISTICS, STD_STATISTICS)
#from action_recognition.utils import load_state, generate_args
from config import config
import datetime
import pytz
from edge_engine.ai.model.modelwraper import ModelWrapper
import json

TEXT_COLOR = (0,0,255)
TEXT_FONT_FACE = cv2.FONT_HERSHEY_PLAIN
TEXT_FONT_SIZE = 3
TEXT_VERTICAL_INTERVAL = 50
NUM_LABELS_TO_DISPLAY = 1
action_queue = ["", "", "", "", ""]
start_time = time.time()
prev_action = " "
prev_action_time = 0
action_transition_counter = 0
activity_dict = {"picking": 1, "needle": 2, "non_needle": 3, "passing": 4, "idle": 5}
sequence_details = {"sequence_start": (None, False), "picking": (0.0, False),"needle": (0.0, False), "non_needle":(0.0, False), "passing":(0.0, False), "sequence_end": (None, False)}
sequence_list = []
sequence_start_text = None
sequence_end_text = None
needle_text = ""
non_needle_text = ""
current_action = " "
accumulated_time = 0
num_cycle = 0
stoppage_counter = 0
sequence_flag = False
sequence_counter = 0
activity_dict = {"picking": 1, "needle": 2, "non_needle": 3, "passing": 4, "idle": 5}
max_cycle_id = 1
ops = {"kinari": "K1", "collar trim": "C1", "cuff trim": "CF1"}	

class TorchActionRecognition:
    def __init__(self, encoder, checkpoint_path, num_classes=101):
        model_type = "{}_vtn".format(encoder)
        args, _ = generate_args(model=model_type, n_classes=num_classes, layer_norm=False)
        self.model, _ = create_model(args, model_type)

        self.model = self.model.module
        self.model.eval()
        self.model.cuda()

        checkpoint = torch.load(str(checkpoint_path))

        load_state(self.model, checkpoint['state_dict'])

        self.preprocessing = self.make_preprocessing(args)
        self.embeds = deque(maxlen=(args.sample_duration * args.temporal_stride))

    def preprocess_frame(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return self.preprocessing(frame)

    def infer_frame(self, frame):
        embedding = self._infer_embed(frame)
        self.embeds.append(embedding)
        sequence = self.get_seq()
        return self._infer_logits(sequence)

    def _infer_embed(self, frame):
        with torch.no_grad():
            frame_tensor = frame.unsqueeze(0).to('cuda')
            tensor = self.model.resnet(frame_tensor)
            tensor = self.model.reduce_conv(tensor)
            embed = F.avg_pool2d(tensor, 7)
        return embed.squeeze(-1).squeeze(-1)

    def _infer_logits(self, embeddings):
        with torch.no_grad():
            ys = self.model.self_attention_decoder(embeddings)
            ys = self.model.fc(ys)
            ys = ys.mean(1)
        return ys.cpu()

    def _infer_seq(self, frame):
        with torch.no_grad():
            result = self.model(frame.view(1, 16, 3, 224, 224).to('cuda'))
        return result.cpu()

    def get_seq(self):
        sequence = torch.stack(tuple(self.embeds), 1)
        sequence = sequence[:, ::2, :]

        if sequence.size(1) < 16:
            num_repeats = 15 // sequence.size(1) + 1
            sequence = sequence.repeat(1, num_repeats, 1)[:, :16, :]

        return sequence

    def make_preprocessing(self, args):
        return Compose([
            Scale(args.sample_size),
            CenterCrop(args.sample_size),
            ToTensor(args.norm_value),
            Normalize(MEAN_STATISTICS[args.mean_dataset], STD_STATISTICS[args.mean_dataset])
        ])



class ActionRecognition(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        self.mongo_logger = MongoLogger()
        self.type = config['inputConf']['sourceType']
        self.frame_skip = self.config.get('frame_skip', False)
        if config['inputConf']['sourceType'] == 'videofile':
            f = open('gokaldas.json', "r")
            self.dets = json.loads(f.read())
            f.close()
        self.print_eu_dist = model_config.get('print_eu_dist', 200)

        self.encoder = 'resnet34'
        self.checkpoint = 'data/save_10.pth'
        input_video = "data/kinari.mp4"
        self.labels = "data/labels.txt"
        with open(self.labels) as fd:
            self.labels = fd.read().strip().split('\n')
        # self.model = TorchActionRecognition(self.encoder, self.checkpoint, num_classes=len(self.labels))

    def _pre_process(self, x):
        """
        Do preprocessing here, if any
        :param x: payload
        :return: payload
        """
        return x

    def _post_process(self, x):
        """
        Apply post processing here, if any
        :param x: payload
        :return: payload
        """
        logger.info("************************")
        self.rtp.publish(x)  # video stream
        return x

    def send_payload(self,
            message = "Action Recognition",
            sequence_start_time= " ",
            picking_time= 0,
            non_needle_time= 0,
            needle_time= 0, passing_time= 0,
            cycle_id= 0,
            stoppage_count = 0,
            sequence_end_time= " ",
            needle_ratio = 0,
            operation_master_id="NA",
            operation_name="NA",
            station_id="NA",
            emp_id="NA"
    ):
        """
        Insert event to Mongo
        :param message:
        :param frame:
        :param label:
        :param bg_color:
        :param font_color:
        :param alert_sound:
        :return: None
        """

        payload = {
            "deviceId": self.device_id,
            "message": message,
            "sequence_start_time": sequence_start_time,
            "picking_time": picking_time,
            "non_needle_time": non_needle_time,
            "needle_time": needle_time,
            "passing_time": passing_time,
            "cycle_id": cycle_id,
            "stoppage_count": stoppage_count,
            "sequence_end_time": sequence_end_time,
            "needle_ratio": needle_ratio,
            "operation_master_id": operation_master_id,
            "operation_name": operation_name,
            "station_id": station_id,
            "emp_id": emp_id
        }

        MongoLogger.insert_attendance_event_to_mongo(self,payload)

    def draw_rect(self, image, bottom_left, top_right, color=(0, 0, 0), alpha=1.):
        xmin, ymin = bottom_left
        xmax, ymax = top_right

        image[ymin:ymax, xmin:xmax, :] = image[ymin:ymax, xmin:xmax, :] * (1 - alpha) + np.asarray(color) * alpha
        return image


    def render_frame(self, frame, frameId):
        global action_transition_counter, prev_action, prev_action_time, sequence_details, current_action, sequence_start_text, sequence_end_text, num_cycle, accumulated_time, stoppage_counter, sequence_flag, sequence_counter, max_cycle_id
        try:

            if self.type == 'videofile':
                dets = self.dets[str(frameId)]['detections']
            text = 'Current Activity: {}'.format(dets)

            text = text.upper().replace("_", " ")
            cv2.putText(frame, text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, TEXT_FONT_SIZE,
                        (0, 0, 255), 2, 1)

            if (prev_action != dets):
                action_transition_counter = action_transition_counter + 1
                if (action_transition_counter > 5):
                    current_action = dets
                    if (current_action == "picking"):
                        if (sequence_flag):
                            if ((sequence_details["picking"][0] + sequence_details["needle"][0] +
                                 sequence_details["non_needle"][0] + sequence_details["passing"][0]) > 20):
                                sequence_list.append(sequence_details)

                                sequence_details["sequence_end"] = (datetime.datetime.now().astimezone(tz=pytz.timezone("Asia/Kolkata")).strftime("%Y-%m-%d %H:%M:%S"), True)
                                needle_ratio = int((sequence_details["needle"][0])/((sequence_details["picking"][0] + sequence_details["needle"][0] +
                                 sequence_details["non_needle"][0] + sequence_details["passing"][0])) * 100)


                                sequence_details["sequence_start"] = (datetime.datetime.now().astimezone(tz=pytz.timezone("Asia/Kolkata")).strftime("%Y-%m-%d %H:%M:%S"), True)
                    if (prev_action == "picking"):
                        sequence_details["picking"] = (
                        round(sequence_details["picking"][0] + float(time.time() - prev_action_time), 2), True)
                        sequence_details["needle"] = (round(sequence_details["needle"][0], 2), False)
                        sequence_details["non_needle"] = (round(sequence_details["non_needle"][0], 2), False)
                        sequence_details["passing"] = (round(sequence_details["passing"][0], 2), False)
                        if (sequence_flag == True):
                            sequence_counter = sequence_counter + 1
                        if (sequence_counter > 3):
                            sequence_flag = False
                            sequence_counter = 0

                    if (prev_action == "needle"):
                        sequence_details["needle"] = (
                        round(sequence_details["needle"][0] + float(time.time() - prev_action_time), 2), True)
                        sequence_details["picking"] = (round(sequence_details["picking"][0], 2), False)
                        sequence_details["non_needle"] = (round(sequence_details["non_needle"][0], 2), False)
                        sequence_details["passing"] = (round(sequence_details["passing"][0], 2), False)
                        if (sequence_flag == True):
                            sequence_counter = sequence_counter + 1
                        if (sequence_counter > 3):
                            sequence_flag = False
                            sequence_counter = 0
                    if (prev_action == "non_needle"):
                        sequence_details["non_needle"] = (
                        round(sequence_details["non_needle"][0] + float(time.time() - prev_action_time), 2),
                        True)
                        sequence_details["picking"] = (round(sequence_details["picking"][0], 2), False)
                        sequence_details["needle"] = (round(sequence_details["needle"][0], 2), False)
                        sequence_details["passing"] = (round(sequence_details["passing"][0], 2), False)
                        if (sequence_flag == True):
                            sequence_counter = sequence_counter + 1
                        if (sequence_counter > 3):
                            sequence_flag = False
                            sequence_counter = 0
                    if (prev_action == "passing"):
                        sequence_details["passing"] = (
                        round(sequence_details["passing"][0] + float(time.time() - prev_action_time), 2), True)
                        sequence_details["picking"] = (round(sequence_details["picking"][0], 2), False)
                        sequence_details["needle"] = (round(sequence_details["needle"][0], 2), False)
                        sequence_details["non_needle"] = (round(sequence_details["non_needle"][0], 2), False)
                        if (sequence_flag == True):
                            sequence_counter = sequence_counter + 1
                        if (sequence_counter > 3):
                            sequence_flag = False
                            sequence_counter = 0
                    if (current_action == "passing"):
                        # sequence_details["sequence_end"] = (datetime.datetime.now(), True)
                        sequence_flag = True

                    if (current_action == "non_needle" and prev_action == "needle"):
                        stoppage_counter = stoppage_counter + 1

                    prev_action_time = time.time()
                    prev_action = dets
                    action_transition_counter = 0

            if (sequence_details["sequence_start"][1] == True):

                sequence_start_text = "SEQUENCE STARTED   " + str(
                    sequence_details["sequence_start"][0])
            else:
                sequence_start_text = "SEQUENCE STARTED   --"
            if (current_action == "needle"):
                needle_text = "NEEDLE ACTIVITY : IN PROGRESS   " + str(
                    round(sequence_details["needle"][0], 3))
            else:
                needle_text = "NEEDLE ACTIVITY :  " + str(sequence_details["needle"][0])
            if (current_action == "non_needle"):
                non_needle_text = "NON NEEDLE ACTIVITY : IN PROGRESS   " + str(
                    sequence_details["non_needle"][0])
            else:
                non_needle_text = "NON NEEDLE ACTIVITY :   " + str(
                    sequence_details["non_needle"][0])
            if (current_action == "picking"):
                picking_text = "PICKING ACTIVITY : IN PROGRESS   " + str(sequence_details["picking"][0])
            else:
                picking_text = "PICKING ACTIVITY :   " + str(sequence_details["picking"][0])
            if (current_action == "passing"):
                passing_text = "PASSING ACTIVITY : IN PROGRESS   " + str(sequence_details["passing"][0])
            else:
                passing_text = "PASSING ACTIVITY :   " + str(sequence_details["passing"][0])
            if (sequence_details["sequence_end"][1] == True):
                sequence_end_text = "SEQUENCE COMPLETED   " + str(
                    sequence_details["sequence_end"][0])
                accumulated_time = sequence_details["picking"][0] + sequence_details["needle"][0] + \
                                   sequence_details["non_needle"][0] + sequence_details["passing"][0]
                num_cycle = len(sequence_list)
                # sequence_count_text = "SEQUENCE CYCLE : " + str(num_cycle) + "TIME : " + str(accumulated_time)
            else:
                sequence_end_text = "SEQUENCE COMPLETE   --"

            sequence_count_text = "SEQUENCE CYCLE : " + str(num_cycle)
            prev_seq_time = "PREVIOUS SEQUENCE TIME :  " + str(
                round(accumulated_time, 2))
            stoppage_text = "NUMBER OF STOPPAGE :  " + str(stoppage_counter)
            cv2.putText(frame, sequence_start_text, (50, 600), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)

            cv2.putText(frame, picking_text, (50, 650), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)
            cv2.putText(frame, needle_text, (50, 700), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)
            cv2.putText(frame, non_needle_text, (50, 750), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)
            cv2.putText(frame, passing_text, (50, 800), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)
            cv2.putText(frame, sequence_end_text, (50, 850), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)
            cv2.putText(frame, sequence_count_text, (50, 900), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)
            cv2.putText(frame, stoppage_text, (50, 950), TEXT_FONT_SIZE,
                        TEXT_FONT_FACE, TEXT_COLOR, 2)

            return frame

        except Exception as e:
            logger.info(f"Inference Failed : {e}")
            return frame

    def list_managment(self, action):
        action_queue.append(action)
        action_queue.pop(0)
        self.transition_check(action_queue)

    def transition_check(self, action_list):
        result = all(element == action_list[0] for element in action_list)
        if result:
            return False
        else:
            return True

  

    def run_demo(self, frame, frameId):
        frame = self.render_frame(frame, frameId)


        return frame

    def _predict(self, obj):
        try:
            time.sleep(0.03)
            frame = obj["frame"]
            frame = cv2.resize(frame, (1920, 1080))
            frameId = obj["frameId"]
            frame = self.run_demo(frame, frameId)

            obj["frame"] = cv2.resize(
                frame, (self.config.get("FRAME_WIDTH"), self.config.get("FRAME_HEIGHT"))
            )

        except Exception as e:
            logger.exception(f"Error: {e}", exc_info=True)
            obj["frame"] = cv2.resize(
                obj["frame"],
                (self.config.get("FRAME_WIDTH"), self.config.get("FRAME_HEIGHT")),
            )

        return obj
