Skip to content

src/crypto/hybrid_kem.c

/*
 * QR-NSP Volcanic Edition — Hybrid KEM
 * SPDX-License-Identifier: AGPL-3.0-or-later
 * ML-KEM-1024 ⊕ X25519 — dual classical/quantum security
 *
 * Combiner: ss = SHA3-256(ss_mlkem || ss_x25519 || ct_mlkem || pk_x25519_peer)
 *
 * This follows the NIST SP 800-227 (draft) hybrid KEM approach:
 * both component KEMs must be broken to recover the shared secret.
 *
 * Wire format:
 *   Public key:  [ML-KEM pk (1568)] [X25519 pk (32)]  = 1600 bytes
 *   Secret key:  [ML-KEM sk (3168)] [X25519 sk (32)]  = 3200 bytes
 *   Ciphertext:  [ML-KEM ct (1568)] [X25519 pk (32)]  = 1600 bytes
 *   Shared secret: 32 bytes
 */

#include "mlkem_params.h"
#include <string.h>

/* ─────────────────────────────────────────────
 * X25519 — Curve25519 Diffie-Hellman
 *
 * Minimal implementation using radix-2^51 representation.
 * For production: use monocypher or libsodium.
 * This is a compact reference for zero-dependency compilation.
 * ───────────────────────────────────────────── */

#define X25519_BYTES 32

typedef int64_t fe25519[5]; /* Radix 2^51 representation */

static const int64_t reduce_mask_51 = (1LL << 51) - 1;

/* a]  = 121665 (constant for curve equation) */
static const int64_t a24 = 121665;

static inline uint64_t
fe_load64(const uint8_t *b)
{
    return (uint64_t)b[0]       | ((uint64_t)b[1] << 8)  | ((uint64_t)b[2] << 16) |
           ((uint64_t)b[3] << 24) | ((uint64_t)b[4] << 32) | ((uint64_t)b[5] << 40) |
           ((uint64_t)b[6] << 48) | ((uint64_t)b[7] << 56);
}

static void
fe_frombytes(fe25519 h, const uint8_t s[32])
{
    uint64_t v0 = fe_load64(s);
    uint64_t v1 = fe_load64(s + 6);
    uint64_t v2 = fe_load64(s + 12);
    uint64_t v3 = fe_load64(s + 19);
    uint64_t v4 = fe_load64(s + 24);

    h[0] = (int64_t)(v0 & (uint64_t)reduce_mask_51);
    h[1] = (int64_t)((v1 >> 3) & (uint64_t)reduce_mask_51);
    h[2] = (int64_t)((v2 >> 6) & (uint64_t)reduce_mask_51);
    h[3] = (int64_t)((v3 >> 1) & (uint64_t)reduce_mask_51);
    h[4] = (int64_t)((v4 >> 12) & (uint64_t)reduce_mask_51);
}

static void
fe_tobytes(uint8_t s[32], const fe25519 h)
{
    int64_t t[5];
    memcpy(t, h, sizeof(t));

    /* Reduce */
    for (int i = 0; i < 5; i++) {
        int64_t c = t[i] >> 51;
        t[i] &= reduce_mask_51;
        t[(i + 1) % 5] += c * (i == 4 ? 19 : 1);
    }
    /* Final reduction */
    int64_t c = t[0] >> 51;
    t[0] &= reduce_mask_51;
    t[1] += c;

    /* Serialize as little-endian 256-bit integer */
    uint64_t u = (uint64_t)t[0] | ((uint64_t)t[1] << 51);
    for (int i = 0; i < 8; i++) s[i] = (uint8_t)(u >> (8 * i));
    u = ((uint64_t)t[1] >> 13) | ((uint64_t)t[2] << 38);
    for (int i = 0; i < 8; i++) s[8 + i] = (uint8_t)(u >> (8 * i));
    u = ((uint64_t)t[2] >> 26) | ((uint64_t)t[3] << 25);
    for (int i = 0; i < 8; i++) s[16 + i] = (uint8_t)(u >> (8 * i));
    u = ((uint64_t)t[3] >> 39) | ((uint64_t)t[4] << 12);
    for (int i = 0; i < 8; i++) s[24 + i] = (uint8_t)(u >> (8 * i));
    s[31] &= 0x7F; /* Clear top bit */
}

