
# from scipy.spatial import distance
from edge_engine.common.logsetup import logger
# from yolov5processor.infer import ExecuteInference
# from scripts.utils.edge_utils import get_extra_fields
# from edge_engine.ai.model.modelwraper import ModelWrapper
# from scripts.utils.centroidtracker import CentroidTracker
# from scripts.common.constants import JanusDeploymentConstants
# from scripts.utils.image_utils import draw_circles_on_frame, resize_to_64_64
#
# import cv2
# import base64
# import datetime
# import numpy as np
#
# from collections import deque
# from expiringdict import ExpiringDict
# from sklearn.utils.linear_assignment_ import linear_assignment
#
# from edge_engine.common.logsetup import logger
# from edge_engine.ai.model.modelwraper import ModelWrapper
#
# from scripts.utils.tracker import Tracker
# from scripts.utils.helpers import box_iou2
# from scripts.utils.edge_utils import Utilities
from scripts.utils.infocenter import MongoLogger
# from scripts.utils.model_tracker import ModelCountTracker
# from scripts.common.constants import JanusDeploymentConstants
#
# from yolov5processor.infer import ExecuteInference
#

import sys
import base64
import time
from argparse import ArgumentParser
from collections import deque

import cv2
import numpy as np
import torch
import torch.nn.functional as F

from action_recognition.model import create_model
from action_recognition.spatial_transforms import (CenterCrop, Compose,
                                                   Normalize, Scale, ToTensor, MEAN_STATISTICS, STD_STATISTICS)
from action_recognition.utils import load_state, generate_args
import psycopg2
from config import config
import datetime
from edge_engine.ai.model.modelwraper import ModelWrapper

TEXT_COLOR = (255,222,173)
TEXT_FONT_FACE = cv2.FONT_HERSHEY_DUPLEX
TEXT_FONT_SIZE = 1
TEXT_VERTICAL_INTERVAL = 25
NUM_LABELS_TO_DISPLAY = 1
action_queue = ["", "", "", "", ""]
start_time = time.time()
prev_action = " "
prev_action_time = 0
action_transition_counter = 0
sequence_details = {"sequence_start": (None, False), "picking": (0.0, False),"needle": (0.0, False), "non_needle":(0.0, False), "passing":(0.0, False), "sequence_end": (None, False)}
sequence_list = []
sequence_start_text = None
sequence_end_text = None
needle_text = ""
non_needle_text = ""
current_action = " "
accumulated_time = 0
num_cycle = 0
stoppage_counter = 0
sequence_flag = False
sequence_counter = 0
activity_dict = {"picking": 1, "needle": 2, "non_needle": 3, "passing": 4, "idle": 5}
max_cycle_id = 1


