# Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 import dataclasses import collections import io import typing from copy import copy from typing import get_type_hints import serde_types as st import serde_binary as sb MAX_LENGTH = (1 << 31) - 1 MAX_U32 = (1 << 32) - 1 MAX_CONTAINER_DEPTH = 500 class BcsSerializer(sb.BinarySerializer): def __init__(self): super().__init__( output=io.BytesIO(), container_depth_budget=MAX_CONTAINER_DEPTH ) def serialize_u32_as_uleb128(self, value: int): while value >= 0x80: b = (value & 0x7F) | 0x80 self.output.write(b.to_bytes(1, "little", signed=False)) value >>= 7 self.output.write(value.to_bytes(1, "little", signed=False)) def serialize_len(self, value: int): if value > MAX_LENGTH: raise st.SerializationError("Length exceeds the maximum supported value.") self.serialize_u32_as_uleb128(value) def serialize_variant_index(self, value: int): if value > MAX_U32: raise st.SerializationError( "Variant index exceeds the maximum supported value." ) self.serialize_u32_as_uleb128(value) def sort_map_entries(self, offsets: typing.List[int]): if len(offsets) < 1: return buf = self.output.getbuffer() offsets.append(len(buf)) slices = [] for i in range(1, len(offsets)): slices.append(bytes(buf[offsets[i - 1] : offsets[i]])) buf.release() slices.sort() self.output.seek(offsets[0]) for s in slices: self.output.write(s) assert offsets[-1] == len(self.output.getbuffer()) class BcsDeserializer(sb.BinaryDeserializer): def __init__(self, content): super().__init__( input=io.BytesIO(content), container_depth_budget=MAX_CONTAINER_DEPTH ) def deserialize_uleb128_as_u32(self) -> int: value = 0 for shift in range(0, 32, 7): byte = int.from_bytes(self.read(1), "little", signed=False) digit = byte & 0x7F value |= digit << shift if value > MAX_U32: raise st.DeserializationError( "Overflow while parsing uleb128-encoded uint32 value" ) if digit == byte: if shift > 0 and digit == 0: raise st.DeserializationError( "Invalid uleb128 number (unexpected zero digit)" ) return value raise st.DeserializationError( "Overflow while parsing uleb128-encoded uint32 value" ) def deserialize_len(self) -> int: value = self.deserialize_uleb128_as_u32() if value > MAX_LENGTH: raise st.DeserializationError("Length exceeds the maximum supported value.") return value def deserialize_variant_index(self) -> int: return self.deserialize_uleb128_as_u32() def check_that_key_slices_are_increasing( self, slice1: typing.Tuple[int, int], slice2: typing.Tuple[int, int] ): key1 = bytes(self.input.getbuffer()[slice1[0] : slice1[1]]) key2 = bytes(self.input.getbuffer()[slice2[0] : slice2[1]]) if key1 >= key2: raise st.DeserializationError( "Serialized keys in a map must be ordered by increasing lexicographic order" ) def serialize(obj: typing.Any, obj_type) -> bytes: serializer = BcsSerializer() serializer.serialize_any(obj, obj_type) return serializer.get_buffer() def deserialize(content: bytes, obj_type) -> typing.Tuple[typing.Any, bytes]: deserializer = BcsDeserializer(content) value = deserializer.deserialize_any(obj_type) return value, deserializer.get_remaining_buffer()