summaryrefslogtreecommitdiff
path: root/libraries/ESP_Async_WebServer/src/Middleware.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'libraries/ESP_Async_WebServer/src/Middleware.cpp')
-rw-r--r--libraries/ESP_Async_WebServer/src/Middleware.cpp287
1 files changed, 287 insertions, 0 deletions
diff --git a/libraries/ESP_Async_WebServer/src/Middleware.cpp b/libraries/ESP_Async_WebServer/src/Middleware.cpp
new file mode 100644
index 0000000..890303d
--- /dev/null
+++ b/libraries/ESP_Async_WebServer/src/Middleware.cpp
@@ -0,0 +1,287 @@
+// SPDX-License-Identifier: LGPL-3.0-or-later
+// Copyright 2016-2025 Hristo Gochkov, Mathieu Carbou, Emil Muratov
+
+#include "WebAuthentication.h"
+#include <ESPAsyncWebServer.h>
+
+AsyncMiddlewareChain::~AsyncMiddlewareChain() {
+ for (AsyncMiddleware *m : _middlewares) {
+ if (m->_freeOnRemoval) {
+ delete m;
+ }
+ }
+}
+
+void AsyncMiddlewareChain::addMiddleware(ArMiddlewareCallback fn) {
+ AsyncMiddlewareFunction *m = new AsyncMiddlewareFunction(fn);
+ m->_freeOnRemoval = true;
+ _middlewares.emplace_back(m);
+}
+
+void AsyncMiddlewareChain::addMiddleware(AsyncMiddleware *middleware) {
+ if (middleware) {
+ _middlewares.emplace_back(middleware);
+ }
+}
+
+void AsyncMiddlewareChain::addMiddlewares(std::vector<AsyncMiddleware *> middlewares) {
+ for (AsyncMiddleware *m : middlewares) {
+ addMiddleware(m);
+ }
+}
+
+bool AsyncMiddlewareChain::removeMiddleware(AsyncMiddleware *middleware) {
+ // remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector.
+ const size_t size = _middlewares.size();
+ _middlewares.erase(
+ std::remove_if(
+ _middlewares.begin(), _middlewares.end(),
+ [middleware](AsyncMiddleware *m) {
+ if (m == middleware) {
+ if (m->_freeOnRemoval) {
+ delete m;
+ }
+ return true;
+ }
+ return false;
+ }
+ ),
+ _middlewares.end()
+ );
+ return size != _middlewares.size();
+}
+
+void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest *request, ArMiddlewareNext finalizer) {
+ if (!_middlewares.size()) {
+ return finalizer();
+ }
+ ArMiddlewareNext next;
+ std::list<AsyncMiddleware *>::iterator it = _middlewares.begin();
+ next = [this, &next, &it, request, finalizer]() {
+ if (it == _middlewares.end()) {
+ return finalizer();
+ }
+ AsyncMiddleware *m = *it;
+ it++;
+ return m->run(request, next);
+ };
+ return next();
+}
+
+void AsyncAuthenticationMiddleware::setUsername(const char *username) {
+ _username = username;
+ _hasCreds = _username.length() && _credentials.length();
+}
+
+void AsyncAuthenticationMiddleware::setPassword(const char *password) {
+ _credentials = password;
+ _hash = false;
+ _hasCreds = _username.length() && _credentials.length();
+}
+
+void AsyncAuthenticationMiddleware::setPasswordHash(const char *hash) {
+ _credentials = hash;
+ _hash = _credentials.length();
+ _hasCreds = _username.length() && _credentials.length();
+}
+
+bool AsyncAuthenticationMiddleware::generateHash() {
+ // ensure we have all the necessary data
+ if (!_hasCreds) {
+ return false;
+ }
+
+ // if we already have a hash, do nothing
+ if (_hash) {
+ return false;
+ }
+
+ switch (_authMethod) {
+ case AsyncAuthType::AUTH_DIGEST:
+ _credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str());
+ if (_credentials.length()) {
+ _hash = true;
+ return true;
+ } else {
+ return false;
+ }
+
+ case AsyncAuthType::AUTH_BASIC:
+ _credentials = generateBasicHash(_username.c_str(), _credentials.c_str());
+ if (_credentials.length()) {
+ _hash = true;
+ return true;
+ } else {
+ return false;
+ }
+
+ default: return false;
+ }
+}
+
+bool AsyncAuthenticationMiddleware::allowed(AsyncWebServerRequest *request) const {
+ if (_authMethod == AsyncAuthType::AUTH_NONE) {
+ return true;
+ }
+
+ if (_authMethod == AsyncAuthType::AUTH_DENIED) {
+ return false;
+ }
+
+ if (!_hasCreds) {
+ return true;
+ }
+
+ return request->authenticate(_username.c_str(), _credentials.c_str(), _realm.c_str(), _hash);
+}
+
+void AsyncAuthenticationMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
+ return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str());
+}
+
+void AsyncHeaderFreeMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
+ std::list<const char *> toRemove;
+ for (auto &h : request->getHeaders()) {
+ bool keep = false;
+ for (const char *k : _toKeep) {
+ if (strcasecmp(h.name().c_str(), k) == 0) {
+ keep = true;
+ break;
+ }
+ }
+ if (!keep) {
+ toRemove.push_back(h.name().c_str());
+ }
+ }
+ for (const char *h : toRemove) {
+ request->removeHeader(h);
+ }
+ next();
+}
+
+void AsyncHeaderFilterMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
+ for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it) {
+ request->removeHeader(*it);
+ }
+ next();
+}
+
+void AsyncLoggingMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
+ if (!isEnabled()) {
+ next();
+ return;
+ }
+ _out->print(F("* Connection from "));
+ _out->print(request->client()->remoteIP().toString());
+ _out->print(':');
+ _out->println(request->client()->remotePort());
+ _out->print('>');
+ _out->print(' ');
+ _out->print(request->methodToString());
+ _out->print(' ');
+ _out->print(request->url().c_str());
+ _out->print(F(" HTTP/1."));
+ _out->println(request->version());
+ for (auto &h : request->getHeaders()) {
+ if (h.value().length()) {
+ _out->print('>');
+ _out->print(' ');
+ _out->print(h.name());
+ _out->print(':');
+ _out->print(' ');
+ _out->println(h.value());
+ }
+ }
+ _out->println(F(">"));
+ uint32_t elapsed = millis();
+ next();
+ elapsed = millis() - elapsed;
+ AsyncWebServerResponse *response = request->getResponse();
+ if (response) {
+ _out->print(F("* Processed in "));
+ _out->print(elapsed);
+ _out->println(F(" ms"));
+ _out->print('<');
+ _out->print(F(" HTTP/1."));
+ _out->print(request->version());
+ _out->print(' ');
+ _out->print(response->code());
+ _out->print(' ');
+ _out->println(AsyncWebServerResponse::responseCodeToString(response->code()));
+ for (auto &h : response->getHeaders()) {
+ if (h.value().length()) {
+ _out->print('<');
+ _out->print(' ');
+ _out->print(h.name());
+ _out->print(':');
+ _out->print(' ');
+ _out->println(h.value());
+ }
+ }
+ _out->println('<');
+ } else {
+ _out->println(F("* Connection closed!"));
+ }
+}
+
+void AsyncCorsMiddleware::addCORSHeaders(AsyncWebServerResponse *response) {
+ response->addHeader(asyncsrv::T_CORS_ACAO, _origin.c_str());
+ response->addHeader(asyncsrv::T_CORS_ACAM, _methods.c_str());
+ response->addHeader(asyncsrv::T_CORS_ACAH, _headers.c_str());
+ response->addHeader(asyncsrv::T_CORS_ACAC, _credentials ? asyncsrv::T_TRUE : asyncsrv::T_FALSE);
+ response->addHeader(asyncsrv::T_CORS_ACMA, String(_maxAge).c_str());
+}
+
+void AsyncCorsMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
+ // Origin header ? => CORS handling
+ if (request->hasHeader(asyncsrv::T_CORS_O)) {
+ // check if this is a preflight request => handle it and return
+ if (request->method() == HTTP_OPTIONS) {
+ AsyncWebServerResponse *response = request->beginResponse(200);
+ addCORSHeaders(response);
+ request->send(response);
+ return;
+ }
+
+ // CORS request, no options => let the request pass and add CORS headers after
+ next();
+ AsyncWebServerResponse *response = request->getResponse();
+ if (response) {
+ addCORSHeaders(response);
+ }
+
+ } else {
+ // NO Origin header => no CORS handling
+ next();
+ }
+}
+
+bool AsyncRateLimitMiddleware::isRequestAllowed(uint32_t &retryAfterSeconds) {
+ uint32_t now = millis();
+
+ while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis) {
+ _requestTimes.pop_front();
+ }
+
+ _requestTimes.push_back(now);
+
+ if (_requestTimes.size() > _maxRequests) {
+ _requestTimes.pop_front();
+ retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1;
+ return false;
+ }
+
+ retryAfterSeconds = 0;
+ return true;
+}
+
+void AsyncRateLimitMiddleware::run(AsyncWebServerRequest *request, ArMiddlewareNext next) {
+ uint32_t retryAfterSeconds;
+ if (isRequestAllowed(retryAfterSeconds)) {
+ next();
+ } else {
+ AsyncWebServerResponse *response = request->beginResponse(429);
+ response->addHeader(asyncsrv::T_retry_after, retryAfterSeconds);
+ request->send(response);
+ }
+}