/*
* QR-NSP Volcanic Edition — ML-KEM-1024 Parameters (FIPS 203)
* SPDX-License-Identifier: AGPL-3.0-or-later
* Module 2: Post-Quantum Key Encapsulation
*
* ML-KEM-1024 = Module-LWE KEM at NIST Level 5
* q = 3329, n = 256, k = 4
*
* Reference: FIPS 203 (August 2024)
*/
#ifndef MLKEM_PARAMS_H
#define MLKEM_PARAMS_H
#include <stdint.h>
#include <stddef.h>
/* ─────────────────────────────────────────────
* ML-KEM-1024 Parameters (FIPS 203 Table 1)
* ───────────────────────────────────────────── */
#define MLKEM_N 256 /* Polynomial degree */
#define MLKEM_K 4 /* Module rank (1024 = k*n) */
#define MLKEM_Q 3329 /* Modulus (prime, 3329 = 13·256 + 1) */
#define MLKEM_ETA1 2 /* CBD parameter for secret/error */
#define MLKEM_ETA2 2 /* CBD parameter for encryption noise */
#define MLKEM_DU 11 /* Compression bits for u vector */
#define MLKEM_DV 5 /* Compression bits for v scalar */
/* Derived sizes (bytes) */
#define MLKEM_SYMBYTES 32 /* Shared secret, seeds, hashes */
#define MLKEM_POLYBYTES 384 /* 256 coefficients × 12 bits / 8 */
#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 /* 256 × 11 / 8 = 352 */
#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 /* 256 × 5 / 8 = 160 */
/* Public key: ek = (t_hat || rho) where t_hat is k polys + 32-byte seed */
#define MLKEM_PUBLICKEYBYTES (MLKEM_K * MLKEM_POLYBYTES + MLKEM_SYMBYTES)
/* 4 × 384 + 32 = 1568 */
/* Secret key: dk = (s_hat || ek || H(ek) || z) */
#define MLKEM_SECRETKEYBYTES (MLKEM_K * MLKEM_POLYBYTES + \
MLKEM_PUBLICKEYBYTES + \
2 * MLKEM_SYMBYTES)
/* 1536 + 1568 + 64 = 3168 */
/* Ciphertext: c = (c1 || c2) where c1 is k compressed polys + compressed v */
#define MLKEM_CIPHERTEXTBYTES (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU + \
MLKEM_POLYCOMPRESSEDBYTES_DV)
/* 4 × 352 + 160 = 1568 */
/* Shared secret */
#define MLKEM_SSBYTES 32
/* ─────────────────────────────────────────────
* NTT Constants
*
* q = 3329 = 13 × 256 + 1
* Primitive 512th root of unity mod q: ζ = 17
* Montgomery constant R = 2^16 mod q = 2285
* Barrett constant: 5039 (for q)
* ───────────────────────────────────────────── */
#define MLKEM_MONT 2285 /* 2^16 mod q */
#define MLKEM_QINV 62209 /* q^(-1) mod 2^16 */
#define MLKEM_ZETA 17 /* Primitive 512th root of unity mod q */
/* ─────────────────────────────────────────────
* Core Types
* ───────────────────────────────────────────── */
/* Single polynomial in R_q = Z_q[X]/(X^256 + 1) */
typedef struct {
int16_t coeffs[MLKEM_N];
} mlkem_poly;
/* Vector of k polynomials (module element) */
typedef struct {
mlkem_poly vec[MLKEM_K];
} mlkem_polyvec;
/* Key pair */
typedef struct {
uint8_t pk[MLKEM_PUBLICKEYBYTES];
uint8_t sk[MLKEM_SECRETKEYBYTES];
} mlkem_keypair;
/* ─────────────────────────────────────────────
* AVX-512 Detection
* ───────────────────────────────────────────── */
#if defined(__AVX512F__) && defined(__AVX512BW__)
#define MLKEM_USE_AVX512 1
#else
#define MLKEM_USE_AVX512 0
#endif
/* ─────────────────────────────────────────────
* Function Declarations
* ───────────────────────────────────────────── */
/* NTT (ntt.c) */
void mlkem_ntt(mlkem_poly *p);
void mlkem_invntt(mlkem_poly *p);
void mlkem_basemul(mlkem_poly *r, const mlkem_poly *a, const mlkem_poly *b);
/* Polynomial operations (poly.c) */
void mlkem_poly_cbd_eta1(mlkem_poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]);
void mlkem_poly_cbd_eta2(mlkem_poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]);
void mlkem_poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const mlkem_poly *p);
void mlkem_poly_frombytes(mlkem_poly *r, const uint8_t a[MLKEM_POLYBYTES]);
void mlkem_poly_compress_du(uint8_t *r, const mlkem_poly *p);
void mlkem_poly_decompress_du(mlkem_poly *r, const uint8_t *a);
void mlkem_poly_compress_dv(uint8_t *r, const mlkem_poly *p);
void mlkem_poly_decompress_dv(mlkem_poly *r, const uint8_t *a);
void mlkem_poly_frommsg(mlkem_poly *r, const uint8_t msg[MLKEM_SYMBYTES]);
void mlkem_poly_tomsg(uint8_t msg[MLKEM_SYMBYTES], const mlkem_poly *p);
void mlkem_poly_add(mlkem_poly *r, const mlkem_poly *a, const mlkem_poly *b);
void mlkem_poly_sub(mlkem_poly *r, const mlkem_poly *a, const mlkem_poly *b);
void mlkem_poly_reduce(mlkem_poly *p);
/* Polyvec operations (poly.c) */
void mlkem_polyvec_ntt(mlkem_polyvec *v);
void mlkem_polyvec_invntt(mlkem_polyvec *v);
void mlkem_polyvec_pointwise_acc(mlkem_poly *r, const mlkem_polyvec *a, const mlkem_polyvec *b);
void mlkem_polyvec_compress(uint8_t *r, const mlkem_polyvec *v);
void mlkem_polyvec_decompress(mlkem_polyvec *r, const uint8_t *a);
void mlkem_polyvec_tobytes(uint8_t *r, const mlkem_polyvec *v);
void mlkem_polyvec_frombytes(mlkem_polyvec *r, const uint8_t *a);
void mlkem_polyvec_add(mlkem_polyvec *r, const mlkem_polyvec *a, const mlkem_polyvec *b);
void mlkem_polyvec_reduce(mlkem_polyvec *v);
/* KEM (kem.c) */
int mlkem_keypair_generate(mlkem_keypair *kp);
int mlkem_encapsulate(uint8_t ct[MLKEM_CIPHERTEXTBYTES],
uint8_t ss[MLKEM_SSBYTES],
const uint8_t pk[MLKEM_PUBLICKEYBYTES]);
int mlkem_decapsulate(uint8_t ss[MLKEM_SSBYTES],
const uint8_t ct[MLKEM_CIPHERTEXTBYTES],
const uint8_t sk[MLKEM_SECRETKEYBYTES]);
/* Symmetric primitives (symmetric.c) — SHA3/SHAKE wrappers */
void mlkem_hash_h(uint8_t h[32], const uint8_t *in, size_t inlen);
void mlkem_hash_g(uint8_t h[64], const uint8_t *in, size_t inlen);
void mlkem_kdf(uint8_t ss[MLKEM_SSBYTES], const uint8_t *in, size_t inlen);
void mlkem_xof_absorb(void *state, const uint8_t seed[34]);
void mlkem_xof_squeeze(uint8_t *out, size_t outlen, void *state);
void mlkem_prf(uint8_t *out, size_t outlen, const uint8_t key[MLKEM_SYMBYTES], uint8_t nonce);
#endif /* MLKEM_PARAMS_H */