"""
*******************************************************************
  Copyright (c) 2013, 2018 IBM Corp.

  All rights reserved. This program and the accompanying materials
  are made available under the terms of the Eclipse Public License v1.0
  and Eclipse Distribution License v1.0 which accompany this distribution.

  The Eclipse Public License is available at
     http://www.eclipse.org/legal/epl-v10.html
  and the Eclipse Distribution License is available at
    http://www.eclipse.org/org/documents/edl-v10.php.

  Contributors:
     Ian Craggs - initial implementation and/or documentation
*******************************************************************
"""

"""

Assertions are used to validate incoming data, but are omitted from outgoing packets.  This is
so that the tests that use this package can send invalid data for error testing.

"""

import logging

logger = logging.getLogger("mqttsas")

# Low-level protocol interface

class MQTTException(Exception):
  pass


# Message types
CONNECT, CONNACK, PUBLISH, PUBACK, PUBREC, PUBREL, \
PUBCOMP, SUBSCRIBE, SUBACK, UNSUBSCRIBE, UNSUBACK, \
PINGREQ, PINGRESP, DISCONNECT = range(1, 15)

packetNames = [ "reserved", \
"Connect", "Connack", "Publish", "Puback", "Pubrec", "Pubrel", \
"Pubcomp", "Subscribe", "Suback", "Unsubscribe", "Unsuback", \
"Pingreq", "Pingresp", "Disconnect"]

classNames = [ "reserved", \
"Connects", "Connacks", "Publishes", "Pubacks", "Pubrecs", "Pubrels", \
"Pubcomps", "Subscribes", "Subacks", "Unsubscribes", "Unsubacks", \
"Pingreqs", "Pingresps", "Disconnects"]


def MessageType(byte):
  if byte != None:
    rc = byte[0] >> 4
  else:
    rc = None
  return rc


def getPacket(aSocket):
  "receive the next packet"
  buf = aSocket.recv(1) # get the first byte fixed header
  if buf == b"":
    return None
  if str(aSocket).find("[closed]") != -1:
    closed = True
  else:
    closed = False
  if closed:
    return None
  # now get the remaining length
  multiplier = 1
  remlength = 0
  while 1:
    next = aSocket.recv(1)
    while len(next) == 0:
      next = aSocket.recv(1)
    buf += next
    digit = buf[-1]
    remlength += (digit & 127) * multiplier
    if digit & 128 == 0:
      break
    multiplier *= 128
  # receive the remaining length if there is any
  rest = bytes([])
  if remlength > 0:
    while len(rest) < remlength:
      rest += aSocket.recv(remlength-len(rest))
  assert len(rest) == remlength
  return buf + rest


class FixedHeaders:

  def __init__(self, aMessageType):
    self.MessageType = aMessageType
    self.DUP = False
    self.QoS = 0
    self.RETAIN = False
    self.remainingLength = 0

  def __eq__(self, fh):
    return self.MessageType == fh.MessageType and \
           self.DUP == fh.DUP and \
           self.QoS == fh.QoS and \
           self.RETAIN == fh.RETAIN # and \
           # self.remainingLength == fh.remainingLength

  def __str__(self):
    "return printable stresentation of our data"
    return classNames[self.MessageType]+'(DUP='+str(self.DUP)+ \
           ", QoS="+str(self.QoS)+", Retain="+str(self.RETAIN)

  def pack(self, length):
    "pack data into string buffer ready for transmission down socket"
    buffer = bytes([(self.MessageType << 4) | (self.DUP << 3) |\
                         (self.QoS << 1) | self.RETAIN])
    self.remainingLength = length
    buffer += self.encode(length)
    return buffer

  def encode(self, x):
    assert 0 <= x <= 268435455
    buffer = b''
    while 1:
      digit = x % 128
      x //= 128
      if x > 0:
        digit |= 0x80
      buffer += bytes([digit])
      if x == 0:
        break
    return buffer

  def unpack(self, buffer):
    "unpack data from string buffer into separate fields"
    b0 = buffer[0]
    self.MessageType = b0 >> 4
    self.DUP = ((b0 >> 3) & 0x01) == 1
    self.QoS = (b0 >> 1) & 0x03
    self.RETAIN = (b0 & 0x01) == 1
    (self.remainingLength, bytes) = self.decode(buffer[1:])
    return bytes + 1 # length of fixed header

  def decode(self, buffer):
    multiplier = 1
    value = 0
    bytes = 0
    while 1:
      bytes += 1
      digit = buffer[0]
      buffer = buffer[1:]
      value += (digit & 127) * multiplier
      if digit & 128 == 0:
        break
      multiplier *= 128
    return (value, bytes)


