Skip to content

src/orchestrator/orchestrator.c

/*
 * QR-NSP Volcanic Edition — Unified Orchestrator
 * Module 8: Pipeline controller
 *
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#include "qrnsp_orchestrator.h"
#include "qrnsp_aead.h"
#include "mlkem_params.h"
#include <string.h>
#include <stdlib.h>

/* ═════════════════════════════════════════════
 * Init
 * ═════════════════════════════════════════════ */

int
orch_init(orch_session_t *sess, int role, morph_cover_t cover)
{
    memset(sess, 0, sizeof(*sess));
    sess->role = role;
    sess->state = ORCH_STATE_INIT;

    /* Generate our keypair */
    if (hybrid_keypair_generate(sess->pk, sess->sk) != 0)
        return -1;

    /* Pre-init morph with a temp seed (re-seeded after KEM) */
    uint8_t temp_seed[32] = {0};
    morph_init(&sess->morph, cover, temp_seed);
    sess->morph.active = 0; /* Inactive until handshake completes */

    sess->cover_interval_ns = 100000000ULL; /* 100ms default */
    sess->next_msg_id = 1;

    return 0;
}

int
orch_get_pubkey(const orch_session_t *sess,
                uint8_t pk[QRNSP_CRYPTO_PUBLICKEYBYTES])
{
    memcpy(pk, sess->pk, QRNSP_CRYPTO_PUBLICKEYBYTES);
    return 0;
}

/* ═════════════════════════════════════════════
 * Handshake: KEM exchange → derive all session keys
 * ═════════════════════════════════════════════ */

int
orch_handshake(orch_session_t *sess,
               const uint8_t peer_pk[QRNSP_CRYPTO_PUBLICKEYBYTES],
               uint8_t ct[QRNSP_CRYPTO_CIPHERTEXTBYTES],
               int is_ciphertext)
{
    memcpy(sess->peer_pk, peer_pk, QRNSP_CRYPTO_PUBLICKEYBYTES);
    sess->state = ORCH_STATE_HANDSHAKE;

    if (!is_ciphertext) {
        /* We are initiator: encapsulate to peer's pk */
        if (hybrid_encapsulate(ct, sess->shared_secret, peer_pk) != 0)
            return -1;
    } else {
        /* We are responder: decapsulate from received ct */
        if (hybrid_decapsulate(sess->shared_secret, ct, sess->sk, sess->pk) != 0)
            return -1;
    }

    sess->kem_complete = 1;

    /* ── Initialize all sub-modules with the shared secret ── */

    /* Module 3: QUIC steganography */
    stego_session_init(&sess->stego, sess->shared_secret, sess->role);

    /* Module 4: Jitter channel */
    jitter_tx_init(&sess->jitter_tx, sess->shared_secret);
    jitter_rx_init(&sess->jitter_rx, sess->shared_secret);

    /* Module 6: TCP fallback */
    fallback_init(&sess->fallback, sess->shared_secret);

    /* Module 7: Re-seed morph with real shared secret */
    morph_init(&sess->morph, sess->morph.profile.type, sess->shared_secret);
    sess->morph.active = 1;

    /* Set initial capabilities */
    sess->capabilities = ORCH_CAP_QUIC_PADDING | ORCH_CAP_JITTER;
    sess->state = ORCH_STATE_ACTIVE;

    return 0;
}

/* ═════════════════════════════════════════════
 * Chunking: split messages into packet-sized pieces
 * ═════════════════════════════════════════════ */

static int
chunk_message(const uint8_t *msg, size_t msg_len,
              uint32_t msg_id, size_t max_chunk_data,
              orch_chunk_t *chunks, int max_chunks)
{
    if (max_chunk_data > 1400) max_chunk_data = 1400;
    if (msg_len == 0) return 0;

    uint16_t total = (uint16_t)((msg_len + max_chunk_data - 1) / max_chunk_data);
    if (total > max_chunks) return -1;

    for (uint16_t i = 0; i < total; i++) {
        chunks[i].msg_id = msg_id;
        chunks[i].chunk_index = i;
        chunks[i].total_chunks = total;

        size_t offset = (size_t)i * max_chunk_data;
        size_t remaining = msg_len - offset;
        chunks[i].data_len = (uint16_t)(remaining < max_chunk_data ?
                                         remaining : max_chunk_data);
        memcpy(chunks[i].data, msg + offset, chunks[i].data_len);
    }

    return (int)total;
}

/* Serialize chunk header + data into a buffer */
static size_t
serialize_chunk(const orch_chunk_t *c, uint8_t *buf)
{
    /* [MSG_ID:4][CHUNK:2][TOTAL:2][DATA:N] */
    buf[0] = (uint8_t)(c->msg_id >> 24);
    buf[1] = (uint8_t)(c->msg_id >> 16);
    buf[2] = (uint8_t)(c->msg_id >> 8);
    buf[3] = (uint8_t)(c->msg_id);
    buf[4] = (uint8_t)(c->chunk_index >> 8);
    buf[5] = (uint8_t)(c->chunk_index);
    buf[6] = (uint8_t)(c->total_chunks >> 8);
    buf[7] = (uint8_t)(c->total_chunks);
    memcpy(buf + ORCH_CHUNK_HEADER, c->data, c->data_len);
    return ORCH_CHUNK_HEADER + c->data_len;
}

