#include "signal.h"

// splay tree dictionary
struct __Signal_dictionary {
    size_t t;
    size_t i;
    const_bytes k;
    Signal_dictionary l;
    Signal_dictionary r;
};
static void Signal_dictionary_init(Signal_dictionary *D) {
    *D = NULL;
}
static void Signal_dictionary_push(Signal_dictionary *D, size_t t, size_t i, const_bytes *k) {
    Signal_dictionary v = *D;
    if (v == NULL) {
        v = malloc(sizeof(_Signal_dictionary));
        v->t = t;
        v->i = i;
        v->k = *k;
        v->l = NULL;
        v->r = NULL;
        *D = v;
        return;
    }
    assert(t != v->t || i != v->i);
    if (t < v->t || (t == v->t && i < v->i)) {
        Signal_dictionary_push(&v->l, t, i, k);
    } else {
        Signal_dictionary_push(&v->r, t, i, k);
    }
}
static const_bytes Signal_dictionary_pop(Signal_dictionary *D, size_t t, size_t i, int remove) {
    Signal_dictionary v = *D;
    if (v == NULL) {
        const_bytes k = {.p = NULL, .l = 0};
        return k;
    }
    const_bytes k;
    if (t == v->t && i == v->i) {
        k = v->k;
    } else if (t < v->t || (t == v->t && i < v->i)) {
        k = Signal_dictionary_pop(&v->l, t, i, 0);
        if (k.p == NULL) {
            return k;
        }
        Signal_dictionary l = v->l;
        v->l = l->r;
        l->r = v;
        *D = v = l;
    } else {
        k = Signal_dictionary_pop(&v->r, t, i, 0);
        if (k.p == NULL) {
            return k;
        }
        Signal_dictionary r = v->r;
        v->r = r->l;
        r->l = v;
        *D = v = r;
    }
    if (remove) {
        Signal_dictionary l = v->l, r = v->r;
        if (l == NULL) {
            *D = r;
        } else {
            *D = l;
            if (r != NULL) {
                while (l->r != NULL) {
                    l = l->r;
                }
                l->r = r;
            }
        }
        free(v);
    }
    return k;
}
static void Signal_dictionary_free(Signal_dictionary *D) {
    Signal_dictionary v = *D;
    if (v == NULL) {
        return;
    }
    Signal_dictionary_free(&v->l);
    Signal_dictionary_free(&v->r);
    free_bytes(&v->k);
    free(v);
    // *D = NULL;
}

#include "cka.h"
#include "prgf.h"
#include "prg.h"
#include "aead.h"

// cf. https://signal.org/docs/specifications/doubleratchet/: 32-byte root key & chain key
static const size_t SIGMA_L = 32;
static const size_t OMEGA_L = 32;

static const uint8_t lambda[32] = {};
static const_bytes lambda_bytes = {.p = lambda, .l = sizeof(lambda)};

static void Signal_init_common(bytes_i k_root, Signal_state *state) {
    state->sigma = alloc_bytes(SIGMA_L);
    prgf_init(k_root, &state->sigma);
    state->omega_send = alloc_bytes(OMEGA_L);
    state->omega_recv = alloc_bytes(OMEGA_L);
    state->T[0].p = NULL;
    state->T[1].p = NULL;
    state->l = 0;
    state->t = 0;
    state->i_send = 0;
    state->i_recv = 0;
    Signal_dictionary_init(&state->D);
}

void Signal_init_send(bytes_i k_root, bytes_i k_CKA, Signal_state *state) {
    Signal_init_common(k_root, state);
    state->id = 0;
    prgf_update(as_const_bytes(&state->sigma), &lambda_bytes, &state->omega_send, &state->sigma);
    cka_init_send(k_CKA, &state->gamma);
}
void Signal_init_recv(bytes_i k_root, bytes_i k_CKA, Signal_state *state) {
    Signal_init_common(k_root, state);
    state->id = 1;
    prgf_update(as_const_bytes(&state->sigma), &lambda_bytes, &state->omega_recv, &state->sigma);
    cka_init_recv(k_CKA, &state->gamma);
}

void Signal_free(Signal_state *state) {
    free_bytes(as_const_bytes(&state->gamma));
    for (size_t i = 0; i < CKA_CT_N(); ++i) {
        if (state->T[i].p != NULL) {
            free_bytes(as_const_bytes(&state->T[i]));
        }
    }
    free_bytes(as_const_bytes(&state->sigma));
    free_bytes(as_const_bytes(&state->omega_send));
    free_bytes(as_const_bytes(&state->omega_recv));
    Signal_dictionary_free(&state->D);
}

