Skip to content

src/crypto/symmetric.c

/*
 * QR-NSP Volcanic Edition — Symmetric Primitives (SHA-3 / SHAKE)
 * SPDX-License-Identifier: AGPL-3.0-or-later
 * Wrapper for ML-KEM's hash function requirements (FIPS 203 §4.1)
 *
 * ML-KEM requires:
 *   H  = SHA3-256   (hash H)
 *   G  = SHA3-512   (hash G — key derivation)
 *   J  = SHAKE-256  (implicit rejection KDF)
 *   XOF = SHAKE-128 (matrix sampling)
 *   PRF = SHAKE-256 (pseudorandom function for CBD sampling)
 *
 * This file provides a minimal Keccak-f[1600] implementation.
 * For production: replace with optimized XKCP (Keccak Code Package)
 * or libcrypto. This exists for zero-dependency compilation.
 */

#include "mlkem_params.h"
#include <string.h>

/* ─────────────────────────────────────────────
 * Keccak-f[1600] State
 * ───────────────────────────────────────────── */

typedef struct {
    uint64_t s[25];
    unsigned int pos;
    unsigned int rate;   /* In bytes: 168 for SHAKE-128, 136 for SHAKE-256/SHA3-256, 72 for SHA3-512 */
} keccak_state;

/* ─────────────────────────────────────────────
 * Keccak-f[1600] Round Constants
 * ───────────────────────────────────────────── */

static const uint64_t keccak_rc[24] = {
    0x0000000000000001ULL, 0x0000000000008082ULL,
    0x800000000000808aULL, 0x8000000080008000ULL,
    0x000000000000808bULL, 0x0000000080000001ULL,
    0x8000000080008081ULL, 0x8000000000008009ULL,
    0x000000000000008aULL, 0x0000000000000088ULL,
    0x0000000080008009ULL, 0x000000008000000aULL,
    0x000000008000808bULL, 0x800000000000008bULL,
    0x8000000000008089ULL, 0x8000000000008003ULL,
    0x8000000000008002ULL, 0x8000000000000080ULL,
    0x000000000000800aULL, 0x800000008000000aULL,
    0x8000000080008081ULL, 0x8000000000008080ULL,
    0x0000000080000001ULL, 0x8000000080008008ULL
};

/* Rotation offsets */
static const unsigned int keccak_rotc[24] = {
     1,  3,  6, 10, 15, 21, 28, 36, 45, 55,  2, 14,
    27, 41, 56,  8, 25, 43, 62, 18, 39, 61, 20, 44
};

static const unsigned int keccak_piln[24] = {
    10,  7, 11, 17, 18,  3,  5, 16,  8, 21, 24,  4,
    15, 23, 19, 13, 12,  2, 20, 14, 22,  9,  6,  1
};

static inline uint64_t
rotl64(uint64_t x, unsigned int n)
{
    return (x << n) | (x >> (64 - n));
}

/* ─────────────────────────────────────────────
 * Keccak-f[1600] Permutation (24 rounds)
 * ───────────────────────────────────────────── */

static void
keccak_f1600(uint64_t s[25])
{
    uint64_t t, bc[5];

    for (int round = 0; round < 24; round++) {
        /* θ step */
        for (int i = 0; i < 5; i++)
            bc[i] = s[i] ^ s[i + 5] ^ s[i + 10] ^ s[i + 15] ^ s[i + 20];
        for (int i = 0; i < 5; i++) {
            t = bc[(i + 4) % 5] ^ rotl64(bc[(i + 1) % 5], 1);
            for (int j = 0; j < 25; j += 5)
                s[j + i] ^= t;
        }

        /* ρ and π steps */
        t = s[1];
        for (int i = 0; i < 24; i++) {
            unsigned int j = keccak_piln[i];
            bc[0] = s[j];
            s[j] = rotl64(t, keccak_rotc[i]);
            t = bc[0];
        }

        /* χ step */
        for (int j = 0; j < 25; j += 5) {
            for (int i = 0; i < 5; i++)
                bc[i] = s[j + i];
            for (int i = 0; i < 5; i++)
                s[j + i] ^= (~bc[(i + 1) % 5]) & bc[(i + 2) % 5];
        }

        /* ι step */
        s[0] ^= keccak_rc[round];
    }
}

/* ─────────────────────────────────────────────
 * Sponge Absorb / Squeeze
 * ───────────────────────────────────────────── */

static void
keccak_init(keccak_state *state, unsigned int rate)
{
    memset(state, 0, sizeof(*state));
    state->rate = rate;
}

