/* Asynchronous TCP library for Espressif MCUs Copyright (c) 2016 Hristo Gochkov. All rights reserved. This file is part of the esp8266 core for Arduino environment. This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA */ /* * Compatibility for AxTLS with LWIP raw tcp mode (http://lwip.wikia.com/wiki/Raw/TCP) * Original Code and Inspiration: Slavey Karadzhov */ // To handle all the definitions needed for debug printing, we need to delay // macro definitions till later. #define DEBUG_SKIP__DEBUG_PRINT_MACROS 1 #include #undef DEBUG_SKIP__DEBUG_PRINT_MACROS #if ASYNC_TCP_SSL_ENABLED #include "lwip/opt.h" #include "lwip/tcp.h" #include "lwip/inet.h" #include #include #include #include #include // ets_uart_printf is defined in esp8266_undocumented.h, in newer Arduino ESP8266 Core. extern int ets_uart_printf(const char *format, ...) __attribute__ ((format (printf, 1, 2))); #include #ifndef TCP_SSL_DEBUG #define TCP_SSL_DEBUG(...) do { (void)0;} while(false) #endif uint8_t * default_private_key = NULL; uint16_t default_private_key_len = 0; uint8_t * default_certificate = NULL; uint16_t default_certificate_len = 0; static uint8_t _tcp_ssl_has_client = 0; SSL_CTX * tcp_ssl_new_server_ctx(const char *cert, const char *private_key_file, const char *password){ uint32_t options = SSL_CONNECT_IN_PARTS; SSL_CTX *ssl_ctx; if(private_key_file){ options |= SSL_NO_DEFAULT_KEY; } if ((ssl_ctx = ssl_ctx_new(options, SSL_DEFAULT_SVR_SESS)) == NULL){ TCP_SSL_DEBUG("tcp_ssl_new_server_ctx: failed to allocate context\n"); return NULL; } if (private_key_file){ int obj_type = SSL_OBJ_RSA_KEY; if (strstr(private_key_file, ".p8")) obj_type = SSL_OBJ_PKCS8; else if (strstr(private_key_file, ".p12")) obj_type = SSL_OBJ_PKCS12; if (ssl_obj_load(ssl_ctx, obj_type, private_key_file, password)){ TCP_SSL_DEBUG("tcp_ssl_new_server_ctx: load private key '%s' failed\n", private_key_file); return NULL; } } if (cert){ if (ssl_obj_load(ssl_ctx, SSL_OBJ_X509_CERT, cert, NULL)){ TCP_SSL_DEBUG("tcp_ssl_new_server_ctx: load certificate '%s' failed\n", cert); return NULL; } } return ssl_ctx; } struct tcp_ssl_pcb { struct tcp_pcb *tcp; int fd; SSL_CTX* ssl_ctx; SSL *ssl; uint8_t type; int handshake; void * arg; tcp_ssl_data_cb_t on_data; tcp_ssl_handshake_cb_t on_handshake; tcp_ssl_error_cb_t on_error; int last_wr; struct pbuf *tcp_pbuf; int pbuf_offset; struct tcp_ssl_pcb * next; }; typedef struct tcp_ssl_pcb tcp_ssl_t; static tcp_ssl_t * tcp_ssl_array = NULL; static int tcp_ssl_next_fd = 0; uint8_t tcp_ssl_has_client(){ return _tcp_ssl_has_client; } tcp_ssl_t * tcp_ssl_new(struct tcp_pcb *tcp) { if(tcp_ssl_next_fd < 0){ tcp_ssl_next_fd = 0;//overflow } tcp_ssl_t * new_item = (tcp_ssl_t*)malloc(sizeof(tcp_ssl_t)); if(!new_item){ TCP_SSL_DEBUG("tcp_ssl_new: failed to allocate tcp_ssl\n"); return NULL; } new_item->tcp = tcp; new_item->handshake = SSL_NOT_OK; new_item->arg = NULL; new_item->on_data = NULL; new_item->on_handshake = NULL; new_item->on_error = NULL; new_item->tcp_pbuf = NULL; new_item->pbuf_offset = 0; new_item->next = NULL; new_item->ssl_ctx = NULL; new_item->ssl = NULL; new_item->type = TCP_SSL_TYPE_CLIENT; new_item->fd = tcp_ssl_next_fd++; if(tcp_ssl_array == NULL){ tcp_ssl_array = new_item; } else { tcp_ssl_t * item = tcp_ssl_array; while(item->next != NULL) item = item->next; item->next = new_item; } TCP_SSL_DEBUG("tcp_ssl_new: %d\n", new_item->fd); return new_item; } tcp_ssl_t* tcp_ssl_get(struct tcp_pcb *tcp) { if(tcp == NULL) { return NULL; } tcp_ssl_t * item = tcp_ssl_array; while(item && item->tcp != tcp){ item = item->next; } return item; } int tcp_ssl_new_client(struct tcp_pcb *tcp){ SSL_CTX* ssl_ctx; tcp_ssl_t * tcp_ssl; if(tcp == NULL) { return -1; } if(tcp_ssl_get(tcp) != NULL){ TCP_SSL_DEBUG("tcp_ssl_new_client: tcp_ssl already exists\n"); return -1; } ssl_ctx = ssl_ctx_new(SSL_CONNECT_IN_PARTS | SSL_SERVER_VERIFY_LATER, 1); if(ssl_ctx == NULL){ TCP_SSL_DEBUG("tcp_ssl_new_client: failed to allocate ssl context\n"); return -1; } tcp_ssl = tcp_ssl_new(tcp); if(tcp_ssl == NULL){ ssl_ctx_free(ssl_ctx); return -1; } tcp_ssl->ssl_ctx = ssl_ctx; tcp_ssl->ssl = ssl_client_new(ssl_ctx, tcp_ssl->fd, NULL, 0, NULL); if(tcp_ssl->ssl == NULL){ TCP_SSL_DEBUG("tcp_ssl_new_client: failed to allocate ssl\n"); tcp_ssl_free(tcp); return -1; } return tcp_ssl->fd; } int tcp_ssl_new_server(struct tcp_pcb *tcp, SSL_CTX* ssl_ctx){ tcp_ssl_t * tcp_ssl; if(tcp == NULL) { return -1; } if(ssl_ctx == NULL){ return -1; } if(tcp_ssl_get(tcp) != NULL){ TCP_SSL_DEBUG("tcp_ssl_new_server: tcp_ssl already exists\n"); return -1; } tcp_ssl = tcp_ssl_new(tcp); if(tcp_ssl == NULL){ return -1; } tcp_ssl->type = TCP_SSL_TYPE_SERVER; tcp_ssl->ssl_ctx = ssl_ctx; _tcp_ssl_has_client = 1; tcp_ssl->ssl = ssl_server_new(ssl_ctx, tcp_ssl->fd); if(tcp_ssl->ssl == NULL){ TCP_SSL_DEBUG("tcp_ssl_new_server: failed to allocate ssl\n"); tcp_ssl_free(tcp); return -1; } return tcp_ssl->fd; } int tcp_ssl_free(struct tcp_pcb *tcp) { if(tcp == NULL) { return -1; } tcp_ssl_t * item = tcp_ssl_array; if(item->tcp == tcp){ tcp_ssl_array = tcp_ssl_array->next; if(item->tcp_pbuf != NULL){ pbuf_free(item->tcp_pbuf); } TCP_SSL_DEBUG("tcp_ssl_free: %d\n", item->fd); if(item->ssl) ssl_free(item->ssl); if(item->type == TCP_SSL_TYPE_CLIENT && item->ssl_ctx) ssl_ctx_free(item->ssl_ctx); if(item->type == TCP_SSL_TYPE_SERVER) _tcp_ssl_has_client = 0; free(item); return 0; } while(item->next && item->next->tcp != tcp) item = item->next; if(item->next == NULL){ return ERR_TCP_SSL_INVALID_CLIENTFD_DATA;//item not found } tcp_ssl_t * i = item->next; item->next = i->next; if(i->tcp_pbuf != NULL){ pbuf_free(i->tcp_pbuf); } TCP_SSL_DEBUG("tcp_ssl_free: %d\n", i->fd); if(i->ssl) ssl_free(i->ssl); if(i->type == TCP_SSL_TYPE_CLIENT && i->ssl_ctx) ssl_ctx_free(i->ssl_ctx); if(i->type == TCP_SSL_TYPE_SERVER) _tcp_ssl_has_client = 0; free(i); return 0; } #ifdef AXTLS_2_0_0_SNDBUF int tcp_ssl_sndbuf(struct tcp_pcb *tcp){ int expected; int available; int result = -1; if(tcp == NULL) { return result; } tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp); if(!tcp_ssl){ TCP_SSL_DEBUG("tcp_ssl_sndbuf: tcp_ssl is NULL\n"); return result; } available = tcp_sndbuf(tcp); if(!available){ TCP_SSL_DEBUG("tcp_ssl_sndbuf: tcp_sndbuf is zero\n"); return 0; } result = available; while((expected = ssl_calculate_write_length(tcp_ssl->ssl, result)) > available){ result -= (expected - available) + 4; } if(expected > 0){ //TCP_SSL_DEBUG("tcp_ssl_sndbuf: tcp_sndbuf is %d from %d\n", result, available); return result; } return 0; } #endif int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len) { if(tcp == NULL) { return -1; } tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp); if(!tcp_ssl){ TCP_SSL_DEBUG("tcp_ssl_write: tcp_ssl is NULL\n"); return 0; } tcp_ssl->last_wr = 0; #ifdef AXTLS_2_0_0_SNDBUF int expected_len = ssl_calculate_write_length(tcp_ssl->ssl, len); int available_len = tcp_sndbuf(tcp); if(expected_len < 0 || expected_len > available_len){ TCP_SSL_DEBUG("tcp_ssl_write: data will not fit! %u < %d(%u)\r\n", available_len, expected_len, len); return -1; } #endif int rc = ssl_write(tcp_ssl->ssl, data, len); //TCP_SSL_DEBUG("tcp_ssl_write: %u -> %d (%d)\r\n", len, tcp_ssl->last_wr, rc); if (rc < 0){ if(rc != SSL_CLOSE_NOTIFY) { TCP_SSL_DEBUG("tcp_ssl_write error: %d\r\n", rc); } return rc; } return tcp_ssl->last_wr; } /** * Reads data from the SSL over TCP stream. Returns decrypted data. * @param tcp_pcb *tcp - pointer to the raw tcp object * @param pbuf *p - pointer to the buffer with the TCP packet data * * @return int * 0 - when everything is fine but there are no symbols to process yet * < 0 - when there is an error * > 0 - the length of the clear text characters that were read */ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) { if(tcp == NULL) { return -1; } tcp_ssl_t* fd_data = NULL; int read_bytes = 0; int total_bytes = 0; uint8_t *read_buf; fd_data = tcp_ssl_get(tcp); if(fd_data == NULL) { TCP_SSL_DEBUG("tcp_ssl_read: tcp_ssl is NULL\n"); return ERR_TCP_SSL_INVALID_CLIENTFD_DATA; } if(p == NULL) { TCP_SSL_DEBUG("tcp_ssl_read:p == NULL\n"); return ERR_TCP_SSL_INVALID_DATA; } //TCP_SSL_DEBUG("READY TO READ SOME DATA\n"); fd_data->tcp_pbuf = p; fd_data->pbuf_offset = 0; do { read_bytes = ssl_read(fd_data->ssl, &read_buf); TCP_SSL_DEBUG("tcp_ssl_ssl_read: %d\n", read_bytes); if(read_bytes < SSL_OK) { if(read_bytes != SSL_CLOSE_NOTIFY) { TCP_SSL_DEBUG("tcp_ssl_read: read error: %d\n", read_bytes); } total_bytes = read_bytes; break; } else if(read_bytes > 0){ if(fd_data->on_data){ fd_data->on_data(fd_data->arg, tcp, read_buf, read_bytes); // fd_data may have been freed in callback fd_data = tcp_ssl_get(tcp); if(NULL == fd_data) return SSL_CLOSE_NOTIFY; } total_bytes+= read_bytes; } else { if(fd_data->handshake != SSL_OK) { // fd_data may be freed in callbacks. int handshake = fd_data->handshake = ssl_handshake_status(fd_data->ssl); if(handshake == SSL_OK){ TCP_SSL_DEBUG("tcp_ssl_read: handshake OK\n"); if(fd_data->on_handshake) fd_data->on_handshake(fd_data->arg, fd_data->tcp, fd_data->ssl); fd_data = tcp_ssl_get(tcp); if(NULL == fd_data) return SSL_CLOSE_NOTIFY; } else if(handshake != SSL_NOT_OK){ TCP_SSL_DEBUG("tcp_ssl_read: handshake error: %d\n", handshake); if(fd_data->on_error) fd_data->on_error(fd_data->arg, fd_data->tcp, handshake); return handshake; // With current code APP gets called twice at onError handler. // Once here and again after return when handshake != SSL_CLOSE_NOTIFY. // As always APP must never free resources at onError only at onDisconnect. } } } } while (p->tot_len - fd_data->pbuf_offset > 0); tcp_recved(tcp, p->tot_len); fd_data->tcp_pbuf = NULL; pbuf_free(p); return total_bytes; } SSL * tcp_ssl_get_ssl(struct tcp_pcb *tcp){ tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp); if(tcp_ssl){ return tcp_ssl->ssl; } return NULL; } bool tcp_ssl_has(struct tcp_pcb *tcp){ return tcp_ssl_get(tcp) != NULL; } int tcp_ssl_is_server(struct tcp_pcb *tcp){ tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp); if(tcp_ssl){ return tcp_ssl->type; } return -1; } void tcp_ssl_arg(struct tcp_pcb *tcp, void * arg){ tcp_ssl_t * item = tcp_ssl_get(tcp); if(item) { item->arg = arg; } } void tcp_ssl_data(struct tcp_pcb *tcp, tcp_ssl_data_cb_t arg){ tcp_ssl_t * item = tcp_ssl_get(tcp); if(item) { item->on_data = arg; } } void tcp_ssl_handshake(struct tcp_pcb *tcp, tcp_ssl_handshake_cb_t arg){ tcp_ssl_t * item = tcp_ssl_get(tcp); if(item) { item->on_handshake = arg; } } void tcp_ssl_err(struct tcp_pcb *tcp, tcp_ssl_error_cb_t arg){ tcp_ssl_t * item = tcp_ssl_get(tcp); if(item) { item->on_error = arg; } } static tcp_ssl_file_cb_t _tcp_ssl_file_cb = NULL; static void * _tcp_ssl_file_arg = NULL; void tcp_ssl_file(tcp_ssl_file_cb_t cb, void * arg){ _tcp_ssl_file_cb = cb; _tcp_ssl_file_arg = arg; } int ax_get_file(const char *filename, uint8_t **buf) { //TCP_SSL_DEBUG("ax_get_file: %s\n", filename); if(_tcp_ssl_file_cb){ return _tcp_ssl_file_cb(_tcp_ssl_file_arg, filename, buf); } *buf = 0; return 0; } tcp_ssl_t* tcp_ssl_get_by_fd(int fd) { tcp_ssl_t * item = tcp_ssl_array; while(item && item->fd != fd){ item = item->next; } return item; } /* * The LWIP tcp raw version of the SOCKET_WRITE(A, B, C) */ int ax_port_write(int fd, uint8_t *data, uint16_t len) { tcp_ssl_t *fd_data = NULL; int tcp_len = 0; err_t err = ERR_OK; //TCP_SSL_DEBUG("ax_port_write: %d, %d\n", fd, len); fd_data = tcp_ssl_get_by_fd(fd); if(fd_data == NULL) { //TCP_SSL_DEBUG("ax_port_write: tcp_ssl[%d] is NULL\n", fd); return ERR_MEM; } if (data == NULL || len == 0) { return 0; } if (tcp_sndbuf(fd_data->tcp) < len) { tcp_len = tcp_sndbuf(fd_data->tcp); if(tcp_len == 0) { TCP_SSL_DEBUG("ax_port_write: tcp_sndbuf is zero: %d\n", len); return ERR_MEM; } } else { tcp_len = len; } if (tcp_len > 2 * fd_data->tcp->mss) { tcp_len = 2 * fd_data->tcp->mss; } err = tcp_write(fd_data->tcp, data, tcp_len, TCP_WRITE_FLAG_COPY); if(err < ERR_OK) { if (err == ERR_MEM) { TCP_SSL_DEBUG("ax_port_write: No memory %d (%d)\n", tcp_len, len); return err; } TCP_SSL_DEBUG("ax_port_write: tcp_write error: %ld\n", err); return err; } else if (err == ERR_OK) { //TCP_SSL_DEBUG("ax_port_write: tcp_output: %d / %d\n", tcp_len, len); err = tcp_output(fd_data->tcp); if(err != ERR_OK) { TCP_SSL_DEBUG("ax_port_write: tcp_output err: %ld\n", err); return err; } } fd_data->last_wr += tcp_len; return tcp_len; } /* * The LWIP tcp raw version of the SOCKET_READ(A, B, C) */ int ax_port_read(int fd, uint8_t *data, int len) { tcp_ssl_t *fd_data = NULL; uint8_t *read_buf = NULL; uint8_t *pread_buf = NULL; u16_t recv_len = 0; //TCP_SSL_DEBUG("ax_port_read: %d, %d\n", fd, len); fd_data = tcp_ssl_get_by_fd(fd); if (fd_data == NULL) { TCP_SSL_DEBUG("ax_port_read: tcp_ssl[%d] is NULL\n", fd); return ERR_TCP_SSL_INVALID_CLIENTFD_DATA; } if(fd_data->tcp_pbuf == NULL || fd_data->tcp_pbuf->tot_len == 0) { return 0; } read_buf =(uint8_t*)calloc(fd_data->tcp_pbuf->len + 1, sizeof(uint8_t)); pread_buf = read_buf; if (pread_buf != NULL){ recv_len = pbuf_copy_partial(fd_data->tcp_pbuf, read_buf, len, fd_data->pbuf_offset); fd_data->pbuf_offset += recv_len; } if (recv_len != 0) { memcpy(data, read_buf, recv_len); } if(len < recv_len) { TCP_SSL_DEBUG("ax_port_read: got %d bytes more than expected\n", recv_len - len); } free(pread_buf); pread_buf = NULL; return recv_len; } void ax_wdt_feed() {} #endif