def writeInt16(length):
  return bytes([length // 256, length % 256])

def readInt16(buf):
  return buf[0]*256 + buf[1]

def writeUTF(data):
  # data could be a string, or bytes.  If string, encode into bytes with utf-8
  return writeInt16(len(data)) + (data if type(data) == type(b"") else bytes(data, "utf-8"))

def readUTF(buffer, maxlen):
  if maxlen >= 2:
    length = readInt16(buffer)
  else:
    raise MQTTException("Not enough data to read string length")
  maxlen -= 2
  if length > maxlen:
    raise MQTTException("Length delimited string too long")
  buf = buffer[2:2+length].decode("utf-8")
  logger.info("[MQTT-4.7.3-2] topic names and filters not include null")
  zz = buf.find("\x00") # look for null in the UTF string
  if zz != -1:
    raise MQTTException("[MQTT-1.5.3-2] Null found in UTF data "+buf)
  for c in range (0xD800, 0xDFFF):
    zz = buf.find(chr(c)) # look for D800-DFFF in the UTF string
    if zz != -1:
      raise MQTTException("[MQTT-1.5.3-1] D800-DFFF found in UTF data "+buf)
  if buf.find("\uFEFF") != -1:
    logger.info("[MQTT-1.5.3-3] U+FEFF in UTF string")
  return buf

def writeBytes(buffer):
  return writeInt16(len(buffer)) + buffer

def readBytes(buffer):
  length = readInt16(buffer)
  return buffer[2:2+length]


class Packets:

  def pack(self):
    buffer = self.fh.pack(0)
    return buffer

  def __str__(self):
    return str(self.fh)

  def __eq__(self, packet):
    return self.fh == packet.fh if packet else False


class Connects(Packets):

  def __init__(self, buffer = None):
    self.fh = FixedHeaders(CONNECT)

    # variable header
    self.ProtocolName = "MQTT"
    self.ProtocolVersion = 4
    self.CleanSession = True
    self.WillFlag = False
    self.WillQoS = 0
    self.WillRETAIN = 0
    self.KeepAliveTimer = 30
    self.usernameFlag = False
    self.passwordFlag = False

    # Payload
    self.ClientIdentifier = ""   # UTF-8
    self.WillTopic = None        # UTF-8
    self.WillMessage = None      # binary
    self.username = None         # UTF-8
    self.password = None         # binary

    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    connectFlags = bytes([(self.CleanSession << 1) | (self.WillFlag << 2) | \
                       (self.WillQoS << 3) | (self.WillRETAIN << 5) | \
                       (self.usernameFlag << 6) | (self.passwordFlag << 7)])
    buffer = writeUTF(self.ProtocolName) + bytes([self.ProtocolVersion]) + \
              connectFlags + writeInt16(self.KeepAliveTimer)
    buffer += writeUTF(self.ClientIdentifier)
    if self.WillFlag:
      buffer += writeUTF(self.WillTopic)
      buffer += writeBytes(self.WillMessage)
    if self.usernameFlag:
      buffer += writeUTF(self.username)
    if self.passwordFlag:
      buffer += writeBytes(self.password)
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == CONNECT

    try:
      fhlen = self.fh.unpack(buffer)
      packlen = fhlen + self.fh.remainingLength
      assert len(buffer) >= packlen, "buffer length %d packet length %d" % (len(buffer), packlen)
      curlen = fhlen # points to after header + remaining length
      assert self.fh.DUP == False, "[MQTT-2.1.2-1]"
      assert self.fh.QoS == 0, "[MQTT-2.1.2-1] QoS was not 0, was %d" % self.fh.QoS
      assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]"

      self.ProtocolName = readUTF(buffer[curlen:], packlen - curlen)
      curlen += len(self.ProtocolName) + 2
      assert self.ProtocolName == "MQTT", "Wrong protocol name %s" % self.ProtocolName

      self.ProtocolVersion = buffer[curlen]
      curlen += 1

      connectFlags = buffer[curlen]
      assert (connectFlags & 0x01) == 0, "[MQTT-3.1.2-3] reserved connect flag must be 0"
      self.CleanSession = ((connectFlags >> 1) & 0x01) == 1
      self.WillFlag = ((connectFlags >> 2) & 0x01) == 1
      self.WillQoS = (connectFlags >> 3) & 0x03
      self.WillRETAIN = (connectFlags >> 5) & 0x01
      self.passwordFlag = ((connectFlags >> 6) & 0x01) == 1
      self.usernameFlag = ((connectFlags >> 7) & 0x01) == 1
      curlen +=1

      if self.WillFlag:
        assert self.WillQoS in [0, 1, 2], "[MQTT-3.1.2-14] will qos must not be 3"
      else:
        assert self.WillQoS == 0, "[MQTT-3.1.2-13] will qos must be 0, if will flag is false"
        assert self.WillRETAIN == False, "[MQTT-3.1.2-14] will retain must be false, if will flag is false"

      self.KeepAliveTimer = readInt16(buffer[curlen:])
      curlen += 2
      logger.info("[MQTT-3.1.3-3] Clientid must be present, and first field")
      logger.info("[MQTT-3.1.3-4] Clientid must be Unicode, and between 0 and 65535 bytes long")
      self.ClientIdentifier = readUTF(buffer[curlen:], packlen - curlen)
      curlen += len(self.ClientIdentifier) + 2

      if self.WillFlag:
        self.WillTopic = readUTF(buffer[curlen:], packlen - curlen)
        curlen += len(self.WillTopic) + 2
        self.WillMessage = readBytes(buffer[curlen:])
        curlen += len(self.WillMessage) + 2
        logger.info("[[MQTT-3.1.2-9] will topic and will message fields must be present")
      else:
        self.WillTopic = self.WillMessage = None

      if self.usernameFlag:
        assert len(buffer) > curlen+2, "Buffer too short to read username length"
        self.username = readUTF(buffer[curlen:], packlen - curlen)
        curlen += len(self.username) + 2
        logger.info("[MQTT-3.1.2-19] username must be in payload if user name flag is 1")
      else:
        logger.info("[MQTT-3.1.2-18] username must not be in payload if user name flag is 0")
        assert self.passwordFlag == False, "[MQTT-3.1.2-22] password flag must be 0 if username flag is 0"

      if self.passwordFlag:
        assert len(buffer) > curlen+2, "Buffer too short to read password length"
        self.password = readBytes(buffer[curlen:])
        curlen += len(self.password) + 2
        logger.info("[MQTT-3.1.2-21] password must be in payload if password flag is 0")
      else:
        logger.info("[MQTT-3.1.2-20] password must not be in payload if password flag is 0")

      if self.WillFlag and self.usernameFlag and self.passwordFlag:
        logger.info("[MQTT-3.1.3-1] clientid, will topic, will message, username and password all present")

      assert curlen == packlen, "Packet is wrong length curlen %d != packlen %d"
    except:
      logger.exception("[MQTT-3.1.4-1] server must validate connect packet and close connection without connack if it does not conform")
      raise



  def __str__(self):
    buf = str(self.fh)+", ProtocolName="+str(self.ProtocolName)+", ProtocolVersion=" +\
          str(self.ProtocolVersion)+", CleanSession="+str(self.CleanSession) +\
          ", WillFlag="+str(self.WillFlag)+", KeepAliveTimer=" +\
          str(self.KeepAliveTimer)+", ClientId="+str(self.ClientIdentifier) +\
          ", usernameFlag="+str(self.usernameFlag)+", passwordFlag="+str(self.passwordFlag)
    if self.WillFlag:
      buf += ", WillQoS=" + str(self.WillQoS) +\
             ", WillRETAIN=" + str(self.WillRETAIN) +\
             ", WillTopic='"+ self.WillTopic +\
             "', WillMessage='"+str(self.WillMessage)+"'"
    if self.username:
      buf += ", username="+self.username
    if self.password:
      buf += ", password="+str(self.password)
    return buf+")"

  def __eq__(self, packet):
    rc = Packets.__eq__(self, packet) and \
           self.ProtocolName == packet.ProtocolName and \
           self.ProtocolVersion == packet.ProtocolVersion and \
           self.CleanSession == packet.CleanSession and \
           self.WillFlag == packet.WillFlag and \
           self.KeepAliveTimer == packet.KeepAliveTimer and \
           self.ClientIdentifier == packet.ClientIdentifier and \
           self.WillFlag == packet.WillFlag
    if rc and self.WillFlag:
      rc = self.WillQoS == packet.WillQoS and \
           self.WillRETAIN == packet.WillRETAIN and \
           self.WillTopic == packet.WillTopic and \
           self.WillMessage == packet.WillMessage
    return rc


class Connacks(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, ReturnCode=0):
    self.fh = FixedHeaders(CONNACK)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    self.flags = 0
    self.returnCode = ReturnCode
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = bytes([self.flags, self.returnCode])
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 4
    assert MessageType(buffer) == CONNACK
    self.fh.unpack(buffer)
    assert self.fh.remainingLength == 2, "Connack packet is wrong length %d" % self.fh.remainingLength
    assert buffer[2] in  [0, 1], "Connect Acknowledge Flags"
    self.returnCode = buffer[3]
    assert self.fh.DUP == False, "[MQTT-2.1.2-1]"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1]"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]"

  def __str__(self):
    return str(self.fh)+", Session present="+str((self.flags & 0x01) == 1)+", ReturnCode="+str(self.returnCode)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.returnCode == packet.returnCode


