From f328baf897380a0fa72fdef9f0e20c57f39d71f7 Mon Sep 17 00:00:00 2001 From: neru Date: Wed, 13 May 2026 12:07:50 -0300 Subject: [PATCH] fix: refactor and fix client handler --- src/proxy/tinymitm/proxy.cpp | 139 +++++++++++++++++++++++------------ 1 file changed, 93 insertions(+), 46 deletions(-) diff --git a/src/proxy/tinymitm/proxy.cpp b/src/proxy/tinymitm/proxy.cpp index 4fa6fff..2d0e3f8 100644 --- a/src/proxy/tinymitm/proxy.cpp +++ b/src/proxy/tinymitm/proxy.cpp @@ -356,23 +356,47 @@ void TinyMITMProxy::handleClient(SOCKET clientSocket) buf[n] = '\0'; std::string req(buf); - if (req.find("CONNECT ") != 0) + bool isConnect = (req.find("CONNECT ") != std::string::npos); + + std::string host, port; + if (isConnect) { - TINYMITM_WRITELOG(error, "handleClient was fed a request that was not a CONNECT request"); - return; + size_t connectPos = req.find("CONNECT "); + size_t endOfHost = req.find_first_of(" \r\n", connectPos + 8); + std::string fullHost = req.substr(connectPos + 8, endOfHost - (connectPos + 8)); + size_t colon = fullHost.find(':'); + host = (colon != std::string::npos) ? fullHost.substr(0, colon) : fullHost; + port = (colon != std::string::npos) ? fullHost.substr(colon + 1) : "443"; + } + else + { + size_t hostStart = req.find("http://"); + if (hostStart != std::string::npos) + { + hostStart += 7; + size_t hostEnd = req.find_first_of(":/ \r\n", hostStart); + host = req.substr(hostStart, hostEnd - hostStart); + if (req[hostEnd] == ':') { + size_t portEnd = req.find_first_of("/ \r\n", hostEnd + 1); + port = req.substr(hostEnd + 1, portEnd - (hostEnd + 1)); + } else port = "80"; + } + else + { + host = getHeader(req, "Host"); + size_t colon = host.find(':'); + if (colon != std::string::npos) { + port = host.substr(colon + 1); + host = host.substr(0, colon); + } else port = "80"; + } + if (host.empty()) { + TINYMITM_WRITELOG(error, "Unable to parse host from request: {}", req.substr(0, 100)); + return; + } } /* - port parsing - */ - size_t endOfHost = req.find_first_of(" \r\n", 8); - std::string fullHost = req.substr(8, endOfHost - 8); - size_t colon = fullHost.find(':'); - std::string host = (colon != std::string::npos) ? fullHost.substr(0, colon) : fullHost; - std::string port = (colon != std::string::npos) ? fullHost.substr(colon + 1) : "443"; - - /* - remote connection remote connection */ addrinfo hints{}, *rawRes; @@ -384,44 +408,67 @@ void TinyMITMProxy::handleClient(SOCKET clientSocket) AutoSocket remoteGuard(socket(AF_INET, SOCK_STREAM, 0)); if (connect(remoteGuard, res->ai_addr, static_cast(res->ai_addrlen)) != 0) return; - const char* connEstablished = "HTTP/1.1 200 Connection Established\r\n\r\n"; - send(clientGuard, connEstablished, static_cast(strlen(connEstablished)), 0); + if (isConnect) + { + const char* connEstablished = "HTTP/1.1 200 Connection Established\r\n\r\n"; + send(clientGuard, connEstablished, static_cast(strlen(connEstablished)), 0); + } - /* - wolfss setup - */ - WOLFSSL_CTX* hostCtx = _certManager.createHostContext(host); - if (!hostCtx) return; + WOLF_ptr clientSSL(nullptr), remoteSSL(nullptr); + if (isConnect) + { + WOLFSSL_CTX* hostCtx = _certManager.createHostContext(host); + if (!hostCtx) return; - WOLF_ptr clientSSL(wolfSSL_new(hostCtx)); - WOLF_ptr remoteSSL(wolfSSL_new(_clientCtx)); - - wolfSSL_set_fd(clientSSL.get(), (int)clientGuard); - wolfSSL_set_fd(remoteSSL.get(), (int)remoteGuard); + clientSSL.reset(wolfSSL_new(hostCtx)); + remoteSSL.reset(wolfSSL_new(_clientCtx)); // temporarily removed alpn //char alpnList[] = "\x08http/1.1"; //wolfSSL_UseALPN(remoteSSL.get(), alpnList, sizeof(alpnList) - 1, 0); //wolfSSL_UseALPN(clientSSL.get(), alpnList, sizeof(alpnList) - 1, 0); + wolfSSL_set_fd(clientSSL.get(), (int)clientGuard); + wolfSSL_set_fd(remoteSSL.get(), (int)remoteGuard); - wolfSSL_UseSNI(remoteSSL.get(), WOLFSSL_SNI_HOST_NAME, host.c_str(), (unsigned short)host.size()); + wolfSSL_UseSNI(remoteSSL.get(), WOLFSSL_SNI_HOST_NAME, host.c_str(), (unsigned short)host.size()); - setNonBlocking(clientGuard, true); - setNonBlocking(remoteGuard, true); + setNonBlocking(clientGuard, true); + setNonBlocking(remoteGuard, true); - TINYMITM_WRITELOG(verbose, "Starting handshakes for {}", host); - if (!doHandshake(clientSSL.get(), clientGuard, true)) - { - TINYMITM_WRITELOG(error, "Client handshake failed for: {}", host); - return; + TINYMITM_WRITELOG(verbose, "Starting handshakes for {}", host); + if (!doHandshake(clientSSL.get(), clientGuard, true)) return; + if (!doHandshake(remoteSSL.get(), remoteGuard, false)) return; } - if (!doHandshake(remoteSSL.get(), remoteGuard, false)) + else { - TINYMITM_WRITELOG(error, "Remote handshake failed for: {}", host); - return; + setNonBlocking(clientGuard, true); + setNonBlocking(remoteGuard, true); + if (::send(remoteGuard, buf, n, 0) <= 0) return; } - TINYMITM_WRITELOG(verbose, "Established tunnel to {}", host); + TINYMITM_WRITELOG(verbose, "Established tunnel to {}", host); + + auto sslRead = [&](WOLFSSL* ssl, SOCKET s, char* b, int sz) -> int { + if (isConnect) return wolfSSL_read(ssl, b, sz); + return ::recv(s, b, sz, 0); + }; + auto sslWrite = [&](WOLFSSL* ssl, SOCKET s, const char* b, int sz) -> int { + if (isConnect) return wolfSSL_write(ssl, b, sz); + return ::send(s, b, sz, 0); + }; + auto sslPending = [&](WOLFSSL* ssl) -> int { + if (isConnect) return wolfSSL_pending(ssl); + return 0; + }; + auto sslError = [&](WOLFSSL* ssl, int ret) -> int { + if (isConnect) return wolfSSL_get_error(ssl, ret); +#ifdef _WIN32 + if (ret < 0 && WSAGetLastError() == WSAEWOULDBLOCK) return WOLFSSL_ERROR_WANT_READ; +#else + if (ret < 0 && (errno == EWOULDBLOCK || errno == EAGAIN)) return WOLFSSL_ERROR_WANT_READ; +#endif + return 0; + }; /* traffic loop @@ -440,7 +487,7 @@ void TinyMITMProxy::handleClient(SOCKET clientSocket) FD_SET(remoteGuard, &r_fds); struct timeval tv{0, 1000}; - bool hasBuffered = (wolfSSL_pending(clientSSL.get()) > 0 || wolfSSL_pending(remoteSSL.get()) > 0); + bool hasBuffered = (sslPending(clientSSL.get()) > 0 || sslPending(remoteSSL.get()) > 0); if (!hasBuffered) { @@ -455,12 +502,12 @@ void TinyMITMProxy::handleClient(SOCKET clientSocket) /* client -> server */ - if (FD_ISSET(clientGuard, &r_fds) || wolfSSL_pending(clientSSL.get())) + if (FD_ISSET(clientGuard, &r_fds) || sslPending(clientSSL.get())) { - int rd = wolfSSL_read(clientSSL.get(), buf, TINYMITM_CLIENT_BUFF_SIZE); + int rd = sslRead(clientSSL.get(), clientGuard, buf, TINYMITM_CLIENT_BUFF_SIZE); if (rd <= 0) { - if (wolfSSL_get_error(clientSSL.get(), rd) != WOLFSSL_ERROR_WANT_READ) break; + if (sslError(clientSSL.get(), rd) != WOLFSSL_ERROR_WANT_READ) break; } else { @@ -483,7 +530,7 @@ void TinyMITMProxy::handleClient(SOCKET clientSocket) if (s1 != std::string::npos && s2 != std::string::npos) path = headers.substr(s1 + 1, s2 - s1 - 1); - pendingUrls.push_back("https://" + host + path); + pendingUrls.push_back((isConnect ? "https://" : "http://") + host + path); } std::string fullBody; @@ -593,10 +640,10 @@ void TinyMITMProxy::handleClient(SOCKET clientSocket) /* server -> client */ - if (FD_ISSET(remoteGuard, &r_fds) || wolfSSL_pending(remoteSSL.get())) + if (FD_ISSET(remoteGuard, &r_fds) || sslPending(remoteSSL.get())) { - int rd = wolfSSL_read(remoteSSL.get(), buf, TINYMITM_CLIENT_BUFF_SIZE); - bool closed = (rd <= 0 && wolfSSL_get_error(remoteSSL.get(), rd) != WOLFSSL_ERROR_WANT_READ); + int rd = sslRead(remoteSSL.get(), remoteGuard, buf, TINYMITM_CLIENT_BUFF_SIZE); + bool closed = (rd <= 0 && sslError(remoteSSL.get(), rd) != WOLFSSL_ERROR_WANT_READ); if (rd > 0) { @@ -663,7 +710,7 @@ void TinyMITMProxy::handleClient(SOCKET clientSocket) if (complete) { - std::string url = inFlightUrls.empty() ? "https://" + host : inFlightUrls.front(); + std::string url = inFlightUrls.empty() ? ((isConnect ? "https://" : "http://") + host) : inFlightUrls.front(); if (!inFlightUrls.empty()) inFlightUrls.pop_front(); std::string respHeaders = serverStream.buffer.substr(0, serverStream.headersEnd + 4);