static int
deserialize_chunk(const uint8_t *buf, size_t len, orch_chunk_t *c)
{
    if (len < ORCH_CHUNK_HEADER) return -1;

    c->msg_id      = ((uint32_t)buf[0] << 24) | ((uint32_t)buf[1] << 16) |
                     ((uint32_t)buf[2] << 8) | buf[3];
    c->chunk_index = ((uint16_t)buf[4] << 8) | buf[5];
    c->total_chunks= ((uint16_t)buf[6] << 8) | buf[7];
    c->data_len    = (uint16_t)(len - ORCH_CHUNK_HEADER);
    if (c->data_len > sizeof(c->data)) return -1;
    memcpy(c->data, buf + ORCH_CHUNK_HEADER, c->data_len);
    return 0;
}

/* ═════════════════════════════════════════════
 * Send: message → chunk → encrypt → inject → morph → packets
 * ═════════════════════════════════════════════ */

int
orch_send(orch_session_t *sess,
          const uint8_t *msg, size_t msg_len,
          uint8_t packets[][2048], size_t *pkt_sizes,
          uint64_t *pkt_delays, int max_pkts, int *n_pkts)
{
    if (sess->state != ORCH_STATE_ACTIVE &&
        sess->state != ORCH_STATE_FALLBACK)
        return -1;

    *n_pkts = 0;

    /* Determine max chunk payload based on available stego capacity */
    size_t max_chunk_data = 1024; /* Conservative default */

    /* Chunk the message */
    orch_chunk_t chunks[256];
    int nchunks = chunk_message(msg, msg_len, sess->next_msg_id++,
                                max_chunk_data, chunks, 256);
    if (nchunks < 0) return -1;
    if (nchunks > max_pkts) return -1;

    /* Process each chunk */
    for (int i = 0; i < nchunks && *n_pkts < max_pkts; i++) {
        /* Serialize chunk */
        uint8_t chunk_buf[1500];
        size_t chunk_len = serialize_chunk(&chunks[i], chunk_buf);

        /* Determine morphed packet size and delay */
        size_t target_size;
        uint64_t delay_ns;
        morph_shape(&sess->morph, chunk_len + 64, &target_size, &delay_ns);

        /* Build output packet: for now, chunk_buf IS the payload.
         * In production, this would be injected into a QUIC PADDING
         * region of a real cover traffic packet. */
        if (target_size > 2048) target_size = 2048;
        memset(packets[*n_pkts], 0, target_size);
        memcpy(packets[*n_pkts], chunk_buf, chunk_len);

        /* Pad to target size */
        morph_pad(&sess->morph, packets[*n_pkts], chunk_len, target_size);

        pkt_sizes[*n_pkts] = target_size;
        pkt_delays[*n_pkts] = delay_ns;

        morph_record_tx(&sess->morph, target_size, 0 /* timestamp filled by caller */);

        sess->chunks_sent++;
        (*n_pkts)++;
    }

    sess->msgs_sent++;
    sess->bytes_sent += msg_len;

    return 0;
}

/* ═════════════════════════════════════════════
 * Receive: packet → extract → dechunk → reassemble
 * ═════════════════════════════════════════════ */

static orch_reassembly_t *
find_or_create_reassembly(orch_session_t *sess, uint32_t msg_id, uint16_t total)
{
    /* Find existing */
    for (int i = 0; i < sess->reassembly_count; i++) {
        if (sess->reassembly[i].msg_id == msg_id)
            return &sess->reassembly[i];
    }

    /* Create new */
    if (sess->reassembly_count >= 4) {
        /* Evict oldest */
        if (sess->reassembly[0].buffer)
            free(sess->reassembly[0].buffer);
        memmove(&sess->reassembly[0], &sess->reassembly[1],
                3 * sizeof(orch_reassembly_t));
        sess->reassembly_count = 3;
    }

    orch_reassembly_t *r = &sess->reassembly[sess->reassembly_count++];
    memset(r, 0, sizeof(*r));
    r->msg_id = msg_id;
    r->total_chunks = total;
    r->buffer_size = (size_t)total * 1400;
    r->buffer = calloc(1, r->buffer_size);

    return r;
}

