import struct
from scapy.all import *
import socket
import json

output={}

modbus_exception_codes = {
  1  : 'ILLEGAL FUNCTION',
  2  : 'ILLEGAL DATA ADDRESS',
  3  : 'ILLEGAL DATA VALUE',
  4  : 'SLAVE DEVICE FAILURE',
  5  : 'ACKNOWLEDGE',
  6  : 'SLAVE DEVICE BUSY',
  8  : 'MEMORY PARITY ERROR',
  10 : 'GATEWAY PATH UNAVAILABLE',
  11 : 'GATEWAY TARGET DEVICE FAILED TO RESPOND'
}

def form_rsid(sid, functionId, data):
	payload_len = 2
	if(len(data)>0):
		payload_len = payload_len+len(data)
	return b"\0\0\0\0\0"+struct.pack('BBB', payload_len, sid, functionId)+data

def discover_device_id_recursive(host, port, sid, start_id, objects_table):
	rsid = form_rsid(sid, 0x2B, b"\x0E\x01"+struct.pack('B',start_id))
	result = comm(host, port, rsid)
	object_value=None
	if (result!=False and len(result) >= 8):
		ret_code = result[7]
		if ( ret_code == 43 and len(result) >= 15 ):
			more_follows = result[11]
			next_object_id =result[12]
			number_of_objects =result[13]
			offset = 15
			for i in range(start_id,(number_of_objects-2)):
				object_len = result[offset+1]
				if object_len == None:
					break
				object_value = result[offset + 1:offset + object_len]
				offset = offset + 2 + object_len
			if ( more_follows == 255 and next_object_id != 0 ):
				return discover_device_id_recursive(host, port, sid, next_object_id, objects_table)
		return object_value


def discover_device_id(host, port, sid):
	return discover_device_id_recursive(host, port, sid, 0x0, {})

def extract_slave_id(response):
	try:
		byte_count = response[8]
		if( byte_count == None or byte_count == 0):
			return None
		return struct.unpack(str(byte_count)+"s",response[9:-1])[0].decode()
	except:
		return None


def comm(host, port, rsid):
	BUFFER_SIZE = 1024
	s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
	try:
		s.settimeout(3)
		s.connect((host,port))
		s.send(rsid)
		data = s.recv(BUFFER_SIZE)
		return data
	except ConnectionError:
		return False
	except:
		return 'terr'

def action(host,port,aggressive):
	count=0
	try:
		for sid in range(1,247):
			rsid=form_rsid(sid, 0x11, b"")
			result=comm(host, port, rsid)
			if(result!=False and len(result)>8):
				output[sid]={'Slave ID data':'Unknown','Device identification':'Unknown'}
				if(result[7]==17 or result[7]==145):
					if(result[7]==17):
						slave_id = extract_slave_id(result)
						output[sid]["Slave ID data"] = slave_id if slave_id else "Unknown"
					elif(result[7]==145):
						exception_code = result[8]
						exception_string = modbus_exception_codes[exception_code] if exception_code<12 and exception_code>0 else None
						if(exception_string==None):
							exception_string = "Unknown exception, Code="+str(exception_code)
						output[sid]["Error"]=exception_string
					else:
						return False
					device_table = discover_device_id(host, port, sid)
					if (device_table!=None and len(device_table) > 0 ):
						device_table = re.sub('[\x00-\x1f]',' ',device_table.decode(errors='ignore'))
						output[sid]["Device identification"] = device_table.replace("  "," ")
						count=0
					if not aggressive:
						output[sid]["sid"]="sid"+str(sid)
						return output[sid]
			elif(result=='terr' and count>2):
				return False
			elif(result==False):
				return False
			else:
				count=count+1
		return json.dumps(output)
	except Exception as e:
		exc_type, exc_obj, exc_tb = sys.exc_info()
		fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
		print(exc_type, fname, exc_tb.tb_lineno, e)


def get_info(ip,port,aggressive=False):
	return(action(ip,port,aggressive))