void skip(Signal_state *state, size_t t, size_t l) {
    while (state->i_recv < l) {
        ++state->i_recv;
        bytes K = alloc_bytes(AEAD_K_L());
        bytes r[] = {state->omega_recv, K};
        prg(as_const_bytes(&state->omega_recv), r, 2);
        Signal_dictionary_push(&state->D, t, state->i_recv, as_const_bytes(&K));
    }
}
const_bytes try_skipped(Signal_state *state, size_t t, size_t i) {
    return Signal_dictionary_pop(&state->D, t, i, 1);
}

void Signal_send(Signal_state *state, bytes_i m, Signal_header *h, bytes_O c) {
    if (state->t % 2 == state->id) {
        ++state->t;
        state->l = state->i_send;
        state->i_send = 0;
        bytes I, gamma;
        for (size_t i = 0; i < CKA_CT_N(); ++i) {
            if (state->T[i].p != NULL) {
                free_bytes(as_const_bytes(&state->T[i]));
            }
        }
        cka_send(state->id, as_const_bytes(&state->gamma), (bytes *)state->T, &I, &gamma);
        free_bytes(as_const_bytes(&state->gamma));
        state->gamma = gamma;
        prgf_update(as_const_bytes(&state->sigma), as_const_bytes(&I), &state->omega_send, &state->sigma);
        free_bytes(as_const_bytes(&I));
    }
    ++state->i_send;
    bytes K = alloc_bytes(AEAD_K_L());
    bytes r[] = {state->omega_send, K};
    prg(as_const_bytes(&state->omega_send), r, 2);
    size_t n = CKA_CT_N() + 3;
    const_bytes ads[n];
    for (size_t i = 0; i < CKA_CT_N(); ++i) {
        h->T[i] = copy_bytes(as_const_bytes(&state->T[i]));
        ads[i] = to_const_bytes(h->T[i]);
    }
    h->l = state->l;
    h->t = state->t;
    h->i = state->i_send;
    ads[CKA_CT_N() + 0] = (const_bytes){.p = (const uint8_t *)&h->l, .l = sizeof(size_t)};
    ads[CKA_CT_N() + 1] = (const_bytes){.p = (const uint8_t *)&h->t, .l = sizeof(size_t)};
    ads[CKA_CT_N() + 2] = (const_bytes){.p = (const uint8_t *)&h->i, .l = sizeof(size_t)};
    aead_enc(K.p, ads, n, m, c);
    free_bytes(as_const_bytes(&K));
}

int Signal_recv(Signal_state *state, bytes_i c, const Signal_header *h, bool *use_cka_free, bytes_O m) {
    if (h->t % 2 != state->id || h->t > state->t + 1) {
        return 0xffff;
    }
    if (h->t == state->t + 1) {
        skip(state, h->t - 2, h->l);
        ++state->t;
        state->i_recv = 0;
        bytes I, gamma;
        cka_recv(state->id, as_const_bytes(&state->gamma), as_const_bytes((bytes *)h->T), &I, &gamma);
        *use_cka_free = true;
        free_bytes(as_const_bytes(&state->gamma));
        state->gamma = gamma;
        prgf_update(as_const_bytes(&state->sigma), as_const_bytes(&I), &state->omega_recv, &state->sigma);
        free_bytes(as_const_bytes(&I));
    } else {
        *use_cka_free = false;
    }
    const_bytes K = try_skipped(state, h->t, h->i);
    if (K.p == NULL) {
        skip(state, h->t, h->i - 1);
        ++state->i_recv;
        bytes _K = alloc_bytes(AEAD_K_L());
        bytes r[] = {state->omega_recv, _K};
        prg(as_const_bytes(&state->omega_recv), r, 2);
        K = to_const_bytes(_K);
    }
    size_t n = CKA_CT_N() + 3;
    const_bytes ads[n];
    for (size_t i = 0; i < CKA_CT_N(); ++i) {
        ads[i] = to_const_bytes(h->T[i]);
    }
    ads[CKA_CT_N() + 0] = (const_bytes){.p = (const uint8_t *)&h->l, .l = sizeof(size_t)};
    ads[CKA_CT_N() + 1] = (const_bytes){.p = (const uint8_t *)&h->t, .l = sizeof(size_t)};
    ads[CKA_CT_N() + 2] = (const_bytes){.p = (const uint8_t *)&h->i, .l = sizeof(size_t)};
    int result = aead_dec(K.p, ads, n, c, m);
    free_bytes(&K);
    if (result != 0) {
        return result;
    }
    return 0;
}

const bool SIGNAL_FREE_HEADER_AT_SEND = false;

void Signal_free_header(Signal_header *h, bool use_cka_free) {
    if (use_cka_free) {
        cka_free_ct(as_const_bytes((bytes *)h->T));
    } else {
        for (size_t i = 0; i < CKA_CT_N(); ++i) {
            free_bytes(as_const_bytes(&h->T[i]));
        }
    }
}