int
orch_receive(orch_session_t *sess,
             const uint8_t *pkt, size_t pkt_len,
             uint64_t timestamp_ns)
{
    if (!sess->kem_complete) return -1;

    /* Try QUIC PADDING extraction first */
    uint8_t extracted[2048];
    size_t extracted_len = 0;
    int found = 0;

    if (sess->capabilities & ORCH_CAP_QUIC_PADDING) {
        stego_result_t sr = stego_extract(&sess->stego, pkt, pkt_len,
                                          extracted, sizeof(extracted),
                                          &extracted_len);
        if (sr == STEGO_OK) found = 1;
    }

    /* Feed to jitter receiver regardless */
    if (sess->capabilities & ORCH_CAP_JITTER) {
        jitter_rx_feed(&sess->jitter_rx, timestamp_ns);
    }

    /* Update fallback metrics */
    fallback_quic_rx(&sess->fallback, timestamp_ns);

    if (!found) return -1; /* Not a QR-NSP packet */

    /* Deserialize chunk */
    orch_chunk_t chunk;
    if (deserialize_chunk(extracted, extracted_len, &chunk) != 0)
        return -1;

    sess->chunks_received++;

    /* Reassemble */
    orch_reassembly_t *r = find_or_create_reassembly(
        sess, chunk.msg_id, chunk.total_chunks);
    if (!r || !r->buffer) return -1;

    /* Store chunk */
    size_t offset = (size_t)chunk.chunk_index * 1400;
    if (offset + chunk.data_len <= r->buffer_size) {
        memcpy(r->buffer + offset, chunk.data, chunk.data_len);

        /* Mark received */
        int word = chunk.chunk_index / 16;
        int bit  = chunk.chunk_index % 16;
        if (word < 64) {
            if (!(r->received_mask[word] & (1 << bit))) {
                r->received_mask[word] |= (1 << bit);
                r->chunks_received++;
            }
        }

        if (r->first_chunk_ns == 0)
            r->first_chunk_ns = timestamp_ns;
    }

    /* Check if message is complete */
    if (r->chunks_received >= r->total_chunks) {
        sess->msgs_received++;
        return 1; /* Message complete */
    }

    return 0; /* Still assembling */
}

int
orch_recv_msg(orch_session_t *sess,
              uint8_t *out, size_t out_cap, size_t *out_len)
{
    /* Find a completed message */
    for (int i = 0; i < sess->reassembly_count; i++) {
        orch_reassembly_t *r = &sess->reassembly[i];
        if (r->chunks_received >= r->total_chunks && r->buffer) {
            size_t total_len = (size_t)r->total_chunks * 1400;
            /* Actual length may be less (last chunk smaller) */
            if (total_len > out_cap) total_len = out_cap;
            memcpy(out, r->buffer, total_len);
            *out_len = total_len;

            sess->bytes_received += total_len;

            /* Free reassembly slot */
            free(r->buffer);
            r->buffer = NULL;
            memmove(&sess->reassembly[i], &sess->reassembly[i+1],
                    (sess->reassembly_count - i - 1) * sizeof(orch_reassembly_t));
            sess->reassembly_count--;

            return 0;
        }
    }
    return -1;
}

/* ═════════════════════════════════════════════
 * Tick: periodic maintenance
 * ═════════════════════════════════════════════ */

int
orch_tick(orch_session_t *sess, uint64_t now_ns)
{
    if (sess->state < ORCH_STATE_ACTIVE) return 0;

    /* Update transport state machine */
    transport_state_t ts = fallback_tick(&sess->fallback, now_ns);
    switch (ts) {
    case TRANSPORT_QUIC_OK:
        if (sess->state == ORCH_STATE_FALLBACK) {
            sess->state = ORCH_STATE_ACTIVE;
            sess->capabilities |= ORCH_CAP_QUIC_PADDING;
        }
        break;
    case TRANSPORT_HTTP2_ACTIVE:
    case TRANSPORT_HTTP2_DEGRADED:
        sess->state = ORCH_STATE_FALLBACK;
        sess->capabilities &= ~ORCH_CAP_QUIC_PADDING;
        sess->capabilities |= ORCH_CAP_TCP_CHAFF | ORCH_CAP_H2_PADDING;
        break;
    case TRANSPORT_BLOCKED:
        sess->state = ORCH_STATE_JITTER_ONLY;
        sess->capabilities = ORCH_CAP_JITTER; /* Only timing channel */
        break;
    default:
        break;
    }

    return 0;
}

/* ═════════════════════════════════════════════
 * Destroy
 * ═════════════════════════════════════════════ */

void
orch_destroy(orch_session_t *sess)
{
    /* Free reassembly buffers */
    for (int i = 0; i < sess->reassembly_count; i++) {
        if (sess->reassembly[i].buffer)
            free(sess->reassembly[i].buffer);
    }

    stego_session_destroy(&sess->stego);
    jitter_tx_destroy(&sess->jitter_tx);
    jitter_rx_destroy(&sess->jitter_rx);
    fallback_destroy(&sess->fallback);
    morph_destroy(&sess->morph);

    memset(sess, 0, sizeof(*sess));
    sess->state = ORCH_STATE_CLOSED;
}

/* ═════════════════════════════════════════════
 * Stats
 * ═════════════════════════════════════════════ */

void
orch_get_stats(const orch_session_t *sess, orch_stats_t *stats)
{
    stats->state = sess->state;
    stats->capabilities = sess->capabilities;
    stats->msgs_sent = sess->msgs_sent;
    stats->msgs_received = sess->msgs_received;
    stats->bytes_sent = sess->bytes_sent;
    stats->bytes_received = sess->bytes_received;
    stats->cover_packets = sess->cover_packets_sent;
    stats->morph_quality = morph_quality(&sess->morph);
    stats->fallback_count = sess->fallback.fallback_count;
    stats->recovery_count = sess->fallback.recovery_count;
}