#! /usr/bin/env python3 # Run this file using pytest, either in this folder or at the root of the # project. Since test_vectors.json is generated from bao.py, it's slightly # cheating to then test bao.py against its own output. But at least this helps # us notice changes, since the vectors are checked in rather than generated # every time. Testing the Rust implementation against the same test vectors # gives us some confidence that they're correct. from binascii import hexlify, unhexlify import io import json from pathlib import Path import subprocess import tempfile # Imports from this directory. import bao import generate_input HERE = Path(__file__).parent BAO_PATH = HERE / "bao.py" VECTORS_PATH = HERE / "test_vectors.json" VECTORS = json.load(VECTORS_PATH.open()) # Wrapper functions # ================= # # Most of the functions in bao.py (except bao_encode) work with streams. These # wrappers work with bytes, and return hashes as strings, which makes them # easier to test. def bao_hash(content): return hexlify(bao.bao_hash(io.BytesIO(content))).decode("utf-8") def blake3(b): return bao_hash(b) def bao_encode(content): # Note that unlike the other functions, this one already takes bytes. encoded, hash_ = bao.bao_encode(content, outboard=False) return encoded, hash_.hex() def bao_encode_outboard(content): # Note that unlike the other functions, this one already takes bytes. outboard, hash_ = bao.bao_encode(content, outboard=True) return outboard, hash_.hex() def bao_decode(hash, encoded): hashbytes = unhexlify(hash) output = io.BytesIO() bao.bao_decode(io.BytesIO(encoded), output, hashbytes) return output.getvalue() def bao_decode_outboard(hash, content, outboard): hashbytes = unhexlify(hash) output = io.BytesIO() bao.bao_decode( io.BytesIO(content), output, hashbytes, outboard_stream=io.BytesIO(outboard) ) return output.getvalue() def bao_slice(encoded, slice_start, slice_len): output = io.BytesIO() bao.bao_slice(io.BytesIO(encoded), output, slice_start, slice_len) return output.getvalue() def bao_slice_outboard(content, outboard, slice_start, slice_len): output = io.BytesIO() bao.bao_slice( io.BytesIO(content), output, slice_start, slice_len, outboard_stream=io.BytesIO(outboard), ) return output.getvalue() def bao_decode_slice(slice_bytes, hash, slice_start, slice_len): hashbytes = unhexlify(hash) output = io.BytesIO() bao.bao_decode_slice( io.BytesIO(slice_bytes), output, hashbytes, slice_start, slice_len ) return output.getvalue() # Tests # ===== def test_hashes(): for case in VECTORS["hash"]: input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) expected_hash = case["bao_hash"] computed_hash = bao_hash(input_bytes) assert expected_hash == computed_hash def bao_cli(*args, input=None, should_fail=False): output = subprocess.run( ["python3", str(BAO_PATH), *args], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL if should_fail else None, input=input, ) cmd = " ".join(["bao.py"] + list(args)) if should_fail: assert output.returncode != 0, "`{}` should've failed".format(cmd) else: assert output.returncode == 0, "`{}` failed".format(cmd) return output.stdout def test_hash_cli(): # CLI tests just use the final (largest) test vector in each set, to avoid # shelling out hundreds of times. There's no need to exhaustively test the # implementation via the CLI, because it's tested on its own above. # Instead, we just need to verify once that it's hooked up properly. case = VECTORS["hash"][-1] input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) expected_hash = case["bao_hash"] computed_hash = bao_cli("hash", input=input_bytes).decode().strip() assert expected_hash == computed_hash def assert_decode_failure(f, *args): try: f(*args) except (AssertionError, IOError): pass else: raise AssertionError("failure expected, but no exception raised") def test_encoded(): for case in VECTORS["encode"]: input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) output_len = case["output_len"] expected_bao_hash = case["bao_hash"] encoded_blake3 = case["encoded_blake3"] corruptions = case["corruptions"] # First make sure the encoded output is what it's supposed to be. encoded, hash_ = bao_encode(input_bytes) assert expected_bao_hash == hash_ assert output_len == len(encoded) assert encoded_blake3 == blake3(encoded) # Now test decoding. output = bao_decode(hash_, encoded) assert input_bytes == output # Make sure decoding with the wrong hash fails. wrong_hash = "0" * len(hash_) assert_decode_failure(bao_decode, wrong_hash, encoded) # Make sure each of the corruption points causes decoding to fail. for c in corruptions: corrupted = bytearray(encoded) corrupted[c] ^= 1 assert_decode_failure(bao_decode, hash_, corrupted) def make_tempfile(b=b""): f = tempfile.NamedTemporaryFile() f.write(b) f.flush() f.seek(0) return f def test_encoded_cli(): case = VECTORS["encode"][-1] input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) output_len = case["output_len"] expected_bao_hash = case["bao_hash"] encoded_blake3 = case["encoded_blake3"] # First make sure the encoded output is what it's supposed to be. input_file = make_tempfile(input_bytes) encoded_file = make_tempfile() bao_cli("encode", input_file.name, encoded_file.name) encoded = encoded_file.read() assert output_len == len(encoded) assert encoded_blake3 == blake3(encoded) # Now test decoding. output = bao_cli("decode", expected_bao_hash, encoded_file.name) assert input_bytes == output # Make sure decoding with the wrong hash fails. wrong_hash = "0" * len(expected_bao_hash) bao_cli("decode", wrong_hash, encoded_file.name, should_fail=True) def test_outboard(): for case in VECTORS["outboard"]: input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) output_len = case["output_len"] expected_bao_hash = case["bao_hash"] encoded_blake3 = case["encoded_blake3"] outboard_corruptions = case["outboard_corruptions"] input_corruptions = case["input_corruptions"] # First make sure the encoded output is what it's supposed to be. outboard, hash_ = bao_encode_outboard(input_bytes) assert expected_bao_hash == hash_ assert output_len == len(outboard) assert encoded_blake3 == blake3(outboard) # Now test decoding. output = bao_decode_outboard(hash_, input_bytes, outboard) assert input_bytes == output # Make sure decoding with the wrong hash fails. wrong_hash = "0" * len(hash_) assert_decode_failure(bao_decode_outboard, wrong_hash, input_bytes, outboard) # Make sure each of the outboard corruption points causes decoding to # fail. for c in outboard_corruptions: corrupted = bytearray(outboard) corrupted[c] ^= 1 assert_decode_failure(bao_decode_outboard, hash_, input_bytes, corrupted) # Make sure each of the input corruption points causes decoding to # fail. for c in input_corruptions: corrupted = bytearray(input_bytes) corrupted[c] ^= 1 assert_decode_failure(bao_decode_outboard, hash_, corrupted, outboard) def test_outboard_cli(): case = VECTORS["outboard"][-1] input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) output_len = case["output_len"] expected_bao_hash = case["bao_hash"] encoded_blake3 = case["encoded_blake3"] # First make sure the encoded output is what it's supposed to be. input_file = make_tempfile(input_bytes) outboard_file = make_tempfile() bao_cli("encode", input_file.name, "--outboard", outboard_file.name) outboard = outboard_file.read() assert output_len == len(outboard) assert encoded_blake3 == blake3(outboard) # Now test decoding. output = bao_cli( "decode", expected_bao_hash, input_file.name, "--outboard", outboard_file.name ) assert input_bytes == output # Make sure decoding with the wrong hash fails. wrong_hash = "0" * len(expected_bao_hash) output = bao_cli( "decode", wrong_hash, input_file.name, "--outboard", outboard_file.name, should_fail=True, ) def test_slices(): for case in VECTORS["slice"]: input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) expected_bao_hash = case["bao_hash"] slices = case["slices"] encoded, hash_ = bao_encode(input_bytes) outboard, hash_outboard = bao_encode_outboard(input_bytes) assert expected_bao_hash == hash_ assert expected_bao_hash == hash_outboard for slice_case in slices: slice_start = slice_case["start"] slice_len = slice_case["len"] output_len = slice_case["output_len"] output_blake3 = slice_case["output_blake3"] corruptions = slice_case["corruptions"] # Make sure the slice output is what it should be. slice_bytes = bao_slice(encoded, slice_start, slice_len) assert output_len == len(slice_bytes) assert output_blake3 == blake3(slice_bytes) # Make sure slicing an outboard tree is the same. outboard_slice_bytes = bao_slice_outboard( input_bytes, outboard, slice_start, slice_len ) assert slice_bytes == outboard_slice_bytes # Test decoding the slice, and compare it to the input. Note that # slicing a byte array in Python allows indices past the end of the # array, and sort of silently caps them. input_slice = input_bytes[slice_start:][:slice_len] output = bao_decode_slice(slice_bytes, hash_, slice_start, slice_len) assert input_slice == output # Make sure decoding with the wrong hash fails. wrong_hash = "0" * len(hash_) assert_decode_failure( bao_decode_slice, slice_bytes, wrong_hash, slice_start, slice_len ) # Make sure each of the slice corruption points causes decoding to # fail. for c in corruptions: corrupted = bytearray(slice_bytes) corrupted[c] ^= 1 assert_decode_failure( bao_decode_slice, corrupted, hash_, slice_start, slice_len ) def test_slices_cli(): case = VECTORS["slice"][-1] input_len = case["input_len"] input_bytes = generate_input.input_bytes(input_len) expected_bao_hash = case["bao_hash"] slices = case["slices"] input_file = make_tempfile(input_bytes) encoded_file = make_tempfile() bao_cli("encode", input_file.name, encoded_file.name) outboard_file = make_tempfile() bao_cli("encode", input_file.name, "--outboard", outboard_file.name) # Use the first slice in the list. Currently they're all the same length. slice_case = slices[0] slice_start = slice_case["start"] slice_len = slice_case["len"] output_len = slice_case["output_len"] output_blake3 = slice_case["output_blake3"] # Make sure the slice output is what it should be. slice_bytes = bao_cli("slice", str(slice_start), str(slice_len), encoded_file.name) assert output_len == len(slice_bytes) assert output_blake3 == blake3(slice_bytes) # Make sure slicing an outboard tree is the same. outboard_slice_bytes = bao_cli( "slice", str(slice_start), str(slice_len), input_file.name, "--outboard", outboard_file.name, ) assert slice_bytes == outboard_slice_bytes # Test decoding the slice, and compare it to the input. Note that # slicing a byte array in Python allows indices past the end of the # array, and sort of silently caps them. input_slice = input_bytes[slice_start:][:slice_len] output = bao_cli( "decode-slice", expected_bao_hash, str(slice_start), str(slice_len), input=slice_bytes, ) assert input_slice == output # Make sure decoding with the wrong hash fails. wrong_hash = "0" * len(expected_bao_hash) bao_cli( "decode-slice", wrong_hash, str(slice_start), str(slice_len), input=slice_bytes, should_fail=True, )