import argparse
import json
import os
import shutil
from collections import defaultdict
from pathlib import Path
import itertools

import numpy as np
from tqdm import tqdm

parser = argparse.ArgumentParser()

parser.add_argument(
    "--coco_json_path",
    default="/home/hodor/dev/data/coco_test/person_keypoints_train2017.json",
    type=str,
    help="input: coco format(json)",
)

parser.add_argument(
    "--yolo_save_root_dir",
    default="/home/hodor/dev/data/coco_test/yolo/",
    type=str,
    help="specify where to save the output dir of labels",
)


def bbox_dict_to_list(bbox, image_size):
  h = 0
  l = bbox[0]
  t = bbox[1]
  w = 0

  img_w, img_h = image_size

  x1 = l/img_w
  y1 = t/img_h
  x2 = (l+w)/img_w
  y2 = (t+h)/img_h
  return [x1, y1]


def convert_bbox_to_yolo(size, box):
    dw = 1.0 / (size[0])
    dh = 1.0 / (size[1])
    x = box[0] + box[2] / 2.0
    y = box[1] + box[3] / 2.0
    w = box[2]
    h = box[3]
    # The round function determines the number of decimal places in (xmin, ymin, xmax, ymax)
    x = round(x * dw, 6)
    w = round(w * dw, 6)
    y = round(y * dh, 6)
    h = round(h * dh, 6)
    return (x, y, w, h)


def convert_keypoints2_list(keypoints, img_width, img_height):
    xiaoshu = 10 ** 6
    arry_x = np.zeros([17, 1])
    num_1 = 0
    for x in keypoints[0:51:3]:
        arry_x[num_1, 0] = int((x / img_width) * xiaoshu) / xiaoshu
        num_1 += 1

    arry_y = np.zeros([17, 1])
    num_2 = 0
    for y in keypoints[1:51:3]:
        arry_y[num_2, 0] = int((y / img_height) * xiaoshu) / xiaoshu
        num_2 += 1

    arry_v = np.zeros([17, 1])
    num_3 = 0
    for v in keypoints[2:51:3]:
        arry_v[num_3, 0] = v
        num_3 += 1

    list_1 = []
    num_4 = 0
    for i in range(17):
        list_1.append(float(arry_x[num_4]))
        list_1.append(float(arry_y[num_4]))
        list_1.append(float(arry_v[num_4]))
        num_4 += 1
    return list_1


def listToString(s):
    # initialize an empty string
    str1 = ""

    # traverse in the string
    for ele in s:
        str1 += str(ele)
        str1 += " "

    # return string
    return str1



def main(root_dir, ana_txt_save_path_txt, json_file):
    xiaoshu = 10 ** 6
    data = json.load(open(json_file, "r"))
    # if not os.path.exists(ana_txt_save_path_txt):
    #     os.makedirs(ana_txt_save_path_txt)

    id_map = (
        {}
    )  # The ids of the coco dataset are not continuous! Remap and output again!
    with open(os.path.join(root_dir, "classes.txt"), "w") as f:
        for i, category in enumerate(data["categories"]):
            f.write(f"{category['name']}\n")
            id_map[category["id"]] = i

    fn = Path(ana_txt_save_path_txt)

    images = {"%g" % x["id"]: x for x in data["images"]}
    # Create image-annotations dict
    imgToAnns = defaultdict(list)
    for ann in data["annotations"]:
        imgToAnns[ann["image_id"]].append(ann)

    list_file = open(os.path.join(root_dir, "train.txt"), "w")

    # Write labels file
    for img_id, anns in tqdm(imgToAnns.items(), desc=f"Annotations {json_file}"):
        img = images["%g" % img_id]
        h, w, f = img["height"], img["width"], img["file_name"]

        bboxes = []
        segments = []
        # print(anns)
        kp = []
        # print(anns)
        for ann in anns:
            # print(ann)
            if "keypoints" in ann.keys():
                    # keypoints
                keypoints = ann["keypoints"]
                # keypoints_list = convert_keypoints2_list(keypoints, w, h)
                # print(keypoints)
                coord = keypoints[:2]
                # print(coord)
                v = keypoints[-1]
                # print(v)
                keypoints_list = bbox_dict_to_list(keypoints, (w, h))
                # keypoints_list = keypoints_list[:3]
                keypoints_list.append(v)
                print(keypoints_list)
                kp.extend(keypoints_list)
            else:

                box = np.array(ann["bbox"], dtype=np.float64)
                # print(box)
                box[:2] += box[2:] / 2  # xy top-left corner to center
                box[[0, 2]] /= w  # normalize x
                box[[1, 3]] /= h  # normalize y
                if box[2] <= 0 or box[3] <= 0:  # if w <= 0 and h <= 0
                    continue

                cls = id_map[ann["category_id"]]
                # print(type(cls))
                cls = 0

                # box = box.tolist()



                # print(box)
        # allList = [x for x in itertools.chain(kp[0], kp[1], kp[2])]
        print("\n")

        box = box.tolist()
        # box.insert(0, str(cls))
        # print(box)

        box = [cls] + box + kp

        b = listToString(box)
        # print(b)
        with open((fn / f).with_suffix(".txt"), "w") as file:
            # for i in range(len(bboxes)):
                # print(i)
                # line = (*(bboxes[i]),)  # cls, box,keypoins
            file.write(b)
        # list_file.write("./images/train/%s\n" % (f))
        file.close()


if __name__ == "__main__":
    args = parser.parse_args()
    print("Parsing and creating directories...")
    ROOT_DIR = "/home/shikhin/Pictures/rectangular_gauge_pose/"
    COCO_JSON_FILE = "/home/shikhin/Pictures/rectangular_gauge_pose/rectangular_gauge_keypoint_annotation-6.json"
    YOLO_ANNO_TXT_SAVE_PATH = ROOT_DIR + "/labels/train/"
    print(ROOT_DIR, COCO_JSON_FILE, YOLO_ANNO_TXT_SAVE_PATH)

    main(ROOT_DIR, YOLO_ANNO_TXT_SAVE_PATH, COCO_JSON_FILE)

## USAGE
# python coco_keypointjson2yolo.py  --coco_json_path /home/hodor/dev/data/coco_test/person_keypoints_train2017.json  --yolo_save_root_dir /home/hodor/dev/data/coco_test/yolo/