#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include typedef struct { // space left for 512 - 64*2 - 2 = 382 for 4096-bit rsa and SHA-512 AesGcmKey aes_key_client_encrypt; // 32 bytes AesGcmIv aes_iv_client_encrypt; // 12 bytes AesGcmKey aes_key_client_decrypt; // 32 bytes AesGcmIv aes_iv_client_decrypt; // 12 bytes // space left for 382 - 2*32 - 2*12 = 294 for 4096-bit rsa and SHA-512 u8 client_random[256]; } CM_Handshake; typedef struct { u8 client_random[256]; } SM_Handshake; typedef struct { int fd; OSNetSecureStreamStatus status; EVP_PKEY *rsa_key; AesGcmKey recv_aes_key; // recvd with CM_Handshake AesGcmIv recv_aes_iv; // recvd_with CM_Handshake AesGcmHeader recv_aes_header; // recvd with each aes message b32 recv_aes_header_initialized; size_t recv_ciphertext_size_filled; u8 recv_ciphertext[1408]; size_t recv_plaintext_size_filled; size_t recv_plaintext_size_processed; u8 recv_plaintext[1408]; AesGcmKey send_aes_key; // recvd with CM_Handshake AesGcmIv send_aes_iv; // recvd with CM_Handshake size_t send_ciphertext_size_used; u8 send_ciphertext[1408]; } OSNetSecureStream; internal_var u32 s_max_secure_stream_count; internal_var OSNetSecureStream *s_secure_streams; internal_var u32 s_free_id_count; internal_var u32 *s_free_ids; OSNetSecureStreamStatus os_net_secure_stream_get_status(u32 id) { return s_secure_streams[id].status; } internal_fn b32 recv_and_decrypt_aes_package(OSNetSecureStream *secure_stream) { i64 size_to_recv; i64 size_recvd; // recv aes header if (!secure_stream->recv_aes_header_initialized) { void *dest = secure_stream->recv_ciphertext + secure_stream->recv_ciphertext_size_filled; size_to_recv = sizeof(AesGcmHeader) - secure_stream->recv_ciphertext_size_filled; size_recvd = recv(secure_stream->fd, dest, size_to_recv, 0); if (size_recvd < 0) { if (errno != EAGAIN) { printf("error: recv aes header failed with errno = %d\n", errno); secure_stream->status = OS_NET_SECURE_STREAM_ERROR; } return false; } else if (size_recvd == 0) { secure_stream->status = OS_NET_SECURE_STREAM_DISCONNECTED; return false; } else if (size_recvd < size_to_recv) { secure_stream->recv_ciphertext_size_filled += size_recvd; return false; } AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->recv_ciphertext; if (aes_header->payload_size > sizeof(secure_stream->recv_ciphertext)) { printf("error: aes header has invalid payload size\n"); secure_stream->recv_ciphertext_size_filled = 0; secure_stream->status = OS_NET_SECURE_STREAM_ERROR; return false; } secure_stream->recv_aes_header = *(AesGcmHeader*)(secure_stream->recv_ciphertext); secure_stream->recv_aes_header_initialized = true; secure_stream->recv_ciphertext_size_filled = 0; } // recv aes payload size_to_recv = secure_stream->recv_aes_header.payload_size - secure_stream->recv_ciphertext_size_filled; size_recvd = recv(secure_stream->fd, secure_stream->recv_ciphertext, size_to_recv, 0); if (size_recvd < 0) { if (errno != EAGAIN) { printf("error: recv aes payload failed with errno = %d\n", errno); secure_stream->status = OS_NET_SECURE_STREAM_ERROR; } return false; } else if (size_recvd == 0) { secure_stream->status = OS_NET_SECURE_STREAM_DISCONNECTED; return false; } else if (size_recvd < size_to_recv) { secure_stream->recv_ciphertext_size_filled += size_recvd; return false; } secure_stream->recv_ciphertext_size_filled += size_recvd; // decrypt b32 decrypted = aes_gcm_decrypt(&secure_stream->recv_aes_key, &secure_stream->recv_aes_iv, secure_stream->recv_plaintext, secure_stream->recv_ciphertext, secure_stream->recv_ciphertext_size_filled, secure_stream->recv_aes_header.tag, sizeof(secure_stream->recv_aes_header.tag)); if (!decrypted) { secure_stream->status = OS_NET_SECURE_STREAM_ERROR; return false; } secure_stream->recv_plaintext_size_filled = secure_stream->recv_ciphertext_size_filled; secure_stream->recv_aes_header_initialized = false; secure_stream->recv_plaintext_size_processed = 0; secure_stream->recv_ciphertext_size_filled = 0; return true; } i64 os_net_secure_stream_recv(u32 id, u8 *buff, size_t size) { OSNetSecureStream *secure_stream = &s_secure_streams[id]; if (secure_stream->status != OS_NET_SECURE_STREAM_CONNECTED) { return -1; } size_t size_delivered = 0; while (size_delivered < size) { size_t plaintext_size_avail = secure_stream->recv_plaintext_size_filled - secure_stream->recv_plaintext_size_processed; if (plaintext_size_avail > 0) { u8 *src = secure_stream->recv_plaintext + secure_stream->recv_plaintext_size_processed; u8 *dest = buff + size_delivered; size_t size_to_copy = MIN(size, plaintext_size_avail); memcpy(dest, src, size_to_copy); secure_stream->recv_plaintext_size_processed += size_to_copy; size_delivered += size_to_copy; } else { b32 recvd_and_decrypted = recv_and_decrypt_aes_package(secure_stream); if (!recvd_and_decrypted) { if (secure_stream->status == OS_NET_SECURE_STREAM_ERROR) { return -1; } else { break; } } } } return size_delivered; } i64 os_net_secure_stream_send(u32 id, u8 *buff, size_t size) { OSNetSecureStream *secure_stream = &s_secure_streams[id]; if (secure_stream->status != OS_NET_SECURE_STREAM_CONNECTED) { return -1; } size_t size_original = size; while (size) { AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->send_ciphertext; void *aes_payload = secure_stream->send_ciphertext + sizeof(*aes_header); i64 aes_payload_size = MIN(size, sizeof(secure_stream->send_ciphertext) - sizeof(*aes_header)); aes_header->payload_size = aes_payload_size; b32 encrypted = aes_gcm_encrypt(&secure_stream->send_aes_key, &secure_stream->send_aes_iv, aes_payload, buff, aes_payload_size, aes_header->tag, sizeof(aes_header->tag)); if (!encrypted) { secure_stream->status = OS_NET_SECURE_STREAM_ERROR; return -1; } size_t size_to_send = sizeof(*aes_header) + aes_payload_size; i64 size_sent = send(secure_stream->fd, secure_stream->send_ciphertext, size_to_send, 0); if (size_sent != size_to_send) { printf("error: send only sent %ld/%ld bytes\n", size_to_send, size_sent); secure_stream->status = OS_NET_SECURE_STREAM_ERROR; return -1; } buff += aes_payload_size; size -= aes_payload_size; } i64 result = size_original - size; return result; } int os_net_secure_stream_get_fd(u32 id) { OSNetSecureStream *secure_stream = &s_secure_streams[id]; return secure_stream->fd; } internal_fn void os_net_secure_stream_free(u32 id) { s_free_ids[s_free_id_count] = id; s_free_id_count += 1; } internal_fn u32 os_net_secure_stream_alloc(void) { assert(s_free_id_count > 0); if (s_free_id_count == 0) { return OS_NET_SECURE_STREAM_ID_INVALID; } u32 id = s_free_ids[s_free_id_count-1]; s_free_id_count -= 1; return id; } void os_net_secure_stream_close(u32 id) { OSNetSecureStream *secure_stream = &s_secure_streams[id]; close(secure_stream->fd); memset(secure_stream, 0, sizeof(*secure_stream)); os_net_secure_stream_free(id); } u32 os_net_secure_stream_listen(u16 port, EVP_PKEY *rsa_pri) { int fd = socket(AF_INET, SOCK_STREAM, 0); if (fd == -1) { perror("socket()"); return OS_NET_SECURE_STREAM_ID_INVALID; } int enable_reuse = 1; if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &enable_reuse, sizeof(int)) < 0) { perror("setsockopt(SO_REUSEADDR)"); close(fd); return OS_NET_SECURE_STREAM_ID_INVALID; } struct sockaddr_in local_addr; local_addr.sin_family = AF_INET; local_addr.sin_port = htons(port); local_addr.sin_addr.s_addr = INADDR_ANY; if (bind(fd, (struct sockaddr *)&local_addr, sizeof(local_addr)) < 0) { perror("bind()"); close(fd); return OS_NET_SECURE_STREAM_ID_INVALID; } int backlog = 128; if (listen(fd, backlog) == -1) { perror("listen()"); close(fd); return OS_NET_SECURE_STREAM_ID_INVALID; } u32 id = os_net_secure_stream_alloc(); OSNetSecureStream *secure_stream = &s_secure_streams[id]; secure_stream->fd = fd; secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED; secure_stream->rsa_key = rsa_pri; return id; } internal_fn b32 set_socket_nonblocking(int fd) { int flags; if ((flags = fcntl(fd, F_GETFL, 0)) < 0) { perror("fcntl(F_GETFL)"); return false; } if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) { perror("fcntl(F_SETFL)"); return false; } return true; } u32 os_net_secure_stream_accept(u32 listener_id) { OSNetSecureStream *listener = &s_secure_streams[listener_id]; struct sockaddr_in addr; socklen_t addr_size = sizeof(addr); int fd = accept(listener->fd, (struct sockaddr*)&addr, &addr_size); if (fd == -1) { printf("accept() failed\n"); return OS_NET_SECURE_STREAM_ID_INVALID; } u32 secure_stream_id = os_net_secure_stream_alloc(); if (secure_stream_id == OS_NET_SECURE_STREAM_ID_INVALID) { close(fd); return OS_NET_SECURE_STREAM_ID_INVALID; } OSNetSecureStream *secure_stream = &s_secure_streams[secure_stream_id]; secure_stream->fd = fd; secure_stream->status = OS_NET_SECURE_STREAM_HANDSHAKING; // recv rsa request u8 encrypted_cm_handshake[512]; // Todo: use secure_stream->recv_buff i64 recvd_size = recv(secure_stream->fd, encrypted_cm_handshake, sizeof(encrypted_cm_handshake), MSG_WAITALL); if (recvd_size != sizeof(encrypted_cm_handshake)) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } // decrypt rsa request CM_Handshake cm_handshake; if (!rsa_decrypt(listener->rsa_key, &cm_handshake, encrypted_cm_handshake, sizeof(encrypted_cm_handshake))) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } // process request memcpy(&secure_stream->recv_aes_key, &cm_handshake.aes_key_client_encrypt, sizeof(cm_handshake.aes_key_client_encrypt)); memcpy(&secure_stream->recv_aes_iv, &cm_handshake.aes_iv_client_encrypt, sizeof(cm_handshake.aes_iv_client_encrypt)); memcpy(&secure_stream->send_aes_key, &cm_handshake.aes_key_client_decrypt, sizeof(cm_handshake.aes_key_client_decrypt)); memcpy(&secure_stream->send_aes_iv, &cm_handshake.aes_iv_client_decrypt, sizeof(cm_handshake.aes_iv_client_decrypt)); // prepare aes response AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->send_ciphertext; SM_Handshake sm_handshake; memcpy(sm_handshake.client_random, cm_handshake.client_random, sizeof(cm_handshake.client_random)); // encrypt aes response if (!aes_gcm_encrypt(&secure_stream->send_aes_key, &secure_stream->send_aes_iv, secure_stream->send_ciphertext + sizeof(*aes_header), (u8*)&sm_handshake, sizeof(sm_handshake), aes_header->tag, sizeof(aes_header->tag))) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } aes_header->payload_size = sizeof(sm_handshake); secure_stream->send_ciphertext_size_used = sizeof(*aes_header) + sizeof(sm_handshake); // send response i64 sent_size = send(secure_stream->fd, secure_stream->send_ciphertext, secure_stream->send_ciphertext_size_used, 0); if (sent_size != secure_stream->send_ciphertext_size_used) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } if (!set_socket_nonblocking(fd)) { close(fd); return OS_NET_SECURE_STREAM_ID_INVALID; } secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED; return secure_stream_id; } u32 os_net_secure_stream_connect(char *address, u16 port, EVP_PKEY *server_rsa_pub) { int fd = socket(AF_INET, SOCK_STREAM, 0); if (fd == -1) { printf("cant open socket\n"); return OS_NET_SECURE_STREAM_ID_INVALID; } struct sockaddr_in target_addr; memset(&target_addr, 0, sizeof(target_addr)); target_addr.sin_family = AF_INET; target_addr.sin_port = htons(port); target_addr.sin_addr.s_addr = inet_addr(address); if (connect(fd, (struct sockaddr*)&target_addr, sizeof(target_addr)) == -1) { printf("connect failed\n"); close(fd); return OS_NET_SECURE_STREAM_ID_INVALID; } u32 secure_stream_id = os_net_secure_stream_alloc(); if (secure_stream_id == OS_NET_SECURE_STREAM_ID_INVALID) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } OSNetSecureStream *secure_stream = &s_secure_streams[secure_stream_id]; secure_stream->fd = fd; secure_stream->rsa_key = server_rsa_pub; if (!aes_gcm_key_init_random(&secure_stream->send_aes_key)) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } if (!aes_gcm_iv_init(&secure_stream->send_aes_iv)) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } if (!aes_gcm_key_init_random(&secure_stream->recv_aes_key)) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } if (!aes_gcm_iv_init(&secure_stream->recv_aes_iv)) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } /* Request */ // prepare rsa request CM_Handshake cm_handshake; memcpy(&cm_handshake.aes_key_client_encrypt, &secure_stream->send_aes_key, sizeof(secure_stream->send_aes_key)); memcpy(&cm_handshake.aes_iv_client_encrypt, &secure_stream->send_aes_iv, sizeof(secure_stream->send_aes_iv)); memcpy(&cm_handshake.aes_key_client_decrypt, &secure_stream->recv_aes_key, sizeof(secure_stream->recv_aes_key)); memcpy(&cm_handshake.aes_iv_client_decrypt, &secure_stream->recv_aes_iv, sizeof(secure_stream->recv_aes_iv)); if (RAND_bytes(cm_handshake.client_random, sizeof(cm_handshake.client_random)) != 1) { printf("RAND_bytes failed at %s:%d", __FILE__, __LINE__); close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } // encrypt rsa request void *encrypted_req = secure_stream->send_ciphertext; if (!rsa_encrypt(server_rsa_pub, encrypted_req, &cm_handshake, sizeof(cm_handshake))) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } // send rsa request i64 sent_size = send(secure_stream->fd, encrypted_req, 512, 0); if (sent_size != 512) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } /* Response */ // recv aes response AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->recv_ciphertext; void *aes_payload = secure_stream->recv_ciphertext + sizeof(*aes_header); size_t response_size = sizeof(*aes_header) + sizeof(SM_Handshake); i64 recvd_size = recv(secure_stream->fd, secure_stream->recv_ciphertext, response_size, 0); if (recvd_size != response_size) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } // decrypt aes response SM_Handshake sm_handshake; if (!aes_gcm_decrypt(&secure_stream->recv_aes_key, &secure_stream->recv_aes_iv, (u8*)&sm_handshake, aes_payload, sizeof(sm_handshake), aes_header->tag, sizeof(aes_header->tag))) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } // verify aes response assert(sizeof(cm_handshake.client_random) == sizeof(sm_handshake.client_random)); if (memcmp(cm_handshake.client_random, sm_handshake.client_random, sizeof(cm_handshake.client_random)) != 0) { close(fd); os_net_secure_stream_free(secure_stream_id); return OS_NET_SECURE_STREAM_ID_INVALID; } if (!set_socket_nonblocking(fd)) { close(fd); return OS_NET_SECURE_STREAM_ID_INVALID; } secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED; secure_stream->rsa_key = server_rsa_pub; return secure_stream_id; } void os_net_secure_streams_init(Arena *arena, size_t max_count) { s_max_secure_stream_count = max_count; s_secure_streams = arena_push(arena, max_count * sizeof(OSNetSecureStream)); memset(s_secure_streams, 0, max_count * sizeof(OSNetSecureStream)); s_free_id_count = max_count; s_free_ids = arena_push(arena, max_count * sizeof(u32)); for (size_t i = 0; i < max_count; i++) { s_free_ids[i] = i; } }