/*
* QR-NSP Volcanic Edition — AES-256-GCM Implementation
* Reference scalar + AES-NI hardware-accelerated path
*
* GCM = CTR mode encryption + GHASH authentication
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
#include "qrnsp_aead.h"
#include <string.h>
/* ─────────────────────────────────────────────
* AES-NI detection
* ───────────────────────────────────────────── */
#if defined(__AES__) && defined(__PCLMUL__) && defined(__x86_64__)
#define USE_AESNI 1
#include <immintrin.h>
#include <wmmintrin.h>
#else
#define USE_AESNI 0
#endif
/* ═════════════════════════════════════════════
* AES-256 Core (Scalar Reference)
* ═════════════════════════════════════════════ */
/* AES S-box */
static const uint8_t sbox[256] = {
0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16
};
/* Round constants */
static const uint8_t rcon[10] = {
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36
};
/* GF(2^8) multiplication tables for MixColumns */
static uint8_t
xtime(uint8_t x)
{
return (uint8_t)((x << 1) ^ (((x >> 7) & 1) * 0x1b));
}
static uint8_t
gmul(uint8_t a, uint8_t b)
{
uint8_t p = 0;
for (int i = 0; i < 8; i++) {
if (b & 1) p ^= a;
uint8_t hi = a & 0x80;
a <<= 1;
if (hi) a ^= 0x1b;
b >>= 1;
}
return p;
}
/* AES-256 key schedule: 32-byte key → 15 round keys (240 bytes) */
static void
aes256_key_expand(uint8_t rk[240], const uint8_t key[32])
{
memcpy(rk, key, 32);
unsigned int i = 8; /* 8 words already in key */
unsigned int rcon_idx = 0;
while (i < 60) { /* 60 words for AES-256 */
uint8_t t[4];
memcpy(t, &rk[(i - 1) * 4], 4);
if (i % 8 == 0) {
/* RotWord + SubWord + Rcon */
uint8_t tmp = t[0];
t[0] = sbox[t[1]] ^ rcon[rcon_idx++];
t[1] = sbox[t[2]];
t[2] = sbox[t[3]];
t[3] = sbox[tmp];
} else if (i % 8 == 4) {
/* SubWord only */
t[0] = sbox[t[0]];
t[1] = sbox[t[1]];
t[2] = sbox[t[2]];
t[3] = sbox[t[3]];
}
for (int j = 0; j < 4; j++)
rk[i * 4 + j] = rk[(i - 8) * 4 + j] ^ t[j];
i++;
}
}
/* AES single-block encryption (14 rounds for AES-256) */
static void
aes256_encrypt_block(uint8_t out[16], const uint8_t in[16], const uint8_t rk[240])
{
uint8_t s[16];
memcpy(s, in, 16);
/* AddRoundKey (round 0) */
for (int i = 0; i < 16; i++) s[i] ^= rk[i];
for (int round = 1; round <= 14; round++) {
/* SubBytes */
for (int i = 0; i < 16; i++) s[i] = sbox[s[i]];
/* ShiftRows */
uint8_t t;
t = s[1]; s[1] = s[5]; s[5] = s[9]; s[9] = s[13]; s[13] = t;
t = s[2]; s[2] = s[10]; s[10] = t; t = s[6]; s[6] = s[14]; s[14] = t;
t = s[15]; s[15] = s[11]; s[11] = s[7]; s[7] = s[3]; s[3] = t;
/* MixColumns (skip on last round) */
if (round < 14) {
for (int c = 0; c < 4; c++) {
uint8_t a0 = s[c*4], a1 = s[c*4+1], a2 = s[c*4+2], a3 = s[c*4+3];
s[c*4] = xtime(a0) ^ xtime(a1) ^ a1 ^ a2 ^ a3;
s[c*4+1] = a0 ^ xtime(a1) ^ xtime(a2) ^ a2 ^ a3;
s[c*4+2] = a0 ^ a1 ^ xtime(a2) ^ xtime(a3) ^ a3;
s[c*4+3] = xtime(a0) ^ a0 ^ a1 ^ a2 ^ xtime(a3);
}
}
/* AddRoundKey */
for (int i = 0; i < 16; i++) s[i] ^= rk[round * 16 + i];
}
memcpy(out, s, 16);
}
/* ═════════════════════════════════════════════
* AES-NI Accelerated Path
* ═════════════════════════════════════════════ */
#if USE_AESNI
/* AES-256 key expansion using AES-NI */
static inline __m128i
aes_keygen_assist(__m128i key, __m128i gen, int rcon_shift)
{
gen = _mm_shuffle_epi32(gen, 0xFF);
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
return _mm_xor_si128(key, gen);
}
static inline __m128i
aes_keygen_assist2(__m128i key, __m128i gen)
{
gen = _mm_shuffle_epi32(_mm_aeskeygenassist_si128(gen, 0), 0xAA);
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
return _mm_xor_si128(key, gen);
}
static void
aesni_key_expand(__m128i rk[15], const uint8_t key[32])
{
rk[0] = _mm_loadu_si128((const __m128i *)key);
rk[1] = _mm_loadu_si128((const __m128i *)(key + 16));
rk[2] = aes_keygen_assist(rk[0], _mm_aeskeygenassist_si128(rk[1], 0x01), 0);
rk[3] = aes_keygen_assist2(rk[1], rk[2]);
rk[4] = aes_keygen_assist(rk[2], _mm_aeskeygenassist_si128(rk[3], 0x02), 0);
rk[5] = aes_keygen_assist2(rk[3], rk[4]);
rk[6] = aes_keygen_assist(rk[4], _mm_aeskeygenassist_si128(rk[5], 0x04), 0);
rk[7] = aes_keygen_assist2(rk[5], rk[6]);
rk[8] = aes_keygen_assist(rk[6], _mm_aeskeygenassist_si128(rk[7], 0x08), 0);
rk[9] = aes_keygen_assist2(rk[7], rk[8]);
rk[10] = aes_keygen_assist(rk[8], _mm_aeskeygenassist_si128(rk[9], 0x10), 0);
rk[11] = aes_keygen_assist2(rk[9], rk[10]);
rk[12] = aes_keygen_assist(rk[10], _mm_aeskeygenassist_si128(rk[11], 0x20), 0);
rk[13] = aes_keygen_assist2(rk[11], rk[12]);
rk[14] = aes_keygen_assist(rk[12], _mm_aeskeygenassist_si128(rk[13], 0x40), 0);
}
/* AES-256 single block encrypt via AES-NI */
static inline __m128i
aesni_encrypt_block(__m128i block, const __m128i rk[15])
{
block = _mm_xor_si128(block, rk[0]);
block = _mm_aesenc_si128(block, rk[1]);
block = _mm_aesenc_si128(block, rk[2]);
block = _mm_aesenc_si128(block, rk[3]);
block = _mm_aesenc_si128(block, rk[4]);
block = _mm_aesenc_si128(block, rk[5]);
block = _mm_aesenc_si128(block, rk[6]);
block = _mm_aesenc_si128(block, rk[7]);
block = _mm_aesenc_si128(block, rk[8]);
block = _mm_aesenc_si128(block, rk[9]);
block = _mm_aesenc_si128(block, rk[10]);
block = _mm_aesenc_si128(block, rk[11]);
block = _mm_aesenc_si128(block, rk[12]);
block = _mm_aesenc_si128(block, rk[13]);
return _mm_aesenclast_si128(block, rk[14]);
}
/* GHASH multiply using PCLMULQDQ */
static inline __m128i
ghash_mul_ni(__m128i a, __m128i b)
{
__m128i t0 = _mm_clmulepi64_si128(a, b, 0x00);
__m128i t1 = _mm_clmulepi64_si128(a, b, 0x01);
__m128i t2 = _mm_clmulepi64_si128(a, b, 0x10);
__m128i t3 = _mm_clmulepi64_si128(a, b, 0x11);
t1 = _mm_xor_si128(t1, t2);
t2 = _mm_slli_si128(t1, 8);
t1 = _mm_srli_si128(t1, 8);
t0 = _mm_xor_si128(t0, t2);
t3 = _mm_xor_si128(t3, t1);
/* Reduce mod x^128 + x^7 + x^2 + x + 1 */
__m128i poly = _mm_set_epi32(0, 0, 0xC2000000, 0x00000001);
t1 = _mm_clmulepi64_si128(t0, poly, 0x00);
t0 = _mm_shuffle_epi32(t0, 78);
t0 = _mm_xor_si128(t0, t1);
t1 = _mm_clmulepi64_si128(t0, poly, 0x00);
t0 = _mm_shuffle_epi32(t0, 78);
t0 = _mm_xor_si128(t0, t1);
return _mm_xor_si128(t0, t3);
}
#endif /* USE_AESNI */
/* ═════════════════════════════════════════════
* GHASH (Scalar Reference)
* GF(2^128) multiplication for authentication
* ═════════════════════════════════════════════ */
/* Byte-reverse a 128-bit block (GCM uses big-endian bit ordering) */
static void
block_reverse(uint8_t out[16], const uint8_t in[16])
{
for (int i = 0; i < 16; i++) out[i] = in[15 - i];
}
/* GF(2^128) multiplication: R = A * B mod P(x) */
static void
ghash_mul_scalar(uint8_t R[16], const uint8_t A[16], const uint8_t B[16])
{
uint8_t V[16], Z[16];
memcpy(V, B, 16);
memset(Z, 0, 16);
for (int i = 0; i < 128; i++) {
if ((A[i / 8] >> (7 - (i % 8))) & 1) {
for (int j = 0; j < 16; j++) Z[j] ^= V[j];
}
uint8_t lsb = V[15] & 1;
/* Right-shift V by 1 */
for (int j = 15; j > 0; j--)
V[j] = (V[j] >> 1) | (V[j-1] << 7);
V[0] >>= 1;
if (lsb) V[0] ^= 0xE1; /* x^128 + x^7 + x^2 + x + 1 reduction */
}
memcpy(R, Z, 16);
}
/* ═════════════════════════════════════════════
* GCM State
* ═════════════════════════════════════════════ */
typedef struct {
uint8_t rk_scalar[240]; /* Scalar round keys */
#if USE_AESNI
__m128i rk_ni[15] __attribute__((aligned(16)));
__m128i H_ni; /* Hash subkey for PCLMULQDQ */
#endif
uint8_t H[16]; /* Hash subkey: AES_K(0^128) */
uint8_t J0[16]; /* Pre-counter block */
uint8_t ghash_state[16]; /* Running GHASH accumulator */
uint64_t aad_len; /* Total AAD bytes processed */
uint64_t ct_len; /* Total ciphertext bytes processed */
int use_ni; /* AES-NI available */
} gcm_ctx;
static void
gcm_init(gcm_ctx *ctx, const uint8_t key[32], const uint8_t nonce[12])
{
memset(ctx, 0, sizeof(*ctx));
#if USE_AESNI
ctx->use_ni = 1;
aesni_key_expand(ctx->rk_ni, key);
/* H = AES_K(0^128) */
ctx->H_ni = aesni_encrypt_block(_mm_setzero_si128(), ctx->rk_ni);
_mm_storeu_si128((__m128i *)ctx->H, ctx->H_ni);
#else
ctx->use_ni = 0;
aes256_key_expand(ctx->rk_scalar, key);
uint8_t zero[16] = {0};
aes256_encrypt_block(ctx->H, zero, ctx->rk_scalar);
#endif
/* J0 = nonce || 0x00000001 (for 96-bit nonce) */
memcpy(ctx->J0, nonce, 12);
ctx->J0[12] = 0; ctx->J0[13] = 0; ctx->J0[14] = 0; ctx->J0[15] = 1;
}
/* Increment 32-bit counter in block (big-endian, last 4 bytes) */
static void
inc32(uint8_t block[16])
{
uint32_t c = ((uint32_t)block[12] << 24) | ((uint32_t)block[13] << 16) |
((uint32_t)block[14] << 8) | block[15];
c++;
block[12] = (uint8_t)(c >> 24);
block[13] = (uint8_t)(c >> 16);
block[14] = (uint8_t)(c >> 8);
block[15] = (uint8_t)c;
}
/* Encrypt block (dispatch) */
static void
gcm_encrypt_block(gcm_ctx *ctx, uint8_t out[16], const uint8_t in[16])
{
#if USE_AESNI
if (ctx->use_ni) {
__m128i b = _mm_loadu_si128((const __m128i *)in);
__m128i r = aesni_encrypt_block(b, ctx->rk_ni);
_mm_storeu_si128((__m128i *)out, r);
return;
}
#endif
aes256_encrypt_block(out, in, ctx->rk_scalar);
}
/* GHASH update: state = (state ⊕ data) · H */
static void
gcm_ghash_update(gcm_ctx *ctx, const uint8_t data[16])
{
#if USE_AESNI
if (ctx->use_ni) {
__m128i s = _mm_loadu_si128((const __m128i *)ctx->ghash_state);
__m128i d = _mm_loadu_si128((const __m128i *)data);
s = _mm_xor_si128(s, d);
s = ghash_mul_ni(s, ctx->H_ni);
_mm_storeu_si128((__m128i *)ctx->ghash_state, s);
return;
}
#endif
uint8_t tmp[16];
for (int i = 0; i < 16; i++) tmp[i] = ctx->ghash_state[i] ^ data[i];
ghash_mul_scalar(ctx->ghash_state, tmp, ctx->H);
}
/* ═════════════════════════════════════════════
* GCM Encrypt / Decrypt
* ═════════════════════════════════════════════ */
int
aead_encrypt(uint8_t *ct, uint8_t tag[AEAD_TAG_BYTES],
const uint8_t *pt, size_t ptlen,
const uint8_t *aad, size_t aadlen,
const uint8_t nonce[AEAD_NONCE_BYTES],
const uint8_t key[AEAD_KEY_BYTES])
{
gcm_ctx ctx;
gcm_init(&ctx, key, nonce);
/* GHASH AAD */
size_t i;
for (i = 0; i + 16 <= aadlen; i += 16)
gcm_ghash_update(&ctx, aad + i);
if (i < aadlen) {
uint8_t pad[16] = {0};
memcpy(pad, aad + i, aadlen - i);
gcm_ghash_update(&ctx, pad);
}
ctx.aad_len = aadlen;
/* CTR encryption + GHASH ciphertext */
uint8_t ctr[16], keystream[16];
memcpy(ctr, ctx.J0, 16);
inc32(ctr); /* First counter is J0+1 */
for (i = 0; i + 16 <= ptlen; i += 16) {
gcm_encrypt_block(&ctx, keystream, ctr);
for (int j = 0; j < 16; j++) ct[i + j] = pt[i + j] ^ keystream[j];
gcm_ghash_update(&ctx, ct + i);
inc32(ctr);
}
if (i < ptlen) {
gcm_encrypt_block(&ctx, keystream, ctr);
uint8_t pad[16] = {0};
for (size_t j = 0; j < ptlen - i; j++) ct[i + j] = pt[i + j] ^ keystream[j];
memcpy(pad, ct + i, ptlen - i);
gcm_ghash_update(&ctx, pad);
}
ctx.ct_len = ptlen;
/* Final GHASH: len(A) || len(C) in bits, big-endian 64-bit */
uint8_t len_block[16] = {0};
uint64_t aad_bits = ctx.aad_len * 8;
uint64_t ct_bits = ctx.ct_len * 8;
for (int j = 0; j < 8; j++) {
len_block[j] = (uint8_t)(aad_bits >> (56 - 8 * j));
len_block[8 + j] = (uint8_t)(ct_bits >> (56 - 8 * j));
}
gcm_ghash_update(&ctx, len_block);
/* Tag = AES_K(J0) ⊕ GHASH */
uint8_t s[16];
gcm_encrypt_block(&ctx, s, ctx.J0);
for (int j = 0; j < 16; j++) tag[j] = s[j] ^ ctx.ghash_state[j];
/* Zeroize */
memset(&ctx, 0, sizeof(ctx));
memset(keystream, 0, sizeof(keystream));
return 0;
}
int
aead_decrypt(uint8_t *pt,
const uint8_t *ct, size_t ctlen,
const uint8_t tag[AEAD_TAG_BYTES],
const uint8_t *aad, size_t aadlen,
const uint8_t nonce[AEAD_NONCE_BYTES],
const uint8_t key[AEAD_KEY_BYTES])
{
gcm_ctx ctx;
gcm_init(&ctx, key, nonce);
/* GHASH AAD */
size_t i;
for (i = 0; i + 16 <= aadlen; i += 16)
gcm_ghash_update(&ctx, aad + i);
if (i < aadlen) {
uint8_t pad[16] = {0};
memcpy(pad, aad + i, aadlen - i);
gcm_ghash_update(&ctx, pad);
}
ctx.aad_len = aadlen;
/* GHASH ciphertext (before decryption) */
for (i = 0; i + 16 <= ctlen; i += 16)
gcm_ghash_update(&ctx, ct + i);
if (i < ctlen) {
uint8_t pad[16] = {0};
memcpy(pad, ct + i, ctlen - i);
gcm_ghash_update(&ctx, pad);
}
ctx.ct_len = ctlen;
/* Compute expected tag */
uint8_t len_block[16] = {0};
uint64_t aad_bits = ctx.aad_len * 8;
uint64_t ct_bits = ctx.ct_len * 8;
for (int j = 0; j < 8; j++) {
len_block[j] = (uint8_t)(aad_bits >> (56 - 8 * j));
len_block[8 + j] = (uint8_t)(ct_bits >> (56 - 8 * j));
}
gcm_ghash_update(&ctx, len_block);
uint8_t s[16];
gcm_encrypt_block(&ctx, s, ctx.J0);
uint8_t expected_tag[16];
for (int j = 0; j < 16; j++) expected_tag[j] = s[j] ^ ctx.ghash_state[j];
/* Constant-time tag comparison */
uint8_t diff = 0;
for (int j = 0; j < 16; j++) diff |= expected_tag[j] ^ tag[j];
if (diff != 0) {
memset(pt, 0, ctlen); /* Zero output on auth failure */
memset(&ctx, 0, sizeof(ctx));
return -1;
}
/* CTR decrypt */
uint8_t ctr[16], keystream[16];
memcpy(ctr, ctx.J0, 16);
inc32(ctr);
for (i = 0; i + 16 <= ctlen; i += 16) {
gcm_encrypt_block(&ctx, keystream, ctr);
for (int j = 0; j < 16; j++) pt[i + j] = ct[i + j] ^ keystream[j];
inc32(ctr);
}
if (i < ctlen) {
gcm_encrypt_block(&ctx, keystream, ctr);
for (size_t j = 0; j < ctlen - i; j++) pt[i + j] = ct[i + j] ^ keystream[j];
}
memset(&ctx, 0, sizeof(ctx));
memset(keystream, 0, sizeof(keystream));
return 0;
}