#!/usr/bin/env python3 # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from dataclasses import dataclass, field import json from pathlib import Path import sys from textwrap import dedent from typing import List, Tuple, Union, Optional from pdl import ast, core from pdl.utils import indent, to_pascal_case def get_cxx_scalar_type(width: int) -> str: """Return the cxx scalar type to be used to back a PDL type.""" for n in [8, 16, 32, 64]: if width <= n: return f'uint{n}_t' # PDL type does not fit on non-extended scalar types. assert False def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str: """Generate the implementation of unit tests for the selected packet.""" def parse_packet(packet: ast.PacketDeclaration) -> str: parent = parse_packet(packet.parent) if packet.parent else "input" return f"{packet.id}View::Create({parent})" def input_bytes(input: str) -> List[str]: input = bytes.fromhex(input) input_bytes = [] for i in range(0, len(input), 16): input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16])) return input_bytes def get_field(decl: ast.Declaration, var: str, id: str) -> str: if isinstance(decl, ast.StructDeclaration): return f"{var}.{id}_" else: return f"{var}.Get{to_pascal_case(id)}()" def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]: checks = [] for (id, value) in expected.items(): field = core.get_packet_field(decl, id) sanitized_var = var.replace('[', '_').replace(']', '') field_var = f'{sanitized_var}_{id}' if isinstance(field, ast.ScalarField) and field.cond: value = f"std::make_optional({value})" if value is not None else "std::nullopt" checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") elif isinstance(field, ast.ScalarField): checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and field.cond): value = f"std::make_optional({field.type_id}({value}))" if value is not None else "std::nullopt" checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))): checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));") elif isinstance(field, ast.TypedefField) and field.cond and value is None: checks.append(f"ASSERT_TRUE(!{get_field(decl, var, id)}.has_value());") elif isinstance(field, ast.TypedefField) and field.cond and value is not None: checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)}.value();") checks.extend(check_members(field.type, field_var, value)) elif isinstance(field, ast.TypedefField): checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};") checks.extend(check_members(field.type, field_var, value)) elif isinstance(field, (ast.PayloadField, ast.BodyField)): checks.append(f"std::vector expected_{field_var} {{") for i in range(0, len(value), 16): checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) checks.append("};") checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") elif isinstance(field, ast.ArrayField) and field.size and field.width: checks.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> expected_{field_var} {{") step = int(16 * 8 / field.width) for i in range(0, len(value), step): checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) checks.append("};") checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") elif isinstance(field, ast.ArrayField) and field.width: checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{") step = int(16 * 8 / field.width) for i in range(0, len(value), step): checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) checks.append("};") checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") elif (isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration)): checks.append(f"std::array<{field.type_id}, {field.size}> expected_{field_var} {{") for v in value: checks.append(f" {field.type_id}({v}),") checks.append("};") checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)): checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{") for v in value: checks.append(f" {field.type_id}({v}),") checks.append("};") checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") elif isinstance(field, ast.ArrayField) and field.size: checks.append(f"std::array<{field.type_id}, {field.size}> {field_var} = {get_field(decl, var, id)};") checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") for (n, value) in enumerate(value): checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) elif isinstance(field, ast.ArrayField): checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};") checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") for (n, value) in enumerate(value): checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) else: pass return checks generated_tests = [] for (test_nr, test) in enumerate(tests): child_packet_id = test.get('packet', packet.id) child_packet = packet.file.packet_scope[child_packet_id] generated_tests.append( dedent("""\ TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{ pdl::packet::slice input(std::shared_ptr>(new std::vector {{ {input_bytes} }})); {child_packet_id}View packet = {parse_packet}; ASSERT_TRUE(packet.IsValid()); {checks} }} """).format(parser_test_suite=parser_test_suite, packet_id=packet.id, child_packet_id=child_packet_id, test_nr=test_nr, input_bytes=indent(input_bytes(test['packed']), 2), parse_packet=parse_packet(child_packet), checks=indent(check_members(packet, 'packet', test['unpacked']), 1))) return ''.join(generated_tests) def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str: """Generate the implementation of unit tests for the selected packet.""" def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]): fields = core.get_unconstrained_parent_fields(decl) + decl.fields declarations = [] parameters = [] for field in fields: sanitized_var = var.replace('[', '_').replace(']', '') field_id = getattr(field, 'id', None) field_var = f'{sanitized_var}_{field_id}' value = initializer['payload'] if isinstance(field, (ast.PayloadField, ast.BodyField)) else initializer.get(field_id, None) if field.cond_for: pass elif field.cond and value is None: parameters.append("std::nullopt") elif isinstance(field, ast.ScalarField) and field.cond: parameters.append(f"std::make_optional({value})") elif isinstance(field, ast.ScalarField): parameters.append(f"{value}") elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and field.cond: parameters.append(f"std::make_optional({field.type_id}({value}))") elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): parameters.append(f"{field.type_id}({value})") elif isinstance(field, ast.TypedefField): (element, intermediate_declarations) = build_packet(field.type, field_var, value) declarations.extend(intermediate_declarations) parameters.append(element) elif isinstance(field, (ast.PayloadField, ast.BodyField)): declarations.append(f"std::vector {field_var} {{") for i in range(0, len(value), 16): declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) declarations.append("};") parameters.append(f"std::move({field_var})") elif isinstance(field, ast.ArrayField) and field.size and field.width: declarations.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> {field_var} {{") step = int(16 * 8 / field.width) for i in range(0, len(value), step): declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) declarations.append("};") parameters.append(f"std::move({field_var})") elif isinstance(field, ast.ArrayField) and field.width: declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{") step = int(16 * 8 / field.width) for i in range(0, len(value), step): declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) declarations.append("};") parameters.append(f"std::move({field_var})") elif isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration): declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") for v in value: declarations.append(f" {field.type_id}({v}),") declarations.append("};") parameters.append(f"std::move({field_var})") elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration): declarations.append(f"std::vector<{field.type_id}> {field_var} {{") for v in value: declarations.append(f" {field.type_id}({v}),") declarations.append("};") parameters.append(f"std::move({field_var})") elif isinstance(field, ast.ArrayField) and field.size: elements = [] for (n, value) in enumerate(value): (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) elements.append(element) declarations.extend(intermediate_declarations) declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") for element in elements: declarations.append(f" {element},") declarations.append("};") parameters.append(f"std::move({field_var})") elif isinstance(field, ast.ArrayField): elements = [] for (n, value) in enumerate(value): (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) elements.append(element) declarations.extend(intermediate_declarations) declarations.append(f"std::vector<{field.type_id}> {field_var} {{") for element in elements: declarations.append(f" {element},") declarations.append("};") parameters.append(f"std::move({field_var})") else: pass constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id return (f"{constructor_name}({', '.join(parameters)})", declarations) def output_bytes(output: str) -> List[str]: output = bytes.fromhex(output) output_bytes = [] for i in range(0, len(output), 16): output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16])) return output_bytes generated_tests = [] for (test_nr, test) in enumerate(tests): child_packet_id = test.get('packet', packet.id) child_packet = packet.file.packet_scope[child_packet_id] (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked']) generated_tests.append( dedent("""\ TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{ std::vector expected_output {{ {output_bytes} }}; {intermediate_declarations} {child_packet_id}Builder packet = {built_packet}; ASSERT_EQ(packet.SerializeToBytes(), expected_output); }} """).format(serializer_test_suite=serializer_test_suite, packet_id=packet.id, child_packet_id=child_packet_id, test_nr=test_nr, output_bytes=indent(output_bytes(test['packed']), 2), built_packet=built_packet, intermediate_declarations=indent(intermediate_declarations, 1))) return ''.join(generated_tests) def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str], using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str): file = ast.File.from_json(json.load(input)) tests = json.load(test_vectors) core.desugar(file) include_header = '\n'.join([f'#include <{header}>' for header in include_header]) using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace]) skipped_tests = [ 'Packet_Checksum_Field_FromStart', 'Packet_Checksum_Field_FromEnd', 'Struct_Checksum_Field_FromStart', 'Struct_Checksum_Field_FromEnd', 'PartialParent5', 'PartialParent12', 'Packet_Array_Field_VariableElementSize_ConstantSize', 'Packet_Array_Field_VariableElementSize_VariableSize', 'Packet_Array_Field_VariableElementSize_VariableCount', 'Packet_Array_Field_VariableElementSize_UnknownSize', ] output.write( dedent("""\ // File generated from {input_name} and {test_vectors_name}, with the command: // {input_command} // /!\\ Do not edit by hand #include #include #include #include {include_header} {using_namespace} namespace {namespace} {{ class {parser_test_suite} : public testing::Test {{}}; class {serializer_test_suite} : public testing::Test {{}}; """).format(parser_test_suite=parser_test_suite, serializer_test_suite=serializer_test_suite, input_name=input.name, input_command=' '.join(sys.argv), test_vectors_name=test_vectors.name, include_header=include_header, using_namespace=using_namespace, namespace=namespace)) for decl in file.declarations: if decl.id in skipped_tests: continue if isinstance(decl, ast.PacketDeclaration): matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id] matching_tests = [test for test_list in matching_tests for test in test_list] if matching_tests: output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests)) output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests)) output.write(f"}} // namespace {namespace}\n") def main() -> int: """Generate cxx PDL backend.""" parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file') parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file') parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file') parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite') parser.add_argument('--serializer-test-suite', type=str, default='SerializerTest', help='Name of the serializer test suite') parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives') parser.add_argument('--using-namespace', type=str, default=[], action='append', help='Added using namespace statements') return run(**vars(parser.parse_args())) if __name__ == '__main__': sys.exit(main())