module Hacl.Policies

module ST = FStar.HyperStack.ST

open FStar.HyperStack.All

open FStar.HyperStack.ST
open FStar.Buffer

open Hacl.Types

(* Module abbreviations *)
module HS  = FStar.HyperStack
module B   = FStar.Buffer
module U8  = FStar.UInt8
module U32 = FStar.UInt32
module U64 = FStar.UInt64
module U128 = FStar.UInt128
module H8  = Hacl.UInt8
module H32  = Hacl.UInt32
module H64  = Hacl.UInt64
module H128  = Hacl.UInt128

assume val declassify_u8: x:H8.t -> Tot (y:U8.t{H8.v x = U8.v y})
assume val declassify_u32: x:H32.t -> Tot (y:U32.t{H32.v x = U32.v y})
assume val declassify_u64: x:H64.t -> Tot (y:U64.t{H64.v x = U64.v y})
assume val declassify_u128: x:H128.t -> Tot (y:U128.t{H128.v x = U128.v y})

#reset-options "--initial_fuel 0 --max_fuel 0 --initial_ifuel 0 --max_ifuel 0 --z3rlimit 50"

private val lemma_not_equal_slice: #a:Type -> b1:Seq.seq a -> b2:Seq.seq a -> i:nat -> j:nat ->
  k:nat{i <= j /\ i <= k /\ j <= k /\ k <= Seq.length b1 /\ k <= Seq.length b2 } ->
  Lemma
    (requires ~(Seq.equal (Seq.slice b1 i j) (Seq.slice b2 i j)))
    (ensures  ~(Seq.equal (Seq.slice b1 i k) (Seq.slice b2 i k)))
let lemma_not_equal_slice #a b1 b2 i j k =
  assert (forall (n:nat{n < k - i}). Seq.index (Seq.slice b1 i k) n == Seq.index b1 (n + i))

private val lemma_not_equal_last: #a:Type -> b1:Seq.seq a -> b2:Seq.seq a -> i:nat ->
  j:nat{i < j /\ j <= Seq.length b1 /\ j <= Seq.length b2} ->
  Lemma
    (requires ~(Seq.index b1 (j - 1) == Seq.index b2 (j - 1)))
    (ensures  ~(Seq.equal (Seq.slice b1 i j) (Seq.slice b2 i j)))
let lemma_not_equal_last #a b1 b2 i j =
  Seq.lemma_index_slice b1 i j (j - i - 1);
  Seq.lemma_index_slice b2 i j (j - i - 1)

val cmp_bytes_:
  b1:uint8_p ->
  b2:uint8_p ->
  len:u32{U32.v len <= length b1 /\ U32.v len <= length b2} ->
  tmp:uint8_p{length tmp == 1 /\ disjoint_2 tmp b1 b2} ->
  Stack h8
    (requires (fun h -> live h b1 /\ live h b2 /\ live h tmp /\ H8.v (get h tmp 0) == 255))
    (ensures  (fun h0 z h1 -> modifies_1 tmp h0 h1 /\
      (H8.v z == 0 \/ H8.v z == 255) /\
      (H8.v z == 255 <==> equal h0 (sub b1 0ul len) h0 (sub b2 0ul len))))
let rec cmp_bytes_ b1 b2 len tmp =
  UInt.logand_lemma_1 #8 255;
  UInt.logand_lemma_2 #8 255;
  UInt.logand_lemma_1 #8 0;
  UInt.logand_lemma_2 #8 0;
  let h0 = ST.get() in
  let inv h (i: nat) =
    let z = get h tmp 0 in
    live h b1 /\ live h b2 /\ live h tmp /\ modifies_1 tmp h0 h /\ 0 <= i /\ i <= U32.v len /\
    (H8.v z == 255 \/ H8.v z == 0) /\
    (H8.v z == 255 <==> equal h0 (sub b1 0ul U32.(uint_to_t i)) h0 (sub b2 0ul U32.(uint_to_t i)))
  in
  let f (i:U32.t{U32.(0 <= v i /\ v i < v len)}) :
    Stack unit
      (requires (fun h -> inv h (U32.v i)))
      (ensures  (fun h0 _ h1 -> U32.(inv h0 (v i) /\ inv h1 (v i + 1)))) =
    let bi1 = b1.(i) in
    let bi2 = b2.(i) in
    let z0 = tmp.(0ul) in
    tmp.(0ul) <- H8.(eq_mask bi1 bi2 &^ z0);
    let h = ST.get() in
    if H8.v (get h tmp 0) = 255 then
      begin
      let s1  = as_seq h0 (sub b1 0ul U32.(i +^ 1ul)) in
      let s2  = as_seq h0 (sub b2 0ul U32.(i +^ 1ul)) in
      let s1' = as_seq h0 (sub b1 0ul i) in
      let s2' = as_seq h0 (sub b2 0ul i) in
      assert (Seq.equal s1' s2');
      assert FStar.Seq.(index s1 (U32.v i) == index s2 (U32.v i));
      assert (Seq.equal s1 (Seq.snoc s1' (Seq.index s1 (U32.v i))));
      assert (Seq.equal s2 (Seq.snoc s2' (Seq.index s2 (U32.v i))));
      Seq.lemma_eq_intro s1 s2
      end
    else if H8.v z0 = 0 then
      lemma_not_equal_slice (as_seq h0 b1) (as_seq h0 b2) 0 (U32.v i) (U32.v i + 1)
    else
      lemma_not_equal_last (as_seq h0 b1) (as_seq h0 b2) 0 (U32.v i + 1)
  in
  C.Compat.Loops.for 0ul len inv f;
  tmp.(0ul)


val cmp_bytes:
  b1:uint8_p ->
  b2:uint8_p ->
  len:u32{U32.v len <= length b1 /\ U32.v len <= length b2} ->
  Stack h8
    (requires (fun h -> live h b1 /\ live h b2))
    (ensures  (fun h0 z h1 -> modifies_0 h0 h1 /\
      (H8.v z == 0 <==> equal h0 (sub b1 0ul len) h0 (sub b2 0ul len))))
let cmp_bytes b1 b2 len =
  push_frame();
  let h0 = ST.get() in
  let tmp = Buffer.create (Hacl.Cast.uint8_to_sint8 255uy) 1ul in
  let h1 = ST.get() in
  let z = cmp_bytes_ b1 b2 len tmp in
  let h2 = ST.get() in
  pop_frame();
  UInt.lognot_lemma_1 #8;
  UInt.lognot_self #8 0;
  lemma_modifies_0_1' tmp h0 h1 h2;
  H8.lognot z