static void fe_add(fe25519 h, const fe25519 f, const fe25519 g)
{ for (int i = 0; i < 5; i++) h[i] = f[i] + g[i]; }

static void fe_sub(fe25519 h, const fe25519 f, const fe25519 g)
{ for (int i = 0; i < 5; i++) h[i] = f[i] - g[i]; }

static void
fe_mul(fe25519 h, const fe25519 f, const fe25519 g)
{
    /* Schoolbook multiplication with lazy reduction */
    __int128 r[5] = {0};
    for (int i = 0; i < 5; i++) {
        for (int j = 0; j < 5; j++) {
            int idx = i + j;
            __int128 prod = (__int128)f[i] * g[j];
            if (idx >= 5) {
                r[idx - 5] += prod * 19;
            } else {
                r[idx] += prod;
            }
        }
    }
    /* Carry propagation */
    for (int i = 0; i < 5; i++) {
        __int128 c = r[i] >> 51;
        h[i] = (int64_t)(r[i] & reduce_mask_51);
        if (i < 4) r[i + 1] += c;
        else h[0] += (int64_t)c * 19;
    }
}

static void
fe_sq(fe25519 h, const fe25519 f)
{
    fe_mul(h, f, f);
}

static void
fe_mul_scalar(fe25519 h, const fe25519 f, int64_t s)
{
    for (int i = 0; i < 5; i++)
        h[i] = f[i] * s;
    /* Carry */
    for (int i = 0; i < 4; i++) {
        h[i + 1] += h[i] >> 51;
        h[i] &= reduce_mask_51;
    }
    h[0] += (h[4] >> 51) * 19;
    h[4] &= reduce_mask_51;
}

/* Inversion via Fermat: a^(p-2) mod p where p = 2^255 - 19 */
static void
fe_invert(fe25519 out, const fe25519 z)
{
    fe25519 t0, t1, t2, t3;
    int i;

    fe_sq(t0, z);
    fe_sq(t1, t0);
    fe_sq(t1, t1);
    fe_mul(t1, z, t1);
    fe_mul(t0, t0, t1);
    fe_sq(t2, t0);
    fe_mul(t1, t1, t2);
    fe_sq(t2, t1);
    for (i = 1; i < 5; i++) fe_sq(t2, t2);
    fe_mul(t1, t2, t1);
    fe_sq(t2, t1);
    for (i = 1; i < 10; i++) fe_sq(t2, t2);
    fe_mul(t2, t2, t1);
    fe_sq(t3, t2);
    for (i = 1; i < 20; i++) fe_sq(t3, t3);
    fe_mul(t2, t3, t2);
    fe_sq(t2, t2);
    for (i = 1; i < 10; i++) fe_sq(t2, t2);
    fe_mul(t1, t2, t1);
    fe_sq(t2, t1);
    for (i = 1; i < 50; i++) fe_sq(t2, t2);
    fe_mul(t2, t2, t1);
    fe_sq(t3, t2);
    for (i = 1; i < 100; i++) fe_sq(t3, t3);
    fe_mul(t2, t3, t2);
    fe_sq(t2, t2);
    for (i = 1; i < 50; i++) fe_sq(t2, t2);
    fe_mul(t1, t2, t1);
    fe_sq(t1, t1);
    for (i = 1; i < 5; i++) fe_sq(t1, t1);
    fe_mul(out, t1, t0);
}

/*
 * X25519 Montgomery ladder
 * Constant-time scalar multiplication on Curve25519
 */
