Aktualizace na verzi 3.3.7

This commit is contained in:
Pavel Brychta 2024-10-01 13:38:03 +02:00
parent 3ba4d9d10d
commit 2333497adc
29 changed files with 5128 additions and 827 deletions

1870
README.md

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -20,7 +20,6 @@ class CaptiveRequestHandler : public AsyncWebHandler {
virtual ~CaptiveRequestHandler() {}
bool canHandle(__unused AsyncWebServerRequest* request) {
// request->addInterestingHeader("ANY");
return true;
}

View File

@ -22,7 +22,6 @@ class CaptiveRequestHandler : public AsyncWebHandler {
virtual ~CaptiveRequestHandler() {}
bool canHandle(__unused AsyncWebServerRequest* request) {
// request->addInterestingHeader("ANY");
return true;
}

View File

@ -19,15 +19,62 @@
#include <ESPAsyncWebServer.h>
#include <ArduinoJson.h>
#include <AsyncJson.h>
#include <AsyncMessagePack.h>
#if ASYNC_JSON_SUPPORT == 1
#include <ArduinoJson.h>
#include <AsyncJson.h>
#include <AsyncMessagePack.h>
#endif
#include <LittleFS.h>
AsyncWebServer server(80);
AsyncEventSource events("/events");
AsyncWebSocket ws("/ws");
/////////////////////////////////////////////////////////////////////////////////////////////////////
// Middlewares
/////////////////////////////////////////////////////////////////////////////////////////////////////
// log incoming requests
LoggingMiddleware requestLogger;
// CORS
CorsMiddleware cors;
// maximum 5 requests per 10 seconds
RateLimitMiddleware rateLimit;
// filter out specific headers from the incoming request
HeaderFilterMiddleware headerFilter;
// remove all headers from the incoming request except the ones provided in the constructor
HeaderFreeMiddleware headerFree;
// basicAuth
AuthenticationMiddleware basicAuth;
AuthenticationMiddleware basicAuthHash;
// simple digest authentication
AuthenticationMiddleware digestAuth;
AuthenticationMiddleware digestAuthHash;
// complex authentication which adds request attributes for the next middlewares and handler
AsyncMiddlewareFunction complexAuth([](AsyncWebServerRequest* request, ArMiddlewareNext next) {
if (!request->authenticate("user", "password")) {
return request->requestAuthentication();
}
request->setAttribute("user", "Mathieu");
request->setAttribute("role", "staff");
next();
request->getResponse()->addHeader("X-Rate-Limit", "200");
});
AuthorizationMiddleware authz([](AsyncWebServerRequest* request) { return request->getAttribute("role") == "staff"; });
/////////////////////////////////////////////////////////////////////////////////////////////////////
const char* PARAM_MESSAGE PROGMEM = "message";
const char* SSE_HTLM PROGMEM = R"(
<!DOCTYPE html>
@ -64,8 +111,10 @@ void notFound(AsyncWebServerRequest* request) {
request->send(404, "text/plain", "Not found");
}
#if ASYNC_JSON_SUPPORT == 1
AsyncCallbackJsonWebHandler* jsonHandler = new AsyncCallbackJsonWebHandler("/json2");
AsyncCallbackMessagePackWebHandler* msgPackHandler = new AsyncCallbackMessagePackWebHandler("/msgpack2");
#endif
void setup() {
@ -85,6 +134,171 @@ void setup() {
WiFi.softAP("esp-captive");
#endif
// curl -v -X GET http://192.168.4.1/handler-not-sending-response
server.on("/handler-not-sending-response", HTTP_GET, [](AsyncWebServerRequest* request) {
// handler forgot to send a response to the client => 501 Not Implemented
});
// This is possible to replace a response.
// the previous one will be deleted.
// response sending happens when the handler returns.
// curl -v -X GET http://192.168.4.1/replace
server.on("/replace", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world");
// oups! finally we want to send a different response
request->send(400, "text/plain", "validation error");
#ifndef TARGET_RP2040
Serial.printf("Free heap: %" PRIu32 "\n", ESP.getFreeHeap());
#endif
});
///////////////////////////////////////////////////////////////////////
// Request header manipulations
///////////////////////////////////////////////////////////////////////
// curl -v -X GET -H "x-remove-me: value" http://192.168.4.1/headers
server.on("/headers", HTTP_GET, [](AsyncWebServerRequest* request) {
Serial.printf("Request Headers:\n");
for (auto& h : request->getHeaders())
Serial.printf("Request Header: %s = %s\n", h.name().c_str(), h.value().c_str());
// remove x-remove-me header
request->removeHeader("x-remove-me");
Serial.printf("Request Headers:\n");
for (auto& h : request->getHeaders())
Serial.printf("Request Header: %s = %s\n", h.name().c_str(), h.value().c_str());
std::vector<const char*> headers;
request->getHeaderNames(headers);
for (auto& h : headers)
Serial.printf("Request Header Name: %s\n", h);
request->send(200);
});
///////////////////////////////////////////////////////////////////////
// Middlewares at server level (will apply to all requests)
///////////////////////////////////////////////////////////////////////
requestLogger.setOutput(Serial);
basicAuth.setUsername("admin");
basicAuth.setPassword("admin");
basicAuth.setRealm("MyApp");
basicAuth.setAuthFailureMessage("Authentication failed");
basicAuth.setAuthType(AsyncAuthType::AUTH_BASIC);
basicAuth.generateHash();
basicAuthHash.setUsername("admin");
basicAuthHash.setPasswordHash("YWRtaW46YWRtaW4="); // BASE64(admin:admin)
basicAuthHash.setRealm("MyApp");
basicAuthHash.setAuthFailureMessage("Authentication failed");
basicAuthHash.setAuthType(AsyncAuthType::AUTH_BASIC);
digestAuth.setUsername("admin");
digestAuth.setPassword("admin");
digestAuth.setRealm("MyApp");
digestAuth.setAuthFailureMessage("Authentication failed");
digestAuth.setAuthType(AsyncAuthType::AUTH_DIGEST);
digestAuth.generateHash();
digestAuthHash.setUsername("admin");
digestAuthHash.setPasswordHash("f499b71f9a36d838b79268e145e132f7"); // MD5(user:realm:pass)
digestAuthHash.setRealm("MyApp");
digestAuthHash.setAuthFailureMessage("Authentication failed");
digestAuthHash.setAuthType(AsyncAuthType::AUTH_DIGEST);
rateLimit.setMaxRequests(5);
rateLimit.setWindowSize(10);
headerFilter.filter("X-Remove-Me");
headerFree.keep("X-Keep-Me");
headerFree.keep("host");
// global middleware
server.addMiddleware(&requestLogger);
server.addMiddlewares({&rateLimit, &cors, &headerFilter});
cors.setOrigin("http://192.168.4.1");
cors.setMethods("POST, GET, OPTIONS, DELETE");
cors.setHeaders("X-Custom-Header");
cors.setAllowCredentials(false);
cors.setMaxAge(600);
// Test CORS preflight request
// curl -v -X OPTIONS -H "origin: http://192.168.4.1" http://192.168.4.1/middleware/cors
server.on("/middleware/cors", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
});
// curl -v -X GET -H "x-remove-me: value" http://192.168.4.1/middleware/test-header-filter
// - requestLogger will log the incoming headers (including x-remove-me)
// - headerFilter will remove x-remove-me header
// - handler will log the remaining headers
server.on("/middleware/test-header-filter", HTTP_GET, [](AsyncWebServerRequest* request) {
for (auto& h : request->getHeaders())
Serial.printf("Request Header: %s = %s\n", h.name().c_str(), h.value().c_str());
request->send(200);
});
// curl -v -X GET -H "x-keep-me: value" http://192.168.4.1/middleware/test-header-free
// - requestLogger will log the incoming headers (including x-keep-me)
// - headerFree will remove all headers except x-keep-me and host
// - handler will log the remaining headers (x-keep-me and host)
server.on("/middleware/test-header-free", HTTP_GET, [](AsyncWebServerRequest* request) {
for (auto& h : request->getHeaders())
Serial.printf("Request Header: %s = %s\n", h.name().c_str(), h.value().c_str());
request->send(200);
})
.addMiddleware(&headerFree);
// basic authentication method
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic
server.on("/middleware/auth-basic", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&basicAuth);
// basic authentication method with hash
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic-hash
server.on("/middleware/auth-basic-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&basicAuthHash);
// digest authentication
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest
server.on("/middleware/auth-digest", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&digestAuth);
// digest authentication with hash
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest-hash
server.on("/middleware/auth-digest-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&digestAuthHash);
// test digest auth with cors
// curl -v -X GET -H "origin: http://192.168.4.1" --digest -u user:password http://192.168.4.1/middleware/auth-custom
server.on("/middleware/auth-custom", HTTP_GET, [](AsyncWebServerRequest* request) {
String buffer = "Hello ";
buffer.concat(request->getAttribute("user"));
buffer.concat(" with role: ");
buffer.concat(request->getAttribute("role"));
request->send(200, "text/plain", buffer);
})
.addMiddlewares({&complexAuth, &authz});
///////////////////////////////////////////////////////////////////////
// curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
// curl -v -X POST -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
server.on("/redirect", HTTP_GET | HTTP_POST, [](AsyncWebServerRequest* request) {
request->redirect("/");
});
server.on("/", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world");
});
@ -139,6 +353,7 @@ void setup() {
request->send(200, "text/plain", "Hello, POST: " + message);
});
#if ASYNC_JSON_SUPPORT == 1
// JSON
// receives JSON and sends JSON
@ -184,6 +399,7 @@ void setup() {
response->setLength();
request->send(response);
});
#endif
events.onConnect([](AsyncEventSourceClient* client) {
if (client->lastId()) {
@ -197,7 +413,7 @@ void setup() {
});
ws.onEvent([](AsyncWebSocket* server, AsyncWebSocketClient* client, AwsEventType type, void* arg, uint8_t* data, size_t len) {
(void) len;
(void)len;
if (type == WS_EVT_CONNECT) {
Serial.println("ws connect");
client->setCloseClientOnQueueFull(false);
@ -222,8 +438,11 @@ void setup() {
server.addHandler(&events);
server.addHandler(&ws);
#if ASYNC_JSON_SUPPORT == 1
server.addHandler(jsonHandler);
server.addHandler(msgPackHandler);
#endif
server.onNotFound(notFound);
@ -244,6 +463,10 @@ void loop() {
}
if (now - lastWS >= deltaWS) {
ws.printfAll("kp%.4f", (10.0 / 3.0));
// ws.getClients
for (auto& client : ws.getClients()) {
client.text("kp%.4f", (10.0 / 3.0));
}
lastWS = millis();
}
}

View File

@ -1,6 +1,6 @@
{
"name": "ESPAsyncWebServer",
"version": "3.2.4",
"version": "3.3.7",
"description": "Asynchronous HTTP and WebSocket Server Library for ESP32, ESP8266 and RP2040. Supports: WebSocket, SSE, Authentication, Arduino Json 7, File Upload, Static File serving, URL Rewrite, URL Redirect, etc.",
"keywords": "http,async,websocket,webserver",
"homepage": "https://github.com/mathieucarbou/ESPAsyncWebServer",

View File

@ -1,5 +1,5 @@
name=ESPAsyncWebServer
version=3.2.4
version=3.3.7
author=Me-No-Dev
maintainer=Mathieu Carbou <mathieu.carbou@gmail.com>
sentence=Asynchronous HTTP and WebSocket Server Library for ESP32, ESP8266 and RP2040

View File

@ -1,3 +1,13 @@
[platformio]
default_envs = arduino-2, arduino-3, arduino-310rc1, esp8266, raspberrypi
lib_dir = .
; src_dir = examples/CaptivePortal
src_dir = examples/SimpleServer
; src_dir = examples/StreamFiles
; src_dir = examples/Filters
; src_dir = examples/Draft
; src_dir = examples/issues/Issue14
[env]
framework = arduino
build_flags =
@ -13,69 +23,89 @@ build_flags =
upload_protocol = esptool
monitor_speed = 115200
monitor_filters = esp32_exception_decoder, log2file
[platformio]
lib_dir = .
; src_dir = examples/CaptivePortal
src_dir = examples/SimpleServer
; src_dir = examples/StreamFiles
; src_dir = examples/Filters
; src_dir = examples/Draft
; src_dir = examples/issues/Issue14
lib_deps =
; bblanchon/ArduinoJson @ 5.13.4
; bblanchon/ArduinoJson @ 6.21.5
bblanchon/ArduinoJson @ 7.2.0
mathieucarbou/AsyncTCP @ 3.2.5
board = esp32dev
[env:arduino-2]
platform = espressif32@6.8.1
board = esp32dev
lib_deps =
bblanchon/ArduinoJson @ 7.1.0
mathieucarbou/AsyncTCP @ 3.2.5
platform = espressif32@6.9.0
[env:arduino-3]
platform = espressif32
platform_packages=
platformio/framework-arduinoespressif32 @ https://github.com/espressif/arduino-esp32.git#3.0.4
platformio/framework-arduinoespressif32-libs @ https://github.com/espressif/arduino-esp32/releases/download/3.0.4/esp32-arduino-libs-3.0.4.zip
board = esp32dev
platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.05/platform-espressif32.zip
; board = esp32-s3-devkitc-1
; board = esp32-c6-devkitc-1
[env:arduino-3-no-json]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.05/platform-espressif32.zip
; board = esp32-s3-devkitc-1
; board = esp32-c6-devkitc-1
lib_deps =
bblanchon/ArduinoJson @ 7.1.0
mathieucarbou/AsyncTCP @ 3.2.5
[env:arduino-310rc1]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/53.03.10-rc1/platform-espressif32.zip
; board = esp32-s3-devkitc-1
; board = esp32-c6-devkitc-1
[env:esp8266]
platform = espressif8266
board = huzzah
; board = d1_mini
lib_deps =
bblanchon/ArduinoJson @ 7.1.0
bblanchon/ArduinoJson @ 7.2.0
esphome/ESPAsyncTCP-esphome @ 2.0.0
; PlatformIO support for Raspberry Pi Pico is not official
; https://github.com/platformio/platform-raspberrypi/pull/36
; https://github.com/earlephilhower/arduino-pico/blob/master/docs/platformio.rst
; board settings: https://github.com/earlephilhower/arduino-pico/blob/master/tools/json/rpipico.json
[env:rpipicow]
[env:raspberrypi]
upload_protocol = picotool
platform = https://github.com/maxgerhardt/platform-raspberrypi.git
; platform = raspberrypi
platform = https://github.com/maxgerhardt/platform-raspberrypi.git#f2687073f73d554c9db41f29b4769fd9703f4e55
board = rpipicow
lib_deps =
bblanchon/ArduinoJson @ 7.1.0
bblanchon/ArduinoJson @ 7.2.0
khoih-prog/AsyncTCP_RP2040W @ 1.2.0
build_flags = ${env.build_flags}
-Wno-missing-field-initializers
[env:pioarduino-esp32dev]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.04/platform-espressif32.zip
board = esp32dev
; CI
[env:ci-arduino-2]
platform = espressif32@6.9.0
board = ${sysenv.PIO_BOARD}
[env:ci-arduino-3]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.05/platform-espressif32.zip
board = ${sysenv.PIO_BOARD}
[env:ci-arduino-3-no-json]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.05/platform-espressif32.zip
board = ${sysenv.PIO_BOARD}
lib_deps =
bblanchon/ArduinoJson @ 7.1.0
mathieucarbou/AsyncTCP @ 3.2.5
[env:pioarduino-c6]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.04/platform-espressif32.zip
board = esp32-c6-devkitc-1
lib_deps =
bblanchon/ArduinoJson @ 7.1.0
mathieucarbou/AsyncTCP @ 3.2.5
[env:ci-arduino-310rc1]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/53.03.10-rc1/platform-espressif32.zip
board = ${sysenv.PIO_BOARD}
[env:pioarduino-h2]
platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.04/platform-espressif32.zip
board = esp32-h2-devkitm-1
[env:ci-esp8266]
platform = espressif8266
board = ${sysenv.PIO_BOARD}
lib_deps =
bblanchon/ArduinoJson @ 7.1.0
mathieucarbou/AsyncTCP @ 3.2.5
bblanchon/ArduinoJson @ 7.2.0
esphome/ESPAsyncTCP-esphome @ 2.0.0
[env:ci-raspberrypi]
; platform = raspberrypi
platform = https://github.com/maxgerhardt/platform-raspberrypi.git#f2687073f73d554c9db41f29b4769fd9703f4e55
board = ${sysenv.PIO_BOARD}
lib_deps =
bblanchon/ArduinoJson @ 7.2.0
khoih-prog/AsyncTCP_RP2040W @ 1.2.0
build_flags = ${env.build_flags}
-Wno-missing-field-initializers

View File

@ -22,7 +22,6 @@
#include <rom/ets_sys.h>
#endif
#include "AsyncEventSource.h"
#include "literals.h"
using namespace asyncsrv;
@ -209,7 +208,7 @@ void AsyncEventSourceClient::_queueMessage(const char* message, size_t len) {
}
}
void AsyncEventSourceClient::_onAck(size_t len, uint32_t time) {
void AsyncEventSourceClient::_onAck(size_t len __attribute__((unused)), uint32_t time __attribute__((unused))) {
#ifdef ESP32
// Same here, acquiring the lock early
std::lock_guard<std::mutex> lock(_lockmq);
@ -289,7 +288,9 @@ void AsyncEventSource::onConnect(ArEventHandlerFunction cb) {
}
void AsyncEventSource::authorizeConnect(ArAuthorizeConnectHandler cb) {
_authorizeConnectHandler = cb;
AuthorizationMiddleware* m = new AuthorizationMiddleware(401, cb);
m->_freeOnRemoval = true;
addMiddleware(m);
}
void AsyncEventSource::_addClient(AsyncEventSourceClient* client) {
@ -374,20 +375,10 @@ bool AsyncEventSource::canHandle(AsyncWebServerRequest* request) {
if (request->method() != HTTP_GET || !request->url().equals(_url)) {
return false;
}
request->addInterestingHeader(T_Last_Event_ID);
request->addInterestingHeader(T_Cookie);
return true;
}
void AsyncEventSource::handleRequest(AsyncWebServerRequest* request) {
if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())) {
return request->requestAuthentication();
}
if (_authorizeConnectHandler != NULL) {
if (!_authorizeConnectHandler(request)) {
return request->send(401);
}
}
request->send(new AsyncEventSourceResponse(this));
}

View File

@ -21,7 +21,6 @@
#define ASYNCEVENTSOURCE_H_
#include <Arduino.h>
#include <list>
#ifdef ESP32
#include <AsyncTCP.h>
#include <mutex>
@ -53,7 +52,7 @@ class AsyncEventSource;
class AsyncEventSourceResponse;
class AsyncEventSourceClient;
using ArEventHandlerFunction = std::function<void(AsyncEventSourceClient* client)>;
using ArAuthorizeConnectHandler = std::function<bool(AsyncWebServerRequest* request)>;
using ArAuthorizeConnectHandler = ArAuthorizeFunction;
class AsyncEventSourceMessage {
private:
@ -116,7 +115,6 @@ class AsyncEventSource : public AsyncWebHandler {
mutable std::mutex _client_queue_lock;
#endif
ArEventHandlerFunction _connectcb{nullptr};
ArAuthorizeConnectHandler _authorizeConnectHandler;
public:
AsyncEventSource(const String& url) : _url(url) {};

155
src/AsyncJson.cpp Normal file
View File

@ -0,0 +1,155 @@
#include "AsyncJson.h"
#if ASYNC_JSON_SUPPORT == 1
#if ARDUINOJSON_VERSION_MAJOR == 5
AsyncJsonResponse::AsyncJsonResponse(bool isArray) : _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_json;
if (isArray)
_root = _jsonBuffer.createArray();
else
_root = _jsonBuffer.createObject();
}
#elif ARDUINOJSON_VERSION_MAJOR == 6
AsyncJsonResponse::AsyncJsonResponse(bool isArray, size_t maxJsonBufferSize) : _jsonBuffer(maxJsonBufferSize), _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_json;
if (isArray)
_root = _jsonBuffer.createNestedArray();
else
_root = _jsonBuffer.createNestedObject();
}
#else
AsyncJsonResponse::AsyncJsonResponse(bool isArray) : _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_json;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
#endif
size_t AsyncJsonResponse::setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measureLength();
#else
_contentLength = measureJson(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t AsyncJsonResponse::_fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.printTo(dest);
#else
serializeJson(_root, dest);
#endif
return len;
}
#if ARDUINOJSON_VERSION_MAJOR == 6
PrettyAsyncJsonResponse::PrettyAsyncJsonResponse(bool isArray, size_t maxJsonBufferSize) : AsyncJsonResponse{isArray, maxJsonBufferSize} {}
#else
PrettyAsyncJsonResponse::PrettyAsyncJsonResponse(bool isArray) : AsyncJsonResponse{isArray} {}
#endif
size_t PrettyAsyncJsonResponse::setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measurePrettyLength();
#else
_contentLength = measureJsonPretty(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t PrettyAsyncJsonResponse::_fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.prettyPrintTo(dest);
#else
serializeJsonPretty(_root, dest);
#endif
return len;
}
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncCallbackJsonWebHandler::AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest, size_t maxJsonBufferSize)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), maxJsonBufferSize(maxJsonBufferSize), _maxContentLength(16384) {}
#else
AsyncCallbackJsonWebHandler::AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
#endif
bool AsyncCallbackJsonWebHandler::canHandle(AsyncWebServerRequest* request) {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(asyncsrv::T_application_json))
return false;
return true;
}
void AsyncCallbackJsonWebHandler::handleRequest(AsyncWebServerRequest* request) {
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
#if ARDUINOJSON_VERSION_MAJOR == 5
DynamicJsonBuffer jsonBuffer;
JsonVariant json = jsonBuffer.parse((uint8_t*)(request->_tempObject));
if (json.success()) {
#elif ARDUINOJSON_VERSION_MAJOR == 6
DynamicJsonDocument jsonBuffer(this->maxJsonBufferSize);
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#else
JsonDocument jsonBuffer;
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#endif
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
void AsyncCallbackJsonWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}
#endif // ASYNC_JSON_SUPPORT

View File

@ -34,222 +34,98 @@
*/
#ifndef ASYNC_JSON_H_
#define ASYNC_JSON_H_
#include <ArduinoJson.h>
#include <ESPAsyncWebServer.h>
#include "ChunkPrint.h"
#if __has_include("ArduinoJson.h")
#include <ArduinoJson.h>
#if ARDUINOJSON_VERSION_MAJOR >= 5
#define ASYNC_JSON_SUPPORT 1
#else
#define ASYNC_JSON_SUPPORT 0
#endif // ARDUINOJSON_VERSION_MAJOR >= 5
#endif // __has_include("ArduinoJson.h")
#if ARDUINOJSON_VERSION_MAJOR == 6
#ifndef DYNAMIC_JSON_DOCUMENT_SIZE
#define DYNAMIC_JSON_DOCUMENT_SIZE 1024
#if ASYNC_JSON_SUPPORT == 1
#include <ESPAsyncWebServer.h>
#include "ChunkPrint.h"
#if ARDUINOJSON_VERSION_MAJOR == 6
#ifndef DYNAMIC_JSON_DOCUMENT_SIZE
#define DYNAMIC_JSON_DOCUMENT_SIZE 1024
#endif
#endif
#endif
constexpr const char* JSON_MIMETYPE = "application/json";
/*
* Json Response
* */
class AsyncJsonResponse : public AsyncAbstractResponse {
protected:
#if ARDUINOJSON_VERSION_MAJOR == 5
#if ARDUINOJSON_VERSION_MAJOR == 5
DynamicJsonBuffer _jsonBuffer;
#elif ARDUINOJSON_VERSION_MAJOR == 6
#elif ARDUINOJSON_VERSION_MAJOR == 6
DynamicJsonDocument _jsonBuffer;
#else
#else
JsonDocument _jsonBuffer;
#endif
#endif
JsonVariant _root;
bool _isValid;
public:
#if ARDUINOJSON_VERSION_MAJOR == 5
AsyncJsonResponse(bool isArray = false) : _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.createArray();
else
_root = _jsonBuffer.createObject();
}
#elif ARDUINOJSON_VERSION_MAJOR == 6
AsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE) : _jsonBuffer(maxJsonBufferSize), _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.createNestedArray();
else
_root = _jsonBuffer.createNestedObject();
}
#else
AsyncJsonResponse(bool isArray = false) : _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
#endif
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
AsyncJsonResponse(bool isArray = false);
#endif
JsonVariant& getRoot() { return _root; }
bool _sourceValid() const { return _isValid; }
size_t setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measureLength();
#else
_contentLength = measureJson(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t setLength();
size_t getSize() const { return _jsonBuffer.size(); }
#if ARDUINOJSON_VERSION_MAJOR >= 6
size_t _fillBuffer(uint8_t* data, size_t len);
#if ARDUINOJSON_VERSION_MAJOR >= 6
bool overflowed() const { return _jsonBuffer.overflowed(); }
#endif
size_t _fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.printTo(dest);
#else
serializeJson(_root, dest);
#endif
return len;
}
#endif
};
class PrettyAsyncJsonResponse : public AsyncJsonResponse {
public:
#if ARDUINOJSON_VERSION_MAJOR == 6
PrettyAsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE) : AsyncJsonResponse{isArray, maxJsonBufferSize} {}
#else
PrettyAsyncJsonResponse(bool isArray = false) : AsyncJsonResponse{isArray} {}
#endif
size_t setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measurePrettyLength();
#else
_contentLength = measureJsonPretty(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t _fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.prettyPrintTo(dest);
#else
serializeJsonPretty(_root, dest);
#endif
return len;
}
#if ARDUINOJSON_VERSION_MAJOR == 6
PrettyAsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
PrettyAsyncJsonResponse(bool isArray = false);
#endif
size_t setLength();
size_t _fillBuffer(uint8_t* data, size_t len);
};
typedef std::function<void(AsyncWebServerRequest* request, JsonVariant& json)> ArJsonRequestHandlerFunction;
class AsyncCallbackJsonWebHandler : public AsyncWebHandler {
private:
protected:
const String _uri;
String _uri;
WebRequestMethodComposite _method;
ArJsonRequestHandlerFunction _onRequest;
size_t _contentLength;
#if ARDUINOJSON_VERSION_MAJOR == 6
const size_t maxJsonBufferSize;
#endif
#if ARDUINOJSON_VERSION_MAJOR == 6
size_t maxJsonBufferSize;
#endif
size_t _maxContentLength;
public:
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), maxJsonBufferSize(maxJsonBufferSize), _maxContentLength(16384) {}
#else
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
#endif
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr);
#endif
void setMethod(WebRequestMethodComposite method) { _method = method; }
void setMaxContentLength(int maxContentLength) { _maxContentLength = maxContentLength; }
void onRequest(ArJsonRequestHandlerFunction fn) { _onRequest = fn; }
virtual bool canHandle(AsyncWebServerRequest* request) override final {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(JSON_MIMETYPE))
return false;
request->addInterestingHeader("ANY");
return true;
}
virtual void handleRequest(AsyncWebServerRequest* request) override final {
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
return request->requestAuthentication();
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
#if ARDUINOJSON_VERSION_MAJOR == 5
DynamicJsonBuffer jsonBuffer;
JsonVariant json = jsonBuffer.parse((uint8_t*)(request->_tempObject));
if (json.success()) {
#elif ARDUINOJSON_VERSION_MAJOR == 6
DynamicJsonDocument jsonBuffer(this->maxJsonBufferSize);
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#else
JsonDocument jsonBuffer;
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#endif
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {
}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}
virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; }
virtual bool canHandle(AsyncWebServerRequest* request) override final;
virtual void handleRequest(AsyncWebServerRequest* request) override final;
virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final;
virtual bool isRequestHandlerTrivial() override final { return !_onRequest; }
};
#endif
#endif // ASYNC_JSON_SUPPORT == 1
#endif // ASYNC_JSON_H_

