#include "../kem.h"

#include <botan/ffi.h>

static const size_t PK_L = 114;
static const size_t SK_L = 120;
static const size_t CT_L = 32;
static const size_t K_L = 32;

size_t KEM_PK_L() {
    return PK_L;
}
size_t KEM_SK_L() {
    return SK_L;
}
size_t KEM_CT_L() {
    return CT_L;
}
size_t KEM_K_L() {
    return K_L;
}

void kem_gen(uint8_t *pk, uint8_t *sk) {
    botan_rng_t rng;
    assert(botan_rng_init(&rng, NULL) == 0);
    botan_privkey_t bsk;
    assert(botan_privkey_create_ecdh(&bsk, rng, "curve25519") == 0);
    assert(botan_rng_destroy(rng) == 0);

    botan_pubkey_t bpk;
    assert(botan_privkey_export_pubkey(&bpk, bsk) == 0);
    size_t _pk_l = PK_L;
    assert(botan_pubkey_export(bpk, pk, &_pk_l, BOTAN_PRIVKEY_EXPORT_FLAG_PEM) == 0);
    assert(_pk_l == PK_L);
    assert(botan_pubkey_destroy(bpk) == 0);

    size_t _sk_l = SK_L;
    assert(botan_privkey_export(bsk, sk, &_sk_l, BOTAN_PRIVKEY_EXPORT_FLAG_PEM) == 0);
    assert(_sk_l == SK_L);
    assert(botan_privkey_destroy(bsk) == 0);
}
void kem_enc(const uint8_t *pk, uint8_t *ct, uint8_t *k) {
    // @Note: no direct KEM for EC
    // @Note: KEM can be implemented using PKE, which is implemented using KA (note that KEM is to be used in CKA)
    botan_pubkey_t bpk;
    assert(botan_pubkey_load(&bpk, pk, PK_L) == 0);
    botan_pk_op_encrypt_t enc;
    int result = botan_pk_op_encrypt_create(&enc, bpk, "N/A", 0); // @TODO: ECIES should be used instead, which is not covered by botan FFI
    printf("(%d)\n", result); // -40, BOTAN_FFI_ERROR_NOT_IMPLEMENTED
    assert(result == 0);

    botan_rng_t rng;
    assert(botan_rng_init(&rng, NULL) == 0);
    assert(botan_rng_get(rng, k, K_L) == 0);
    assert(botan_rng_reseed(rng, K_L) == 0); // @TODO: necessary?

    size_t _ct_l = CT_L;
    assert(botan_pk_op_encrypt(enc, rng, ct, &_ct_l, k, K_L) == 0);
    assert(_ct_l == CT_L);
    assert(botan_rng_destroy(rng) == 0);

    assert(botan_pk_op_encrypt_destroy(enc) == 0);
    assert(botan_pubkey_destroy(bpk) == 0);
}
void kem_dec(const uint8_t *sk, const uint8_t *ct, uint8_t *k) {
    botan_privkey_t bsk;
    assert(botan_privkey_load(&bsk, NULL, sk, SK_L, NULL) == 0);
    botan_pk_op_decrypt_t dec;
    assert(botan_pk_op_decrypt_create(&dec, bsk, "N/A", 0) == 0);

    size_t _k_l = K_L;
    assert(botan_pk_op_decrypt(dec, k, &_k_l, ct, CT_L) == 0);
    assert(_k_l == K_L);

    assert(botan_pk_op_decrypt_destroy(dec) == 0);
    assert(botan_privkey_destroy(bsk) == 0);
}