static void
x25519_scalarmult(uint8_t out[32], const uint8_t scalar[32], const uint8_t point[32])
{
    uint8_t e[32];
    memcpy(e, scalar, 32);
    e[0]  &= 248;
    e[31] &= 127;
    e[31] |= 64;

    fe25519 x1, x2, z2, x3, z3, tmp0, tmp1;
    fe_frombytes(x1, point);

    /* x2 = 1, z2 = 0, x3 = x1, z3 = 1 */
    memset(x2, 0, sizeof(fe25519)); x2[0] = 1;
    memset(z2, 0, sizeof(fe25519));
    memcpy(x3, x1, sizeof(fe25519));
    memset(z3, 0, sizeof(fe25519)); z3[0] = 1;

    int swap = 0;

    for (int pos = 254; pos >= 0; pos--) {
        int b = (e[pos / 8] >> (pos & 7)) & 1;
        swap ^= b;

        /* Constant-time conditional swap */
        for (int i = 0; i < 5; i++) {
            int64_t mask = -(int64_t)swap;
            int64_t d;
            d = mask & (x2[i] ^ x3[i]); x2[i] ^= d; x3[i] ^= d;
            d = mask & (z2[i] ^ z3[i]); z2[i] ^= d; z3[i] ^= d;
        }
        swap = b;

        /* Montgomery ladder step */
        fe25519 A, AA, B, BB, E, C, D, DA, CB;
        fe_add(A, x2, z2);
        fe_sq(AA, A);
        fe_sub(B, x2, z2);
        fe_sq(BB, B);
        fe_sub(E, AA, BB);
        fe_add(C, x3, z3);
        fe_sub(D, x3, z3);
        fe_mul(DA, D, A);
        fe_mul(CB, C, B);

        fe_add(tmp0, DA, CB);
        fe_sq(x3, tmp0);

        fe_sub(tmp0, DA, CB);
        fe_sq(tmp1, tmp0);
        fe_mul(z3, x1, tmp1);

        fe_mul(x2, AA, BB);
        fe_mul_scalar(tmp0, E, a24);
        fe_add(tmp0, tmp0, AA);
        fe_mul(z2, E, tmp0);
    }

    /* Final swap */
    for (int i = 0; i < 5; i++) {
        int64_t mask = -(int64_t)swap;
        int64_t d;
        d = mask & (x2[i] ^ x3[i]); x2[i] ^= d; x3[i] ^= d;
        d = mask & (z2[i] ^ z3[i]); z2[i] ^= d; z3[i] ^= d;
    }

    /* out = x2 / z2 */
    fe_invert(z2, z2);
    fe_mul(x2, x2, z2);
    fe_tobytes(out, x2);
}

/* X25519 base point (9) */
static const uint8_t x25519_basepoint[32] = { 9 };

static void
x25519_keypair(uint8_t pk[32], uint8_t sk[32])
{
    /* sk is random; pk = sk * G */
    x25519_scalarmult(pk, sk, x25519_basepoint);
}

/* ═════════════════════════════════════════════
 * Hybrid KEM: ML-KEM-1024 ⊕ X25519
 * ═════════════════════════════════════════════ */

#define HYBRID_PUBLICKEYBYTES  (MLKEM_PUBLICKEYBYTES + X25519_BYTES)   /* 1600 */
#define HYBRID_SECRETKEYBYTES  (MLKEM_SECRETKEYBYTES + X25519_BYTES)   /* 3200 */
#define HYBRID_CIPHERTEXTBYTES (MLKEM_CIPHERTEXTBYTES + X25519_BYTES)  /* 1600 */
#define HYBRID_SSBYTES         32

/* Defined in kem.c */
extern int randombytes(uint8_t *out, size_t len);

int
hybrid_keypair_generate(uint8_t pk[HYBRID_PUBLICKEYBYTES],
                        uint8_t sk[HYBRID_SECRETKEYBYTES])
{
    /* ML-KEM keypair */
    mlkem_keypair kp;
    if (mlkem_keypair_generate(&kp) != 0) return -1;
    memcpy(pk, kp.pk, MLKEM_PUBLICKEYBYTES);
    memcpy(sk, kp.sk, MLKEM_SECRETKEYBYTES);

    /* X25519 keypair */
    uint8_t x_sk[32];
    if (randombytes(x_sk, 32) != 0) return -1;
    uint8_t x_pk[32];
    x25519_keypair(x_pk, x_sk);

    memcpy(pk + MLKEM_PUBLICKEYBYTES, x_pk, 32);
    memcpy(sk + MLKEM_SECRETKEYBYTES, x_sk, 32);

    memset(&kp, 0, sizeof(kp));
    memset(x_sk, 0, 32);
    return 0;
}

