diff options
Diffstat (limited to 'libraries/ESP_Async_WebServer/src/Middleware.cpp')
| -rw-r--r-- | libraries/ESP_Async_WebServer/src/Middleware.cpp | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/libraries/ESP_Async_WebServer/src/Middleware.cpp b/libraries/ESP_Async_WebServer/src/Middleware.cpp new file mode 100644 index 0000000..890303d --- /dev/null +++ b/libraries/ESP_Async_WebServer/src/Middleware.cpp @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// Copyright 2016-2025 Hristo Gochkov, Mathieu Carbou, Emil Muratov + +#include "WebAuthentication.h" +#include <ESPAsyncWebServer.h> + +AsyncMiddlewareChain::~AsyncMiddlewareChain() { + for (AsyncMiddleware *m : _middlewares) { + if (m->_freeOnRemoval) { + delete m; + } + } +} + +void AsyncMiddlewareChain::addMiddleware(ArMiddlewareCallback fn) { + AsyncMiddlewareFunction *m = new AsyncMiddlewareFunction(fn); + m->_freeOnRemoval = true; + _middlewares.emplace_back(m); +} + +void AsyncMiddlewareChain::addMiddleware(AsyncMiddleware *middleware) { + if (middleware) { + _middlewares.emplace_back(middleware); + } +} + +void AsyncMiddlewareChain::addMiddlewares(std::vector<AsyncMiddleware *> middlewares) { + for (AsyncMiddleware *m : middlewares) { + addMiddleware(m); + } +} + +bool AsyncMiddlewareChain::removeMiddleware(AsyncMiddleware *middleware) { + // remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector. + const size_t size = _middlewares.size(); + _middlewares.erase( + std::remove_if( + _middlewares.begin(), _middlewares.end(), + [middleware](AsyncMiddleware *m) { + if (m == middleware) { + if (m->_freeOnRemoval) { + delete m; + } + return true; + } + return false; + } + ), + _middlewares.end() + ); + return size != _middlewares.size(); +} + +void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest *request, ArMiddlewareNext finalizer) { + if (!_middlewares.size()) { + return finalizer(); + } + ArMiddlewareNext next; + std::list<AsyncMiddleware *>::iterator it = _middlewares.begin(); + next = [this, &next, &it, request, finalizer]() { + if (it == _middlewares.end()) { + return finalizer(); + } + AsyncMiddleware *m = *it; + it++; + return m->run(request, next); + }; + return next(); +} + +void AsyncAuthenticationMiddleware::setUsername(const char *username) { + _username = username; + _hasCreds = _username.length() && _credentials.length(); +} + +void AsyncAuthenticationMiddleware::setPassword(const char *password) { + _credentials = password; + _hash = false; + _hasCreds = _username.length() && _credentials.length(); +} + +void AsyncAuthenticationMiddleware::setPasswordHash(const char *hash) { + _credentials = hash; + _hash = _credentials.length(); + _hasCreds = _username.length() && _credentials.length(); +} + +bool AsyncAuthenticationMiddleware::generateHash() { + // ensure we have all the necessary data + if (!_hasCreds) { + return false; + } + + // if we already have a hash, do nothing + if (_hash) { + return false; + } + + switch (_authMethod) { + case AsyncAuthType::AUTH_DIGEST: + _credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str()); + if (_credentials.length()) { + _hash = true; + return true; + } else { + return false; + } + + case AsyncAuthType::AUTH_BASIC: + _credentials = generateBasicHash(_username.c_str(), _credentials.c_str()); + if (_credentials.length()) { + _hash = true; + return true; + } else { + return false; + } + + default: return false; + } +} + +bool AsyncAuthenticationMiddleware::allowed(AsyncWebServerRequest *request) const { + if (_authMethod == AsyncAuthType::AUTH_NONE) { + return true; + } + + if (_authMethod == AsyncAuthType::AUTH_DENIED) { + return false; + } + + if (!_hasCreds) { + return true; + } + + return request->authenticate(_username.c_str(), _credentials.c_str(), _realm.c_str(), _hash); +} + +void AsyncAuthenticationMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) { + return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str()); +} + +void AsyncHeaderFreeMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) { + std::list<const char *> toRemove; + for (auto &h : request->getHeaders()) { + bool keep = false; + for (const char *k : _toKeep) { + if (strcasecmp(h.name().c_str(), k) == 0) { + keep = true; + break; + } + } + if (!keep) { + toRemove.push_back(h.name().c_str()); + } + } + for (const char *h : toRemove) { + request->removeHeader(h); + } + next(); +} + +void AsyncHeaderFilterMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) { + for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it) { + request->removeHeader(*it); + } + next(); +} + +void AsyncLoggingMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) { + if (!isEnabled()) { + next(); + return; + } + _out->print(F("* Connection from ")); + _out->print(request->client()->remoteIP().toString()); + _out->print(':'); + _out->println(request->client()->remotePort()); + _out->print('>'); + _out->print(' '); + _out->print(request->methodToString()); + _out->print(' '); + _out->print(request->url().c_str()); + _out->print(F(" HTTP/1.")); + _out->println(request->version()); + for (auto &h : request->getHeaders()) { + if (h.value().length()) { + _out->print('>'); + _out->print(' '); + _out->print(h.name()); + _out->print(':'); + _out->print(' '); + _out->println(h.value()); + } + } + _out->println(F(">")); + uint32_t elapsed = millis(); + next(); + elapsed = millis() - elapsed; + AsyncWebServerResponse *response = request->getResponse(); + if (response) { + _out->print(F("* Processed in ")); + _out->print(elapsed); + _out->println(F(" ms")); + _out->print('<'); + _out->print(F(" HTTP/1.")); + _out->print(request->version()); + _out->print(' '); + _out->print(response->code()); + _out->print(' '); + _out->println(AsyncWebServerResponse::responseCodeToString(response->code())); + for (auto &h : response->getHeaders()) { + if (h.value().length()) { + _out->print('<'); + _out->print(' '); + _out->print(h.name()); + _out->print(':'); + _out->print(' '); + _out->println(h.value()); + } + } + _out->println('<'); + } else { + _out->println(F("* Connection closed!")); + } +} + +void AsyncCorsMiddleware::addCORSHeaders(AsyncWebServerResponse *response) { + response->addHeader(asyncsrv::T_CORS_ACAO, _origin.c_str()); + response->addHeader(asyncsrv::T_CORS_ACAM, _methods.c_str()); + response->addHeader(asyncsrv::T_CORS_ACAH, _headers.c_str()); + response->addHeader(asyncsrv::T_CORS_ACAC, _credentials ? asyncsrv::T_TRUE : asyncsrv::T_FALSE); + response->addHeader(asyncsrv::T_CORS_ACMA, String(_maxAge).c_str()); +} + +void AsyncCorsMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) { + // Origin header ? => CORS handling + if (request->hasHeader(asyncsrv::T_CORS_O)) { + // check if this is a preflight request => handle it and return + if (request->method() == HTTP_OPTIONS) { + AsyncWebServerResponse *response = request->beginResponse(200); + addCORSHeaders(response); + request->send(response); + return; + } + + // CORS request, no options => let the request pass and add CORS headers after + next(); + AsyncWebServerResponse *response = request->getResponse(); + if (response) { + addCORSHeaders(response); + } + + } else { + // NO Origin header => no CORS handling + next(); + } +} + +bool AsyncRateLimitMiddleware::isRequestAllowed(uint32_t &retryAfterSeconds) { + uint32_t now = millis(); + + while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis) { + _requestTimes.pop_front(); + } + + _requestTimes.push_back(now); + + if (_requestTimes.size() > _maxRequests) { + _requestTimes.pop_front(); + retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1; + return false; + } + + retryAfterSeconds = 0; + return true; +} + +void AsyncRateLimitMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) { + uint32_t retryAfterSeconds; + if (isRequestAllowed(retryAfterSeconds)) { + next(); + } else { + AsyncWebServerResponse *response = request->beginResponse(429); + response->addHeader(asyncsrv::T_retry_after, retryAfterSeconds); + request->send(response); + } +} |
