#include "packing.h"
#include "params.h"
#include "poly.h"
#include "polyvec.h"


/*************************************************
* Name:        PQCLEAN_DILITHIUM2AES_CLEAN_pack_pk
*
* Description: Bit-pack public key pk = (rho, t1).
*
* Arguments:   - uint8_t pk[]: output byte array
*              - const uint8_t rho[]: byte array containing rho
*              - const polyveck *t1: pointer to vector t1
**************************************************/
void PQCLEAN_DILITHIUM2AES_CLEAN_pack_pk(uint8_t pk[PQCLEAN_DILITHIUM2AES_CLEAN_CRYPTO_PUBLICKEYBYTES],
        const uint8_t rho[SEEDBYTES],
        const polyveck *t1) {
    unsigned int i;

    for (i = 0; i < SEEDBYTES; ++i) {
        pk[i] = rho[i];
    }
    pk += SEEDBYTES;

    for (i = 0; i < K; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyt1_pack(pk + i * POLYT1_PACKEDBYTES, &t1->vec[i]);
    }
}

/*************************************************
* Name:        PQCLEAN_DILITHIUM2AES_CLEAN_unpack_pk
*
* Description: Unpack public key pk = (rho, t1).
*
* Arguments:   - const uint8_t rho[]: output byte array for rho
*              - const polyveck *t1: pointer to output vector t1
*              - uint8_t pk[]: byte array containing bit-packed pk
**************************************************/
void PQCLEAN_DILITHIUM2AES_CLEAN_unpack_pk(uint8_t rho[SEEDBYTES],
        polyveck *t1,
        const uint8_t pk[PQCLEAN_DILITHIUM2AES_CLEAN_CRYPTO_PUBLICKEYBYTES]) {
    unsigned int i;

    for (i = 0; i < SEEDBYTES; ++i) {
        rho[i] = pk[i];
    }
    pk += SEEDBYTES;

    for (i = 0; i < K; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyt1_unpack(&t1->vec[i], pk + i * POLYT1_PACKEDBYTES);
    }
}

/*************************************************
* Name:        PQCLEAN_DILITHIUM2AES_CLEAN_pack_sk
*
* Description: Bit-pack secret key sk = (rho, tr, key, t0, s1, s2).
*
* Arguments:   - uint8_t sk[]: output byte array
*              - const uint8_t rho[]: byte array containing rho
*              - const uint8_t tr[]: byte array containing tr
*              - const uint8_t key[]: byte array containing key
*              - const polyveck *t0: pointer to vector t0
*              - const polyvecl *s1: pointer to vector s1
*              - const polyveck *s2: pointer to vector s2
**************************************************/
void PQCLEAN_DILITHIUM2AES_CLEAN_pack_sk(uint8_t sk[PQCLEAN_DILITHIUM2AES_CLEAN_CRYPTO_SECRETKEYBYTES],
        const uint8_t rho[SEEDBYTES],
        const uint8_t tr[SEEDBYTES],
        const uint8_t key[SEEDBYTES],
        const polyveck *t0,
        const polyvecl *s1,
        const polyveck *s2) {
    unsigned int i;

    for (i = 0; i < SEEDBYTES; ++i) {
        sk[i] = rho[i];
    }
    sk += SEEDBYTES;

    for (i = 0; i < SEEDBYTES; ++i) {
        sk[i] = key[i];
    }
    sk += SEEDBYTES;

    for (i = 0; i < SEEDBYTES; ++i) {
        sk[i] = tr[i];
    }
    sk += SEEDBYTES;

    for (i = 0; i < L; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyeta_pack(sk + i * POLYETA_PACKEDBYTES, &s1->vec[i]);
    }
    sk += L * POLYETA_PACKEDBYTES;

    for (i = 0; i < K; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyeta_pack(sk + i * POLYETA_PACKEDBYTES, &s2->vec[i]);
    }
    sk += K * POLYETA_PACKEDBYTES;

    for (i = 0; i < K; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyt0_pack(sk + i * POLYT0_PACKEDBYTES, &t0->vec[i]);
    }
}