int
hybrid_encapsulate(uint8_t ct[HYBRID_CIPHERTEXTBYTES],
                   uint8_t ss[HYBRID_SSBYTES],
                   const uint8_t pk[HYBRID_PUBLICKEYBYTES])
{
    /* ML-KEM encapsulate */
    uint8_t ss_mlkem[MLKEM_SSBYTES];
    if (mlkem_encapsulate(ct, ss_mlkem, pk) != 0) return -1;

    /* X25519 ephemeral DH */
    uint8_t x_ek[32]; /* ephemeral secret */
    if (randombytes(x_ek, 32) != 0) return -1;

    uint8_t x_epk[32]; /* ephemeral public */
    x25519_keypair(x_epk, x_ek);

    uint8_t ss_x25519[32];
    x25519_scalarmult(ss_x25519, x_ek, pk + MLKEM_PUBLICKEYBYTES);

    /* ct = [mlkem_ct || x25519_epk] */
    memcpy(ct + MLKEM_CIPHERTEXTBYTES, x_epk, 32);

    /* Combine: ss = SHA3-256(ss_mlkem || ss_x25519 || mlkem_ct || x25519_peer_pk) */
    uint8_t combiner_input[MLKEM_SSBYTES + 32 + MLKEM_CIPHERTEXTBYTES + 32];
    uint8_t *p = combiner_input;
    memcpy(p, ss_mlkem, MLKEM_SSBYTES); p += MLKEM_SSBYTES;
    memcpy(p, ss_x25519, 32);           p += 32;
    memcpy(p, ct, MLKEM_CIPHERTEXTBYTES); p += MLKEM_CIPHERTEXTBYTES;
    memcpy(p, pk + MLKEM_PUBLICKEYBYTES, 32);

    mlkem_hash_h(ss, combiner_input, sizeof(combiner_input));

    /* Zeroize */
    memset(x_ek, 0, 32);
    memset(ss_mlkem, 0, sizeof(ss_mlkem));
    memset(ss_x25519, 0, sizeof(ss_x25519));
    memset(combiner_input, 0, sizeof(combiner_input));

    return 0;
}

int
hybrid_decapsulate(uint8_t ss[HYBRID_SSBYTES],
                   const uint8_t ct[HYBRID_CIPHERTEXTBYTES],
                   const uint8_t sk[HYBRID_SECRETKEYBYTES],
                   const uint8_t pk[HYBRID_PUBLICKEYBYTES])
{
    /* ML-KEM decapsulate */
    uint8_t ss_mlkem[MLKEM_SSBYTES];
    mlkem_decapsulate(ss_mlkem, ct, sk);

    /* X25519 DH with ephemeral public key from ct */
    const uint8_t *x_epk = ct + MLKEM_CIPHERTEXTBYTES;
    const uint8_t *x_sk  = sk + MLKEM_SECRETKEYBYTES;

    uint8_t ss_x25519[32];
    x25519_scalarmult(ss_x25519, x_sk, x_epk);

    /* Same combiner as encapsulate */
    uint8_t combiner_input[MLKEM_SSBYTES + 32 + MLKEM_CIPHERTEXTBYTES + 32];
    uint8_t *p = combiner_input;
    memcpy(p, ss_mlkem, MLKEM_SSBYTES); p += MLKEM_SSBYTES;
    memcpy(p, ss_x25519, 32);           p += 32;
    memcpy(p, ct, MLKEM_CIPHERTEXTBYTES); p += MLKEM_CIPHERTEXTBYTES;
    memcpy(p, pk + MLKEM_PUBLICKEYBYTES, 32);

    mlkem_hash_h(ss, combiner_input, sizeof(combiner_input));

    memset(ss_mlkem, 0, sizeof(ss_mlkem));
    memset(ss_x25519, 0, sizeof(ss_x25519));
    memset(combiner_input, 0, sizeof(combiner_input));

    return 0;
}