#!/usr/bin/env python3 # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from dataclasses import dataclass, field import json from pathlib import Path import sys from textwrap import dedent from typing import List, Tuple, Union, Optional from pdl import ast, core from pdl.utils import indent, to_pascal_case def mask(width: int) -> str: return hex((1 << width) - 1) def deref(var: Optional[str], id: str) -> str: return f'{var}.{id}' if var else id def get_cxx_scalar_type(width: int) -> str: """Return the cxx scalar type to be used to back a PDL type.""" for n in [8, 16, 32, 64]: if width <= n: return f'uint{n}_t' # PDL type does not fit on non-extended scalar types. assert False @dataclass class FieldParser: byteorder: str offset: int = 0 shift: int = 0 extract_arrays: bool = field(default=False) chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: []) chunk_nr: int = 0 unchecked_code: List[str] = field(default_factory=lambda: []) code: List[str] = field(default_factory=lambda: []) def unchecked_append_(self, line: str): """Append unchecked field parsing code. The function check_size_ must be called to generate a size guard after parsing is completed.""" self.unchecked_code.append(line) def append_(self, line: str): """Append field parsing code. There must be no unchecked code left before this function is called.""" assert len(self.unchecked_code) == 0 self.code.append(line) def check_size_(self, size: str): """Generate a check of the current span size.""" self.append_(f"if (span.size() < {size}) {{") self.append_(" return false;") self.append_("}") def check_code_(self): """Generate a size check for pending field parsing.""" if len(self.unchecked_code) > 0: assert len(self.chunk) == 0 unchecked_code = self.unchecked_code self.unchecked_code = [] self.check_size_(str(self.offset)) self.code.extend(unchecked_code) self.offset = 0 def parse_bit_field_(self, field: ast.Field): """Parse the selected field as a bit field. The field is added to the current chunk. When a byte boundary is reached all saved fields are extracted together.""" # Add to current chunk. width = core.get_field_size(field) self.chunk.append((self.shift, width, field)) self.shift += width # Wait for more fields if not on a byte boundary. if (self.shift % 8) != 0: return # Parse the backing integer using the configured endianness, # extract field values. size = int(self.shift / 8) backing_type = get_cxx_scalar_type(self.shift) # Special case when no field is actually used from # the chunk. should_skip_value = all(isinstance(field, ast.ReservedField) for (_, _, field) in self.chunk) if should_skip_value: self.unchecked_append_(f"span.skip({size}); // skip reserved fields") self.offset += size self.shift = 0 self.chunk = [] return if len(self.chunk) > 1: value = f"chunk{self.chunk_nr}" self.unchecked_append_(f"{backing_type} {value} = span.read_{self.byteorder}<{backing_type}, {size}>();") self.chunk_nr += 1 else: value = f"span.read_{self.byteorder}<{backing_type}, {size}>()" for shift, width, field in self.chunk: v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}") if field.cond_for: self.unchecked_append_(f"uint8_t {field.id} = {v};") elif isinstance(field, ast.ScalarField): self.unchecked_append_(f"{field.id}_ = {v};") elif isinstance(field, ast.FixedField) and field.enum_id: self.unchecked_append_(f"if ({field.enum_id}({v}) != {field.enum_id}::{field.tag_id}) {{") self.unchecked_append_(" return false;") self.unchecked_append_("}") elif isinstance(field, ast.FixedField): self.unchecked_append_(f"if (({v}) != {hex(field.value)}) {{") self.unchecked_append_(" return false;") self.unchecked_append_("}") elif isinstance(field, ast.TypedefField): self.unchecked_append_(f"{field.id}_ = {field.type_id}({v});") elif isinstance(field, ast.SizeField): self.unchecked_append_(f"{field.field_id}_size = {v};") elif isinstance(field, ast.CountField): self.unchecked_append_(f"{field.field_id}_count = {v};") elif isinstance(field, ast.ReservedField): pass else: raise Exception(f'Unsupported bit field type {field.kind}') # Reset state. self.offset += size self.shift = 0 self.chunk = [] def parse_typedef_field_(self, field: ast.TypedefField): """Parse a typedef field, to the exclusion of Enum fields.""" if self.shift != 0: raise Exception('Typedef field does not start on an octet boundary') self.check_code_() self.append_( dedent("""\ if (!{field_type}::Parse(span, &{field_id}_)) {{ return false; }}""".format(field_type=field.type.id, field_id=field.id))) def parse_optional_field_(self, field: ast.Field): """Parse the selected optional field. Optional fields must start and end on a byte boundary.""" self.check_code_() if isinstance(field, ast.ScalarField): backing_type = get_cxx_scalar_type(field.width) self.append_(dedent(""" if ({cond_id} == {cond_value}) {{ if (span.size() < {size}) {{ return false; }} {field_id}_ = std::make_optional( span.read_{byteorder}<{backing_type}, {size}>()); }} """.format(size=int(field.width / 8), backing_type=backing_type, field_id=field.id, cond_id=field.cond.id, cond_value=field.cond.value, byteorder=self.byteorder))) elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): backing_type = get_cxx_scalar_type(field.type.width) self.append_(dedent(""" if ({cond_id} == {cond_value}) {{ if (span.size() < {size}) {{ return false; }} {field_id}_ = std::make_optional({type_id}( span.read_{byteorder}<{backing_type}, {size}>())); }} """.format(size=int(field.type.width / 8), backing_type=backing_type, type_id=field.type_id, field_id=field.id, cond_id=field.cond.id, cond_value=field.cond.value, byteorder=self.byteorder))) elif isinstance(field, ast.TypedefField): self.append_(dedent(""" if ({cond_id} == {cond_value}) {{ auto& output = {field_id}_.emplace(); if (!{type_id}::Parse(span, &output)) {{ return false; }} }} """.format(type_id=field.type_id, field_id=field.id, cond_id=field.cond.id, cond_value=field.cond.value))) else: raise Exception(f"unsupported field type {field.__class__.__name__}") def parse_array_field_lite_(self, field: ast.ArrayField): """Parse the selected array field. This function does not attempt to parse all elements but just to identify the span of the array.""" array_size = core.get_array_field_size(field) element_width = core.get_array_element_size(field) padded_size = field.padded_size if element_width: element_width = int(element_width / 8) if isinstance(array_size, int): size = None count = array_size elif isinstance(array_size, ast.SizeField): size = f'{field.id}_size' count = None elif isinstance(array_size, ast.CountField): size = None count = f'{field.id}_count' else: size = None count = None # Shift the span to reset the offset to 0. self.check_code_() # Apply the size modifier. if field.size_modifier and size: self.append_(f"{size} = {size} - {field.size_modifier};") # Compute the array size if the count and element width are known. if count is not None and element_width is not None: size = f"{count} * {element_width}" # Parse from the padded array if padding is present. if padded_size: self.check_size_(padded_size) self.append_("{") self.append_( f"pdl::packet::slice remaining_span = span.subrange({padded_size}, span.size() - {padded_size});") self.append_(f"span = span.subrange(0, {padded_size});") # The array size is known in bytes. if size is not None: self.check_size_(size) self.append_(f"{field.id}_ = span.subrange(0, {size});") self.append_(f"span.skip({size});") # The array count is known. The element width is dynamic. # Parse each element iteratively and derive the array span. elif count is not None: self.append_("{") self.append_("pdl::packet::slice temp_span = span;") self.append_(f"for (size_t n = 0; n < {count}; n++) {{") self.append_(f" {field.type_id} element;") self.append_(f" if (!{field.type_id}::Parse(temp_span, &element)) {{") self.append_(" return false;") self.append_(" }") self.append_("}") self.append_(f"{field.id}_ = span.subrange(0, span.size() - temp_span.size());") self.append_(f"span.skip({field.id}_.size());") self.append_("}") # The array size is not known, assume the array takes the # full remaining space. TODO support having fixed sized fields # following the array. else: self.append_(f"{field.id}_ = span;") self.append_("span.clear();") if padded_size: self.append_(f"span = remaining_span;") self.append_("}") def parse_array_field_full_(self, field: ast.ArrayField): """Parse the selected array field. This function does not attempt to parse all elements but just to identify the span of the array.""" array_size = core.get_array_field_size(field) element_width = core.get_array_element_size(field) element_type = field.type_id or get_cxx_scalar_type(field.width) padded_size = field.padded_size if element_width: element_width = int(element_width / 8) if isinstance(array_size, int): size = None count = array_size elif isinstance(array_size, ast.SizeField): size = f'{field.id}_size' count = None elif isinstance(array_size, ast.CountField): size = None count = f'{field.id}_count' else: size = None count = None # Shift the span to reset the offset to 0. self.check_code_() # Apply the size modifier. if field.size_modifier and size: self.append_(f"{size} = {size} - {field.size_modifier};") # Compute the array size if the count and element width are known. if count is not None and element_width is not None: size = f"{count} * {element_width}" # Parse from the padded array if padding is present. if padded_size: self.check_size_(padded_size) self.append_("{") self.append_( f"pdl::packet::slice remaining_span = span.subrange({padded_size}, span.size() - {padded_size});") self.append_(f"span = span.subrange(0, {padded_size});") # The array count is known statically, elements are scalar. if field.width and field.size: assert size is not None self.check_size_(size) element_size = int(field.width / 8) self.append_(f"for (size_t n = 0; n < {field.size}; n++) {{") self.append_(f" {field.id}_[n] = span.read_{self.byteorder}<{element_type}, {element_size}>();") self.append_("}") # The array count is known statically, elements are enum values. elif isinstance(field.type, ast.EnumDeclaration) and field.size: assert size is not None self.check_size_(size) element_size = int(field.type.width / 8) backing_type = get_cxx_scalar_type(field.type.width) self.append_(f"for (size_t n = 0; n < {field.size}; n++) {{") self.append_( f" {field.id}_[n] = {element_type}(span.read_{self.byteorder}<{backing_type}, {element_size}>());") self.append_("}") # The array count is known statically, elements have variable size. elif field.size: self.append_(f"for (size_t n = 0; n < {field.size}; n++) {{") self.append_(f" if (!{element_type}::Parse(span, &{field.id}_[n])) {{") self.append_(" return false;") self.append_(" }") self.append_("}") # The array size is known in bytes. elif size is not None: self.check_size_(size) self.append_("{") self.append_(f"pdl::packet::slice temp_span = span.subrange(0, {size});") self.append_(f"span.skip({size});") self.append_(f"while (temp_span.size() > 0) {{") if field.width: element_size = int(field.width / 8) self.append_(f" if (temp_span.size() < {element_size}) {{") self.append_(f" return false;") self.append_(" }") self.append_( f" {field.id}_.push_back(temp_span.read_{self.byteorder}<{element_type}, {element_size}>());") elif isinstance(field.type, ast.EnumDeclaration): backing_type = get_cxx_scalar_type(field.type.width) element_size = int(field.type.width / 8) self.append_(f" if (temp_span.size() < {element_size}) {{") self.append_(f" return false;") self.append_(" }") self.append_( f" {field.id}_.push_back({element_type}(temp_span.read_{self.byteorder}<{backing_type}, {element_size}>()));" ) else: self.append_(f" {element_type} element;") self.append_(f" if (!{element_type}::Parse(temp_span, &element)) {{") self.append_(f" return false;") self.append_(" }") self.append_(f" {field.id}_.emplace_back(std::move(element));") self.append_("}") self.append_("}") # The array count is known. The element width is dynamic. # Parse each element iteratively and derive the array span. elif count is not None: self.append_(f"for (size_t n = 0; n < {count}; n++) {{") self.append_(f" {element_type} element;") self.append_(f" if (!{field.type_id}::Parse(span, &element)) {{") self.append_(" return false;") self.append_(" }") self.append_(f" {field.id}_.emplace_back(std::move(element));") self.append_("}") # The array size is not known, assume the array takes the # full remaining space. TODO support having fixed sized fields # following the array. elif field.width: element_size = int(field.width / 8) self.append_(f"while (span.size() > 0) {{") self.append_(f" if (span.size() < {element_size}) {{") self.append_(f" return false;") self.append_(" }") self.append_(f" {field.id}_.push_back(span.read_{self.byteorder}<{element_type}, {element_size}>());") self.append_("}") elif isinstance(field.type, ast.EnumDeclaration): element_size = int(field.type.width / 8) backing_type = get_cxx_scalar_type(field.type.width) self.append_(f"while (span.size() > 0) {{") self.append_(f" if (span.size() < {element_size}) {{") self.append_(f" return false;") self.append_(" }") self.append_( f" {field.id}_.push_back({element_type}(span.read_{self.byteorder}<{backing_type}, {element_size}>()));" ) self.append_("}") else: self.append_(f"while (span.size() > 0) {{") self.append_(f" {element_type} element;") self.append_(f" if (!{element_type}::Parse(span, &element)) {{") self.append_(f" return false;") self.append_(" }") self.append_(f" {field.id}_.emplace_back(std::move(element));") self.append_("}") if padded_size: self.append_(f"span = remaining_span;") self.append_("}") def parse_payload_field_lite_(self, field: Union[ast.BodyField, ast.PayloadField]): """Parse body and payload fields.""" if self.shift != 0: raise Exception('Payload field does not start on an octet boundary') payload_size = core.get_payload_field_size(field) offset_from_end = core.get_field_offset_from_end(field) self.check_code_() if payload_size and getattr(field, 'size_modifier', None): self.append_(f"{field.id}_size -= {field.size_modifier};") # The payload or body has a known size. # Consume the payload and update the span in case # fields are placed after the payload. if payload_size: self.check_size_(f"{field.id}_size") self.append_(f"payload_ = span.subrange(0, {field.id}_size);") self.append_(f"span.skip({field.id}_size);") # The payload or body is the last field of a packet, # consume the remaining span. elif offset_from_end == 0: self.append_(f"payload_ = span;") self.append_(f"span.clear();") # The payload or body is followed by fields of static size. # Consume the span that is not reserved for the following fields. elif offset_from_end: if (offset_from_end % 8) != 0: raise Exception('Payload field offset from end of packet is not a multiple of 8') offset_from_end = int(offset_from_end / 8) self.check_size_(f'{offset_from_end}') self.append_(f"payload_ = span.subrange(0, span.size() - {offset_from_end});") self.append_(f"span.skip(payload_.size());") def parse_payload_field_full_(self, field: Union[ast.BodyField, ast.PayloadField]): """Parse body and payload fields.""" if self.shift != 0: raise Exception('Payload field does not start on an octet boundary') payload_size = core.get_payload_field_size(field) offset_from_end = core.get_field_offset_from_end(field) self.check_code_() if payload_size and getattr(field, 'size_modifier', None): self.append_(f"{field.id}_size -= {field.size_modifier};") # The payload or body has a known size. # Consume the payload and update the span in case # fields are placed after the payload. if payload_size: self.check_size_(f"{field.id}_size") self.append_(f"for (size_t n = 0; n < {field.id}_size; n++) {{") self.append_(f" payload_.push_back(span.read_{self.byteorder}();") self.append_("}") # The payload or body is the last field of a packet, # consume the remaining span. elif offset_from_end == 0: self.append_("while (span.size() > 0) {") self.append_(f" payload_.push_back(span.read_{self.byteorder}();") self.append_("}") # The payload or body is followed by fields of static size. # Consume the span that is not reserved for the following fields. elif offset_from_end is not None: if (offset_from_end % 8) != 0: raise Exception('Payload field offset from end of packet is not a multiple of 8') offset_from_end = int(offset_from_end / 8) self.check_size_(f'{offset_from_end}') self.append_(f"while (span.size() > {offset_from_end}) {{") self.append_(f" payload_.push_back(span.read_{self.byteorder}();") self.append_("}") def parse(self, field: ast.Field): # Field has bit granularity. # Append the field to the current chunk, # check if a byte boundary was reached. if field.cond: self.parse_optional_field_(field) elif core.is_bit_field(field): self.parse_bit_field_(field) # Padding fields. elif isinstance(field, ast.PaddingField): pass # Array fields. elif isinstance(field, ast.ArrayField) and self.extract_arrays: self.parse_array_field_full_(field) elif isinstance(field, ast.ArrayField) and not self.extract_arrays: self.parse_array_field_lite_(field) # Other typedef fields. elif isinstance(field, ast.TypedefField): self.parse_typedef_field_(field) # Payload and body fields. elif isinstance(field, (ast.PayloadField, ast.BodyField)) and self.extract_arrays: self.parse_payload_field_full_(field) elif isinstance(field, (ast.PayloadField, ast.BodyField)) and not self.extract_arrays: self.parse_payload_field_lite_(field) else: raise Exception(f'Unsupported field type {field.kind}') def done(self): self.check_code_() @dataclass class FieldSerializer: byteorder: str shift: int = 0 value: List[Tuple[str, int]] = field(default_factory=lambda: []) code: List[str] = field(default_factory=lambda: []) indent: int = 0 def indent_(self): self.indent += 1 def unindent_(self): self.indent -= 1 def append_(self, line: str): """Append field serializing code.""" lines = line.split('\n') self.code.extend([' ' * self.indent + line for line in lines]) def get_payload_field_size(self, var: Optional[str], payload: ast.PayloadField, decl: ast.Declaration) -> str: """Compute the size of the selected payload field, with the information of the builder for the selected declaration. The payload field can be the payload of any of the parent declarations, or the current declaration.""" if payload.parent.id == decl.id: return deref(var, 'payload_.size()') # Get the child packet declaration that will match the current # declaration further down. child = decl while child.parent_id != payload.parent.id: child = child.parent # The payload is the result of serializing the children fields. constant_width = 0 variable_width = [] for f in child.fields: field_size = core.get_field_size(f) if field_size is not None: constant_width += field_size elif isinstance(f, (ast.PayloadField, ast.BodyField)): variable_width.append(self.get_payload_field_size(var, f, decl)) elif isinstance(f, ast.TypedefField): variable_width.append(f"{f.id}_.GetSize()") elif isinstance(f, ast.ArrayField): variable_width.append(f"Get{to_pascal_case(f.id)}Size()") else: raise Exception("Unsupported field type") constant_width = int(constant_width / 8) if constant_width and not variable_width: return str(constant_width) temp_var = f'{payload.parent.id.lower()}_payload_size' self.append_(f"size_t {temp_var} = {constant_width};") for dyn in variable_width: self.append_(f"{temp_var} += {dyn};") return temp_var def serialize_array_element_(self, field: ast.ArrayField, var: str): """Serialize a single array field element.""" if field.width: backing_type = get_cxx_scalar_type(field.width) element_size = int(field.width / 8) self.append_( f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {element_size}>(output, {var});") elif isinstance(field.type, ast.EnumDeclaration): backing_type = get_cxx_scalar_type(field.type.width) element_size = int(field.type.width / 8) self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {element_size}>(" + f"output, static_cast<{backing_type}>({var}));") else: self.append_(f"{var}.Serialize(output);") def serialize_array_field_(self, field: ast.ArrayField, var: str): """Serialize the selected array field.""" if field.padded_size: self.append_(f"size_t {field.id}_end = output.size() + {field.padded_size};") if field.width == 8: self.append_(f"output.insert(output.end(), {var}.begin(), {var}.end());") else: self.append_(f"for (size_t n = 0; n < {var}.size(); n++) {{") self.indent_() self.serialize_array_element_(field, f'{var}[n]') self.unindent_() self.append_("}") if field.padded_size: self.append_(f"while (output.size() < {field.id}_end) {{") self.append_(" output.push_back(0);") self.append_("}") def serialize_bit_field_(self, field: ast.Field, parent_var: Optional[str], var: Optional[str], decl: ast.Declaration): """Serialize the selected field as a bit field. The field is added to the current chunk. When a byte boundary is reached all saved fields are serialized together.""" # Add to current chunk. width = core.get_field_size(field) shift = self.shift if field.cond_for: value_present = field.cond_for.cond.value value_absent = 0 if field.cond_for.cond.value else 1 self.value.append((f"({field.cond_for.id}_.has_value() ? {value_present} : {value_absent})", shift)) elif isinstance(field, ast.ScalarField): self.value.append((f"{var} & {mask(field.width)}", shift)) elif isinstance(field, ast.FixedField) and field.enum_id: self.value.append((f"{field.enum_id}::{field.tag_id}", shift)) elif isinstance(field, ast.FixedField): self.value.append((f"{field.value}", shift)) elif isinstance(field, ast.TypedefField): self.value.append((f"{var}", shift)) elif isinstance(field, ast.SizeField): max_size = (1 << field.width) - 1 value_field = core.get_packet_field(field.parent, field.field_id) size_modifier = '' if getattr(value_field, 'size_modifier', None): size_modifier = f' + {value_field.size_modifier}' if isinstance(value_field, (ast.PayloadField, ast.BodyField)): array_size = self.get_payload_field_size(var, field, decl) + size_modifier elif isinstance(value_field, ast.ArrayField): accessor_name = to_pascal_case(field.field_id) array_size = deref(var, f'Get{accessor_name}Size()') + size_modifier self.value.append((f"{array_size}", shift)) elif isinstance(field, ast.CountField): max_count = (1 << field.width) - 1 self.value.append((f"{field.field_id}_.size()", shift)) elif isinstance(field, ast.ReservedField): pass else: raise Exception(f'Unsupported bit field type {field.kind}') # Check if a byte boundary is reached. self.shift += width if (self.shift % 8) == 0: self.pack_bit_fields_() def pack_bit_fields_(self): """Pack serialized bit fields.""" # Should have an integral number of bytes now. assert (self.shift % 8) == 0 # Generate the backing integer, and serialize it # using the configured endiannes, size = int(self.shift / 8) backing_type = get_cxx_scalar_type(self.shift) value = [f"(static_cast<{backing_type}>({v[0]}) << {v[1]})" for v in self.value] if len(value) == 0: self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, 0);") elif len(value) == 1: self.append_(f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, {value[0]});") else: self.append_( f"pdl::packet::Builder::write_{self.byteorder}<{backing_type}, {size}>(output, {' | '.join(value)});") # Reset state. self.shift = 0 self.value = [] def serialize_typedef_field_(self, field: ast.TypedefField, var: str): """Serialize a typedef field, to the exclusion of Enum fields.""" if self.shift != 0: raise Exception('Typedef field does not start on an octet boundary') if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None): raise Exception('Derived struct used in typedef field') self.append_(f"{var}.Serialize(output);") def serialize_optional_field_(self, field: ast.Field): """Serialize optional scalar or typedef fields.""" if isinstance(field, ast.ScalarField): backing_type = get_cxx_scalar_type(field.width) self.append_(dedent( """ if ({field_id}_.has_value()) {{ pdl::packet::Builder::write_{byteorder}<{backing_type}, {size}>(output, {field_id}_.value()); }}""".format(field_id=field.id, size=int(field.width / 8), backing_type=backing_type, byteorder=self.byteorder))) elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): backing_type = get_cxx_scalar_type(field.type.width) self.append_(dedent( """ if ({field_id}_.has_value()) {{ pdl::packet::Builder::write_{byteorder}<{backing_type}, {size}>( output, static_cast<{backing_type}>({field_id}_.value())); }}""".format(field_id=field.id, size=int(field.type.width / 8), backing_type=backing_type, byteorder=self.byteorder))) elif isinstance(field, ast.TypedefField): self.append_(dedent( """ if ({field_id}_.has_value()) {{ {field_id}_->Serialize(output); }}""".format(field_id=field.id))) else: raise Exception(f"unsupported field type {field.__class__.__name__}") def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField], var: str): """Serialize body and payload fields.""" if self.shift != 0: raise Exception('Payload field does not start on an octet boundary') self.append_(f"output.insert(output.end(), {var}.begin(), {var}.end());") def serialize(self, field: ast.Field, decl: ast.Declaration, var: Optional[str] = None): field_var = deref(var, f'{field.id}_') if hasattr(field, 'id') else None if field.cond: self.serialize_optional_field_(field) # Field has bit granularity. # Append the field to the current chunk, # check if a byte boundary was reached. elif core.is_bit_field(field): self.serialize_bit_field_(field, var, field_var, decl) # Padding fields. elif isinstance(field, ast.PaddingField): pass # Array fields. elif isinstance(field, ast.ArrayField): self.serialize_array_field_(field, field_var) # Other typedef fields. elif isinstance(field, ast.TypedefField): self.serialize_typedef_field_(field, field_var) # Payload and body fields. elif isinstance(field, (ast.PayloadField, ast.BodyField)): self.serialize_payload_field_(field, deref(var, 'payload_')) else: raise Exception(f'Unimplemented field type {field.kind}') def generate_enum_declaration(decl: ast.EnumDeclaration) -> str: """Generate the implementation of an enum type.""" enum_name = decl.id enum_type = get_cxx_scalar_type(decl.width) tag_decls = [] for t in decl.tags: # Exclude default tags: DEFAULT = .. if t.value is not None: tag_decls.append(f"{t.id} = {hex(t.value)},") return dedent("""\ enum class {enum_name} : {enum_type} {{ {tag_decls} }}; """).format(enum_name=enum_name, enum_type=enum_type, tag_decls=indent(tag_decls, 1)) def generate_enum_to_text(decl: ast.EnumDeclaration) -> str: """Generate the helper function that will convert an enum tag to string.""" enum_name = decl.id tag_cases = [] for t in decl.tags: # Exclude default tags: DEFAULT = .. if t.value is not None: tag_cases.append(f"case {enum_name}::{t.id}: return \"{t.id}\";") return dedent("""\ inline std::string {enum_name}Text({enum_name} tag) {{ switch (tag) {{ {tag_cases} default: return std::string("Unknown {enum_name}: " + std::to_string(static_cast(tag))); }} }} """).format(enum_name=enum_name, tag_cases=indent(tag_cases, 2)) def generate_packet_view_field_members(decl: ast.Declaration) -> List[str]: """Return the declaration of fields that are backed in the view class declaration. Backed fields include all named fields that do not have a constrained value in the selected declaration and its parents. :param decl: target declaration""" fields = core.get_unconstrained_parent_fields(decl) + decl.fields members = [] for field in fields: if field.cond_for: # Scalar fields used as condition for optional fields are treated # as fixed fields since their value is tied to the value of the # optional field. pass elif isinstance(field, (ast.PayloadField, ast.BodyField)): members.append("pdl::packet::slice payload_;") elif isinstance(field, ast.ArrayField): members.append(f"pdl::packet::slice {field.id}_;") elif isinstance(field, ast.ScalarField) and field.cond: members.append(f"std::optional<{get_cxx_scalar_type(field.width)}> {field.id}_{{}};") elif isinstance(field, ast.TypedefField) and field.cond: members.append(f"std::optional<{field.type_id}> {field.id}_{{}};") elif isinstance(field, ast.ScalarField): members.append(f"{get_cxx_scalar_type(field.width)} {field.id}_{{0}};") elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): members.append(f"{field.type_id} {field.id}_{{{field.type_id}::{field.type.tags[0].id}}};") elif isinstance(field, ast.TypedefField): members.append(f"{field.type_id} {field.id}_;") return members def generate_packet_field_members(decl: ast.Declaration) -> List[str]: """Return the declaration of fields that are backed in the view class declaration. Backed fields include all named fields that do not have a constrained value in the selected declaration and its parents. :param decl: target declaration""" members = [] for field in decl.fields: if field.cond_for: # Scalar fields used as condition for optional fields are treated # as fixed fields since their value is tied to the value of the # optional field. pass elif isinstance(field, (ast.PayloadField, ast.BodyField)) and not decl.parent: members.append("std::vector payload_;") elif isinstance(field, ast.ArrayField) and field.size: element_type = field.type_id or get_cxx_scalar_type(field.width) members.append(f"std::array<{element_type}, {field.size}> {field.id}_;") elif isinstance(field, ast.ArrayField): element_type = field.type_id or get_cxx_scalar_type(field.width) members.append(f"std::vector<{element_type}> {field.id}_;") elif isinstance(field, ast.ScalarField) and field.cond: members.append(f"std::optional<{get_cxx_scalar_type(field.width)}> {field.id}_{{}};") elif isinstance(field, ast.TypedefField) and field.cond: members.append(f"std::optional<{field.type_id}> {field.id}_{{}};") elif isinstance(field, ast.ScalarField): members.append(f"{get_cxx_scalar_type(field.width)} {field.id}_{{0}};") elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): members.append(f"{field.type_id} {field.id}_{{{field.type_id}::{field.type.tags[0].id}}};") elif isinstance(field, ast.TypedefField): members.append(f"{field.type_id} {field.id}_;") return members def generate_packet_field_serializers(packet: ast.Declaration) -> List[str]: """Generate the code to serialize the fields of a packet builder or struct.""" serializer = FieldSerializer(byteorder=packet.file.byteorder_short) constraints = core.get_parent_constraints(packet) constraints = dict([(c.id, c) for c in constraints]) for field in core.get_packet_fields(packet): field_id = getattr(field, 'id', None) constraint = constraints.get(field_id, None) fixed_field = None if constraint and constraint.tag_id: fixed_field = ast.FixedField(enum_id=field.type_id, tag_id=constraint.tag_id, loc=field.loc, kind='fixed_field') fixed_field.parent = field.parent elif constraint: fixed_field = ast.FixedField(width=field.width, value=constraint.value, loc=field.loc, kind='fixed_field') fixed_field.parent = field.parent serializer.serialize(fixed_field or field, packet) return serializer.code def generate_scalar_array_field_accessor(field: ast.ArrayField) -> str: """Parse the selected scalar array field.""" element_size = int(field.width / 8) backing_type = get_cxx_scalar_type(field.width) byteorder = field.parent.file.byteorder_short if field.size: return dedent("""\ pdl::packet::slice span = {field_id}_; std::array<{backing_type}, {array_size}> elements; for (int n = 0; n < {array_size}; n++) {{ elements[n] = span.read_{byteorder}<{backing_type}, {element_size}>(); }} return elements;""").format(field_id=field.id, backing_type=backing_type, element_size=element_size, array_size=field.size, byteorder=byteorder) else: return dedent("""\ pdl::packet::slice span = {field_id}_; std::vector<{backing_type}> elements; while (span.size() >= {element_size}) {{ elements.push_back(span.read_{byteorder}<{backing_type}, {element_size}>()); }} return elements;""").format(field_id=field.id, backing_type=backing_type, element_size=element_size, byteorder=byteorder) def generate_enum_array_field_accessor(field: ast.ArrayField) -> str: """Parse the selected enum array field.""" element_size = int(field.type.width / 8) backing_type = get_cxx_scalar_type(field.type.width) byteorder = field.parent.file.byteorder_short if field.size: return dedent("""\ pdl::packet::slice span = {field_id}_; std::array<{enum_type}, {array_size}> elements; for (int n = 0; n < {array_size}; n++) {{ elements[n] = {enum_type}(span.read_{byteorder}<{backing_type}, {element_size}>()); }} return elements;""").format(field_id=field.id, enum_type=field.type.id, backing_type=backing_type, element_size=element_size, array_size=field.size, byteorder=byteorder) else: return dedent("""\ pdl::packet::slice span = {field_id}_; std::vector<{enum_type}> elements; while (span.size() >= {element_size}) {{ elements.push_back({enum_type}(span.read_{byteorder}<{backing_type}, {element_size}>())); }} return elements;""").format(field_id=field.id, enum_type=field.type_id, backing_type=backing_type, element_size=element_size, byteorder=byteorder) def generate_typedef_array_field_accessor(field: ast.ArrayField) -> str: """Parse the selected typedef array field.""" if field.size: return dedent("""\ pdl::packet::slice span = {field_id}_; std::array<{struct_type}, {array_size}> elements; for (int n = 0; n < {array_size}; n++) {{ {struct_type}::Parse(span, &elements[n]); }} return elements;""").format(field_id=field.id, struct_type=field.type_id, array_size=field.size) else: return dedent("""\ pdl::packet::slice span = {field_id}_; std::vector<{struct_type}> elements; for (;;) {{ {struct_type} element; if (!{struct_type}::Parse(span, &element)) {{ break; }} elements.emplace_back(std::move(element)); }} return elements;""").format(field_id=field.id, struct_type=field.type_id) def generate_array_field_accessor(field: ast.ArrayField): """Parse the selected array field.""" if field.width is not None: return generate_scalar_array_field_accessor(field) elif isinstance(field.type, ast.EnumDeclaration): return generate_enum_array_field_accessor(field) else: return generate_typedef_array_field_accessor(field) def generate_array_field_size_getters(decl: ast.Declaration) -> str: """Generate size getters for array fields. Produces the serialized size of the array in bytes.""" getters = [] fields = core.get_unconstrained_parent_fields(decl) + decl.fields for field in fields: if not isinstance(field, ast.ArrayField): continue element_width = field.width or core.get_declaration_size(field.type) size = None if element_width and field.size: size = int(element_width * field.size / 8) elif element_width: size = f"{field.id}_.size() * {int(element_width / 8)}" if size: getters.append( dedent("""\ size_t Get{accessor_name}Size() const {{ return {size}; }} """).format(accessor_name=to_pascal_case(field.id), size=size)) else: getters.append( dedent("""\ size_t Get{accessor_name}Size() const {{ size_t array_size = 0; for (size_t n = 0; n < {field_id}_.size(); n++) {{ array_size += {field_id}_[n].GetSize(); }} return array_size; }} """).format(accessor_name=to_pascal_case(field.id), field_id=field.id)) return '\n'.join(getters) def generate_packet_size_getter(decl: ast.Declaration) -> List[str]: """Generate a size getter the current packet. Produces the serialized size of the packet in bytes.""" constant_width = 0 variable_width = [] for f in core.get_packet_fields(decl): field_size = core.get_field_size(f) if f.cond: if isinstance(f, ast.ScalarField): variable_width.append(f"({f.id}_.has_value() ? {f.width} : 0)") elif isinstance(f, ast.TypedefField) and isinstance(f.type, ast.EnumDeclaration): variable_width.append(f"({f.id}_.has_value() ? {f.type.width} : 0)") elif isinstance(f, ast.TypedefField): variable_width.append(f"({f.id}_.has_value() ? {f.id}_->GetSize() : 0)") else: raise Exception(f"unsupported field type {f.__class__.__name__}") elif field_size is not None: constant_width += field_size elif isinstance(f, (ast.PayloadField, ast.BodyField)): variable_width.append("payload_.size()") elif isinstance(f, ast.TypedefField): variable_width.append(f"{f.id}_.GetSize()") elif isinstance(f, ast.ArrayField): variable_width.append(f"Get{to_pascal_case(f.id)}Size()") else: raise Exception("Unsupported field type") constant_width = int(constant_width / 8) if not variable_width: return [f"return {constant_width};"] elif len(variable_width) == 1 and constant_width: return [f"return {variable_width[0]} + {constant_width};"] elif len(variable_width) == 1: return [f"return {variable_width[0]};"] elif len(variable_width) > 1 and constant_width: return ([f"return {constant_width} + ("] + " +\n ".join(variable_width).split("\n") + [");"]) elif len(variable_width) > 1: return (["return ("] + " +\n ".join(variable_width).split("\n") + [");"]) else: assert False def generate_packet_view_field_accessors(packet: ast.PacketDeclaration) -> List[str]: """Return the declaration of accessors for the named packet fields.""" accessors = [] # Add accessors for the backed fields. fields = core.get_unconstrained_parent_fields(packet) + packet.fields for field in fields: if field.cond_for: # Scalar fields used as condition for optional fields are treated # as fixed fields since their value is tied to the value of the # optional field. pass elif isinstance(field, (ast.PayloadField, ast.BodyField)): accessors.append( dedent("""\ std::vector GetPayload() const { _ASSERT_VALID(valid_); return payload_.bytes(); } """)) elif isinstance(field, ast.ArrayField): element_type = field.type_id or get_cxx_scalar_type(field.width) array_type = (f"std::array<{element_type}, {field.size}>" if field.size else f"std::vector<{element_type}>") accessor_name = to_pascal_case(field.id) accessors.append( dedent("""\ {array_type} Get{accessor_name}() const {{ _ASSERT_VALID(valid_); {accessor} }} """).format(array_type=array_type, accessor_name=accessor_name, accessor=indent(generate_array_field_accessor(field), 1))) elif isinstance(field, ast.ScalarField): field_type = get_cxx_scalar_type(field.width) field_type = f"std::optional<{field_type}>" if field.cond else field_type accessor_name = to_pascal_case(field.id) accessors.append( dedent("""\ {field_type} Get{accessor_name}() const {{ _ASSERT_VALID(valid_); return {member_name}_; }} """).format(field_type=field_type, accessor_name=accessor_name, member_name=field.id)) elif isinstance(field, ast.TypedefField): field_qualifier = "" if isinstance(field.type, ast.EnumDeclaration) else " const&" field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id accessor_name = to_pascal_case(field.id) accessors.append( dedent("""\ {field_type}{field_qualifier} Get{accessor_name}() const {{ _ASSERT_VALID(valid_); return {member_name}_; }} """).format(field_type=field_type, field_qualifier=field_qualifier, accessor_name=accessor_name, member_name=field.id)) # Add accessors for constrained parent fields. # The accessors return a constant value in this case. for c in core.get_parent_constraints(packet): field = core.get_packet_field(packet, c.id) if isinstance(field, ast.ScalarField): field_type = get_cxx_scalar_type(field.width) accessor_name = to_pascal_case(field.id) accessors.append( dedent("""\ {field_type} Get{accessor_name}() const {{ return {value}; }} """).format(field_type=field_type, accessor_name=accessor_name, value=c.value)) else: accessor_name = to_pascal_case(field.id) accessors.append( dedent("""\ {field_type} Get{accessor_name}() const {{ return {field_type}::{tag_id}; }} """).format(field_type=field.type_id, accessor_name=accessor_name, tag_id=c.tag_id)) return "".join(accessors) def generate_packet_stringifier(packet: ast.PacketDeclaration) -> str: """Generate the packet printer. TODO """ return dedent("""\ std::string ToString() const { return ""; } """) def generate_packet_view_field_parsers(packet: ast.PacketDeclaration) -> str: """Generate the packet parser. The validator will extract the fields it can in a pre-parsing phase. """ code = [] # Generate code to check the validity of the parent, # and import parent fields that do not have a fixed value in the # current packet. if packet.parent: code.append( dedent("""\ // Check validity of parent packet. if (!parent.IsValid()) { return false; } """)) parent_fields = core.get_unconstrained_parent_fields(packet) if parent_fields: code.append("// Copy parent field values.") for f in parent_fields: code.append(f"{f.id}_ = parent.{f.id}_;") code.append("") span = "parent.payload_" else: span = "parent" # Validate parent constraints. for c in packet.constraints: if c.tag_id: enum_type = core.get_packet_field(packet.parent, c.id).type_id code.append( dedent("""\ if (parent.{field_id}_ != {enum_type}::{tag_id}) {{ return false; }} """).format(field_id=c.id, enum_type=enum_type, tag_id=c.tag_id)) else: code.append( dedent("""\ if (parent.{field_id}_ != {value}) {{ return false; }} """).format(field_id=c.id, value=c.value)) # Parse fields linearly. if packet.fields: code.append("// Parse packet field values.") code.append(f"pdl::packet::slice span = {span};") for f in packet.fields: if isinstance(f, ast.SizeField): code.append(f"{get_cxx_scalar_type(f.width)} {f.field_id}_size;") elif isinstance(f, (ast.SizeField, ast.CountField)): code.append(f"{get_cxx_scalar_type(f.width)} {f.field_id}_count;") parser = FieldParser(extract_arrays=False, byteorder=packet.file.byteorder_short) for f in packet.fields: parser.parse(f) parser.done() code.extend(parser.code) code.append("return true;") return '\n'.join(code) def generate_packet_view_friend_classes(packet: ast.PacketDeclaration) -> str: """Generate the list of friend declarations for a packet. These are the direct children of the class.""" return [f"friend class {decl.id}View;" for (_, decl) in core.get_derived_packets(packet, traverse=False)] def generate_packet_view(packet: ast.PacketDeclaration) -> str: """Generate the implementation of the View class for a packet declaration.""" parent_class = f"{packet.parent.id}View" if packet.parent else "pdl::packet::slice" field_members = generate_packet_view_field_members(packet) field_accessors = generate_packet_view_field_accessors(packet) field_parsers = generate_packet_view_field_parsers(packet) friend_classes = generate_packet_view_friend_classes(packet) stringifier = generate_packet_stringifier(packet) bytes_initializer = f"parent.bytes_" if packet.parent else "parent" return dedent("""\ class {packet_name}View {{ public: static {packet_name}View Create({parent_class} const& parent) {{ return {packet_name}View(parent); }} {field_accessors} {stringifier} bool IsValid() const {{ return valid_; }} pdl::packet::slice bytes() const {{ return bytes_; }} protected: explicit {packet_name}View({parent_class} const& parent) : bytes_({bytes_initializer}) {{ valid_ = Parse(parent); }} bool Parse({parent_class} const& parent) {{ {field_parsers} }} bool valid_{{false}}; pdl::packet::slice bytes_; {field_members} {friend_classes} }}; """).format(packet_name=packet.id, parent_class=parent_class, bytes_initializer=bytes_initializer, field_accessors=indent(field_accessors, 1), field_members=indent(field_members, 1), field_parsers=indent(field_parsers, 2), friend_classes=indent(friend_classes, 1), stringifier=indent(stringifier, 1)) def generate_packet_constructor(struct: ast.StructDeclaration, constructor_name: str) -> str: """Generate the implementation of the constructor for a struct declaration.""" constructor_params = [] constructor_initializers = [] inherited_fields = core.get_unconstrained_parent_fields(struct) payload_initializer = '' parent_initializer = [] for field in inherited_fields: if isinstance(field, ast.ArrayField) and field.size: element_type = field.type_id or get_cxx_scalar_type(field.width) constructor_params.append(f"std::array<{element_type}, {field.size}> {field.id}") elif isinstance(field, ast.ArrayField): element_type = field.type_id or get_cxx_scalar_type(field.width) constructor_params.append(f"std::vector<{element_type}> {field.id}") elif isinstance(field, ast.ScalarField): backing_type = get_cxx_scalar_type(field.width) constructor_params.append(f"{backing_type} {field.id}") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)): constructor_params.append(f"{field.type_id} {field.id}") elif isinstance(field, ast.TypedefField): constructor_params.append(f"{field.type_id} {field.id}") for field in struct.fields: if field.cond_for: pass elif isinstance(field, (ast.PayloadField, ast.BodyField)): constructor_params.append("std::vector payload") if struct.parent: payload_initializer = f"payload_ = std::move(payload);" else: constructor_initializers.append("payload_(std::move(payload))") elif isinstance(field, ast.ArrayField) and field.size: element_type = field.type_id or get_cxx_scalar_type(field.width) constructor_params.append(f"std::array<{element_type}, {field.size}> {field.id}") constructor_initializers.append(f"{field.id}_(std::move({field.id}))") elif isinstance(field, ast.ArrayField): element_type = field.type_id or get_cxx_scalar_type(field.width) constructor_params.append(f"std::vector<{element_type}> {field.id}") constructor_initializers.append(f"{field.id}_(std::move({field.id}))") elif isinstance(field, ast.ScalarField): backing_type = get_cxx_scalar_type(field.width) field_type = f"std::optional<{backing_type}>" if field.cond else backing_type constructor_params.append(f"{field_type} {field.id}") constructor_initializers.append(f"{field.id}_({field.id})") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)): field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id constructor_params.append(f"{field_type} {field.id}") constructor_initializers.append(f"{field.id}_({field.id})") elif isinstance(field, ast.TypedefField): field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id constructor_params.append(f"{field_type} {field.id}") constructor_initializers.append(f"{field.id}_(std::move({field.id}))") if not constructor_params: return "" if struct.parent: fields = core.get_unconstrained_parent_fields(struct.parent) + struct.parent.fields parent_constructor_params = [] for field in fields: constraints = [c for c in struct.constraints if c.id == getattr(field, 'id', None)] if field.cond_for: pass elif isinstance(field, (ast.PayloadField, ast.BodyField)): parent_constructor_params.append("std::vector{}") elif isinstance(field, ast.ArrayField): parent_constructor_params.append(f"std::move({field.id})") elif isinstance(field, ast.ScalarField) and constraints: parent_constructor_params.append(f"{constraints[0].value}") elif isinstance(field, ast.ScalarField): parent_constructor_params.append(f"{field.id}") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and constraints): parent_constructor_params.append(f"{field.type_id}::{constraints[0].tag_id}") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)): parent_constructor_params.append(f"{field.id}") elif isinstance(field, ast.TypedefField): parent_constructor_params.append(f"std::move({field.id})") parent_constructor_params = ', '.join(parent_constructor_params) parent_initializer = [f"{struct.parent_id}Builder({parent_constructor_params})"] explicit = 'explicit ' if len(constructor_params) == 1 else '' constructor_params = ', '.join(constructor_params) constructor_initializers = ', '.join(parent_initializer + constructor_initializers) return dedent("""\ {explicit}{constructor_name}({constructor_params}) : {constructor_initializers} {{ {payload_initializer} }}""").format(explicit=explicit, constructor_name=constructor_name, constructor_params=constructor_params, payload_initializer=payload_initializer, constructor_initializers=constructor_initializers) def generate_packet_creator(packet: ast.PacketDeclaration) -> str: """Generate the implementation of the creator for a struct declaration.""" constructor_name = f"{packet.id}Builder" creator_params = [] constructor_params = [] fields = core.get_unconstrained_parent_fields(packet) + packet.fields for field in fields: if field.cond_for: # Scalar fields used as condition for optional fields are treated # as fixed fields since their value is tied to the value of the # optional field. pass elif isinstance(field, (ast.PayloadField, ast.BodyField)): creator_params.append("std::vector payload") constructor_params.append("std::move(payload)") elif isinstance(field, ast.ArrayField) and field.size: element_type = field.type_id or get_cxx_scalar_type(field.width) creator_params.append(f"std::array<{element_type}, {field.size}> {field.id}") constructor_params.append(f"std::move({field.id})") elif isinstance(field, ast.ArrayField): element_type = field.type_id or get_cxx_scalar_type(field.width) creator_params.append(f"std::vector<{element_type}> {field.id}") constructor_params.append(f"std::move({field.id})") elif isinstance(field, ast.ScalarField): backing_type = get_cxx_scalar_type(field.width) field_type = f"std::optional<{backing_type}>" if field.cond else backing_type creator_params.append(f"{field_type} {field.id}") constructor_params.append(f"{field.id}") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)): field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id creator_params.append(f"{field_type} {field.id}") constructor_params.append(f"{field.id}") elif isinstance(field, ast.TypedefField): field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id creator_params.append(f"{field_type} {field.id}") constructor_params.append(f"std::move({field.id})") creator_params = ', '.join(creator_params) constructor_params = ', '.join(constructor_params) return dedent("""\ static std::unique_ptr<{constructor_name}> Create({creator_params}) {{ return std::make_unique<{constructor_name}>({constructor_params}); }}""").format(constructor_name=constructor_name, creator_params=creator_params, constructor_params=constructor_params) def generate_packet_builder(packet: ast.PacketDeclaration) -> str: """Generate the implementation of the Builder class for a packet declaration.""" class_name = f'{packet.id}Builder' parent_class = f'{packet.parent_id}Builder' if packet.parent_id else "pdl::packet::Builder" builder_constructor = generate_packet_constructor(packet, constructor_name=class_name) builder_creator = generate_packet_creator(packet) field_members = generate_packet_field_members(packet) field_serializers = generate_packet_field_serializers(packet) size_getter = generate_packet_size_getter(packet) array_field_size_getters = generate_array_field_size_getters(packet) return dedent("""\ class {class_name} : public {parent_class} {{ public: ~{class_name}() override = default; {class_name}() = default; {class_name}({class_name} const&) = default; {class_name}({class_name}&&) = default; {class_name}& operator=({class_name} const&) = default; {builder_constructor} {builder_creator} void Serialize(std::vector& output) const override {{ {field_serializers} }} size_t GetSize() const override {{ {size_getter} }} {array_field_size_getters} {field_members} }}; """).format(class_name=f'{packet.id}Builder', parent_class=parent_class, builder_constructor=builder_constructor, builder_creator=builder_creator, field_members=indent(field_members, 1), field_serializers=indent(field_serializers, 2), size_getter=indent(size_getter, 1), array_field_size_getters=indent(array_field_size_getters, 1)) def generate_struct_field_parsers(struct: ast.StructDeclaration) -> str: """Generate the struct parser. The validator will extract the fields it can in a pre-parsing phase. """ code = [] parsed_fields = [] post_processing = [] for field in struct.fields: if field.cond_for: # Scalar fields used as condition for optional fields are treated # as fixed fields since their value is tied to the value of the # optional field. pass elif isinstance(field, (ast.PayloadField, ast.BodyField)): code.append("std::vector payload_;") parsed_fields.append("std::move(payload_)") elif isinstance(field, ast.ArrayField) and field.size: element_type = field.type_id or get_cxx_scalar_type(field.width) code.append(f"std::array<{element_type}, {field.size}> {field.id}_;") parsed_fields.append(f"std::move({field.id}_)") elif isinstance(field, ast.ArrayField): element_type = field.type_id or get_cxx_scalar_type(field.width) code.append(f"std::vector<{element_type}> {field.id}_;") parsed_fields.append(f"std::move({field.id}_)") elif isinstance(field, ast.ScalarField): backing_type = get_cxx_scalar_type(field.width) field_type = f"std::optional<{backing_type}>" if field.cond else backing_type code.append(f"{field_type} {field.id}_;") parsed_fields.append(f"{field.id}_") elif (isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration)): field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id code.append(f"{field_type} {field.id}_;") parsed_fields.append(f"{field.id}_") elif isinstance(field, ast.TypedefField): field_type = f"std::optional<{field.type_id}>" if field.cond else field.type_id code.append(f"{field_type} {field.id}_;") parsed_fields.append(f"std::move({field.id}_)") elif isinstance(field, ast.SizeField): code.append(f"{get_cxx_scalar_type(field.width)} {field.field_id}_size;") elif isinstance(field, ast.CountField): code.append(f"{get_cxx_scalar_type(field.width)} {field.field_id}_count;") parser = FieldParser(extract_arrays=True, byteorder=struct.file.byteorder_short) for f in struct.fields: parser.parse(f) parser.done() code.extend(parser.code) parsed_fields = ', '.join(parsed_fields) code.append(f"*output = {struct.id}({parsed_fields});") code.append("return true;") return '\n'.join(code) def generate_struct_declaration(struct: ast.StructDeclaration) -> str: """Generate the implementation of the class for a struct declaration.""" if struct.parent: raise Exception("Struct declaration with parents are not supported") struct_constructor = generate_packet_constructor(struct, constructor_name=struct.id) field_members = generate_packet_field_members(struct) field_parsers = generate_struct_field_parsers(struct) field_serializers = generate_packet_field_serializers(struct) size_getter = generate_packet_size_getter(struct) array_field_size_getters = generate_array_field_size_getters(struct) stringifier = generate_packet_stringifier(struct) return dedent("""\ class {struct_name} : public pdl::packet::Builder {{ public: ~{struct_name}() override = default; {struct_name}() = default; {struct_name}({struct_name} const&) = default; {struct_name}({struct_name}&&) = default; {struct_name}& operator=({struct_name} const&) = default; {struct_constructor} static bool Parse(pdl::packet::slice& span, {struct_name}* output) {{ {field_parsers} }} void Serialize(std::vector& output) const override {{ {field_serializers} }} size_t GetSize() const override {{ {size_getter} }} {array_field_size_getters} {stringifier} {field_members} }}; """).format(struct_name=struct.id, struct_constructor=struct_constructor, field_members=indent(field_members, 1), field_parsers=indent(field_parsers, 2), field_serializers=indent(field_serializers, 2), stringifier=indent(stringifier, 1), size_getter=indent(size_getter, 1), array_field_size_getters=indent(array_field_size_getters, 1)) def run(input: argparse.FileType, output: argparse.FileType, namespace: Optional[str], include_header: List[str], using_namespace: List[str], exclude_declaration: List[str]): file = ast.File.from_json(json.load(input)) core.desugar(file) include_header = '\n'.join([f'#include <{header}>' for header in include_header]) using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace]) open_namespace = f"namespace {namespace} {{" if namespace else "" close_namespace = f"}} // {namespace}" if namespace else "" output.write( dedent("""\ // File generated from {input_name}, with the command: // {input_command} // /!\\ Do not edit by hand #pragma once #include #include #include #include #include #include {include_header} {using_namespace} #ifndef _ASSERT_VALID #ifdef ASSERT #define _ASSERT_VALID ASSERT #else #include #define _ASSERT_VALID assert #endif // ASSERT #endif // !_ASSERT_VALID {open_namespace} """).format(input_name=input.name, input_command=' '.join(sys.argv), include_header=include_header, using_namespace=using_namespace, open_namespace=open_namespace)) # Forward declarations for packet classes. # Required for the friend class specifiers. for d in file.declarations: if isinstance(d, ast.PacketDeclaration): output.write(f"class {d.id}View;\n") for d in file.declarations: if d.id in exclude_declaration: continue if isinstance(d, ast.EnumDeclaration): output.write(generate_enum_declaration(d)) output.write(generate_enum_to_text(d)) elif isinstance(d, ast.PacketDeclaration): output.write(generate_packet_view(d)) output.write(generate_packet_builder(d)) elif isinstance(d, ast.StructDeclaration): output.write(generate_struct_declaration(d)) output.write(f"{close_namespace}\n") def main() -> int: """Generate cxx PDL backend.""" parser = argparse.ArgumentParser(description=__doc__) parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file') parser.add_argument('--namespace', type=str, help='Generated module namespace') parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives') parser.add_argument('--using-namespace', type=str, default=[], action='append', help='Added using namespace statements') parser.add_argument('--exclude-declaration', type=str, default=[], action='append', help='Exclude declaration from the generated output') return run(**vars(parser.parse_args())) if __name__ == '__main__': sys.exit(main())