""" ******************************************************************* Copyright (c) 2013, 2018 IBM Corp. All rights reserved. This program and the accompanying materials are made available under the terms of the Eclipse Public License v1.0 and Eclipse Distribution License v1.0 which accompany this distribution. The Eclipse Public License is available at http://www.eclipse.org/legal/epl-v10.html and the Eclipse Distribution License is available at http://www.eclipse.org/org/documents/edl-v10.php. Contributors: Ian Craggs - initial implementation and/or documentation ******************************************************************* """ """ Assertions are used to validate incoming data, but are omitted from outgoing packets. This is so that the tests that use this package can send invalid data for error testing. """ import logging logger = logging.getLogger("mqttsas") # Low-level protocol interface class MQTTException(Exception): pass # Message types CONNECT, CONNACK, PUBLISH, PUBACK, PUBREC, PUBREL, \ PUBCOMP, SUBSCRIBE, SUBACK, UNSUBSCRIBE, UNSUBACK, \ PINGREQ, PINGRESP, DISCONNECT = range(1, 15) packetNames = [ "reserved", \ "Connect", "Connack", "Publish", "Puback", "Pubrec", "Pubrel", \ "Pubcomp", "Subscribe", "Suback", "Unsubscribe", "Unsuback", \ "Pingreq", "Pingresp", "Disconnect"] classNames = [ "reserved", \ "Connects", "Connacks", "Publishes", "Pubacks", "Pubrecs", "Pubrels", \ "Pubcomps", "Subscribes", "Subacks", "Unsubscribes", "Unsubacks", \ "Pingreqs", "Pingresps", "Disconnects"] def MessageType(byte): if byte != None: rc = byte[0] >> 4 else: rc = None return rc def getPacket(aSocket): "receive the next packet" buf = aSocket.recv(1) # get the first byte fixed header if buf == b"": return None if str(aSocket).find("[closed]") != -1: closed = True else: closed = False if closed: return None # now get the remaining length multiplier = 1 remlength = 0 while 1: next = aSocket.recv(1) while len(next) == 0: next = aSocket.recv(1) buf += next digit = buf[-1] remlength += (digit & 127) * multiplier if digit & 128 == 0: break multiplier *= 128 # receive the remaining length if there is any rest = bytes([]) if remlength > 0: while len(rest) < remlength: rest += aSocket.recv(remlength-len(rest)) assert len(rest) == remlength return buf + rest class FixedHeaders: def __init__(self, aMessageType): self.MessageType = aMessageType self.DUP = False self.QoS = 0 self.RETAIN = False self.remainingLength = 0 def __eq__(self, fh): return self.MessageType == fh.MessageType and \ self.DUP == fh.DUP and \ self.QoS == fh.QoS and \ self.RETAIN == fh.RETAIN # and \ # self.remainingLength == fh.remainingLength def __str__(self): "return printable stresentation of our data" return classNames[self.MessageType]+'(DUP='+str(self.DUP)+ \ ", QoS="+str(self.QoS)+", Retain="+str(self.RETAIN) def pack(self, length): "pack data into string buffer ready for transmission down socket" buffer = bytes([(self.MessageType << 4) | (self.DUP << 3) |\ (self.QoS << 1) | self.RETAIN]) self.remainingLength = length buffer += self.encode(length) return buffer def encode(self, x): assert 0 <= x <= 268435455 buffer = b'' while 1: digit = x % 128 x //= 128 if x > 0: digit |= 0x80 buffer += bytes([digit]) if x == 0: break return buffer def unpack(self, buffer): "unpack data from string buffer into separate fields" b0 = buffer[0] self.MessageType = b0 >> 4 self.DUP = ((b0 >> 3) & 0x01) == 1 self.QoS = (b0 >> 1) & 0x03 self.RETAIN = (b0 & 0x01) == 1 (self.remainingLength, bytes) = self.decode(buffer[1:]) return bytes + 1 # length of fixed header def decode(self, buffer): multiplier = 1 value = 0 bytes = 0 while 1: bytes += 1 digit = buffer[0] buffer = buffer[1:] value += (digit & 127) * multiplier if digit & 128 == 0: break multiplier *= 128 return (value, bytes) def writeInt16(length): return bytes([length // 256, length % 256]) def readInt16(buf): return buf[0]*256 + buf[1] def writeUTF(data): # data could be a string, or bytes. If string, encode into bytes with utf-8 return writeInt16(len(data)) + (data if type(data) == type(b"") else bytes(data, "utf-8")) def readUTF(buffer, maxlen): if maxlen >= 2: length = readInt16(buffer) else: raise MQTTException("Not enough data to read string length") maxlen -= 2 if length > maxlen: raise MQTTException("Length delimited string too long") buf = buffer[2:2+length].decode("utf-8") logger.info("[MQTT-4.7.3-2] topic names and filters not include null") zz = buf.find("\x00") # look for null in the UTF string if zz != -1: raise MQTTException("[MQTT-1.5.3-2] Null found in UTF data "+buf) for c in range (0xD800, 0xDFFF): zz = buf.find(chr(c)) # look for D800-DFFF in the UTF string if zz != -1: raise MQTTException("[MQTT-1.5.3-1] D800-DFFF found in UTF data "+buf) if buf.find("\uFEFF") != -1: logger.info("[MQTT-1.5.3-3] U+FEFF in UTF string") return buf def writeBytes(buffer): return writeInt16(len(buffer)) + buffer def readBytes(buffer): length = readInt16(buffer) return buffer[2:2+length] class Packets: def pack(self): buffer = self.fh.pack(0) return buffer def __str__(self): return str(self.fh) def __eq__(self, packet): return self.fh == packet.fh if packet else False class Connects(Packets): def __init__(self, buffer = None): self.fh = FixedHeaders(CONNECT) # variable header self.ProtocolName = "MQTT" self.ProtocolVersion = 4 self.CleanSession = True self.WillFlag = False self.WillQoS = 0 self.WillRETAIN = 0 self.KeepAliveTimer = 30 self.usernameFlag = False self.passwordFlag = False # Payload self.ClientIdentifier = "" # UTF-8 self.WillTopic = None # UTF-8 self.WillMessage = None # binary self.username = None # UTF-8 self.password = None # binary if buffer != None: self.unpack(buffer) def pack(self): connectFlags = bytes([(self.CleanSession << 1) | (self.WillFlag << 2) | \ (self.WillQoS << 3) | (self.WillRETAIN << 5) | \ (self.usernameFlag << 6) | (self.passwordFlag << 7)]) buffer = writeUTF(self.ProtocolName) + bytes([self.ProtocolVersion]) + \ connectFlags + writeInt16(self.KeepAliveTimer) buffer += writeUTF(self.ClientIdentifier) if self.WillFlag: buffer += writeUTF(self.WillTopic) buffer += writeBytes(self.WillMessage) if self.usernameFlag: buffer += writeUTF(self.username) if self.passwordFlag: buffer += writeBytes(self.password) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == CONNECT try: fhlen = self.fh.unpack(buffer) packlen = fhlen + self.fh.remainingLength assert len(buffer) >= packlen, "buffer length %d packet length %d" % (len(buffer), packlen) curlen = fhlen # points to after header + remaining length assert self.fh.DUP == False, "[MQTT-2.1.2-1]" assert self.fh.QoS == 0, "[MQTT-2.1.2-1] QoS was not 0, was %d" % self.fh.QoS assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]" self.ProtocolName = readUTF(buffer[curlen:], packlen - curlen) curlen += len(self.ProtocolName) + 2 assert self.ProtocolName == "MQTT", "Wrong protocol name %s" % self.ProtocolName self.ProtocolVersion = buffer[curlen] curlen += 1 connectFlags = buffer[curlen] assert (connectFlags & 0x01) == 0, "[MQTT-3.1.2-3] reserved connect flag must be 0" self.CleanSession = ((connectFlags >> 1) & 0x01) == 1 self.WillFlag = ((connectFlags >> 2) & 0x01) == 1 self.WillQoS = (connectFlags >> 3) & 0x03 self.WillRETAIN = (connectFlags >> 5) & 0x01 self.passwordFlag = ((connectFlags >> 6) & 0x01) == 1 self.usernameFlag = ((connectFlags >> 7) & 0x01) == 1 curlen +=1 if self.WillFlag: assert self.WillQoS in [0, 1, 2], "[MQTT-3.1.2-14] will qos must not be 3" else: assert self.WillQoS == 0, "[MQTT-3.1.2-13] will qos must be 0, if will flag is false" assert self.WillRETAIN == False, "[MQTT-3.1.2-14] will retain must be false, if will flag is false" self.KeepAliveTimer = readInt16(buffer[curlen:]) curlen += 2 logger.info("[MQTT-3.1.3-3] Clientid must be present, and first field") logger.info("[MQTT-3.1.3-4] Clientid must be Unicode, and between 0 and 65535 bytes long") self.ClientIdentifier = readUTF(buffer[curlen:], packlen - curlen) curlen += len(self.ClientIdentifier) + 2 if self.WillFlag: self.WillTopic = readUTF(buffer[curlen:], packlen - curlen) curlen += len(self.WillTopic) + 2 self.WillMessage = readBytes(buffer[curlen:]) curlen += len(self.WillMessage) + 2 logger.info("[[MQTT-3.1.2-9] will topic and will message fields must be present") else: self.WillTopic = self.WillMessage = None if self.usernameFlag: assert len(buffer) > curlen+2, "Buffer too short to read username length" self.username = readUTF(buffer[curlen:], packlen - curlen) curlen += len(self.username) + 2 logger.info("[MQTT-3.1.2-19] username must be in payload if user name flag is 1") else: logger.info("[MQTT-3.1.2-18] username must not be in payload if user name flag is 0") assert self.passwordFlag == False, "[MQTT-3.1.2-22] password flag must be 0 if username flag is 0" if self.passwordFlag: assert len(buffer) > curlen+2, "Buffer too short to read password length" self.password = readBytes(buffer[curlen:]) curlen += len(self.password) + 2 logger.info("[MQTT-3.1.2-21] password must be in payload if password flag is 0") else: logger.info("[MQTT-3.1.2-20] password must not be in payload if password flag is 0") if self.WillFlag and self.usernameFlag and self.passwordFlag: logger.info("[MQTT-3.1.3-1] clientid, will topic, will message, username and password all present") assert curlen == packlen, "Packet is wrong length curlen %d != packlen %d" except: logger.exception("[MQTT-3.1.4-1] server must validate connect packet and close connection without connack if it does not conform") raise def __str__(self): buf = str(self.fh)+", ProtocolName="+str(self.ProtocolName)+", ProtocolVersion=" +\ str(self.ProtocolVersion)+", CleanSession="+str(self.CleanSession) +\ ", WillFlag="+str(self.WillFlag)+", KeepAliveTimer=" +\ str(self.KeepAliveTimer)+", ClientId="+str(self.ClientIdentifier) +\ ", usernameFlag="+str(self.usernameFlag)+", passwordFlag="+str(self.passwordFlag) if self.WillFlag: buf += ", WillQoS=" + str(self.WillQoS) +\ ", WillRETAIN=" + str(self.WillRETAIN) +\ ", WillTopic='"+ self.WillTopic +\ "', WillMessage='"+str(self.WillMessage)+"'" if self.username: buf += ", username="+self.username if self.password: buf += ", password="+str(self.password) return buf+")" def __eq__(self, packet): rc = Packets.__eq__(self, packet) and \ self.ProtocolName == packet.ProtocolName and \ self.ProtocolVersion == packet.ProtocolVersion and \ self.CleanSession == packet.CleanSession and \ self.WillFlag == packet.WillFlag and \ self.KeepAliveTimer == packet.KeepAliveTimer and \ self.ClientIdentifier == packet.ClientIdentifier and \ self.WillFlag == packet.WillFlag if rc and self.WillFlag: rc = self.WillQoS == packet.WillQoS and \ self.WillRETAIN == packet.WillRETAIN and \ self.WillTopic == packet.WillTopic and \ self.WillMessage == packet.WillMessage return rc class Connacks(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, ReturnCode=0): self.fh = FixedHeaders(CONNACK) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain self.flags = 0 self.returnCode = ReturnCode if buffer != None: self.unpack(buffer) def pack(self): buffer = bytes([self.flags, self.returnCode]) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 4 assert MessageType(buffer) == CONNACK self.fh.unpack(buffer) assert self.fh.remainingLength == 2, "Connack packet is wrong length %d" % self.fh.remainingLength assert buffer[2] in [0, 1], "Connect Acknowledge Flags" self.returnCode = buffer[3] assert self.fh.DUP == False, "[MQTT-2.1.2-1]" assert self.fh.QoS == 0, "[MQTT-2.1.2-1]" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]" def __str__(self): return str(self.fh)+", Session present="+str((self.flags & 0x01) == 1)+", ReturnCode="+str(self.returnCode)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.returnCode == packet.returnCode class Disconnects(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False): self.fh = FixedHeaders(DISCONNECT) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain if buffer != None: self.unpack(buffer) def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == DISCONNECT self.fh.unpack(buffer) assert self.fh.remainingLength == 0, "Disconnect packet is wrong length %d" % self.fh.remainingLength logger.info("[MQTT-3.14.1-1] disconnect reserved bits must be 0") assert self.fh.DUP == False, "[MQTT-2.1.2-1]" assert self.fh.QoS == 0, "[MQTT-2.1.2-1]" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]" def __str__(self): return str(self.fh)+")" class Publishes(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0, TopicName="", Payload=b""): self.fh = FixedHeaders(PUBLISH) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.topicName = TopicName self.messageIdentifier = MsgId # payload self.data = Payload if buffer != None: self.unpack(buffer) def pack(self): buffer = writeUTF(self.topicName) if self.fh.QoS != 0: buffer += writeInt16(self.messageIdentifier) buffer += self.data buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == PUBLISH fhlen = self.fh.unpack(buffer) assert self.fh.QoS in [0, 1, 2], "QoS in Publish must be 0, 1, or 2" packlen = fhlen + self.fh.remainingLength assert len(buffer) >= packlen curlen = fhlen try: self.topicName = readUTF(buffer[fhlen:], packlen - curlen) except UnicodeDecodeError: logger.info("[MQTT-3.3.2-1] topic name in publish must be utf-8") raise curlen += len(self.topicName) + 2 if self.fh.QoS != 0: self.messageIdentifier = readInt16(buffer[curlen:]) logger.info("[MQTT-2.3.1-1] packet indentifier must be in publish if QoS is 1 or 2") curlen += 2 assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0" else: logger.info("[MQTT-2.3.1-5] no packet indentifier in publish if QoS is 0") self.messageIdentifier = 0 self.data = buffer[curlen:fhlen + self.fh.remainingLength] if self.fh.QoS == 0: assert self.fh.DUP == False, "[MQTT-2.1.2-4]" return fhlen + self.fh.remainingLength def __str__(self): rc = str(self.fh) if self.fh.QoS != 0: rc += ", MsgId="+str(self.messageIdentifier) rc += ", TopicName="+str(self.topicName)+", Payload="+str(self.data)+")" return rc def __eq__(self, packet): rc = Packets.__eq__(self, packet) and \ self.topicName == packet.topicName and \ self.data == packet.data if rc and self.fh.QoS != 0: rc = self.messageIdentifier == packet.messageIdentifier return rc class Pubacks(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0): self.fh = FixedHeaders(PUBACK) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == PUBACK fhlen = self.fh.unpack(buffer) assert self.fh.remainingLength == 2, "Puback packet is wrong length %d" % self.fh.remainingLength assert len(buffer) >= fhlen + self.fh.remainingLength self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.fh.DUP == False, "[MQTT-2.1.2-1] Puback reserved bits must be 0" assert self.fh.QoS == 0, "[MQTT-2.1.2-1] Puback reserved bits must be 0" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Puback reserved bits must be 0" return fhlen + 2 def __str__(self): return str(self.fh)+", MsgId "+str(self.messageIdentifier) def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier class Pubrecs(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0): self.fh = FixedHeaders(PUBREC) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == PUBREC fhlen = self.fh.unpack(buffer) assert self.fh.remainingLength == 2, "Pubrec packet is wrong length %d" % self.fh.remainingLength assert len(buffer) >= fhlen + self.fh.remainingLength self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.fh.DUP == False, "[MQTT-2.1.2-1] Pubrec reserved bits must be 0" assert self.fh.QoS == 0, "[MQTT-2.1.2-1] Pubrec reserved bits must be 0" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Pubrec reserved bits must be 0" return fhlen + 2 def __str__(self): return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier class Pubrels(Packets): def __init__(self, buffer=None, DUP=False, QoS=1, Retain=False, MsgId=0): self.fh = FixedHeaders(PUBREL) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == PUBREL fhlen = self.fh.unpack(buffer) assert self.fh.remainingLength == 2, "Pubrel packet is wrong length %d" % self.fh.remainingLength assert len(buffer) >= fhlen + self.fh.remainingLength self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP should be False in PUBREL" assert self.fh.QoS == 1, "[MQTT-2.1.2-1] QoS should be 1 in PUBREL" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] RETAIN should be False in PUBREL" logger.info("[MQTT-3.6.1-1] bits in fixed header for pubrel are ok") return fhlen + 2 def __str__(self): return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier class Pubcomps(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0): self.fh = FixedHeaders(PUBCOMP) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == PUBCOMP fhlen = self.fh.unpack(buffer) assert len(buffer) >= fhlen + self.fh.remainingLength assert self.fh.remainingLength == 2, "Pubcomp packet is wrong length %d" % self.fh.remainingLength self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP should be False in Pubcomp" assert self.fh.QoS == 0, "[MQTT-2.1.2-1] QoS should be 0 in Pubcomp" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Retain should be false in Pubcomp" return fhlen + 2 def __str__(self): return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier class Subscribes(Packets): def __init__(self, buffer=None, DUP=False, QoS=1, Retain=False, MsgId=0, Data=[]): self.fh = FixedHeaders(SUBSCRIBE) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId # payload - list of topic, qos pairs self.data = Data[:] if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) for d in self.data: buffer += writeUTF(d[0]) + bytes([d[1]]) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == SUBSCRIBE fhlen = self.fh.unpack(buffer) assert len(buffer) >= fhlen + self.fh.remainingLength logger.info("[MQTT-2.3.1-1] packet indentifier must be in subscribe") self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0" leftlen = self.fh.remainingLength - 2 self.data = [] while leftlen > 0: topic = readUTF(buffer[-leftlen:], leftlen) leftlen -= len(topic) + 2 qos = buffer[-leftlen] assert qos in [0, 1, 2], "[MQTT-3-8.3-2] reserved bits must be zero" leftlen -= 1 self.data.append((topic, qos)) assert len(self.data) > 0, "[MQTT-3.8.3-1] at least one topic, qos pair must be in subscribe" assert leftlen == 0 assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP must be false in subscribe" assert self.fh.QoS == 1, "[MQTT-2.1.2-1] QoS must be 1 in subscribe" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] RETAIN must be false in subscribe" return fhlen + self.fh.remainingLength def __str__(self): return str(self.fh)+", MsgId="+str(self.messageIdentifier)+\ ", Data="+str(self.data)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier and \ self.data == packet.data class Subacks(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0, Data=[]): self.fh = FixedHeaders(SUBACK) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId # payload - list of qos self.data = Data[:] if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) for d in self.data: buffer += bytes([d]) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == SUBACK fhlen = self.fh.unpack(buffer) assert len(buffer) >= fhlen + self.fh.remainingLength self.messageIdentifier = readInt16(buffer[fhlen:]) leftlen = self.fh.remainingLength - 2 self.data = [] while leftlen > 0: qos = buffer[-leftlen] assert qos in [0, 1, 2, 0x80], "[MQTT-3.9.3-2] return code in QoS must be 0, 1, 2 or 0x80" leftlen -= 1 self.data.append(qos) assert leftlen == 0 assert self.fh.DUP == False, "[MQTT-2.1.2-1] DUP should be false in suback" assert self.fh.QoS == 0, "[MQTT-2.1.2-1] QoS should be 0 in suback" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1] Retain should be false in suback" return fhlen + self.fh.remainingLength def __str__(self): return str(self.fh)+", MsgId="+str(self.messageIdentifier)+\ ", Data="+str(self.data)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier and \ self.data == packet.data class Unsubscribes(Packets): def __init__(self, buffer=None, DUP=False, QoS=1, Retain=False, MsgId=0, Data=[]): self.fh = FixedHeaders(UNSUBSCRIBE) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId # payload - list of topics self.data = Data[:] if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) for d in self.data: buffer += writeUTF(d) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == UNSUBSCRIBE fhlen = self.fh.unpack(buffer) assert len(buffer) >= fhlen + self.fh.remainingLength logger.info("[MQTT-2.3.1-1] packet indentifier must be in unsubscribe") self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0" leftlen = self.fh.remainingLength - 2 self.data = [] while leftlen > 0: topic = readUTF(buffer[-leftlen:], leftlen) leftlen -= len(topic) + 2 self.data.append(topic) assert leftlen == 0 assert self.fh.DUP == False, "[MQTT-2.1.2-1]" assert self.fh.QoS == 1, "[MQTT-2.1.2-1]" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]" logger.info("[MQTT-3-10.1-1] fixed header bits are 0,0,1,0") return fhlen + self.fh.remainingLength def __str__(self): return str(self.fh)+", MsgId="+str(self.messageIdentifier)+\ ", Data="+str(self.data)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier and \ self.data == packet.data class Unsubacks(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False, MsgId=0): self.fh = FixedHeaders(UNSUBACK) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain # variable header self.messageIdentifier = MsgId if buffer != None: self.unpack(buffer) def pack(self): buffer = writeInt16(self.messageIdentifier) buffer = self.fh.pack(len(buffer)) + buffer return buffer def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == UNSUBACK fhlen = self.fh.unpack(buffer) assert len(buffer) >= fhlen + self.fh.remainingLength self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.messageIdentifier > 0, "[MQTT-2.3.1-1] packet indentifier must be > 0" self.messageIdentifier = readInt16(buffer[fhlen:]) assert self.fh.DUP == False, "[MQTT-2.1.2-1]" assert self.fh.QoS == 0, "[MQTT-2.1.2-1]" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]" return fhlen + self.fh.remainingLength def __str__(self): return str(self.fh)+", MsgId="+str(self.messageIdentifier)+")" def __eq__(self, packet): return Packets.__eq__(self, packet) and \ self.messageIdentifier == packet.messageIdentifier class Pingreqs(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False): self.fh = FixedHeaders(PINGREQ) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain if buffer != None: self.unpack(buffer) def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == PINGREQ fhlen = self.fh.unpack(buffer) assert self.fh.remainingLength == 0 assert self.fh.DUP == False, "[MQTT-2.1.2-1]" assert self.fh.QoS == 0, "[MQTT-2.1.2-1]" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]" return fhlen def __str__(self): return str(self.fh)+")" class Pingresps(Packets): def __init__(self, buffer=None, DUP=False, QoS=0, Retain=False): self.fh = FixedHeaders(PINGRESP) self.fh.DUP = DUP self.fh.QoS = QoS self.fh.Retain = Retain if buffer != None: self.unpack(buffer) def unpack(self, buffer): assert len(buffer) >= 2 assert MessageType(buffer) == PINGRESP fhlen = self.fh.unpack(buffer) assert self.fh.remainingLength == 0 assert self.fh.DUP == False, "[MQTT-2.1.2-1]" assert self.fh.QoS == 0, "[MQTT-2.1.2-1]" assert self.fh.RETAIN == False, "[MQTT-2.1.2-1]" return fhlen def __str__(self): return str(self.fh)+")" classes = [None, Connects, Connacks, Publishes, Pubacks, Pubrecs, Pubrels, Pubcomps, Subscribes, Subacks, Unsubscribes, Unsubacks, Pingreqs, Pingresps, Disconnects] def unpackPacket(buffer): if MessageType(buffer) != None: packet = classes[MessageType(buffer)]() packet.unpack(buffer) else: packet = None return packet if __name__ == "__main__": fh = FixedHeaders(CONNECT) tests = [0, 56, 127, 128, 8888, 16383, 16384, 65535, 2097151, 2097152, 20555666, 268435454, 268435455] for x in tests: try: assert x == fh.decode(fh.encode(x))[0] except AssertionError: print("Test failed for x =", x, fh.decode(fh.encode(x))) try: fh.decode(fh.encode(268435456)) print("Error") except AssertionError: pass for packet in classes[1:]: before = str(packet()) after = str(unpackPacket(packet().pack())) try: assert before == after except: print("before:", before, "\nafter:", after) print("End")