include "../../../arch/x64/Vale.X64.InsBasic.vaf" //include "../../../arch/x64/Vale.X64.InsMem.vaf" include "../../../arch/x64/Vale.X64.InsVector.vaf" //include "../../../arch/x64/Vale.X64.InsAes.vaf" include "Vale.AES.X64.AES128.vaf" include "Vale.AES.X64.AES256.vaf" include "../../../lib/util/x64/Vale.X64.Stack.vaf" include{:fstar}{:open} "Vale.Def.Types_s" include{:fstar}{:open} "FStar.Seq.Base" include{:fstar}{:open} "Vale.AES.AES_s" include{:fstar}{:open} "Vale.X64.Machine_s" include{:fstar}{:open} "Vale.X64.Memory" include{:fstar}{:open} "Vale.X64.State" include{:fstar}{:open} "Vale.X64.Decls" include{:fstar}{:open} "Vale.X64.QuickCode" include{:fstar}{:open} "Vale.X64.QuickCodes" include{:fstar}{:open} "Vale.Arch.Types" include{:fstar}{:open} "Vale.AES.AES256_helpers" include{:fstar}{:open} "Vale.Def.Opaque_s" include{:fstar}{:open} "Vale.AES.GCTR" include{:fstar}{:open} "Vale.X64.CPU_Features_s" module Vale.AES.X64.AES #verbatim{:interface}{:implementation} open Vale.Def.Types_s open FStar.Seq open Vale.Arch.HeapImpl open Vale.AES.AES_s open Vale.X64.Machine_s open Vale.X64.Memory open Vale.X64.State open Vale.X64.Decls open Vale.X64.InsBasic open Vale.X64.InsMem open Vale.X64.InsVector //open Vale.X64.InsAes open Vale.X64.QuickCode open Vale.X64.QuickCodes open Vale.Arch.Types //open Vale.AES.AES_helpers open Vale.AES.AES256_helpers open Vale.Def.Opaque_s open Vale.AES.X64.AES128 open Vale.AES.X64.AES256 open Vale.AES.GCTR open Vale.X64.CPU_Features_s open Vale.X64.Stack #reset-options "--z3rlimit 20" #endverbatim /////////////////////////// // KEY EXPANSION /////////////////////////// procedure KeyExpansionStdcall( inline win:bool, inline alg:algorithm, ghost input_key_b:buffer128, ghost output_key_expansion_b:buffer128) {:public} {:quick} {:exportSpecs} reads rcx; rsi; rdi; heap0; modifies rdx; heap1; memLayout; xmm1; xmm2; xmm3; xmm4; efl; lets key_ptr := if win then rcx else rdi; key_expansion_ptr := if win then rdx else rsi; key := if alg = AES_128 then quad32_to_seq(buffer128_read(input_key_b, 0, mem)) else make_AES256_key(buffer128_read(input_key_b, 0, mem), buffer128_read(input_key_b, 1, mem)); requires is_initial_heap(memLayout, mem); aesni_enabled && avx_enabled && sse_enabled; alg = AES_128 || alg = AES_256; buffers_disjoint128(input_key_b, output_key_expansion_b); validSrcAddrs128(mem, key_ptr, input_key_b, if alg = AES_128 then 1 else 2, memLayout, Secret); validDstAddrs128(mem, key_expansion_ptr, output_key_expansion_b, nr(alg) + 1, memLayout, Secret); ensures modifies_buffer128(output_key_expansion_b, old(mem), mem); forall(j) 0 <= j <= nr(alg) ==> buffer128_read(output_key_expansion_b, j, mem) == index(key_to_round_keys_LE(alg, key), j); { CreateHeaplets(list( declare_buffer128(input_key_b, 0, Secret, Immutable), declare_buffer128(output_key_expansion_b, 1, Secret, Mutable))); inline if (alg = AES_128) { KeyExpansion128Stdcall(win, input_key_b, output_key_expansion_b); } else { KeyExpansion256Stdcall(win, input_key_b, output_key_expansion_b); } DestroyHeaplets(); } /////////////////////////// // ENCRYPTION /////////////////////////// procedure AESEncryptBlock( inline alg:algorithm, ghost input:quad32, ghost key:seq(nat32), ghost round_keys:seq(quad32), ghost keys_buffer:buffer128) {:public} {:quick} reads r8; heap0; memLayout; modifies xmm0; xmm2; efl; requires aesni_enabled && sse_enabled; alg = AES_128 || alg = AES_256; is_aes_key_LE(alg, key); length(round_keys) == nr(alg) + 1; round_keys == key_to_round_keys_LE(alg, key); xmm0 == input; r8 == buffer_addr(keys_buffer, heap0); validSrcAddrs128(heap0, r8, keys_buffer, nr(alg) + 1, memLayout, Secret); forall(i:nat) i < nr(alg) + 1 ==> buffer128_read(keys_buffer, i, heap0) == index(round_keys, i); ensures xmm0 == aes_encrypt_LE(alg, key, input); { inline if (alg = AES_128) { AES128EncryptBlock(input, key, round_keys, keys_buffer); } else { AES256EncryptBlock(input, key, round_keys, keys_buffer); } } /* /////////////////////////// // KEY INVERSION /////////////////////////// procedure KeyInversionRound( inline round:int, // inline taint:taint, ghost k_b:buffer128) // ghost key:aes_key_LE(AES_128), // ghost w:seq(nat32), // ghost dw_in:seq(nat32) // ) returns ( // ghost dw_out:seq(nat32) // ) requires/ensures validSrcAddrs128(heap0, rdx, k_b, 11); requires 0 <= round <= 8; // SeqLength(w) == 44; // SeqLength(dw_in) == 4*(round+1); // KeyExpansionPredicate(key, AES_128, w); // EqInvkey_expansion_partial(key, AES_128, dw_in, round); // forall(j) 0 <= j <= round ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(dw_in[4*j], dw_in[4*j+1], dw_in[4*j+2], dw_in[4*j+3]); // forall(j) round < j <= 10 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(w[4*j], w[4*j+1], w[4*j+2], w[4*j+3]); reads rdx; modifies heap0; xmm1; efl; ensures modifies_buffer_specific128(k_b, old(heap0), heap0, round+1, round+1); // forall(a) (a < rdx || a >= rdx + 176) && old(heap0)[k_b].quads?[a] ==> heap0[k_b].quads?[a] && heap0[k_b].quads[a] == old(heap0)[k_b].quads[a]; // forall(j) round+1 < j <= 10 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(w[4*j], w[4*j+1], w[4*j+2], w[4*j+3]); // SeqLength(dw_out) == 4*(round + 2); // forall(j) 0 <= j <= round+1 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(dw_out[4*j], dw_out[4*j+1], dw_out[4*j+2], dw_out[4*j+3]); // EqInvkey_expansion_partial(key, AES_128, dw_out, round+1); { // assert heap0[k_b].quads[rdx + 16*(round+1)].v == Quadword(w[(round+1)*4], w[(round+1)*4 + 1], w[(round+1)*4 + 2], w[(round+1)*4 + 3]); Load128_buffer(heap0, xmm1, rdx, 16*(round+1), k_b, round+1); // assert xmm1 == Quadword(w[(round+1)*4], w[(round+1)*4 + 1], w[(round+1)*4 + 2], w[(round+1)*4 + 3]); // let ws := SeqRange(w, (round+1)*4, (round+1)*4 + 4); // assert ws[0] == w[(round+1)*4] && ws[1] == w[(round+1)*4+1] && ws[2] == w[(round+1)*4+2] && ws[3] == w[(round+1)*4+3]; // assert xmm1 == seq_to_Quadword(ws); AESNI_imc(xmm1, xmm1); Store128_buffer(heap0, rdx, xmm1, 16*(round+1), k_b, round+1); // dw_out := dw_in + Quadword_to_seq(xmm1); // lemma_KeyInversionRoundHelper(round+1, key, w, dw_in, dw_out); // // forall(j) round+1 < j <= 10 implies heap0[k_b].quads[rdx + 16*j].v == Quadword(w[4*j], w[4*j+1], w[4*j+2], w[4*j+3]) by { // } // // forall(j) 0 <= j <= round + 1 implies heap0[k_b].quads[rdx + 16*j].v == Quadword(dw_out[4*j], dw_out[4*j+1], dw_out[4*j+2], dw_out[4*j+3]) by { // } } procedure KeyInversionRoundUnrolledRecursive( inline rounds:int, // inline taint:taint, ghost k_b:buffer128) // ghost key:aes_key_LE(AES_128), // ghost w:seq(nat32), // ghost dw_in:seq(nat32) // ) returns ( // ghost dw_out:seq(nat32) // ) {:recursive} requires/ensures validSrcAddrs128(heap0, rdx, k_b, 11); requires 0 <= rounds <= 9; rdx % 16 == 0; // SeqLength(w) == 44; // SeqLength(dw_in) == 4; // KeyExpansionPredicate(key, AES_128, w); // EqInvkey_expansion_partial(key, AES_128, dw_in, 0); // heap0[k_b].quads[rdx+16*0].v == Quadword(dw_in[4*0], dw_in[4*0+1], dw_in[4*0+2], dw_in[4*0+3]); // forall(j) 0 <= j <= 0 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(dw_in[4*j], dw_in[4*j+1], dw_in[4*j+2], dw_in[4*j+3]); // forall(j) 0 < j <= 10 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(w[4*j], w[4*j+1], w[4*j+2], w[4*j+3]); reads rdx; modifies heap0; xmm1; efl; ensures modifies_buffer128(k_b, old(heap0), heap0); // forall(a) (a < rdx || a >= rdx + 176) && old(heap0)[k_b].quads?[a] ==> heap0[k_b].quads?[a] && heap0[k_b].quads[a] == old(heap0)[k_b].quads[a]; // SeqLength(dw_out) == 4*(rounds+1); // forall(j) 0 <= j <= rounds ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(dw_out[4*j], dw_out[4*j+1], dw_out[4*j+2], dw_out[4*j+3]); // forall(j) rounds < j <= 10 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(w[4*j], w[4*j+1], w[4*j+2], w[4*j+3]); // EqInvkey_expansion_partial(key, AES_128, dw_out, rounds); { // inline if (0 < rounds <= 9) { // let dw_mid := KeyInversionRoundUnrolledRecursive(rounds-1, taint, k_b, key, w, dw_in); // dw_out := KeyInversionRound(rounds-1, taint, k_b, key, w, dw_mid); // } // else { // dw_out := dw_in; // } } procedure KeyInversionImpl( // ghost key:aes_key_LE(AES_128), // ghost w:seq(nat32), // inline taint:taint, ghost k_b:buffer128) // ) returns ( // ghost dw:seq(nat32) // ) requires/ensures validSrcAddrs128(heap0, rdx, k_b, 11); requires rdx % 16 == 0; // SeqLength(w) == 44; // forall(j) 0 <= j <= 10 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(w[4*j], w[4*j+1], w[4*j+2], w[4*j+3]); // KeyExpansionPredicate(key, AES_128, w); ensures modifies_buffer128(k_b, old(heap0), heap0); // forall(a) (a < rdx || a >= rdx + 176) && old(heap0)[k_b].quads?[a] ==> heap0[k_b].quads?[a] && heap0[k_b].quads[a] == old(heap0)[k_b].quads[a]; // SeqLength(dw) == 44; // EqInvKeyExpansionPredicate(key, AES_128, dw); // forall(j) 0 <= j <= 10 ==> heap0[k_b].quads[rdx + 16*j].v == Quadword(dw[4*j], dw[4*j+1], dw[4*j+2], dw[4*j+3]); reads rdx; modifies heap0; xmm1; efl; { // lemma_KeyExpansionPredicateImpliesExpandKey(key, AES_128, w); // let dw1 := seq(w[0], w[1], w[2], w[3]); // // let dw2 := KeyInversionRoundUnrolledRecursive(9, taint, k_b, key, w, dw1); KeyInversionRoundUnrolledRecursive(9, k_b); // // dw := dw2 + seq(w[40], w[41], w[42], w[43]); // assert SeqLength(dw) == 44; // forall(j) 0 <= j <= 10 implies heap0[k_b].quads[rdx + 16*j].v == Quadword(dw[4*j], dw[4*j+1], dw[4*j+2], dw[4*j+3]) by // { // } // assert EqInvKeyExpansionPredicate(key, AES_128, dw); } procedure KeyExpansionAndInversionStdcall( // inline taint:taint, inline win:bool, ghost input_key_b:buffer128, ghost output_key_expansion_b:buffer128) // ) returns ( // ghost dw:seq(nat32) // ) reads rcx; rsi; rdi; modifies rdx; heap0; xmm1; xmm2; xmm3; efl; requires // HasStackSlots(stack, 2); let key_ptr := if win then rcx else rdi; let key_expansion_ptr := if win then rdx else rsi; key_ptr % 16 == 0; key_expansion_ptr % 16 == 0; validSrcAddrs128(heap0, key_ptr, input_key_b, 1); validDstAddrs128(heap0, key_expansion_ptr, output_key_expansion_b, 11); ensures let key_ptr := if win then rcx else rdi; let key_expansion_ptr := if win then rdx else rsi; // let key := Quadword_to_seq(old(heap0)[input_key_id].quads[key_ptr].v); // SeqLength(dw) == 44; // ValidSrcAddrs(heap0, output_key_expansion_id, key_expansion_ptr, 128, taint, 176); validSrcAddrs128(heap0, key_expansion_ptr, output_key_expansion_b, 11); modifies_buffer128(output_key_expansion_b, old(heap0), heap0); // (forall(a) (a < key_expansion_ptr || a >= key_expansion_ptr + 176) && old(heap0)[output_key_expansion_id].quads?[a] ==> heap0[output_key_expansion_id].quads?[a] && heap0[output_key_expansion_id].quads[a] == old(heap0)[output_key_expansion_id].quads[a]); // (forall(j) 0 <= j <= 10 ==> heap0[output_key_expansion_id].quads[key_expansion_ptr + 16*j].v == Quadword(dw[4*j], dw[4*j+1], dw[4*j+2], dw[4*j+3])); // EqInvKeyExpansionPredicate(key, AES_128, dw); { let key_ptr := if win then rcx else rdi; let key_expansion_ptr := if win then rdx else rsi; // let key := Quadword_to_seq(heap0[input_key_id].quads[key_ptr].v); // LoadStack(eax, 0); // eax := key_ptr (from stack position 0) inline if (win) { Load128_buffer(heap0, xmm1, rcx, 0, input_key_b, 0); } else { Load128_buffer(heap0, xmm1, rdi, 0, input_key_b, 0); Mov64(rdx, rsi); } // let w := KeyExpansionImpl(key, taint, output_key_expansion_id); // expand key from xmm1 to region pointed to by rdx // dw := KeyInversionImpl(key, w, taint, output_key_expansion_id); let key := quad32_to_seq(xmm1); KeyExpansionImpl(key, output_key_expansion_b); // expand key from xmm1 to region pointed to by rdx // forall(j) 0 <= j <= 10 implies heap0[output_key_expansion_id].quads[key_expansion_ptr + 16*j].v == Quadword(dw[4*j], dw[4*j+1], dw[4*j+2], dw[4*j+3]) // { // assert heap0[output_key_expansion_id].quads[rdx + 16*j].v == Quadword(dw[4*j], dw[4*j+1], dw[4*j+2], dw[4*j+3]); // assert rdx == key_expansion_ptr; // } // Clear secrets out of registers // Xor32(eax, eax); Pxor(xmm1, xmm1); Pxor(xmm2, xmm2); Pxor(xmm3, xmm3); } */