diff --git a/src/AsyncTCP.cpp b/src/AsyncTCP.cpp index 3d81963..f083995 100644 --- a/src/AsyncTCP.cpp +++ b/src/AsyncTCP.cpp @@ -294,6 +294,7 @@ AsyncClient::AsyncClient(int sockfd) , _handshake_done(true) , _psk_ident(0) , _psk(0) +, _sslctx(NULL) #endif // ASYNC_TCP_SSL_ENABLED , _writeSpaceRemaining(TCP_SND_BUF) , _conn_state(0) @@ -561,6 +562,7 @@ bool AsyncClient::connect(const char* host, uint16_t port){ if(err == ERR_OK) { log_v("\taddr resolved as %08x, connecting...", addr.u_addr.ip4.addr); #if ASYNC_TCP_SSL_ENABLED + _hostname = host; return connect(IPAddress(addr.u_addr.ip4.addr), port, secure); #else return connect(IPAddress(addr.u_addr.ip4.addr), port); @@ -570,6 +572,7 @@ bool AsyncClient::connect(const char* host, uint16_t port){ _conn_state = 1; _connect_port = port; #if ASYNC_TCP_SSL_ENABLED + _hostname = host; _secure = secure; _handshake_done = !secure; #endif // ASYNC_TCP_SSL_ENABLED @@ -601,7 +604,11 @@ void _tcpsock_dns_found(const char * name, struct ip_addr * ipaddr, void * arg) void AsyncClient::_sockDelayedConnect(void) { if (_connect_addr.u_addr.ip4.addr) { +#if ASYNC_TCP_SSL_ENABLED + connect(IPAddress(_connect_addr.u_addr.ip4.addr), _connect_port, _secure); +#else connect(IPAddress(_connect_addr.u_addr.ip4.addr), _connect_port); +#endif } else { _conn_state = 0; if(_error_cb) { @@ -613,6 +620,32 @@ void AsyncClient::_sockDelayedConnect(void) } } +#if ASYNC_TCP_SSL_ENABLED +int AsyncClient::_runSSLHandshakeLoop(void) +{ + int res = 0; + + while (!_handshake_done) { + res = _sslctx->runSSLHandshake(); + if (res == 0) { + // Handshake successful + _handshake_done = true; + } else if (ASYNCTCP_TLS_CAN_RETRY(res)) { + // Ran out of readable data or writable space on socket, must continue later + break; + } else { + // SSL handshake for AsyncTCP does not inform SSL errors + log_e("TLS setup failed with error %d, closing socket...", res); + _close(); + // _sslctx should be NULL after this + break; + } + } + + return res; +} +#endif + bool AsyncClient::_sockIsWriteable(void) { int res; @@ -635,6 +668,46 @@ bool AsyncClient::_sockIsWriteable(void) } else if (sockerr != 0) { _error(sockerr); } else { +#if ASYNC_TCP_SSL_ENABLED + if (_secure) { + int res = 0; + + if (_sslctx == NULL) { + String remIP_str = remoteIP().toString(); + const char * host_or_ip = _hostname.isEmpty() + ? remIP_str.c_str() + : _hostname.c_str(); + + _sslctx = new AsyncTCP_TLS_Context(); + if (_root_ca != NULL) { + res = _sslctx->startSSLClient(_socket, host_or_ip, + (const unsigned char *)_root_ca, _root_ca_len, + (const unsigned char *)_cli_cert, _cli_cert_len, + (const unsigned char *)_cli_key, _cli_key_len); + } else if (_psk_ident != NULL) { + res = _sslctx->startSSLClient(_socket, host_or_ip, + _psk_ident, _psk); + } else { + res = _sslctx->startSSLClientInsecure(_socket, host_or_ip); + } + + if (res != 0) { + // SSL setup for AsyncTCP does not inform SSL errors + log_e("TLS setup failed with error %d, closing socket...", res); + _close(); + // _sslctx should be NULL after this + } + } + + // _handshake_done is set to FALSE on connect() if encrypted connection + if (_sslctx != NULL && res == 0) res = _runSSLHandshakeLoop(); + + if (!_handshake_done) return ASYNCTCP_TLS_CAN_RETRY(res); + + // Fallthrough to ordinary successful connection + } +#endif + // Socket is now fully connected _conn_state = 4; activity = true; @@ -681,7 +754,27 @@ bool AsyncClient::_flushWriteQueue(void) do { uint8_t * p = it->data + it->written; size_t n = it->length - it->written; - errno = 0; ssize_t r = lwip_write(_socket, p, n); + errno = 0; + ssize_t r; + +#if ASYNC_TCP_SSL_ENABLED + if (_sslctx != NULL) { + r = _sslctx->write(p, n); + if (ASYNCTCP_TLS_CAN_RETRY(r)) { + r = -1; + errno = EAGAIN; + } else if (ASYNCTCP_TLS_EOF(r)) { + r = -1; + errno = EPIPE; + } else if (r < 0) { + if (errno == 0) errno = EIO; + } + } else { +#endif + r = lwip_write(_socket, p, n); +#if ASYNC_TCP_SSL_ENABLED + } +#endif if (r >= 0) { // Written some data into the socket @@ -755,7 +848,38 @@ void AsyncClient::_notifyWrittenBuffers(std::deque & notifyqueu void AsyncClient::_sockIsReadable(void) { _rx_last_packet = millis(); - errno = 0; ssize_t r = lwip_read(_socket, _readBuffer, MAX_PAYLOAD_SIZE); + errno = 0; + ssize_t r; + +#if ASYNC_TCP_SSL_ENABLED + if (_sslctx != NULL) { + if (!_handshake_done) { + // Handshake process has stopped for want of data, must be + // continued here for connection to complete. + _runSSLHandshakeLoop(); + + // If handshake was successful, this will be recognized when the socket + // next becomes writable. No other read operation should be done here. + return; + } else { + r = _sslctx->read(_readBuffer, MAX_PAYLOAD_SIZE); + if (ASYNCTCP_TLS_CAN_RETRY(r)) { + r = -1; + errno = EAGAIN; + } else if (ASYNCTCP_TLS_EOF(r)) { + // Simulate "successful" end-of-stream condition + r = 0; + } else if (r < 0) { + if (errno == 0) errno = EIO; + } + } + } else { +#endif + r = lwip_read(_socket, _readBuffer, MAX_PAYLOAD_SIZE); +#if ASYNC_TCP_SSL_ENABLED + } +#endif + if (r > 0) { if(_recv_cb) { _recv_cb(_recv_cb_arg, this, _readBuffer, r); @@ -844,6 +968,12 @@ void AsyncClient::_close(void) _conn_state = 0; ::close(_socket); _socket = -1; +#if ASYNC_TCP_SSL_ENABLED + if (_sslctx != NULL) { + delete _sslctx; + _sslctx = NULL; + } +#endif xSemaphoreGiveRecursive(_asyncsock_mutex); _clearWriteQueue(); @@ -856,6 +986,12 @@ void AsyncClient::_error(int8_t err) _conn_state = 0; ::close(_socket); _socket = -1; +#if ASYNC_TCP_SSL_ENABLED + if (_sslctx != NULL) { + delete _sslctx; + _sslctx = NULL; + } +#endif xSemaphoreGiveRecursive(_asyncsock_mutex); _clearWriteQueue(); diff --git a/src/AsyncTCP.h b/src/AsyncTCP.h index 1dcfcec..fe7c12c 100644 --- a/src/AsyncTCP.h +++ b/src/AsyncTCP.h @@ -203,6 +203,9 @@ class AsyncClient : public AsyncSocketBase bool _handshake_done; const char* _psk_ident; const char* _psk; + + String _hostname; + AsyncTCP_TLS_Context * _sslctx; #endif // ASYNC_TCP_SSL_ENABLED // The following private struct represents a buffer enqueued with the add() @@ -244,6 +247,10 @@ class AsyncClient : public AsyncSocketBase void _collectNotifyWrittenBuffers(std::deque &, int &); void _notifyWrittenBuffers(std::deque &, int); +#if ASYNC_TCP_SSL_ENABLED + int _runSSLHandshakeLoop(void); +#endif + friend void _tcpsock_dns_found(const char * name, struct ip_addr * ipaddr, void * arg); };