aboutsummaryrefslogtreecommitdiff
path: root/src/os/linux/linux_net_secure_stream.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/os/linux/linux_net_secure_stream.c')
-rw-r--r--src/os/linux/linux_net_secure_stream.c579
1 files changed, 579 insertions, 0 deletions
diff --git a/src/os/linux/linux_net_secure_stream.c b/src/os/linux/linux_net_secure_stream.c
new file mode 100644
index 0000000..cfacfc3
--- /dev/null
+++ b/src/os/linux/linux_net_secure_stream.c
@@ -0,0 +1,579 @@
+#include <os/os.h>
+#include <basic/basic.h>
+#include <basic/arena.h>
+#include <crypto/rsa.h>
+#include <crypto/aes_gcm.h>
+
+#include <string.h>
+
+#include <openssl/aes.h>
+#include <openssl/rsa.h>
+#include <openssl/rand.h>
+
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#include <sys/socket.h>
+#include <unistd.h>
+#include <fcntl.h>
+#include <errno.h>
+
+typedef struct {
+ // space left for 512 - 64*2 - 2 = 382 for 4096-bit rsa and SHA-512
+
+ AesGcmKey aes_key_client_encrypt; // 32 bytes
+ AesGcmIv aes_iv_client_encrypt; // 12 bytes
+
+ AesGcmKey aes_key_client_decrypt; // 32 bytes
+ AesGcmIv aes_iv_client_decrypt; // 12 bytes
+
+ // space left for 382 - 2*32 - 2*12 = 294 for 4096-bit rsa and SHA-512
+
+ u8 client_random[256];
+} CM_Handshake;
+
+
+typedef struct {
+ u8 client_random[256];
+} SM_Handshake;
+
+
+typedef struct {
+ int fd;
+ OSNetSecureStreamStatus status;
+
+ EVP_PKEY *rsa_key;
+
+ AesGcmKey recv_aes_key; // recvd with CM_Handshake
+ AesGcmIv recv_aes_iv; // recvd_with CM_Handshake
+ AesGcmHeader recv_aes_header; // recvd with each aes message
+ b32 recv_aes_header_initialized;
+ size_t recv_ciphertext_size_filled;
+ u8 recv_ciphertext[1408];
+
+ size_t recv_plaintext_size_filled;
+ size_t recv_plaintext_size_processed;
+ u8 recv_plaintext[1408];
+
+ AesGcmKey send_aes_key; // recvd with CM_Handshake
+ AesGcmIv send_aes_iv; // recvd with CM_Handshake
+ size_t send_ciphertext_size_used;
+ u8 send_ciphertext[1408];
+} OSNetSecureStream;
+
+
+internal_var u32 s_max_secure_stream_count;
+internal_var OSNetSecureStream *s_secure_streams;
+
+internal_var u32 s_free_id_count;
+internal_var u32 *s_free_ids;
+
+
+
+OSNetSecureStreamStatus
+os_net_secure_stream_get_status(u32 id)
+{
+ return s_secure_streams[id].status;
+}
+
+internal_fn b32
+recv_and_decrypt_aes_package(OSNetSecureStream *secure_stream)
+{
+ i64 size_to_recv;
+ i64 size_recvd;
+
+ // recv aes header
+ if (!secure_stream->recv_aes_header_initialized) {
+ void *dest = secure_stream->recv_ciphertext + secure_stream->recv_ciphertext_size_filled;
+ size_to_recv = sizeof(AesGcmHeader) - secure_stream->recv_ciphertext_size_filled;
+ size_recvd = recv(secure_stream->fd, dest, size_to_recv, 0);
+ if (size_recvd < 0) {
+ if (errno != EAGAIN) {
+ printf("error: recv aes header failed with errno = %d\n", errno);
+ secure_stream->status = OS_NET_SECURE_STREAM_ERROR;
+ }
+ return false;
+ }
+ else if (size_recvd == 0) {
+ secure_stream->status = OS_NET_SECURE_STREAM_DISCONNECTED;
+ return false;
+ }
+ else if (size_recvd < size_to_recv) {
+ secure_stream->recv_ciphertext_size_filled += size_recvd;
+ return false;
+ }
+
+ AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->recv_ciphertext;
+ if (aes_header->payload_size > sizeof(secure_stream->recv_ciphertext)) {
+ printf("error: aes header has invalid payload size\n");
+ secure_stream->recv_ciphertext_size_filled = 0;
+ secure_stream->status = OS_NET_SECURE_STREAM_ERROR;
+ return false;
+ }
+
+ secure_stream->recv_aes_header = *(AesGcmHeader*)(secure_stream->recv_ciphertext);
+ secure_stream->recv_aes_header_initialized = true;
+ secure_stream->recv_ciphertext_size_filled = 0;
+ }
+
+
+ // recv aes payload
+ size_to_recv = secure_stream->recv_aes_header.payload_size - secure_stream->recv_ciphertext_size_filled;
+ size_recvd = recv(secure_stream->fd, secure_stream->recv_ciphertext, size_to_recv, 0);
+ if (size_recvd < 0) {
+ if (errno != EAGAIN) {
+ printf("error: recv aes payload failed with errno = %d\n", errno);
+ secure_stream->status = OS_NET_SECURE_STREAM_ERROR;
+ }
+ return false;
+ }
+ else if (size_recvd == 0) {
+ secure_stream->status = OS_NET_SECURE_STREAM_DISCONNECTED;
+ return false;
+ }
+ else if (size_recvd < size_to_recv) {
+ secure_stream->recv_ciphertext_size_filled += size_recvd;
+ return false;
+ }
+
+ secure_stream->recv_ciphertext_size_filled += size_recvd;
+
+
+ // decrypt
+ b32 decrypted = aes_gcm_decrypt(&secure_stream->recv_aes_key,
+ &secure_stream->recv_aes_iv,
+ secure_stream->recv_plaintext,
+ secure_stream->recv_ciphertext,
+ secure_stream->recv_ciphertext_size_filled,
+ secure_stream->recv_aes_header.tag,
+ sizeof(secure_stream->recv_aes_header.tag));
+ if (!decrypted) {
+ secure_stream->status = OS_NET_SECURE_STREAM_ERROR;
+ return false;
+ }
+
+ secure_stream->recv_plaintext_size_filled = secure_stream->recv_ciphertext_size_filled;
+ secure_stream->recv_aes_header_initialized = false;
+ secure_stream->recv_plaintext_size_processed = 0;
+ secure_stream->recv_ciphertext_size_filled = 0;
+
+ return true;
+}
+
+
+i64
+os_net_secure_stream_recv(u32 id, u8 *buff, size_t size)
+{
+ OSNetSecureStream *secure_stream = &s_secure_streams[id];
+ if (secure_stream->status != OS_NET_SECURE_STREAM_CONNECTED) {
+ return -1;
+ }
+
+ size_t size_delivered = 0;
+ while (size_delivered < size) {
+ size_t plaintext_size_avail = secure_stream->recv_plaintext_size_filled - secure_stream->recv_plaintext_size_processed;
+ if (plaintext_size_avail > 0) {
+ u8 *src = secure_stream->recv_plaintext + secure_stream->recv_plaintext_size_processed;
+ u8 *dest = buff + size_delivered;
+ size_t size_to_copy = MIN(size, plaintext_size_avail);
+ memcpy(dest, src, size_to_copy);
+ secure_stream->recv_plaintext_size_processed += size_to_copy;
+ size_delivered += size_to_copy;
+ }
+ else {
+ b32 recvd_and_decrypted = recv_and_decrypt_aes_package(secure_stream);
+ if (!recvd_and_decrypted) {
+ if (secure_stream->status == OS_NET_SECURE_STREAM_ERROR) {
+ return -1;
+ }
+ else {
+ break;
+ }
+ }
+ }
+ }
+
+ return size_delivered;
+}
+
+i64
+os_net_secure_stream_send(u32 id, u8 *buff, size_t size)
+{
+ OSNetSecureStream *secure_stream = &s_secure_streams[id];
+ if (secure_stream->status != OS_NET_SECURE_STREAM_CONNECTED) {
+ return -1;
+ }
+
+
+ size_t size_original = size;
+
+ while (size) {
+ AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->send_ciphertext;
+ void *aes_payload = secure_stream->send_ciphertext + sizeof(*aes_header);
+ i64 aes_payload_size = MIN(size, sizeof(secure_stream->send_ciphertext) - sizeof(*aes_header));
+
+ aes_header->payload_size = aes_payload_size;
+ b32 encrypted = aes_gcm_encrypt(&secure_stream->send_aes_key,
+ &secure_stream->send_aes_iv,
+ aes_payload, buff, aes_payload_size,
+ aes_header->tag, sizeof(aes_header->tag));
+ if (!encrypted) {
+ secure_stream->status = OS_NET_SECURE_STREAM_ERROR;
+ return -1;
+ }
+
+ size_t size_to_send = sizeof(*aes_header) + aes_payload_size;
+ i64 size_sent = send(secure_stream->fd, secure_stream->send_ciphertext, size_to_send, 0);
+ if (size_sent != size_to_send) {
+ printf("error: send only sent %ld/%ld bytes\n", size_to_send, size_sent);
+ secure_stream->status = OS_NET_SECURE_STREAM_ERROR;
+ return -1;
+ }
+
+ buff += aes_payload_size;
+ size -= aes_payload_size;
+ }
+
+ i64 result = size_original - size;
+ return result;
+}
+
+int
+os_net_secure_stream_get_fd(u32 id)
+{
+ OSNetSecureStream *secure_stream = &s_secure_streams[id];
+ return secure_stream->fd;
+}
+
+
+internal_fn void
+os_net_secure_stream_free(u32 id)
+{
+ s_free_ids[s_free_id_count] = id;
+ s_free_id_count += 1;
+}
+
+
+internal_fn u32
+os_net_secure_stream_alloc(void)
+{
+ assert(s_free_id_count > 0);
+ if (s_free_id_count == 0) {
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ u32 id = s_free_ids[s_free_id_count-1];
+ s_free_id_count -= 1;
+
+ return id;
+}
+
+
+void
+os_net_secure_stream_close(u32 id)
+{
+ OSNetSecureStream *secure_stream = &s_secure_streams[id];
+
+ close(secure_stream->fd);
+ memset(secure_stream, 0, sizeof(*secure_stream));
+
+ os_net_secure_stream_free(id);
+}
+
+u32
+os_net_secure_stream_listen(u16 port, EVP_PKEY *rsa_pri)
+{
+ int fd = socket(AF_INET, SOCK_STREAM, 0);
+ if (fd == -1) {
+ perror("socket()");
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ int enable_reuse = 1;
+ if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &enable_reuse, sizeof(int)) < 0) {
+ perror("setsockopt(SO_REUSEADDR)");
+ close(fd);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ struct sockaddr_in local_addr;
+ local_addr.sin_family = AF_INET;
+ local_addr.sin_port = htons(port);
+ local_addr.sin_addr.s_addr = INADDR_ANY;
+ if (bind(fd, (struct sockaddr *)&local_addr, sizeof(local_addr)) < 0) {
+ perror("bind()");
+ close(fd);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ int backlog = 128;
+ if (listen(fd, backlog) == -1) {
+ perror("listen()");
+ close(fd);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ u32 id = os_net_secure_stream_alloc();
+
+ OSNetSecureStream *secure_stream = &s_secure_streams[id];
+ secure_stream->fd = fd;
+ secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED;
+ secure_stream->rsa_key = rsa_pri;
+
+
+ return id;
+}
+
+
+internal_fn b32
+set_socket_nonblocking(int fd)
+{
+ int flags;
+ if ((flags = fcntl(fd, F_GETFL, 0)) < 0) {
+ perror("fcntl(F_GETFL)");
+ return false;
+ }
+ if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) {
+ perror("fcntl(F_SETFL)");
+ return false;
+ }
+ return true;
+}
+
+u32
+os_net_secure_stream_accept(u32 listener_id)
+{
+ OSNetSecureStream *listener = &s_secure_streams[listener_id];
+
+
+ struct sockaddr_in addr;
+ socklen_t addr_size = sizeof(addr);
+
+ int fd = accept(listener->fd, (struct sockaddr*)&addr, &addr_size);
+ if (fd == -1) {
+ printf("accept() failed\n");
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ u32 secure_stream_id = os_net_secure_stream_alloc();
+ if (secure_stream_id == OS_NET_SECURE_STREAM_ID_INVALID) {
+ close(fd);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ OSNetSecureStream *secure_stream = &s_secure_streams[secure_stream_id];
+ secure_stream->fd = fd;
+ secure_stream->status = OS_NET_SECURE_STREAM_HANDSHAKING;
+
+
+ // recv rsa request
+ u8 encrypted_cm_handshake[512]; // Todo: use secure_stream->recv_buff
+ i64 recvd_size = recv(secure_stream->fd, encrypted_cm_handshake, sizeof(encrypted_cm_handshake), MSG_WAITALL);
+ if (recvd_size != sizeof(encrypted_cm_handshake)) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ // decrypt rsa request
+ CM_Handshake cm_handshake;
+ if (!rsa_decrypt(listener->rsa_key, &cm_handshake, encrypted_cm_handshake, sizeof(encrypted_cm_handshake))) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ // process request
+ memcpy(&secure_stream->recv_aes_key, &cm_handshake.aes_key_client_encrypt, sizeof(cm_handshake.aes_key_client_encrypt));
+ memcpy(&secure_stream->recv_aes_iv, &cm_handshake.aes_iv_client_encrypt, sizeof(cm_handshake.aes_iv_client_encrypt));
+ memcpy(&secure_stream->send_aes_key, &cm_handshake.aes_key_client_decrypt, sizeof(cm_handshake.aes_key_client_decrypt));
+ memcpy(&secure_stream->send_aes_iv, &cm_handshake.aes_iv_client_decrypt, sizeof(cm_handshake.aes_iv_client_decrypt));
+
+
+ // prepare aes response
+ AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->send_ciphertext;
+
+ SM_Handshake sm_handshake;
+ memcpy(sm_handshake.client_random, cm_handshake.client_random, sizeof(cm_handshake.client_random));
+
+ // encrypt aes response
+ if (!aes_gcm_encrypt(&secure_stream->send_aes_key,
+ &secure_stream->send_aes_iv,
+ secure_stream->send_ciphertext + sizeof(*aes_header), (u8*)&sm_handshake, sizeof(sm_handshake),
+ aes_header->tag, sizeof(aes_header->tag))) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+ aes_header->payload_size = sizeof(sm_handshake);
+ secure_stream->send_ciphertext_size_used = sizeof(*aes_header) + sizeof(sm_handshake);
+
+ // send response
+ i64 sent_size = send(secure_stream->fd, secure_stream->send_ciphertext, secure_stream->send_ciphertext_size_used, 0);
+ if (sent_size != secure_stream->send_ciphertext_size_used) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ if (!set_socket_nonblocking(fd)) {
+ close(fd);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED;
+ return secure_stream_id;
+}
+
+u32
+os_net_secure_stream_connect(char *address, u16 port, EVP_PKEY *server_rsa_pub)
+{
+ int fd = socket(AF_INET, SOCK_STREAM, 0);
+ if (fd == -1) {
+ printf("cant open socket\n");
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ struct sockaddr_in target_addr;
+ memset(&target_addr, 0, sizeof(target_addr));
+ target_addr.sin_family = AF_INET;
+ target_addr.sin_port = htons(port);
+ target_addr.sin_addr.s_addr = inet_addr(address);
+ if (connect(fd, (struct sockaddr*)&target_addr, sizeof(target_addr)) == -1) {
+ printf("connect failed\n");
+ close(fd);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+
+ u32 secure_stream_id = os_net_secure_stream_alloc();
+ if (secure_stream_id == OS_NET_SECURE_STREAM_ID_INVALID) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ OSNetSecureStream *secure_stream = &s_secure_streams[secure_stream_id];
+ secure_stream->fd = fd;
+ secure_stream->rsa_key = server_rsa_pub;
+
+
+ if (!aes_gcm_key_init_random(&secure_stream->send_aes_key)) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+ if (!aes_gcm_iv_init(&secure_stream->send_aes_iv)) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+ if (!aes_gcm_key_init_random(&secure_stream->recv_aes_key)) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+ if (!aes_gcm_iv_init(&secure_stream->recv_aes_iv)) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ /* Request */
+
+ // prepare rsa request
+ CM_Handshake cm_handshake;
+ memcpy(&cm_handshake.aes_key_client_encrypt, &secure_stream->send_aes_key, sizeof(secure_stream->send_aes_key));
+ memcpy(&cm_handshake.aes_iv_client_encrypt, &secure_stream->send_aes_iv, sizeof(secure_stream->send_aes_iv));
+ memcpy(&cm_handshake.aes_key_client_decrypt, &secure_stream->recv_aes_key, sizeof(secure_stream->recv_aes_key));
+ memcpy(&cm_handshake.aes_iv_client_decrypt, &secure_stream->recv_aes_iv, sizeof(secure_stream->recv_aes_iv));
+ if (RAND_bytes(cm_handshake.client_random, sizeof(cm_handshake.client_random)) != 1) {
+ printf("RAND_bytes failed at %s:%d", __FILE__, __LINE__);
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ // encrypt rsa request
+ void *encrypted_req = secure_stream->send_ciphertext;
+ if (!rsa_encrypt(server_rsa_pub, encrypted_req, &cm_handshake, sizeof(cm_handshake))) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ // send rsa request
+ i64 sent_size = send(secure_stream->fd, encrypted_req, 512, 0);
+ if (sent_size != 512) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ /* Response */
+
+ // recv aes response
+ AesGcmHeader *aes_header = (AesGcmHeader*)secure_stream->recv_ciphertext;
+ void *aes_payload = secure_stream->recv_ciphertext + sizeof(*aes_header);
+
+ size_t response_size = sizeof(*aes_header) + sizeof(SM_Handshake);
+
+ i64 recvd_size = recv(secure_stream->fd, secure_stream->recv_ciphertext, response_size, 0);
+ if (recvd_size != response_size) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ // decrypt aes response
+ SM_Handshake sm_handshake;
+ if (!aes_gcm_decrypt(&secure_stream->recv_aes_key,
+ &secure_stream->recv_aes_iv,
+ (u8*)&sm_handshake, aes_payload, sizeof(sm_handshake),
+ aes_header->tag, sizeof(aes_header->tag))) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+ // verify aes response
+ assert(sizeof(cm_handshake.client_random) == sizeof(sm_handshake.client_random));
+ if (memcmp(cm_handshake.client_random, sm_handshake.client_random, sizeof(cm_handshake.client_random)) != 0) {
+ close(fd);
+ os_net_secure_stream_free(secure_stream_id);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ if (!set_socket_nonblocking(fd)) {
+ close(fd);
+ return OS_NET_SECURE_STREAM_ID_INVALID;
+ }
+
+
+ secure_stream->status = OS_NET_SECURE_STREAM_CONNECTED;
+ secure_stream->rsa_key = server_rsa_pub;
+ return secure_stream_id;
+}
+
+void
+os_net_secure_streams_init(Arena *arena, size_t max_count)
+{
+ s_max_secure_stream_count = max_count;
+ s_secure_streams = arena_push(arena, max_count * sizeof(OSNetSecureStream));
+ memset(s_secure_streams, 0, max_count * sizeof(OSNetSecureStream));
+
+ s_free_id_count = max_count;
+ s_free_ids = arena_push(arena, max_count * sizeof(u32));
+ for (size_t i = 0; i < max_count; i++) {
+ s_free_ids[i] = i;
+ }
+}
+