// Copyright (c) Facebook, Inc. and its affiliates // SPDX-License-Identifier: MIT OR Apache-2.0 package serde import ( "bytes" "errors" "fmt" "unicode/utf8" ) // `BinaryDeserializer` is a partial implementation of the `Deserializer` interface. // It is used as an embedded struct by the Bincode and BCS deserializers. type BinaryDeserializer struct { Buffer *bytes.Buffer Input []byte containerDepthBudget uint64 } func NewBinaryDeserializer(input []byte, max_container_depth uint64) *BinaryDeserializer { return &BinaryDeserializer{ Buffer: bytes.NewBuffer(input), Input: input, containerDepthBudget: max_container_depth, } } func (d *BinaryDeserializer) IncreaseContainerDepth() error { if d.containerDepthBudget == 0 { return errors.New("exceeded maximum container depth") } d.containerDepthBudget -= 1 return nil } func (d *BinaryDeserializer) DecreaseContainerDepth() { d.containerDepthBudget += 1 } // `deserializeLen` to be provided by the extending struct. func (d *BinaryDeserializer) DeserializeBytes(deserializeLen func() (uint64, error)) ([]byte, error) { len, err := deserializeLen() if err != nil { return nil, err } ret := make([]byte, len) n, err := d.Buffer.Read(ret) if err == nil && uint64(n) < len { return nil, errors.New("input is too short") } return ret, err } // `deserializeLen` to be provided by the extending struct. func (d *BinaryDeserializer) DeserializeStr(deserializeLen func() (uint64, error)) (string, error) { bytes, err := d.DeserializeBytes(deserializeLen) if err != nil { return "", err } if !utf8.Valid(bytes) { return "", errors.New("invalid UTF8 string") } return string(bytes), nil } func (d *BinaryDeserializer) DeserializeBool() (bool, error) { ret, err := d.Buffer.ReadByte() if err != nil { return false, err } switch ret { case 0: return false, nil case 1: return true, nil default: return false, fmt.Errorf("invalid bool byte: expected 0 / 1, but got %d", ret) } } func (d *BinaryDeserializer) DeserializeUnit() (struct{}, error) { return struct{}{}, nil } // DeserializeChar is unimplemented. func (d *BinaryDeserializer) DeserializeChar() (rune, error) { return 0, errors.New("unimplemented") } func (d *BinaryDeserializer) DeserializeU8() (uint8, error) { ret, err := d.Buffer.ReadByte() return uint8(ret), err } func (d *BinaryDeserializer) DeserializeU16() (uint16, error) { var ret uint16 for i := 0; i < 8*2; i += 8 { b, err := d.Buffer.ReadByte() if err != nil { return 0, err } ret = ret | uint16(b)<