class Disconnects(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False):
    self.fh = FixedHeaders(DISCONNECT)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    if buffer != None:
      self.unpack(buffer)

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == DISCONNECT
    self.fh.unpack(buffer)
    assert self.fh.remainingLength == 0, "Disconnect packet is wrong length %d" % self.fh.remainingLength
    logger.info("[MQTT-3.14.1-1] disconnect reserved bits must be 0")
    assert self.fh.DUP == False, "[MQTT-2.1.2-1]"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1]"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]"

  def __str__(self):
    return str(self.fh)+")"


class Publishes(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0, TopicName="", Payload=b""):
    self.fh = FixedHeaders(PUBLISH)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.topicName = TopicName
    self.messageIdentifier = MsgId
    # payload
    self.data = Payload
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeUTF(self.topicName)
    if self.fh.QoS != 0:
      buffer +=  writeInt16(self.messageIdentifier)
    buffer += self.data
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == PUBLISH
    fhlen = self.fh.unpack(buffer)
    assert self.fh.QoS in [0, 1, 2], "QoS in Publish must be 0, 1, or 2"
    packlen = fhlen + self.fh.remainingLength
    assert len(buffer) >= packlen
    curlen = fhlen
    try:
      self.topicName = readUTF(buffer[fhlen:], packlen - curlen)
    except UnicodeDecodeError:
      logger.info("[MQTT-3.3.2-1] topic name in publish must be utf-8")
      raise
    curlen += len(self.topicName) + 2
    if self.fh.QoS != 0:
      self.messageIdentifier = readInt16(buffer[curlen:])
      logger.info("[MQTT-2.3.1-1] packet indentifier must be in publish if QoS is 1 or 2")
      curlen += 2
      assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0"
    else:
      logger.info("[MQTT-2.3.1-5] no packet indentifier in publish if QoS is 0")
      self.messageIdentifier = 0
    self.data = buffer[curlen:fhlen + self.fh.remainingLength]
    if self.fh.QoS == 0:
      assert self.fh.DUP == False, "[MQTT-2.1.2-4]"
    return fhlen + self.fh.remainingLength

  def __str__(self):
    rc = str(self.fh)
    if self.fh.QoS != 0:
      rc += ", MsgId="+str(self.messageIdentifier)
    rc += ", TopicName="+str(self.topicName)+", Payload="+str(self.data)+")"
    return rc

  def __eq__(self, packet):
    rc = Packets.__eq__(self, packet) and \
         self.topicName == packet.topicName and \
         self.data == packet.data
    if rc and self.fh.QoS != 0:
      rc = self.messageIdentifier == packet.messageIdentifier
    return rc


