diff options
Diffstat (limited to 'src/os/linux/linux_net_secure_stream.c')
| -rw-r--r-- | src/os/linux/linux_net_secure_stream.c | 579 | 
1 files changed, 579 insertions, 0 deletions
diff --git a/src/os/linux/linux_net_secure_stream.c b/src/os/linux/linux_net_secure_stream.c new file mode 100644 index 0000000..cfacfc3 --- /dev/null +++ b/src/os/linux/linux_net_secure_stream.c @@ -0,0 +1,579 @@ +#include <os/os.h> +#include <basic/basic.h> +#include <basic/arena.h> +#include <crypto/rsa.h> +#include <crypto/aes_gcm.h> + +#include <string.h> + +#include <openssl/aes.h> +#include <openssl/rsa.h> +#include <openssl/rand.h> + +#include <netinet/in.h> +#include <arpa/inet.h> +#include <sys/socket.h> +#include <unistd.h> +#include <fcntl.h> +#include <errno.h> + +typedef struct { +    // space left for 512 - 64*2 - 2 = 382 for 4096-bit rsa and SHA-512 + +    AesGcmKey aes_key_client_encrypt; // 32 bytes +    AesGcmIv  aes_iv_client_encrypt;  // 12 bytes + +    AesGcmKey aes_key_client_decrypt; // 32 bytes +    AesGcmIv  aes_iv_client_decrypt;  // 12 bytes + +    // space left for 382 - 2*32 - 2*12 = 294 for 4096-bit rsa and SHA-512 + +    u8 client_random[256]; +} CM_Handshake; + + +typedef struct { +    u8 client_random[256]; +} SM_Handshake; + + +typedef struct { +    int fd; +    OSNetSecureStreamStatus status; + +    EVP_PKEY *rsa_key; + +    AesGcmKey    recv_aes_key;    // recvd with CM_Handshake +    AesGcmIv     recv_aes_iv;     // recvd_with CM_Handshake +    AesGcmHeader recv_aes_header; // recvd with each aes message +    b32          recv_aes_header_initialized; +    size_t       recv_ciphertext_size_filled; +    u8           recv_ciphertext[1408]; + +    size_t       recv_plaintext_size_filled; +    size_t       recv_plaintext_size_processed; +    u8           recv_plaintext[1408]; + +    AesGcmKey    send_aes_key; // recvd with CM_Handshake +    AesGcmIv     send_aes_iv;  // recvd with CM_Handshake +    size_t       send_ciphertext_size_used; +    u8           send_ciphertext[1408]; +} OSNetSecureStream; + + +internal_var u32 s_max_secure_stream_count; +internal_var OSNetSecureStream *s_secure_streams; + +internal_var u32 s_free_id_count; +internal_var u32 *s_free_ids; + + + +OSNetSecureStreamStatus +os_net_secure_stream_get_status(u32 id) +{ +    return s_secure_streams[id].status; +} + +internal_fn b32 +recv_and_decrypt_aes_package(OSNetSecureStream *secure_stream) +{ +    i64 size_to_recv; +    i64 size_recvd; + +    // recv aes header +    if (!secure_stream->recv_aes_header_initialized) { +        void *dest = secure_stream->recv_ciphertext + secure_stream->recv_ciphertext_size_filled; +        size_to_recv = sizeof(AesGcmHeader) - secure_stream->recv_ciphertext_size_filled; +        size_recvd = recv(secure_stream->fd, dest, size_to_recv, 0); +        if (size_recvd < 0) { +            if (errno != EAGAIN) { +                printf("error: recv aes header failed with errno = %d\n", errno); +                secure_stream->status = OS_NET_SECURE_STREAM_ERROR; +            } +            return false; +        } +        else if (size_recvd == 0) { +            secure_stream->status = OS_NET_SECURE_STREAM_DISCONNECTED; +            return false; +        } +        else if (size_recvd < size_to_recv) { +            secure_stream->recv_ciphertext_size_filled += size_recvd; +            return false; +        } + +        AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->recv_ciphertext; +        if (aes_header->payload_size > sizeof(secure_stream->recv_ciphertext)) { +            printf("error: aes header has invalid payload size\n"); +            secure_stream->recv_ciphertext_size_filled = 0; +            secure_stream->status = OS_NET_SECURE_STREAM_ERROR; +            return false; +        } + +        secure_stream->recv_aes_header = *(AesGcmHeader*)(secure_stream->recv_ciphertext); +        secure_stream->recv_aes_header_initialized = true; +        secure_stream->recv_ciphertext_size_filled = 0; +    } + + +    // recv aes payload +    size_to_recv = secure_stream->recv_aes_header.payload_size - secure_stream->recv_ciphertext_size_filled; +    size_recvd = recv(secure_stream->fd, secure_stream->recv_ciphertext, size_to_recv, 0); +    if (size_recvd < 0) { +        if (errno != EAGAIN) { +            printf("error: recv aes payload failed with errno = %d\n", errno); +            secure_stream->status = OS_NET_SECURE_STREAM_ERROR; +        } +        return false; +    } +    else if (size_recvd == 0) { +        secure_stream->status = OS_NET_SECURE_STREAM_DISCONNECTED; +        return false; +    } +    else if (size_recvd < size_to_recv) { +        secure_stream->recv_ciphertext_size_filled += size_recvd; +        return false; +    } + +    secure_stream->recv_ciphertext_size_filled += size_recvd; + + +    // decrypt +    b32 decrypted = aes_gcm_decrypt(&secure_stream->recv_aes_key, +                                    &secure_stream->recv_aes_iv, +                                    secure_stream->recv_plaintext, +                                    secure_stream->recv_ciphertext, +                                    secure_stream->recv_ciphertext_size_filled, +                                    secure_stream->recv_aes_header.tag, +                                    sizeof(secure_stream->recv_aes_header.tag)); +    if (!decrypted) { +        secure_stream->status = OS_NET_SECURE_STREAM_ERROR; +        return false; +    } + +    secure_stream->recv_plaintext_size_filled = secure_stream->recv_ciphertext_size_filled; +    secure_stream->recv_aes_header_initialized = false; +    secure_stream->recv_plaintext_size_processed = 0; +    secure_stream->recv_ciphertext_size_filled = 0; + +    return true; +} + + +i64 +os_net_secure_stream_recv(u32 id, u8 *buff, size_t size) +{ +    OSNetSecureStream *secure_stream = &s_secure_streams[id]; +    if (secure_stream->status != OS_NET_SECURE_STREAM_CONNECTED) { +        return -1; +    } + +    size_t size_delivered = 0; +    while (size_delivered < size) { +        size_t plaintext_size_avail = secure_stream->recv_plaintext_size_filled - secure_stream->recv_plaintext_size_processed; +        if (plaintext_size_avail > 0) { +            u8 *src = secure_stream->recv_plaintext + secure_stream->recv_plaintext_size_processed; +            u8 *dest = buff + size_delivered; +            size_t size_to_copy = MIN(size, plaintext_size_avail); +            memcpy(dest, src, size_to_copy); +            secure_stream->recv_plaintext_size_processed += size_to_copy; +            size_delivered += size_to_copy; +        } +        else { +            b32 recvd_and_decrypted = recv_and_decrypt_aes_package(secure_stream); +            if (!recvd_and_decrypted) { +                if (secure_stream->status == OS_NET_SECURE_STREAM_ERROR) { +                    return -1; +                } +                else { +                    break; +                } +            } +        } +    } + +    return size_delivered; +} + +i64 +os_net_secure_stream_send(u32 id, u8 *buff, size_t size) +{ +    OSNetSecureStream *secure_stream = &s_secure_streams[id]; +    if (secure_stream->status != OS_NET_SECURE_STREAM_CONNECTED) { +        return -1; +    } + + +    size_t size_original = size; + +    while (size) { +        AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->send_ciphertext; +        void *aes_payload = secure_stream->send_ciphertext + sizeof(*aes_header); +        i64 aes_payload_size = MIN(size, sizeof(secure_stream->send_ciphertext) - sizeof(*aes_header)); + +        aes_header->payload_size = aes_payload_size; +        b32 encrypted = aes_gcm_encrypt(&secure_stream->send_aes_key, +                                        &secure_stream->send_aes_iv, +                                        aes_payload, buff, aes_payload_size, +                                        aes_header->tag, sizeof(aes_header->tag)); +        if (!encrypted) { +            secure_stream->status = OS_NET_SECURE_STREAM_ERROR; +            return -1; +        } + +        size_t size_to_send = sizeof(*aes_header) + aes_payload_size; +        i64 size_sent = send(secure_stream->fd, secure_stream->send_ciphertext, size_to_send, 0); +        if (size_sent != size_to_send) { +            printf("error: send only sent %ld/%ld bytes\n", size_to_send, size_sent); +            secure_stream->status = OS_NET_SECURE_STREAM_ERROR; +            return -1; +        } + +        buff += aes_payload_size; +        size -= aes_payload_size; +    } + +    i64 result = size_original - size; +    return result; +} + +int +os_net_secure_stream_get_fd(u32 id) +{ +    OSNetSecureStream *secure_stream = &s_secure_streams[id]; +    return secure_stream->fd; +} + + +internal_fn void +os_net_secure_stream_free(u32 id) +{ +    s_free_ids[s_free_id_count] = id; +    s_free_id_count += 1; +} + + +internal_fn u32 +os_net_secure_stream_alloc(void) +{ +    assert(s_free_id_count > 0); +    if (s_free_id_count == 0) { +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    u32 id = s_free_ids[s_free_id_count-1]; +    s_free_id_count -= 1; + +    return id; +} + + +void +os_net_secure_stream_close(u32 id) +{ +    OSNetSecureStream *secure_stream = &s_secure_streams[id]; + +    close(secure_stream->fd); +    memset(secure_stream, 0, sizeof(*secure_stream)); + +    os_net_secure_stream_free(id); +} + +u32 +os_net_secure_stream_listen(u16 port, EVP_PKEY *rsa_pri) +{ +    int fd = socket(AF_INET, SOCK_STREAM, 0); +    if (fd == -1) { +        perror("socket()"); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    int enable_reuse = 1; +    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &enable_reuse, sizeof(int)) < 0) { +        perror("setsockopt(SO_REUSEADDR)"); +        close(fd); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    struct sockaddr_in local_addr; +    local_addr.sin_family = AF_INET; +    local_addr.sin_port = htons(port); +    local_addr.sin_addr.s_addr = INADDR_ANY; +    if (bind(fd, (struct sockaddr *)&local_addr, sizeof(local_addr)) < 0) { +        perror("bind()"); +        close(fd); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    int backlog = 128; +    if (listen(fd, backlog) == -1) { +        perror("listen()"); +        close(fd); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    u32 id = os_net_secure_stream_alloc(); + +    OSNetSecureStream *secure_stream = &s_secure_streams[id]; +    secure_stream->fd = fd; +    secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED; +    secure_stream->rsa_key = rsa_pri; + + +    return id; +} + + +internal_fn b32 +set_socket_nonblocking(int fd) +{ +    int flags; +    if ((flags = fcntl(fd, F_GETFL, 0)) < 0) { +        perror("fcntl(F_GETFL)"); +        return false; +    } +    if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) { +        perror("fcntl(F_SETFL)"); +        return false; +    } +    return true; +} + +u32 +os_net_secure_stream_accept(u32 listener_id) +{ +    OSNetSecureStream *listener = &s_secure_streams[listener_id]; + + +    struct sockaddr_in addr; +    socklen_t addr_size = sizeof(addr); + +    int fd = accept(listener->fd, (struct sockaddr*)&addr, &addr_size); +    if (fd == -1) { +        printf("accept() failed\n"); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    u32 secure_stream_id = os_net_secure_stream_alloc(); +    if (secure_stream_id == OS_NET_SECURE_STREAM_ID_INVALID) { +        close(fd); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    OSNetSecureStream *secure_stream = &s_secure_streams[secure_stream_id]; +    secure_stream->fd = fd; +    secure_stream->status = OS_NET_SECURE_STREAM_HANDSHAKING; + + +    // recv rsa request +    u8 encrypted_cm_handshake[512]; // Todo: use secure_stream->recv_buff +    i64 recvd_size = recv(secure_stream->fd, encrypted_cm_handshake, sizeof(encrypted_cm_handshake), MSG_WAITALL); +    if (recvd_size != sizeof(encrypted_cm_handshake)) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    // decrypt rsa request +    CM_Handshake cm_handshake; +    if (!rsa_decrypt(listener->rsa_key, &cm_handshake, encrypted_cm_handshake, sizeof(encrypted_cm_handshake))) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    // process request +    memcpy(&secure_stream->recv_aes_key, &cm_handshake.aes_key_client_encrypt, sizeof(cm_handshake.aes_key_client_encrypt)); +    memcpy(&secure_stream->recv_aes_iv,  &cm_handshake.aes_iv_client_encrypt,  sizeof(cm_handshake.aes_iv_client_encrypt)); +    memcpy(&secure_stream->send_aes_key, &cm_handshake.aes_key_client_decrypt, sizeof(cm_handshake.aes_key_client_decrypt)); +    memcpy(&secure_stream->send_aes_iv,  &cm_handshake.aes_iv_client_decrypt,  sizeof(cm_handshake.aes_iv_client_decrypt)); + + +    // prepare aes response +    AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->send_ciphertext; + +    SM_Handshake sm_handshake; +    memcpy(sm_handshake.client_random, cm_handshake.client_random, sizeof(cm_handshake.client_random)); + +    // encrypt aes response +    if (!aes_gcm_encrypt(&secure_stream->send_aes_key, +                         &secure_stream->send_aes_iv, +                         secure_stream->send_ciphertext + sizeof(*aes_header), (u8*)&sm_handshake, sizeof(sm_handshake), +                         aes_header->tag, sizeof(aes_header->tag))) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } +    aes_header->payload_size = sizeof(sm_handshake); +    secure_stream->send_ciphertext_size_used = sizeof(*aes_header) + sizeof(sm_handshake); + +    // send response +    i64 sent_size = send(secure_stream->fd, secure_stream->send_ciphertext, secure_stream->send_ciphertext_size_used, 0); +    if (sent_size != secure_stream->send_ciphertext_size_used) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    if (!set_socket_nonblocking(fd)) { +        close(fd); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED; +    return secure_stream_id; +} + +u32 +os_net_secure_stream_connect(char *address, u16 port, EVP_PKEY *server_rsa_pub) +{ +    int fd = socket(AF_INET, SOCK_STREAM, 0); +    if (fd == -1) { +        printf("cant open socket\n"); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    struct sockaddr_in target_addr; +    memset(&target_addr, 0, sizeof(target_addr)); +    target_addr.sin_family = AF_INET; +    target_addr.sin_port = htons(port); +    target_addr.sin_addr.s_addr = inet_addr(address); +    if (connect(fd, (struct sockaddr*)&target_addr, sizeof(target_addr)) == -1) { +        printf("connect failed\n"); +        close(fd); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + + +    u32 secure_stream_id = os_net_secure_stream_alloc(); +    if (secure_stream_id == OS_NET_SECURE_STREAM_ID_INVALID) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    OSNetSecureStream *secure_stream = &s_secure_streams[secure_stream_id]; +    secure_stream->fd = fd; +    secure_stream->rsa_key = server_rsa_pub; + + +    if (!aes_gcm_key_init_random(&secure_stream->send_aes_key)) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } +    if (!aes_gcm_iv_init(&secure_stream->send_aes_iv)) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } +    if (!aes_gcm_key_init_random(&secure_stream->recv_aes_key)) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } +    if (!aes_gcm_iv_init(&secure_stream->recv_aes_iv)) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    /* Request */ + +    // prepare rsa request +    CM_Handshake cm_handshake; +    memcpy(&cm_handshake.aes_key_client_encrypt, &secure_stream->send_aes_key, sizeof(secure_stream->send_aes_key)); +    memcpy(&cm_handshake.aes_iv_client_encrypt,  &secure_stream->send_aes_iv,  sizeof(secure_stream->send_aes_iv)); +    memcpy(&cm_handshake.aes_key_client_decrypt, &secure_stream->recv_aes_key, sizeof(secure_stream->recv_aes_key)); +    memcpy(&cm_handshake.aes_iv_client_decrypt,  &secure_stream->recv_aes_iv,  sizeof(secure_stream->recv_aes_iv)); +    if (RAND_bytes(cm_handshake.client_random, sizeof(cm_handshake.client_random)) != 1) { +        printf("RAND_bytes failed at %s:%d", __FILE__, __LINE__); +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    // encrypt rsa request +    void *encrypted_req = secure_stream->send_ciphertext; +    if (!rsa_encrypt(server_rsa_pub, encrypted_req, &cm_handshake, sizeof(cm_handshake))) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    // send rsa request +    i64 sent_size = send(secure_stream->fd, encrypted_req, 512, 0); +    if (sent_size != 512) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    /* Response */ + +    // recv aes response +    AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->recv_ciphertext; +    void *aes_payload = secure_stream->recv_ciphertext + sizeof(*aes_header); + +    size_t response_size = sizeof(*aes_header) + sizeof(SM_Handshake); + +    i64 recvd_size = recv(secure_stream->fd, secure_stream->recv_ciphertext, response_size, 0); +    if (recvd_size != response_size) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    // decrypt aes response +    SM_Handshake sm_handshake; +    if (!aes_gcm_decrypt(&secure_stream->recv_aes_key, +                         &secure_stream->recv_aes_iv, +                         (u8*)&sm_handshake, aes_payload, sizeof(sm_handshake), +                         aes_header->tag, sizeof(aes_header->tag))) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + +    // verify aes response +    assert(sizeof(cm_handshake.client_random) == sizeof(sm_handshake.client_random)); +    if (memcmp(cm_handshake.client_random, sm_handshake.client_random, sizeof(cm_handshake.client_random)) != 0) { +        close(fd); +        os_net_secure_stream_free(secure_stream_id); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    if (!set_socket_nonblocking(fd)) { +        close(fd); +        return OS_NET_SECURE_STREAM_ID_INVALID; +    } + + +    secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED; +    secure_stream->rsa_key = server_rsa_pub; +    return secure_stream_id; +} + +void +os_net_secure_streams_init(Arena *arena, size_t max_count) +{ +    s_max_secure_stream_count = max_count; +    s_secure_streams = arena_push(arena, max_count * sizeof(OSNetSecureStream)); +    memset(s_secure_streams, 0, max_count * sizeof(OSNetSecureStream)); + +    s_free_id_count = max_count; +    s_free_ids = arena_push(arena, max_count * sizeof(u32)); +    for (size_t i = 0; i < max_count; i++) { +        s_free_ids[i] = i; +    } +} +  | 
