#include "../hkdf.h"

#include "../hmac.h"

size_t HKDF_K_L() {
    return HMAC_OUTPUT_L();
}

void hkdf_extract(bytes_i salt, bytes_i input, uint8_t *key) {
    hmac(salt, input, 1, key);
}
void hkdf_expand(const uint8_t *key, bytes_i info, bytes_o outputs, size_t n_outputs) {
    const_bytes _key = {.p = key, .l = HKDF_K_L()};
    bytes output = alloc_bytes(HMAC_OUTPUT_L());
    uint8_t index = 0;
    const_bytes inputs[3] = {{.p = NULL, .l = 0}, *info, {.p = &index, .l = 1}};
    for (size_t i = 0; i < n_outputs; ++i) {
        size_t l = 0;
        while (l < outputs[i].l) {
            if (index == 0 && (i > 0 || l > 0)) { // overflow
                assert(0);
            }
            hmac(&_key, inputs, 3, output.p);
            if (i == 0 && l == 0) {
                inputs[0] = to_const_bytes(output);
            }
            ++index;
            if (l + output.l <= outputs[i].l) {
                memcpy(outputs[i].p + l, output.p, output.l);
                l += output.l;
            } else {
                memcpy(outputs[i].p + l, output.p, outputs[i].l - l);
                l = outputs[i].l;
            }
        }
    }
    free_bytes(as_const_bytes(&output));
}

void hkdf(bytes_i salt, bytes_i input, bytes_i info, bytes_o outputs, size_t n_outputs) {
    uint8_t k[HKDF_K_L()];
    hkdf_extract(salt, input, k);
    hkdf_expand(k, info, outputs, n_outputs);
}