class TorchActionRecognition:
    def __init__(self, encoder, checkpoint_path, num_classes=101):
        model_type = "{}_vtn".format(encoder)
        args, _ = generate_args(model=model_type, n_classes=num_classes, layer_norm=False)
        self.model, _ = create_model(args, model_type)

        self.model = self.model.module
        self.model.eval()
        self.model.cuda()

        checkpoint = torch.load(str(checkpoint_path))

        load_state(self.model, checkpoint['state_dict'])

        # print(len(self.model.state_dict().keys()))
        # print(len(checkpoint['state_dict'].keys()))
        # print(self.model)

        self.preprocessing = self.make_preprocessing(args)
        self.embeds = deque(maxlen=(args.sample_duration * args.temporal_stride))

    def preprocess_frame(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return self.preprocessing(frame)

    def infer_frame(self, frame):
        embedding = self._infer_embed(frame)
        self.embeds.append(embedding)
        sequence = self.get_seq()
        return self._infer_logits(sequence)

    def _infer_embed(self, frame):
        with torch.no_grad():
            frame_tensor = frame.unsqueeze(0).to('cuda')
            tensor = self.model.resnet(frame_tensor)
            tensor = self.model.reduce_conv(tensor)
            embed = F.avg_pool2d(tensor, 7)
        return embed.squeeze(-1).squeeze(-1)

    def _infer_logits(self, embeddings):
        with torch.no_grad():
            ys = self.model.self_attention_decoder(embeddings)
            ys = self.model.fc(ys)
            ys = ys.mean(1)
        return ys.cpu()

    def _infer_seq(self, frame):
        with torch.no_grad():
            result = self.model(frame.view(1, 16, 3, 224, 224).to('cuda'))
        return result.cpu()

    def get_seq(self):
        sequence = torch.stack(tuple(self.embeds), 1)
        sequence = sequence[:, ::2, :]

        if sequence.size(1) < 16:
            num_repeats = 15 // sequence.size(1) + 1
            sequence = sequence.repeat(1, num_repeats, 1)[:, :16, :]

        return sequence

    def make_preprocessing(self, args):
        return Compose([
            Scale(args.sample_size),
            CenterCrop(args.sample_size),
            ToTensor(args.norm_value),
            Normalize(MEAN_STATISTICS[args.mean_dataset], STD_STATISTICS[args.mean_dataset])
        ])



class CementBagCounter(ModelWrapper):

    def __init__(self, config, model_config, pubs, device_id):
    #     super().__init__()
    #     """
    #     init function
    #     """
        self.config = config["config"]
        self.device_id = device_id
        self.rtp = pubs.rtp_write
        self.mongo_logger = MongoLogger()
        self.frame_skip = self.config.get('frame_skip', False)
    #     model = "data/ACC_v3.pt"
    #     # self.yp = ExecuteInference(weight=model,
    #     #                            gpu=model_config.get("gpu", False),
    #     #                            agnostic_nms=model_config.get("agnostic_nms", True),
    #     #                            iou=model_config.get("iou", 0.2),
    #     #                            confidence=model_config.get("confidence", 0.4))
        self.print_eu_dist = model_config.get('print_eu_dist', 200)
    #     # self.ct1 = CentroidTracker(maxDisappeared=5)
    #     # self.ct2 = CentroidTracker(maxDisappeared=5)
    #     self.frame_skipping = {
    #         "skip_current_frame": True,
    #         "detection_value": None
    #     }
    #     self.count = 0
    #     self.cement_bag = 0
    #     self.count_suraksha = 0
    #     self.count_whitecem = 0
    #     self.count_gold = 0
    #     self.tracker_list = []
    #     self.max_age = 3
    #     self.min_hits = 0
    #     self.track_id_list = deque([str(i) for i in range(1, 50)])
    #     self.prev_annotation = []
    #
    #
    #     self.initial_object_position = None

    def send_payload(self,
            message = "Action Recognition",
            sequence_start_text=" ",
            picking_text=" ",
            non_needle_text=" ",
            needle_text=" ", passing_text=" ",
            sequence_count_text=" ",
            stoppage_text=" ",
            sequence_end_text=" "
    ):
        """
        Insert event to Mongo
        :param message:
        :param frame:
        :param label:
        :param bg_color:
        :param font_color:
        :param alert_sound:
        :return: None
        """

        payload = {
            "deviceId": self.device_id,
            "message": message,
            "sequence_start_text": sequence_start_text,
            "picking_text": picking_text,
            "non_needle_text": non_needle_text,
            "needle_text": needle_text,
            "passing_text": passing_text,
            "sequence_count_text": sequence_count_text,
            "stoppage_text": stoppage_text,
            "sequence_end_text": sequence_end_text
        }

        MongoLogger.insert_attendance_event_to_mongo(self,payload)


    def draw_rect(self, image, bottom_left, top_right, color=(0, 0, 0), alpha=1.):
        xmin, ymin = bottom_left
        xmax, ymax = top_right

        image[ymin:ymax, xmin:xmax, :] = image[ymin:ymax, xmin:xmax, :] * (1 - alpha) + np.asarray(color) * alpha
        return image

    def insert_sequence_activity_table_list(seq_act_list):
        conn = None
        try:
            # read connection parameters
            params = config()
            print(params)

            uri = "postgres://ilens:iLens#4321@192.168.0.207:5432/ilens_ai?sslmode=disable"

            # connect to the PostgreSQL server
            # print('Connecting to the PostgreSQL database...')
            conn = psycopg2.connect(**params)
            conn = psycopg2.connect(uri)

            # create a cursor

            cur = conn.cursor()

            # execute a statement
            sql = 'INSERT INTO "sequence_activity_table" ("CycleID", "SequenceID", "SequenceName", "SequenceDuration", "OperationMasterID") VALUES (%s, %s, %s, %s, %s);'
            cur.execute(sql, seq_act_list)
            conn.commit()
            cur.close()
        except (Exception, psycopg2.DatabaseError) as error:
            print(error)
        finally:
            if conn is not None:
                conn.close()
                print('Database connection closed.')

    def insert_micro_activity_time_table_list(micro_act_time_list):
        conn = None
        try:
            # read connection parameters
            params = config()
            print(params)

            uri = "postgres://ilens:iLens#4321@192.168.0.207:5432/ilens_ai?sslmode=disable"

            # connect to the PostgreSQL server
            # print('Connecting to the PostgreSQL database...')
            conn = psycopg2.connect(**params)
            conn = psycopg2.connect(uri)

            # create a cursor

            cur = conn.cursor()

            # execute a statement
            sql = 'INSERT INTO "microactivity_time" ("CycleId", "EmpId", "Picking", "Non-Needle", "Needle", "Passing", "TotalTime") VALUES (%s, %s, %s, %s, %s, %s, %s);'
            cur.execute(sql, micro_act_time_list)
            conn.commit()
            cur.close()
        except (Exception, psycopg2.DatabaseError) as error:
            print(error)
        finally:
            if conn is not None:
                conn.close()
                print('Database connection closed.')

    # def max_id_generation():
    #     params = config()
    #     print(params)
    #     # connect to the PostgreSQL server
    #     print('Connecting to the PostgreSQL database...')
    #     conn = psycopg2.connect(**params)
    #     # conn = psycopg2.connect(uri)
    #
    #     # create a cursor
    #     cur = conn.cursor()
    #     max_cycle_id = 'SELECT "CycleID" FROM sequence_activity_table'
    #     cur.execute(max_cycle_id)
    #     max_id = cur.fetchall()
    #     print(max_id)
    #     if (not max_id):
    #         return 1
    #     x = [int(i[0]) for i in max_id]
    #     max_id = max(x)
    #     # ops_master_list = list(input_list)
    #     # ops_master_list[0] = str(max_id + 1)
    #     # cycle_id = tuple(ops_master_list)
    #     cur.close()
    #     conn.close()
    #     return max_id + 1

    def insert_operation_master_table_list(self, ops_master_list):
        conn = None
        try:
            # read connection parameters
            params = config()
            print(params)

            # uri = "postgres://postgres:postgres@192.168.0.220:5432/Gokaldas?sslmode=disable"

            # connect to the PostgreSQL server
            print('Connecting to the PostgreSQL database...')
            conn = psycopg2.connect(**params)
            # conn = psycopg2.connect(uri)

            # create a cursor
            cur = conn.cursor()
            # execute a statement
            sql = 'INSERT INTO "Operations_master_dummy" ("CycleID", "OperationMasterID", "OperationName", "StationID", "EmpID", "SequenceStartTime", "SequenceEndTime", "StoppageCount") VALUES (%s, %s, %s, %s, %s, %s, %s, %s);'
            cur.execute(sql, ops_master_list)
            conn.commit()
            cur.close()
        except (Exception, psycopg2.DatabaseError) as error:
            print(error)
        finally:
            if conn is not None:
                conn.close()
                print('Database connection closed.')

    def render_frame(self, frame, probs, labels):
        global action_transition_counter, prev_action, prev_action_time, sequence_details, current_action, sequence_start_text, sequence_end_text, num_cycle, accumulated_time, stoppage_counter, sequence_flag, sequence_counter, max_cycle_id
        order = probs.argsort(descending=True)

        status_bar_coorinates = (
            (0, 0),  # top left
            (650, 25 + TEXT_VERTICAL_INTERVAL * NUM_LABELS_TO_DISPLAY)  # bottom right
        )

        self.draw_rect(frame, status_bar_coorinates[0], status_bar_coorinates[1], alpha=0.5)

        for i, imax in enumerate(order[:NUM_LABELS_TO_DISPLAY]):
            if (probs[imax] * 100 > 40):
                # list_managment(labels[imax])
                if (prev_action != labels[imax]):
                    action_transition_counter = action_transition_counter + 1
                    if (action_transition_counter > 5):
                        print("diff action")
                        current_action = labels[imax]
                        if (current_action == "picking"):
                            if (sequence_flag):
                                print("cycle complete")
                                print(sequence_details)
                                sequence_list.append(sequence_details)
                                # max_cycle_id = max_id_generation()

                                # picking_list = tuple([max_cycle_id, "1", "picking", sequence_details["picking"][0], "C1"])
                                # insert_sequence_activity_table_list(picking_list)
                                # needle_list = tuple([max_cycle_id, "2", "needle", sequence_details["needle"][0], "C1"])
                                # insert_sequence_activity_table_list(needle_list)
                                # non_needle_list = tuple(
                                #     [max_cycle_id, "3", "non_needle", sequence_details["non_needle"][0], "C1"])
                                # insert_sequence_activity_table_list(non_needle_list)
                                # passing_list = tuple([max_cycle_id, "4", "passing", sequence_details["passing"][0], "C1"])
                                # insert_sequence_activity_table_list(passing_list)
                                #
                                # accumulated_time_for_table = sequence_details["picking"][0] + sequence_details["needle"][0] + \
                                #                              sequence_details["non_needle"][0] + sequence_details["passing"][0]
                                # micro_activity_time_list = tuple(
                                #     [max_cycle_id, 3, sequence_details["picking"][0], sequence_details["non_needle"][0],
                                #      sequence_details["needle"][0], sequence_details["passing"][0],
                                #      round(accumulated_time_for_table, 2)])
                                # insert_micro_activity_time_table_list(micro_activity_time_list)
                                operation_master_list = tuple(
                                    [max_cycle_id, "K1", "Kinari", "S1", 1, sequence_details["sequence_start"][0],
                                     sequence_details["sequence_end"][0], stoppage_counter])
                                if ((sequence_details["picking"][0] + sequence_details["needle"][0] +
                                     sequence_details["non_needle"][0] + sequence_details["passing"][0]) > 15):
                                    self.insert_operation_master_table_list(operation_master_list)
                                    max_cycle_id = max_cycle_id + 1
                                sequence_counter = 0
                                sequence_flag = False
                                stoppage_counter = 0
                                sequence_details = {"sequence_start": ("", False), "picking": (0.0, False),
                                                    "non_needle": (0.0, False), "needle": (0.0, False),
                                                    "passing": (0.0, False),
                                                    "sequence_end": ("", False)}

                            sequence_details["sequence_start"] = (datetime.datetime.now(), True)
                        if (prev_action == "picking"):
                            sequence_details["picking"] = (
                            round(sequence_details["picking"][0] + float(time.time() - prev_action_time), 2), True)
                            sequence_details["needle"] = (round(sequence_details["needle"][0], 2), False)
                            sequence_details["non_needle"] = (round(sequence_details["non_needle"][0], 2), False)
                            sequence_details["passing"] = (round(sequence_details["passing"][0], 2), False)
                            if (sequence_flag == True):
                                sequence_counter = sequence_counter + 1
                            if (sequence_counter > 3):
                                sequence_flag = False
                                sequence_counter = 0

                        if (prev_action == "needle"):
                            sequence_details["needle"] = (
                            round(sequence_details["needle"][0] + float(time.time() - prev_action_time), 2), True)
                            sequence_details["picking"] = (round(sequence_details["picking"][0], 2), False)
                            sequence_details["non_needle"] = (round(sequence_details["non_needle"][0], 2), False)
                            sequence_details["passing"] = (round(sequence_details["passing"][0], 2), False)
                            if (sequence_flag == True):
                                sequence_counter = sequence_counter + 1
                            if (sequence_counter > 3):
                                sequence_flag = False
                                sequence_counter = 0
                        if (prev_action == "non_needle"):
                            sequence_details["non_needle"] = (
                            round(sequence_details["non_needle"][0] + float(time.time() - prev_action_time), 2),
                            True)
                            sequence_details["picking"] = (round(sequence_details["picking"][0], 2), False)
                            sequence_details["needle"] = (round(sequence_details["needle"][0], 2), False)
                            sequence_details["passing"] = (round(sequence_details["passing"][0], 2), False)
                            if (sequence_flag == True):
                                sequence_counter = sequence_counter + 1
                            if (sequence_counter > 3):
                                sequence_flag = False
                                sequence_counter = 0
                        if (prev_action == "passing"):
                            sequence_details["passing"] = (
                            round(sequence_details["passing"][0] + float(time.time() - prev_action_time), 2), True)
                            sequence_details["picking"] = (round(sequence_details["picking"][0], 2), False)
                            sequence_details["needle"] = (round(sequence_details["needle"][0], 2), False)
                            sequence_details["non_needle"] = (round(sequence_details["non_needle"][0], 2), False)
                            if (sequence_flag == True):
                                sequence_counter = sequence_counter + 1
                            if (sequence_counter > 3):
                                sequence_flag = False
                                sequence_counter = 0
                        if (current_action == "passing"):
                            sequence_details["sequence_end"] = (datetime.datetime.now(), True)
                            sequence_flag = True

                            # print(sequence_details)
                            # sequence_list.append(sequence_details)
                            # max_cycle_id = max_id_generation()
                            # picking_list = tuple([max_cycle_id, "1", "picking", sequence_details["picking"][0], "C1"])
                            # insert_sequence_activity_table_list(picking_list)
                            # needle_list = tuple([max_cycle_id, "2", "needle", sequence_details["needle"][0], "C1"])
                            # insert_sequence_activity_table_list(needle_list)
                            # non_needle_list = tuple([max_cycle_id, "3", "non_needle", sequence_details["non_needle"][0], "C1"])
                            # insert_sequence_activity_table_list(non_needle_list)
                            # passing_list = tuple([max_cycle_id, "4", "passing", sequence_details["passing"][0], "C1"])
                            # insert_sequence_activity_table_list(passing_list)
                            #
                            # accumulated_time_for_table = sequence_details["picking"][0] + sequence_details["needle"][0] + sequence_details["non_needle"][0] + sequence_details["passing"][0]
                            # micro_activity_time_list = tuple([max_cycle_id, 3, sequence_details["picking"][0], sequence_details["non_needle"][0], sequence_details["needle"][0], sequence_details["passing"][0], round(accumulated_time_for_table, 2)])
                            # insert_micro_activity_time_table_list(micro_activity_time_list)
                            # operation_master_list = tuple([max_cycle_id, "C1", "Cuff_Cutting", "S1", 3, sequence_details["sequence_start"][0], sequence_details["sequence_end"][0], stoppage_counter])
                            # insert_operation_master_table_list(operation_master_list)

                        if (current_action == "non_needle" and prev_action == "needle"):
                            stoppage_counter = stoppage_counter + 1

                        prev_action_time = time.time()
                        prev_action = labels[imax]
                        action_transition_counter = 0
                text = '{}'.format(labels[imax])
                text = text.upper().replace("_", " ")
                cv2.putText(frame, text, (15, TEXT_VERTICAL_INTERVAL * (i + 1)), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR)

                if (sequence_details["sequence_start"][1] == True):

                    sequence_start_text = "SEQUENCE STARTED                  " + str(
                        sequence_details["sequence_start"][0])
                else:
                    sequence_start_text = "SEQUENCE STARTED                          --"
                if (current_action == "needle"):
                    needle_text = "NEEDLE ACTIVITY : IN PROGRESS           " + str(
                        round(sequence_details["needle"][0], 3))
                else:
                    needle_text = "NEEDLE ACTIVITY :                       " + str(sequence_details["needle"][0])
                if (current_action == "non_needle"):
                    non_needle_text = "NON NEEDLE ACTIVITY : IN PROGRESS   " + str(
                        sequence_details["non_needle"][0])
                else:
                    non_needle_text = "NON NEEDLE ACTIVITY :               " + str(
                        sequence_details["non_needle"][0])
                if (current_action == "picking"):
                    picking_text = "PICKING ACTIVITY : IN PROGRESS         " + str(sequence_details["picking"][0])
                else:
                    picking_text = "PICKING ACTIVITY :                     " + str(sequence_details["picking"][0])
                if (current_action == "passing"):
                    passing_text = "PASSING ACTIVITY : IN PROGRESS         " + str(sequence_details["passing"][0])
                else:
                    passing_text = "PASSING ACTIVITY :                      " + str(sequence_details["passing"][0])
                if (sequence_details["sequence_end"][1] == True):
                    sequence_end_text = "SEQUENCE COMPLETED                 " + str(
                        sequence_details["sequence_end"][0])
                    accumulated_time = sequence_details["picking"][0] + sequence_details["needle"][0] + \
                                       sequence_details["non_needle"][0] + sequence_details["passing"][0]
                    num_cycle = len(sequence_list)
                    # sequence_count_text = "SEQUENCE CYCLE : " + str(num_cycle) + "TIME : " + str(accumulated_time)
                else:
                    sequence_end_text = "SEQUENCE COMPLETE                    --"
                # print(sequence_start_text)
                # print(sequence_end_text)
                # print(needle_text)
                # print(non_needle_text)
                sequence_count_text = "SEQUENCE CYCLE : " + str(num_cycle) + "  PREVIOUS SEQUENCE TIME :  " + str(
                    round(accumulated_time, 2))
                stoppage_text = "NUMBER OF STOPPAGE :  " + str(stoppage_counter)
                cv2.putText(frame, sequence_start_text, (50, 600), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)

                cv2.putText(frame, picking_text, (50, 650), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)
                cv2.putText(frame, needle_text, (50, 700), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)
                cv2.putText(frame, non_needle_text, (50, 750), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)
                cv2.putText(frame, passing_text, (50, 800), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)
                cv2.putText(frame, sequence_end_text, (50, 850), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)
                cv2.putText(frame, sequence_count_text, (50, 900), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)
                cv2.putText(frame, stoppage_text, (50, 950), TEXT_FONT_SIZE,
                            TEXT_FONT_FACE, TEXT_COLOR, 2)
                CementBagCounter.send_payload(self, sequence_start_text=str(sequence_details["sequence_start"][0]),
                                              picking_text=str(round(sequence_details["picking"][0], 3)),
                                              non_needle_text=str(round(sequence_details["non_needle"][0], 3)),
                                              needle_text=str(round(sequence_details["needle"][0], 3)), passing_text=str(round(sequence_details["passing"][0], 3)),
                                              sequence_count_text=str(num_cycle),
                                              stoppage_text=str(stoppage_counter),
                                              sequence_end_text=str(sequence_details["sequence_end"][0]))

        # if (prev_action == "passing"):
        # stoppage_counter = 0
        # sequence_details = {"sequence_start": ("", False), "picking": (0.0, False),
        #                     "non_needle": (0.0, False), "needle": (0.0, False), "passing": (0.0, False),
        #                     "sequence_end": ("", False)}

        return frame

    def list_managment(self, action):
        action_queue.append(action)
        action_queue.pop(0)
        self.transition_check(action_queue)

    def transition_check(self, action_list):
        result = all(element == action_list[0] for element in action_list)
        if result:
            return False
        else:
            return True

    def run_demo(self, model, video_cap, labels):
        tick = time.time()
        while video_cap.isOpened():
            ok, frame = video_cap.read()

            if not ok:
                break

            logits = model.infer_frame(model.preprocess_frame(frame))
            # print(np.argmax(logits[0]))
            # print(logits[0])
            # cv2.waitKey(0)
            probs = F.softmax(logits[0], dim=0)
            frame = self.render_frame(frame, probs, labels)

            tock = time.time()
            expected_time = tick + 1 / 30.
            if tock < expected_time:
                time.sleep(expected_time - tock)
            tick = tock

            cv2.imshow("demo", frame)
            key = cv2.waitKey(1)
            if key == 27 or key == ord('q'):
                break

    # def main():

    # run_demo(model, cap, labels)

        # self.uncounted_objects = ExpiringDict(max_len=model_config.get("uncounted_obj_length", 50),
        #                                       max_age_seconds=model_config.get("uncounted_obj_age", 60))
        # self.janus_metadata = ExpiringDict(max_age_seconds=120, max_len=1)

#     def _pre_process(self, x):
#         """
#         Do preprocessing here, if any
#         :param x: payload
#         :return: payload
#         """
#         return x
#
#     def _post_process(self, x):
#         """
#         Apply post processing here, if any
#         :param x: payload
#         :return: payload
#         """
#         self.rtp.publish(x)  # video stream
#         return x
#
#     def send_payload(self, frame, label='CementBagDetected', bg_color="#474520", font_color="#FFFF00", alert_sound=None,
#                      message="Cement Bag Detected!"):
#         """
#         Insert event to Mongox
#         :param message:
#         :param frame:
#         :param label:
#         :param bg_color:
#         :param font_color:
#         :param alert_sound:
#         :return: None
#         """
#
#         payload = {"deviceId": self.device_id, "message": message,
#                    "frame": 'data:image/jpeg;base64,' + base64.b64encode(
#                        cv2.imencode('.jpg', frame)[1].tostring()).decode("utf-8"), "activity": label,
#                    "bg_color": bg_color, "font_color": font_color, "alert_sound": alert_sound}
#
#         self.mongo_logger.insert_attendance_event_to_mongo(payload)
#
#     def kalman_tracker(
#             self,
#             bboxs,
#             img,
#     ):
#
#         z_box = bboxs
#         x_box = []
#
#         if len(self.tracker_list) > 0:
#             for trk in self.tracker_list:
#                 x_box.append(trk.box)
#
#         matched, unmatched_dets, unmatched_trks = self.assign_detections_to_trackers(x_box, z_box, iou_thrd=0.03)
#
#         # Deal with matched detections
#         if matched.size > 0:
#             for trk_idx, det_idx in matched:
#                 z = z_box[det_idx]
#                 z = np.expand_dims(z, axis=0).T
#                 tmp_trk = self.tracker_list[trk_idx]
#                 tmp_trk.kalman_filter(z)
#                 xx = tmp_trk.x_state.T[0].tolist()
#                 xx = [xx[0], xx[2], xx[4], xx[6]]
#                 x_box[trk_idx] = xx
#                 tmp_trk.box = xx
#                 tmp_trk.hits += 1
#
#         # Deal with unmatched detections
#         if len(unmatched_dets) > 0:
#             for idx in unmatched_dets:
#                 z = z_box[idx]
#                 z = np.expand_dims(z, axis=0).T
#                 tmp_trk = Tracker()  # Create a new tracker
#                 x = np.array([[z[0], 0, z[1], 0, z[2], 0, z[3], 0]]).T
#                 tmp_trk.x_state = x
#                 tmp_trk.predict_only()
#                 xx = tmp_trk.x_state
#                 xx = xx.T[0].tolist()
#                 xx = [xx[0], xx[2], xx[4], xx[6]]
#                 tmp_trk.box = xx
#                 tmp_trk.id = self.track_id_list.popleft()  # assign an ID for the tracker
#
#                 self.tracker_list.append(tmp_trk)
#                 x_box.append(xx)
#
#         # Deal with unmatched tracks
#         if len(unmatched_trks) > 0:
#             for trk_idx in unmatched_trks:
#                 tmp_trk = self.tracker_list[trk_idx]
#                 tmp_trk.no_losses += 1
#                 tmp_trk.predict_only()
#                 xx = tmp_trk.x_state
#                 xx = xx.T[0].tolist()
#                 xx = [xx[0], xx[2], xx[4], xx[6]]
#                 tmp_trk.box = xx
#                 x_box[trk_idx] = xx
#
#         # The list of tracks to be annotated
#         good_tracker_list = []
#         objects = []
#         boxs = []
#         for trk in self.tracker_list:
#             if (trk.hits >= self.min_hits) and (trk.no_losses <= self.max_age):
#                 good_tracker_list.append(trk)
#                 x_cv2 = trk.box
#                 left, top, right, bottom = x_cv2[1], x_cv2[0], x_cv2[3], x_cv2[2]
#                 centroid = [int(left + ((right - left) / 2)), bottom]
#                 objects.append([int(trk.id), centroid])
#                 boxs.append(x_cv2)
#
#         deleted_tracks = filter(lambda _x: _x.no_losses > self.max_age, self.tracker_list)
#
#         for trk in deleted_tracks:
#             self.track_id_list.append(trk.id)
#
#         self.tracker_list = [x for x in self.tracker_list if x.no_losses <= self.max_age]
#        # print("object is ", str(objects))
#
#         return img, objects, boxs
#
#     @staticmethod
#     def assign_detections_to_trackers(
#             trackers,
#             detections,
#             iou_thrd=0.3,
#     ):
#         """
#         From current list of trackers and new detections, output matched detections,
#         un matched trackers, unmatched detections.
#         """
#         iou_mat = np.zeros((len(trackers), len(detections)), dtype=np.float32)
#         for t, trk in enumerate(trackers):
#             for d, det in enumerate(detections):
#                 iou_mat[t, d] = box_iou2(trk, det)
#
#         matched_idx = linear_assignment(-iou_mat)
#
#         unmatched_trackers, unmatched_detections = [], []
#         for t, trk in enumerate(trackers):
#             if t not in matched_idx[:, 0]:
#                 unmatched_trackers.append(t)
#
#         for d, det in enumerate(detections):
#             if d not in matched_idx[:, 1]:
#                 unmatched_detections.append(d)
#
#         matches = []
#
#         for m in matched_idx:
#             if iou_mat[m[0], m[1]] < iou_thrd:
#                 unmatched_trackers.append(m[0])
#                 unmatched_detections.append(m[1])
#             else:
#                 matches.append(m.reshape(1, 2))
#
#         if len(matches) == 0:
#             matches = np.empty((0, 2), dtype=int)
#         else:
#             matches = np.concatenate(matches, axis=0)
#
#         return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
#
#
#
#     def get_line_coordinates(self):
#         """
#         Get the line coordinates from the deployment JSON
#         """
#         if not self.janus_metadata.get('metadata'):
#             self.janus_metadata['metadata'] = get_extra_fields(self.device_id)
#
#         _coordinates = [self.janus_metadata['metadata'].get(coordinate_key) for coordinate_key in
#                         JanusDeploymentConstants.LINE_COORDINATES]
#         _alignment = self.janus_metadata['metadata'].get(JanusDeploymentConstants.ALIGNMENT_KEY)\
#         # _coordinates = [550, 200, 555, 1100]
#         #
#         # _alignment = "vertical"
#         return _alignment, _coordinates
#
#     def line_point_position(self, point):
#         """
#         Get the position of point w.r.t. the line
#         :param point: point to be compared
#         :return: boolean
#         """
#         _alignment, line_coordinates = self.get_line_coordinates()
#
#         assert len(line_coordinates) == 4, "Line coordinates variable is invalid"
#         assert len(point) == 2, "Point variable is invalid"
#
#         _slope = (line_coordinates[3] - line_coordinates[1]) / (line_coordinates[2] - line_coordinates[0])
#         _point_equation_value = point[1] - line_coordinates[1] - _slope * (point[0] - line_coordinates[0])
#         if _point_equation_value > 0:
#             return True
#         else:
#             return False
#
#     def validate_point_position(self, point):
#         """
#         Validate the position of the point w.r.t. the line
#         :param point: centroid
#         :return: bool
#         """
#         _alignment, line_coordinates = self.get_line_coordinates()
#         assert _alignment in [JanusDeploymentConstants.VERTICAL, JanusDeploymentConstants.HORIZONTAL], \
#             "Invalid alignment variable"
#         if _alignment == JanusDeploymentConstants.VERTICAL:
#         # _alignment, line_coordinates = self.get_line_coordinates()
#         # assert _alignment in ["horizontal", "vertical"], \
#         #     "Invalid alignment variable"
#         # print(point)
#         # if _alignment == "vertical":
#             line_y2 = line_coordinates[3]
#             line_y1 = line_coordinates[1]
#             if line_y1 < point[1] < line_y2 or line_y2 < point[1] < line_y1:
#                 return True
#             else:
#                 return False
#         else:
#             line_x2 = line_coordinates[2]
#             line_x1 = line_coordinates[0]
#             if line_x1 < point[0] < line_x2 or line_x2 < point[0] < line_x1:
#
#                 return True
#             else:
#                 return False
#
#     def update_bag_count(self, frame, detection_objects, class_name, detections):
#         """
#         Maintains the bag counts
#         :param frame: image
#         :param detection_objects: detection object having object id and centroids
#         """
#         #for class_name, (objectID, centroid) in zip(classes, detection_objects):
#         for (object_id, det) in zip(detection_objects, detections):
#             centroid = object_id[1]
#             object_id = object_id[0]
#             logger.debug(detections)
#             #print(object_id)
#             frame = draw_circles_on_frame(frame, centroid, radius=10, color=(0, 0, 255),
#                                           thickness=-1)
#             if self.validate_point_position(centroid):
#                 logger.debug("centroid detected")
#                 #if self.validate_point_position(centroid):
#
#                 # # if not isinstance(self.count, int):
#                 # #     logger.debug("Initializing the count variable")
#                 # print("again entering")
#                 # # Initializing the bag count
#                 # if (class_name == "acc_gold"):
#                 #     if not isinstance(self.count_gold, int):
#                 #         logger.debug("Initializing the count variable")
#                 #         self.count_gold = 0
#                 # elif (class_name == "acc_suraksha"):
#                 #     if not isinstance(self.count_suraksha, int):
#                 #         logger.debug("Initializing the count variable")
#                 #         self.count_suraksha = 0
#                 # elif (class_name == "acc_buildcem"):
#                 #     if not isinstance(self.count_whitecem, int):
#                 #         logger.debug("Initializing the count variable")
#                 #         self.count_whitecem = 0
#
#                 if not isinstance(self.initial_object_position, bool):
#                     logger.debug("Initializing the initial object position")
#                     #self.initial_object_position = self.line_point_position(point=centroid)
#                     self.initial_object_position = True
#                     logger.debug(self.initial_object_position)
#
#                 _point_position = self.line_point_position(point=centroid)
#                 #print("object ID is : ", str(objectID))
#                 logger.debug(self.uncounted_objects)
#
#                 # Check point in the same side as the initial object
#                 if _point_position == self.initial_object_position:
#                     logger.debug("same side only")
#                     #print(class_name)
#                     # Check the object is not already counted
#                     if object_id not in self.uncounted_objects:
#                         self.uncounted_objects[object_id] = centroid
#
#
#                 elif object_id in self.uncounted_objects:
#                     self.uncounted_objects.pop(object_id, None)
#                     if ("acc_gold" in class_name):
#                         self.count_gold += 1
#                         mrp_result = self.distances(detections)
#                         if mrp_result:
#                             self.send_payload(resize_to_64_64(frame=frame),
#                                               message='ACC GOLD Bag Detected: Print Detected!')
#                             logger.info(f"Count: {self.count_gold}, Print Found: True")
#                         else:
#                             self.send_payload(resize_to_64_64(frame=frame),
#                                               message='ACC GOLD Bag Detected: Print Missing!')
#                             logger.info(f"Count: {self.count_gold}, Print Found: False")
#
#
#
#                     elif ("acc_suraksha_plus" in class_name):
#                         self.count_suraksha += 1
#                         logger.debug(self.count_suraksha)
#                         mrp_result = self.distances(detections)
#                         if mrp_result:
#                             self.send_payload(resize_to_64_64(frame=frame),
#                                               message='ACC SURAKSHA PLUS Bag Detected: Print Detected!')
#                             logger.info(f"Count: {self.count_suraksha}, Print Found: True")
#                         else:
#                             self.send_payload(resize_to_64_64(frame=frame),
#                                               message='ACC SURAKSHA PLUS Bag Detected: Print Missing!')
#                             logger.info(f"Count: {self.count_suraksha}, Print Found: False")
#
#                     elif ("ambuja_whitecem" in class_name):
#                         self.count_whitecem += 1
#                         mrp_result = self.distances(detections)
#                         if mrp_result:
#                             self.send_payload(resize_to_64_64(frame=frame),
#                                               message='PPC White Bag Detected: Print Detected!')
#                             logger.info(f"Count: {self.count_whitecem}, Print Found: True")
#                         else:
#                             self.send_payload(resize_to_64_64(frame=frame),
#                                               message='PPC White Bag Detected: Print Missing!')
#                             logger.info(f"Count: {self.count_whitecem}, Print Found: False")
#
#                     frame = draw_circles_on_frame(frame, centroid, radius=10, color=(0, 255, 0),
#                                                   thickness=-1)
#                     # if centroid['has_print']:
#                     #     self.send_payload(resize_to_64_64(frame=frame), message='Print Detected!')
#                     #     logger.info(f"Count: {self.count}, Print Found: True")
#                     # else:
#                     #     self.send_payload(resize_to_64_64(frame=frame), message='Print Missing!')
#                     #     logger.info(f"Count: {self.count}, Print Found: False")
#                 else:
#                     frame = draw_circles_on_frame(frame, centroid, radius=10, color=(0, 255, 0),
#                                                   thickness=-1)
#
#         count_text_gold = "ACC_GOLD: " + str(self.count_gold)
#         count_text_suraksha = "ACC_SURAKSHA_PLUS: " + str(self.count_suraksha)
#         count_text_whitecem = "ACC_WHITE_CEM: " + str(self.count_whitecem)
#         cv2.putText(frame, count_text_gold, (1300, 200), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
#                     cv2.LINE_AA)
#         cv2.putText(frame, count_text_suraksha, (1300, 400), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
#                     cv2.LINE_AA)
#         cv2.putText(frame, count_text_whitecem, (1300, 600), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3,
#                     cv2.LINE_AA)
#         return frame
#
#     def draw_line_over_image(self, frame, color=(255, 255, 255)):
#         """
#         Draws line over the counting line
#         :param frame: frame for
#         :param color:
#         :return:
#         """
#         _alignment, line_coordinates = self.get_line_coordinates()
#         assert len(line_coordinates) == 4, "Line coordinates variable is invalid"
#
#         # return cv2.line(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2], line_coordinates[3]),
#         #                 color, 3)
#
#         self.drawline(frame, (line_coordinates[0], line_coordinates[1]), (line_coordinates[2],
#                                                                           line_coordinates[3]), color, thickness=3)
#         return frame
#
#     @staticmethod
#     def drawline(img, pt1, pt2, color, thickness=1, style='dotted', gap=20):
#         dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** .5
#         pts = []
#         for i in np.arange(0, dist, gap):
#             r = i / dist
#             x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5)
#             y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5)
#             p = (x, y)
#             pts.append(p)
#
#         if style == 'dotted':
#             for p in pts:
#                 cv2.circle(img, p, thickness, color, -1)
#         else:
#             s = pts[0]
#             e = pts[0]
#             i = 0
#             for p in pts:
#                 s = e
#                 e = p
#                 if i % 2 == 1:
#                     cv2.line(img, s, e, color, thickness)
#                 i += 1
#
#     def distances(self, detections):
#         mrp_cord = list()
#         cem_bag_cord = list()
#         for det in detections:
#             if(det["class"] == "mrp"):
#                 mrp_cord.append(det["points"])
#             else:
#                 cem_bag_cord.append(det["points"])
#         for c_cord in cem_bag_cord:
#             for m_cord in mrp_cord:
#
#                 if (m_cord[0] > c_cord[0] and m_cord[0] < c_cord[2] and
#                         m_cord[1] > c_cord[1] and m_cord[1] < c_cord[3]):
#                     logger.debug("print is detected")
#                     #cv2.waitKey(0)
#                     return True
#                 else:
#                     return False
#
#
#     def inference(
#             self,
#             frame,
#             classes,
#
#     ):
#         dets = self.yp.predict(frame)
#         class_name = list()
#         bboxs = []
#
#         if dets:
#             for i in dets:
#                 if(i["class"] in classes):
#                     class_name.append(i["class"])
#                     #cv2.rectangle(frame, (i["points"][0], i["points"][1]), (i["points"][2], i["points"][3]), (255, 255, 0), 2)
#                     bboxs.append([i["points"][1], i["points"][0], i["points"][3], i["points"][2]])
#
# #        frame = cv2.rectangle(frame, (bboxs[0][0], bboxs[0][1]), (bboxs[0][2], bboxs[0][3]),(255, 255, 0) , 2)
#         return bboxs, frame, dets, class_name
#
    def _predict(self, obj):
        self.run_demo(model, cap, labels)
        class_list = ["acc_gold", "acc_suraksha_plus", "ambuja_whitecem"]
        mrp = ["mrp"]
        try:
            frame = obj['frame']

            if self.frame_skip:
                if not self.frame_skipping["skip_current_frame"]:
                    dets = self.yp.predict(frame)
                    self.frame_skipping["detection_value"] = dets
                    self.frame_skipping["skip_current_frame"] = True
                else:
                    dets = self.frame_skipping["detection_value"]
                    self.frame_skipping["skip_current_frame"] = False
            else:
                dets, frame, _dets, class_name = self.inference(frame, class_list)


                #dets_mrp, frame_mrp, _dets_mrp, class_name_mrp = self.inference(frame, mrp)

            frame = self.draw_line_over_image(frame)
            # if [True for e in dets if e['class'] == 'cement_bag']:


            #if dets:
            frame, objects, boxs = self.kalman_tracker(dets, frame)
            logger.debug("PRINTING KALMAN OUTPUT")
            logger.debug(objects)
            logger.debug(boxs)
            # for box in boxs:
            #     cv2.rectangle(frame, (box[1], box[0]), (box[3], box[2]), (255, 0, 0), 2)

            #objects,classes_cement, frame = self.track_bags(self.ct1, dets, frame, class_list)
            #_,classes, frame = self.track_bags(self.ct2, _dets, frame, mrp)
            #frame, _, box_mrp = self.kalman_tracker(dets_mrp, frame)
            frame = self.update_bag_count(frame=frame, detection_objects=objects, class_name = class_name, detections = _dets)

            # print("******")
            # print(objects)
            # print(_)

            # self.distances(objects, _)
            logger.debug("self.uncounted_objects --> {}".format(self.uncounted_objects))
            # for each in dets:
            #     color = (255, 255, 0)
            #     class_n = "Cement Bag"
            #
            #     if each['class'] == 'label':
            #         color = (0, 255, 0)
            #         class_n = "Printing Detected!"
            #     cv2.rectangle(frame, (each['points'][0], each['points'][1]), (each['points'][2], each['points'][3]),
            #                   color, 2)
            #     cv2.putText(frame, class_n, (each['points'][2], each['points'][1]), cv2.FONT_HERSHEY_SIMPLEX,
            #                 1, color, 2, cv2.LINE_AA)

            obj['frame'] = cv2.resize(frame, (self.config.get('FRAME_WIDTH'), self.config.
                                              get('FRAME_HEIGHT')))

            # cv2.imshow("output is ", cv2.resize(frame, (900, 600)))
            # cv2.waitKey(1)
        except Exception as e:
            logger.exception(f"Error: {e}", exc_info=True)
            obj['frame'] = cv2.resize(obj['frame'], (self.config.get('FRAME_WIDTH'), self.config.get('FRAME_HEIGHT')))

        return obj

parser = ArgumentParser()
parser.add_argument("--encoder", help="What encoder to use ", default='resnet34')
parser.add_argument("--checkpoint", help="Path to pretrained model (.pth) file",
                    default="/home/shikhin/PycharmProjects/activity_recognition_gokaldas/action_recognition_light/model_v2/save_10.pth")
parser.add_argument("--input-video", type=str, help="Path to input video",
                    default="/home/shikhin/PycharmProjects/activity_recognition_gokaldas/action_recognition_light/model_v2_test/kinari.mp4")
parser.add_argument("--labels", help="Path to labels file (new-line separated file with label names)", type=str,
                    default="/home/shikhin/PycharmProjects/activity_recognition_gokaldas/action_recognition_light/labels.txt")
args = parser.parse_args()

with open(args.labels) as fd:
    labels = fd.read().strip().split('\n')

model = TorchActionRecognition(args.encoder, args.checkpoint, num_classes=len(labels))


cap = cv2.VideoCapture(args.input_video)
# cap = cv2.VideoCapture("yawn.mp4")
# cap = cv2.VideoCapture("rtsp://localhost:8554/stream")
# CementBagCounter_1 = CementBagCounter(ModelWrapper)
# CementBagCounter_1.run_demo(model, cap, labels)



# if __name__ == '__main__':
#     sys.exit(main())