/*************************************************
* Name:        PQCLEAN_DILITHIUM2AES_CLEAN_unpack_sk
*
* Description: Unpack secret key sk = (rho, tr, key, t0, s1, s2).
*
* Arguments:   - const uint8_t rho[]: output byte array for rho
*              - const uint8_t tr[]: output byte array for tr
*              - const uint8_t key[]: output byte array for key
*              - const polyveck *t0: pointer to output vector t0
*              - const polyvecl *s1: pointer to output vector s1
*              - const polyveck *s2: pointer to output vector s2
*              - uint8_t sk[]: byte array containing bit-packed sk
**************************************************/
void PQCLEAN_DILITHIUM2AES_CLEAN_unpack_sk(uint8_t rho[SEEDBYTES],
        uint8_t tr[SEEDBYTES],
        uint8_t key[SEEDBYTES],
        polyveck *t0,
        polyvecl *s1,
        polyveck *s2,
        const uint8_t sk[PQCLEAN_DILITHIUM2AES_CLEAN_CRYPTO_SECRETKEYBYTES]) {
    unsigned int i;

    for (i = 0; i < SEEDBYTES; ++i) {
        rho[i] = sk[i];
    }
    sk += SEEDBYTES;

    for (i = 0; i < SEEDBYTES; ++i) {
        key[i] = sk[i];
    }
    sk += SEEDBYTES;

    for (i = 0; i < SEEDBYTES; ++i) {
        tr[i] = sk[i];
    }
    sk += SEEDBYTES;

    for (i = 0; i < L; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyeta_unpack(&s1->vec[i], sk + i * POLYETA_PACKEDBYTES);
    }
    sk += L * POLYETA_PACKEDBYTES;

    for (i = 0; i < K; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyeta_unpack(&s2->vec[i], sk + i * POLYETA_PACKEDBYTES);
    }
    sk += K * POLYETA_PACKEDBYTES;

    for (i = 0; i < K; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyt0_unpack(&t0->vec[i], sk + i * POLYT0_PACKEDBYTES);
    }
}

/*************************************************
* Name:        PQCLEAN_DILITHIUM2AES_CLEAN_pack_sig
*
* Description: Bit-pack signature sig = (c, z, h).
*
* Arguments:   - uint8_t sig[]: output byte array
*              - const uint8_t *c: pointer to PQCLEAN_DILITHIUM2AES_CLEAN_challenge hash length SEEDBYTES
*              - const polyvecl *z: pointer to vector z
*              - const polyveck *h: pointer to hint vector h
**************************************************/
void PQCLEAN_DILITHIUM2AES_CLEAN_pack_sig(uint8_t sig[PQCLEAN_DILITHIUM2AES_CLEAN_CRYPTO_BYTES],
        const uint8_t c[SEEDBYTES],
        const polyvecl *z,
        const polyveck *h) {
    unsigned int i, j, k;

    for (i = 0; i < SEEDBYTES; ++i) {
        sig[i] = c[i];
    }
    sig += SEEDBYTES;

    for (i = 0; i < L; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyz_pack(sig + i * POLYZ_PACKEDBYTES, &z->vec[i]);
    }
    sig += L * POLYZ_PACKEDBYTES;

    /* Encode h */
    for (i = 0; i < OMEGA + K; ++i) {
        sig[i] = 0;
    }

    k = 0;
    for (i = 0; i < K; ++i) {
        for (j = 0; j < N; ++j) {
            if (h->vec[i].coeffs[j] != 0) {
                sig[k++] = (uint8_t) j;
            }
        }

        sig[OMEGA + i] = (uint8_t) k;
    }
}

/*************************************************
* Name:        PQCLEAN_DILITHIUM2AES_CLEAN_unpack_sig
*
* Description: Unpack signature sig = (c, z, h).
*
* Arguments:   - uint8_t *c: pointer to output PQCLEAN_DILITHIUM2AES_CLEAN_challenge hash
*              - polyvecl *z: pointer to output vector z
*              - polyveck *h: pointer to output hint vector h
*              - const uint8_t sig[]: byte array containing
*                bit-packed signature
*
* Returns 1 in case of malformed signature; otherwise 0.
**************************************************/
int PQCLEAN_DILITHIUM2AES_CLEAN_unpack_sig(uint8_t c[SEEDBYTES],
        polyvecl *z,
        polyveck *h,
        const uint8_t sig[PQCLEAN_DILITHIUM2AES_CLEAN_CRYPTO_BYTES]) {
    unsigned int i, j, k;

    for (i = 0; i < SEEDBYTES; ++i) {
        c[i] = sig[i];
    }
    sig += SEEDBYTES;

    for (i = 0; i < L; ++i) {
        PQCLEAN_DILITHIUM2AES_CLEAN_polyz_unpack(&z->vec[i], sig + i * POLYZ_PACKEDBYTES);
    }
    sig += L * POLYZ_PACKEDBYTES;

    /* Decode h */
    k = 0;
    for (i = 0; i < K; ++i) {
        for (j = 0; j < N; ++j) {
            h->vec[i].coeffs[j] = 0;
        }

        if (sig[OMEGA + i] < k || sig[OMEGA + i] > OMEGA) {
            return 1;
        }

        for (j = k; j < sig[OMEGA + i]; ++j) {
            /* Coefficients are ordered for strong unforgeability */
            if (j > k && sig[j] <= sig[j - 1]) {
                return 1;
            }
            h->vec[i].coeffs[sig[j]] = 1;
        }

        k = sig[OMEGA + i];
    }

    /* Extra indices are zero for strong unforgeability */
    for (j = k; j < OMEGA; ++j) {
        if (sig[j]) {
            return 1;
        }
    }

    return 0;
}