import logging
from threading import RLock, Lock
import time
from KL import hand_protocol, application_protocol
from KL.status_codes import StatusCodes
from KL.types import VariantType, StatusCode, NodeId
import struct
from enum import Enum
import sys
import KL
from common.connection import SecureConnection
from KL.binary import header_to_binary, struct_from_binary, nodeid_from_binary, struct_to_binary
from KL.constants import ObjectIds
from KL import binary

if sys.version_info.major > 2:
    unicode = str


class PublishRequestData(object):

    def __init__(self):
        self.requesthdr = None
        self.seqhdr = None
        self.timestamp = time.time()


class Processor(object):

    def __init__(self, internal_server, socket):
        self.logger = logging.getLogger(__name__)
        self.iserver = internal_server
        self.name = socket.get_extra_info('peername')
        self.sockname = socket.get_extra_info('sockname')
        self.session = None
        self.socket = socket
        self._socketlock = Lock()
        self._datalock = RLock()
        self._publishdata_queue = []
        self._publish_result_queue = []  # used when we need to wait for PublishRequest
        self._connection = SecureConnection(hand_protocol.SecurityPolicy())

    def set_policies(self, policies):
        # self._connection.set_policy_factories(policies)
        pass

    def send_response(self, requesthandle, seqhdr, response, msgtype=hand_protocol.MessageType.SecureMessage):
        with self._socketlock:
            response.ResponseHeader.RequestHandle = requesthandle
            data = self._connection.message_to_binary(
                struct_to_binary(response), message_type=msgtype, request_id=seqhdr.RequestId)

            self.socket.write(data)

    def process(self, header, body):
        msg = self._connection.receive_from_header_and_body(header, body)
        if isinstance(msg, hand_protocol.Message):
            if header.MessageType == hand_protocol.MessageType.SecureOpen:
                self.open_secure_channel(msg.SecurityHeader(), msg.SequenceHeader(), msg.body())

            elif header.MessageType == hand_protocol.MessageType.SecureClose:
                self._connection.close()
                return False

            elif header.MessageType == hand_protocol.MessageType.SecureMessage:
                return self.process_message(msg.SequenceHeader(), msg.body())
        elif isinstance(msg, hand_protocol.Hello):
            ack = hand_protocol.Acknowledge()
            ack.ReceiveBufferSize = msg.ReceiveBufferSize
            ack.SendBufferSize = msg.SendBufferSize
            data = tcp_to_binary(hand_protocol.MessageType.Acknowledge, ack)
            self.socket.write(data)
        # elif isinstance(msg, ua.ErrorMessage):
        #     self.logger.warning("Received an error message type")
        elif msg is None:
            pass  # msg is a ChunkType.Intermediate of an ua.MessageType.SecureMessage
        else:
            self.logger.warning("Unsupported message type: %s", header.MessageType)
            raise Exception(StatusCodes.BadTcpMessageTypeInvalid)
        return True

    def process_message(self, seqhdr, body):
        typeid = nodeid_from_binary(body)
        requesthdr = struct_from_binary(application_protocol.RequestHeader, body)
        try:
            return self._process_message(typeid, requesthdr, seqhdr, body)
        except Exception as e:
            status = StatusCode(e)
            response = application_protocol.ServiceFault()
            response.ResponseHeader.ServiceResult = status
            self.logger.info("sending service fault response: %s (%s)", status, status.name)
            self.send_response(requesthdr.RequestHandle, seqhdr, response)
            return True

    def _process_message(self, typeid, requesthdr, seqhdr, body):
        if typeid == NodeId(ObjectIds.CreateSessionRequest_Encoding_DefaultBinary):
            self.logger.info("Create session request")
            params = struct_from_binary(application_protocol.CreateSessionParameters, body)

            # create the session on server
            self.session = self.iserver.create_session(self.name)
            # get a session creation result to send back
            sessiondata = self.session.create_session(params, sockname=self.sockname)

            response = application_protocol.CreateSessionResponse()
            response.Parameters = sessiondata
            response.Parameters.ServerCertificate = self._connection.security_policy.client_certificate
            if self._connection.security_policy.server_certificate is None:
                data = params.ClientNonce
            else:
                data = self._connection.security_policy.server_certificate + params.ClientNonce
            response.Parameters.ServerSignature.Signature = \
                self._connection.security_policy.asymmetric_cryptography.signature(data)

            response.Parameters.ServerSignature.Algorithm = self._connection.security_policy.AsymmetricSignatureURI

            self.logger.info("sending create session response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.CloseSessionRequest_Encoding_DefaultBinary):
            self.logger.info("Close session request")

            if self.session:
                deletesubs = binary.Primitives.Boolean.unpack(body)
                self.session.close_session(deletesubs)
            else:
                self.logger.info("Request to close non-existing session")

            response = application_protocol.CloseSessionResponse()
            self.logger.info("sending close session response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.ActivateSessionRequest_Encoding_DefaultBinary):
            self.logger.info("Activate session request")
            params = struct_from_binary(application_protocol.ActivateSessionParameters, body)

            if not self.session:
                self.logger.info("request to activate non-existing session")
                raise Exception(StatusCodes.BadSessionIdInvalid)

            if self._connection.security_policy.client_certificate is None:
                data = self.session.nonce
            else:
                data = self._connection.security_policy.client_certificate + self.session.nonce
            self._connection.security_policy.asymmetric_cryptography.verify(data, params.ClientSignature.Signature)

            result = self.session.activate_session(params)

            response = application_protocol.ActivateSessionResponse()
            response.Parameters = result

            self.logger.info("sending read response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.ReadRequest_Encoding_DefaultBinary):
            self.logger.info("Read request")
            params = struct_from_binary(application_protocol.ReadParameters, body)

            results = self.session.read(params)

            response = application_protocol.ReadResponse()
            response.Results = results

            self.logger.info("sending read response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.WriteRequest_Encoding_DefaultBinary):
            self.logger.info("Write request")
            params = struct_from_binary(application_protocol.WriteParameters, body)

            results = self.session.write(params)

            response = application_protocol.WriteResponse()
            response.Results = results

            self.logger.info("sending write response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.BrowseRequest_Encoding_DefaultBinary):
            self.logger.info("Browse request")
            params = struct_from_binary(application_protocol.BrowseParameters, body)

            results = self.session.browse(params)

            response = application_protocol.BrowseResponse()
            response.Results = results

            self.logger.info("sending browse response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.GetEndpointsRequest_Encoding_DefaultBinary):
            self.logger.info("get endpoints request")
            params = struct_from_binary(application_protocol.GetEndpointsParameters, body)

            endpoints = self.iserver.get_endpoints(params, sockname=self.sockname)

            response = application_protocol.GetEndpointsResponse()
            response.Endpoints = endpoints

            self.logger.info("sending get endpoints response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.FindServersRequest_Encoding_DefaultBinary):
            self.logger.info("find servers request")
            params = struct_from_binary(application_protocol.FindServersParameters, body)

            servers = self.local_discovery_service.find_servers(params)

            response = application_protocol.FindServersResponse()
            response.Servers = servers

            self.logger.info("sending find servers response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.RegisterServerRequest_Encoding_DefaultBinary):
            self.logger.info("register server request")
            serv = struct_from_binary(application_protocol.RegisteredServer, body)

            self.local_discovery_service.register_server(serv)

            response = application_protocol.RegisterServerResponse()

            self.logger.info("sending register server response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.RegisterServer2Request_Encoding_DefaultBinary):
            self.logger.info("register server 2 request")
            params = struct_from_binary(application_protocol.RegisterServer2Parameters, body)

            results = self.local_discovery_service.register_server2(params)

            response = application_protocol.RegisterServer2Response()
            response.ConfigurationResults = results

            self.logger.info("sending register server 2 response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.TranslateBrowsePathsToNodeIdsRequest_Encoding_DefaultBinary):
            self.logger.info("translate browsepaths to nodeids request")
            params = struct_from_binary(application_protocol.TranslateBrowsePathsToNodeIdsParameters, body)

            paths = self.session.translate_browsepaths_to_nodeids(params.BrowsePaths)

            response = application_protocol.TranslateBrowsePathsToNodeIdsResponse()
            response.Results = paths

            self.logger.info("sending translate browsepaths to nodeids response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.AddNodesRequest_Encoding_DefaultBinary):
            self.logger.info("add nodes request")
            params = struct_from_binary(application_protocol.AddNodesParameters, body)

            results = self.session.add_nodes(params.NodesToAdd)

            response = application_protocol.AddNodesResponse()
            response.Results = results

            self.logger.info("sending add node response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.DeleteNodesRequest_Encoding_DefaultBinary):
            self.logger.info("delete nodes request")
            params = struct_from_binary(application_protocol.DeleteNodesParameters, body)

            results = self.session.delete_nodes(params)

            response = application_protocol.DeleteNodesResponse()
            response.Results = results

            self.logger.info("sending delete node response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.AddReferencesRequest_Encoding_DefaultBinary):
            self.logger.info("add references request")
            params = struct_from_binary(application_protocol.AddReferencesParameters, body)

            results = self.session.add_references(params.ReferencesToAdd)

            response = application_protocol.AddReferencesResponse()
            response.Results = results

            self.logger.info("sending add references response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.DeleteReferencesRequest_Encoding_DefaultBinary):
            self.logger.info("delete references request")
            params = struct_from_binary(application_protocol.DeleteReferencesParameters, body)

            results = self.session.delete_references(params.ReferencesToDelete)

            response = application_protocol.DeleteReferencesResponse()
            response.Parameters.Results = results

            self.logger.info("sending delete references response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)


        elif typeid == NodeId(ObjectIds.CreateSubscriptionRequest_Encoding_DefaultBinary):
            self.logger.info("create subscription request")
            params = struct_from_binary(application_protocol.CreateSubscriptionParameters, body)

            result = self.session.create_subscription(params, self.forward_publish_response)

            response = application_protocol.CreateSubscriptionResponse()
            response.Parameters = result

            self.logger.info("sending create subscription response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.ModifySubscriptionRequest_Encoding_DefaultBinary):
            self.logger.info("modify subscription request")
            params = struct_from_binary(application_protocol.ModifySubscriptionParameters, body)

            result = self.session.modify_subscription(params, self.forward_publish_response)

            response = application_protocol.ModifySubscriptionResponse()
            response.Parameters = result

            self.logger.info("sending modify subscription response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.DeleteSubscriptionsRequest_Encoding_DefaultBinary):
            self.logger.info("delete subscriptions request")
            params = struct_from_binary(application_protocol.DeleteSubscriptionsParameters, body)

            results = self.session.delete_subscriptions(params.SubscriptionIds)

            response = application_protocol.DeleteSubscriptionsResponse()
            response.Results = results

            self.logger.info("sending delte subscription response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.CreateMonitoredItemsRequest_Encoding_DefaultBinary):
            self.logger.info("create monitored items request")
            params = struct_from_binary(application_protocol.CreateMonitoredItemsParameters, body)
            results = self.session.create_monitored_items(params)

            response = application_protocol.CreateMonitoredItemsResponse()
            response.Results = results

            self.logger.info("sending create monitored items response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.ModifyMonitoredItemsRequest_Encoding_DefaultBinary):
            self.logger.info("modify monitored items request")
            params = struct_from_binary(application_protocol.ModifyMonitoredItemsParameters, body)
            results = self.session.modify_monitored_items(params)

            response = application_protocol.ModifyMonitoredItemsResponse()
            response.Results = results

            self.logger.info("sending modify monitored items response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.DeleteMonitoredItemsRequest_Encoding_DefaultBinary):
            self.logger.info("delete monitored items request")
            params = struct_from_binary(application_protocol.DeleteMonitoredItemsParameters, body)

            results = self.session.delete_monitored_items(params)

            response = application_protocol.DeleteMonitoredItemsResponse()
            response.Results = results

            self.logger.info("sending delete monitored items response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.HistoryReadRequest_Encoding_DefaultBinary):
            self.logger.info("history read request")
            params = struct_from_binary(application_protocol.HistoryReadParameters, body)

            results = self.session.history_read(params)

            response = application_protocol.HistoryReadResponse()
            response.Results = results

            self.logger.info("sending history read response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.RegisterNodesRequest_Encoding_DefaultBinary):
            self.logger.info("register nodes request")
            params = struct_from_binary(application_protocol.RegisterNodesParameters, body)
            self.logger.info("Node registration not implemented")

            response = application_protocol.RegisterNodesResponse()
            response.Parameters.RegisteredNodeIds = params.NodesToRegister

            self.logger.info("sending register nodes response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.UnregisterNodesRequest_Encoding_DefaultBinary):
            self.logger.info("unregister nodes request")
            params = struct_from_binary(application_protocol.UnregisterNodesParameters, body)

            response = application_protocol.UnregisterNodesResponse()

            self.logger.info("sending unregister nodes response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.PublishRequest_Encoding_DefaultBinary):
            self.logger.info("publish request")

            if not self.session:
                return False

            params = struct_from_binary(application_protocol.PublishParameters, body)

            data = PublishRequestData()
            data.requesthdr = requesthdr
            data.seqhdr = seqhdr
            with self._datalock:
                self._publishdata_queue.append(data)  # will be used to send publish answers from server
                if self._publish_result_queue:
                    result = self._publish_result_queue.pop(0)
                    self.forward_publish_response(result)
            self.session.publish(params.SubscriptionAcknowledgements)
            self.logger.info("publish forward to server")

        elif typeid == NodeId(ObjectIds.RepublishRequest_Encoding_DefaultBinary):
            self.logger.info("re-publish request")

            params = struct_from_binary(application_protocol.RepublishParameters, body)
            msg = self.session.republish(params)

            response = application_protocol.RepublishResponse()
            response.NotificationMessage = msg

            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.CloseSecureChannelRequest_Encoding_DefaultBinary):
            self.logger.info("close secure channel request")
            self._connection.close()
            response = application_protocol.CloseSecureChannelResponse()
            self.send_response(requesthdr.RequestHandle, seqhdr, response)
            return False

        elif typeid == NodeId(ObjectIds.CallRequest_Encoding_DefaultBinary):
            self.logger.info("call request")

            params = struct_from_binary(application_protocol.CallParameters, body)

            results = self.session.call(params.MethodsToCall)

            response = application_protocol.CallResponse()
            response.Results = results

            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.SetMonitoringModeRequest_Encoding_DefaultBinary):
            self.logger.info("set monitoring mode request")

            params = struct_from_binary(application_protocol.SetMonitoringModeParameters, body)

            # FIXME: Implement SetMonitoringMode
            # Send dummy results to keep clients happy
            response = application_protocol.SetMonitoringModeResponse()
            results = application_protocol.SetMonitoringModeResult()
            ids = params.MonitoredItemIds
            statuses = [StatusCode(StatusCodes.Good) for node_id in ids]
            results.Results = statuses
            response.Parameters = results

            self.logger.info("sending set monitoring mode response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        elif typeid == NodeId(ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary):
            self.logger.info("set publishing mode request")

            params = struct_from_binary(application_protocol.SetPublishingModeParameters, body)

            # FIXME: Implement SetPublishingMode
            # Send dummy results to keep clients happy
            response = application_protocol.SetPublishingModeResponse()
            results = application_protocol.SetPublishingModeResult()
            ids = params.SubscriptionIds
            statuses = [application_protocol.StatusCode(application_protocol.StatusCodes.Good) for node_id in ids]
            results.Results = statuses
            response.Parameters = results

            self.logger.info("sending set publishing mode response")
            self.send_response(requesthdr.RequestHandle, seqhdr, response)

        else:
            self.logger.warning("Unknown message received %s", typeid)
            raise Exception(application_protocol.StatusCodes.BadServiceUnsupported)

        return True

    def close(self):
        """
        to be called when client has disconnected to ensure we really close
        everything we should
        """
        self.logger.info("Cleanup client connection: %s", self.name)
        if self.session:
            self.session.close_session(True)


def tcp_to_binary(message_type, message):
    packet = []
    header = hand_protocol.Header(message_type, b"F")
    for name, types in message.types:
        val = getattr(message, name)
        k = Primitives1.UInt32.pack(val)
        packet.append(k)
    binmsg = b''.join(packet)
    header.body_size = len(binmsg)
    # return TCPHeader.to_binary(header) + binmsg
    return header_to_binary(header) + binmsg


class _Primitive1(object):
    def __init__(self, fmt):
        self._fmt = fmt
        st = struct.Struct(fmt.format(1))
        self.size = st.size
        self.format = st.format

    def pack(self, data):
        return struct.pack(self.format, data)

    def unpack(self, data):
        return struct.unpack(self.format, data.read(self.size))[0]


class Primitives1(object):
    UInt32 = _Primitive1("<{:d}I")

# def header_to_binary(hdr):
#     b = []
#     b.append(struct.pack("<3ss", hdr.MessageType, hdr.ChunkType))
#     size = hdr.body_size + 8
#     b.append(Primitives1.UInt32.pack(size))
#     return b"".join(b)


# def header_from_binary(data):
#     hdr = hand_protocol.Header()
#     hdr.MessageType, hdr.ChunkType, hdr.packet_size = struct.unpack("<3scI", data.read(8))
#     hdr.body_size = hdr.packet_size - 8
#     if hdr.MessageType in (hand_protocol.MessageType.SecureOpen, hand_protocol.MessageType.SecureClose, hand_protocol.MessageType.SecureMessage):
#         hdr.body_size -= 4
#         hdr.ChannelId = Primitives1.UInt32.unpack(data)
#     return hdr

# class TCPHeader:
#
#     def __init__(self, source_port, dest_port, seq_num, ack_num, flags):
#         self.source_port = source_port
#         self.dest_port = dest_port
#         self.seq_num = seq_num
#         self.ack_num = ack_num
#         self.flags = flags
#
#     def to_binary(self):
#         format_string = "<3ss"  # Format string for struct.pack
#         binary_data = struct.pack(format_string, self.source_port, self.dest_port,
#                                   self.seq_num, self.ack_num, self.flags)
#         return binary_data


# Example TCP header values
# source_port = 12345
# dest_port = 80
# seq_num = 1000
# ack_num = 2000
# flags = 0b0101010101010101
#
# # Create TCP header instance
# tcp_header = TCPHeader(source_port, dest_port, seq_num, ack_num, flags)
#
# # Convert TCP header to binary
# tcp_binary = tcp_header.to_binary()