106
src/AsyncMessagePack.cpp Normal file
View File

@ -0,0 +1,106 @@
#include "AsyncMessagePack.h"
#if ASYNC_MSG_PACK_SUPPORT == 1
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncMessagePackResponse::AsyncMessagePackResponse(bool isArray, size_t maxJsonBufferSize) : _jsonBuffer(maxJsonBufferSize), _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_msgpack;
if (isArray)
_root = _jsonBuffer.createNestedArray();
else
_root = _jsonBuffer.createNestedObject();
}
#else
AsyncMessagePackResponse::AsyncMessagePackResponse(bool isArray) : _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_msgpack;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
#endif
size_t AsyncMessagePackResponse::setLength() {
_contentLength = measureMsgPack(_root);
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t AsyncMessagePackResponse::_fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
serializeMsgPack(_root, dest);
return len;
}
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncCallbackMessagePackWebHandler::AsyncCallbackMessagePackWebHandler(const String& uri, ArMessagePackRequestHandlerFunction onRequest, size_t maxJsonBufferSize)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), maxJsonBufferSize(maxJsonBufferSize), _maxContentLength(16384) {}
#else
AsyncCallbackMessagePackWebHandler::AsyncCallbackMessagePackWebHandler(const String& uri, ArMessagePackRequestHandlerFunction onRequest)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
#endif
bool AsyncCallbackMessagePackWebHandler::canHandle(AsyncWebServerRequest* request) {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(asyncsrv::T_application_msgpack))
return false;
return true;
}
void AsyncCallbackMessagePackWebHandler::handleRequest(AsyncWebServerRequest* request) {
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
#if ARDUINOJSON_VERSION_MAJOR == 6
DynamicJsonDocument jsonBuffer(this->maxJsonBufferSize);
DeserializationError error = deserializeMsgPack(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#else
JsonDocument jsonBuffer;
DeserializationError error = deserializeMsgPack(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#endif
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
void AsyncCallbackMessagePackWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}
#endif // ASYNC_MSG_PACK_SUPPORT

View File

@ -21,125 +21,82 @@
server.addHandler(handler);
*/
#include <ArduinoJson.h>
#include <ESPAsyncWebServer.h>
#if __has_include("ArduinoJson.h")
#include <ArduinoJson.h>
#if ARDUINOJSON_VERSION_MAJOR >= 6
#define ASYNC_MSG_PACK_SUPPORT 1
#else
#define ASYNC_MSG_PACK_SUPPORT 0
#endif // ARDUINOJSON_VERSION_MAJOR >= 6
#endif // __has_include("ArduinoJson.h")
#include "ChunkPrint.h"
#include "literals.h"
#if ASYNC_MSG_PACK_SUPPORT == 1
#include <ESPAsyncWebServer.h>
#include "ChunkPrint.h"
#if ARDUINOJSON_VERSION_MAJOR == 6
#ifndef DYNAMIC_JSON_DOCUMENT_SIZE
#define DYNAMIC_JSON_DOCUMENT_SIZE 1024
#endif
#endif
class AsyncMessagePackResponse : public AsyncAbstractResponse {
protected:
#if ARDUINOJSON_VERSION_MAJOR == 6
DynamicJsonDocument _jsonBuffer;
#else
JsonDocument _jsonBuffer;
#endif
JsonVariant _root;
bool _isValid;
public:
AsyncMessagePackResponse(bool isArray = false) : _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_msgpack;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncMessagePackResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
AsyncMessagePackResponse(bool isArray = false);
#endif
JsonVariant& getRoot() { return _root; }
bool _sourceValid() const { return _isValid; }
size_t setLength() {
_contentLength = measureMsgPack(_root);
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t setLength();
size_t getSize() const { return _jsonBuffer.size(); }
size_t _fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
serializeMsgPack(_root, dest);
return len;
}
size_t _fillBuffer(uint8_t* data, size_t len);
#if ARDUINOJSON_VERSION_MAJOR >= 6
bool overflowed() const { return _jsonBuffer.overflowed(); }
#endif
};
class AsyncCallbackMessagePackWebHandler : public AsyncWebHandler {
public:
typedef std::function<void(AsyncWebServerRequest* request, JsonVariant& json)> ArJsonRequestHandlerFunction;
typedef std::function<void(AsyncWebServerRequest* request, JsonVariant& json)> ArMessagePackRequestHandlerFunction;
class AsyncCallbackMessagePackWebHandler : public AsyncWebHandler {
protected:
const String _uri;
String _uri;
WebRequestMethodComposite _method;
ArJsonRequestHandlerFunction _onRequest;
ArMessagePackRequestHandlerFunction _onRequest;
size_t _contentLength;
#if ARDUINOJSON_VERSION_MAJOR == 6
size_t maxJsonBufferSize;
#endif
size_t _maxContentLength;
public:
AsyncCallbackMessagePackWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncCallbackMessagePackWebHandler(const String& uri, ArMessagePackRequestHandlerFunction onRequest = nullptr, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
AsyncCallbackMessagePackWebHandler(const String& uri, ArMessagePackRequestHandlerFunction onRequest = nullptr);
#endif
void setMethod(WebRequestMethodComposite method) { _method = method; }
void setMaxContentLength(int maxContentLength) { _maxContentLength = maxContentLength; }
void onRequest(ArJsonRequestHandlerFunction fn) { _onRequest = fn; }
virtual bool canHandle(AsyncWebServerRequest* request) override final {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(asyncsrv::T_application_msgpack))
return false;
request->addInterestingHeader("ANY");
return true;
}
virtual void handleRequest(AsyncWebServerRequest* request) override final {
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
return request->requestAuthentication();
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
JsonDocument jsonBuffer;
DeserializationError error = deserializeMsgPack(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
void onRequest(ArMessagePackRequestHandlerFunction fn) { _onRequest = fn; }
virtual bool canHandle(AsyncWebServerRequest* request) override final;
virtual void handleRequest(AsyncWebServerRequest* request) override final;
virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}
virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; }
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final;
virtual bool isRequestHandlerTrivial() override final { return !_onRequest; }
};
#endif // ASYNC_MSG_PACK_SUPPORT == 1

22
src/AsyncWebHeader.cpp Normal file
View File

@ -0,0 +1,22 @@
#include <ESPAsyncWebServer.h>
AsyncWebHeader::AsyncWebHeader(const String& data) {
if (!data)
return;
int index = data.indexOf(':');
if (index < 0)
return;
_name = data.substring(0, index);
_value = data.substring(index + 2);
}
String AsyncWebHeader::toString() const {
String str;
str.reserve(_name.length() + _value.length() + 2);
str.concat(_name);
str.concat((char)0x3a);
str.concat((char)0x20);
str.concat(_value);
str.concat(asyncsrv::T_rn);
return str;
}

View File

@ -381,7 +381,7 @@ size_t AsyncWebSocketClient::queueLen() const {
std::lock_guard<std::mutex> lock(_lock);
#endif
return _messageQueue.size() + _controlQueue.size();
return _messageQueue.size();
}
bool AsyncWebSocketClient::canSend() const {
@ -1078,13 +1078,6 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest* request) {
if (request->method() != HTTP_GET || !request->url().equals(_url) || !request->isExpectedRequestedConnType(RCT_WS))
return false;
request->addInterestingHeader(WS_STR_CONNECTION);
request->addInterestingHeader(WS_STR_UPGRADE);
request->addInterestingHeader(WS_STR_ORIGIN);
request->addInterestingHeader(WS_STR_COOKIE);
request->addInterestingHeader(WS_STR_VERSION);
request->addInterestingHeader(WS_STR_KEY);
request->addInterestingHeader(WS_STR_PROTOCOL);
return true;
}
@ -1093,9 +1086,6 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest* request) {
request->send(400);
return;
}
if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())) {
return request->requestAuthentication();
}
if (_handshakeHandler != nullptr) {
if (!_handshakeHandler(request)) {
request->send(401);

View File

@ -42,8 +42,6 @@
#include <ESPAsyncWebServer.h>
#include <deque>
#include <list>
#include <memory>
#ifdef ESP8266
@ -281,7 +279,7 @@ class AsyncWebSocket : public AsyncWebHandler {
public:
explicit AsyncWebSocket(const char* url) : _url(url), _cNextId(1), _enabled(true) {}
AsyncWebSocket(const String& url) : _url(url), _cNextId(1), _enabled(true) {}
~AsyncWebSocket(){};
~AsyncWebSocket() {};
const char* url() const { return _url.c_str(); }
void enable(bool e) { _enabled = e; }
bool enabled() const { return _enabled; }
@ -339,15 +337,8 @@ class AsyncWebSocket : public AsyncWebHandler {
size_t printfAll_P(PGM_P formatP, ...) __attribute__((format(printf, 2, 3)));
#endif
// event listener
void onEvent(AwsEventHandler handler) {
_eventHandler = handler;
}
// Handshake Handler
void handleHandshake(AwsHandshakeHandler handler) {
_handshakeHandler = handler;
}
void onEvent(AwsEventHandler handler) { _eventHandler = handler; }
void handleHandshake(AwsHandshakeHandler handler) { _handshakeHandler = handler; }
// system callbacks (do not call)
uint32_t _getNextId() { return _cNextId++; }
@ -360,7 +351,7 @@ class AsyncWebSocket : public AsyncWebHandler {
AsyncWebSocketMessageBuffer* makeBuffer(size_t size = 0);
AsyncWebSocketMessageBuffer* makeBuffer(const uint8_t* data, size_t size);
const std::list<AsyncWebSocketClient>& getClients() const { return _clients; }
std::list<AsyncWebSocketClient>& getClients() { return _clients; }
};
// WebServer response to authenticate the socket and detach the tcp client from the web server request

16
src/ChunkPrint.cpp Normal file
View File

@ -0,0 +1,16 @@
#include <ChunkPrint.h>
ChunkPrint::ChunkPrint(uint8_t* destination, size_t from, size_t len)
: _destination(destination), _to_skip(from), _to_write(len), _pos{0} {}
size_t ChunkPrint::write(uint8_t c) {
if (_to_skip > 0) {
_to_skip--;
return 1;
} else if (_to_write > 0) {
_to_write--;
_destination[_pos++] = c;
return 1;
}
return 0;
}

View File

@ -11,22 +11,9 @@ class ChunkPrint : public Print {
size_t _pos;
public:
ChunkPrint(uint8_t* destination, size_t from, size_t len)
: _destination(destination), _to_skip(from), _to_write(len), _pos{0} {}
ChunkPrint(uint8_t* destination, size_t from, size_t len);
virtual ~ChunkPrint() {}
size_t write(uint8_t c) {
if (_to_skip > 0) {
_to_skip--;
return 1;
} else if (_to_write > 0) {
_to_write--;
_destination[_pos++] = c;
return 1;
}
return 0;
}
size_t write(const uint8_t* buffer, size_t size) {
return this->Print::write(buffer, size);
}
size_t write(uint8_t c);
size_t write(const uint8_t* buffer, size_t size) { return this->Print::write(buffer, size); }
};
#endif

View File

@ -24,8 +24,11 @@
#include "Arduino.h"
#include "FS.h"
#include <algorithm>
#include <deque>
#include <functional>
#include <list>
#include <unordered_map>
#include <vector>
#ifdef ESP32
@ -45,10 +48,10 @@
#include "literals.h"
#define ASYNCWEBSERVER_VERSION "3.2.4"
#define ASYNCWEBSERVER_VERSION "3.3.7"
#define ASYNCWEBSERVER_VERSION_MAJOR 3
#define ASYNCWEBSERVER_VERSION_MINOR 2
#define ASYNCWEBSERVER_VERSION_REVISION 4
#define ASYNCWEBSERVER_VERSION_MINOR 3
#define ASYNCWEBSERVER_VERSION_REVISION 7
#define ASYNCWEBSERVER_FORK_mathieucarbou
#ifdef ASYNCWEBSERVER_REGEX
@ -67,6 +70,7 @@ class AsyncWebHandler;
class AsyncStaticWebHandler;
class AsyncCallbackWebHandler;
class AsyncResponseStream;
class AsyncMiddlewareChain;
#if defined(TARGET_RP2040)
typedef enum http_method WebRequestMethod;
@ -99,7 +103,8 @@ namespace fs {
#endif
// if this value is returned when asked for data, packet will not be sent and you will be asked for data again
#define RESPONSE_TRY_AGAIN 0xFFFFFFFF
#define RESPONSE_TRY_AGAIN 0xFFFFFFFF
#define RESPONSE_STREAM_BUFFER_SIZE 1460
typedef uint8_t WebRequestMethodComposite;
typedef std::function<void(void)> ArDisconnectHandler;
@ -137,33 +142,15 @@ class AsyncWebHeader {
public:
AsyncWebHeader() = default;
AsyncWebHeader(const AsyncWebHeader&) = default;
AsyncWebHeader(const char* name, const char* value) : _name(name), _value(value) {}
AsyncWebHeader(const String& name, const String& value) : _name(name), _value(value) {}
AsyncWebHeader(const String& data) {
if (!data)
return;
int index = data.indexOf(':');
if (index < 0)
return;
_name = data.substring(0, index);
_value = data.substring(index + 2);
}
AsyncWebHeader(const String& data);
AsyncWebHeader& operator=(const AsyncWebHeader&) = default;
const String& name() const { return _name; }
const String& value() const { return _value; }
String toString() const {
String str;
str.reserve(_name.length() + _value.length() + 2);
str.concat(_name);
str.concat((char)0x3a);
str.concat((char)0x20);
str.concat(_value);
str.concat(asyncsrv::T_rn);
return str;
}
String toString() const;
};
/*
@ -177,6 +164,15 @@ typedef enum { RCT_NOT_USED = -1,
RCT_EVENT,
RCT_MAX } RequestedConnectionType;
// this enum is similar to Arduino WebServer's AsyncAuthType and PsychicHttp
typedef enum {
AUTH_NONE = 0,
AUTH_BASIC,
AUTH_DIGEST,
AUTH_BEARER,
AUTH_OTHER,
} AsyncAuthType;
typedef std::function<size_t(uint8_t*, size_t, size_t)> AwsResponseFiller;
typedef std::function<String(const String&)> AwsTemplateProcessor;
@ -191,9 +187,11 @@ class AsyncWebServerRequest {
AsyncWebServer* _server;
AsyncWebHandler* _handler;
AsyncWebServerResponse* _response;
std::vector<String> _interestingHeaders;
ArDisconnectHandler _onDisconnectfn;
// response is sent
bool _sent = false;
String _temp;
uint8_t _parseState;
@ -205,8 +203,7 @@ class AsyncWebServerRequest {
String _boundary;
String _authorization;
RequestedConnectionType _reqconntype;
void _removeNotInterestingHeaders();
bool _isDigest;
AsyncAuthType _authMethod = AsyncAuthType::AUTH_NONE;
bool _isMultipart;
bool _isPlainPost;
bool _expectingContinue;
@ -217,6 +214,8 @@ class AsyncWebServerRequest {
std::list<AsyncWebParameter> _params;
std::vector<String> _pathParams;
std::unordered_map<const char*, String, std::hash<const char*>, std::equal_to<const char*>> _attributes;
uint8_t _multiParseState;
uint8_t _boundaryPosition;
size_t _itemStartIndex;
@ -281,31 +280,38 @@ class AsyncWebServerRequest {
// base64(user:pass) for basic or
// user:realm:md5(user:realm:pass) for digest
bool authenticate(const char* hash);
bool authenticate(const char* username, const char* password, const char* realm = NULL, bool passwordIsHash = false);
void requestAuthentication(const char* realm = NULL, bool isDigest = true);
bool authenticate(const char* username, const char* credentials, const char* realm = NULL, bool isHash = false);
void requestAuthentication(const char* realm = nullptr, bool isDigest = true) { requestAuthentication(isDigest ? AsyncAuthType::AUTH_DIGEST : AsyncAuthType::AUTH_BASIC, realm); }
void requestAuthentication(AsyncAuthType method, const char* realm = nullptr, const char* _authFailMsg = nullptr);
void setHandler(AsyncWebHandler* handler) { _handler = handler; }
/**
* @brief add header to collect from a response
*
* @param name
*/
void addInterestingHeader(const char* name);
void addInterestingHeader(const String& name) { return addInterestingHeader(name.c_str()); };
#ifndef ESP8266
[[deprecated("All headers are now collected. Use removeHeader(name) or HeaderFreeMiddleware if you really need to free some headers.")]]
#endif
void addInterestingHeader(__unused const char* name) {
}
#ifndef ESP8266
[[deprecated("All headers are now collected. Use removeHeader(name) or HeaderFreeMiddleware if you really need to free some headers.")]]
#endif
void addInterestingHeader(__unused const String& name) {
}
/**
* @brief issue 302 redirect response
* @brief issue HTTP redirect responce with Location header
*
* @param url
* @param url - url to redirect to
* @param code - responce code, default is 302 : temporary redirect
*/
void redirect(const char* url);
void redirect(const String& url) { return redirect(url.c_str()); };
void redirect(const char* url, int code = 302);
void redirect(const String& url, int code = 302) { return redirect(url.c_str(), code); };
void send(AsyncWebServerResponse* response);
AsyncWebServerResponse* getResponse() const { return _response; }
void send(int code, const char* contentType = asyncsrv::empty, const char* content = asyncsrv::empty, AwsTemplateProcessor callback = nullptr) { send(beginResponse(code, contentType, content, callback)); }
void send(int code, const String& contentType, const String& content = emptyString, AwsTemplateProcessor callback = nullptr) { send(beginResponse(code, contentType, content, callback)); }
void send(int code, const String& contentType, const char* content = asyncsrv::empty, AwsTemplateProcessor callback = nullptr) { send(beginResponse(code, contentType.c_str(), content, callback)); }
void send(int code, const String& contentType, const String& content, AwsTemplateProcessor callback = nullptr) { send(beginResponse(code, contentType.c_str(), content.c_str(), callback)); }
void send(int code, const char* contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr) { send(beginResponse(code, contentType, content, len, callback)); }
void send(int code, const String& contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr) { send(beginResponse(code, contentType, content, len, callback)); }
@ -336,24 +342,25 @@ class AsyncWebServerRequest {
void sendChunked(const String& contentType, AwsResponseFiller callback, AwsTemplateProcessor templateCallback = nullptr) { send(beginChunkedResponse(contentType, callback, templateCallback)); }
#ifndef ESP8266
[[deprecated("Replaced by send(...)")]]
[[deprecated("Replaced by send(int code, const String& contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr)")]]
#endif
void send_P(int code, const String& contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr) {
send(code, contentType, content, len, callback);
}
#ifndef ESP8266
[[deprecated("Replaced by send(...)")]]
#endif
[[deprecated("Replaced by send(int code, const String& contentType, const char* content = asyncsrv::empty, AwsTemplateProcessor callback = nullptr)")]]
void send_P(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback = nullptr) {
send(code, contentType, content, callback);
}
#ifdef ESP8266
void send(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback = nullptr) { send(beginResponse(code, contentType, content, callback)); }
#else
void send_P(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback = nullptr) {
send(beginResponse_P(code, contentType, content, callback));
}
#endif
AsyncWebServerResponse* beginResponse(int code, const char* contentType = asyncsrv::empty, const char* content = asyncsrv::empty, AwsTemplateProcessor callback = nullptr);
AsyncWebServerResponse* beginResponse(int code, const String& contentType, const String& content = emptyString, AwsTemplateProcessor callback = nullptr) { return beginResponse(code, contentType.c_str(), content.c_str(), callback); }
AsyncWebServerResponse* beginResponse(int code, const String& contentType, const char* content = asyncsrv::empty, AwsTemplateProcessor callback = nullptr) { return beginResponse(code, contentType.c_str(), content, callback); }
AsyncWebServerResponse* beginResponse(int code, const String& contentType, const String& content, AwsTemplateProcessor callback = nullptr) { return beginResponse(code, contentType.c_str(), content.c_str(), callback); }
AsyncWebServerResponse* beginResponse(int code, const char* contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr);
AsyncWebServerResponse* beginResponse(int code, const String& contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr) { return beginResponse(code, contentType.c_str(), content, len, callback); }
@ -373,48 +380,19 @@ class AsyncWebServerRequest {
AsyncWebServerResponse* beginChunkedResponse(const char* contentType, AwsResponseFiller callback, AwsTemplateProcessor templateCallback = nullptr);
AsyncWebServerResponse* beginChunkedResponse(const String& contentType, AwsResponseFiller callback, AwsTemplateProcessor templateCallback = nullptr);
AsyncResponseStream* beginResponseStream(const char* contentType, size_t bufferSize = 1460);
AsyncResponseStream* beginResponseStream(const String& contentType, size_t bufferSize = 1460) { return beginResponseStream(contentType.c_str(), bufferSize); }
AsyncResponseStream* beginResponseStream(const char* contentType, size_t bufferSize = RESPONSE_STREAM_BUFFER_SIZE);
AsyncResponseStream* beginResponseStream(const String& contentType, size_t bufferSize = RESPONSE_STREAM_BUFFER_SIZE) { return beginResponseStream(contentType.c_str(), bufferSize); }
#ifndef ESP8266
[[deprecated("Replaced by beginResponse(...)")]]
[[deprecated("Replaced by beginResponse(int code, const String& contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr)")]]
#endif
AsyncWebServerResponse* beginResponse_P(int code, const String& contentType, const uint8_t* content, size_t len, AwsTemplateProcessor callback = nullptr) {
return beginResponse(code, contentType.c_str(), content, len, callback);
}
#ifndef ESP8266
[[deprecated("Replaced by beginResponse(...)")]]
#endif
AsyncWebServerResponse* beginResponse_P(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback = nullptr) {
return beginResponse(code, contentType.c_str(), content, callback);
}
#ifdef ESP8266
AsyncWebServerResponse* beginResponse(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback = nullptr);
#endif
size_t headers() const; // get header count
// check if header exists
bool hasHeader(const char* name) const;
bool hasHeader(const String& name) const { return hasHeader(name.c_str()); };
#ifdef ESP8266
bool hasHeader(const __FlashStringHelper* data) const; // check if header exists
#endif
const AsyncWebHeader* getHeader(const char* name) const;
const AsyncWebHeader* getHeader(const String& name) const { return getHeader(name.c_str()); };
#ifdef ESP8266
const AsyncWebHeader* getHeader(const __FlashStringHelper* data) const;
#endif
const AsyncWebHeader* getHeader(size_t num) const;
size_t params() const; // get arguments count
bool hasParam(const char* name, bool post = false, bool file = false) const;
bool hasParam(const String& name, bool post = false, bool file = false) const { return hasParam(name.c_str(), post, file); };
#ifdef ESP8266
bool hasParam(const __FlashStringHelper* data, bool post = false, bool file = false) const { return hasParam(String(data).c_str(), post, file); };
[[deprecated("Replaced by beginResponse(int code, const String& contentType, const char* content = asyncsrv::empty, AwsTemplateProcessor callback = nullptr)")]]
#endif
AsyncWebServerResponse* beginResponse_P(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback = nullptr);
/**
* @brief Get the Request parameter by name
@ -469,6 +447,56 @@ class AsyncWebServerRequest {
const String& header(size_t i) const; // get request header value by number
const String& headerName(size_t i) const; // get request header name by number
size_t headers() const; // get header count
// check if header exists
bool hasHeader(const char* name) const;
bool hasHeader(const String& name) const { return hasHeader(name.c_str()); };
#ifdef ESP8266
bool hasHeader(const __FlashStringHelper* data) const; // check if header exists
#endif
const AsyncWebHeader* getHeader(const char* name) const;
const AsyncWebHeader* getHeader(const String& name) const { return getHeader(name.c_str()); };
#ifdef ESP8266
const AsyncWebHeader* getHeader(const __FlashStringHelper* data) const;
#endif
const AsyncWebHeader* getHeader(size_t num) const;
const std::list<AsyncWebHeader>& getHeaders() const { return _headers; }
size_t getHeaderNames(std::vector<const char*>& names) const;
// Remove a header from the request.
// It will free the memory and prevent the header to be seen during request processing.
bool removeHeader(const char* name);
// Remove all request headers.
void removeHeaders() { _headers.clear(); }
size_t params() const; // get arguments count
bool hasParam(const char* name, bool post = false, bool file = false) const;
bool hasParam(const String& name, bool post = false, bool file = false) const { return hasParam(name.c_str(), post, file); };
#ifdef ESP8266
bool hasParam(const __FlashStringHelper* data, bool post = false, bool file = false) const { return hasParam(String(data).c_str(), post, file); };
#endif
// REQUEST ATTRIBUTES
void setAttribute(const char* name, const char* value) { _attributes[name] = value; }
void setAttribute(const char* name, bool value) { _attributes[name] = value ? "1" : emptyString; }
void setAttribute(const char* name, long value) { _attributes[name] = String(value); }
void setAttribute(const char* name, float value, unsigned int decimalPlaces = 2) { _attributes[name] = String(value, decimalPlaces); }
void setAttribute(const char* name, double value, unsigned int decimalPlaces = 2) { _attributes[name] = String(value, decimalPlaces); }
bool hasAttribute(const char* name) const { return _attributes.find(name) != _attributes.end(); }
const String& getAttribute(const char* name, const String& defaultValue = emptyString) const;
bool getAttribute(const char* name, bool defaultValue) const;
long getAttribute(const char* name, long defaultValue) const;
float getAttribute(const char* name, float defaultValue) const;
double getAttribute(const char* name, double defaultValue) const;
String urlDecode(const String& text) const;
};
@ -482,6 +510,176 @@ bool ON_STA_FILTER(AsyncWebServerRequest* request);
bool ON_AP_FILTER(AsyncWebServerRequest* request);
/*
* MIDDLEWARE :: Request interceptor, assigned to a AsyncWebHandler (or the server), which can be used:
* 1. to run some code before the final handler is executed (e.g. check authentication)
* 2. decide whether to proceed or not with the next handler
* */
using ArMiddlewareNext = std::function<void(void)>;
using ArMiddlewareCallback = std::function<void(AsyncWebServerRequest* request, ArMiddlewareNext next)>;
// Middleware is a base class for all middleware
class AsyncMiddleware {
public:
virtual ~AsyncMiddleware() {}
virtual void run(__unused AsyncWebServerRequest* request, __unused ArMiddlewareNext next) {
return next();
};
private:
friend class AsyncWebHandler;
friend class AsyncEventSource;
friend class AsyncMiddlewareChain;
bool _freeOnRemoval = false;
};
// Create a custom middleware by providing an anonymous callback function
class AsyncMiddlewareFunction : public AsyncMiddleware {
public:
AsyncMiddlewareFunction(ArMiddlewareCallback fn) : _fn(fn) {}
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) override { return _fn(request, next); };
private:
ArMiddlewareCallback _fn;
};
// For internal use only: super class to add/remove middleware to server or handlers
class AsyncMiddlewareChain {
public:
virtual ~AsyncMiddlewareChain();
void addMiddleware(ArMiddlewareCallback fn);
void addMiddleware(AsyncMiddleware* middleware);
void addMiddlewares(std::vector<AsyncMiddleware*> middlewares);
bool removeMiddleware(AsyncMiddleware* middleware);
// For internal use only
void _runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer);
protected:
std::list<AsyncMiddleware*> _middlewares;
};
// AuthenticationMiddleware is a middleware that checks if the request is authenticated
class AuthenticationMiddleware : public AsyncMiddleware {
public:
void setUsername(const char* username);
void setPassword(const char* password);
void setPasswordHash(const char* hash);
void setRealm(const char* realm) { _realm = realm; }
void setAuthFailureMessage(const char* message) { _authFailMsg = message; }
void setAuthType(AsyncAuthType authMethod) { _authMethod = authMethod; }
// precompute and store the hash value based on the username, realm, and authMethod
// returns true if the hash was successfully generated and replaced
bool generateHash();
bool allowed(AsyncWebServerRequest* request);
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
String _username;
String _credentials;
bool _hash = false;
String _realm = asyncsrv::T_LOGIN_REQ;
AsyncAuthType _authMethod = AsyncAuthType::AUTH_NONE;
String _authFailMsg;
bool _hasCreds = false;
};
using ArAuthorizeFunction = std::function<bool(AsyncWebServerRequest* request)>;
// AuthorizationMiddleware is a middleware that checks if the request is authorized
class AuthorizationMiddleware : public AsyncMiddleware {
public:
AuthorizationMiddleware(ArAuthorizeFunction authorizeConnectHandler) : _code(403), _authz(authorizeConnectHandler) {}
AuthorizationMiddleware(int code, ArAuthorizeFunction authorizeConnectHandler) : _code(code), _authz(authorizeConnectHandler) {}
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return _authz && !_authz(request) ? request->send(_code) : next(); }
private:
int _code;
ArAuthorizeFunction _authz;
};
// remove all headers from the incoming request except the ones provided in the constructor
class HeaderFreeMiddleware : public AsyncMiddleware {
public:
void keep(const char* name) { _toKeep.push_back(name); }
void unKeep(const char* name) { _toKeep.erase(std::remove(_toKeep.begin(), _toKeep.end(), name), _toKeep.end()); }
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
std::vector<const char*> _toKeep;
};
// filter out specific headers from the incoming request
class HeaderFilterMiddleware : public AsyncMiddleware {
public:
void filter(const char* name) { _toRemove.push_back(name); }
void unFilter(const char* name) { _toRemove.erase(std::remove(_toRemove.begin(), _toRemove.end(), name), _toRemove.end()); }
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
std::vector<const char*> _toRemove;
};
// curl-like logging of incoming requests
class LoggingMiddleware : public AsyncMiddleware {
public:
void setOutput(Print& output) { _out = &output; }
void setEnabled(bool enabled) { _enabled = enabled; }
bool isEnabled() { return _enabled && _out; }
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
Print* _out = nullptr;
bool _enabled = true;
};
// CORS Middleware
class CorsMiddleware : public AsyncMiddleware {
public:
void setOrigin(const char* origin) { _origin = origin; }
void setMethods(const char* methods) { _methods = methods; }
void setHeaders(const char* headers) { _headers = headers; }
void setAllowCredentials(bool credentials) { _credentials = credentials; }
void setMaxAge(uint32_t seconds) { _maxAge = seconds; }
void addCORSHeaders(AsyncWebServerResponse* response);
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
String _origin = "*";
String _methods = "*";
String _headers = "*";
bool _credentials = true;
uint32_t _maxAge = 86400;
};
// Rate limit Middleware
class RateLimitMiddleware : public AsyncMiddleware {
public:
void setMaxRequests(size_t maxRequests) { _maxRequests = maxRequests; }
void setWindowSize(uint32_t seconds) { _windowSizeMillis = seconds * 1000; }
bool isRequestAllowed(uint32_t& retryAfterSeconds);
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
size_t _maxRequests = 0;
uint32_t _windowSizeMillis = 0;
std::list<uint32_t> _requestTimes;
};
/*
* REWRITE :: One instance can be handle any Request (done by the Server)
* */
@ -517,36 +715,24 @@ class AsyncWebRewrite {
* HANDLER :: One instance can be attached to any Request (done by the Server)
* */
class AsyncWebHandler {
class AsyncWebHandler : public AsyncMiddlewareChain {
protected:
ArRequestFilterFunction _filter{nullptr};
String _username;
String _password;
ArRequestFilterFunction _filter = nullptr;
AuthenticationMiddleware* _authMiddleware = nullptr;
public:
AsyncWebHandler() {}
AsyncWebHandler& setFilter(ArRequestFilterFunction fn) {
_filter = fn;
return *this;
}
AsyncWebHandler& setAuthentication(const char* username, const char* password) {
_username = username;
_password = password;
return *this;
};
AsyncWebHandler& setAuthentication(const String& username, const String& password) {
_username = username;
_password = password;
return *this;
};
AsyncWebHandler& setFilter(ArRequestFilterFunction fn);
AsyncWebHandler& setAuthentication(const char* username, const char* password);
AsyncWebHandler& setAuthentication(const String& username, const String& password) { return setAuthentication(username.c_str(), password.c_str()); };
bool filter(AsyncWebServerRequest* request) { return _filter == NULL || _filter(request); }
virtual ~AsyncWebHandler() {}
virtual bool canHandle(AsyncWebServerRequest* request __attribute__((unused))) {
return false;
}
virtual void handleRequest(AsyncWebServerRequest* request __attribute__((unused))) {}
virtual void handleUpload(AsyncWebServerRequest* request __attribute__((unused)), const String& filename __attribute__((unused)), size_t index __attribute__((unused)), uint8_t* data __attribute__((unused)), size_t len __attribute__((unused)), bool final __attribute__((unused))) {}
virtual void handleBody(AsyncWebServerRequest* request __attribute__((unused)), uint8_t* data __attribute__((unused)), size_t len __attribute__((unused)), size_t index __attribute__((unused)), size_t total __attribute__((unused))) {}
virtual void handleRequest(__unused AsyncWebServerRequest* request) {}
virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) {}
virtual void handleBody(__unused AsyncWebServerRequest* request, __unused uint8_t* data, __unused size_t len, __unused size_t index, __unused size_t total) {}
virtual bool isRequestHandlerTrivial() { return true; }
};
@ -588,6 +774,7 @@ class AsyncWebServerResponse {
AsyncWebServerResponse();
virtual ~AsyncWebServerResponse();
virtual void setCode(int code);
int code() const { return _code; }
virtual void setContentLength(size_t len);
void setContentType(const String& type) { setContentType(type.c_str()); }
virtual void setContentType(const char* type);
@ -597,6 +784,7 @@ class AsyncWebServerResponse {
bool addHeader(const String& name, long value, bool replaceExisting = true) { return addHeader(name.c_str(), value, replaceExisting); }
virtual bool removeHeader(const char* name);
virtual const AsyncWebHeader* getHeader(const char* name) const;
const std::list<AsyncWebHeader>& getHeaders() const { return _headers; }
#ifndef ESP8266
[[deprecated("Use instead: _assembleHead(String& buffer, uint8_t version)")]]
@ -624,7 +812,7 @@ typedef std::function<void(AsyncWebServerRequest* request)> ArRequestHandlerFunc
typedef std::function<void(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final)> ArUploadHandlerFunction;
typedef std::function<void(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total)> ArBodyHandlerFunction;
class AsyncWebServer {
class AsyncWebServer : public AsyncMiddlewareChain {
protected:
AsyncServer _server;
std::list<std::shared_ptr<AsyncWebRewrite>> _rewrites;
@ -647,7 +835,7 @@ class AsyncWebServer {
/**
* @brief (compat) Add url rewrite rule by pointer
* a deep copy of the pounter object will be created,
* a deep copy of the pointer object will be created,
* it is up to user to manage further lifetime of the object in argument
*
* @param rewrite pointer to rewrite object to copy setting from
@ -687,10 +875,8 @@ class AsyncWebServer {
AsyncWebHandler& addHandler(AsyncWebHandler* handler);
bool removeHandler(AsyncWebHandler* handler);
AsyncCallbackWebHandler& on(const char* uri, ArRequestHandlerFunction onRequest);
AsyncCallbackWebHandler& on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest);
AsyncCallbackWebHandler& on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload);
AsyncCallbackWebHandler& on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload, ArBodyHandlerFunction onBody);
AsyncCallbackWebHandler& on(const char* uri, ArRequestHandlerFunction onRequest) { return on(uri, HTTP_ANY, onRequest); }
AsyncCallbackWebHandler& on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload = nullptr, ArBodyHandlerFunction onBody = nullptr);
AsyncStaticWebHandler& serveStatic(const char* uri, fs::FS& fs, const char* path, const char* cache_control = NULL);

253
src/Middleware.cpp Normal file
View File

@ -0,0 +1,253 @@
#include "WebAuthentication.h"
#include <ESPAsyncWebServer.h>
AsyncMiddlewareChain::~AsyncMiddlewareChain() {
for (AsyncMiddleware* m : _middlewares)
if (m->_freeOnRemoval)
delete m;
}
void AsyncMiddlewareChain::addMiddleware(ArMiddlewareCallback fn) {
AsyncMiddlewareFunction* m = new AsyncMiddlewareFunction(fn);
m->_freeOnRemoval = true;
_middlewares.emplace_back(m);
}
void AsyncMiddlewareChain::addMiddleware(AsyncMiddleware* middleware) {
if (middleware)
_middlewares.emplace_back(middleware);
}
void AsyncMiddlewareChain::addMiddlewares(std::vector<AsyncMiddleware*> middlewares) {
for (AsyncMiddleware* m : middlewares)
addMiddleware(m);
}
bool AsyncMiddlewareChain::removeMiddleware(AsyncMiddleware* middleware) {
// remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector.
const size_t size = _middlewares.size();
_middlewares.erase(std::remove_if(_middlewares.begin(), _middlewares.end(), [middleware](AsyncMiddleware* m) {
if (m == middleware) {
if (m->_freeOnRemoval)
delete m;
return true;
}
return false;
}),
_middlewares.end());
return size != _middlewares.size();
}
void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer) {
if (!_middlewares.size())
return finalizer();
ArMiddlewareNext next;
std::list<AsyncMiddleware*>::iterator it = _middlewares.begin();
next = [this, &next, &it, request, finalizer]() {
if (it == _middlewares.end())
return finalizer();
AsyncMiddleware* m = *it;
it++;
return m->run(request, next);
};
return next();
}
void AuthenticationMiddleware::setUsername(const char* username) {
_username = username;
_hasCreds = _username.length() && _credentials.length();
}
void AuthenticationMiddleware::setPassword(const char* password) {
_credentials = password;
_hash = false;
_hasCreds = _username.length() && _credentials.length();
}
void AuthenticationMiddleware::setPasswordHash(const char* hash) {
_credentials = hash;
_hash = true;
_hasCreds = _username.length() && _credentials.length();
}
bool AuthenticationMiddleware::generateHash() {
// ensure we have all the necessary data
if (!_hasCreds)
return false;
// if we already have a hash, do nothing
if (_hash)
return false;
switch (_authMethod) {
case AsyncAuthType::AUTH_DIGEST:
_credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str());
_hash = true;
return true;
case AsyncAuthType::AUTH_BASIC:
_credentials = generateBasicHash(_username.c_str(), _credentials.c_str());
_hash = true;
return true;
default:
return false;
}
}
bool AuthenticationMiddleware::allowed(AsyncWebServerRequest* request) {
if (_authMethod == AsyncAuthType::AUTH_NONE)
return true;
if (!_hasCreds)
return false;
return request->authenticate(_username.c_str(), _credentials.c_str(), _realm.c_str(), _hash);
}
void AuthenticationMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str());
}
void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
std::vector<const char*> reqHeaders;
request->getHeaderNames(reqHeaders);
for (const char* h : reqHeaders) {
bool keep = false;
for (const char* k : _toKeep) {
if (strcasecmp(h, k) == 0) {
keep = true;
break;
}
}
if (!keep) {
request->removeHeader(h);
}
}
next();
}
void HeaderFilterMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it)
request->removeHeader(*it);
next();
}
void LoggingMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
if (!isEnabled()) {
next();
return;
}
_out->print(F("* Connection from "));
_out->print(request->client()->remoteIP().toString());
_out->print(':');
_out->println(request->client()->remotePort());
_out->print('>');
_out->print(' ');
_out->print(request->methodToString());
_out->print(' ');
_out->print(request->url().c_str());
_out->print(F(" HTTP/1."));
_out->println(request->version());
for (auto& h : request->getHeaders()) {
if (h.value().length()) {
_out->print('>');
_out->print(' ');
_out->print(h.name());
_out->print(':');
_out->print(' ');
_out->println(h.value());
}
}
_out->println(F(">"));
uint32_t elapsed = millis();
next();
elapsed = millis() - elapsed;
AsyncWebServerResponse* response = request->getResponse();
if (response) {
_out->print(F("* Processed in "));
_out->print(elapsed);
_out->println(F(" ms"));
_out->print('<');
_out->print(F(" HTTP/1."));
_out->print(request->version());
_out->print(' ');
_out->print(response->code());
_out->print(' ');
_out->println(AsyncWebServerResponse::responseCodeToString(response->code()));
for (auto& h : response->getHeaders()) {
if (h.value().length()) {
_out->print('<');
_out->print(' ');
_out->print(h.name());
_out->print(':');
_out->print(' ');
_out->println(h.value());
}
}
_out->println('<');
} else {
_out->println(F("* Connection closed!"));
}
}
void CorsMiddleware::addCORSHeaders(AsyncWebServerResponse* response) {
response->addHeader(F("Access-Control-Allow-Origin"), _origin.c_str());
response->addHeader(F("Access-Control-Allow-Methods"), _methods.c_str());
response->addHeader(F("Access-Control-Allow-Headers"), _headers.c_str());
response->addHeader(F("Access-Control-Allow-Credentials"), _credentials ? F("true") : F("false"));
response->addHeader(F("Access-Control-Max-Age"), String(_maxAge).c_str());
}
void CorsMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
// Origin header ? => CORS handling
if (request->hasHeader(F("Origin"))) {
// check if this is a preflight request => handle it and return
if (request->method() == HTTP_OPTIONS) {
AsyncWebServerResponse* response = request->beginResponse(200);
addCORSHeaders(response);
request->send(response);
return;
}
// CORS request, no options => let the request pass and add CORS headers after
next();
AsyncWebServerResponse* response = request->getResponse();
if (response) {
addCORSHeaders(response);
}
} else {
// NO Origin header => no CORS handling
next();
}
}
bool RateLimitMiddleware::isRequestAllowed(uint32_t& retryAfterSeconds) {
uint32_t now = millis();
while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis)
_requestTimes.pop_front();
_requestTimes.push_back(now);
if (_requestTimes.size() > _maxRequests) {
_requestTimes.pop_front();
retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1;
return false;
}
retryAfterSeconds = 0;
return true;
}
void RateLimitMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
uint32_t retryAfterSeconds;
if (isRequestAllowed(retryAfterSeconds)) {
next();
} else {
AsyncWebServerResponse* response = request->beginResponse(429);
response->addHeader(F("Retry-After"), retryAfterSeconds);
request->send(response);
}
}

View File

@ -34,36 +34,34 @@ using namespace asyncsrv;
bool checkBasicAuthentication(const char* hash, const char* username, const char* password) {
if (username == NULL || password == NULL || hash == NULL)
return false;
return generateBasicHash(username, password).equalsIgnoreCase(hash);
}
String generateBasicHash(const char* username, const char* password) {
if (username == NULL || password == NULL)
return emptyString;
size_t toencodeLen = strlen(username) + strlen(password) + 1;
size_t encodedLen = base64_encode_expected_len(toencodeLen);
if (strlen(hash) != encodedLen)
// Fix from https://github.com/me-no-dev/ESPAsyncWebServer/issues/667
#ifdef ARDUINO_ARCH_ESP32
if (strlen(hash) != encodedLen)
#else
if (strlen(hash) != encodedLen - 1)
#endif
return false;
char* toencode = new char[toencodeLen + 1];
if (toencode == NULL) {
return false;
return emptyString;
}
char* encoded = new char[base64_encode_expected_len(toencodeLen) + 1];
if (encoded == NULL) {
delete[] toencode;
return false;
return emptyString;
}
sprintf_P(toencode, PSTR("%s:%s"), username, password);
if (base64_encode_chars(toencode, toencodeLen, encoded) > 0 && memcmp(hash, encoded, encodedLen) == 0) {
if (base64_encode_chars(toencode, toencodeLen, encoded) > 0) {
String res = String(encoded);
delete[] toencode;
delete[] encoded;
return true;
return res;
}
delete[] toencode;
delete[] encoded;
return false;
return emptyString;
}
static bool getMD5(uint8_t* data, uint16_t len, char* output) { // 33 bytes or more
@ -94,7 +92,7 @@ static bool getMD5(uint8_t* data, uint16_t len, char* output) { // 33 bytes or m
return true;
}
static String genRandomMD5() {
String genRandomMD5() {
#ifdef ESP8266
uint32_t r = RANDOM_REG32;
#else
@ -122,31 +120,21 @@ String generateDigestHash(const char* username, const char* password, const char
return emptyString;
}
char* out = (char*)malloc(33);
String res = String(username);
res += ':';
res.concat(realm);
res += ':';
String in = res;
String in;
in.reserve(strlen(username) + strlen(realm) + strlen(password) + 2);
in.concat(username);
in.concat(':');
in.concat(realm);
in.concat(':');
in.concat(password);
if (out == NULL || !getMD5((uint8_t*)(in.c_str()), in.length(), out))
return emptyString;
res.concat(out);
free(out);
return res;
}
String requestDigestAuthentication(const char* realm) {
String header(T_realm__);
if (realm == NULL)
header.concat(T_asyncesp);
else
header.concat(realm);
header.concat(T_auth_nonce);
header.concat(genRandomMD5());
header.concat(T__opaque);
header.concat(genRandomMD5());
header += (char)0x22; // '"'
return header;
in = String(out);
free(out);
return in;
}
#ifndef ESP8266
@ -235,9 +223,9 @@ bool checkDigestAuthentication(const char* header, const __FlashStringHelper* me
}
} while (nextBreak > 0);
String ha1 = (passwordIsHash) ? String(password) : stringMD5(myUsername + ':' + myRealm + ':' + password);
String ha2 = String(method) + ':' + myUri;
String response = ha1 + ':' + myNonce + ':' + myNc + ':' + myCnonce + ':' + myQop + ':' + stringMD5(ha2);
String ha1 = passwordIsHash ? password : stringMD5(myUsername + ':' + myRealm + ':' + password).c_str();
String ha2 = stringMD5(String(method) + ':' + myUri);
String response = ha1 + ':' + myNonce + ':' + myNc + ':' + myCnonce + ':' + myQop + ':' + ha2;
if (myResponse.equals(stringMD5(response))) {
// os_printf("AUTH SUCCESS\n");

View File

@ -25,7 +25,6 @@
#include "Arduino.h"
bool checkBasicAuthentication(const char* header, const char* username, const char* password);
String requestDigestAuthentication(const char* realm);
bool checkDigestAuthentication(const char* header, const char* method, const char* username, const char* password, const char* realm, bool passwordIsHash, const char* nonce, const char* opaque, const char* uri);
@ -36,4 +35,8 @@ bool checkDigestAuthentication(const char* header, const __FlashStringHelper* me
// for storing hashed versions on the device that can be authenticated against
String generateDigestHash(const char* username, const char* password, const char* realm);
String generateBasicHash(const char* username, const char* password);
String genRandomMD5();
#endif

View File

@ -63,10 +63,7 @@ class AsyncStaticWebHandler : public AsyncWebHandler {
AsyncStaticWebHandler& setLastModified(time_t last_modified);
AsyncStaticWebHandler& setLastModified(); // sets to current time. Make sure sntp is runing and time is updated
#endif
AsyncStaticWebHandler& setTemplateProcessor(AwsTemplateProcessor newCallback) {
_callback = newCallback;
return *this;
}
AsyncStaticWebHandler& setTemplateProcessor(AwsTemplateProcessor newCallback);
};
class AsyncCallbackWebHandler : public AsyncWebHandler {
@ -81,75 +78,17 @@ class AsyncCallbackWebHandler : public AsyncWebHandler {
public:
AsyncCallbackWebHandler() : _uri(), _method(HTTP_ANY), _onRequest(NULL), _onUpload(NULL), _onBody(NULL), _isRegex(false) {}
void setUri(const String& uri) {
_uri = uri;
_isRegex = uri.startsWith("^") && uri.endsWith("$");
}
void setUri(const String& uri);
void setMethod(WebRequestMethodComposite method) { _method = method; }
void onRequest(ArRequestHandlerFunction fn) { _onRequest = fn; }
void onUpload(ArUploadHandlerFunction fn) { _onUpload = fn; }
void onBody(ArBodyHandlerFunction fn) { _onBody = fn; }
virtual bool canHandle(AsyncWebServerRequest* request) override final {
if (!_onRequest)
return false;
if (!(_method & request->method()))
return false;
#ifdef ASYNCWEBSERVER_REGEX
if (_isRegex) {
std::regex pattern(_uri.c_str());
std::smatch matches;
std::string s(request->url().c_str());
if (std::regex_search(s, matches, pattern)) {
for (size_t i = 1; i < matches.size(); ++i) { // start from 1
request->_addPathParam(matches[i].str().c_str());
}
} else {
return false;
}
} else
#endif
if (_uri.length() && _uri.startsWith("/*.")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(uriTemplate.lastIndexOf("."));
if (!request->url().endsWith(uriTemplate))
return false;
} else if (_uri.length() && _uri.endsWith("*")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(0, uriTemplate.length() - 1);
if (!request->url().startsWith(uriTemplate))
return false;
} else if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
request->addInterestingHeader("ANY");
return true;
}
virtual void handleRequest(AsyncWebServerRequest* request) override final {
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
return request->requestAuthentication();
if (_onRequest)
_onRequest(request);
else
request->send(500);
}
virtual void handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) override final {
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
return request->requestAuthentication();
if (_onUpload)
_onUpload(request, filename, index, data, len, final);
}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final {
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
return request->requestAuthentication();
if (_onBody)
_onBody(request, data, len, index, total);
}
virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; }
virtual bool canHandle(AsyncWebServerRequest* request) override final;
virtual void handleRequest(AsyncWebServerRequest* request) override final;
virtual void handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) override final;
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final;
virtual bool isRequestHandlerTrivial() override final { return !_onRequest; }
};
#endif /* ASYNCWEBSERVERHANDLERIMPL_H_ */

View File

@ -23,6 +23,20 @@
using namespace asyncsrv;
AsyncWebHandler& AsyncWebHandler::setFilter(ArRequestFilterFunction fn) {
_filter = fn;
return *this;
}
AsyncWebHandler& AsyncWebHandler::setAuthentication(const char* username, const char* password) {
if (!_authMiddleware) {
_authMiddleware = new AuthenticationMiddleware();
_authMiddleware->_freeOnRemoval = true;
addMiddleware(_authMiddleware);
}
_authMiddleware->setUsername(username);
_authMiddleware->setPassword(password);
return *this;
};
AsyncStaticWebHandler::AsyncStaticWebHandler(const char* uri, FS& fs, const char* path, const char* cache_control)
: _fs(fs), _uri(uri), _path(path), _default_file(F("index.htm")), _cache_control(cache_control), _last_modified(), _callback(nullptr) {
@ -94,18 +108,7 @@ bool AsyncStaticWebHandler::canHandle(AsyncWebServerRequest* request) {
if (request->method() != HTTP_GET || !request->url().startsWith(_uri) || !request->isExpectedRequestedConnType(RCT_DEFAULT, RCT_HTTP)) {
return false;
}
if (_getFile(request)) {
// We interested in "If-Modified-Since" header to check if file was modified
if (_last_modified.length())
request->addInterestingHeader(F("If-Modified-Since"));
if (_cache_control.length())
request->addInterestingHeader(F("If-None-Match"));
return true;
}
return false;
return _getFile(request);
}
bool AsyncStaticWebHandler::_getFile(AsyncWebServerRequest* request) {
@ -204,8 +207,6 @@ void AsyncStaticWebHandler::handleRequest(AsyncWebServerRequest* request) {
String filename = String((char*)request->_tempObject);
free(request->_tempObject);
request->_tempObject = NULL;
if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str()))
return request->requestAuthentication();
if (request->_tempFile == true) {
time_t lw = request->_tempFile.getLastWrite(); // get last file mod time (if supported by FS)
@ -248,3 +249,65 @@ void AsyncStaticWebHandler::handleRequest(AsyncWebServerRequest* request) {
request->send(404);
}
}
AsyncStaticWebHandler& AsyncStaticWebHandler::setTemplateProcessor(AwsTemplateProcessor newCallback) {
_callback = newCallback;
return *this;
}
void AsyncCallbackWebHandler::setUri(const String& uri) {
_uri = uri;
_isRegex = uri.startsWith("^") && uri.endsWith("$");
}
bool AsyncCallbackWebHandler::canHandle(AsyncWebServerRequest* request) {
if (!_onRequest)
return false;
if (!(_method & request->method()))
return false;
#ifdef ASYNCWEBSERVER_REGEX
if (_isRegex) {
std::regex pattern(_uri.c_str());
std::smatch matches;
std::string s(request->url().c_str());
if (std::regex_search(s, matches, pattern)) {
for (size_t i = 1; i < matches.size(); ++i) { // start from 1
request->_addPathParam(matches[i].str().c_str());
}
} else {
return false;
}
} else
#endif
if (_uri.length() && _uri.startsWith("/*.")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(uriTemplate.lastIndexOf("."));
if (!request->url().endsWith(uriTemplate))
return false;
} else if (_uri.length() && _uri.endsWith("*")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(0, uriTemplate.length() - 1);
if (!request->url().startsWith(uriTemplate))
return false;
} else if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
return true;
}
void AsyncCallbackWebHandler::handleRequest(AsyncWebServerRequest* request) {
if (_onRequest)
_onRequest(request);
else
request->send(500);
}
void AsyncCallbackWebHandler::handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) {
if (_onUpload)
_onUpload(request, filename, index, data, len, final);
}
void AsyncCallbackWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
if (_onBody)
_onBody(request, data, len, index, total);
}

View File

@ -35,7 +35,7 @@ enum { PARSE_REQ_START,
PARSE_REQ_FAIL };
AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c)
: _client(c), _server(s), _handler(NULL), _response(NULL), _temp(), _parseState(0), _version(0), _method(HTTP_ANY), _url(), _host(), _contentType(), _boundary(), _authorization(), _reqconntype(RCT_HTTP), _isDigest(false), _isMultipart(false), _isPlainPost(false), _expectingContinue(false), _contentLength(0), _parsedLength(0), _multiParseState(0), _boundaryPosition(0), _itemStartIndex(0), _itemSize(0), _itemName(), _itemFilename(), _itemType(), _itemValue(), _itemBuffer(0), _itemBufferIndex(0), _itemIsFile(false), _tempObject(NULL) {
: _client(c), _server(s), _handler(NULL), _response(NULL), _temp(), _parseState(0), _version(0), _method(HTTP_ANY), _url(), _host(), _contentType(), _boundary(), _authorization(), _reqconntype(RCT_HTTP), _authMethod(AsyncAuthType::AUTH_NONE), _isMultipart(false), _isPlainPost(false), _expectingContinue(false), _contentLength(0), _parsedLength(0), _multiParseState(0), _boundaryPosition(0), _itemStartIndex(0), _itemSize(0), _itemName(), _itemFilename(), _itemType(), _itemValue(), _itemBuffer(0), _itemBufferIndex(0), _itemIsFile(false), _tempObject(NULL) {
c->onError([](void* r, AsyncClient* c, int8_t error) { (void)c; AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onError(error); }, this);
c->onAck([](void* r, AsyncClient* c, size_t len, uint32_t time) { (void)c; AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onAck(len, time); }, this);
c->onDisconnect([](void* r, AsyncClient* c) { AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onDisconnect(); delete c; }, this);
@ -49,8 +49,6 @@ AsyncWebServerRequest::~AsyncWebServerRequest() {
_pathParams.clear();
_interestingHeaders.clear();
if (_response != NULL) {
delete _response;
}
@ -125,7 +123,6 @@ void AsyncWebServerRequest::_onData(void* buf, size_t len) {
}
}
if (!_isPlainPost) {
// check if authenticated before calling the body
if (_handler)
_handler->handleBody(this, (uint8_t*)buf, len, _parsedLength, _contentLength);
_parsedLength += len;
@ -141,31 +138,20 @@ void AsyncWebServerRequest::_onData(void* buf, size_t len) {
}
if (_parsedLength == _contentLength) {
_parseState = PARSE_REQ_END;
// check if authenticated before calling handleRequest and request auth instead
if (_handler)
_handler->handleRequest(this);
else
send(501);
_server->_runChain(this, [this]() { return _handler ? _handler->_runChain(this, [this]() { _handler->handleRequest(this); }) : send(501); });
if (!_sent) {
if (!_response)
send(501, T_text_plain, "Handler did not handle the request");
_client->setRxTimeout(0);
_response->_respond(this);
_sent = true;
}
}
}
break;
}
}
void AsyncWebServerRequest::_removeNotInterestingHeaders() {
if (std::any_of(std::begin(_interestingHeaders), std::end(_interestingHeaders), [](const String& str) { return str.equalsIgnoreCase(T_ANY); }))
return; // nothing to do
for (auto iter = std::begin(_headers); iter != std::end(_headers);) {
const auto name = iter->name();
if (std::none_of(std::begin(_interestingHeaders), std::end(_interestingHeaders), [&name](const String& str) { return str.equalsIgnoreCase(name); }))
iter = _headers.erase(iter);
else
iter++;
}
}
void AsyncWebServerRequest::_onPoll() {
// os_printf("p\n");
if (_response != NULL && _client != NULL && _client->canSend()) {
@ -299,9 +285,16 @@ bool AsyncWebServerRequest::_parseReqHeader() {
} else if (name.equalsIgnoreCase(T_AUTH)) {
if (value.length() > 5 && value.substring(0, 5).equalsIgnoreCase(T_BASIC)) {
_authorization = value.substring(6);
_authMethod = AsyncAuthType::AUTH_BASIC;
} else if (value.length() > 6 && value.substring(0, 6).equalsIgnoreCase(T_DIGEST)) {
_isDigest = true;
_authMethod = AsyncAuthType::AUTH_DIGEST;
_authorization = value.substring(7);
} else if (value.length() > 6 && value.substring(0, 6).equalsIgnoreCase(T_BEARER)) {
_authMethod = AsyncAuthType::AUTH_BEARER;
_authorization = value.substring(7);
} else {
_authorization = value;
_authMethod = AsyncAuthType::AUTH_OTHER;
}
} else {
if (name.equalsIgnoreCase(T_UPGRADE) && value.equalsIgnoreCase(T_WS)) {
@ -356,7 +349,7 @@ void AsyncWebServerRequest::_parsePlainPostChar(uint8_t data) {
void AsyncWebServerRequest::_handleUploadByte(uint8_t data, bool last) {
_itemBuffer[_itemBufferIndex++] = data;
if (last || _itemBufferIndex == 1460) {
if (last || _itemBufferIndex == RESPONSE_STREAM_BUFFER_SIZE) {
// check if authenticated before calling the upload
if (_handler)
_handler->handleUpload(this, _itemFilename, _itemSize - _itemBufferIndex, _itemBuffer, _itemBufferIndex, false);
@ -460,7 +453,7 @@ void AsyncWebServerRequest::_parseMultipartPostByte(uint8_t data, bool last) {
if (_itemIsFile) {
if (_itemBuffer)
free(_itemBuffer);
_itemBuffer = (uint8_t*)malloc(1460);
_itemBuffer = (uint8_t*)malloc(RESPONSE_STREAM_BUFFER_SIZE);
if (_itemBuffer == NULL) {
_multiParseState = PARSE_ERROR;
return;
@ -514,7 +507,6 @@ void AsyncWebServerRequest::_parseMultipartPostByte(uint8_t data, bool last) {
_params.emplace_back(_itemName, _itemValue, true);
} else {
if (_itemSize) {
// check if authenticated before calling the upload
if (_handler)
_handler->handleUpload(this, _itemFilename, _itemSize - _itemBufferIndex, _itemBuffer, _itemBufferIndex, true);
_itemBufferIndex = 0;
@ -583,20 +575,22 @@ void AsyncWebServerRequest::_parseLine() {
// end of headers
_server->_rewriteRequest(this);
_server->_attachHandler(this);
_removeNotInterestingHeaders();
if (_expectingContinue) {
String response(T_HTTP_100_CONT);
_client->write(response.c_str(), response.length());
}
// check handler for authentication
if (_contentLength) {
_parseState = PARSE_REQ_BODY;
} else {
_parseState = PARSE_REQ_END;
if (_handler)
_handler->handleRequest(this);
else
send(501);
_server->_runChain(this, [this]() { return _handler ? _handler->_runChain(this, [this]() { _handler->handleRequest(this); }) : send(501); });
if (!_sent) {
if (!_response)
send(501, T_text_plain, "Handler did not handle the request");
_client->setRxTimeout(0);
_response->_respond(this);
_sent = true;
}
}
} else
_parseReqHeader();
@ -649,6 +643,21 @@ const AsyncWebHeader* AsyncWebServerRequest::getHeader(size_t num) const {
return &(*std::next(_headers.cbegin(), num));
}
size_t AsyncWebServerRequest::getHeaderNames(std::vector<const char*>& names) const {
const size_t size = _headers.size();
names.reserve(size);
for (const auto& h : _headers) {
names.push_back(h.name().c_str());
}
return size;
}
bool AsyncWebServerRequest::removeHeader(const char* name) {
const size_t size = _headers.size();
_headers.remove_if([name](const AsyncWebHeader& header) { return header.name().equalsIgnoreCase(name); });
return size != _headers.size();
}
size_t AsyncWebServerRequest::params() const {
return _params.size();
}
@ -683,9 +692,25 @@ const AsyncWebParameter* AsyncWebServerRequest::getParam(size_t num) const {
return &(*std::next(_params.cbegin(), num));
}
void AsyncWebServerRequest::addInterestingHeader(const char* name) {
if (std::none_of(std::begin(_interestingHeaders), std::end(_interestingHeaders), [&name](const String& str) { return str.equalsIgnoreCase(name); }))
_interestingHeaders.emplace_back(name);
const String& AsyncWebServerRequest::getAttribute(const char* name, const String& defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second : defaultValue;
}
bool AsyncWebServerRequest::getAttribute(const char* name, bool defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second == "1" : defaultValue;
}
long AsyncWebServerRequest::getAttribute(const char* name, long defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toInt() : defaultValue;
}
float AsyncWebServerRequest::getAttribute(const char* name, float defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toFloat() : defaultValue;
}
double AsyncWebServerRequest::getAttribute(const char* name, double defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toDouble() : defaultValue;
}
AsyncWebServerResponse* AsyncWebServerRequest::beginResponse(int code, const char* contentType, const char* content, AwsTemplateProcessor callback) {
@ -728,38 +753,35 @@ AsyncResponseStream* AsyncWebServerRequest::beginResponseStream(const char* cont
return new AsyncResponseStream(contentType, bufferSize);
}
#ifdef ESP8266
AsyncWebServerResponse* AsyncWebServerRequest::beginResponse(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback) {
AsyncWebServerResponse* AsyncWebServerRequest::beginResponse_P(int code, const String& contentType, PGM_P content, AwsTemplateProcessor callback) {
return new AsyncProgmemResponse(code, contentType, (const uint8_t*)content, strlen_P(content), callback);
}
#endif
void AsyncWebServerRequest::send(AsyncWebServerResponse* response) {
if (_sent)
return;
if (_response)
delete _response;
_response = response;
if (_response == NULL) {
_client->close(true);
_onDisconnect();
_sent = true;
return;
}
if (!_response->_sourceValid()) {
delete response;
_response = NULL;
if (!_response->_sourceValid())
send(500);
} else {
_client->setRxTimeout(0);
_response->_respond(this);
}
}
void AsyncWebServerRequest::redirect(const char* url) {
AsyncWebServerResponse* response = beginResponse(302);
void AsyncWebServerRequest::redirect(const char* url, int code) {
AsyncWebServerResponse* response = beginResponse(code);
response->addHeader(T_LOCATION, url);
send(response);
}
bool AsyncWebServerRequest::authenticate(const char* username, const char* password, const char* realm, bool passwordIsHash) {
if (_authorization.length()) {
if (_isDigest)
if (_authMethod == AsyncAuthType::AUTH_DIGEST)
return checkDigestAuthentication(_authorization.c_str(), methodToString(), username, password, realm, passwordIsHash, NULL, NULL, NULL);
else if (!passwordIsHash)
return checkBasicAuthentication(_authorization.c_str(), username, password);
@ -773,7 +795,7 @@ bool AsyncWebServerRequest::authenticate(const char* hash) {
if (!_authorization.length() || hash == NULL)
return false;
if (_isDigest) {
if (_authMethod == AsyncAuthType::AUTH_DIGEST) {
String hStr = String(hash);
int separator = hStr.indexOf(':');
if (separator <= 0)
@ -788,23 +810,45 @@ bool AsyncWebServerRequest::authenticate(const char* hash) {
return checkDigestAuthentication(_authorization.c_str(), methodToString(), username.c_str(), hStr.c_str(), realm.c_str(), true, NULL, NULL, NULL);
}
// Basic Auth, Bearer Auth, or other
return (_authorization.equals(hash));
}
void AsyncWebServerRequest::requestAuthentication(const char* realm, bool isDigest) {
AsyncWebServerResponse* r = beginResponse(401);
if (!isDigest && realm == NULL) {
r->addHeader(T_WWW_AUTH, T_BASIC_REALM_LOGIN_REQ);
} else if (!isDigest) {
String header(T_BASIC_REALM);
header.concat(realm);
header += '"';
r->addHeader(T_WWW_AUTH, header.c_str());
} else {
String header(T_DIGEST_);
header.concat(requestDigestAuthentication(realm));
r->addHeader(T_WWW_AUTH, header.c_str());
void AsyncWebServerRequest::requestAuthentication(AsyncAuthType method, const char* realm, const char* _authFailMsg) {
if (!realm)
realm = T_LOGIN_REQ;
AsyncWebServerResponse* r = _authFailMsg ? beginResponse(401, T_text_html, _authFailMsg) : beginResponse(401);
switch (method) {
case AsyncAuthType::AUTH_BASIC: {
String header;
header.reserve(strlen(T_BASIC_REALM) + strlen(realm) + 1);
header.concat(T_BASIC_REALM);
header.concat(realm);
header.concat('"');
r->addHeader(T_WWW_AUTH, header.c_str());
break;
}
case AsyncAuthType::AUTH_DIGEST: {
constexpr size_t len = strlen(T_DIGEST_) + strlen(T_realm__) + strlen(T_auth_nonce) + 32 + strlen(T__opaque) + 32 + 1;
String header;
header.reserve(len + strlen(realm));
header.concat(T_DIGEST_);
header.concat(T_realm__);
header.concat(realm);
header.concat(T_auth_nonce);
header.concat(genRandomMD5());
header.concat(T__opaque);
header.concat(genRandomMD5());
header.concat((char)0x22); // '"'
r->addHeader(T_WWW_AUTH, header.c_str());
break;
}
default:
break;
}
send(r);
}

View File

@ -109,6 +109,8 @@ const char* AsyncWebServerResponse::responseCodeToString(int code) {
return T_HTTP_CODE_416;
case 417:
return T_HTTP_CODE_417;
case 429:
return T_HTTP_CODE_429;
case 500:
return T_HTTP_CODE_500;
case 501:
@ -196,6 +198,8 @@ const __FlashStringHelper* AsyncWebServerResponse::responseCodeToString(int code
return FPSTR(T_HTTP_CODE_416);
case 417:
return FPSTR(T_HTTP_CODE_417);
case 429:
return FPSTR(T_HTTP_CODE_429);
case 500:
return FPSTR(T_HTTP_CODE_500);
case 501:
@ -318,7 +322,6 @@ void AsyncWebServerResponse::_assembleHead(String& buffer, uint8_t version) {
buffer.concat(header.value());
buffer.concat(T_rn);
}
_headers.clear();
buffer.concat(T_rn);
_headLength = buffer.length();
@ -492,7 +495,7 @@ size_t AsyncAbstractResponse::_ack(AsyncWebServerRequest* request, size_t len, u
free(buf);
return 0;
}
outLen = sprintf((char*)buf+headLen, "%04x", readLen) + headLen;
outLen = sprintf((char*)buf + headLen, "%04x", readLen) + headLen;
buf[outLen++] = '\r';
buf[outLen++] = '\n';
outLen += readLen;

View File

@ -24,19 +24,19 @@
using namespace asyncsrv;
bool ON_STA_FILTER(AsyncWebServerRequest* request) {
#ifndef CONFIG_IDF_TARGET_ESP32H2
#ifndef CONFIG_IDF_TARGET_ESP32H2
return WiFi.localIP() == request->client()->localIP();
#else
#else
return false;
#endif
#endif
}
bool ON_AP_FILTER(AsyncWebServerRequest* request) {
#ifndef CONFIG_IDF_TARGET_ESP32H2
#ifndef CONFIG_IDF_TARGET_ESP32H2
return WiFi.localIP() != request->client()->localIP();
#else
#else
return false;
#endif
#endif
}
#ifndef HAVE_FS_FILE_OPEN_MODE
@ -155,7 +155,6 @@ void AsyncWebServer::_attachHandler(AsyncWebServerRequest* request) {
}
}
request->addInterestingHeader(T_ANY);
request->setHandler(_catchAllHandler);
}
@ -170,33 +169,6 @@ AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodCom
return *handler;
}
AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest, ArUploadHandlerFunction onUpload) {
AsyncCallbackWebHandler* handler = new AsyncCallbackWebHandler();
handler->setUri(uri);
handler->setMethod(method);
handler->onRequest(onRequest);
handler->onUpload(onUpload);
addHandler(handler);
return *handler;
}
AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, WebRequestMethodComposite method, ArRequestHandlerFunction onRequest) {
AsyncCallbackWebHandler* handler = new AsyncCallbackWebHandler();
handler->setUri(uri);
handler->setMethod(method);
handler->onRequest(onRequest);
addHandler(handler);
return *handler;
}
AsyncCallbackWebHandler& AsyncWebServer::on(const char* uri, ArRequestHandlerFunction onRequest) {
AsyncCallbackWebHandler* handler = new AsyncCallbackWebHandler();
handler->setUri(uri);
handler->onRequest(onRequest);
addHandler(handler);
return *handler;
}
AsyncStaticWebHandler& AsyncWebServer::serveStatic(const char* uri, fs::FS& fs, const char* path, const char* cache_control) {
AsyncStaticWebHandler* handler = new AsyncStaticWebHandler(uri, fs, path, cache_control);
addHandler(handler);

View File

@ -12,7 +12,7 @@ static constexpr const char* T_app_xform_urlencoded = "application/x-www-form-ur
static constexpr const char* T_AUTH = "Authorization";
static constexpr const char* T_BASIC = "Basic";
static constexpr const char* T_BASIC_REALM = "Basic realm=\"";
static constexpr const char* T_BASIC_REALM_LOGIN_REQ = "Basic realm=\"Login Required\"";
static constexpr const char* T_LOGIN_REQ = "Login Required";
static constexpr const char* T_BODY = "body";
static constexpr const char* T_Cache_Control = "Cache-Control";
static constexpr const char* T_chunked = "chunked";
@ -25,6 +25,7 @@ static constexpr const char* T_Content_Type = "Content-Type";
static constexpr const char* T_Cookie = "Cookie";
static constexpr const char* T_DIGEST = "Digest";
static constexpr const char* T_DIGEST_ = "Digest ";
static constexpr const char* T_BEARER = "Bearer";
static constexpr const char* T_ETag = "ETag";
static constexpr const char* T_EXPECT = "Expect";
static constexpr const char* T_HTTP_1_0 = "HTTP/1.0";
@ -137,6 +138,7 @@ static constexpr const char* T_HTTP_CODE_414 = "Request-URI Too Large";
static constexpr const char* T_HTTP_CODE_415 = "Unsupported Media Type";
static constexpr const char* T_HTTP_CODE_416 = "Requested range not satisfiable";
static constexpr const char* T_HTTP_CODE_417 = "Expectation Failed";
static constexpr const char* T_HTTP_CODE_429 = "Too Many Requests";
static constexpr const char* T_HTTP_CODE_500 = "Internal Server Error";
static constexpr const char* T_HTTP_CODE_501 = "Not Implemented";
static constexpr const char* T_HTTP_CODE_502 = "Bad Gateway";
@ -148,7 +150,6 @@ static constexpr const char* T_HTTP_CODE_ANY = "Unknown code";
// other
static constexpr const char* T__opaque = "\", opaque=\"";
static constexpr const char* T_13 = "13";
static constexpr const char* T_asyncesp = "asyncesp";
static constexpr const char* T_auth_nonce = "\", qop=\"auth\", nonce=\"";
static constexpr const char* T_cnonce = "cnonce";
static constexpr const char* T_data_ = "data: ";
@ -181,7 +182,7 @@ static const char T_app_xform_urlencoded[] PROGMEM = "application/x-www-form-url
static const char T_AUTH[] PROGMEM = "Authorization";
static const char T_BASIC[] PROGMEM = "Basic";
static const char T_BASIC_REALM[] PROGMEM = "Basic realm=\"";
static const char T_BASIC_REALM_LOGIN_REQ[] PROGMEM = "Basic realm=\"Login Required\"";
static const char T_LOGIN_REQ[] PROGMEM = "Login Required";
static const char T_BODY[] PROGMEM = "body";
static const char T_Cache_Control[] PROGMEM = "Cache-Control";
static const char T_chunked[] PROGMEM = "chunked";
@ -194,6 +195,7 @@ static const char T_Content_Type[] PROGMEM = "Content-Type";
static const char T_Cookie[] PROGMEM = "Cookie";
static const char T_DIGEST[] PROGMEM = "Digest";
static const char T_DIGEST_[] PROGMEM = "Digest ";
static const char T_BEARER[] PROGMEM = "Bearer";
static const char T_ETag[] PROGMEM = "ETag";
static const char T_EXPECT[] PROGMEM = "Expect";
static const char T_HTTP_1_0[] PROGMEM = "HTTP/1.0";
@ -306,6 +308,7 @@ static const char T_HTTP_CODE_414[] PROGMEM = "Request-URI Too Large";
static const char T_HTTP_CODE_415[] PROGMEM = "Unsupported Media Type";
static const char T_HTTP_CODE_416[] PROGMEM = "Requested range not satisfiable";
static const char T_HTTP_CODE_417[] PROGMEM = "Expectation Failed";
static const char T_HTTP_CODE_429[] PROGMEM = "Too Many Requests";
static const char T_HTTP_CODE_500[] PROGMEM = "Internal Server Error";
static const char T_HTTP_CODE_501[] PROGMEM = "Not Implemented";
static const char T_HTTP_CODE_502[] PROGMEM = "Bad Gateway";
@ -317,7 +320,6 @@ static const char T_HTTP_CODE_ANY[] PROGMEM = "Unknown code";
// other
static const char T__opaque[] PROGMEM = "\", opaque=\"";
static const char T_13[] PROGMEM = "13";
static const char T_asyncesp[] PROGMEM = "asyncesp";
static const char T_auth_nonce[] PROGMEM = "\", qop=\"auth\", nonce=\"";
static const char T_cnonce[] PROGMEM = "cnonce";
static const char T_data_[] PROGMEM = "data: ";