import logging
from KL.binary import struct_from_binary, header_to_binary, struct_to_binary
from common import utils
from datetime import datetime
import copy

from KL import application_protocol, hand_protocol, types

logger = logging.getLogger('protocol')


class SecureConnection(object):

    def __init__(self, security_policy):
        self._sequence_number = 0
        self._peer_sequence_number = None
        self._incoming_parts = []
        self.security_policy = security_policy
        self._policies = []
        self.security_token = application_protocol.ChannelSecurityToken()
        self.next_security_token = application_protocol.ChannelSecurityToken()
        self.prev_security_token = application_protocol.ChannelSecurityToken()
        self.local_nonce = 0
        self.remote_nonce = 0
        self._open = False
        self._allow_prev_token = False
        self._max_chunk_size = 65536

    def receive_from_header_and_body(self, header, body):
        if header.MessageType == hand_protocol.MessageType.SecureOpen:
            data = body.copy(header.body_size)
            security_header = struct_from_binary(hand_protocol.AsymmetricAlgorithmHeader, data)

            if not self.is_open():
                # Only call select_policy if the channel isn't open. Otherwise
                # it will break the Secure channel renewal.
                self.select_policy(security_header.SecurityPolicyURI, security_header.SenderCertificate)

        elif header.MessageType in (hand_protocol.MessageType.SecureMessage, hand_protocol.MessageType.SecureClose):
            data = body.copy(header.body_size)
            security_header = struct_from_binary(hand_protocol.SymmetricAlgorithmHeader, data)
            self._check_sym_header(security_header)

        if header.MessageType in (hand_protocol.MessageType.SecureMessage, hand_protocol.MessageType.SecureOpen,
                                  hand_protocol.MessageType.SecureClose):
            chunk = MessageChunk.from_header_and_body(self.security_policy, header, body)
            return self._receive(chunk)
        elif header.MessageType == hand_protocol.MessageType.Hello:
            msg = struct_from_binary(hand_protocol.Hello, body)
            self._max_chunk_size = msg.ReceiveBufferSize
            return msg
        elif header.MessageType == hand_protocol.MessageType.Acknowledge:
            msg = struct_from_binary(hand_protocol.Acknowledge, body)
            self._max_chunk_size = msg.SendBufferSize
            return msg
        elif header.MessageType == hand_protocol.MessageType.Error:
            msg = struct_from_binary(hand_protocol.ErrorMessage, body)
            logger.warning("Received an error: %s", msg)
            return msg
        else:
            raise Exception("Unsupported message type {0}".format(header.MessageType))

    def open(self, params, server):
        """
        called on server side to open secure channel
        """

        self.local_nonce = utils.create_nonce(self.security_policy.symmetric_key_size)
        self.remote_nonce = params.ClientNonce
        response = application_protocol.OpenSecureChannelResult()
        response.ServerNonce = self.local_nonce

        if not self._open or params.RequestType == application_protocol.SecurityTokenRequestType.Issue:
            self._open = True
            self.security_token.TokenId = 13  # random value
            self.security_token.ChannelId = server.get_new_channel_id()
            self.security_token.RevisedLifetime = params.RequestedLifetime
            self.security_token.CreatedAt = datetime.utcnow()

            response.SecurityToken = self.security_token

            self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
            self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
        else:
            self.next_security_token = copy.deepcopy(self.security_token)
            self.next_security_token.TokenId += 1
            self.next_security_token.RevisedLifetime = params.RequestedLifetime
            self.next_security_token.CreatedAt = datetime.utcnow()

            response.SecurityToken = self.next_security_token

        return response

    def close(self):
        self._open = False

    def is_open(self):
        return self._open

    def set_policy_factories(self, policies):
        self._policies = policies

    def message_to_binary(self, message, message_type=hand_protocol.MessageType.SecureMessage, request_id=0):
        """
        Convert OPC UA secure message to binary.
        The only supported types are SecureOpen, SecureMessage, SecureClose
        if message_type is SecureMessage, the AlgoritmHeader should be passed as arg
        """
        chunks = MessageChunk.message_to_chunks(self.security_policy, message, self._max_chunk_size,
                                                message_type=message_type, channel_id=self.security_token.ChannelId,
                                                request_id=request_id, token_id=self.security_token.TokenId)
        for chunk in chunks:
            self._sequence_number += 1
            if self._sequence_number >= (1 << 32):
                logger.debug("Wrapping sequence number: %d -> 1", self._sequence_number)
                self._sequence_number = 1
            chunk.SequenceHeader.SequenceNumber = self._sequence_number
        return b"".join([chunk.to_binary() for chunk in chunks])


class MessageChunk(types.FrozenClass):
    def __init__(self, security_policy, body=b'', msg_type=hand_protocol.MessageType.SecureMessage,
                 chunk_type=hand_protocol.ChunkType.Single):
        self.MessageHeader = hand_protocol.Header(msg_type, chunk_type)
        if msg_type in (hand_protocol.MessageType.SecureMessage, hand_protocol.MessageType.SecureClose):
            self.SecurityHeader = hand_protocol.SymmetricAlgorithmHeader()
        elif msg_type == hand_protocol.MessageType.SecureOpen:
            self.SecurityHeader = hand_protocol.AsymmetricAlgorithmHeader()
        else:
            raise Exception("Unsupported message type: {0}".format(msg_type))
        self.SequenceHeader = hand_protocol.SequenceHeader()
        self.Body = body
        self.security_policy = security_policy

    @staticmethod
    def from_header_and_body(security_policy, header, buf):
        assert len(buf) >= header.body_size, 'Full body expected here'
        data = buf.copy(header.body_size)
        buf.skip(header.body_size)
        if header.MessageType in (hand_protocol.MessageType.SecureMessage, hand_protocol.MessageType.SecureClose):
            security_header = struct_from_binary(hand_protocol.SymmetricAlgorithmHeader, data)
            crypto = security_policy.symmetric_cryptography
        elif header.MessageType == hand_protocol.MessageType.SecureOpen:
            security_header = struct_from_binary(hand_protocol.AsymmetricAlgorithmHeader, data)
            crypto = security_policy.asymmetric_cryptography
        else:
            raise Exception("Unsupported message type: {0}".format(header.MessageType))
        obj = MessageChunk(crypto)
        obj.MessageHeader = header
        obj.SecurityHeader = security_header
        decrypted = crypto.decrypt(data.read(len(data)))
        signature_size = crypto.vsignature_size()
        if signature_size > 0:
            signature = decrypted[-signature_size:]
            decrypted = decrypted[:-signature_size]
            crypto.verify(header_to_binary(obj.MessageHeader) + struct_to_binary(obj.SecurityHeader) + decrypted,
                          signature)
        data = utils.Buffer(crypto.remove_padding(decrypted))
        obj.SequenceHeader = struct_from_binary(hand_protocol.SequenceHeader, data)
        obj.Body = data.read(len(data))
        return obj
