diff --git a/src/proxy/tinymitm/proxy.cpp b/src/proxy/tinymitm/proxy.cpp new file mode 100644 index 0000000..dc51660 --- /dev/null +++ b/src/proxy/tinymitm/proxy.cpp @@ -0,0 +1,640 @@ + +#include "proxy.h" + +#include +#include + +#if defined(_WIN64) || defined(_WIN32) + #define FD_SETSIZE 1024 + #define NOMINMAX + + #include + #include + + #define CLOSE_SOCKET closesocket + #define SHUT_RDWR SD_BOTH +#endif + +#include +#include +#include "ssl.h" + +#include "raai-helper.h" + +/* + RAAI helpers +*/ +using WOLF_ptr = std::unique_ptr; +using AddrInfo_Ptr = std::unique_ptr; + +struct AutoSocket +{ + SOCKET s; + AutoSocket(SOCKET val = INVALID_SOCKET) : s(val) {} + ~AutoSocket() + { + if (s != INVALID_SOCKET) CLOSE_SOCKET(s); + } + operator SOCKET() const { return s; } + SOCKET release() + { + SOCKET tmp = s; + s = INVALID_SOCKET; + return tmp; + } +}; + +/* + HTTPStream +*/ +struct HttpStream +{ + std::string buffer; + bool isReceivingBody = false; + bool isChunked = false; + long long contentLength = -1; + size_t headersEnd = std::string::npos; + int statusCode = 0; + size_t currentChunkIdx = 0; + std::string payload; + + void reset() + { + isReceivingBody = false; + isChunked = false; + contentLength = -1; + headersEnd = std::string::npos; + statusCode = 0; + currentChunkIdx = 0; + payload.clear(); + } + + bool parseHeaders() + { + headersEnd = buffer.find("\r\n\r\n"); + if (headersEnd == std::string::npos) return false; + std::string headers = buffer.substr(0, headersEnd + 4); + std::string te = TinyMITMProxy::getHeader(headers, "Transfer-Encoding"); + std::transform(te.begin(), te.end(), te.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + isChunked = (te.find("chunked") != std::string::npos); + if (!isChunked) + { + std::string cl = TinyMITMProxy::getHeader(headers, "Content-Length"); + contentLength = TinyMITMProxy::stollSafe(cl, -1); + } + if (headers.compare(0, 5, "HTTP/") == 0) + { + size_t space = headers.find(' '); + if (space != std::string::npos) + statusCode = static_cast(TinyMITMProxy::stollSafe(headers.substr(space + 1, 3))); + } + isReceivingBody = true; + return true; + } +}; + +/* + platform specific stuff +*/ +void setNonBlocking(SOCKET s) +{ +#ifdef _WIN32 + unsigned long mode = 1; + ioctlsocket(s, FIONBIO, &mode); +#else + fcntl(s, F_SETFL, fcntl(s, F_GETFL, 0) | O_NONBLOCK); +#endif +} + +/* + TinyMITMProxy implementation +*/ +TinyMITMProxy::~TinyMITMProxy() +{ + shutdown(); +} + +bool TinyMITMProxy::init() +{ + _running = true; + + if (!_certManager.init()) return false; + + // wolfssl setup + wolfSSL_Init(); + + _clientCtx = wolfSSL_CTX_new(wolfTLSv1_3_client_method()); + if (!_clientCtx) return false; + + wolfSSL_CTX_set_verify(_clientCtx, WOLFSSL_VERIFY_NONE, nullptr); + + wolfSSL_CTX_set_alpn_select_cb( + _clientCtx, + [](WOLFSSL* /*ssl*/, const unsigned char** out, unsigned char* outLen, const unsigned char* /*in*/, + unsigned int /*inLen*/, void* /*arg*/) { + static const unsigned char forcedProtocol[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'}; + *out = forcedProtocol; + *outLen = sizeof(forcedProtocol); + return 0; + }, + nullptr); + + // socket setup +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false; +#endif + + _listenSocket = socket(AF_INET, SOCK_STREAM, 0); + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(_port); + addr.sin_addr.s_addr = INADDR_ANY; + + if (bind(_listenSocket, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) return false; + listen(_listenSocket, SOMAXCONN); + + // handler threads + for (unsigned char i = 0; i < _threadCount; i++) + { + _poolThreads.emplace_back([this]() { + while (_running) + { + SOCKET client; + std::unique_lock lock(_queueMutex); + _queueCond.wait(lock, [this]() { return !_clientQueue.empty() || !_running; }); + if (!_running && _clientQueue.empty()) return; + client = _clientQueue.front(); + _clientQueue.pop(); + this->handleClient(client); + } + }); + } + + // dispatcher thread + _dispatchThread = std::thread([this] { + while (_running) + { + SOCKET client = accept(_listenSocket, nullptr, nullptr); + if (client == INVALID_SOCKET) continue; + { + std::lock_guard lock(_queueMutex); + _clientQueue.push(client); + } + _queueCond.notify_one(); + } + }); + + return true; +} + +void TinyMITMProxy::shutdown() +{ + if (!_running) return; + _running = false; + + _queueCond.notify_all(); + + if (_listenSocket != INVALID_SOCKET) + { + ::shutdown(_listenSocket, SHUT_RDWR); + CLOSE_SOCKET(_listenSocket); + _listenSocket = INVALID_SOCKET; + } + + if (_dispatchThread.joinable()) _dispatchThread.join(); + + for (auto& t : _poolThreads) + if (t.joinable()) t.join(); + + if (_clientCtx) + { + wolfSSL_CTX_free(_clientCtx); + _clientCtx = nullptr; + } + + wolfSSL_Cleanup(); + +#ifdef _WIN32 + WSACleanup(); +#endif +} + +void TinyMITMProxy::handleClient(SOCKET clientSocket) +{ + AutoSocket clientGuard(clientSocket); + + auto bufPtr = std::make_unique(TINYMITM_CLIENT_BUFF_SIZE); + char* buf = bufPtr.get(); + + /* + initial CONNECT peek + */ + int n = recv(clientGuard, buf, TINYMITM_CLIENT_BUFF_SIZE - 1, 0); + if (n <= 0) return; + buf[n] = '\0'; + + std::string req(buf); + if (req.find("CONNECT ") != 0) return; + + /* + port parsing + */ + size_t space = req.find(' ', 8); + std::string fullHost = req.substr(8, space - 8); + size_t colon = fullHost.find(':'); + std::string host = (colon != std::string::npos) ? fullHost.substr(0, colon) : fullHost; + std::string port = (colon != std::string::npos) ? fullHost.substr(colon + 1) : "443"; + + /* + remote connection + */ + addrinfo hints{}, *rawRes; + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + if (getaddrinfo(host.c_str(), port.c_str(), &hints, &rawRes) != 0) return; + AddrInfo_Ptr res(rawRes); + + AutoSocket remoteGuard(socket(AF_INET, SOCK_STREAM, 0)); + if (connect(remoteGuard, res->ai_addr, static_cast(res->ai_addrlen)) != 0) return; + + const char* connEstablished = "HTTP/1.1 200 Connection Established\r\n\r\n"; + send(clientGuard, connEstablished, static_cast(strlen(connEstablished)), 0); + + /* + wolfss setup + */ + WOLFSSL_CTX* hostCtx = _certManager.createHostContext(host); + if (!hostCtx) return; + + WOLF_ptr clientSSL(wolfSSL_new(hostCtx)); + WOLF_ptr remoteSSL(wolfSSL_new(_clientCtx)); + + wolfSSL_set_fd(clientSSL.get(), (int)clientGuard); + wolfSSL_set_fd(remoteSSL.get(), (int)remoteGuard); + + char alpnList[] = "\x08http/1.1"; + wolfSSL_UseALPN(remoteSSL.get(), alpnList, sizeof(alpnList) - 1, 0); + wolfSSL_UseSNI(remoteSSL.get(), WOLFSSL_SNI_HOST_NAME, host.c_str(), (unsigned short)host.size()); + + setNonBlocking(clientGuard); + setNonBlocking(remoteGuard); + + if (!doHandshake(clientSSL.get(), clientGuard, true)) return; + if (!doHandshake(remoteSSL.get(), remoteGuard, false)) return; + + /* + traffic loop + */ + HttpStream clientStream, serverStream; + std::deque pendingUrls; + bool tunnelMode = false; + + while (_running) + { + + fd_set r_fds; + FD_ZERO(&r_fds); + FD_SET(clientGuard, &r_fds); + FD_SET(remoteGuard, &r_fds); + + struct timeval tv{0, 50000}; + bool hasBuffered = (wolfSSL_pending(clientSSL.get()) > 0 || wolfSSL_pending(remoteSSL.get()) > 0); + + if (!hasBuffered) + { +#ifdef _WIN32 + int nfds = 0; +#else + int nfds = static_cast(std::max(clientGuard.s, remoteGuard.s)) + 1; +#endif + if (select(nfds, &r_fds, nullptr, nullptr, &tv) < 0) break; + } + + /* + client -> server + */ + if (FD_ISSET(clientGuard, &r_fds) || wolfSSL_pending(clientSSL.get())) + { + int rd = wolfSSL_read(clientSSL.get(), buf, TINYMITM_CLIENT_BUFF_SIZE); + if (rd <= 0) + { + if (wolfSSL_get_error(clientSSL.get(), rd) != WOLFSSL_ERROR_WANT_READ) break; + } + else + { + if (tunnelMode) + { + wolfSSL_write(remoteSSL.get(), buf, rd); + } + else + { + clientStream.buffer.append(buf, rd); + while (true) + { + if (!clientStream.isReceivingBody) + { + if (!clientStream.parseHeaders()) break; + std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); + std::string path = "/"; + size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1); + if (s1 != std::string::npos && s2 != std::string::npos) + path = headers.substr(s1 + 1, s2 - s1 - 1); + pendingUrls.push_back("https://" + host + path); + } + + std::string fullBody; + size_t totalRequestSize = 0; + bool complete = false; + + if (clientStream.isChunked) + { + if (clientStream.currentChunkIdx == 0) + clientStream.currentChunkIdx = clientStream.headersEnd + 4; + while (clientStream.currentChunkIdx < clientStream.buffer.size()) + { + size_t idx = clientStream.currentChunkIdx; + size_t le = clientStream.buffer.find("\r\n", idx); + if (le == std::string::npos) break; + long long cs = stollSafe(clientStream.buffer.substr(idx, le - idx), -1, 16); + if (cs < 0 || idx + (le - idx) + 2 + cs + 2 > clientStream.buffer.size()) break; + if (cs > 0) clientStream.payload.append(clientStream.buffer, le + 2, cs); + clientStream.currentChunkIdx = le + 2 + cs + 2; + if (cs == 0) + { + fullBody = std::move(clientStream.payload); + complete = true; + totalRequestSize = clientStream.currentChunkIdx; + break; + } + } + } + else + { + long long cl = (clientStream.contentLength < 0) ? 0 : clientStream.contentLength; + if (clientStream.buffer.size() >= (clientStream.headersEnd + 4 + cl)) + { + fullBody = clientStream.buffer.substr(clientStream.headersEnd + 4, cl); + complete = true; + totalRequestSize = clientStream.headersEnd + 4 + cl; + } + } + + if (complete) + { + std::string url = pendingUrls.back(); + std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); + + removeHeader(headers, "Accept-Encoding"); + headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n"); + + bool blockOutgoing = false; + onClientRequest.run(url, fullBody, headers, blockOutgoing); + + if (blockOutgoing) + { + std::string mockHeaders = "HTTP/1.1 500 Internal Server Error\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n\r\n"; + std::string mockBody = "Request blocked by proxy."; + + onServerResponse.run(url, mockBody, mockHeaders, true); + + removeHeader(mockHeaders, "Content-Length"); + removeHeader(mockHeaders, "Transfer-Encoding"); + mockHeaders.insert(mockHeaders.size() - 2, + "Content-Length: " + std::to_string(mockBody.size()) + "\r\n"); + + std::string packet = mockHeaders + mockBody; + if (wolfSSL_write(clientSSL.get(), packet.data(), (int)packet.size()) <= 0) break; + + if (!pendingUrls.empty()) pendingUrls.pop_back(); + } + else + { + removeHeader(headers, "Transfer-Encoding"); + removeHeader(headers, "Content-Length"); + headers.insert(headers.size() - 2, + "Content-Length: " + std::to_string(fullBody.size()) + "\r\n"); + + if (wolfSSL_write(remoteSSL.get(), headers.data(), (int)headers.size()) <= 0) break; + if (wolfSSL_write(remoteSSL.get(), fullBody.data(), (int)fullBody.size()) <= 0) break; + } + + clientStream.buffer.erase(0, totalRequestSize); + clientStream.reset(); + } + } + } + } + } + + /* + server -> client + */ + if (FD_ISSET(remoteGuard, &r_fds) || wolfSSL_pending(remoteSSL.get())) + { + int rd = wolfSSL_read(remoteSSL.get(), buf, TINYMITM_CLIENT_BUFF_SIZE); + bool closed = (rd <= 0 && wolfSSL_get_error(remoteSSL.get(), rd) != WOLFSSL_ERROR_WANT_READ); + + if (rd > 0) + { + if (tunnelMode) + { + wolfSSL_write(clientSSL.get(), buf, rd); + } + else + { + serverStream.buffer.append(buf, rd); + while (true) + { + if (!serverStream.isReceivingBody && !serverStream.parseHeaders()) break; + + std::string fullBody; + size_t totalResponseSize = 0; + bool complete = false; + + if (serverStream.statusCode == 204 || serverStream.statusCode == 304 || + (serverStream.statusCode >= 100 && serverStream.statusCode < 200)) + { + complete = true; + totalResponseSize = serverStream.headersEnd + 4; + } + else if (serverStream.isChunked) + { + if (serverStream.currentChunkIdx == 0) + serverStream.currentChunkIdx = serverStream.headersEnd + 4; + while (serverStream.currentChunkIdx < serverStream.buffer.size()) + { + size_t idx = serverStream.currentChunkIdx; + size_t le = serverStream.buffer.find("\r\n", idx); + if (le == std::string::npos) break; + long long cs = stollSafe(serverStream.buffer.substr(idx, le - idx), -1, 16); + if (cs < 0 || idx + (le - idx) + 2 + cs + 2 > serverStream.buffer.size()) break; + if (cs > 0) serverStream.payload.append(serverStream.buffer, le + 2, cs); + serverStream.currentChunkIdx = le + 2 + cs + 2; + if (cs == 0) + { + fullBody = std::move(serverStream.payload); + complete = true; + totalResponseSize = serverStream.currentChunkIdx; + break; + } + } + } + else if (serverStream.contentLength >= 0) + { + if (serverStream.buffer.size() >= + (serverStream.headersEnd + 4 + serverStream.contentLength)) + { + fullBody = + serverStream.buffer.substr(serverStream.headersEnd + 4, serverStream.contentLength); + complete = true; + totalResponseSize = serverStream.headersEnd + 4 + serverStream.contentLength; + } + } + else if (closed) + { + fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4); + complete = true; + totalResponseSize = serverStream.buffer.size(); + } + + if (complete) + { + std::string url = pendingUrls.empty() ? "https://" + host : pendingUrls.front(); + if (!pendingUrls.empty()) pendingUrls.pop_front(); + std::string respHeaders = serverStream.buffer.substr(0, serverStream.headersEnd + 4); + + if (serverStream.statusCode == 101) tunnelMode = true; + + onServerResponse.run(url, fullBody, respHeaders, false); + + removeHeader(respHeaders, "Transfer-Encoding"); + removeHeader(respHeaders, "Content-Length"); + respHeaders.insert(respHeaders.size() - 2, + "Content-Length: " + std::to_string(fullBody.size()) + "\r\n"); + + std::string packet = respHeaders + fullBody; + wolfSSL_write(clientSSL.get(), packet.data(), (int)packet.size()); + + serverStream.buffer.erase(0, totalResponseSize); + + serverStream.reset(); + if (tunnelMode) break; + } + else + break; + } + } + } + if (closed) break; + } + } +} + +long long TinyMITMProxy::stollSafe(const std::string& s, int def, int base) +{ + if (s.empty()) return def; + try + { + return std::stoll(s, nullptr, base); + } + catch (...) + { + return def; + } +} + +std::string TinyMITMProxy::getHeader(const std::string& headers, std::string key) +{ + std::string keyLower = key; + std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + size_t pos = 0; + while ((pos = headers.find("\r\n", pos)) != std::string::npos) + { + pos += 2; + size_t lineEnd = headers.find("\r\n", pos); + if (lineEnd == std::string::npos) break; + + std::string line = headers.substr(pos, lineEnd - pos); + size_t colon = line.find(':'); + if (colon == std::string::npos) continue; + + std::string k = line.substr(0, colon); + std::transform(k.begin(), k.end(), k.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + k.erase(k.find_last_not_of(" \t") + 1); + + if (k == keyLower) + { + std::string val = line.substr(colon + 1); + val.erase(0, val.find_first_not_of(" \t")); + val.erase(val.find_last_not_of(" \t") + 1); + return val; + } + } + return ""; +} + +void TinyMITMProxy::removeHeader(std::string& headers, const std::string& key) +{ + if (!headers.empty() && headers.back() != '\n') headers += "\r\n"; + std::string result, keyLower = key; + std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + size_t start = 0, end; + while ((end = headers.find('\n', start)) != std::string::npos) + { + std::string line = headers.substr(start, end - start + 1); + std::string lineLower = line; + std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + bool match = false; + if (lineLower.compare(0, keyLower.length(), keyLower) == 0) + { + size_t pos = keyLower.length(); + while (pos < lineLower.length() && (lineLower[pos] == ' ' || lineLower[pos] == '\t')) + pos++; + if (pos < lineLower.length() && lineLower[pos] == ':') match = true; + } + if (!match) result += line; + start = end + 1; + } + headers = std::move(result); +} + +bool TinyMITMProxy::doHandshake(WOLFSSL* ssl, SOCKET socket, bool isAccept) +{ + while (true) + { + int ret = isAccept ? wolfSSL_accept(ssl) : wolfSSL_connect(ssl); + if (ret == WOLFSSL_SUCCESS) return true; + + int err = wolfSSL_get_error(ssl, ret); + if (err == WOLFSSL_ERROR_WANT_READ || err == WOLFSSL_ERROR_WANT_WRITE) + { + fd_set fds; + FD_ZERO(&fds); + FD_SET(socket, &fds); + struct timeval tv{TINYMTM_HANDSHAKE_TIMEOUT, 0}; + +#ifdef _WIN32 + int nfds = 0; +#else + int nfds = static_cast(socket) + 1; +#endif + + int sel = select(nfds, (err == WOLFSSL_ERROR_WANT_READ) ? &fds : nullptr, + (err == WOLFSSL_ERROR_WANT_WRITE) ? &fds : nullptr, nullptr, &tv); + if (sel <= 0) return false; + + continue; + } + return false; + } +} diff --git a/src/proxy/tinymitm/proxy.h b/src/proxy/tinymitm/proxy.h new file mode 100644 index 0000000..a55fe2c --- /dev/null +++ b/src/proxy/tinymitm/proxy.h @@ -0,0 +1,100 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "ssl.h" + +#if defined(_WIN64) || defined(_WIN32) +typedef unsigned __int64 SOCKET; +#else +typedef int SOCKET; +#endif + +struct WOLFSSL_CTX; +struct WOLFSSL; + +#ifndef TINYMITM_CLIENT_BUFF_SIZE + #define TINYMITM_CLIENT_BUFF_SIZE 16384 +#endif + +#ifndef TINYMTM_HANDSHAKE_TIMEOUT + #define TINYMTM_HANDSHAKE_TIMEOUT 5 +#endif + +class TinyMITMProxy +{ + public: + TinyMITMProxy(unsigned short port = 44444, unsigned char threadCount = 255) + : _port(port), _threadCount(threadCount) {}; + ~TinyMITMProxy(); + + bool init(); + void shutdown(); + + inline unsigned short getPort() { return _port; } + inline bool getRunning() { return _running; } + + /* + onClientRequest is emitted every time the proxy attempts to start a request to a domain + + Setting blockOutgoing to true will result in the request being ignored, onServerResponse will still be fired and a mock response can be generated there. + If response is not edited on onServerResponse, the response will be set to a generic 500 error + + Arguments are: + * const std::string& url + * std::string& body + * std::string& headers + * bool& blockOutgoing + */ + seallib::Event onClientRequest; + + /* + onServerResponse is emitted once the proxy receives a response to a request + + Arguments are: + * const std::string& url + * std::string& body + * std::string& headers + * bool wasBlocked + */ + seallib::Event onServerResponse; + + protected: + // helper functions + static long long stollSafe(const std::string& s, int def = 0, int base = 10); + static std::string getHeader(const std::string& headers, std::string key); + static void removeHeader(std::string& headers, const std::string& key); + + friend struct HttpStream; + + private: + TinyMITMProxy() = delete; + + void handleClient(SOCKET clientSocket); + + static bool doHandshake(WOLFSSL* ssl, SOCKET socket, bool isAccept); + + unsigned short _port; + unsigned char _threadCount; + + SOCKET _listenSocket = 0; + + std::atomic _running = false; + + std::thread _dispatchThread; + + std::mutex _queueMutex; + std::condition_variable _queueCond; + std::queue _clientQueue; + + std::vector _poolThreads; + + WOLFSSL_CTX* _clientCtx = nullptr; + + CertificateManager _certManager; +};