import threading
import yaml
yamlfile=open("config.yaml")
data = yaml.load(yamlfile, Loader=yaml.FullLoader)
import time
from scapy.all import *
import requests
import json_operations
import json
from datetime import datetime
import scan_port
from contextlib import closing
from helpers.Packet_Analyzer import Packet_Analyzer

class Active_Scan():
    def __init__(self):
        self.found_hosts=[]
        self.url=data["configuration"]["unifytwin_server_ip_address"]
        self.OT_PORT_LIST = [44818,102,9600,502,47808]
        self.pa = Packet_Analyzer()

    def is_valid_ipv4(self, address):
        pattern = r"^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
        return re.match(pattern, address) is not None

    def check_if_ip_exist_and_insert(self,ip,mac_addr,p):
        try:
            with open("assets.json") as json_file:
                json_decoded = json.load(json_file)
        except:
            json_decoded = {}
        if(ip not in json_decoded):
            json_operations.insert_asset(ip,datetime.now().strftime("%d-%m-%y %H:%M:%S"),"Unknown","ICMP",mac_addr,self.pa.get_vendor(mac_addr),"Unknown","Unknown",self.pa.get_os(bytes(p).hex()),"Unknown")

    def send_icmp(self,ip):
        p = sr1(IP(dst=ip) / ICMP() / "XXXXXXXXXXX", timeout=2, verbose=False)
        if (p):
            pkt = Ether(dst='ff:ff:ff:ff:ff:ff') / ARP(pdst=ip)
            mac_addr = pkt.src
            self.found_hosts.append(ip)
            requests.post(url=self.url, json=data)
            self.check_if_ip_exist_and_insert(ip,mac_addr,p)

    def scan_single_host(self,host):
        host_sent_to_scan="No"
        while(host_sent_to_scan=="No"):
            if(threading.active_count()<int(data["configuration"]["threads"])+5):
                self.send_icmp(host)
                host_sent_to_scan="Yes"
            else:
                time.sleep(1)

    def scan_host_range(self,start, end):
        global found_hosts
        sclass = start.split(".")
        eclass = end.split(".")
        if (sclass[0] == eclass[0]):
            if (sclass[0] == "192"):
                host = "192.168." + sclass[2] + "-" + eclass[2] + "." + sclass[3] + "-" + eclass[3] + ""
                for i in range(int(sclass[2]), int(eclass[2]) + 1):
                    if (int(sclass[2]) != int(eclass[2])):
                        if (i == int(sclass[2])):
                            for ip in range(int(sclass[3]), 256):
                                ti = threading.Thread(target=self.scan_single_host, args=('192.168.' + str(i) + '.' + str(ip),))
                                ti.start()

                        elif (i == int(eclass[2])):
                            for ip in range(0, int(eclass[3]) + 1):
                                ti = threading.Thread(target=self.scan_single_host, args=('192.168.' + str(i) + '.' + str(ip),))
                                ti.start()
                        else:
                            for ip in range(0, 256):
                                ti = threading.Thread(target=self.scan_single_host, args=('192.168.' + str(i) + '.' + str(ip),))
                                ti.start()

                    else:
                        for ip in range(int(sclass[3]), int(eclass[3]) + 1):
                            ti = threading.Thread(target=self.scan_single_host, args=('192.168.' + str(i) + '.' + str(ip),))
                            ti.start()

            elif (sclass[0] == "10"):
                if (0 <= int(sclass[1]) <= 255 and 0 <= int(sclass[2]) <= 255 and 0 <= int(sclass[3]) <= 255 and 0 <= int(
                        eclass[1]) <= 255 and 0 <= int(eclass[2]) <= 255 and 0 <= int(eclass[3]) <= 255):
                    if (int(sclass[1]) == int(eclass[1]) and int(sclass[2]) == int(eclass[2])):
                        for i in range(int(sclass[3]), int(eclass[3]) + 1):
                            ti = threading.Thread(target=self.scan_single_host,
                                                  args=(sclass[0] + '.' + sclass[1] + '.' + sclass[2] + '.' + str(i),))
                            ti.start()

                    elif (int(sclass[1]) == int(eclass[1]) and int(sclass[2]) != int(eclass[2])):
                        for i in range(int(sclass[2]), int(eclass[2]) + 1):
                            if (i == int(sclass[2])):
                                for j in range(int(sclass[3]), 256):
                                    ti = threading.Thread(target=self.scan_single_host,
                                                          args=(sclass[0] + '.' + sclass[1] + '.' + str(i) + '.' + str(j),))
                                    ti.start()
                            elif (i == int(eclass[2])):
                                for j in range(0, int(eclass[3]) + 1):
                                    ti = threading.Thread(target=self.scan_single_host,
                                                          args=(sclass[0] + '.' + sclass[1] + '.' + str(i) + '.' + str(j),))
                                    ti.start()
                            else:
                                for j in range(0, 256):
                                    ti = threading.Thread(target=self.scan_single_host,
                                                          args=(sclass[0] + '.' + sclass[1] + '.' + str(i) + '.' + str(j),))
                                    ti.start()
                    else:
                        for i in range(int(sclass[1]), int(eclass[1]) + 1):
                            if (i == int(sclass[1])):
                                for j in range(int(sclass[2]), 256):
                                    if (j == int(sclass[2])):
                                        for k in range(int(sclass[3]), 256):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()
                                    else:
                                        for k in range(0, 256):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()

                            elif (i == int(eclass[1])):
                                for j in range(0, int(eclass[2]) + 1):
                                    if (j != int(eclass[2])):
                                        for k in range(0, 256):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()
                                    else:
                                        for k in range(0, int(eclass[3]) + 1):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()

                            else:
                                for j in range(0, 256):
                                    for k in range(0, 256):
                                        ti = threading.Thread(target=self.scan_single_host, args=(
                                        sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                        ti.start()
                else:
                    print("invalid address format")
            elif (sclass[0] == "172"):
                if (16 <= int(sclass[1]) <= 31 and 0 <= int(sclass[2]) <= 255 and 0 <= int(sclass[3]) <= 255 and 16 <= int(
                        eclass[1]) <= 31 and 0 <= int(eclass[2]) <= 255 and 0 <= int(eclass[3]) <= 255):
                    if (int(sclass[1]) == int(eclass[1]) and int(sclass[2]) == int(eclass[2])):
                        for i in range(int(sclass[3]), int(eclass[3]) + 1):
                            ti = threading.Thread(target=self.scan_single_host,
                                                  args=(sclass[0] + '.' + sclass[1] + '.' + sclass[2] + '.' + str(i),))
                            ti.start()
                    elif (int(sclass[1]) == int(eclass[1]) and int(sclass[2]) != int(eclass[2])):
                        for i in range(int(sclass[2]), int(eclass[2]) + 1):
                            if (i == int(sclass[2])):
                                for j in range(int(sclass[3]), 256):
                                    ti = threading.Thread(target=self.scan_single_host,
                                                          args=(sclass[0] + '.' + sclass[1] + '.' + str(i) + '.' + str(j),))
                                    ti.start()
                            elif (i == int(eclass[2])):
                                for j in range(0, int(eclass[3]) + 1):
                                    ti = threading.Thread(target=self.scan_single_host,
                                                          args=(sclass[0] + '.' + sclass[1] + '.' + str(i) + '.' + str(j),))
                                    ti.start()
                            else:
                                for j in range(0, 256):
                                    ti = threading.Thread(target=self.scan_single_host,
                                                          args=(sclass[0] + '.' + sclass[1] + '.' + str(i) + '.' + str(j),))
                                    ti.start()
                    else:
                        for i in range(int(sclass[1]), int(eclass[1]) + 1):
                            if (i == int(sclass[1])):
                                for j in range(int(sclass[2]), 256):
                                    if (j == int(sclass[2])):
                                        for k in range(int(sclass[3]), 256):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()
                                    else:
                                        for k in range(0, 256):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()

                            elif (i == int(eclass[1])):
                                for j in range(0, int(eclass[2]) + 1):
                                    if (j != int(eclass[2])):
                                        for k in range(0, 256):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()
                                    else:
                                        for k in range(0, int(eclass[3]) + 1):
                                            ti = threading.Thread(target=self.scan_single_host, args=(
                                            sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                            ti.start()

                            else:
                                for j in range(0, 256):
                                    for k in range(0, 256):
                                        ti = threading.Thread(target=self.scan_single_host, args=(
                                        sclass[0] + '.' + str(i) + '.' + str(j) + '.' + str(k),))
                                        ti.start()
                else:
                    print("invalid address format")
            else:
                if not self.is_valid_ipv4(start) or not self.is_valid_ipv4(end):
                    print("Invalid start or end address.")
                    return

                start_octets = list(map(int, start.split('.')))
                end_octets = list(map(int, end.split('.')))

                current_octets = start_octets.copy()
                while current_octets <= end_octets:
                    current_address = '.'.join(map(str, current_octets))
                    ti = threading.Thread(target=self.scan_single_host, args=(current_address,))
                    ti.start()

                    current_octets[3] += 1
                    for i in range(3, 0, -1):
                        if current_octets[i] > 255:
                            current_octets[i] = 0
                            current_octets[i - 1] += 1
        while threading.active_count()>1:
            time.sleep(1)
        return(self.found_hosts)

    def find_open_port(self,ip):
        for dst_port in self.OT_PORT_LIST:
            with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
                sock.settimeout(3)
                if sock.connect_ex((ip, dst_port)) == 0:
                    scan_port.map_ports(ip, dst_port)

    def scan_for_assets(self):
        start=data["configuration"]["active_scan_start"]
        end = data["configuration"]["active_scan_end"]
        if(data["configuration"]["active_scan"]):
            hosts = self.scan_host_range(start,end)
            for i in hosts:
                self.find_open_port(i)