#!/usr/bin/env python3 # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Tests the generated python backend against standard PDL # constructs, with matching input vectors. import dataclasses import enum import json import typing import unittest from importlib import resources # (le|be)_backend are the names of the modules generated from the canonical # little endian and big endian test grammars. The purpose of this module # is to validate the generated parsers against the set of pre-generated # test vectors in canonical/(le|be)_test_vectors.json. import le_backend import be_backend SKIPPED_TESTS = [ "Packet_Array_Field_VariableElementSize_ConstantSize", "Packet_Array_Field_VariableElementSize_VariableSize", "Packet_Array_Field_VariableElementSize_VariableCount", "Packet_Array_Field_VariableElementSize_UnknownSize", ] def match_object(self, left, right): """Recursively match a python class object against a reference json object.""" if isinstance(right, int): self.assertEqual(left, right) elif isinstance(right, list): self.assertEqual(len(left), len(right)) for n in range(len(right)): match_object(self, left[n], right[n]) elif isinstance(right, dict): for (k, v) in right.items(): self.assertTrue(hasattr(left, k)) match_object(self, getattr(left, k), v) def create_object(typ, value): """Build an object of the selected type using the input value.""" if dataclasses.is_dataclass(typ): field_types = dict([(f.name, f.type) for f in dataclasses.fields(typ)]) values = dict() for (f, v) in value.items(): field_type = field_types[f] values[f] = create_object(field_type, v) return typ(**values) elif typing.get_origin(typ) is list: typ = typing.get_args(typ)[0] return [create_object(typ, v) for v in value] elif typing.get_origin(typ) is typing.Union: # typing.Optional[int] expands to typing.Union[int, None] typ = typing.get_args(typ)[0] return create_object(typ, value) if value is not None else None elif typ is bytes: return bytes(value) elif typ is bytearray: return bytearray(value) elif issubclass(typ, enum.Enum): from_int = getattr(typ, 'from_int') return from_int(value) elif typ is int: return value else: raise Exception(f"unsupported type annotation {typ}") class PacketParserTest(unittest.TestCase): """Validate the generated parser against pre-generated test vectors in canonical/(le|be)_test_vectors.json""" def testLittleEndian(self): with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f: reference = json.load(f) for item in reference: # 'packet' is the name of the packet being tested, # 'tests' lists input vectors that must match the # selected packet. packet = item['packet'] tests = item['tests'] if packet in SKIPPED_TESTS: continue with self.subTest(packet=packet): # Retrieve the class object from the generated # module, in order to invoke the proper parse # method for this test. cls = getattr(le_backend, packet) for test in tests: result = cls.parse_all(bytes.fromhex(test['packed'])) match_object(self, result, test['unpacked']) def testBigEndian(self): with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f: reference = json.load(f) for item in reference: # 'packet' is the name of the packet being tested, # 'tests' lists input vectors that must match the # selected packet. packet = item['packet'] tests = item['tests'] if packet in SKIPPED_TESTS: continue with self.subTest(packet=packet): # Retrieve the class object from the generated # module, in order to invoke the proper constructor # method for this test. cls = getattr(be_backend, packet) for test in tests: result = cls.parse_all(bytes.fromhex(test['packed'])) match_object(self, result, test['unpacked']) class PacketSerializerTest(unittest.TestCase): """Validate the generated serializer against pre-generated test vectors in canonical/(le|be)_test_vectors.json""" def testLittleEndian(self): with resources.files('tests.canonical').joinpath('le_test_vectors.json').open('r') as f: reference = json.load(f) for item in reference: # 'packet' is the name of the packet being tested, # 'tests' lists input vectors that must match the # selected packet. packet = item['packet'] tests = item['tests'] if packet in SKIPPED_TESTS: continue with self.subTest(packet=packet): # Retrieve the class object from the generated # module, in order to invoke the proper constructor # method for this test. for test in tests: cls = getattr(le_backend, test.get('packet', packet)) obj = create_object(cls, test['unpacked']) result = obj.serialize() self.assertEqual(result, bytes.fromhex(test['packed'])) def testBigEndian(self): with resources.files('tests.canonical').joinpath('be_test_vectors.json').open('r') as f: reference = json.load(f) for item in reference: # 'packet' is the name of the packet being tested, # 'tests' lists input vectors that must match the # selected packet. packet = item['packet'] tests = item['tests'] if packet in SKIPPED_TESTS: continue with self.subTest(packet=packet): # Retrieve the class object from the generated # module, in order to invoke the proper parse # method for this test. for test in tests: cls = getattr(be_backend, test.get('packet', packet)) obj = create_object(cls, test['unpacked']) result = obj.serialize() self.assertEqual(result, bytes.fromhex(test['packed'])) class CustomPacketParserTest(unittest.TestCase): """Manual testing for custom fields.""" def testCustomField(self): result = le_backend.Packet_Custom_Field_ConstantSize.parse_all([1]) self.assertEqual(result.a.value, 1) result = le_backend.Packet_Custom_Field_VariableSize.parse_all([1]) self.assertEqual(result.a.value, 1) result = le_backend.Struct_Custom_Field_ConstantSize.parse_all([1]) self.assertEqual(result.s.a.value, 1) result = le_backend.Struct_Custom_Field_VariableSize.parse_all([1]) self.assertEqual(result.s.a.value, 1) result = be_backend.Packet_Custom_Field_ConstantSize.parse_all([1]) self.assertEqual(result.a.value, 1) result = be_backend.Packet_Custom_Field_VariableSize.parse_all([1]) self.assertEqual(result.a.value, 1) result = be_backend.Struct_Custom_Field_ConstantSize.parse_all([1]) self.assertEqual(result.s.a.value, 1) result = be_backend.Struct_Custom_Field_VariableSize.parse_all([1]) self.assertEqual(result.s.a.value, 1) if __name__ == '__main__': unittest.main(verbosity=3)