class Pubacks(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0):
    self.fh = FixedHeaders(PUBACK)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == PUBACK
    fhlen = self.fh.unpack(buffer)
    assert self.fh.remainingLength == 2, "Puback packet is wrong length %d" % self.fh.remainingLength
    assert len(buffer) >= fhlen + self.fh.remainingLength
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.fh.DUP == False, "[MQTT-2.1.2-1] Puback reserved bits must be 0"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1] Puback reserved bits must be 0"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Puback reserved bits must be 0"
    return fhlen + 2

  def __str__(self):
    return str(self.fh)+", MsgId "+str(self.messageIdentifier)

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier


class Pubrecs(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0):
    self.fh = FixedHeaders(PUBREC)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == PUBREC
    fhlen = self.fh.unpack(buffer)
    assert self.fh.remainingLength == 2, "Pubrec packet is wrong length %d" % self.fh.remainingLength
    assert len(buffer) >= fhlen + self.fh.remainingLength
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.fh.DUP == False, "[MQTT-2.1.2-1] Pubrec reserved bits must be 0"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1] Pubrec reserved bits must be 0"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Pubrec reserved bits must be 0"
    return fhlen + 2

  def __str__(self):
    return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier


class Pubrels(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=1, Retain=False, MsgId=0):
    self.fh = FixedHeaders(PUBREL)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == PUBREL
    fhlen = self.fh.unpack(buffer)
    assert self.fh.remainingLength == 2, "Pubrel packet is wrong length %d" % self.fh.remainingLength
    assert len(buffer) >= fhlen + self.fh.remainingLength
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP should be False in PUBREL"
    assert self.fh.QoS == 1, "[MQTT-2.1.2-1] QoS should be 1 in PUBREL"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] RETAIN should be False in PUBREL"
    logger.info("[MQTT-3.6.1-1] bits in fixed header for pubrel are ok")
    return fhlen + 2

  def __str__(self):
    return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier


