From 1edefb8147f526db4beba71e73ceaf6227c44db8 Mon Sep 17 00:00:00 2001 From: neru Date: Fri, 20 Mar 2026 12:38:30 -0300 Subject: [PATCH] feat: add full HTTP/HTTPs support --- src/unlocker/proxy.cpp | 494 ++++++++++++++++++++++++++++++++++++++--- src/unlocker/proxy.h | 14 ++ 2 files changed, 483 insertions(+), 25 deletions(-) diff --git a/src/unlocker/proxy.cpp b/src/unlocker/proxy.cpp index 060a030..5615cd9 100644 --- a/src/unlocker/proxy.cpp +++ b/src/unlocker/proxy.cpp @@ -4,6 +4,28 @@ #include #include +#include + +bool Proxy::initSSL() +{ + _clientCtx = SSL_CTX_new(TLS_client_method()); + if (!_clientCtx) + { + Log::error("Failed to create client SSL context"); + return false; + } + SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr); + return true; +} + +void Proxy::cleanupSSL() +{ + if (_clientCtx) + { + SSL_CTX_free(_clientCtx); + _clientCtx = nullptr; + } +} Proxy::Proxy() {} @@ -18,6 +40,15 @@ bool Proxy::Init() return false; } + if (!_certManager.Init()) + { + return false; + } + + if (!initSSL()) + { + return false; + } _listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (_listenSocket == INVALID_SOCKET) { @@ -61,6 +92,8 @@ void Proxy::Shutdown() if (_workerThread.joinable()) _workerThread.join(); WSACleanup(); + + cleanupSSL(); } void Proxy::loop() @@ -92,17 +125,66 @@ void Proxy::handleClient(SOCKET clientSocket) return; } + buffer[bytesReceived] = '\0'; std::string request(buffer, bytesReceived); - /* - get host - */ - std::string host; - size_t hostPos = request.find("Host: "); - if (hostPos != std::string::npos) + std::string method, url; + size_t space1 = request.find(' '); + if (space1 != std::string::npos) { - size_t endPos = request.find("\r\n", hostPos); - host = request.substr(hostPos + 6, endPos - (hostPos + 6)); + method = request.substr(0, space1); + size_t space2 = request.find(' ', space1 + 1); + if (space2 != std::string::npos) + { + url = request.substr(space1 + 1, space2 - space1 - 1); + } + } + + if (method.empty() || url.empty()) + { + closesocket(clientSocket); + return; + } + + std::string host; + std::string port = "80"; + bool isConnect = (method == "CONNECT"); + + if (isConnect) + { + size_t colon = url.find(':'); + if (colon != std::string::npos) + { + host = url.substr(0, colon); + port = url.substr(colon + 1); + } + else + { + host = url; + port = "443"; + } + } + else + { + size_t hostPos = request.find("Host: "); + if (hostPos != std::string::npos) + { + size_t endPos = request.find("\r\n", hostPos); + if (endPos != std::string::npos) + { + std::string hostHeader = request.substr(hostPos + 6, endPos - (hostPos + 6)); + size_t colon = hostHeader.find(':'); + if (colon != std::string::npos) + { + host = hostHeader.substr(0, colon); + port = hostHeader.substr(colon + 1); + } + else + { + host = hostHeader; + } + } + } } if (host.empty()) @@ -111,14 +193,11 @@ void Proxy::handleClient(SOCKET clientSocket) return; } - /* - handle remote - */ struct addrinfo hints = {}, *res; hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; - if (getaddrinfo(host.c_str(), "80", &hints, &res) != 0) + if (getaddrinfo(host.c_str(), port.c_str(), &hints, &res) != 0) { Log::error("Could not resolve host: {}", host); closesocket(clientSocket); @@ -128,26 +207,391 @@ void Proxy::handleClient(SOCKET clientSocket) SOCKET remoteSocket = socket(res->ai_family, res->ai_socktype, res->ai_protocol); if (connect(remoteSocket, res->ai_addr, (int)res->ai_addrlen) == SOCKET_ERROR) { - Log::error("Connection to {} failed", host); + Log::error("Connection to {}:{} failed", host, port); freeaddrinfo(res); closesocket(clientSocket); return; } freeaddrinfo(res); - /* - fwd - */ - send(remoteSocket, buffer, bytesReceived, 0); - - /* - recv - */ - int remoteBytes; - while ((remoteBytes = recv(remoteSocket, buffer, sizeof(buffer), 0)) > 0) + if (isConnect) { - // Log::verbose("Forwarding {} bytes from server back to client", remoteBytes); - send(clientSocket, buffer, remoteBytes, 0); + const char* reply = "HTTP/1.1 200 Connection Established\r\n\r\n"; + send(clientSocket, reply, static_cast(strlen(reply)), 0); + + SSL_CTX* serverCtx = _certManager.CreateHostContext(host); + if (!serverCtx) + { + Log::error("Failed to generate dynamic cert for {}", host); + closesocket(clientSocket); + closesocket(remoteSocket); + return; + } + + SSL* clientSSL = SSL_new(serverCtx); + SSL_set_fd(clientSSL, static_cast(clientSocket)); + + if (SSL_accept(clientSSL) <= 0) + { + Log::error("SSL_accept failed on client"); + SSL_free(clientSSL); + closesocket(clientSocket); + closesocket(remoteSocket); + return; + } + + SSL* remoteSSL = SSL_new(_clientCtx); + SSL_set_fd(remoteSSL, static_cast(remoteSocket)); + SSL_set_tlsext_host_name(remoteSSL, host.c_str()); + + if (SSL_connect(remoteSSL) <= 0) + { + Log::error("SSL_connect failed on remote server"); + SSL_free(remoteSSL); + SSL_free(clientSSL); + closesocket(clientSocket); + closesocket(remoteSocket); + return; + } + + std::deque pendingUrls; + std::string serverBuffer; + bool isReceivingBody = false; + int expectedLength = -1; + bool isChunked = false; + size_t headersEnd = 0; + + fd_set readfds; + while (_running) + { + FD_ZERO(&readfds); + FD_SET(clientSocket, &readfds); + FD_SET(remoteSocket, &readfds); + + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 100000; + + int ret = select(0, &readfds, NULL, NULL, &tv); + if (ret < 0) break; + + if (FD_ISSET(clientSocket, &readfds) || SSL_pending(clientSSL) > 0) + { + int bytes = SSL_read(clientSSL, buffer, sizeof(buffer)); + if (bytes <= 0) break; + + std::string data(buffer, bytes); + + size_t reqStart = 0; + while (reqStart < data.size()) + { + size_t nextReq = std::string::npos; + const char* methods[] = {"GET ", "POST ", "PUT ", "DELETE ", "PATCH ", "OPTIONS ", "HEAD "}; + for (const char* m : methods) + { + size_t found = data.find(m, reqStart + 1); + if (found != std::string::npos && (nextReq == std::string::npos || found < nextReq)) + nextReq = found; + } + + std::string singleReq = (nextReq == std::string::npos) ? data.substr(reqStart) + : data.substr(reqStart, nextReq - reqStart); + + std::string url; + size_t pathSpace1 = singleReq.find(' '); + size_t pathSpace2 = singleReq.find(' ', pathSpace1 + 1); + if (pathSpace1 != std::string::npos && pathSpace2 != std::string::npos) + { + url = "https://" + host + singleReq.substr(pathSpace1 + 1, pathSpace2 - pathSpace1 - 1); + pendingUrls.push_back(url); + + size_t aePos = singleReq.find("Accept-Encoding:"); + if (aePos == std::string::npos) aePos = singleReq.find("accept-encoding:"); + if (aePos != std::string::npos) + { + size_t aeEndPos = singleReq.find("\r\n", aePos); + if (aeEndPos != std::string::npos) + singleReq.replace(aePos, aeEndPos - aePos, "Accept-Encoding: identity"); + } + } + + OnClientRequest.run(url.empty() ? ("https://" + host) : url, singleReq); + + if (nextReq == std::string::npos) break; + reqStart = nextReq; + } + + int sent = SSL_write(remoteSSL, data.data(), static_cast(data.size())); + if (sent <= 0) break; + } + + if (FD_ISSET(remoteSocket, &readfds) || SSL_pending(remoteSSL) > 0) + { + int bytes = SSL_read(remoteSSL, buffer, sizeof(buffer) - 1); + if (bytes <= 0) break; + + serverBuffer.append(buffer, bytes); + + while (!serverBuffer.empty()) + { + if (!isReceivingBody) + { + headersEnd = serverBuffer.find("\r\n\r\n"); + if (headersEnd != std::string::npos) + { + isReceivingBody = true; + std::string headers = serverBuffer.substr(0, headersEnd + 4); + + size_t clPos = headers.find("Content-Length: "); + if (clPos == std::string::npos) clPos = headers.find("content-length: "); + if (clPos != std::string::npos) + { + size_t clEnd = headers.find("\r\n", clPos); + expectedLength = std::stoi(headers.substr(clPos + 16, clEnd - clPos - 16)); + } + else + expectedLength = -1; + + isChunked = (headers.find("chunked") != std::string::npos); + } + else + { + break; // need more data + } + } + + if (isReceivingBody) + { + bool complete = false; + std::string fullBody; + size_t bodyStart = headersEnd + 4; + size_t totalProcessed = bodyStart; + + if (isChunked) + { + size_t idx = bodyStart; + bool parseOk = true; + while (idx < serverBuffer.size()) + { + size_t lineEnd = serverBuffer.find("\r\n", idx); + if (lineEnd == std::string::npos) + { + parseOk = false; + break; + } + std::string hexStr = serverBuffer.substr(idx, lineEnd - idx); + int chunkSize = 0; + try + { + chunkSize = std::stoi(hexStr, nullptr, 16); + } + catch (...) + { + parseOk = false; + break; + } + idx = lineEnd + 2; + if (chunkSize == 0) + { + idx += 2; // skip terminal \r\n + complete = true; + totalProcessed = idx; + break; + } + if (idx + (size_t)chunkSize + 2 > serverBuffer.size()) + { + parseOk = false; + break; + } + fullBody.append(serverBuffer, idx, chunkSize); + idx += chunkSize + 2; + } + if (!parseOk) complete = false; + } + else if (expectedLength >= 0) + { + if (serverBuffer.size() >= bodyStart + expectedLength) + { + complete = true; + totalProcessed = bodyStart + expectedLength; + fullBody = serverBuffer.substr(bodyStart, expectedLength); + } + } + else + { + std::string peekBuffer = serverBuffer.substr(0, bodyStart); + bool isCloseConn = peekBuffer.find("Connection: close") != std::string::npos || + peekBuffer.find("connection: close") != std::string::npos; + + bool isNoBodyStatus = peekBuffer.find("HTTP/1.1 204") != std::string::npos || + peekBuffer.find("HTTP/1.1 304") != std::string::npos || + peekBuffer.find("HTTP/1.0 204") != std::string::npos || + peekBuffer.find("HTTP/1.0 304") != std::string::npos; + + if (isCloseConn) + break; + else + { + complete = true; + fullBody = ""; + totalProcessed = bodyStart; + } + } + + if (complete) + { + std::string headers = serverBuffer.substr(0, bodyStart); + std::string responseData = fullBody; + + std::string currentUrl = "https://" + host; + if (!pendingUrls.empty()) + { + currentUrl = pendingUrls.front(); + pendingUrls.pop_front(); + } + + OnServerResponse.run(currentUrl, responseData); + + auto removeHeader = [&](std::string& h, const std::string& key) { + size_t pos = 0; + while (true) + { + pos = h.find(key, pos); + if (pos == std::string::npos) + { + std::string lowerKey = key; + for (char& c : lowerKey) + c = (char)tolower(c); + pos = h.find(lowerKey, 0); + if (pos == std::string::npos) break; + } + if (pos == 0 || h[pos - 1] == '\n') + { + size_t end = h.find("\r\n", pos); + if (end != std::string::npos) + { + h.erase(pos, end - pos + 2); + continue; + } + } + pos++; + } + }; + + removeHeader(headers, "Transfer-Encoding"); + removeHeader(headers, "Content-Length"); + + headers.insert(headers.size() - 2, + "Content-Length: " + std::to_string(responseData.size()) + "\r\n"); + + std::string finalPacket = headers + responseData; + int sent = SSL_write(clientSSL, finalPacket.data(), static_cast(finalPacket.size())); + + serverBuffer.erase(0, totalProcessed); + + isReceivingBody = false; + expectedLength = -1; + headersEnd = 0; + isChunked = false; + + if (sent <= 0) break; + } + else + break; // wait for more streaming packets + } + } + } + } + + if (isReceivingBody && expectedLength < 0 && !isChunked && serverBuffer.size() > headersEnd + 4) + { + std::string headers = serverBuffer.substr(0, headersEnd + 4); + std::string responseData = serverBuffer.substr(headersEnd + 4); + + std::string finalUrl = "https://" + host; + if (!pendingUrls.empty()) + { + finalUrl = pendingUrls.front(); + pendingUrls.pop_front(); + } + + OnServerResponse.run(finalUrl, responseData); + + auto removeHeader = [&](std::string& h, const std::string& key) { + size_t pos = 0; + while (true) + { + pos = h.find(key, pos); + if (pos == std::string::npos) + { + std::string lowerKey = key; + for (char& c : lowerKey) + c = (char)tolower(c); + pos = h.find(lowerKey, 0); + if (pos == std::string::npos) break; + } + if (pos == 0 || h[pos - 1] == '\n') + { + size_t end = h.find("\r\n", pos); + if (end != std::string::npos) + { + h.erase(pos, end - pos + 2); + continue; + } + } + pos++; + } + }; + + removeHeader(headers, "Transfer-Encoding"); + removeHeader(headers, "Content-Length"); + headers.insert(headers.size() - 2, "Content-Length: " + std::to_string(responseData.size()) + "\r\n"); + + std::string finalPacket = headers + responseData; + SSL_write(clientSSL, finalPacket.data(), static_cast(finalPacket.size())); + } + + SSL_shutdown(clientSSL); + SSL_free(clientSSL); + + SSL_shutdown(remoteSSL); + SSL_free(remoteSSL); + } + else + { + send(remoteSocket, buffer, bytesReceived, 0); + + fd_set readfds; + while (_running) + { + FD_ZERO(&readfds); + FD_SET(clientSocket, &readfds); + FD_SET(remoteSocket, &readfds); + + struct timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + + int ret = select(0, &readfds, NULL, NULL, &tv); + if (ret < 0) break; + if (ret == 0) continue; + + if (FD_ISSET(clientSocket, &readfds)) + { + int bytes = recv(clientSocket, buffer, sizeof(buffer), 0); + if (bytes <= 0) break; + int sent = send(remoteSocket, buffer, bytes, 0); + if (sent == SOCKET_ERROR) break; + } + + if (FD_ISSET(remoteSocket, &readfds)) + { + int bytes = recv(remoteSocket, buffer, sizeof(buffer), 0); + if (bytes <= 0) break; + int sent = send(clientSocket, buffer, bytes, 0); + if (sent == SOCKET_ERROR) break; + } + } } closesocket(remoteSocket); diff --git a/src/unlocker/proxy.h b/src/unlocker/proxy.h index 6b4fa4a..cc2f8f9 100644 --- a/src/unlocker/proxy.h +++ b/src/unlocker/proxy.h @@ -2,6 +2,11 @@ #include #include +#include +#include +#include +#include "cert_manager.h" +#include #define PROXY_PORT 1337 @@ -15,11 +20,20 @@ public: bool Init(); void Shutdown(); + CallbackEvent OnClientRequest; + CallbackEvent OnServerResponse; + private: void loop(); void handleClient(SOCKET clientSocket); + bool initSSL(); + void cleanupSSL(); + SOCKET _listenSocket = 0; std::thread _workerThread; std::atomic _running; + + CertManager _certManager; + SSL_CTX* _clientCtx = nullptr; };