From 71ab0b8e8c4b25916f4db0cbd65baf7690996a74 Mon Sep 17 00:00:00 2001 From: Pavel Brychta Date: Wed, 22 Jun 2022 11:39:36 +0200 Subject: [PATCH] Pokus o nejnovejsi upravy kvuli ESP32 - nutno overit na ESP8266!! --- src/AsyncEventSource.cpp | 104 ++- src/AsyncEventSource.h | 11 +- src/AsyncJson.h | 4 +- src/AsyncWebSocket.cpp | 1646 +++++++++++++++------------------ src/AsyncWebSocket.h | 197 ++-- src/AsyncWebSynchronization.h | 71 +- src/ESPAsyncWebServer.h | 66 +- src/SPIFFSEditor.cpp | 2 - src/StringArray.h | 19 - src/WebAuthentication.cpp | 6 +- src/WebRequest.cpp | 138 +-- src/WebResponseImpl.h | 4 +- src/WebResponses.cpp | 24 +- 13 files changed, 1063 insertions(+), 1229 deletions(-) diff --git a/src/AsyncEventSource.cpp b/src/AsyncEventSource.cpp index 1325310..b5d72fa 100644 --- a/src/AsyncEventSource.cpp +++ b/src/AsyncEventSource.cpp @@ -137,16 +137,17 @@ size_t AsyncEventSourceMessage::ack(size_t len, uint32_t time) { return 0; } +// This could also return void as the return value is not used. +// Leaving as-is for compatibility... size_t AsyncEventSourceMessage::send(AsyncClient *client) { - const size_t len = _len - _sent; - if(client->space() < len){ - return 0; - } - size_t sent = client->add((const char *)_data, len); - if(client->canSend()) - client->send(); - _sent += sent; - return sent; + if (_sent >= _len) { + return 0; + } + const size_t len_to_send = _len - _sent; + auto position = reinterpret_cast(_data + _sent); + const size_t sent_now = client->write(position, len_to_send); + _sent += sent_now; + return sent_now; } // Client @@ -173,7 +174,9 @@ AsyncEventSourceClient::AsyncEventSourceClient(AsyncWebServerRequest *request, A } AsyncEventSourceClient::~AsyncEventSourceClient(){ - _messageQueue.free(); + _lockmq.lock(); + _messageQueue.free(); + _lockmq.unlock(); close(); } @@ -184,33 +187,41 @@ void AsyncEventSourceClient::_queueMessage(AsyncEventSourceMessage *dataMessage) delete dataMessage; return; } + //length() is not thread-safe, thus acquiring the lock before this call.. + _lockmq.lock(); if(_messageQueue.length() >= SSE_MAX_QUEUED_MESSAGES){ ets_printf(String(F("ERROR: Too many messages queued\n")).c_str()); delete dataMessage; } else { - _messageQueue.add(dataMessage); + _messageQueue.add(dataMessage); + // runqueue trigger when new messages added + if(_client->canSend()) { + _runQueue(); + } } - if(_client->canSend()) - _runQueue(); + _lockmq.unlock(); } void AsyncEventSourceClient::_onAck(size_t len, uint32_t time){ + // Same here, acquiring the lock early + _lockmq.lock(); while(len && !_messageQueue.isEmpty()){ len = _messageQueue.front()->ack(len, time); if(_messageQueue.front()->finished()) _messageQueue.remove(_messageQueue.front()); } - _runQueue(); + _lockmq.unlock(); } void AsyncEventSourceClient::_onPoll(){ + _lockmq.lock(); if(!_messageQueue.isEmpty()){ _runQueue(); } + _lockmq.unlock(); } - void AsyncEventSourceClient::_onTimeout(uint32_t time __attribute__((unused))){ _client->close(true); } @@ -225,7 +236,7 @@ void AsyncEventSourceClient::close(){ _client->close(); } -void AsyncEventSourceClient::write(const char * message, size_t len){ +void AsyncEventSourceClient::_write(const char * message, size_t len){ _queueMessage(new AsyncEventSourceMessage(message, len)); } @@ -234,15 +245,23 @@ void AsyncEventSourceClient::send(const char *message, const char *event, uint32 _queueMessage(new AsyncEventSourceMessage(ev.c_str(), ev.length())); } -void AsyncEventSourceClient::_runQueue(){ - while(!_messageQueue.isEmpty() && _messageQueue.front()->finished()){ - _messageQueue.remove(_messageQueue.front()); - } +size_t AsyncEventSourceClient::packetsWaiting() const { + size_t len; + _lockmq.lock(); + len = _messageQueue.length(); + _lockmq.unlock(); + return len; +} - for(auto i = _messageQueue.begin(); i != _messageQueue.end(); ++i) - { - if(!(*i)->sent()) +void AsyncEventSourceClient::_runQueue() { + // Calls to this private method now already protected by _lockmq acquisition + // so no extra call of _lockmq.lock() here.. + for (auto i = _messageQueue.begin(); i != _messageQueue.end(); ++i) { + // If it crashes here, iterator (i) has been invalidated as _messageQueue + // has been changed... (UL 2020-11-15: Not supposed to happen any more ;-) ) + if (!(*i)->sent()) { (*i)->send(_client); + } } } @@ -280,17 +299,22 @@ void AsyncEventSource::_addClient(AsyncEventSourceClient * client){ client->write((const char *)temp, 2053); free(temp); }*/ - + AsyncWebLockGuard l(_client_queue_lock); _clients.add(client); if(_connectcb) _connectcb(client); } void AsyncEventSource::_handleDisconnect(AsyncEventSourceClient * client){ + AsyncWebLockGuard l(_client_queue_lock); _clients.remove(client); } void AsyncEventSource::close(){ + // While the whole loop is not done, the linked list is locked and so the + // iterator should remain valid even when AsyncEventSource::_handleDisconnect() + // is called very early + AsyncWebLockGuard l(_client_queue_lock); for(const auto &c: _clients){ if(c->connected()) c->close(); @@ -299,37 +323,39 @@ void AsyncEventSource::close(){ // pmb fix size_t AsyncEventSource::avgPacketsWaiting() const { - if(_clients.isEmpty()) + size_t aql = 0; + uint32_t nConnectedClients = 0; + AsyncWebLockGuard l(_client_queue_lock); + if (_clients.isEmpty()) { return 0; - - size_t aql=0; - uint32_t nConnectedClients=0; - + } for(const auto &c: _clients){ if(c->connected()) { - aql+=c->packetsWaiting(); + aql += c->packetsWaiting(); ++nConnectedClients; } } -// return aql / nConnectedClients; - return ((aql) + (nConnectedClients/2))/(nConnectedClients); // round up + return ((aql) + (nConnectedClients/2)) / (nConnectedClients); // round up } -void AsyncEventSource::send(const char *message, const char *event, uint32_t id, uint32_t reconnect){ - - +void AsyncEventSource::send( + const char *message, const char *event, uint32_t id, uint32_t reconnect){ String ev = generateEventMessage(message, event, id, reconnect); + AsyncWebLockGuard l(_client_queue_lock); for(const auto &c: _clients){ if(c->connected()) { - c->write(ev.c_str(), ev.length()); + c->_write(ev.c_str(), ev.length()); } } } size_t AsyncEventSource::count() const { - return _clients.count_if([](AsyncEventSourceClient *c){ - return c->connected(); - }); + size_t n_clients; + AsyncWebLockGuard l(_client_queue_lock); + n_clients = _clients.count_if([](AsyncEventSourceClient *c){ + return c->connected(); + }); + return n_clients; } bool AsyncEventSource::canHandle(AsyncWebServerRequest *request){ diff --git a/src/AsyncEventSource.h b/src/AsyncEventSource.h index a350e7f..abb1829 100644 --- a/src/AsyncEventSource.h +++ b/src/AsyncEventSource.h @@ -73,6 +73,8 @@ class AsyncEventSourceClient { AsyncEventSource *_server; uint32_t _lastId; LinkedList _messageQueue; + // ArFi 2020-08-27 for protecting/serializing _messageQueue + AsyncPlainLock _lockmq; void _queueMessage(AsyncEventSourceMessage *dataMessage); void _runQueue(); @@ -83,12 +85,12 @@ class AsyncEventSourceClient { AsyncClient* client(){ return _client; } void close(); - void write(const char * message, size_t len); void send(const char *message, const char *event=NULL, uint32_t id=0, uint32_t reconnect=0); bool connected() const { return (_client != NULL) && _client->connected(); } uint32_t lastId() const { return _lastId; } - size_t packetsWaiting() const { return _messageQueue.length(); } + size_t packetsWaiting() const; + void _write(const char * message, size_t len); //system callbacks (do not call) void _onAck(size_t len, uint32_t time); void _onPoll(); @@ -100,6 +102,9 @@ class AsyncEventSource: public AsyncWebHandler { private: String _url; LinkedList _clients; + // Same as for individual messages, protect mutations of _clients list + // since simultaneous access from different tasks is possible + AsyncWebLock _client_queue_lock; ArEventHandlerFunction _connectcb; ArAuthorizeConnectHandler _authorizeConnectHandler; public: @@ -111,7 +116,7 @@ class AsyncEventSource: public AsyncWebHandler { void onConnect(ArEventHandlerFunction cb); void authorizeConnect(ArAuthorizeConnectHandler cb); void send(const char *message, const char *event=NULL, uint32_t id=0, uint32_t reconnect=0); - size_t count() const; //number clinets connected + size_t count() const; //number clients connected size_t avgPacketsWaiting() const; //system callbacks (do not call) diff --git a/src/AsyncJson.h b/src/AsyncJson.h index 27b4a26..2fa6a2d 100644 --- a/src/AsyncJson.h +++ b/src/AsyncJson.h @@ -41,7 +41,9 @@ #if ARDUINOJSON_VERSION_MAJOR == 5 #define ARDUINOJSON_5_COMPATIBILITY #else - #define DYNAMIC_JSON_DOCUMENT_SIZE 1024 + #ifndef DYNAMIC_JSON_DOCUMENT_SIZE + #define DYNAMIC_JSON_DOCUMENT_SIZE 1024 + #endif #endif constexpr const char* JSON_MIMETYPE = "application/json"; diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 9ebab12..12be5f8 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -21,21 +21,12 @@ #include "Arduino.h" #include "AsyncWebSocket.h" +#include + #include #ifndef ESP8266 -extern "C" { -typedef struct { - uint32_t state[5]; - uint32_t count[2]; - unsigned char buffer[64]; -} SHA1_CTX; - -void SHA1Transform(uint32_t state[5], const unsigned char buffer[64]); -void SHA1Init(SHA1_CTX* context); -void SHA1Update(SHA1_CTX* context, const unsigned char* data, uint32_t len); -void SHA1Final(unsigned char digest[20], SHA1_CTX* context); -} +#include "mbedtls/sha1.h" #else #include #endif @@ -131,387 +122,129 @@ size_t webSocketSendFrame(AsyncClient *client, bool final, uint8_t opcode, bool } -/* - * AsyncWebSocketMessageBuffer - */ - - - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() - :_data(nullptr) - ,_len(0) - ,_lock(false) - ,_count(0) -{ - -} - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t * data, size_t size) - :_data(nullptr) - ,_len(size) - ,_lock(false) - ,_count(0) -{ - - if (!data) { - return; - } - - _data = new uint8_t[_len + 1]; - - if (_data) { - // Serial.println("BUFF alloc"); - memcpy(_data, data, _len); - _data[_len] = 0; - } -} - - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(size_t size) - :_data(nullptr) - ,_len(size) - ,_lock(false) - ,_count(0) -{ - _data = new uint8_t[_len + 1]; - - if (_data) { - // Serial.println("BUFF alloc"); - _data[_len] = 0; - } - -} - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(const AsyncWebSocketMessageBuffer & copy) - :_data(nullptr) - ,_len(0) - ,_lock(false) - ,_count(0) -{ - _len = copy._len; - _lock = copy._lock; - _count = 0; - - if (_len) { - _data = new uint8_t[_len + 1]; - _data[_len] = 0; - } - - if (_data) { - // Serial.println("BUFF alloc"); - memcpy(_data, copy._data, _len); - _data[_len] = 0; - } - -} - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(AsyncWebSocketMessageBuffer && copy) - :_data(nullptr) - ,_len(0) - ,_lock(false) - ,_count(0) -{ - _len = copy._len; - _lock = copy._lock; - _count = 0; - - if (copy._data) { - // Serial.println("BUFF alloc"); - _data = copy._data; - copy._data = nullptr; - } - -} - -AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer() -{ - if (_data) { - // Serial.println("BUFF free"); - delete[] _data; - } -} - -bool AsyncWebSocketMessageBuffer::reserve(size_t size) -{ - _len = size; - - if (_data) { - delete[] _data; - _data = nullptr; - } - - _data = new uint8_t[_len + 1]; - - if (_data) { - _data[_len] = 0; - return true; - } else { - return false; - } - -} - - /* * Control Frame */ class AsyncWebSocketControl { - private: +private: uint8_t _opcode; uint8_t *_data; size_t _len; bool _mask; bool _finished; - public: - AsyncWebSocketControl(uint8_t opcode, uint8_t *data=NULL, size_t len=0, bool mask=false) + +public: + AsyncWebSocketControl(uint8_t opcode, const uint8_t *data=NULL, size_t len=0, bool mask=false) :_opcode(opcode) ,_len(len) ,_mask(len && mask) ,_finished(false) - { - if(data == NULL) - _len = 0; - if(_len){ - if(_len > 125) - _len = 125; - _data = (uint8_t*)malloc(_len); - if(_data == NULL) - _len = 0; - else memcpy(_data, data, len); - } else _data = NULL; + { + if (data == NULL) + _len = 0; + if (_len) + { + if (_len > 125) + _len = 125; + + _data = (uint8_t*)malloc(_len); + + if(_data == NULL) + _len = 0; + else + memcpy(_data, data, len); + } + else + _data = NULL; } - virtual ~AsyncWebSocketControl(){ - if(_data != NULL) - free(_data); + + virtual ~AsyncWebSocketControl() + { + if (_data != NULL) + free(_data); } + virtual bool finished() const { return _finished; } uint8_t opcode(){ return _opcode; } uint8_t len(){ return _len + 2; } size_t send(AsyncClient *client){ - _finished = true; - return webSocketSendFrame(client, true, _opcode & 0x0F, _mask, _data, _len); + _finished = true; + return webSocketSendFrame(client, true, _opcode & 0x0F, _mask, _data, _len); } }; + /* - * Basic Buffered Message + * AsyncWebSocketMessage Message */ -AsyncWebSocketBasicMessage::AsyncWebSocketBasicMessage(const char * data, size_t len, uint8_t opcode, bool mask) - :_len(len) - ,_sent(0) - ,_ack(0) - ,_acked(0) +AsyncWebSocketMessage::AsyncWebSocketMessage(std::shared_ptr> buffer, uint8_t opcode, bool mask) : + _WSbuffer{buffer}, + _opcode(opcode & 0x07), + _mask{mask}, + _status{_WSbuffer?WS_MSG_SENDING:WS_MSG_ERROR} { - _opcode = opcode & 0x07; - _mask = mask; - _data = (uint8_t*)malloc(_len+1); - // Serial.println("MSG alloc"); - if(_data == NULL){ - _len = 0; - _status = WS_MSG_ERROR; - } else { - _status = WS_MSG_SENDING; - memcpy(_data, data, _len); - _data[_len] = 0; - } -} -AsyncWebSocketBasicMessage::AsyncWebSocketBasicMessage(uint8_t opcode, bool mask) - :_len(0) - ,_sent(0) - ,_ack(0) - ,_acked(0) - ,_data(NULL) +} + +void AsyncWebSocketMessage::ack(size_t len, uint32_t time) { - _opcode = opcode & 0x07; - _mask = mask; - + (void)time; + _acked += len; + if (_sent >= _WSbuffer->size() && _acked >= _ack) + { + _status = WS_MSG_SENT; + } + //ets_printf("A: %u\n", len); } - -AsyncWebSocketBasicMessage::~AsyncWebSocketBasicMessage() { - if(_data != NULL) { - // Serial.println("MSG free"); - free(_data); - } -} - - void AsyncWebSocketBasicMessage::ack(size_t len, uint32_t time) { - (void)time; - _acked += len; - // Serial.printf("ACK %u = %u | %u = %u\n", _sent, _len, _acked, _ack); - if(_sent == _len && _acked == _ack){ - // Serial.println("ACK end"); - _status = WS_MSG_SENT; - } -} - size_t AsyncWebSocketBasicMessage::send(AsyncClient *client) { - if(_status != WS_MSG_SENDING){ - // Serial.println("MS 1"); - return 0; - } - if(_acked < _ack){ - // Serial.println("MS 2"); - return 0; - } - if(_sent == _len){ - // Serial.println("MS 3"); - _status = WS_MSG_SENT; - return 0; - } - if(_sent > _len){ - // Serial.println("MS 4"); - _status = WS_MSG_ERROR; - return 0; - } - size_t toSend = _len - _sent; - size_t window = webSocketSendFrameWindow(client); - // Serial.printf("Send %u %u %u\n", _len, _sent, toSend); - - if(window < toSend) { - toSend = window; - } - - _sent += toSend; - _ack += toSend + ((toSend < 126)?2:4) + (_mask * 4); - - bool final = (_sent == _len); - uint8_t* dPtr = (uint8_t*)(_data + (_sent - toSend)); - uint8_t opCode = (toSend && _sent == toSend)?_opcode:(uint8_t)WS_CONTINUATION; - - size_t sent = webSocketSendFrame(client, final, opCode, _mask, dPtr, toSend); - _status = WS_MSG_SENDING; - if(toSend && sent != toSend){ - size_t delta = (toSend - sent); - // Serial.printf("\ns:%u a:%u d:%u\n", _sent, _ack, delta); - _sent -= delta; - _ack -= delta + ((delta < 126)?2:4) + (_mask * 4); - // Serial.printf("s:%u a:%u\n", _sent, _ack); - if (!sent) { +size_t AsyncWebSocketMessage::send(AsyncClient *client) +{ + if (_status != WS_MSG_SENDING) + return 0; + if (_acked < _ack){ + return 0; + } + if (_sent == _WSbuffer->size()) + { + if(_acked == _ack) + _status = WS_MSG_SENT; + return 0; + } + if (_sent > _WSbuffer->size()) + { _status = WS_MSG_ERROR; - } - } - return sent; -} + //ets_printf("E: %u > %u\n", _sent, _WSbuffer->length()); + return 0; + } -// bool AsyncWebSocketBasicMessage::reserve(size_t size) { -// if (size) { -// _data = (uint8_t*)malloc(size +1); -// if (_data) { -// memset(_data, 0, size); -// _len = size; -// _status = WS_MSG_SENDING; -// return true; -// } -// } -// return false; -// } + size_t toSend = _WSbuffer->size() - _sent; + size_t window = webSocketSendFrameWindow(client); + if (window < toSend) { + toSend = window; + } -/* - * AsyncWebSocketMultiMessage Message - */ + _sent += toSend; + _ack += toSend + ((toSend < 126)?2:4) + (_mask * 4); + //ets_printf("W: %u %u\n", _sent - toSend, toSend); -AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer * buffer, uint8_t opcode, bool mask) - :_len(0) - ,_sent(0) - ,_ack(0) - ,_acked(0) - ,_WSbuffer(nullptr) -{ + bool final = (_sent == _WSbuffer->size()); + uint8_t* dPtr = (uint8_t*)(_WSbuffer->data() + (_sent - toSend)); + uint8_t opCode = (toSend && _sent == toSend)?_opcode:(uint8_t)WS_CONTINUATION; - _opcode = opcode & 0x07; - _mask = mask; - - if (buffer) { - _WSbuffer = buffer; - (*_WSbuffer)++; - // Serial.printf("INC WSbuffer == %u\n", _WSbuffer->count()); - _data = buffer->get(); - _len = buffer->length(); + size_t sent = webSocketSendFrame(client, final, opCode, _mask, dPtr, toSend); _status = WS_MSG_SENDING; - //ets_printf("M: %u\n", _len); - } else { - // Serial.println("BUFF ERROR"); - _status = WS_MSG_ERROR; - } - -} - - -AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() { - if (_WSbuffer) { - (*_WSbuffer)--; // decreases the counter. - // Serial.printf("DEC WSbuffer == %u\n", _WSbuffer->count()); - } -} - - void AsyncWebSocketMultiMessage::ack(size_t len, uint32_t time) { - (void)time; - _acked += len; - // Serial.printf("ACK %u = %u | %u = %u\n", _sent, _len, _acked, _ack); - if(_sent >= _len && _acked >= _ack){ - // Serial.println("ACK end"); - _status = WS_MSG_SENT; - } - //ets_printf("A: %u\n", len); -} - size_t AsyncWebSocketMultiMessage::send(AsyncClient *client) { - if(_status != WS_MSG_SENDING) { - // Serial.println("MS 1"); - return 0; - } - if(_acked < _ack){ - // Serial.println("MS 2"); - return 0; - } - if(_sent == _len){ - // Serial.println("MS 3"); - _status = WS_MSG_SENT; - return 0; - } - if(_sent > _len){ - // Serial.println("MS 4"); - _status = WS_MSG_ERROR; - //ets_printf("E: %u > %u\n", _sent, _len); - return 0; - } - size_t toSend = _len - _sent; - size_t window = webSocketSendFrameWindow(client); - // Serial.printf("Send %u %u %u\n", _len, _sent, toSend); - - if(window < toSend) { - toSend = window; - } - // Serial.printf("s:%u a:%u t:%u\n", _sent, _ack, toSend); - _sent += toSend; - _ack += toSend + ((toSend < 126)?2:4) + (_mask * 4); - - //ets_printf("W: %u %u\n", _sent - toSend, toSend); - - bool final = (_sent == _len); - uint8_t* dPtr = (uint8_t*)(_data + (_sent - toSend)); - uint8_t opCode = (toSend && _sent == toSend)?_opcode:(uint8_t)WS_CONTINUATION; - - size_t sent = webSocketSendFrame(client, final, opCode, _mask, dPtr, toSend); - _status = WS_MSG_SENDING; - if(toSend && sent != toSend){ - //ets_printf("E: %u != %u\n", toSend, sent); - size_t delta = (toSend - sent); - // Serial.printf("\ns:%u a:%u d:%u\n", _sent, _ack, delta); - _sent -= delta; - _ack -= delta + ((delta < 126)?2:4) + (_mask * 4); - // Serial.printf("s:%u a:%u\n", _sent, _ack); - if (!sent) { - _status = WS_MSG_ERROR; - } - } - //ets_printf("S: %u %u\n", _sent, sent); - return sent; + if (toSend && sent != toSend){ + //ets_printf("E: %u != %u\n", toSend, sent); + _sent -= (toSend - sent); + _ack -= (toSend - sent); + } + //ets_printf("S: %u %u\n", _sent, sent); + return sent; } @@ -522,185 +255,237 @@ AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() { const size_t AWSC_PING_PAYLOAD_LEN = 22; AsyncWebSocketClient::AsyncWebSocketClient(AsyncWebServerRequest *request, AsyncWebSocket *server) - : _controlQueue(LinkedList([](AsyncWebSocketControl *c){ delete c; })) - , _messageQueue(LinkedList([](AsyncWebSocketMessage *m){ delete m; })) - , _tempObject(NULL) + : _tempObject(NULL) { - _client = request->client(); - _server = server; - _clientId = _server->_getNextId(); - _status = WS_CONNECTED; - _pstate = 0; - _lastMessageTime = millis(); - _keepAlivePeriod = 0; - _client->setRxTimeout(0); - _client->onError([](void *r, AsyncClient* c, int8_t error){ (void)c; ((AsyncWebSocketClient*)(r))->_onError(error); }, this); - _client->onAck([](void *r, AsyncClient* c, size_t len, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onAck(len, time); }, this); - _client->onDisconnect([](void *r, AsyncClient* c){ ((AsyncWebSocketClient*)(r))->_onDisconnect(); delete c; }, this); - _client->onTimeout([](void *r, AsyncClient* c, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onTimeout(time); }, this); - _client->onData([](void *r, AsyncClient* c, void *buf, size_t len){ (void)c; ((AsyncWebSocketClient*)(r))->_onData(buf, len); }, this); - _client->onPoll([](void *r, AsyncClient* c){ (void)c; ((AsyncWebSocketClient*)(r))->_onPoll(); }, this); - _server->_addClient(this); - _server->_handleEvent(this, WS_EVT_CONNECT, request, NULL, 0); - delete request; - memset(&_pinfo,0,sizeof(_pinfo)); + _client = request->client(); + _server = server; + _clientId = _server->_getNextId(); + _status = WS_CONNECTED; + _pstate = 0; + _lastMessageTime = millis(); + _keepAlivePeriod = 0; + _client->setRxTimeout(0); + _client->onError([](void *r, AsyncClient* c, int8_t error){ (void)c; ((AsyncWebSocketClient*)(r))->_onError(error); }, this); + _client->onAck([](void *r, AsyncClient* c, size_t len, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onAck(len, time); }, this); + _client->onDisconnect([](void *r, AsyncClient* c){ ((AsyncWebSocketClient*)(r))->_onDisconnect(); delete c; }, this); + _client->onTimeout([](void *r, AsyncClient* c, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onTimeout(time); }, this); + _client->onData([](void *r, AsyncClient* c, void *buf, size_t len){ (void)c; ((AsyncWebSocketClient*)(r))->_onData(buf, len); }, this); + _client->onPoll([](void *r, AsyncClient* c){ (void)c; ((AsyncWebSocketClient*)(r))->_onPoll(); }, this); + _server->_handleEvent(this, WS_EVT_CONNECT, request, NULL, 0); + delete request; + memset(&_pinfo,0,sizeof(_pinfo)); } -AsyncWebSocketClient::~AsyncWebSocketClient(){ - // Serial.printf("%u FREE Q\n", id()); - _messageQueue.free(); - _controlQueue.free(); - _server->_cleanBuffers(); - _server->_handleEvent(this, WS_EVT_DISCONNECT, NULL, NULL, 0); +AsyncWebSocketClient::~AsyncWebSocketClient() +{ + { + AsyncWebLockGuard l(_lock); + + _messageQueue.clear(); + _controlQueue.clear(); + } + _server->_handleEvent(this, WS_EVT_DISCONNECT, NULL, NULL, 0); } -void AsyncWebSocketClient::_clearQueue(){ - while(!_messageQueue.isEmpty() && _messageQueue.front()->finished()){ - _messageQueue.remove(_messageQueue.front()); - } +void AsyncWebSocketClient::_clearQueue() +{ + while (!_messageQueue.empty() && _messageQueue.front().finished()) + _messageQueue.pop_front(); } void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){ - // Serial.printf("%u onAck\n", id()); - _lastMessageTime = millis(); - if(!_controlQueue.isEmpty()){ - auto head = _controlQueue.front(); - if(head->finished()){ - len -= head->len(); - if(_status == WS_DISCONNECTING && head->opcode() == WS_DISCONNECT){ - _controlQueue.remove(head); - _status = WS_DISCONNECTED; - _client->close(true); + _lastMessageTime = millis(); + + AsyncWebLockGuard l(_lock); + + if (!_controlQueue.empty()) { + auto &head = _controlQueue.front(); + if (head.finished()){ + len -= head.len(); + if (_status == WS_DISCONNECTING && head.opcode() == WS_DISCONNECT){ + _controlQueue.pop_front(); + _status = WS_DISCONNECTED; + l.unlock(); + if (_client) _client->close(true); + return; + } + _controlQueue.pop_front(); + } + } + + if(len && !_messageQueue.empty()){ + _messageQueue.front().ack(len, time); + } + + _clearQueue(); + + _runQueue(); +} + +void AsyncWebSocketClient::_onPoll() +{ + if (!_client) return; - } - _controlQueue.remove(head); + + AsyncWebLockGuard l(_lock); + if (_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) + { + l.unlock(); + _runQueue(); } - } - - if(len && !_messageQueue.isEmpty()){ - _messageQueue.front()->ack(len, time); - } - - _clearQueue(); - - _server->_cleanBuffers(); - // Serial.println("RUN 1"); - _runQueue(); -} - -void AsyncWebSocketClient::_onPoll(){ - if(_client->canSend() && (!_controlQueue.isEmpty() || !_messageQueue.isEmpty())){ - // Serial.println("RUN 2"); - _runQueue(); - } else if(_keepAlivePeriod > 0 && _controlQueue.isEmpty() && _messageQueue.isEmpty() && (millis() - _lastMessageTime) >= _keepAlivePeriod){ - ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); - } -} - -void AsyncWebSocketClient::_runQueue(){ - _clearQueue(); - - //size_t m0 = _messageQueue.isEmpty()? 0 : _messageQueue.length(); - //size_t m1 = _messageQueue.isEmpty()? 0 : _messageQueue.front()->betweenFrames(); - // Serial.printf("%u R C = %u %u\n", _clientId, m0, m1); - if(!_controlQueue.isEmpty() && (_messageQueue.isEmpty() || _messageQueue.front()->betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front()->len() - 1)){ - // Serial.printf("%u R S C\n", _clientId); - _controlQueue.front()->send(_client); - } else if(!_messageQueue.isEmpty() && _messageQueue.front()->betweenFrames() && webSocketSendFrameWindow(_client)){ - // Serial.printf("%u R S M = ", _clientId); - _messageQueue.front()->send(_client); - } - - _clearQueue(); -} - -bool AsyncWebSocketClient::queueIsFull(){ - if((_messageQueue.length() >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED) ) return true; - return false; -} - -void AsyncWebSocketClient::_queueMessage(AsyncWebSocketMessage *dataMessage){ - if(dataMessage == NULL){ - // Serial.printf("%u Q1\n", _clientId); - return; - } - if(_status != WS_CONNECTED){ - // Serial.printf("%u Q2\n", _clientId); - delete dataMessage; - return; - } - if(_messageQueue.length() >= WS_MAX_QUEUED_MESSAGES){ - ets_printf(String(F("ERROR: Too many messages queued\n")).c_str()); - // Serial.printf("%u Q3\n", _clientId); - delete dataMessage; - } else { - _messageQueue.add(dataMessage); - // Serial.printf("%u Q A %u\n", _clientId, _messageQueue.length()); - } - if(_client->canSend()) { - // Serial.printf("%u Q S\n", _clientId); - // Serial.println("RUN 3"); - _runQueue(); - } -} - -void AsyncWebSocketClient::_queueControl(AsyncWebSocketControl *controlMessage){ - if(controlMessage == NULL) - return; - _controlQueue.add(controlMessage); - if(_client->canSend()) { - // Serial.println("RUN 4"); - _runQueue(); - } -} - -void AsyncWebSocketClient::close(uint16_t code, const char * message){ - if(_status != WS_CONNECTED) - return; - if(code){ - uint8_t packetLen = 2; - if(message != NULL){ - size_t mlen = strlen(message); - if(mlen > 123) mlen = 123; - packetLen += mlen; + else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) + { + l.unlock(); + ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); } - char * buf = (char*)malloc(packetLen); - if(buf != NULL){ - buf[0] = (uint8_t)(code >> 8); - buf[1] = (uint8_t)(code & 0xFF); - if(message != NULL){ - memcpy(buf+2, message, packetLen -2); - } - _queueControl(new AsyncWebSocketControl(WS_DISCONNECT,(uint8_t*)buf,packetLen)); - free(buf); - return; +} + +void AsyncWebSocketClient::_runQueue() +{ + if (!_client) + return; + + AsyncWebLockGuard l(_lock); + + _clearQueue(); + + if (!_controlQueue.empty() && (_messageQueue.empty() || _messageQueue.front().betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front().len() - 1)) + { + //l.unlock(); + _controlQueue.front().send(_client); + } + else if (!_messageQueue.empty() && _messageQueue.front().betweenFrames() && webSocketSendFrameWindow(_client)) + { + //l.unlock(); + _messageQueue.front().send(_client); } - } - _queueControl(new AsyncWebSocketControl(WS_DISCONNECT)); } -void AsyncWebSocketClient::ping(uint8_t *data, size_t len){ - if(_status == WS_CONNECTED) - _queueControl(new AsyncWebSocketControl(WS_PING, data, len)); +bool AsyncWebSocketClient::queueIsFull() const +{ + size_t size; + { + AsyncWebLockGuard l(_lock); + size = _messageQueue.size(); + } + return (size >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED); } -void AsyncWebSocketClient::_onError(int8_t){ - //Serial.println("onErr"); +size_t AsyncWebSocketClient::queueLen() const +{ + AsyncWebLockGuard l(_lock); + + return _messageQueue.size() + _controlQueue.size(); } -void AsyncWebSocketClient::_onTimeout(uint32_t time){ - // Serial.println("onTime"); - (void)time; - _client->close(true); +bool AsyncWebSocketClient::canSend() const +{ + size_t size; + { + AsyncWebLockGuard l(_lock); + size = _messageQueue.size(); + } + return size < WS_MAX_QUEUED_MESSAGES; } -void AsyncWebSocketClient::_onDisconnect(){ - // Serial.println("onDis"); - _client = NULL; - _server->_handleDisconnect(this); +void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) +{ + if (!_client) + return; + + { + AsyncWebLockGuard l(_lock); + _controlQueue.emplace_back(opcode, data, len, mask); + } + + if (_client && _client->canSend()) + _runQueue(); } -void AsyncWebSocketClient::_onData(void *pbuf, size_t plen){ +void AsyncWebSocketClient::_queueMessage(std::shared_ptr> buffer, uint8_t opcode, bool mask) +{ + if(_status != WS_CONNECTED) + return; + + if (!_client) + return; + + { + AsyncWebLockGuard l(_lock); + if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) + { + l.unlock(); + ets_printf("AsyncWebSocketClient::_queueMessage: Too many messages queued, closing connection\n"); + _status = WS_DISCONNECTED; + if (_client) _client->close(true); + return; + } + else + { + _messageQueue.emplace_back(buffer, opcode, mask); + } + } + + if (_client && _client->canSend()) + _runQueue(); +} + +void AsyncWebSocketClient::close(uint16_t code, const char * message) +{ + if(_status != WS_CONNECTED) + return; + + if(code) + { + uint8_t packetLen = 2; + if (message != NULL) + { + size_t mlen = strlen(message); + if(mlen > 123) mlen = 123; + packetLen += mlen; + } + char * buf = (char*)malloc(packetLen); + if (buf != NULL) + { + buf[0] = (uint8_t)(code >> 8); + buf[1] = (uint8_t)(code & 0xFF); + if(message != NULL){ + memcpy(buf+2, message, packetLen -2); + } + _queueControl(WS_DISCONNECT, (uint8_t*)buf, packetLen); + free(buf); + return; + } + } + _queueControl(WS_DISCONNECT); +} + +void AsyncWebSocketClient::ping(const uint8_t *data, size_t len) +{ + if (_status == WS_CONNECTED) + _queueControl(WS_PING, data, len); +} + +void AsyncWebSocketClient::_onError(int8_t) +{ + //Serial.println("onErr"); +} + +void AsyncWebSocketClient::_onTimeout(uint32_t time) +{ + // Serial.println("onTime"); + (void)time; + _client->close(true); +} + +void AsyncWebSocketClient::_onDisconnect() +{ + // Serial.println("onDis"); + _client = NULL; +} + +void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) +{ // Serial.println("onData"); _lastMessageTime = millis(); uint8_t *data = (uint8_t*)pbuf; @@ -767,10 +552,10 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen){ } else { _status = WS_DISCONNECTING; _client->ackLater(); - _queueControl(new AsyncWebSocketControl(WS_DISCONNECT, data, datalen)); + _queueControl(WS_DISCONNECT, data, datalen); } } else if(_pinfo.opcode == WS_PING){ - _queueControl(new AsyncWebSocketControl(WS_PONG, data, datalen)); + _queueControl(WS_PONG, data, datalen); } else if(_pinfo.opcode == WS_PONG){ if(datalen != AWSC_PING_PAYLOAD_LEN || memcmp(AWSC_PING_PAYLOAD, data, AWSC_PING_PAYLOAD_LEN) != 0) _server->_handleEvent(this, WS_EVT_PONG, NULL, data, datalen); @@ -794,7 +579,8 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen){ } } -size_t AsyncWebSocketClient::printf(const char *format, ...) { +size_t AsyncWebSocketClient::printf(const char *format, ...) +{ va_list arg; va_start(arg, format); char* temp = new char[MAX_PRINTF_LEN]; @@ -825,7 +611,8 @@ size_t AsyncWebSocketClient::printf(const char *format, ...) { } #ifndef ESP32 -size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) { +size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) +{ va_list arg; va_start(arg, formatP); char* temp = new char[MAX_PRINTF_LEN]; @@ -856,70 +643,110 @@ size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) { } #endif -void AsyncWebSocketClient::text(const char * message, size_t len){ - _queueMessage(new AsyncWebSocketBasicMessage(message, len)); -} -void AsyncWebSocketClient::text(const char * message){ - text(message, strlen(message)); -} -void AsyncWebSocketClient::text(uint8_t * message, size_t len){ - text((const char *)message, len); -} -void AsyncWebSocketClient::text(char * message){ - text(message, strlen(message)); -} -void AsyncWebSocketClient::text(const String &message){ - text(message.c_str(), message.length()); -} -void AsyncWebSocketClient::text(const __FlashStringHelper *data){ - text(String(data)); -} -void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer * buffer) +namespace { +std::shared_ptr> makeBuffer(const uint8_t *message, size_t len) { - _queueMessage(new AsyncWebSocketMultiMessage(buffer)); + auto buffer = std::make_shared>(len); + std::memcpy(buffer->data(), message, len); + return buffer; +} } -void AsyncWebSocketClient::binary(const char * message, size_t len){ - _queueMessage(new AsyncWebSocketBasicMessage(message, len, WS_BINARY)); -} -void AsyncWebSocketClient::binary(const char * message){ - binary(message, strlen(message)); -} -void AsyncWebSocketClient::binary(uint8_t * message, size_t len){ - binary((const char *)message, len); -} -void AsyncWebSocketClient::binary(char * message){ - binary(message, strlen(message)); -} -void AsyncWebSocketClient::binary(const String &message){ - binary(message.c_str(), message.length()); -} -void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len){ - PGM_P p = reinterpret_cast(data); - char * message = (char*) malloc(len); - if(message){ - memcpy_P(message, p, len); - binary(message, len); - free(message); - } - -} -void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer * buffer) +void AsyncWebSocketClient::text(std::shared_ptr> buffer) { - _queueMessage(new AsyncWebSocketMultiMessage(buffer, WS_BINARY)); + _queueMessage(buffer); } -IPAddress AsyncWebSocketClient::remoteIP() { - if(!_client) { - return IPAddress(0U); +void AsyncWebSocketClient::text(const uint8_t *message, size_t len) +{ + text(makeBuffer(message, len)); +} + +void AsyncWebSocketClient::text(const char *message, size_t len) +{ + text((const uint8_t *)message, len); +} + +void AsyncWebSocketClient::text(const char *message) +{ + text(message, strlen(message)); +} + +void AsyncWebSocketClient::text(const String &message) +{ + text(message.c_str(), message.length()); +} + +void AsyncWebSocketClient::text(const __FlashStringHelper *data) +{ + PGM_P p = reinterpret_cast(data); + + size_t n = 0; + while (1) + { + if (pgm_read_byte(p+n) == 0) break; + n += 1; } + + char * message = (char*) malloc(n+1); + if(message) + { + memcpy_P(message, p, n); + message[n] = 0; + text(message, n); + free(message); + } +} + +void AsyncWebSocketClient::binary(std::shared_ptr> buffer) +{ + _queueMessage(buffer, WS_BINARY); +} + +void AsyncWebSocketClient::binary(const uint8_t *message, size_t len) +{ + binary(makeBuffer(message, len)); +} + +void AsyncWebSocketClient::binary(const char *message, size_t len) +{ + binary((const uint8_t *)message, len); +} + +void AsyncWebSocketClient::binary(const char *message) +{ + binary(message, strlen(message)); +} + +void AsyncWebSocketClient::binary(const String &message) +{ + binary(message.c_str(), message.length()); +} + +void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) +{ + PGM_P p = reinterpret_cast(data); + char *message = (char*) malloc(len); + if (message) { + memcpy_P(message, p, len); + binary(message, len); + free(message); + } +} + +IPAddress AsyncWebSocketClient::remoteIP() const +{ + if (!_client) + return IPAddress(0U); + return _client->remoteIP(); } -uint16_t AsyncWebSocketClient::remotePort() { - if(!_client) { +uint16_t AsyncWebSocketClient::remotePort() const +{ + if(!_client) return 0; - } + return _client->remotePort(); } @@ -931,10 +758,8 @@ uint16_t AsyncWebSocketClient::remotePort() { AsyncWebSocket::AsyncWebSocket(const String& url) :_url(url) - ,_clients(LinkedList([](AsyncWebSocketClient *c){ delete c; })) ,_cNextId(1) ,_enabled(true) - ,_buffers(LinkedList([](AsyncWebSocketMessageBuffer *b){ delete b; })) { _eventHandler = NULL; } @@ -947,176 +772,264 @@ void AsyncWebSocket::_handleEvent(AsyncWebSocketClient * client, AwsEventType ty } } -void AsyncWebSocket::_addClient(AsyncWebSocketClient * client){ - _clients.add(client); +AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) +{ + _clients.emplace_back(request, this); + return &_clients.back(); } -void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient * client){ - - _clients.remove_first([=](AsyncWebSocketClient * c){ - return c->id() == client->id(); - }); +bool AsyncWebSocket::availableForWriteAll() +{ + return std::none_of(std::begin(_clients), std::end(_clients), + [](const AsyncWebSocketClient &c){ return c.queueIsFull(); }); } -bool AsyncWebSocket::availableForWriteAll(){ - for(const auto& c: _clients){ - if(c->queueIsFull()) return false; - } - return true; +bool AsyncWebSocket::availableForWrite(uint32_t id) +{ + const auto iter = std::find_if(std::begin(_clients), std::end(_clients), + [id](const AsyncWebSocketClient &c){ return c.id() == id; }); + if (iter == std::end(_clients)) + return true; + return !iter->queueIsFull(); } -bool AsyncWebSocket::availableForWrite(uint32_t id){ - for(const auto& c: _clients){ - if(c->queueIsFull() && (c->id() == id )) return false; - } - return true; +size_t AsyncWebSocket::count() const +{ + return std::count_if(std::begin(_clients), std::end(_clients), + [](const AsyncWebSocketClient &c){ return c.status() == WS_CONNECTED; }); } -size_t AsyncWebSocket::count() const { - return _clients.count_if([](AsyncWebSocketClient * c){ - return c->status() == WS_CONNECTED; - }); -} +AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id) +{ + const auto iter = std::find_if(std::begin(_clients), std::end(_clients), + [id](const AsyncWebSocketClient &c){ return c.id() == id && c.status() == WS_CONNECTED; }); + if (iter == std::end(_clients)) + return nullptr; -AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){ - for(const auto &c: _clients){ - if(c->id() == id && c->status() == WS_CONNECTED){ - return c; - } - } - return nullptr; + return &(*iter); } -void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message){ - AsyncWebSocketClient * c = client(id); - if(c) - c->close(code, message); +void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message) +{ + if (AsyncWebSocketClient *c = client(id)) + c->close(code, message); } -void AsyncWebSocket::closeAll(uint16_t code, const char * message){ - for(const auto& c: _clients){ - if(c->status() == WS_CONNECTED) - c->close(code, message); - } +void AsyncWebSocket::closeAll(uint16_t code, const char * message) +{ + for (auto &c : _clients) + if (c.status() == WS_CONNECTED) + c.close(code, message); } void AsyncWebSocket::cleanupClients(uint16_t maxClients) { - if (count() > maxClients){ - _clients.front()->close(); - } -} + if (count() > maxClients) + _clients.front().close(); -void AsyncWebSocket::ping(uint32_t id, uint8_t *data, size_t len){ - AsyncWebSocketClient * c = client(id); - if(c) - c->ping(data, len); -} - -void AsyncWebSocket::pingAll(uint8_t *data, size_t len){ - for(const auto& c: _clients){ - if(c->status() == WS_CONNECTED) - c->ping(data, len); - } -} - -void AsyncWebSocket::text(uint32_t id, const char * message, size_t len){ - AsyncWebSocketClient * c = client(id); - if(c) - c->text(message, len); -} - -void AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer * buffer){ - if (!buffer) return; - buffer->lock(); - for(const auto& c: _clients){ - if(c->status() == WS_CONNECTED){ - c->text(buffer); + for (auto iter = std::begin(_clients); iter != std::end(_clients);) + { + if (iter->shouldBeDeleted()) + iter = _clients.erase(iter); + else + iter++; } - } - buffer->unlock(); - _cleanBuffers(); } - -void AsyncWebSocket::textAll(const char * message, size_t len){ - //if (_buffers.length()) return; - AsyncWebSocketMessageBuffer * WSBuffer = makeBuffer((uint8_t *)message, len); - textAll(WSBuffer); -} - -void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len){ - AsyncWebSocketClient * c = client(id); - if(c) - c->binary(message, len); -} - -void AsyncWebSocket::binaryAll(const char * message, size_t len){ - AsyncWebSocketMessageBuffer * buffer = makeBuffer((uint8_t *)message, len); - binaryAll(buffer); -} - -void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer * buffer) +void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) { - if (!buffer) return; - buffer->lock(); - for(const auto& c: _clients){ - if(c->status() == WS_CONNECTED) - c->binary(buffer); - } - buffer->unlock(); - _cleanBuffers(); + if (AsyncWebSocketClient * c = client(id)) + c->ping(data, len); } -void AsyncWebSocket::message(uint32_t id, AsyncWebSocketMessage *message){ - AsyncWebSocketClient * c = client(id); - if(c) - c->message(message); +void AsyncWebSocket::pingAll(const uint8_t *data, size_t len) +{ + for (auto &c : _clients) + if (c.status() == WS_CONNECTED) + c.ping(data, len); } -void AsyncWebSocket::messageAll(AsyncWebSocketMultiMessage *message){ - for(const auto& c: _clients){ - if(c->status() == WS_CONNECTED) - c->message(message); - } - _cleanBuffers(); +void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) +{ + if (AsyncWebSocketClient * c = client(id)) + c->text(makeBuffer(message, len)); +} +void AsyncWebSocket::text(uint32_t id, const char *message, size_t len) +{ + text(id, (const uint8_t *)message, len); +} +void AsyncWebSocket::text(uint32_t id, const char * message) +{ + text(id, message, strlen(message)); +} +void AsyncWebSocket::text(uint32_t id, const String &message) +{ + text(id, message.c_str(), message.length()); +} +void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data) +{ + PGM_P p = reinterpret_cast(data); + + size_t n = 0; + while (true) + { + if (pgm_read_byte(p+n) == 0) + break; + n += 1; + } + + char * message = (char*) malloc(n+1); + if (message) + { + memcpy_P(message, p, n); + message[n] = 0; + text(id, message, n); + free(message); + } +} + +void AsyncWebSocket::textAll(std::shared_ptr> buffer) +{ + for (auto &c : _clients) + if (c.status() == WS_CONNECTED) + c.text(buffer); +} +void AsyncWebSocket::textAll(const uint8_t *message, size_t len) +{ + textAll(makeBuffer(message, len)); +} +void AsyncWebSocket::textAll(const char * message, size_t len) +{ + textAll((const uint8_t *)message, len); +} +void AsyncWebSocket::textAll(const char *message) +{ + textAll(message, strlen(message)); +} +void AsyncWebSocket::textAll(const String &message) +{ + textAll(message.c_str(), message.length()); +} +void AsyncWebSocket::textAll(const __FlashStringHelper *data) +{ + PGM_P p = reinterpret_cast(data); + + size_t n = 0; + while (1) + { + if (pgm_read_byte(p+n) == 0) break; + n += 1; + } + + char *message = (char*)malloc(n+1); + if(message) + { + memcpy_P(message, p, n); + message[n] = 0; + textAll(message, n); + free(message); + } +} + +void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) +{ + if (AsyncWebSocketClient *c = client(id)) + c->binary(makeBuffer(message, len)); +} +void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len) +{ + binary(id, (const uint8_t *)message, len); +} +void AsyncWebSocket::binary(uint32_t id, const char * message) +{ + binary(id, message, strlen(message)); +} +void AsyncWebSocket::binary(uint32_t id, const String &message) +{ + binary(id, message.c_str(), message.length()); +} +void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t len) +{ + PGM_P p = reinterpret_cast(data); + char *message = (char*) malloc(len); + if (message) + { + memcpy_P(message, p, len); + binary(id, message, len); + free(message); + } +} + +void AsyncWebSocket::binaryAll(std::shared_ptr> buffer) +{ + for (auto &c : _clients) + if (c.status() == WS_CONNECTED) + c.binary(buffer); +} + +void AsyncWebSocket::binaryAll(const uint8_t *message, size_t len) +{ + binaryAll(makeBuffer(message, len)); +} + +void AsyncWebSocket::binaryAll(const char *message, size_t len) +{ + binaryAll((const uint8_t *)message, len); +} +void AsyncWebSocket::binaryAll(const char *message) +{ + binaryAll(message, strlen(message)); +} +void AsyncWebSocket::binaryAll(const String &message) +{ + binaryAll(message.c_str(), message.length()); +} +void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len) +{ + PGM_P p = reinterpret_cast(data); + char * message = (char*) malloc(len); + if(message) + { + memcpy_P(message, p, len); + binaryAll(message, len); + free(message); + } } size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ - AsyncWebSocketClient * c = client(id); - if(c){ - va_list arg; - va_start(arg, format); - size_t len = c->printf(format, arg); - va_end(arg); - return len; - } - return 0; + AsyncWebSocketClient * c = client(id); + if (c) + { + va_list arg; + va_start(arg, format); + size_t len = c->printf(format, arg); + va_end(arg); + return len; + } + return 0; } -size_t AsyncWebSocket::printfAll(const char *format, ...) { - va_list arg; - char* temp = new char[MAX_PRINTF_LEN]; - if(!temp){ - return 0; - } - va_start(arg, format); - size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); - va_end(arg); - delete[] temp; +size_t AsyncWebSocket::printfAll(const char *format, ...) +{ + va_list arg; + char *temp = new char[MAX_PRINTF_LEN]; + if (!temp) + return 0; - AsyncWebSocketMessageBuffer * buffer = makeBuffer(len); - if (!buffer) { - return 0; - } + va_start(arg, format); + size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); + va_end(arg); + delete[] temp; - va_start(arg, format); - vsnprintf( (char *)buffer->get(), len + 1, format, arg); - va_end(arg); + std::shared_ptr> buffer = std::make_shared>(len); - textAll(buffer); - return len; + va_start(arg, format); + vsnprintf( (char *)buffer->data(), len + 1, format, arg); + va_end(arg); + + textAll(buffer); + return len; } #ifndef ESP32 @@ -1133,100 +1046,27 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){ } #endif -size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) { - va_list arg; - char* temp = new char[MAX_PRINTF_LEN]; - if(!temp){ - return 0; - } - va_start(arg, formatP); - size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); - va_end(arg); - delete[] temp; +size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) +{ + va_list arg; + char *temp = new char[MAX_PRINTF_LEN]; + if (!temp) + return 0; - AsyncWebSocketMessageBuffer * buffer = makeBuffer(len + 1); - if (!buffer) { - return 0; - } + va_start(arg, formatP); + size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); + va_end(arg); + delete[] temp; - va_start(arg, formatP); - vsnprintf_P((char *)buffer->get(), len + 1, formatP, arg); - va_end(arg); + std::shared_ptr> buffer = std::make_shared>(len + 1); - textAll(buffer); - return len; -} + va_start(arg, formatP); + vsnprintf_P((char *)buffer->data(), len + 1, formatP, arg); + va_end(arg); -void AsyncWebSocket::text(uint32_t id, const char * message){ - text(id, message, strlen(message)); + textAll(buffer); + return len; } -void AsyncWebSocket::text(uint32_t id, uint8_t * message, size_t len){ - text(id, (const char *)message, len); -} -void AsyncWebSocket::text(uint32_t id, char * message){ - text(id, message, strlen(message)); -} -void AsyncWebSocket::text(uint32_t id, const String &message){ - text(id, message.c_str(), message.length()); -} -void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *message){ - AsyncWebSocketClient * c = client(id); - if(c != NULL) - c->text(message); -} -void AsyncWebSocket::textAll(const char * message){ - textAll(message, strlen(message)); -} -void AsyncWebSocket::textAll(uint8_t * message, size_t len){ - textAll((const char *)message, len); -} -void AsyncWebSocket::textAll(char * message){ - textAll(message, strlen(message)); -} -void AsyncWebSocket::textAll(const String &message){ - textAll(message.c_str(), message.length()); -} -void AsyncWebSocket::textAll(const __FlashStringHelper *message){ - for(const auto& c: _clients){ - if(c->status() == WS_CONNECTED) - c->text(message); - } -} -void AsyncWebSocket::binary(uint32_t id, const char * message){ - binary(id, message, strlen(message)); -} -void AsyncWebSocket::binary(uint32_t id, uint8_t * message, size_t len){ - binary(id, (const char *)message, len); -} -void AsyncWebSocket::binary(uint32_t id, char * message){ - binary(id, message, strlen(message)); -} -void AsyncWebSocket::binary(uint32_t id, const String &message){ - binary(id, message.c_str(), message.length()); -} -void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *message, size_t len){ - AsyncWebSocketClient * c = client(id); - if(c != NULL) - c-> binary(message, len); -} -void AsyncWebSocket::binaryAll(const char * message){ - binaryAll(message, strlen(message)); -} -void AsyncWebSocket::binaryAll(uint8_t * message, size_t len){ - binaryAll((const char *)message, len); -} -void AsyncWebSocket::binaryAll(char * message){ - binaryAll(message, strlen(message)); -} -void AsyncWebSocket::binaryAll(const String &message){ - binaryAll(message.c_str(), message.length()); -} -void AsyncWebSocket::binaryAll(const __FlashStringHelper *message, size_t len){ - for(const auto& c: _clients){ - if(c->status() == WS_CONNECTED) - c-> binary(message, len); - } - } const char __WS_STR_CONNECTION[] PROGMEM = { "Connection" }; const char __WS_STR_UPGRADE[] PROGMEM = { "Upgrade" }; @@ -1249,91 +1089,56 @@ const char __WS_STR_UUID[] PROGMEM = { "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" }; #define WS_STR_UUID FPSTR(__WS_STR_UUID) bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){ - if(!_enabled) - return false; + if(!_enabled) + return false; - if(request->method() != HTTP_GET || !request->url().equals(_url) || !request->isExpectedRequestedConnType(RCT_WS)) - return false; + if(request->method() != HTTP_GET || !request->url().equals(_url) || !request->isExpectedRequestedConnType(RCT_WS)) + return false; - request->addInterestingHeader(WS_STR_CONNECTION); - request->addInterestingHeader(WS_STR_UPGRADE); - request->addInterestingHeader(WS_STR_ORIGIN); - request->addInterestingHeader(WS_STR_COOKIE); - request->addInterestingHeader(WS_STR_VERSION); - request->addInterestingHeader(WS_STR_KEY); - request->addInterestingHeader(WS_STR_PROTOCOL); - return true; + request->addInterestingHeader(WS_STR_CONNECTION); + request->addInterestingHeader(WS_STR_UPGRADE); + request->addInterestingHeader(WS_STR_ORIGIN); + request->addInterestingHeader(WS_STR_COOKIE); + request->addInterestingHeader(WS_STR_VERSION); + request->addInterestingHeader(WS_STR_KEY); + request->addInterestingHeader(WS_STR_PROTOCOL); + return true; } -void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request){ - if(!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)){ - request->send(400); - return; - } - if((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())){ - return request->requestAuthentication(); - } -////////////////////////////////////////// - if(_handshakeHandler != nullptr){ - if(!_handshakeHandler(request)){ - request->send(401); - return; +void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) +{ + if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)) + { + request->send(400); + return; + } + if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())) + { + return request->requestAuthentication(); + } + if (_handshakeHandler != nullptr){ + if(!_handshakeHandler(request)){ + request->send(401); + return; + } + } + AsyncWebHeader* version = request->getHeader(WS_STR_VERSION); + if (version->value().toInt() != 13) + { + AsyncWebServerResponse *response = request->beginResponse(400); + response->addHeader(WS_STR_VERSION, F("13")); + request->send(response); + return; + } + AsyncWebHeader* key = request->getHeader(WS_STR_KEY); + AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this); + if (request->hasHeader(WS_STR_PROTOCOL)) + { + AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL); + //ToDo: check protocol + response->addHeader(WS_STR_PROTOCOL, protocol->value()); } - } -////////////////////////////////////////// - AsyncWebHeader* version = request->getHeader(WS_STR_VERSION); - if(version->value().toInt() != 13){ - AsyncWebServerResponse *response = request->beginResponse(400); - response->addHeader(WS_STR_VERSION, F("13")); request->send(response); - return; - } - AsyncWebHeader* key = request->getHeader(WS_STR_KEY); - AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this); - if(request->hasHeader(WS_STR_PROTOCOL)){ - AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL); - //ToDo: check protocol - response->addHeader(WS_STR_PROTOCOL, protocol->value()); - } - request->send(response); -} - -AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(size_t size) -{ - AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(size); - if (buffer) { - AsyncWebLockGuard l(_lock); - _buffers.add(buffer); - } - return buffer; -} - -AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(uint8_t * data, size_t size) -{ - AsyncWebSocketMessageBuffer * buffer = new AsyncWebSocketMessageBuffer(data, size); - - if (buffer) { - AsyncWebLockGuard l(_lock); - // Serial.printf("Add to global buffers = %u\n", _buffers.length() + 1); - _buffers.add(buffer); - } - - return buffer; -} - -void AsyncWebSocket::_cleanBuffers() -{ - AsyncWebLockGuard l(_lock); - for(AsyncWebSocketMessageBuffer * c: _buffers){ - if(c && c->canDelete()){ - // Serial.printf("Remove from global buffers = %u\n", _buffers.length() - 1); - _buffers.remove(c); - } - } -} - -AsyncWebSocket::AsyncWebSocketClientLinkedList AsyncWebSocket::getClients() const { - return _clients; } /* @@ -1341,56 +1146,65 @@ AsyncWebSocket::AsyncWebSocketClientLinkedList AsyncWebSocket::getClients() cons * Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480 */ -AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server){ - _server = server; - _code = 101; - _sendContentLength = false; +AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server) +{ + _server = server; + _code = 101; + _sendContentLength = false; - uint8_t * hash = (uint8_t*)malloc(20); - if(hash == NULL){ - _state = RESPONSE_FAILED; - return; - } - char * buffer = (char *) malloc(33); - if(buffer == NULL){ - free(hash); - _state = RESPONSE_FAILED; - return; - } + uint8_t * hash = (uint8_t*)malloc(20); + if(hash == NULL) + { + _state = RESPONSE_FAILED; + return; + } + char * buffer = (char *) malloc(33); + if(buffer == NULL) + { + free(hash); + _state = RESPONSE_FAILED; + return; + } #ifdef ESP8266 - sha1(key + WS_STR_UUID, hash); + sha1(key + WS_STR_UUID, hash); #else - (String&)key += WS_STR_UUID; - SHA1_CTX ctx; - SHA1Init(&ctx); - SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length()); - SHA1Final(hash, &ctx); + (String&)key += WS_STR_UUID; + mbedtls_sha1_context ctx; + mbedtls_sha1_init(&ctx); + mbedtls_sha1_starts_ret(&ctx); + mbedtls_sha1_update_ret(&ctx, (const unsigned char*)key.c_str(), key.length()); + mbedtls_sha1_finish_ret(&ctx, hash); + mbedtls_sha1_free(&ctx); #endif - base64_encodestate _state; - base64_init_encodestate(&_state); - int len = base64_encode_block((const char *) hash, 20, buffer, &_state); - len = base64_encode_blockend((buffer + len), &_state); - addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE); - addHeader(WS_STR_UPGRADE, F("websocket")); - addHeader(WS_STR_ACCEPT,buffer); - free(buffer); - free(hash); + base64_encodestate _state; + base64_init_encodestate(&_state); + int len = base64_encode_block((const char *) hash, 20, buffer, &_state); + len = base64_encode_blockend((buffer + len), &_state); + addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE); + addHeader(WS_STR_UPGRADE, F("websocket")); + addHeader(WS_STR_ACCEPT,buffer); + free(buffer); + free(hash); } -void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){ - if(_state == RESPONSE_FAILED){ - request->client()->close(true); - return; - } - String out = _assembleHead(request->version()); - request->client()->write(out.c_str(), _headLength); - _state = RESPONSE_WAIT_ACK; +void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request) +{ + if(_state == RESPONSE_FAILED) + { + request->client()->close(true); + return; + } + String out = _assembleHead(request->version()); + request->client()->write(out.c_str(), _headLength); + _state = RESPONSE_WAIT_ACK; } -size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time){ - (void)time; - if(len){ - new AsyncWebSocketClient(request, _server); - } - return 0; +size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time) +{ + (void)time; + + if(len) + _server->_newClient(request); + + return 0; } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 5ebf1cc..9a0a3b4 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -24,7 +24,7 @@ #include #ifdef ESP32 #include -#define WS_MAX_QUEUED_MESSAGES 32 +#define WS_MAX_QUEUED_MESSAGES 16 #else #include #define WS_MAX_QUEUED_MESSAGES 8 @@ -33,6 +33,10 @@ #include "AsyncWebSynchronization.h" +#include +#include +#include + #ifdef ESP8266 #include #ifdef CRYPTO_HASH_h // include Hash.h from espressif framework if the first include was from the crypto library @@ -80,78 +84,25 @@ typedef enum { WS_CONTINUATION, WS_TEXT, WS_BINARY, WS_DISCONNECT = 0x08, WS_PIN typedef enum { WS_MSG_SENDING, WS_MSG_SENT, WS_MSG_ERROR } AwsMessageStatus; typedef enum { WS_EVT_CONNECT, WS_EVT_DISCONNECT, WS_EVT_PONG, WS_EVT_ERROR, WS_EVT_DATA } AwsEventType; -class AsyncWebSocketMessageBuffer { - private: - uint8_t * _data; - size_t _len; - bool _lock; - uint32_t _count; +class AsyncWebSocketMessage +{ +private: + std::shared_ptr> _WSbuffer; + uint8_t _opcode{WS_TEXT}; + bool _mask{false}; + AwsMessageStatus _status{WS_MSG_ERROR}; + size_t _sent{}; + size_t _ack{}; + size_t _acked{}; - public: - AsyncWebSocketMessageBuffer(); - AsyncWebSocketMessageBuffer(size_t size); - AsyncWebSocketMessageBuffer(uint8_t * data, size_t size); - AsyncWebSocketMessageBuffer(const AsyncWebSocketMessageBuffer &); - AsyncWebSocketMessageBuffer(AsyncWebSocketMessageBuffer &&); - ~AsyncWebSocketMessageBuffer(); - void operator ++(int i) { (void)i; _count++; } - void operator --(int i) { (void)i; if (_count > 0) { _count--; } ; } - bool reserve(size_t size); - void lock() { _lock = true; } - void unlock() { _lock = false; } - uint8_t * get() { return _data; } - size_t length() { return _len; } - uint32_t count() { return _count; } - bool canDelete() { return (!_count && !_lock); } - - friend AsyncWebSocket; - -}; - -class AsyncWebSocketMessage { - protected: - uint8_t _opcode; - bool _mask; - AwsMessageStatus _status; - public: - AsyncWebSocketMessage():_opcode(WS_TEXT),_mask(false),_status(WS_MSG_ERROR){} - virtual ~AsyncWebSocketMessage(){} - virtual void ack(size_t len __attribute__((unused)), uint32_t time __attribute__((unused))){} - virtual size_t send(AsyncClient *client __attribute__((unused))){ return 0; } - virtual bool finished(){ return _status != WS_MSG_SENDING; } - virtual bool betweenFrames() const { return false; } -}; - -class AsyncWebSocketBasicMessage: public AsyncWebSocketMessage { - private: - size_t _len; - size_t _sent; - size_t _ack; - size_t _acked; - uint8_t * _data; public: - AsyncWebSocketBasicMessage(const char * data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); - AsyncWebSocketBasicMessage(uint8_t opcode=WS_TEXT, bool mask=false); - virtual ~AsyncWebSocketBasicMessage() override; - virtual bool betweenFrames() const override { return _acked == _ack; } - virtual void ack(size_t len, uint32_t time) override ; - virtual size_t send(AsyncClient *client) override ; -}; + AsyncWebSocketMessage(std::shared_ptr> buffer, uint8_t opcode=WS_TEXT, bool mask=false); -class AsyncWebSocketMultiMessage: public AsyncWebSocketMessage { - private: - uint8_t * _data; - size_t _len; - size_t _sent; - size_t _ack; - size_t _acked; - AsyncWebSocketMessageBuffer * _WSbuffer; -public: - AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer * buffer, uint8_t opcode=WS_TEXT, bool mask=false); - virtual ~AsyncWebSocketMultiMessage() override; - virtual bool betweenFrames() const override { return _acked == _ack; } - virtual void ack(size_t len, uint32_t time) override ; - virtual size_t send(AsyncClient *client) override ; + bool finished() const { return _status != WS_MSG_SENDING; } + bool betweenFrames() const { return _acked == _ack; } + + void ack(size_t len, uint32_t time); + size_t send(AsyncClient *client); }; class AsyncWebSocketClient { @@ -161,8 +112,10 @@ class AsyncWebSocketClient { uint32_t _clientId; AwsClientStatus _status; - LinkedList _controlQueue; - LinkedList _messageQueue; + AsyncWebLock _lock; + + std::deque _controlQueue; + std::deque _messageQueue; uint8_t _pstate; AwsFrameInfo _pinfo; @@ -170,8 +123,8 @@ class AsyncWebSocketClient { uint32_t _lastMessageTime; uint32_t _keepAlivePeriod; - void _queueMessage(AsyncWebSocketMessage *dataMessage); - void _queueControl(AsyncWebSocketControl *controlMessage); + void _queueControl(uint8_t opcode, const uint8_t *data=NULL, size_t len=0, bool mask=false); + void _queueMessage(std::shared_ptr> buffer, uint8_t opcode=WS_TEXT, bool mask=false); void _runQueue(); void _clearQueue(); @@ -182,18 +135,22 @@ class AsyncWebSocketClient { ~AsyncWebSocketClient(); //client id increments for the given server - uint32_t id(){ return _clientId; } - AwsClientStatus status(){ return _status; } - AsyncClient* client(){ return _client; } + uint32_t id() const { return _clientId; } + AwsClientStatus status() const { return _status; } + AsyncClient* client() { return _client; } + const AsyncClient* client() const { return _client; } AsyncWebSocket *server(){ return _server; } + const AsyncWebSocket *server() const { return _server; } AwsFrameInfo const &pinfo() const { return _pinfo; } - IPAddress remoteIP(); - uint16_t remotePort(); + IPAddress remoteIP() const; + uint16_t remotePort() const; + + bool shouldBeDeleted() const { return !_client; } //control frames void close(uint16_t code=0, const char * message=NULL); - void ping(uint8_t *data=NULL, size_t len=0); + void ping(const uint8_t *data=NULL, size_t len=0); //set auto-ping period in seconds. disabled if zero (default) void keepAlivePeriod(uint16_t seconds){ @@ -204,31 +161,30 @@ class AsyncWebSocketClient { } //data packets - void message(AsyncWebSocketMessage *message){ _queueMessage(message); } - bool queueIsFull(); - size_t queueLen() { return _messageQueue.length() + _controlQueue.length(); } + void message(std::shared_ptr> buffer, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(buffer, opcode, mask); } + bool queueIsFull() const; + size_t queueLen() const; size_t printf(const char *format, ...) __attribute__ ((format (printf, 2, 3))); #ifndef ESP32 size_t printf_P(PGM_P formatP, ...) __attribute__ ((format (printf, 2, 3))); #endif - void text(const char * message, size_t len); - void text(const char * message); - void text(uint8_t * message, size_t len); - void text(char * message); - void text(const String &message); - void text(const __FlashStringHelper *data); - void text(AsyncWebSocketMessageBuffer *buffer); + void text(std::shared_ptr> buffer); + void text(const uint8_t *message, size_t len); + void text(const char *message, size_t len); + void text(const char *message); + void text(const String &message); + void text(const __FlashStringHelper *message); + + void binary(std::shared_ptr> buffer); + void binary(const uint8_t *message, size_t len); void binary(const char * message, size_t len); void binary(const char * message); - void binary(uint8_t * message, size_t len); - void binary(char * message); void binary(const String &message); - void binary(const __FlashStringHelper *data, size_t len); - void binary(AsyncWebSocketMessageBuffer *buffer); + void binary(const __FlashStringHelper *message, size_t len); - bool canSend() { return _messageQueue.length() < WS_MAX_QUEUED_MESSAGES; } + bool canSend() const; //system callbacks (do not call) void _onAck(size_t len, uint32_t time); @@ -244,11 +200,9 @@ typedef std::function AsyncWebSocketClientLinkedList; private: String _url; - AsyncWebSocketClientLinkedList _clients; + std::list _clients; uint32_t _cNextId; AwsEventHandler _eventHandler; AwsHandshakeHandler _handshakeHandler; @@ -272,41 +226,34 @@ class AsyncWebSocket: public AsyncWebHandler { void closeAll(uint16_t code=0, const char * message=NULL); void cleanupClients(uint16_t maxClients = DEFAULT_MAX_WS_CLIENTS); - void ping(uint32_t id, uint8_t *data=NULL, size_t len=0); - void pingAll(uint8_t *data=NULL, size_t len=0); // done + void ping(uint32_t id, const uint8_t *data=NULL, size_t len=0); + void pingAll(const uint8_t *data=NULL, size_t len=0); // done - void text(uint32_t id, const char * message, size_t len); - void text(uint32_t id, const char * message); - void text(uint32_t id, uint8_t * message, size_t len); - void text(uint32_t id, char * message); + void text(uint32_t id, const uint8_t * message, size_t len); + void text(uint32_t id, const char *message, size_t len); + void text(uint32_t id, const char *message); void text(uint32_t id, const String &message); void text(uint32_t id, const __FlashStringHelper *message); + void textAll(std::shared_ptr> buffer); + void textAll(const uint8_t *message, size_t len); void textAll(const char * message, size_t len); void textAll(const char * message); - void textAll(uint8_t * message, size_t len); - void textAll(char * message); void textAll(const String &message); void textAll(const __FlashStringHelper *message); // need to convert - void textAll(AsyncWebSocketMessageBuffer * buffer); - void binary(uint32_t id, const char * message, size_t len); - void binary(uint32_t id, const char * message); - void binary(uint32_t id, uint8_t * message, size_t len); - void binary(uint32_t id, char * message); + void binary(uint32_t id, const uint8_t *message, size_t len); + void binary(uint32_t id, const char *message, size_t len); + void binary(uint32_t id, const char *message); void binary(uint32_t id, const String &message); void binary(uint32_t id, const __FlashStringHelper *message, size_t len); - void binaryAll(const char * message, size_t len); - void binaryAll(const char * message); - void binaryAll(uint8_t * message, size_t len); - void binaryAll(char * message); + void binaryAll(std::shared_ptr> buffer); + void binaryAll(const uint8_t *message, size_t len); + void binaryAll(const char *message, size_t len); + void binaryAll(const char *message); void binaryAll(const String &message); void binaryAll(const __FlashStringHelper *message, size_t len); - void binaryAll(AsyncWebSocketMessageBuffer * buffer); - - void message(uint32_t id, AsyncWebSocketMessage *message); - void messageAll(AsyncWebSocketMultiMessage *message); size_t printf(uint32_t id, const char *format, ...) __attribute__ ((format (printf, 3, 4))); size_t printfAll(const char *format, ...) __attribute__ ((format (printf, 2, 3))); @@ -327,20 +274,12 @@ class AsyncWebSocket: public AsyncWebHandler { //system callbacks (do not call) uint32_t _getNextId(){ return _cNextId++; } - void _addClient(AsyncWebSocketClient * client); - void _handleDisconnect(AsyncWebSocketClient * client); + AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request); void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len); virtual bool canHandle(AsyncWebServerRequest *request) override final; virtual void handleRequest(AsyncWebServerRequest *request) override final; - - // messagebuffer functions/objects. - AsyncWebSocketMessageBuffer * makeBuffer(size_t size = 0); - AsyncWebSocketMessageBuffer * makeBuffer(uint8_t * data, size_t size); - LinkedList _buffers; - void _cleanBuffers(); - - AsyncWebSocketClientLinkedList getClients() const; + const std::list &getClients() const { return _clients; } }; //WebServer response to authenticate the socket and detach the tcp client from the web server request diff --git a/src/AsyncWebSynchronization.h b/src/AsyncWebSynchronization.h index f36c52d..0ff8ab6 100644 --- a/src/AsyncWebSynchronization.h +++ b/src/AsyncWebSynchronization.h @@ -8,39 +8,75 @@ #ifdef ESP32 // This is the ESP32 version of the Sync Lock, using the FreeRTOS Semaphore -class AsyncWebLock +// Modified 'AsyncWebLock' to just only use mutex since pxCurrentTCB is not +// always available. According to example by Arjan Filius, changed name, +// added unimplemented version for ESP8266 +class AsyncPlainLock { private: SemaphoreHandle_t _lock; - mutable void *_lockedBy; public: - AsyncWebLock() { + AsyncPlainLock() { _lock = xSemaphoreCreateBinary(); - _lockedBy = NULL; + // In this fails, the system is likely that much out of memory that + // we should abort anyways. If assertions are disabled, nothing is lost.. + assert(_lock); xSemaphoreGive(_lock); } - ~AsyncWebLock() { + ~AsyncPlainLock() { vSemaphoreDelete(_lock); } bool lock() const { - extern void *pxCurrentTCB; - if (_lockedBy != pxCurrentTCB) { xSemaphoreTake(_lock, portMAX_DELAY); - _lockedBy = pxCurrentTCB; return true; - } - return false; } void unlock() const { - _lockedBy = NULL; xSemaphoreGive(_lock); } }; +// This is the ESP32 version of the Sync Lock, using the FreeRTOS Semaphore +class AsyncWebLock +{ +private: + SemaphoreHandle_t _lock; + mutable TaskHandle_t _lockedBy{}; + +public: + AsyncWebLock() + { + _lock = xSemaphoreCreateBinary(); + // In this fails, the system is likely that much out of memory that + // we should abort anyways. If assertions are disabled, nothing is lost.. + assert(_lock); + _lockedBy = NULL; + xSemaphoreGive(_lock); + } + + ~AsyncWebLock() { + vSemaphoreDelete(_lock); + } + + bool lock() const { + const auto currentTask = xTaskGetCurrentTaskHandle(); + if (_lockedBy != currentTask) { + xSemaphoreTake(_lock, portMAX_DELAY); + _lockedBy = currentTask; + return true; + } + return false; + } + + void unlock() const { + _lockedBy = NULL; + xSemaphoreGive(_lock); + } +}; + #else // This is the 8266 version of the Sync Lock which is currently unimplemented @@ -61,6 +97,10 @@ public: void unlock() const { } }; + +// Same for AsyncPlainLock, for ESP8266 this is just the unimplemented version above. +using AsyncPlainLock = AsyncWebLock; + #endif class AsyncWebLockGuard @@ -82,6 +122,13 @@ public: _lock->unlock(); } } + + void unlock() { + if (_lock) { + _lock->unlock(); + _lock = NULL; + } + } }; -#endif // ASYNCWEBSYNCHRONIZATION_H_ \ No newline at end of file +#endif // ASYNCWEBSYNCHRONIZATION_H_ diff --git a/src/ESPAsyncWebServer.h b/src/ESPAsyncWebServer.h index 498ae62..1d3c404 100644 --- a/src/ESPAsyncWebServer.h +++ b/src/ESPAsyncWebServer.h @@ -24,6 +24,8 @@ #include "Arduino.h" #include +#include +#include #include "FS.h" #include "StringArray.h" @@ -59,22 +61,14 @@ class AsyncResponseStream; #ifndef WEBSERVER_H typedef enum { - HTTP_GET = 0b0000000000000001, - HTTP_POST = 0b0000000000000010, - HTTP_DELETE = 0b0000000000000100, - HTTP_PUT = 0b0000000000001000, - HTTP_PATCH = 0b0000000000010000, - HTTP_HEAD = 0b0000000000100000, - HTTP_OPTIONS = 0b0000000001000000, - HTTP_PROPFIND = 0b0000000010000000, - HTTP_LOCK = 0b0000000100000000, - HTTP_UNLOCK = 0b0000001000000000, - HTTP_PROPPATCH = 0b0000010000000000, - HTTP_MKCOL = 0b0000100000000000, - HTTP_MOVE = 0b0001000000000000, - HTTP_COPY = 0b0010000000000000, - HTTP_RESERVED = 0b0100000000000000, - HTTP_ANY = 0b0111111111111111, + HTTP_GET = 0b00000001, + HTTP_POST = 0b00000010, + HTTP_DELETE = 0b00000100, + HTTP_PUT = 0b00001000, + HTTP_PATCH = 0b00010000, + HTTP_HEAD = 0b00100000, + HTTP_OPTIONS = 0b01000000, + HTTP_ANY = 0b01111111, } WebRequestMethod; #endif @@ -94,7 +88,7 @@ namespace fs { //if this value is returned when asked for data, packet will not be sent and you will be asked for data again #define RESPONSE_TRY_AGAIN 0xFFFFFFFF -typedef uint16_t WebRequestMethodComposite; +typedef uint8_t WebRequestMethodComposite; typedef std::function ArDisconnectHandler; /* @@ -129,6 +123,9 @@ class AsyncWebHeader { String _value; public: + AsyncWebHeader() = default; + AsyncWebHeader(const AsyncWebHeader &) = default; + AsyncWebHeader(const String& name, const String& value): _name(name), _value(value){} AsyncWebHeader(const String& data): _name(), _value(){ if(!data) return; @@ -137,10 +134,12 @@ class AsyncWebHeader { _name = data.substring(0, index); _value = data.substring(index + 2); } - ~AsyncWebHeader(){} + + AsyncWebHeader &operator=(const AsyncWebHeader &) = default; + const String& name() const { return _name; } const String& value() const { return _value; } - String toString() const { return String(_name + F(": ") + _value + F("\r\n")); } + String toString() const { return _name + F(": ") + _value + F("\r\n"); } }; /* @@ -162,7 +161,7 @@ class AsyncWebServerRequest { AsyncWebServer* _server; AsyncWebHandler* _handler; AsyncWebServerResponse* _response; - StringArray _interestingHeaders; + std::vector _interestingHeaders; ArDisconnectHandler _onDisconnectfn; String _temp; @@ -184,9 +183,9 @@ class AsyncWebServerRequest { size_t _contentLength; size_t _parsedLength; - LinkedList _headers; + std::list _headers; LinkedList _params; - LinkedList _pathParams; + std::vector _pathParams; uint8_t _multiParseState; uint8_t _boundaryPosition; @@ -278,9 +277,12 @@ class AsyncWebServerRequest { bool hasHeader(const String& name) const; // check if header exists bool hasHeader(const __FlashStringHelper * data) const; // check if header exists - AsyncWebHeader* getHeader(const String& name) const; - AsyncWebHeader* getHeader(const __FlashStringHelper * data) const; - AsyncWebHeader* getHeader(size_t num) const; + AsyncWebHeader* getHeader(const String& name); + const AsyncWebHeader* getHeader(const String& name) const; + AsyncWebHeader* getHeader(const __FlashStringHelper * data); + const AsyncWebHeader* getHeader(const __FlashStringHelper * data) const; + AsyncWebHeader* getHeader(size_t num); + const AsyncWebHeader* getHeader(size_t num) const; size_t params() const; // get arguments count bool hasParam(const String& name, bool post=false, bool file=false) const; @@ -379,7 +381,7 @@ typedef enum { class AsyncWebServerResponse { protected: int _code; - LinkedList _headers; + std::list _headers; String _contentType; size_t _contentLength; bool _sendContentLength; @@ -462,17 +464,16 @@ class AsyncWebServer { }; class DefaultHeaders { - using headers_t = LinkedList; + using headers_t = std::list; headers_t _headers; - DefaultHeaders() - :_headers(headers_t([](AsyncWebHeader *h){ delete h; })) - {} public: - using ConstIterator = headers_t::ConstIterator; + DefaultHeaders() = default; + + using ConstIterator = headers_t::const_iterator; void addHeader(const String& name, const String& value){ - _headers.add(new AsyncWebHeader(name, value)); + _headers.emplace_back(name, value); } ConstIterator begin() const { return _headers.begin(); } @@ -480,6 +481,7 @@ public: DefaultHeaders(DefaultHeaders const &) = delete; DefaultHeaders &operator=(DefaultHeaders const &) = delete; + static DefaultHeaders &Instance() { static DefaultHeaders instance; return instance; diff --git a/src/SPIFFSEditor.cpp b/src/SPIFFSEditor.cpp index 343ed79..f0d143c 100644 --- a/src/SPIFFSEditor.cpp +++ b/src/SPIFFSEditor.cpp @@ -1,8 +1,6 @@ #include "SPIFFSEditor.h" #include -#define EDFS - #ifndef EDFS #include "edit.htm.gz.h" #endif diff --git a/src/StringArray.h b/src/StringArray.h index 4c0aa70..d5096e6 100644 --- a/src/StringArray.h +++ b/src/StringArray.h @@ -171,23 +171,4 @@ class LinkedList { } }; - -class StringArray : public LinkedList { -public: - - StringArray() : LinkedList(nullptr) {} - - bool containsIgnoreCase(const String& str){ - for (const auto& s : *this) { - if (str.equalsIgnoreCase(s)) { - return true; - } - } - return false; - } -}; - - - - #endif /* STRINGARRAY_H_ */ diff --git a/src/WebAuthentication.cpp b/src/WebAuthentication.cpp index 2d72de9..1a22afd 100644 --- a/src/WebAuthentication.cpp +++ b/src/WebAuthentication.cpp @@ -77,9 +77,9 @@ static bool getMD5(uint8_t * data, uint16_t len, char * output){//33 bytes or mo memset(_buf, 0x00, 16); #ifdef ESP32 mbedtls_md5_init(&_ctx); - mbedtls_md5_starts(&_ctx); - mbedtls_md5_update(&_ctx, data, len); - mbedtls_md5_finish(&_ctx, _buf); + mbedtls_md5_starts_ret(&_ctx); + mbedtls_md5_update_ret(&_ctx, data, len); + mbedtls_md5_finish_ret(&_ctx, _buf); #else MD5Init(&_ctx); MD5Update(&_ctx, data, len); diff --git a/src/WebRequest.cpp b/src/WebRequest.cpp index 025bc3b..31613df 100644 --- a/src/WebRequest.cpp +++ b/src/WebRequest.cpp @@ -51,9 +51,7 @@ AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c) , _expectingContinue(false) , _contentLength(0) , _parsedLength(0) - , _headers(LinkedList([](AsyncWebHeader *h){ delete h; })) , _params(LinkedList([](AsyncWebParameter *p){ delete p; })) - , _pathParams(LinkedList([](String *p){ delete p; })) , _multiParseState(0) , _boundaryPosition(0) , _itemStartIndex(0) @@ -76,12 +74,12 @@ AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c) } AsyncWebServerRequest::~AsyncWebServerRequest(){ - _headers.free(); + _headers.clear(); _params.free(); - _pathParams.free(); + _pathParams.clear(); - _interestingHeaders.free(); + _interestingHeaders.clear(); if(_response != NULL){ delete _response; @@ -182,11 +180,19 @@ void AsyncWebServerRequest::_onData(void *buf, size_t len){ } void AsyncWebServerRequest::_removeNotInterestingHeaders(){ - if (_interestingHeaders.containsIgnoreCase(F("ANY"))) return; // nothing to do - for(const auto& header: _headers){ - if(!_interestingHeaders.containsIgnoreCase(header->name().c_str())){ - _headers.remove(header); - } + if (std::any_of(std::begin(_interestingHeaders), std::end(_interestingHeaders), + [](const String &str){ return str.equalsIgnoreCase(F("ANY")); })) + return; // nothing to do + + for(auto iter = std::begin(_headers); iter != std::end(_headers); ) + { + const auto name = iter->name(); + + if (std::none_of(std::begin(_interestingHeaders), std::end(_interestingHeaders), + [&name](const String &str){ return str.equalsIgnoreCase(name); })) + iter = _headers.erase(iter); + else + iter++; } } @@ -247,7 +253,7 @@ void AsyncWebServerRequest::_addParam(AsyncWebParameter *p){ } void AsyncWebServerRequest::_addPathParam(const char *p){ - _pathParams.add(new String(p)); + _pathParams.emplace_back(p); } void AsyncWebServerRequest::_addGetParams(const String& params){ @@ -286,24 +292,6 @@ bool AsyncWebServerRequest::_parseReqHead(){ _method = HTTP_HEAD; } else if(m == F("OPTIONS")){ _method = HTTP_OPTIONS; - } else if(m == F("PROPFIND")){ - _method = HTTP_PROPFIND; - } else if(m == F("LOCK")){ - _method = HTTP_LOCK; - } else if(m == F("UNLOCK")){ - _method = HTTP_UNLOCK; - } else if(m == F("PROPPATCH")){ - _method = HTTP_PROPPATCH; - } else if(m == F("MKCOL")){ - _method = HTTP_MKCOL; - } else if(m == F("MOVE")){ - _method = HTTP_MOVE; - } else if(m == F("COPY")){ - _method = HTTP_COPY; - } else if(m == F("RESERVED")){ - _method = HTTP_RESERVED; - } else if(m == F("ANY")){ - _method = HTTP_ANY; } String g; @@ -379,7 +367,7 @@ bool AsyncWebServerRequest::_parseReqHeader(){ } } } - _headers.add(new AsyncWebHeader(name, value)); + _headers.emplace_back(name, value); } _temp = String(); return true; @@ -620,12 +608,12 @@ void AsyncWebServerRequest::_parseLine(){ } size_t AsyncWebServerRequest::headers() const{ - return _headers.length(); + return _headers.size(); } bool AsyncWebServerRequest::hasHeader(const String& name) const { for(const auto& h: _headers){ - if(h->name().equalsIgnoreCase(name)){ + if(h.name().equalsIgnoreCase(name)){ return true; } } @@ -636,22 +624,64 @@ bool AsyncWebServerRequest::hasHeader(const __FlashStringHelper * data) const { return hasHeader(String(data)); } -AsyncWebHeader* AsyncWebServerRequest::getHeader(const String& name) const { - for(const auto& h: _headers){ - if(h->name().equalsIgnoreCase(name)){ - return h; - } +AsyncWebHeader* AsyncWebServerRequest::getHeader(const String& name) { + auto iter = std::find_if(std::begin(_headers), std::end(_headers), + [&name](const AsyncWebHeader &header){ return header.name().equalsIgnoreCase(name); }); + + if (iter == std::end(_headers)) + return nullptr; + + return &(*iter); +} + +const AsyncWebHeader* AsyncWebServerRequest::getHeader(const String& name) const { + auto iter = std::find_if(std::begin(_headers), std::end(_headers), + [&name](const AsyncWebHeader &header){ return header.name().equalsIgnoreCase(name); }); + + if (iter == std::end(_headers)) + return nullptr; + + return &(*iter); +} + +AsyncWebHeader* AsyncWebServerRequest::getHeader(const __FlashStringHelper * data) { + PGM_P p = reinterpret_cast(data); + size_t n = strlen_P(p); + char * name = (char*) malloc(n+1); + if (name) { + strcpy_P(name, p); + AsyncWebHeader* result = getHeader( String(name)); + free(name); + return result; + } else { + return nullptr; } - return nullptr; } -AsyncWebHeader* AsyncWebServerRequest::getHeader(const __FlashStringHelper * data) const { - return getHeader(String(data)); +const AsyncWebHeader* AsyncWebServerRequest::getHeader(const __FlashStringHelper * data) const { + PGM_P p = reinterpret_cast(data); + size_t n = strlen_P(p); + char * name = (char*) malloc(n+1); + if (name) { + strcpy_P(name, p); + const AsyncWebHeader* result = getHeader( String(name)); + free(name); + return result; + } else { + return nullptr; + } } -AsyncWebHeader* AsyncWebServerRequest::getHeader(size_t num) const { - auto header = _headers.nth(num); - return header ? *header : nullptr; +AsyncWebHeader* AsyncWebServerRequest::getHeader(size_t num) { + if (num >= _headers.size()) + return nullptr; + return &(*std::next(std::begin(_headers), num)); +} + +const AsyncWebHeader* AsyncWebServerRequest::getHeader(size_t num) const { + if (num >= _headers.size()) + return nullptr; + return &(*std::next(std::begin(_headers), num)); } size_t AsyncWebServerRequest::params() const { @@ -690,8 +720,9 @@ AsyncWebParameter* AsyncWebServerRequest::getParam(size_t num) const { } void AsyncWebServerRequest::addInterestingHeader(const String& name){ - if(!_interestingHeaders.containsIgnoreCase(name)) - _interestingHeaders.add(name); + if(std::none_of(std::begin(_interestingHeaders), std::end(_interestingHeaders), + [&name](const String &str){ return str.equalsIgnoreCase(name); })) + _interestingHeaders.push_back(name); } void AsyncWebServerRequest::send(AsyncWebServerResponse *response){ @@ -883,12 +914,11 @@ const String& AsyncWebServerRequest::argName(size_t i) const { } const String& AsyncWebServerRequest::pathArg(size_t i) const { - auto param = _pathParams.nth(i); - return param ? **param : emptyString; + return i < _pathParams.size() ? _pathParams[i] : emptyString; } const String& AsyncWebServerRequest::header(const char* name) const { - AsyncWebHeader* h = getHeader(String(name)); + const AsyncWebHeader* h = getHeader(String(name)); return h ? h->value() : emptyString; } @@ -898,12 +928,12 @@ const String& AsyncWebServerRequest::header(const __FlashStringHelper * data) co const String& AsyncWebServerRequest::header(size_t i) const { - AsyncWebHeader* h = getHeader(i); + const AsyncWebHeader* h = getHeader(i); return h ? h->value() : emptyString; } const String& AsyncWebServerRequest::headerName(size_t i) const { - AsyncWebHeader* h = getHeader(i); + const AsyncWebHeader* h = getHeader(i); return h ? h->name() : emptyString; } @@ -940,14 +970,6 @@ const __FlashStringHelper *AsyncWebServerRequest::methodToString() const { else if(_method & HTTP_PATCH) return F("PATCH"); else if(_method & HTTP_HEAD) return F("HEAD"); else if(_method & HTTP_OPTIONS) return F("OPTIONS"); - else if(_method & HTTP_PROPFIND) return F("PROPFIND"); - else if(_method & HTTP_LOCK) return F("LOCK"); - else if(_method & HTTP_UNLOCK) return F("UNLOCK"); - else if(_method & HTTP_PROPPATCH) return F("PROPPATCH"); - else if(_method & HTTP_MKCOL) return F("MKCOL"); - else if(_method & HTTP_MOVE) return F("MOVE"); - else if(_method & HTTP_COPY) return F("COPY"); - else if(_method & HTTP_RESERVED) return F("RESERVED"); return F("UNKNOWN"); } diff --git a/src/WebResponseImpl.h b/src/WebResponseImpl.h index 9a64e3a..4a47225 100644 --- a/src/WebResponseImpl.h +++ b/src/WebResponseImpl.h @@ -27,6 +27,8 @@ #undef max #endif #include +#include + // It is possible to restore these defines, but one can use _min and _max instead. Or std::min, std::max. class AsyncBasicResponse: public AsyncWebServerResponse { @@ -122,7 +124,7 @@ class cbuf; class AsyncResponseStream: public AsyncAbstractResponse, public Print { private: - cbuf *_content; + std::unique_ptr _content; public: AsyncResponseStream(const String& contentType, size_t bufferSize); ~AsyncResponseStream(); diff --git a/src/WebResponses.cpp b/src/WebResponses.cpp index adeab98..22a549f 100644 --- a/src/WebResponses.cpp +++ b/src/WebResponses.cpp @@ -88,7 +88,6 @@ const __FlashStringHelper *AsyncWebServerResponse::responseCodeToString(int code AsyncWebServerResponse::AsyncWebServerResponse() : _code(0) - , _headers(LinkedList([](AsyncWebHeader *h){ delete h; })) , _contentType() , _contentLength(0) , _sendContentLength(true) @@ -99,14 +98,12 @@ AsyncWebServerResponse::AsyncWebServerResponse() , _writtenLength(0) , _state(RESPONSE_SETUP) { - for(auto header: DefaultHeaders::Instance()) { - _headers.add(new AsyncWebHeader(header->name(), header->value())); + for(const auto &header: DefaultHeaders::Instance()) { + _headers.emplace_back(header); } } -AsyncWebServerResponse::~AsyncWebServerResponse(){ - _headers.free(); -} +AsyncWebServerResponse::~AsyncWebServerResponse() = default; void AsyncWebServerResponse::setCode(int code){ if(_state == RESPONSE_SETUP) @@ -124,7 +121,7 @@ void AsyncWebServerResponse::setContentType(const String& type){ } void AsyncWebServerResponse::addHeader(const String& name, const String& value){ - _headers.add(new AsyncWebHeader(name, value)); + _headers.emplace_back(name, value); } String AsyncWebServerResponse::_assembleHead(uint8_t version){ @@ -150,10 +147,10 @@ String AsyncWebServerResponse::_assembleHead(uint8_t version){ } for(const auto& header: _headers){ - snprintf_P(buf, bufSize, PSTR("%s: %s\r\n"), header->name().c_str(), header->value().c_str()); + snprintf_P(buf, bufSize, PSTR("%s: %s\r\n"), header.name().c_str(), header.value().c_str()); out.concat(buf); } - _headers.free(); + _headers.clear(); out.concat(F("\r\n")); _headLength = out.length(); @@ -675,16 +672,15 @@ size_t AsyncProgmemResponse::_fillBuffer(uint8_t *data, size_t len){ * Response Stream (You can print/write/printf to it, up to the contentLen bytes) * */ -AsyncResponseStream::AsyncResponseStream(const String& contentType, size_t bufferSize){ +AsyncResponseStream::AsyncResponseStream(const String& contentType, size_t bufferSize) +{ _code = 200; _contentLength = 0; _contentType = contentType; - _content = new cbuf(bufferSize); + _content = std::unique_ptr(new cbuf(bufferSize)); //std::make_unique(bufferSize); } -AsyncResponseStream::~AsyncResponseStream(){ - delete _content; -} +AsyncResponseStream::~AsyncResponseStream() = default; size_t AsyncResponseStream::_fillBuffer(uint8_t *buf, size_t maxLen){ return _content->read((char*)buf, maxLen);