class Pubcomps(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0):
    self.fh = FixedHeaders(PUBCOMP)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == PUBCOMP
    fhlen = self.fh.unpack(buffer)
    assert len(buffer) >= fhlen + self.fh.remainingLength
    assert self.fh.remainingLength == 2, "Pubcomp packet is wrong length %d" % self.fh.remainingLength
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP should be False in Pubcomp"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1] QoS should be 0 in Pubcomp"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Retain should be false in Pubcomp"
    return fhlen + 2

  def __str__(self):
    return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier


class Subscribes(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=1, Retain=False, MsgId=0, Data=[]):
    self.fh = FixedHeaders(SUBSCRIBE)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    # payload - list of topic, qos pairs
    self.data = Data[:]
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    for d in self.data:
      buffer += writeUTF(d[0]) + bytes([d[1]])
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == SUBSCRIBE
    fhlen = self.fh.unpack(buffer)
    assert len(buffer) >= fhlen + self.fh.remainingLength
    logger.info("[MQTT-2.3.1-1] packet indentifier must be in subscribe")
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0"
    leftlen = self.fh.remainingLength - 2
    self.data = []
    while leftlen > 0:
      topic = readUTF(buffer[-leftlen:], leftlen)
      leftlen -= len(topic) + 2
      qos = buffer[-leftlen]
      assert qos in [0, 1, 2], "[MQTT-3-8.3-2] reserved bits must be zero"
      leftlen -= 1
      self.data.append((topic, qos))
    assert len(self.data) > 0, "[MQTT-3.8.3-1] at least one topic, qos pair must be in subscribe"
    assert leftlen == 0
    assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP must be false in subscribe"
    assert self.fh.QoS == 1, "[MQTT-2.1.2-1] QoS must be 1 in subscribe"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] RETAIN must be false in subscribe"
    return fhlen + self.fh.remainingLength

  def __str__(self):
    return str(self.fh)+", MsgId="+str(self.messageIdentifier)+\
           ", Data="+str(self.data)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier and \
           self.data == packet.data


class Subacks(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0, Data=[]):
    self.fh = FixedHeaders(SUBACK)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    # payload - list of qos
    self.data = Data[:]
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    for d in self.data:
      buffer += bytes([d])
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == SUBACK
    fhlen = self.fh.unpack(buffer)
    assert len(buffer) >= fhlen + self.fh.remainingLength
    self.messageIdentifier = readInt16(buffer[fhlen:])
    leftlen = self.fh.remainingLength - 2
    self.data = []
    while leftlen > 0:
      qos = buffer[-leftlen]
      assert qos in [0, 1, 2, 0x80], "[MQTT-3.9.3-2] return code in QoS must be 0, 1, 2 or 0x80"
      leftlen -= 1
      self.data.append(qos)
    assert leftlen == 0
    assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP should be false in suback"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1] QoS should be 0 in suback"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Retain should be false in suback"
    return fhlen + self.fh.remainingLength

  def __str__(self):
    return str(self.fh)+", MsgId="+str(self.messageIdentifier)+\
           ", Data="+str(self.data)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier and \
           self.data == packet.data


