254 lines
7.5 KiB
C++
254 lines
7.5 KiB
C++
#include "WebAuthentication.h"
|
|
#include <ESPAsyncWebServer.h>
|
|
|
|
AsyncMiddlewareChain::~AsyncMiddlewareChain() {
|
|
for (AsyncMiddleware* m : _middlewares)
|
|
if (m->_freeOnRemoval)
|
|
delete m;
|
|
}
|
|
|
|
void AsyncMiddlewareChain::addMiddleware(ArMiddlewareCallback fn) {
|
|
AsyncMiddlewareFunction* m = new AsyncMiddlewareFunction(fn);
|
|
m->_freeOnRemoval = true;
|
|
_middlewares.emplace_back(m);
|
|
}
|
|
|
|
void AsyncMiddlewareChain::addMiddleware(AsyncMiddleware* middleware) {
|
|
if (middleware)
|
|
_middlewares.emplace_back(middleware);
|
|
}
|
|
|
|
void AsyncMiddlewareChain::addMiddlewares(std::vector<AsyncMiddleware*> middlewares) {
|
|
for (AsyncMiddleware* m : middlewares)
|
|
addMiddleware(m);
|
|
}
|
|
|
|
bool AsyncMiddlewareChain::removeMiddleware(AsyncMiddleware* middleware) {
|
|
// remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector.
|
|
const size_t size = _middlewares.size();
|
|
_middlewares.erase(std::remove_if(_middlewares.begin(), _middlewares.end(), [middleware](AsyncMiddleware* m) {
|
|
if (m == middleware) {
|
|
if (m->_freeOnRemoval)
|
|
delete m;
|
|
return true;
|
|
}
|
|
return false;
|
|
}),
|
|
_middlewares.end());
|
|
return size != _middlewares.size();
|
|
}
|
|
|
|
void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer) {
|
|
if (!_middlewares.size())
|
|
return finalizer();
|
|
ArMiddlewareNext next;
|
|
std::list<AsyncMiddleware*>::iterator it = _middlewares.begin();
|
|
next = [this, &next, &it, request, finalizer]() {
|
|
if (it == _middlewares.end())
|
|
return finalizer();
|
|
AsyncMiddleware* m = *it;
|
|
it++;
|
|
return m->run(request, next);
|
|
};
|
|
return next();
|
|
}
|
|
|
|
void AuthenticationMiddleware::setUsername(const char* username) {
|
|
_username = username;
|
|
_hasCreds = _username.length() && _credentials.length();
|
|
}
|
|
|
|
void AuthenticationMiddleware::setPassword(const char* password) {
|
|
_credentials = password;
|
|
_hash = false;
|
|
_hasCreds = _username.length() && _credentials.length();
|
|
}
|
|
|
|
void AuthenticationMiddleware::setPasswordHash(const char* hash) {
|
|
_credentials = hash;
|
|
_hash = true;
|
|
_hasCreds = _username.length() && _credentials.length();
|
|
}
|
|
|
|
bool AuthenticationMiddleware::generateHash() {
|
|
// ensure we have all the necessary data
|
|
if (!_hasCreds)
|
|
return false;
|
|
|
|
// if we already have a hash, do nothing
|
|
if (_hash)
|
|
return false;
|
|
|
|
switch (_authMethod) {
|
|
case AsyncAuthType::AUTH_DIGEST:
|
|
_credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str());
|
|
_hash = true;
|
|
return true;
|
|
|
|
case AsyncAuthType::AUTH_BASIC:
|
|
_credentials = generateBasicHash(_username.c_str(), _credentials.c_str());
|
|
_hash = true;
|
|
return true;
|
|
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool AuthenticationMiddleware::allowed(AsyncWebServerRequest* request) {
|
|
if (_authMethod == AsyncAuthType::AUTH_NONE)
|
|
return true;
|
|
|
|
if (!_hasCreds)
|
|
return false;
|
|
|
|
return request->authenticate(_username.c_str(), _credentials.c_str(), _realm.c_str(), _hash);
|
|
}
|
|
|
|
void AuthenticationMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
|
|
return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str());
|
|
}
|
|
|
|
void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
|
|
std::vector<const char*> reqHeaders;
|
|
request->getHeaderNames(reqHeaders);
|
|
for (const char* h : reqHeaders) {
|
|
bool keep = false;
|
|
for (const char* k : _toKeep) {
|
|
if (strcasecmp(h, k) == 0) {
|
|
keep = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!keep) {
|
|
request->removeHeader(h);
|
|
}
|
|
}
|
|
next();
|
|
}
|
|
|
|
void HeaderFilterMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
|
|
for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it)
|
|
request->removeHeader(*it);
|
|
next();
|
|
}
|
|
|
|
void LoggingMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
|
|
if (!isEnabled()) {
|
|
next();
|
|
return;
|
|
}
|
|
_out->print(F("* Connection from "));
|
|
_out->print(request->client()->remoteIP().toString());
|
|
_out->print(':');
|
|
_out->println(request->client()->remotePort());
|
|
_out->print('>');
|
|
_out->print(' ');
|
|
_out->print(request->methodToString());
|
|
_out->print(' ');
|
|
_out->print(request->url().c_str());
|
|
_out->print(F(" HTTP/1."));
|
|
_out->println(request->version());
|
|
for (auto& h : request->getHeaders()) {
|
|
if (h.value().length()) {
|
|
_out->print('>');
|
|
_out->print(' ');
|
|
_out->print(h.name());
|
|
_out->print(':');
|
|
_out->print(' ');
|
|
_out->println(h.value());
|
|
}
|
|
}
|
|
_out->println(F(">"));
|
|
uint32_t elapsed = millis();
|
|
next();
|
|
elapsed = millis() - elapsed;
|
|
AsyncWebServerResponse* response = request->getResponse();
|
|
if (response) {
|
|
_out->print(F("* Processed in "));
|
|
_out->print(elapsed);
|
|
_out->println(F(" ms"));
|
|
_out->print('<');
|
|
_out->print(F(" HTTP/1."));
|
|
_out->print(request->version());
|
|
_out->print(' ');
|
|
_out->print(response->code());
|
|
_out->print(' ');
|
|
_out->println(AsyncWebServerResponse::responseCodeToString(response->code()));
|
|
for (auto& h : response->getHeaders()) {
|
|
if (h.value().length()) {
|
|
_out->print('<');
|
|
_out->print(' ');
|
|
_out->print(h.name());
|
|
_out->print(':');
|
|
_out->print(' ');
|
|
_out->println(h.value());
|
|
}
|
|
}
|
|
_out->println('<');
|
|
} else {
|
|
_out->println(F("* Connection closed!"));
|
|
}
|
|
}
|
|
|
|
void CorsMiddleware::addCORSHeaders(AsyncWebServerResponse* response) {
|
|
response->addHeader(F("Access-Control-Allow-Origin"), _origin.c_str());
|
|
response->addHeader(F("Access-Control-Allow-Methods"), _methods.c_str());
|
|
response->addHeader(F("Access-Control-Allow-Headers"), _headers.c_str());
|
|
response->addHeader(F("Access-Control-Allow-Credentials"), _credentials ? F("true") : F("false"));
|
|
response->addHeader(F("Access-Control-Max-Age"), String(_maxAge).c_str());
|
|
}
|
|
|
|
void CorsMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
|
|
// Origin header ? => CORS handling
|
|
if (request->hasHeader(F("Origin"))) {
|
|
// check if this is a preflight request => handle it and return
|
|
if (request->method() == HTTP_OPTIONS) {
|
|
AsyncWebServerResponse* response = request->beginResponse(200);
|
|
addCORSHeaders(response);
|
|
request->send(response);
|
|
return;
|
|
}
|
|
|
|
// CORS request, no options => let the request pass and add CORS headers after
|
|
next();
|
|
AsyncWebServerResponse* response = request->getResponse();
|
|
if (response) {
|
|
addCORSHeaders(response);
|
|
}
|
|
|
|
} else {
|
|
// NO Origin header => no CORS handling
|
|
next();
|
|
}
|
|
}
|
|
|
|
bool RateLimitMiddleware::isRequestAllowed(uint32_t& retryAfterSeconds) {
|
|
uint32_t now = millis();
|
|
|
|
while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis)
|
|
_requestTimes.pop_front();
|
|
|
|
_requestTimes.push_back(now);
|
|
|
|
if (_requestTimes.size() > _maxRequests) {
|
|
_requestTimes.pop_front();
|
|
retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1;
|
|
return false;
|
|
}
|
|
|
|
retryAfterSeconds = 0;
|
|
return true;
|
|
}
|
|
|
|
void RateLimitMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
|
|
uint32_t retryAfterSeconds;
|
|
if (isRequestAllowed(retryAfterSeconds)) {
|
|
next();
|
|
} else {
|
|
AsyncWebServerResponse* response = request->beginResponse(429);
|
|
response->addHeader(F("Retry-After"), retryAfterSeconds);
|
|
request->send(response);
|
|
}
|
|
}
|