static void
keccak_absorb(keccak_state *state, const uint8_t *in, size_t inlen)
{
    unsigned int rate = state->rate;
    uint8_t *s = (uint8_t *)state->s;

    while (inlen > 0) {
        size_t chunk = rate - state->pos;
        if (chunk > inlen) chunk = inlen;

        for (size_t i = 0; i < chunk; i++)
            s[state->pos + i] ^= in[i];

        state->pos += (unsigned int)chunk;
        in += chunk;
        inlen -= chunk;

        if (state->pos == rate) {
            keccak_f1600(state->s);
            state->pos = 0;
        }
    }
}

static void
keccak_finalize(keccak_state *state, uint8_t suffix)
{
    uint8_t *s = (uint8_t *)state->s;
    s[state->pos] ^= suffix;
    s[state->rate - 1] ^= 0x80;
    keccak_f1600(state->s);
    state->pos = 0;
}

static void
keccak_squeeze(keccak_state *state, uint8_t *out, size_t outlen)
{
    unsigned int rate = state->rate;
    uint8_t *s = (uint8_t *)state->s;

    while (outlen > 0) {
        if (state->pos == rate) {
            keccak_f1600(state->s);
            state->pos = 0;
        }

        size_t chunk = rate - state->pos;
        if (chunk > outlen) chunk = outlen;

        memcpy(out, s + state->pos, chunk);
        state->pos += (unsigned int)chunk;
        out += chunk;
        outlen -= chunk;
    }
}

/* ─────────────────────────────────────────────
 * SHA3-256: H(m) → 32 bytes
 * ───────────────────────────────────────────── */

static void
sha3_256(uint8_t h[32], const uint8_t *in, size_t inlen)
{
    keccak_state state;
    keccak_init(&state, 136);    /* rate = 1600 - 2×256 = 1088 bits = 136 bytes */
    keccak_absorb(&state, in, inlen);
    keccak_finalize(&state, 0x06); /* SHA-3 domain separator */
    keccak_squeeze(&state, h, 32);
}

/* ─────────────────────────────────────────────
 * SHA3-512: G(m) → 64 bytes
 * ───────────────────────────────────────────── */

static void
sha3_512(uint8_t h[64], const uint8_t *in, size_t inlen)
{
    keccak_state state;
    keccak_init(&state, 72);     /* rate = 1600 - 2×512 = 576 bits = 72 bytes */
    keccak_absorb(&state, in, inlen);
    keccak_finalize(&state, 0x06);
    keccak_squeeze(&state, h, 64);
}

/* ─────────────────────────────────────────────
 * SHAKE-128: XOF for matrix sampling
 * ───────────────────────────────────────────── */

static keccak_state xof_state; /* Thread-local in production */

void
mlkem_xof_absorb(void *state_ptr, const uint8_t seed[34])
{
    keccak_state *st = (keccak_state *)state_ptr;
    keccak_init(st, 168);        /* rate = 1600 - 2×128 = 1344 bits = 168 bytes */
    keccak_absorb(st, seed, 34);
    keccak_finalize(st, 0x1F);   /* SHAKE domain separator */
}

void
mlkem_xof_squeeze(uint8_t *out, size_t outlen, void *state_ptr)
{
    keccak_state *st = (keccak_state *)state_ptr;
    keccak_squeeze(st, out, outlen);
}

/* ─────────────────────────────────────────────
 * SHAKE-256: PRF for CBD sampling
 * PRF(s, b) = SHAKE-256(s || b)
 * ───────────────────────────────────────────── */

void
mlkem_prf(uint8_t *out, size_t outlen, const uint8_t key[MLKEM_SYMBYTES], uint8_t nonce)
{
    keccak_state state;
    keccak_init(&state, 136);    /* SHAKE-256 rate */
    keccak_absorb(&state, key, MLKEM_SYMBYTES);
    keccak_absorb(&state, &nonce, 1);
    keccak_finalize(&state, 0x1F);
    keccak_squeeze(&state, out, outlen);
}

/* ─────────────────────────────────────────────
 * Public API wrappers
 * ───────────────────────────────────────────── */

void
mlkem_hash_h(uint8_t h[32], const uint8_t *in, size_t inlen)
{
    sha3_256(h, in, inlen);
}

void
mlkem_hash_g(uint8_t h[64], const uint8_t *in, size_t inlen)
{
    sha3_512(h, in, inlen);
}

void
mlkem_kdf(uint8_t ss[MLKEM_SSBYTES], const uint8_t *in, size_t inlen)
{
    /* J = SHAKE-256, truncated to 32 bytes */
    keccak_state state;
    keccak_init(&state, 136);
    keccak_absorb(&state, in, inlen);
    keccak_finalize(&state, 0x1F);
    keccak_squeeze(&state, ss, MLKEM_SSBYTES);
}

/* Expose keccak_state size for external callers */
size_t mlkem_xof_state_size(void) { return sizeof(keccak_state); }