class Unsubscribes(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=1, Retain=False, MsgId=0, Data=[]):
    self.fh = FixedHeaders(UNSUBSCRIBE)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    # payload - list of topics
    self.data = Data[:]
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    for d in self.data:
      buffer += writeUTF(d)
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == UNSUBSCRIBE
    fhlen = self.fh.unpack(buffer)
    assert len(buffer) >= fhlen + self.fh.remainingLength
    logger.info("[MQTT-2.3.1-1] packet indentifier must be in unsubscribe")
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0"
    leftlen = self.fh.remainingLength - 2
    self.data = []
    while leftlen > 0:
      topic = readUTF(buffer[-leftlen:], leftlen)
      leftlen -= len(topic) + 2
      self.data.append(topic)
    assert leftlen == 0
    assert self.fh.DUP == False, "[MQTT-2.1.2-1]"
    assert self.fh.QoS == 1, "[MQTT-2.1.2-1]"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]"
    logger.info("[MQTT-3-10.1-1] fixed header bits are 0,0,1,0")
    return fhlen + self.fh.remainingLength

  def __str__(self):
    return str(self.fh)+", MsgId="+str(self.messageIdentifier)+\
           ", Data="+str(self.data)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier and \
           self.data == packet.data


class Unsubacks(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0):
    self.fh = FixedHeaders(UNSUBACK)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    # variable header
    self.messageIdentifier = MsgId
    if buffer != None:
      self.unpack(buffer)

  def pack(self):
    buffer = writeInt16(self.messageIdentifier)
    buffer = self.fh.pack(len(buffer)) + buffer
    return buffer

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == UNSUBACK
    fhlen = self.fh.unpack(buffer)
    assert len(buffer) >= fhlen + self.fh.remainingLength
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0"
    self.messageIdentifier = readInt16(buffer[fhlen:])
    assert self.fh.DUP == False, "[MQTT-2.1.2-1]"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1]"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]"
    return fhlen + self.fh.remainingLength

  def __str__(self):
    return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")"

  def __eq__(self, packet):
    return Packets.__eq__(self, packet) and \
           self.messageIdentifier == packet.messageIdentifier


class Pingreqs(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False):
    self.fh = FixedHeaders(PINGREQ)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    if buffer != None:
      self.unpack(buffer)

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == PINGREQ
    fhlen = self.fh.unpack(buffer)
    assert self.fh.remainingLength == 0
    assert self.fh.DUP == False, "[MQTT-2.1.2-1]"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1]"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]"
    return fhlen

  def __str__(self):
    return str(self.fh)+")"


class Pingresps(Packets):

  def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False):
    self.fh = FixedHeaders(PINGRESP)
    self.fh.DUP = DUP
    self.fh.QoS = QoS
    self.fh.Retain = Retain
    if buffer != None:
      self.unpack(buffer)

  def unpack(self, buffer):
    assert len(buffer) >= 2
    assert MessageType(buffer) == PINGRESP
    fhlen = self.fh.unpack(buffer)
    assert self.fh.remainingLength == 0
    assert self.fh.DUP == False, "[MQTT-2.1.2-1]"
    assert self.fh.QoS == 0, "[MQTT-2.1.2-1]"
    assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]"
    return fhlen

  def __str__(self):
    return str(self.fh)+")"

classes = [None, Connects, Connacks, Publishes, Pubacks, Pubrecs,
           Pubrels, Pubcomps, Subscribes, Subacks, Unsubscribes,
           Unsubacks, Pingreqs, Pingresps, Disconnects]

def unpackPacket(buffer):
  if MessageType(buffer) != None:
    packet = classes[MessageType(buffer)]()
    packet.unpack(buffer)
  else:
    packet = None
  return packet

if __name__ == "__main__":
  fh = FixedHeaders(CONNECT)
  tests = [0, 56, 127, 128, 8888, 16383, 16384, 65535, 2097151, 2097152,
           20555666, 268435454, 268435455]
  for x in tests:
    try:
      assert x == fh.decode(fh.encode(x))[0]
    except AssertionError:
      print("Test failed for x =", x, fh.decode(fh.encode(x)))
  try:
    fh.decode(fh.encode(268435456))
    print("Error")
  except AssertionError:
    pass

  for packet in classes[1:]:
    before = str(packet())
    after = str(unpackPacket(packet().pack()))
    try:
      assert before == after
    except:
      print("before:", before, "\nafter:", after)
  print("End")