import logging
import os

from KL import application_protocol
from common.node import Node, NodeId
from KL.constants import ObjectIds
from mainserver.user_manager import UserManager
from threading import Lock
from enum import Enum
from mainserver.address_space import AddressSpace, AttributeService, ViewService, MethodService, NodeManagementService
from mainserver.subscription_service import SubscriptionService
from mainserver.history import HistoryManager
from common.callback import CallbackDispatcher
from KL.application_protocol import *
from KL.attribute_ids import *
from mainserver import standard_address_space
from datetime import datetime
from common import utils


class SessionState(Enum):
    Created = 0
    Activated = 1
    Closed = 2


class InternalServer(object):

    def __init__(self, shelffile=None, user_manager=None, session_cls=None):
        self.logger = logging.getLogger(__name__)

        self.server_callback_dispatcher = CallbackDispatcher()

        self.endpoints = []
        self._channel_id_counter = 5
        self.disabled_clock = False  # for debugging we may want to disable clock that writes too much in log
        self._local_discovery_service = None  # lazy-loading

        self.aspace = AddressSpace()
        self.attribute_service = AttributeService(self.aspace)
        self.view_service = ViewService(self.aspace)
        self.method_service = MethodService(self.aspace)
        self.node_mgt_service = NodeManagementService(self.aspace)

        self.load_standard_address_space(shelffile)

        self.loop = None
        self.asyncio_transports = []
        self.subscription_service = SubscriptionService(self.aspace)

        self.history_manager = HistoryManager(self)
        self.user_manager = user_manager

        # create a session to use on server side
        self.session_cls = session_cls or InternalSession
        self.isession = self.session_cls(self, self.aspace,
                                         self.subscription_service, "Internal", user=UserManager.User.Admin)
        local_value = application_protocol.ServerStatusDataType()
        self.server_status_node = Node(self.isession, NodeId(ObjectIds.Server_ServerStatus))
        self.server_status_node.set_value(local_value)
        self.current_time_node = Node(self.isession, NodeId(ObjectIds.Server_ServerStatus_CurrentTime))
        self._address_space_fixes()
        self.setup_nodes()

    def setup_nodes(self):
        """
        Set up some nodes as defined by spec
        """
        uries = ["http://opcfoundation.org/UA/"]
        ns_node = Node(self.isession, NodeId(ObjectIds.Server_NamespaceArray))
        ns_node.set_value(uries)

    def load_standard_address_space(self, shelffile=None):
        if (shelffile is not None) and (os.path.isfile(shelffile) or os.path.isfile(shelffile + ".db")):
            # import address space from shelf
            self.aspace.load_aspace_shelf(shelffile)
        else:
            # import address space from code generated from xml
            standard_address_space.fill_address_space(self.node_mgt_service)
            # import address space directly from xml, this has performance impact so disabled
            # importer = xmlimporter.XmlImporter(self.node_mgt_service)
            # importer.import_xml("/path/to/python-opcua/schemas/Opc.Ua.NodeSet2.xml", self)

            # if a cache file was supplied a shelve of the standard address space can now be built for next start up
            if shelffile:
                self.aspace.make_aspace_shelf(shelffile)

    def _address_space_fixes(self):
        """
        Looks like the xml definition of address space has some error. This is a good place to fix them
        """

        it = AddReferencesItem()
        it.SourceNodeId = NodeId(ObjectIds.BaseObjectType)
        it.ReferenceTypeId = NodeId(ObjectIds.Organizes)
        it.IsForward = False
        it.TargetNodeId = NodeId(ObjectIds.ObjectTypesFolder)
        it.TargetNodeClass = NodeClass.Object

        it2 = AddReferencesItem()
        it2.SourceNodeId = NodeId(ObjectIds.BaseDataType)
        it2.ReferenceTypeId = NodeId(ObjectIds.Organizes)
        it2.IsForward = False
        it2.TargetNodeId = NodeId(ObjectIds.DataTypesFolder)
        it2.TargetNodeClass = NodeClass.Object

        results = self.isession.add_references([it, it2])

        params = WriteParameters()
        for nodeid in (ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerRead,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerHistoryReadData,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerHistoryReadEvents,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerWrite,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerHistoryUpdateData,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerHistoryUpdateEvents,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerMethodCall,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerBrowse,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerRegisterNodes,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerTranslateBrowsePathsToNodeIds,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxNodesPerNodeManagement,
                       ObjectIds.Server_ServerCapabilities_OperationLimits_MaxMonitoredItemsPerCall):
            attr = WriteValue()
            attr.NodeId = NodeId(nodeid)
            attr.AttributeId = AttributeIds.Value
            attr.Value = DataValue(Variant(10000, VariantType.UInt32), StatusCode(StatusCodes.Good))
            attr.Value.ServerTimestamp = datetime.utcnow()
            params.NodesToWrite.append(attr)
        result = self.isession.write(params)
        result[0].check()

    def add_endpoint(self, endpoint):
        self.endpoints.append(endpoint)

    def start(self):
        self.logger.info("starting internal server")
        self.loop = utils.ThreadLoop()
        self.loop.start()
        # self.subscription_service.set_loop(self.loop)
        serverState = Node(self.isession, NodeId(ObjectIds.Server_ServerStatus_State))
        serverState.set_value(application_protocol.ServerState.Running, VariantType.Int32)
        Node(self.isession, NodeId(ObjectIds.Server_ServerStatus_StartTime)).set_value(datetime.utcnow())
        if not self.disabled_clock:
            self._set_current_time()

    def stop(self):
        self.logger.info("stopping internal server")
        self.isession.close_session()
        # self.subscription_service.set_loop(None)
        # self.history_manager.stop()
        if self.loop:
            self.loop.stop()
            # wait for ThreadLoop to finish before proceeding
            self.loop.join()
            self.loop.close()
            self.loop = None

    def _set_current_time(self):
        self.current_time_node.set_value(datetime.utcnow())
        ssdata = self.server_status_node.get_value()
        ssdata.CurrentTime = datetime.utcnow()
        self.server_status_node.set_value(ssdata)
        self.loop.call_later(1, self._set_current_time)


class InternalSession(object):
    _counter = 10
    _auth_counter = 1000

    def __init__(self, internal_server, aspace, submgr, name, user=UserManager.User.Anonymous):
        self.logger = logging.getLogger(__name__)
        self.iserver = internal_server
        self.aspace = aspace
        self.subscription_service = submgr
        self.name = name
        self.user = user
        self.nonce = None
        self.state = SessionState.Created
        self.session_id = NodeId(self._counter)
        InternalSession._counter += 1
        self.authentication_token = NodeId(self._auth_counter)
        InternalSession._auth_counter += 1
        self.subscriptions = []
        self.logger.info("Created internal session %s", self.name)
        self._lock = Lock()

    def __str__(self):
        return "InternalSession(name:{0}, user:{1}, id:{2}, auth_token:{3})".format(
            self.name, self.user, self.session_id, self.authentication_token)

    def add_references(self, params):
        return self.iserver.node_mgt_service.add_references(params, self.user)

    def write(self, params):
        return self.iserver.attribute_service.write(params, self.user)

    def read(self, params):
        results = self.iserver.attribute_service.read(params)
        return results

    def close_session(self):
        # self.logger.info("close session %s with subscriptions %s", self, self.subscriptions)
        self.state = SessionState.Closed
        # self.delete_subscriptions(self.subscriptions[:])

    def browse(self, params):
        return self.iserver.view_service.browse(params)

    def add_nodes(self, params):
        return self.iserver.node_mgt_service.add_nodes(params, self.user)
