#include "proxy.h" #include #include #include #include #include /* memory helpers */ template struct Deleter { void operator()(T* p) const { if (p) f(p); } }; using SSL_ptr = std::unique_ptr>; struct AutoSocket { SOCKET s; AutoSocket(SOCKET val = INVALID_SOCKET) : s(val) {} ~AutoSocket() { if (s != INVALID_SOCKET) closesocket(s); } operator SOCKET() const { return s; } }; /* helper functions */ int stoiSafe(const std::string& s, int default_val = 0, int base = 10) { if (s.empty()) return default_val; try { return std::stoi(s, nullptr, base); } catch (...) { return default_val; } } void removeHeader(std::string& headers, const std::string& key) { std::string result; size_t start = 0; size_t end; std::string keyLower = key; std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower); if (keyLower.back() != ':') keyLower += ':'; while ((end = headers.find('\n', start)) != std::string::npos) { std::string line = headers.substr(start, end - start + 1); std::string lineLower = line; std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower); if (lineLower.compare(0, keyLower.length(), keyLower) != 0) result += line; start = end + 1; } if (start < headers.length()) { std::string line = headers.substr(start); std::string lineLower = line; if (lineLower.compare(0, keyLower.length(), keyLower) != 0) result += line; } headers = std::move(result); } std::string getHeaderValue(const std::string& headers, const std::string& key) { size_t start = 0; size_t end; std::string keyLower = key; std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower); if (keyLower.empty()) return ""; if (keyLower.back() != ':') keyLower += ':'; while (start < headers.length()) { end = headers.find('\n', start); std::string line = (end == std::string::npos) ? headers.substr(start) : headers.substr(start, end - start); std::string lineLower = line; std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower); if (lineLower.compare(0, keyLower.length(), keyLower) == 0) { size_t valueStart = keyLower.length(); while (valueStart < line.length() && (line[valueStart] == ' ' || line[valueStart] == '\t')) valueStart++; size_t valueEnd = line.length(); if (valueEnd > valueStart && line[valueEnd - 1] == '\r') valueEnd--; return line.substr(valueStart, valueEnd - valueStart); } if (end == std::string::npos) break; start = end + 1; } return ""; } /* http stream helper class */ struct HttpStream { std::string buffer; bool isReceivingBody = false; bool isChunked = false; int contentLength = -1; size_t headersEnd = std::string::npos; void reset() { isReceivingBody = false; isChunked = false; contentLength = -1; headersEnd = std::string::npos; } bool parseHeaders() { headersEnd = buffer.find("\r\n\r\n"); if (headersEnd == std::string::npos) return false; std::string headers = buffer.substr(0, headersEnd + 4); std::string te = getHeaderValue(headers, "Transfer-Encoding"); std::transform(te.begin(), te.end(), te.begin(), ::tolower); isChunked = (te.find("chunked") != std::string::npos); if (!isChunked) { std::string cl = getHeaderValue(headers, "Content-Length"); contentLength = stoiSafe(cl, -1); } isReceivingBody = true; return true; } }; /* proxy impl */ Proxy::Proxy() {} Proxy::~Proxy() { shutdown(); } bool Proxy::init() { if (!_certManager.init()) return false; initSSL(); WSADATA wsaData; if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false; _listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); sockaddr_in serverAddr = {}; serverAddr.sin_family = AF_INET; serverAddr.sin_port = htons(PROXY_PORT); inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr); if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false; listen(_listenSocket, SOMAXCONN); _running = true; _workerThread = std::thread(&Proxy::loop, this); Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT); return true; } void Proxy::shutdown() { if (!_running) return; _running = false; if (_listenSocket != INVALID_SOCKET) { closesocket(_listenSocket); _listenSocket = INVALID_SOCKET; } if (_workerThread.joinable()) _workerThread.join(); WSACleanup(); cleanupSSL(); } void Proxy::loop() { while (_running) { SOCKET client = accept(_listenSocket, NULL, NULL); if (!_running) { if (client != INVALID_SOCKET) closesocket(client); break; } if (client == INVALID_SOCKET) continue; std::thread([this, client]() { this->handleClient(client); }).detach(); } } void Proxy::handleClient(SOCKET clientSocket) { AutoSocket clientGuard(clientSocket); char buffer[32768]; int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0); if (bytesRead <= 0) return; buffer[bytesRead] = '\0'; std::string initialReq(buffer); if (initialReq.find("CONNECT ") != 0) return; /* host info */ size_t hostStart = 8; size_t hostEnd = initialReq.find(' ', hostStart); if (hostEnd == std::string::npos) return; std::string fullHost = initialReq.substr(hostStart, hostEnd - hostStart); std::string host = fullHost; int port = 443; size_t colon = fullHost.find(':'); if (colon != std::string::npos) { host = fullHost.substr(0, colon); port = stoiSafe(fullHost.substr(colon + 1), 443); } struct addrinfo hints = {}, *res = nullptr; hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return; /* establish connection to host */ AutoSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0) return freeaddrinfo(res); freeaddrinfo(res); send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0); /* SSL */ SSL_CTX* hostCtx = _certManager.createHostContext(host); if (!hostCtx) return; auto sslHandshake = [](SSL* ssl, bool isAccept, SOCKET s) -> bool { while (true) { int ret = isAccept ? SSL_accept(ssl) : SSL_connect(ssl); if (ret > 0) return true; int err = SSL_get_error(ssl, ret); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { fd_set fds; FD_ZERO(&fds); FD_SET(s, &fds); struct timeval tv = {1, 0}; if (err == SSL_ERROR_WANT_READ) select(0, &fds, NULL, NULL, &tv); else select(0, NULL, &fds, NULL, &tv); continue; } return false; } }; SSL_ptr clientSSL(SSL_new(hostCtx)); SSL_set_fd(clientSSL.get(), (int)clientGuard); if (!sslHandshake(clientSSL.get(), true, clientGuard)) return; SSL_ptr remoteSSL(SSL_new(_clientCtx)); SSL_set_fd(remoteSSL.get(), (int)remoteGuard); SSL_set_tlsext_host_name(remoteSSL.get(), host.c_str()); if (!sslHandshake(remoteSSL.get(), false, remoteGuard)) return; /* traffic handler */ HttpStream clientStream, serverStream; std::deque pendingUrls; bool tunnelMode = false; while (_running) { fd_set readfds; FD_ZERO(&readfds); FD_SET(clientGuard, &readfds); FD_SET(remoteGuard, &readfds); struct timeval tv = {0, 50000}; if (select(0, &readfds, NULL, NULL, &tv) < 0) break; /* client -> server */ if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL.get()) > 0) { int n = SSL_read(clientSSL.get(), buffer, sizeof(buffer)); if (n <= 0) break; if (tunnelMode) SSL_write(remoteSSL.get(), buffer, n); else { clientStream.buffer.append(buffer, n); while (true) { if (!clientStream.isReceivingBody) { if (!clientStream.parseHeaders()) break; std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); std::string url = "https://" + host; size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1); if (s1 != std::string::npos && s2 != std::string::npos) url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1); pendingUrls.push_back(url); } bool complete = false; std::string fullBody; size_t totalRequestSize = 0; if (clientStream.isChunked) { size_t idx = clientStream.headersEnd + 4; while (idx < clientStream.buffer.size()) { size_t le = clientStream.buffer.find("\r\n", idx); if (le == std::string::npos) break; int cs = stoiSafe(clientStream.buffer.substr(idx, le - idx), -1, 16); if (cs < 0) return; if (idx + (le - idx) + 2 + cs + 2 > clientStream.buffer.size()) break; if (cs > 0) fullBody.append(clientStream.buffer, le + 2, cs); idx = le + 2 + cs + 2; if (cs == 0) { complete = true; totalRequestSize = idx; break; } } } else { int cl = clientStream.contentLength; if (cl < 0) cl = 0; if (clientStream.buffer.size() >= (clientStream.headersEnd + 4 + cl)) { fullBody = clientStream.buffer.substr(clientStream.headersEnd + 4, cl); complete = true; totalRequestSize = clientStream.headersEnd + 4 + cl; } } if (complete) { std::string url = pendingUrls.back(); std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); removeHeader(headers, "Accept-Encoding"); headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n"); OnClientRequest.run(url, fullBody, headers); SSL_write(remoteSSL.get(), headers.data(), (int)headers.size()); SSL_write(remoteSSL.get(), fullBody.data(), (int)fullBody.size()); clientStream.buffer.erase(0, totalRequestSize); clientStream.reset(); } else break; } } } /* server -> client */ if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL.get()) > 0) { int n = SSL_read(remoteSSL.get(), buffer, sizeof(buffer)); bool closed = (n <= 0); if (!closed) serverStream.buffer.append(buffer, n); while (true) { if (!serverStream.isReceivingBody) if (!serverStream.parseHeaders()) break; bool complete = false; std::string fullBody; size_t totalResponseSize = 0; if (serverStream.isChunked) { size_t idx = serverStream.headersEnd + 4; while (idx < serverStream.buffer.size()) { size_t le = serverStream.buffer.find("\r\n", idx); if (le == std::string::npos) break; int cs = stoiSafe(serverStream.buffer.substr(idx, le - idx), -1, 16); if (cs < 0) return; if (idx + (le - idx) + 2 + cs + 2 > serverStream.buffer.size()) break; if (cs > 0) fullBody.append(serverStream.buffer, le + 2, cs); idx = le + 2 + cs + 2; if (cs == 0) { complete = true; totalResponseSize = idx; break; } } } else if (serverStream.contentLength >= 0) { if (serverStream.buffer.size() >= (serverStream.headersEnd + 4 + serverStream.contentLength)) { fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4, serverStream.contentLength); complete = true; totalResponseSize = serverStream.headersEnd + 4 + serverStream.contentLength; } } else if (closed) { fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4); complete = true; totalResponseSize = serverStream.buffer.size(); } if (complete) { std::string url = pendingUrls.empty() ? "https://" + host : pendingUrls.front(); if (!pendingUrls.empty()) pendingUrls.pop_front(); std::string respHeaders = serverStream.buffer.substr(0, serverStream.headersEnd + 4); OnServerResponse.run(url, fullBody, respHeaders); std::string packet; int statusCode = 200; size_t space = respHeaders.find(' '); if (space != std::string::npos) statusCode = stoiSafe(respHeaders.substr(space + 1, 3)); if (tunnelMode || statusCode == 204 || statusCode == 304 || (statusCode >= 100 && statusCode < 200)) packet = respHeaders + fullBody; else { removeHeader(respHeaders, "Transfer-Encoding"); removeHeader(respHeaders, "Content-Length"); respHeaders.insert(respHeaders.size() - 2, "Content-Length: " + std::to_string(fullBody.size()) + "\r\n"); packet = respHeaders + fullBody; } SSL_write(clientSSL.get(), packet.data(), (int)packet.size()); serverStream.buffer.erase(0, totalResponseSize); if (tunnelMode) { if (serverStream.buffer.size() > 0) { SSL_write(clientSSL.get(), serverStream.buffer.data(), (int)serverStream.buffer.size()); serverStream.buffer.clear(); } if (clientStream.buffer.size() > 0) { SSL_write(remoteSSL.get(), clientStream.buffer.data(), (int)clientStream.buffer.size()); clientStream.buffer.clear(); } } serverStream.reset(); if (tunnelMode) break; } else break; } if (closed) break; } } } bool Proxy::initSSL() { _clientCtx = SSL_CTX_new(TLS_client_method()); if (!_clientCtx) return false; SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr); const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'}; SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos)); return true; } void Proxy::cleanupSSL() { if (!_clientCtx) return; SSL_CTX_free(_clientCtx); _clientCtx = nullptr; }