#!/usr/bin/env python3 def make_lengths(): n = 8 lengths = [] for d in range(1, n + 1): for c in range(1, n + 1): for b in range(1, n + 1): for a in range(1, n + 1): lengths.append((a, b, c, d)) return lengths def print_header(): print("""#![cfg_attr(rustfmt, rustfmt_skip)] #[cfg(target_arch = "x86")] use std::arch::x86::__m256i; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::__m256i; #[repr(C)] pub union Hack { pub v: __m256i, b: [i8; 32], }""") def print_lengths(lengths): print("pub static LENGTH: [u8; {}] = [".format(len(lengths)), end="") for i, (a, b, c, d) in enumerate(lengths): if i % 32 == 0: print("\n ", end="") else: print(" ", end="") print("{},".format(a + b + c + d), end="") print("\n];") def print_decode_shuffle_1(lengths): print( "pub static DECODE_SHUFFLE_1: [Hack; {}] = [".format(len(lengths)), end="" ) for a, b, c, d in lengths: print("\n Hack { b: [", end="") first = True next_byte = 0 for n in [a, b]: for i in range(0, 8): if not first: print(", ", end="") first = False if n > i: byte = next_byte next_byte += 1 else: byte = -1 print("{}".format(byte), end="") for n in [c, d]: for i in range(0, 8): print(", ", end="") if n > i: if next_byte >= 16: byte = next_byte - 16 else: byte = -1 next_byte += 1 else: byte = -1 print(byte, end="") print("] },", end="") print("\n];") def print_decode_shuffle_2(lengths): print( "pub static DECODE_SHUFFLE_2: [Hack; {}] = [".format(len(lengths)), end="" ) for a, b, c, d in lengths: print("\n Hack { b: [", end="") first = True next_byte = a + b for n in [c, d]: for i in range(0, 8): if not first: print(", ", end="") first = False if n > i and next_byte < 16: byte = next_byte next_byte += 1 else: byte = -1 print(byte, end="") for _ in range(0, 16): print(", -1", end="") print("] },", end="") print("\n];") def print_encode_shuffle_1(lengths): print( "pub static ENCODE_SHUFFLE_1: [Hack; {}] = [".format(len(lengths)), end="" ) for a, b, c, d in lengths: print("\n Hack { b: [", end="") first = True base = 0 next_byte = 0 for n in [a, b]: for i in range(0, n): if not first: print(", ", end="") first = False byte = base + i print(byte, end="") next_byte += 1 base += 8 for _ in range(a + b, 16): print(", -1", end="") base = 0 written = 0 for n in [c, d]: for i in range(0, n): if next_byte >= 16: byte = base + i written += 1 print(",", byte, end="") next_byte += 1 base += 8 for _ in range(written, 16): print(", -1", end="") print("] },", end="") print("\n];") def print_encode_shuffle_2(lengths): print( "pub static ENCODE_SHUFFLE_2: [Hack; {}] = [".format(len(lengths)), end="" ) for a, b, c, d in lengths: print("\n Hack { b: [", end="") written = 16 + a + b for i in range(0, written): if i != 0: print(", ", end="") print("-1", end="") next_byte = a + b for idx, n in enumerate([c, d]): for i in range(0, n): if written < 32: if next_byte < 16: byte = idx * 8 + i else: byte = -1 print(",", byte, end="") written += 1 next_byte += 1 for _ in range(written, 32): print(", -1", end="") print("] },", end="") print("\n];") if __name__ == "__main__": lengths = make_lengths() print_header() print() print_lengths(lengths) print() print_decode_shuffle_1(lengths) print() print_decode_shuffle_2(lengths) print() print_encode_shuffle_1(lengths) print() print_encode_shuffle_2(lengths)