#! /usr/bin/env python3 # This is an example implementation of Bao, with the goal of being as readable # as possible and generating test vectors. There are a few differences that # make this code much simpler than the Rust version: # # 1. This version's encode implementation buffers all input and output in # memory. The Rust version uses a more complicated tree-flipping strategy to # avoid using extra storage. # 2. This version isn't incremental. The Rust version provides incremental # encoders and decoders, which accept small reads and writes from the # caller, and that requires more bookkeeping. # 3. This version doesn't support arbitrary seeking. The most complicated bit # of bookkeeping in the Rust version is seeking in the incremental decoder. # # Some more specific details about how each part of this implementation works: # # *bao_decode*, *bao_slice*, and *bao_decode_slice* are recursive streaming # implementations. Recursion is easy here because the length header at the # start of the encoding tells us all we need to know about the layout of the # tree. The pre-order layout means that neither of the decode functions needs # to seek (though bao_slice does, to skip the parts that aren't in the slice). # # *bao_hash* (identical to the BLAKE3 hash function) is an iterative streaming # implementation, which is closer to an incremental implementation than the # recursive functions are. Recursion doesn't work well here, because we don't # know the length of the input in advance. Instead, we keep a stack of subtrees # filled so far, merging them as we go along. There is a very cute trick, where # the number of subtree hashes that should remain in the stack is the same as # the number of 1's in the binary representation of the count of chunks so far. # (E.g. If you've read 255 chunks so far, then you have 8 partial subtrees. One # of 128 chunks, one of 64 chunks, and so on. After you read the 256th chunk, # you can merge all of those into a single subtree.) That, plus the fact that # merging is always done smallest-to-largest / at the top of the stack, means # that we don't need to remember the size of each subtree; just the hash is # enough. # # *bao_encode* is a recursive implementation, but as noted above, it's not # streaming. Instead, to keep things simple, it buffers the entire input and # output in memory. The Rust implementation uses a more complicated # tree-flipping strategy to avoid hogging memory like this, where it writes the # output tree first in a post-order layout, and then does a second pass # back-to-front to flip it in place to pre-order. __doc__ = """\ Usage: bao.py hash [...] bao.py encode ( | --outboard=) bao.py decode [] [] [--outboard=] bao.py slice [] [] [--outboard=] bao.py decode-slice [] [] """ import binascii import docopt import hmac import sys # the BLAKE3 initialization constants IV = [ 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, ] # the BLAKE3 message schedule MSG_SCHEDULE = [ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8], [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1], [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6], [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4], [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7], [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13], ] BLOCK_SIZE = 64 CHUNK_SIZE = 1024 KEY_SIZE = 32 HASH_SIZE = 32 PARENT_SIZE = 2 * HASH_SIZE WORD_BITS = 32 WORD_BYTES = 4 WORD_MAX = 2**WORD_BITS - 1 HEADER_SIZE = 8 # domain flags CHUNK_START = 1 << 0 CHUNK_END = 1 << 1 PARENT = 1 << 2 ROOT = 1 << 3 KEYED_HASH = 1 << 4 DERIVE_KEY = 1 << 5 # finalization flags IS_ROOT = object() NOT_ROOT = object() def wrapping_add(a, b): return (a + b) & WORD_MAX def rotate_right(x, n): return (x >> n | x << (WORD_BITS - n)) & WORD_MAX # The BLAKE3 G function. This is historically related to the ChaCha # "quarter-round" function, though note that a BLAKE3 round is more like a # ChaCha "double-round", and the round function below calls G eight times. def g(state, a, b, c, d, x, y): state[a] = wrapping_add(state[a], wrapping_add(state[b], x)) state[d] = rotate_right(state[d] ^ state[a], 16) state[c] = wrapping_add(state[c], state[d]) state[b] = rotate_right(state[b] ^ state[c], 12) state[a] = wrapping_add(state[a], wrapping_add(state[b], y)) state[d] = rotate_right(state[d] ^ state[a], 8) state[c] = wrapping_add(state[c], state[d]) state[b] = rotate_right(state[b] ^ state[c], 7) # the BLAKE3 round function def round(state, msg_words, schedule): # Mix the columns. g(state, 0, 4, 8, 12, msg_words[schedule[0]], msg_words[schedule[1]]) g(state, 1, 5, 9, 13, msg_words[schedule[2]], msg_words[schedule[3]]) g(state, 2, 6, 10, 14, msg_words[schedule[4]], msg_words[schedule[5]]) g(state, 3, 7, 11, 15, msg_words[schedule[6]], msg_words[schedule[7]]) # Mix the rows. g(state, 0, 5, 10, 15, msg_words[schedule[8]], msg_words[schedule[9]]) g(state, 1, 6, 11, 12, msg_words[schedule[10]], msg_words[schedule[11]]) g(state, 2, 7, 8, 13, msg_words[schedule[12]], msg_words[schedule[13]]) g(state, 3, 4, 9, 14, msg_words[schedule[14]], msg_words[schedule[15]]) def words_from_bytes(buf): words = [0] * (len(buf) // WORD_BYTES) for word_i in range(len(words)): words[word_i] = int.from_bytes( buf[word_i * WORD_BYTES : (word_i + 1) * WORD_BYTES], "little" ) return words def bytes_from_words(words): buf = bytearray(len(words) * WORD_BYTES) for word_i in range(len(words)): buf[WORD_BYTES * word_i : WORD_BYTES * (word_i + 1)] = words[word_i].to_bytes( WORD_BYTES, "little" ) return buf # The truncated BLAKE3 compression function. This implementation does not # support extended output. def compress(cv, block, block_len, offset, flags): block_words = words_from_bytes(block) state = [ cv[0], cv[1], cv[2], cv[3], cv[4], cv[5], cv[6], cv[7], IV[0], IV[1], IV[2], IV[3], offset & WORD_MAX, (offset >> WORD_BITS) & WORD_MAX, block_len, flags, ] for round_number in range(7): round(state, block_words, MSG_SCHEDULE[round_number]) return [state[i] ^ state[i + 8] for i in range(8)] # Compute a BLAKE3 chunk chaining value. def chunk_chaining_value(chunk_bytes, chunk_index, finalization): cv = IV[:] i = 0 flags = CHUNK_START while len(chunk_bytes) - i > BLOCK_SIZE: block = chunk_bytes[i : i + BLOCK_SIZE] cv = compress(cv, block, BLOCK_SIZE, chunk_index, flags) flags = 0 i += BLOCK_SIZE flags |= CHUNK_END if finalization is IS_ROOT: flags |= ROOT block = chunk_bytes[i:] block_len = len(block) block += b"\0" * (BLOCK_SIZE - block_len) cv = compress(cv, block, block_len, chunk_index, flags) return bytes_from_words(cv) # Compute a BLAKE3 parent node chaining value. def parent_chaining_value(parent_bytes, finalization): cv = IV[:] flags = PARENT if finalization is IS_ROOT: flags |= ROOT cv = compress(cv, parent_bytes, BLOCK_SIZE, 0, flags) return bytes_from_words(cv) # Verify a chunk chaining value with a constant-time comparison. def verify_chunk(expected_cv, chunk_bytes, chunk_index, finalization): found_cv = chunk_chaining_value(chunk_bytes, chunk_index, finalization) assert hmac.compare_digest(expected_cv, found_cv), "hash mismatch" # Verify a parent node chaining value with a constant-time comparison. def verify_parent(expected_cv, parent_bytes, finalization): found_cv = parent_chaining_value(parent_bytes, finalization) assert hmac.compare_digest(expected_cv, found_cv), "hash mismatch" # The standard read() function is allowed to return fewer bytes than requested # for a number of different reasons, including but not limited to EOF. To # guarantee we get the bytes we need, we have to call it in a loop. def read_exact(stream, n): out = bytearray(n) # initialized to n zeros mv = memoryview(out) while mv: n = stream.readinto(mv) # read into `out` without an extra copy if n == 0: raise IOError("unexpected EOF") mv = mv[n:] # move the memoryview forward return out def encode_len(content_len): return content_len.to_bytes(HEADER_SIZE, "little") def decode_len(len_bytes): return int.from_bytes(len_bytes, "little") # Left subtrees contain the largest possible power of two chunks, with at least # one byte left for the right subtree. def left_len(parent_len): available_chunks = (parent_len - 1) // CHUNK_SIZE power_of_two_chunks = 2 ** (available_chunks.bit_length() - 1) return CHUNK_SIZE * power_of_two_chunks def bao_encode(buf, *, outboard=False): chunk_index = 0 def encode_recurse(buf, finalization): nonlocal chunk_index if len(buf) <= CHUNK_SIZE: chunk_cv = chunk_chaining_value(buf, chunk_index, finalization) chunk_encoded = b"" if outboard else buf chunk_index += 1 return chunk_encoded, chunk_cv llen = left_len(len(buf)) # Interior nodes have no len suffix. left_encoded, left_cv = encode_recurse(buf[:llen], NOT_ROOT) right_encoded, right_cv = encode_recurse(buf[llen:], NOT_ROOT) node = left_cv + right_cv encoded = node + left_encoded + right_encoded return encoded, parent_chaining_value(node, finalization) # Only this topmost call sets a non-None finalization. encoded, hash_ = encode_recurse(buf, IS_ROOT) # The final output prefixes the encoded length. output = encode_len(len(buf)) + encoded return output, hash_ def bao_decode(input_stream, output_stream, hash_, *, outboard_stream=None): tree_stream = outboard_stream or input_stream chunk_index = 0 def decode_recurse(subtree_cv, content_len, finalization): nonlocal chunk_index if content_len <= CHUNK_SIZE: chunk = read_exact(input_stream, content_len) verify_chunk(subtree_cv, chunk, chunk_index, finalization) chunk_index += 1 output_stream.write(chunk) else: parent = read_exact(tree_stream, PARENT_SIZE) verify_parent(subtree_cv, parent, finalization) left_cv, right_cv = parent[:HASH_SIZE], parent[HASH_SIZE:] llen = left_len(content_len) # Interior nodes have no len suffix. decode_recurse(left_cv, llen, NOT_ROOT) decode_recurse(right_cv, content_len - llen, NOT_ROOT) # The first HEADER_SIZE bytes are the encoded content len. content_len = decode_len(read_exact(tree_stream, HEADER_SIZE)) decode_recurse(hash_, content_len, IS_ROOT) # This is identical to the BLAKE3 hash function. def bao_hash(input_stream): buf = b"" chunks = 0 subtrees = [] while True: # We ask for CHUNK_SIZE bytes, but be careful, we can always get fewer. read = input_stream.read(CHUNK_SIZE) # If the read is EOF, do a final rollup merge of all the subtrees we # have, and pass the finalization flag for hashing the root node. if not read: if chunks == 0: # This is the only chunk and therefore the root. return chunk_chaining_value(buf, chunks, IS_ROOT) new_subtree = chunk_chaining_value(buf, chunks, NOT_ROOT) while len(subtrees) > 1: parent = subtrees.pop() + new_subtree new_subtree = parent_chaining_value(parent, NOT_ROOT) return parent_chaining_value(subtrees[0] + new_subtree, IS_ROOT) # If we already had a full chunk buffered, hash it and merge subtrees # before adding in bytes we just read into the buffer. This order or # operations means we know the finalization is non-root. if len(buf) >= CHUNK_SIZE: new_subtree = chunk_chaining_value(buf[:CHUNK_SIZE], chunks, NOT_ROOT) chunks += 1 # This is the very cute trick described at the top. total_after_merging = bin(chunks).count("1") while len(subtrees) + 1 > total_after_merging: parent = subtrees.pop() + new_subtree new_subtree = parent_chaining_value(parent, NOT_ROOT) subtrees.append(new_subtree) buf = buf[CHUNK_SIZE:] buf = buf + read # Round up to the next full chunk, and remember that the empty tree still # counts as one chunk. def count_chunks(content_len): if content_len == 0: return 1 return (content_len + CHUNK_SIZE - 1) // CHUNK_SIZE # A subtree of N chunks always has N-1 parent nodes. def encoded_subtree_size(content_len, outboard=False): parents_size = PARENT_SIZE * (count_chunks(content_len) - 1) return parents_size if outboard else parents_size + content_len def bao_slice( input_stream, output_stream, slice_start, slice_len, outboard_stream=None ): tree_stream = outboard_stream or input_stream content_len_bytes = read_exact(tree_stream, HEADER_SIZE) output_stream.write(content_len_bytes) content_len = decode_len(content_len_bytes) # Slicing try to read at least one byte. if slice_len == 0: slice_len = 1 slice_end = slice_start + slice_len # Seeking past EOF still needs to validate the final chunk. The easiest way # to do that is to repoint slice_start to be the byte right before the end. if slice_start >= content_len: slice_start = content_len - 1 if content_len > 0 else 0 def slice_recurse(subtree_start, subtree_len): subtree_end = subtree_start + subtree_len if subtree_end <= slice_start: # Seek past the current subtree. parent_nodes_size = encoded_subtree_size(subtree_len, outboard=True) # `1` here means seek from the current position. tree_stream.seek(parent_nodes_size, 1) input_stream.seek(subtree_len, 1) elif slice_end <= subtree_start: # We've sliced all the requested content, and we're done. pass elif subtree_len <= CHUNK_SIZE: # The current subtree is just a chunk. Read the whole thing. The # recipient will need the whole thing to verify its hash, # regardless of whether it overlaps slice_end. chunk = read_exact(input_stream, subtree_len) output_stream.write(chunk) else: # We need to read a parent node and recurse into the current # subtree. parent = read_exact(tree_stream, PARENT_SIZE) output_stream.write(parent) llen = left_len(subtree_len) slice_recurse(subtree_start, llen) slice_recurse(subtree_start + llen, subtree_len - llen) slice_recurse(0, content_len) # Note that unlike bao_slice, there is no optional outboard parameter. Slices # can be created from either a combined our outboard tree, but the resulting # slice itself is always combined. def bao_decode_slice(input_stream, output_stream, hash_, slice_start, slice_len): content_len_bytes = read_exact(input_stream, HEADER_SIZE) content_len = decode_len(content_len_bytes) # Always try to verify at least one byte. But don't output it unless the # caller asked for it. skip_output = False if slice_len == 0: slice_len = 1 skip_output = True slice_end = slice_start + slice_len # As above, if slice_start is past EOF, we repoint it to the last byte of # the encoding, to make sure that the final chunk gets validated. But # again, don't emit bytes unless the caller asked for them. if slice_start >= content_len: slice_start = content_len - 1 if content_len > 0 else 0 skip_output = True def decode_slice_recurse(subtree_start, subtree_len, subtree_cv, finalization): subtree_end = subtree_start + subtree_len # Check content_len before skipping subtrees, to be sure we don't skip # validating the empty chunk. if subtree_end <= slice_start and content_len > 0: # This subtree isn't part of the slice. Keep going. pass elif slice_end <= subtree_start and content_len > 0: # We've verified all the requested content, and we're done. pass elif subtree_len <= CHUNK_SIZE: # The current subtree is just a chunk. Verify the whole thing, and # then output however many bytes we need. chunk = read_exact(input_stream, subtree_len) chunk_index = subtree_start // CHUNK_SIZE verify_chunk(subtree_cv, chunk, chunk_index, finalization) chunk_start = max(0, min(subtree_len, slice_start - subtree_start)) chunk_end = max(0, min(subtree_len, slice_end - subtree_start)) if not skip_output: output_stream.write(chunk[chunk_start:chunk_end]) else: # We need to read a parent node and recurse into the current # subtree. Note that the finalization is always NOT_ROOT after this # point. parent = read_exact(input_stream, PARENT_SIZE) verify_parent(subtree_cv, parent, finalization) left_cv, right_cv = parent[:HASH_SIZE], parent[HASH_SIZE:] llen = left_len(subtree_len) decode_slice_recurse(subtree_start, llen, left_cv, NOT_ROOT) decode_slice_recurse( subtree_start + llen, subtree_len - llen, right_cv, NOT_ROOT ) decode_slice_recurse(0, content_len, hash_, IS_ROOT) def open_input(maybe_path): if maybe_path is None or maybe_path == "-": return sys.stdin.buffer return open(maybe_path, "rb") def open_output(maybe_path): if maybe_path is None or maybe_path == "-": return sys.stdout.buffer return open(maybe_path, "w+b") def main(): args = docopt.docopt(__doc__) in_stream = open_input(args[""]) out_stream = open_output(args[""]) if args["encode"]: outboard = False if args["--outboard"] is not None: outboard = True out_stream = open_output(args["--outboard"]) encoded, _ = bao_encode(in_stream.read(), outboard=outboard) out_stream.write(encoded) elif args["decode"]: hash_ = binascii.unhexlify(args[""]) outboard_stream = None if args["--outboard"] is not None: outboard_stream = open_input(args["--outboard"]) bao_decode(in_stream, out_stream, hash_, outboard_stream=outboard_stream) elif args["hash"]: inputs = args[""] if len(inputs) > 0: # This loop just crashes on IO errors, which is fine for testing. for name in inputs: hash_ = bao_hash(open_input(name)) if len(inputs) > 1: print("{} {}".format(hash_.hex(), name)) else: print(hash_.hex()) else: hash_ = bao_hash(in_stream) print(hash_.hex()) elif args["slice"]: outboard_stream = None if args["--outboard"] is not None: outboard_stream = open_input(args["--outboard"]) bao_slice( in_stream, out_stream, int(args[""]), int(args[""]), outboard_stream, ) elif args["decode-slice"]: hash_ = binascii.unhexlify(args[""]) bao_decode_slice( in_stream, out_stream, hash_, int(args[""]), int(args[""]) ) if __name__ == "__main__": main()