#include "../aead.h"

#include "../ske.h"
#include "../hmac.h"

#include <string.h>

size_t AEAD_K_L() {
    return SKE_K_L() * 2;
}

void aead_enc(const uint8_t *k, bytes_i ads, size_t n_ads, bytes_i m, bytes_O c) {
    bytes enc;
    ske_enc(k, m, &enc);
    const_bytes k_hmac = {.p = k + SKE_K_L(), .l = SKE_K_L()};
    bytes mac = alloc_bytes(HMAC_OUTPUT_L());
    const_bytes inputs[n_ads + 1];
    for (size_t i = 0; i < n_ads; ++i) {
        inputs[i] = ads[i];
    }
    inputs[n_ads] = to_const_bytes(enc);
    hmac(&k_hmac, inputs, n_ads + 1, mac.p);
    *c = alloc_bytes(enc.l + mac.l); // @TODO: use sparse output
    memcpy(c->p, enc.p, enc.l);
    memcpy(c->p + enc.l, mac.p, mac.l);
    free_bytes(as_const_bytes(&mac));
}
int aead_dec(const uint8_t *k, bytes_i ads, size_t n_ads, bytes_i c, bytes_O m) {
    const_bytes k_hmac = {.p = k + SKE_K_L(), .l = SKE_K_L()};
    bytes mac = alloc_bytes(HMAC_OUTPUT_L());
    const_bytes inputs[n_ads + 1];
    for (size_t i = 0; i < n_ads; ++i) {
        inputs[i] = ads[i];
    }
    inputs[n_ads] = (const_bytes){.p = c->p, .l = c->l - HMAC_OUTPUT_L()};
    hmac(&k_hmac, inputs, n_ads + 1, mac.p);
    if (memcmp(mac.p, c->p + c->l - HMAC_OUTPUT_L(), HMAC_OUTPUT_L()) != 0) {
        free_bytes(as_const_bytes(&mac));
        return 1;
    }
    free_bytes(as_const_bytes(&mac));
    ske_dec(k, &inputs[n_ads], m);
    return 0;
}
