#include "proxy.h" #include #include #include #include #include #include #include namespace { class ScopedSSL { SSL* s; public: ScopedSSL(SSL* val = nullptr) : s(val) {} ~ScopedSSL() { if (s) SSL_free(s); } void reset(SSL* val) { if (s) SSL_free(s); s = val; } operator SSL*() const { return s; } }; class ScopedSocket { SOCKET s; public: ScopedSocket(SOCKET val = INVALID_SOCKET) : s(val) {} ~ScopedSocket() { if (s != INVALID_SOCKET) closesocket(s); } void reset(SOCKET val) { if (s != INVALID_SOCKET) closesocket(s); s = val; } operator SOCKET() const { return s; } SOCKET get() const { return s; } }; int safe_stoi(const std::string& s, int default_val = 0, int base = 10) { if (s.empty()) return default_val; try { return std::stoi(s, nullptr, base); } catch (...) { return default_val; } } void removeHeader(std::string& h, const std::string& key) { std::string h_lower = h; std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower); std::string key_lower = key; std::transform(key_lower.begin(), key_lower.end(), key_lower.begin(), ::tolower); size_t pos = 0; while ((pos = h_lower.find(key_lower, pos)) != std::string::npos) { if (pos == 0 || h_lower[pos - 1] == '\n') { size_t end = h_lower.find("\r\n", pos); if (end != std::string::npos) { h.erase(pos, end - pos + 2); h_lower.erase(pos, end - pos + 2); continue; } } pos++; } } std::string getHeaderValue(const std::string& headers, const std::string& key) { std::string h_lower = headers; std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower); std::string search = "\n" + key; std::transform(search.begin(), search.end(), search.begin(), ::tolower); search += ":"; size_t pos = h_lower.find(search); if (pos == std::string::npos) return ""; size_t vStart = pos + search.length(); while (vStart < headers.size() && (headers[vStart] == ' ' || headers[vStart] == '\t')) vStart++; size_t vEnd = headers.find("\r\n", vStart); if (vEnd == std::string::npos) return ""; return headers.substr(vStart, vEnd - vStart); } struct HttpStream { std::string buffer; bool isReceivingBody = false; bool isChunked = false; int contentLength = -1; size_t headersEnd = 0; void reset() { isReceivingBody = false; isChunked = false; contentLength = -1; headersEnd = 0; } bool parseHeaders() { if (isReceivingBody) return true; headersEnd = buffer.find("\r\n\r\n"); if (headersEnd == std::string::npos) return false; std::string headers = buffer.substr(0, headersEnd + 4); std::string te = getHeaderValue(headers, "Transfer-Encoding"); std::transform(te.begin(), te.end(), te.begin(), ::tolower); isChunked = (te.find("chunked") != std::string::npos); if (!isChunked) { std::string cl = getHeaderValue(headers, "Content-Length"); contentLength = safe_stoi(cl, -1); } else { contentLength = -1; } isReceivingBody = true; return true; } }; } // namespace bool Proxy::initSSL() { _clientCtx = SSL_CTX_new(TLS_client_method()); if (!_clientCtx) return false; SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr); const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'}; SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos)); return true; } void Proxy::cleanupSSL() { if (_clientCtx) { SSL_CTX_free(_clientCtx); _clientCtx = nullptr; } } Proxy::Proxy() {} Proxy::~Proxy() { Shutdown(); } bool Proxy::Init() { WSADATA wsaData; if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false; if (!_certManager.Init()) return false; if (!initSSL()) return false; _listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); sockaddr_in serverAddr = {}; serverAddr.sin_family = AF_INET; serverAddr.sin_port = htons(PROXY_PORT); inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr); if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false; listen(_listenSocket, SOMAXCONN); Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT); _running = true; _workerThread = std::thread(&Proxy::loop, this); return true; } void Proxy::Shutdown() { _running = false; if (_listenSocket != INVALID_SOCKET) { closesocket(_listenSocket); _listenSocket = INVALID_SOCKET; } if (_workerThread.joinable()) _workerThread.join(); WSACleanup(); cleanupSSL(); } void Proxy::loop() { while (_running) { SOCKET hClient = accept(_listenSocket, NULL, NULL); if (!_running) { if (hClient != INVALID_SOCKET) closesocket(hClient); break; } if (hClient == INVALID_SOCKET) continue; std::thread([this, hClient]() { this->handleClient(hClient); }).detach(); } } void Proxy::handleClient(SOCKET hClientSocket) { ScopedSocket clientGuard(hClientSocket); char buffer[32768]; int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0); if (bytesRead <= 0) return; buffer[bytesRead] = '\0'; std::string initialReq(buffer); if (initialReq.find("CONNECT ") != 0) return; size_t hostStart = 8; size_t hostEnd = initialReq.find(' ', hostStart); if (hostEnd == std::string::npos) return; std::string fullHost = initialReq.substr(hostStart, hostEnd - hostStart); std::string host = fullHost; int port = 443; size_t colon = fullHost.find(':'); if (colon != std::string::npos) { host = fullHost.substr(0, colon); port = safe_stoi(fullHost.substr(colon + 1), 443); } struct addrinfo hints = {}, *res = nullptr; hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return; ScopedSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0) { freeaddrinfo(res); return; } freeaddrinfo(res); send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0); SSL_CTX* hostCtx = _certManager.CreateHostContext(host); if (!hostCtx) return; ScopedSSL clientSSL(SSL_new(hostCtx)); SSL_set_fd(clientSSL, (int)clientGuard.get()); if (SSL_accept(clientSSL) <= 0) return; ScopedSSL remoteSSL(SSL_new(_clientCtx)); SSL_set_fd(remoteSSL, (int)remoteGuard.get()); SSL_set_tlsext_host_name(remoteSSL, host.c_str()); if (SSL_connect(remoteSSL) <= 0) return; HttpStream clientStream, serverStream; std::deque pendingUrls; bool tunnelMode = false; fd_set readfds; while (_running) { FD_ZERO(&readfds); FD_SET(clientGuard, &readfds); FD_SET(remoteGuard, &readfds); struct timeval tv = {0, 50000}; if (select(0, &readfds, NULL, NULL, &tv) < 0) break; if (tunnelMode) { if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0) { int n = SSL_read(clientSSL, buffer, sizeof(buffer)); if (n <= 0) break; SSL_write(remoteSSL, buffer, n); } if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0) { int n = SSL_read(remoteSSL, buffer, sizeof(buffer)); if (n <= 0) break; SSL_write(clientSSL, buffer, n); } continue; } if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0) { int n = SSL_read(clientSSL, buffer, sizeof(buffer)); if (n <= 0) break; clientStream.buffer.append(buffer, n); while (!clientStream.buffer.empty()) { if (!clientStream.isReceivingBody) { if (!clientStream.parseHeaders()) break; std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); std::string url = "https://" + host; size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1); if (s1 != std::string::npos && s2 != std::string::npos) url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1); pendingUrls.push_back(url); removeHeader(headers, "Accept-Encoding"); removeHeader(headers, "Expect"); headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n"); if (clientStream.contentLength == 0 || (clientStream.contentLength < 0 && !clientStream.isChunked)) { std::string emptyBody = ""; OnClientRequest.run(url, emptyBody, headers); SSL_write(remoteSSL, headers.data(), (int)headers.size()); clientStream.buffer.erase(0, clientStream.headersEnd + 4); clientStream.reset(); } } if (clientStream.isReceivingBody) { size_t bodyStart = clientStream.headersEnd + 4; std::string url = pendingUrls.back(); std::string headers = clientStream.buffer.substr(0, bodyStart); removeHeader(headers, "Accept-Encoding"); removeHeader(headers, "Expect"); headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n"); bool complete = false; std::string body; if (clientStream.isChunked) { size_t idx = bodyStart; while (idx < clientStream.buffer.size()) { size_t le = clientStream.buffer.find("\r\n", idx); if (le == std::string::npos) break; int cs = safe_stoi(clientStream.buffer.substr(idx, le - idx), 0, 16); if (idx + (le - idx) + 2 + cs + 2 > clientStream.buffer.size()) break; body.append(clientStream.buffer, le + 2, cs); idx = le + 2 + cs + 2; if (cs == 0) { complete = true; break; } } } else if (clientStream.contentLength >= 0) { if (clientStream.buffer.size() >= bodyStart + clientStream.contentLength) { body = clientStream.buffer.substr(bodyStart, clientStream.contentLength); complete = true; } } if (complete) { OnClientRequest.run(url, body, headers); SSL_write(remoteSSL, headers.data(), (int)headers.size()); SSL_write(remoteSSL, clientStream.buffer.data() + bodyStart, (int)(clientStream.buffer.size() - bodyStart)); clientStream.buffer.clear(); clientStream.reset(); } else break; } } } if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0) { int n = SSL_read(remoteSSL, buffer, sizeof(buffer)); bool connectionClosed = (n <= 0); if (!connectionClosed) { serverStream.buffer.append(buffer, n); } while (!serverStream.buffer.empty() || connectionClosed) { if (!serverStream.isReceivingBody) { if (connectionClosed) break; serverStream.headersEnd = serverStream.buffer.find("\r\n\r\n"); if (serverStream.headersEnd == std::string::npos) break; std::string headers = serverStream.buffer.substr(0, serverStream.headersEnd + 4); size_t s1 = headers.find(' '); int sCode = (s1 != std::string::npos) ? safe_stoi(headers.substr(s1 + 1, 3)) : 0; if (sCode == 101) { SSL_write(clientSSL, serverStream.buffer.data(), (int)serverStream.buffer.size()); serverStream.buffer.clear(); clientStream.buffer.clear(); tunnelMode = true; break; } if (sCode >= 100 && sCode < 200) { SSL_write(clientSSL, headers.data(), (int)headers.size()); serverStream.buffer.erase(0, serverStream.headersEnd + 4); serverStream.isReceivingBody = false; continue; } serverStream.isReceivingBody = true; std::string h_lower = headers; std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower); serverStream.isChunked = (h_lower.find("transfer-encoding: chunked") != std::string::npos); size_t clPos = h_lower.find("content-length:"); if (clPos != std::string::npos) { size_t vStart = clPos + 15; while (vStart < h_lower.size() && (h_lower[vStart] == ' ' || h_lower[vStart] == '\t')) vStart++; serverStream.contentLength = safe_stoi(h_lower.substr(vStart, h_lower.find("\r\n", vStart) - vStart), -1); } else if (sCode == 204 || sCode == 304 || sCode == 205) serverStream.contentLength = 0; else if (!serverStream.isChunked) serverStream.contentLength = -1; } if (serverStream.isReceivingBody) { bool complete = false; std::string body; size_t bStart = serverStream.headersEnd + 4; size_t processed = bStart; if (serverStream.isChunked) { size_t idx = bStart; bool ok = true; while (idx < serverStream.buffer.size()) { size_t le = serverStream.buffer.find("\r\n", idx); if (le == std::string::npos) { ok = false; break; } int cs = safe_stoi(serverStream.buffer.substr(idx, le - idx), 0, 16); idx = le + 2; if (cs == 0) { idx += 2; complete = true; processed = idx; break; } if (idx + cs + 2 > serverStream.buffer.size()) { ok = false; break; } body.append(serverStream.buffer, idx, cs); idx += cs + 2; } if (!ok) { if (connectionClosed) { complete = true; processed = serverStream.buffer.size(); } else complete = false; } } else if (serverStream.contentLength >= 0) { if (serverStream.buffer.size() >= bStart + serverStream.contentLength) { complete = true; processed = bStart + serverStream.contentLength; body = serverStream.buffer.substr(bStart, serverStream.contentLength); } else if (connectionClosed) { complete = true; processed = serverStream.buffer.size(); body = serverStream.buffer.substr(bStart); } } else { if (connectionClosed) { complete = true; processed = serverStream.buffer.size(); body = serverStream.buffer.substr(bStart); } else complete = false; } if (complete) { std::string url = pendingUrls.empty() ? ("https://" + host) : pendingUrls.front(); if (!pendingUrls.empty()) pendingUrls.pop_front(); std::string respHeaders = serverStream.buffer.substr(0, bStart); OnServerResponse.run(url, body, respHeaders); removeHeader(respHeaders, "Transfer-Encoding"); removeHeader(respHeaders, "Content-Length"); size_t fs = respHeaders.find(' '); int scFinal = (fs != std::string::npos) ? safe_stoi(respHeaders.substr(fs + 1, 3)) : 0; if (scFinal != 204 && scFinal != 304 && scFinal != 205) respHeaders.insert(respHeaders.size() - 2, "Content-Length: " + std::to_string(body.size()) + "\r\n"); std::string packet = respHeaders + body; SSL_write(clientSSL, packet.data(), (int)packet.size()); serverStream.buffer.erase(0, processed); serverStream.reset(); clientStream.reset(); } else break; } } if (connectionClosed) break; } } }