fix: refactor everything, fix everything

This commit is contained in:
2026-03-20 15:19:19 -03:00
parent 74e3087295
commit 013755bd15
3 changed files with 473 additions and 505 deletions
+1 -1
View File
@@ -86,7 +86,7 @@ add_custom_command(TARGET dbd-unlocker POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different COMMAND ${CMAKE_COMMAND} -E copy_if_different
${JSON_RES_FILES} ${JSON_RES_FILES}
"$<TARGET_FILE_DIR:dbd-unlocker>/" "$<TARGET_FILE_DIR:dbd-unlocker>/"
COMMENT "Copying JSON resources to executable directory" COMMENT "copying json sources"
) )
# ------------------------------ # ------------------------------
+404 -437
View File
@@ -5,16 +5,156 @@
#include <nerutils/log.h> #include <nerutils/log.h>
#include <deque> #include <deque>
#include <algorithm>
#include <atomic>
#include <string>
namespace
{
class ScopedSSL
{
SSL* s;
public:
ScopedSSL(SSL* val = nullptr) : s(val) {}
~ScopedSSL()
{
if (s) SSL_free(s);
}
void reset(SSL* val)
{
if (s) SSL_free(s);
s = val;
}
operator SSL*() const { return s; }
};
class ScopedSocket
{
SOCKET s;
public:
ScopedSocket(SOCKET val = INVALID_SOCKET) : s(val) {}
~ScopedSocket()
{
if (s != INVALID_SOCKET) closesocket(s);
}
void reset(SOCKET val)
{
if (s != INVALID_SOCKET) closesocket(s);
s = val;
}
operator SOCKET() const { return s; }
SOCKET get() const { return s; }
};
int safe_stoi(const std::string& s, int default_val = 0, int base = 10)
{
if (s.empty()) return default_val;
try
{
return std::stoi(s, nullptr, base);
}
catch (...)
{
return default_val;
}
}
void removeHeader(std::string& h, const std::string& key)
{
std::string h_lower = h;
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
std::string key_lower = key;
std::transform(key_lower.begin(), key_lower.end(), key_lower.begin(), ::tolower);
size_t pos = 0;
while ((pos = h_lower.find(key_lower, pos)) != std::string::npos)
{
if (pos == 0 || h_lower[pos - 1] == '\n')
{
size_t end = h_lower.find("\r\n", pos);
if (end != std::string::npos)
{
h.erase(pos, end - pos + 2);
h_lower.erase(pos, end - pos + 2);
continue;
}
}
pos++;
}
}
std::string getHeaderValue(const std::string& headers, const std::string& key)
{
std::string h_lower = headers;
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
std::string search = "\n" + key;
std::transform(search.begin(), search.end(), search.begin(), ::tolower);
search += ":";
size_t pos = h_lower.find(search);
if (pos == std::string::npos) return ""; // Not found
size_t vStart = pos + search.length();
while (vStart < headers.size() && (headers[vStart] == ' ' || headers[vStart] == '\t'))
vStart++;
size_t vEnd = headers.find("\r\n", vStart);
if (vEnd == std::string::npos) return "";
return headers.substr(vStart, vEnd - vStart);
}
struct HttpStream
{
std::string buffer;
bool isReceivingBody = false;
bool isChunked = false;
int contentLength = -1;
size_t headersEnd = 0;
void reset()
{
isReceivingBody = false;
isChunked = false;
contentLength = -1;
headersEnd = 0;
}
bool parseHeaders()
{
if (isReceivingBody) return true;
headersEnd = buffer.find("\r\n\r\n");
if (headersEnd == std::string::npos) return false;
std::string headers = buffer.substr(0, headersEnd + 4);
std::string te = getHeaderValue(headers, "Transfer-Encoding");
std::transform(te.begin(), te.end(), te.begin(), ::tolower);
isChunked = (te.find("chunked") != std::string::npos);
if (!isChunked)
{
std::string cl = getHeaderValue(headers, "Content-Length");
contentLength = safe_stoi(cl, -1);
}
else
{
contentLength = -1;
}
isReceivingBody = true;
return true;
}
};
} // namespace
bool Proxy::initSSL() bool Proxy::initSSL()
{ {
_clientCtx = SSL_CTX_new(TLS_client_method()); _clientCtx = SSL_CTX_new(TLS_client_method());
if (!_clientCtx) if (!_clientCtx) return false;
{
Log::error("Failed to create client SSL context");
return false;
}
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr); SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos));
return true; return true;
} }
@@ -28,71 +168,43 @@ void Proxy::cleanupSSL()
} }
Proxy::Proxy() {} Proxy::Proxy() {}
Proxy::~Proxy()
Proxy::~Proxy() {} {
Shutdown();
}
bool Proxy::Init() bool Proxy::Init()
{ {
WSADATA wsaData; WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false;
{ if (!_certManager.Init()) return false;
Log::error("WSAStartup failed"); if (!initSSL()) return false;
return false;
}
if (!_certManager.Init())
{
return false;
}
if (!initSSL())
{
return false;
}
_listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); _listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (_listenSocket == INVALID_SOCKET) sockaddr_in serverAddr = {};
{
Log::error("Error creating listen socket: {0:x}", WSAGetLastError());
Shutdown();
return false;
}
sockaddr_in serverAddr;
serverAddr.sin_family = AF_INET; serverAddr.sin_family = AF_INET;
serverAddr.sin_port = htons(PROXY_PORT); serverAddr.sin_port = htons(PROXY_PORT);
inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr); inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr);
if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false;
{
Log::error("Error binding listen socket: {0:x}", WSAGetLastError());
Shutdown();
return false;
}
listen(_listenSocket, SOMAXCONN); listen(_listenSocket, SOMAXCONN);
Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT);
Log::verbose("Listening on 127.0.0.1:{}", PROXY_PORT);
_running = true; _running = true;
_workerThread = std::thread(&Proxy::loop, this); _workerThread = std::thread(&Proxy::loop, this);
return true; return true;
} }
void Proxy::Shutdown() void Proxy::Shutdown()
{ {
_running = false; _running = false;
if (_listenSocket != INVALID_SOCKET)
if (_listenSocket != INVALID_SOCKET && _listenSocket != 0)
{ {
closesocket(_listenSocket); closesocket(_listenSocket);
_listenSocket = 0; _listenSocket = INVALID_SOCKET;
} }
if (_workerThread.joinable()) _workerThread.join(); if (_workerThread.joinable()) _workerThread.join();
WSACleanup(); WSACleanup();
cleanupSSL(); cleanupSSL();
} }
@@ -100,500 +212,355 @@ void Proxy::loop()
{ {
while (_running) while (_running)
{ {
SOCKET clientSocket = accept(_listenSocket, NULL, NULL); SOCKET hClient = accept(_listenSocket, NULL, NULL);
if (!_running) if (!_running)
{ {
if (clientSocket != INVALID_SOCKET) closesocket(clientSocket); if (hClient != INVALID_SOCKET) closesocket(hClient);
break; break;
} }
if (hClient == INVALID_SOCKET) continue;
if (clientSocket == INVALID_SOCKET) continue; std::thread([this, hClient]() { this->handleClient(hClient); }).detach();
std::thread([this, clientSocket]() { this->handleClient(clientSocket); }).detach();
} }
} }
void Proxy::handleClient(SOCKET clientSocket) void Proxy::handleClient(SOCKET hClientSocket)
{ {
char buffer[8192]; ScopedSocket clientGuard(hClientSocket);
int bytesReceived = recv(clientSocket, buffer, sizeof(buffer) - 1, 0); char buffer[32768];
if (bytesReceived <= 0) int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0);
{ if (bytesRead <= 0) return;
closesocket(clientSocket); buffer[bytesRead] = '\0';
return;
}
buffer[bytesReceived] = '\0'; std::string initialReq(buffer);
std::string request(buffer, bytesReceived); if (initialReq.find("CONNECT ") != 0) return;
std::string method, url; size_t hostStart = 8;
size_t space1 = request.find(' '); size_t hostEnd = initialReq.find(' ', hostStart);
if (space1 != std::string::npos) if (hostEnd == std::string::npos) return;
{
method = request.substr(0, space1);
size_t space2 = request.find(' ', space1 + 1);
if (space2 != std::string::npos)
{
url = request.substr(space1 + 1, space2 - space1 - 1);
}
}
if (method.empty() || url.empty()) std::string fullHost = initialReq.substr(hostStart, hostEnd - hostStart);
{ std::string host = fullHost;
closesocket(clientSocket); int port = 443;
return; size_t colon = fullHost.find(':');
}
std::string host;
std::string port = "80";
bool isConnect = (method == "CONNECT");
if (isConnect)
{
size_t colon = url.find(':');
if (colon != std::string::npos) if (colon != std::string::npos)
{ {
host = url.substr(0, colon); host = fullHost.substr(0, colon);
port = url.substr(colon + 1); port = safe_stoi(fullHost.substr(colon + 1), 443);
}
else
{
host = url;
port = "443";
}
}
else
{
size_t hostPos = request.find("Host: ");
if (hostPos != std::string::npos)
{
size_t endPos = request.find("\r\n", hostPos);
if (endPos != std::string::npos)
{
std::string hostHeader = request.substr(hostPos + 6, endPos - (hostPos + 6));
size_t colon = hostHeader.find(':');
if (colon != std::string::npos)
{
host = hostHeader.substr(0, colon);
port = hostHeader.substr(colon + 1);
}
else
{
host = hostHeader;
}
}
}
} }
if (host.empty()) struct addrinfo hints = {}, *res = nullptr;
{
closesocket(clientSocket);
return;
}
struct addrinfo hints = {}, *res;
hints.ai_family = AF_INET; hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM; hints.ai_socktype = SOCK_STREAM;
if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return;
if (getaddrinfo(host.c_str(), port.c_str(), &hints, &res) != 0) ScopedSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0)
{ {
Log::error("Could not resolve host: {}", host);
closesocket(clientSocket);
return;
}
SOCKET remoteSocket = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
if (connect(remoteSocket, res->ai_addr, (int)res->ai_addrlen) == SOCKET_ERROR)
{
Log::error("Connection to {}:{} failed", host, port);
freeaddrinfo(res); freeaddrinfo(res);
closesocket(clientSocket);
return; return;
} }
freeaddrinfo(res); freeaddrinfo(res);
if (isConnect) send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0);
{
const char* reply = "HTTP/1.1 200 Connection Established\r\n\r\n";
send(clientSocket, reply, static_cast<int>(strlen(reply)), 0);
SSL_CTX* serverCtx = _certManager.CreateHostContext(host); SSL_CTX* hostCtx = _certManager.CreateHostContext(host);
if (!serverCtx) if (!hostCtx) return;
{
Log::error("Failed to generate dynamic cert for {}", host);
closesocket(clientSocket);
closesocket(remoteSocket);
return;
}
SSL* clientSSL = SSL_new(serverCtx); ScopedSSL clientSSL(SSL_new(hostCtx));
SSL_set_fd(clientSSL, static_cast<int>(clientSocket)); SSL_set_fd(clientSSL, (int)clientGuard.get());
if (SSL_accept(clientSSL) <= 0) return;
if (SSL_accept(clientSSL) <= 0) ScopedSSL remoteSSL(SSL_new(_clientCtx));
{ SSL_set_fd(remoteSSL, (int)remoteGuard.get());
Log::error("SSL_accept failed on client");
SSL_free(clientSSL);
closesocket(clientSocket);
closesocket(remoteSocket);
return;
}
SSL* remoteSSL = SSL_new(_clientCtx);
SSL_set_fd(remoteSSL, static_cast<int>(remoteSocket));
SSL_set_tlsext_host_name(remoteSSL, host.c_str()); SSL_set_tlsext_host_name(remoteSSL, host.c_str());
if (SSL_connect(remoteSSL) <= 0) return;
if (SSL_connect(remoteSSL) <= 0) HttpStream clientStream, serverStream;
{
Log::error("SSL_connect failed on remote server");
SSL_free(remoteSSL);
SSL_free(clientSSL);
closesocket(clientSocket);
closesocket(remoteSocket);
return;
}
std::deque<std::string> pendingUrls; std::deque<std::string> pendingUrls;
std::string serverBuffer; bool tunnelMode = false;
bool isReceivingBody = false;
int expectedLength = -1;
bool isChunked = false;
size_t headersEnd = 0;
fd_set readfds; fd_set readfds;
while (_running) while (_running)
{ {
FD_ZERO(&readfds); FD_ZERO(&readfds);
FD_SET(clientSocket, &readfds); FD_SET(clientGuard, &readfds);
FD_SET(remoteSocket, &readfds); FD_SET(remoteGuard, &readfds);
struct timeval tv; struct timeval tv = {0, 50000};
tv.tv_sec = 0; if (select(0, &readfds, NULL, NULL, &tv) < 0) break;
tv.tv_usec = 100000;
int ret = select(0, &readfds, NULL, NULL, &tv); if (tunnelMode)
if (ret < 0) break;
if (FD_ISSET(clientSocket, &readfds) || SSL_pending(clientSSL) > 0)
{ {
int bytes = SSL_read(clientSSL, buffer, sizeof(buffer)); if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
if (bytes <= 0) break;
std::string data(buffer, bytes);
size_t reqStart = 0;
while (reqStart < data.size())
{ {
size_t nextReq = std::string::npos; int n = SSL_read(clientSSL, buffer, sizeof(buffer));
const char* methods[] = {"GET ", "POST ", "PUT ", "DELETE ", "PATCH ", "OPTIONS ", "HEAD "}; if (n <= 0) break;
for (const char* m : methods) SSL_write(remoteSSL, buffer, n);
}
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
{ {
size_t found = data.find(m, reqStart + 1); int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
if (found != std::string::npos && (nextReq == std::string::npos || found < nextReq)) if (n <= 0) break;
nextReq = found; SSL_write(clientSSL, buffer, n);
}
continue;
} }
std::string singleReq = (nextReq == std::string::npos) ? data.substr(reqStart) if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
: data.substr(reqStart, nextReq - reqStart);
std::string url;
size_t pathSpace1 = singleReq.find(' ');
size_t pathSpace2 = singleReq.find(' ', pathSpace1 + 1);
if (pathSpace1 != std::string::npos && pathSpace2 != std::string::npos)
{ {
url = "https://" + host + singleReq.substr(pathSpace1 + 1, pathSpace2 - pathSpace1 - 1); int n = SSL_read(clientSSL, buffer, sizeof(buffer));
if (n <= 0) break;
clientStream.buffer.append(buffer, n);
while (!clientStream.buffer.empty())
{
if (!clientStream.isReceivingBody)
{
if (!clientStream.parseHeaders()) break;
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
std::string url = "https://" + host;
size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1);
if (s1 != std::string::npos && s2 != std::string::npos)
url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1);
pendingUrls.push_back(url); pendingUrls.push_back(url);
size_t aePos = singleReq.find("Accept-Encoding:"); removeHeader(headers, "Accept-Encoding");
if (aePos == std::string::npos) aePos = singleReq.find("accept-encoding:"); removeHeader(headers, "Expect");
if (aePos != std::string::npos) headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n");
OnClientRequest.run(url, headers);
SSL_write(remoteSSL, headers.data(), (int)headers.size());
clientStream.buffer.erase(0, clientStream.headersEnd + 4);
}
if (clientStream.isReceivingBody)
{ {
size_t aeEndPos = singleReq.find("\r\n", aePos); if (clientStream.isChunked)
if (aeEndPos != std::string::npos) {
singleReq.replace(aePos, aeEndPos - aePos, "Accept-Encoding: identity"); size_t idx = 0;
while (idx < clientStream.buffer.size())
{
size_t le = clientStream.buffer.find("\r\n", idx);
if (le == std::string::npos) break;
int chunkSz = safe_stoi(clientStream.buffer.substr(idx, le - idx), 0, 16);
size_t totalChunkSz = (le - idx) + 2 + chunkSz + 2;
if (idx + totalChunkSz > clientStream.buffer.size()) break;
SSL_write(remoteSSL, clientStream.buffer.data() + idx, (int)totalChunkSz);
idx += totalChunkSz;
if (chunkSz == 0)
{
clientStream.reset();
break;
}
}
if (idx > 0) clientStream.buffer.erase(0, idx);
if (!clientStream.isReceivingBody) continue;
break;
}
else if (clientStream.contentLength >= 0)
{
size_t ts = (std::min)((size_t)clientStream.contentLength, clientStream.buffer.size());
if (ts > 0)
{
SSL_write(remoteSSL, clientStream.buffer.data(), (int)ts);
clientStream.buffer.erase(0, ts);
clientStream.contentLength -= (int)ts;
}
if (clientStream.contentLength <= 0)
clientStream.reset();
else
break;
}
else
{
clientStream.reset();
}
}
} }
} }
OnClientRequest.run(url.empty() ? ("https://" + host) : url, singleReq); if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
{
if (nextReq == std::string::npos) break; int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
reqStart = nextReq; bool connectionClosed = (n <= 0);
if (!connectionClosed)
{
serverStream.buffer.append(buffer, n);
} }
int sent = SSL_write(remoteSSL, data.data(), static_cast<int>(data.size())); while (!serverStream.buffer.empty() || connectionClosed)
if (sent <= 0) break; {
if (!serverStream.isReceivingBody)
{
if (connectionClosed) break;
serverStream.headersEnd = serverStream.buffer.find("\r\n\r\n");
if (serverStream.headersEnd == std::string::npos) break;
std::string headers = serverStream.buffer.substr(0, serverStream.headersEnd + 4);
size_t s1 = headers.find(' ');
int sCode = (s1 != std::string::npos) ? safe_stoi(headers.substr(s1 + 1, 3)) : 0;
if (sCode == 101)
{
SSL_write(clientSSL, serverStream.buffer.data(), (int)serverStream.buffer.size());
serverStream.buffer.clear();
clientStream.buffer.clear();
tunnelMode = true;
break;
} }
if (FD_ISSET(remoteSocket, &readfds) || SSL_pending(remoteSSL) > 0) if (sCode >= 100 && sCode < 200)
{ {
int bytes = SSL_read(remoteSSL, buffer, sizeof(buffer) - 1); SSL_write(clientSSL, headers.data(), (int)headers.size());
if (bytes <= 0) break; serverStream.buffer.erase(0, serverStream.headersEnd + 4);
serverStream.isReceivingBody = false;
continue;
}
serverBuffer.append(buffer, bytes); serverStream.isReceivingBody = true;
std::string h_lower = headers;
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
serverStream.isChunked = (h_lower.find("transfer-encoding: chunked") != std::string::npos);
while (!serverBuffer.empty()) size_t clPos = h_lower.find("content-length:");
{
if (!isReceivingBody)
{
headersEnd = serverBuffer.find("\r\n\r\n");
if (headersEnd != std::string::npos)
{
isReceivingBody = true;
std::string headers = serverBuffer.substr(0, headersEnd + 4);
size_t clPos = headers.find("Content-Length: ");
if (clPos == std::string::npos) clPos = headers.find("content-length: ");
if (clPos != std::string::npos) if (clPos != std::string::npos)
{ {
size_t clEnd = headers.find("\r\n", clPos); size_t vStart = clPos + 15;
expectedLength = std::stoi(headers.substr(clPos + 16, clEnd - clPos - 16)); while (vStart < h_lower.size() && (h_lower[vStart] == ' ' || h_lower[vStart] == '\t'))
vStart++;
serverStream.contentLength =
safe_stoi(h_lower.substr(vStart, h_lower.find("\r\n", vStart) - vStart), -1);
} }
else else if (sCode == 204 || sCode == 304 || sCode == 205)
expectedLength = -1;
isChunked = (headers.find("chunked") != std::string::npos);
}
else
{ {
break; // need more data serverStream.contentLength = 0;
}
else if (!serverStream.isChunked)
{
serverStream.contentLength = -1;
} }
} }
if (isReceivingBody) if (serverStream.isReceivingBody)
{ {
bool complete = false; bool complete = false;
std::string fullBody; std::string body;
size_t bodyStart = headersEnd + 4; size_t bStart = serverStream.headersEnd + 4;
size_t totalProcessed = bodyStart; size_t processed = bStart;
if (isChunked) if (serverStream.isChunked)
{ {
size_t idx = bodyStart; size_t idx = bStart;
bool parseOk = true; bool ok = true;
while (idx < serverBuffer.size()) while (idx < serverStream.buffer.size())
{ {
size_t lineEnd = serverBuffer.find("\r\n", idx); size_t le = serverStream.buffer.find("\r\n", idx);
if (lineEnd == std::string::npos) if (le == std::string::npos)
{ {
parseOk = false; ok = false;
break; break;
} }
std::string hexStr = serverBuffer.substr(idx, lineEnd - idx); int cs = safe_stoi(serverStream.buffer.substr(idx, le - idx), 0, 16);
int chunkSize = 0; idx = le + 2;
try if (cs == 0)
{ {
chunkSize = std::stoi(hexStr, nullptr, 16); idx += 2;
}
catch (...)
{
parseOk = false;
break;
}
idx = lineEnd + 2;
if (chunkSize == 0)
{
idx += 2; // skip terminal \r\n
complete = true; complete = true;
totalProcessed = idx; processed = idx;
break; break;
} }
if (idx + (size_t)chunkSize + 2 > serverBuffer.size()) if (idx + cs + 2 > serverStream.buffer.size())
{ {
parseOk = false; ok = false;
break; break;
} }
fullBody.append(serverBuffer, idx, chunkSize); body.append(serverStream.buffer, idx, cs);
idx += chunkSize + 2; idx += cs + 2;
} }
if (!parseOk) complete = false; if (!ok)
}
else if (expectedLength >= 0)
{ {
if (serverBuffer.size() >= bodyStart + expectedLength) if (connectionClosed)
{ {
complete = true; complete = true;
totalProcessed = bodyStart + expectedLength; processed = serverStream.buffer.size();
fullBody = serverBuffer.substr(bodyStart, expectedLength); }
else
{
complete = false;
}
}
}
else if (serverStream.contentLength >= 0)
{
if (serverStream.buffer.size() >= bStart + serverStream.contentLength)
{
complete = true;
processed = bStart + serverStream.contentLength;
body = serverStream.buffer.substr(bStart, serverStream.contentLength);
}
else if (connectionClosed)
{
complete = true;
processed = serverStream.buffer.size();
body = serverStream.buffer.substr(bStart);
} }
} }
else else
{ {
std::string peekBuffer = serverBuffer.substr(0, bodyStart); if (connectionClosed)
bool isCloseConn = peekBuffer.find("Connection: close") != std::string::npos ||
peekBuffer.find("connection: close") != std::string::npos;
bool isNoBodyStatus = peekBuffer.find("HTTP/1.1 204") != std::string::npos ||
peekBuffer.find("HTTP/1.1 304") != std::string::npos ||
peekBuffer.find("HTTP/1.0 204") != std::string::npos ||
peekBuffer.find("HTTP/1.0 304") != std::string::npos;
if (isCloseConn)
break;
else
{ {
complete = true; complete = true;
fullBody = ""; processed = serverStream.buffer.size();
totalProcessed = bodyStart; body = serverStream.buffer.substr(bStart);
}
else
{
complete = false;
} }
} }
if (complete) if (complete)
{ {
std::string headers = serverBuffer.substr(0, bodyStart); std::string url = pendingUrls.empty() ? ("https://" + host) : pendingUrls.front();
std::string responseData = fullBody; if (!pendingUrls.empty()) pendingUrls.pop_front();
std::string currentUrl = "https://" + host; std::string respHeaders = serverStream.buffer.substr(0, bStart);
if (!pendingUrls.empty()) size_t firstSpace = respHeaders.find(' ');
int sc =
(firstSpace != std::string::npos) ? safe_stoi(respHeaders.substr(firstSpace + 1, 3)) : 0;
OnServerResponse.run(url, body);
removeHeader(respHeaders, "Transfer-Encoding");
removeHeader(respHeaders, "Content-Length");
if (sc != 204 && sc != 304 && sc != 205)
{ {
currentUrl = pendingUrls.front(); respHeaders.insert(respHeaders.size() - 2,
pendingUrls.pop_front(); "Content-Length: " + std::to_string(body.size()) + "\r\n");
} }
OnServerResponse.run(currentUrl, responseData); std::string packet = respHeaders + body;
SSL_write(clientSSL, packet.data(), (int)packet.size());
auto removeHeader = [&](std::string& h, const std::string& key) { serverStream.buffer.erase(0, processed);
size_t pos = 0; serverStream.reset();
while (true) clientStream.reset();
{
pos = h.find(key, pos);
if (pos == std::string::npos)
{
std::string lowerKey = key;
for (char& c : lowerKey)
c = (char)tolower(c);
pos = h.find(lowerKey, 0);
if (pos == std::string::npos) break;
}
if (pos == 0 || h[pos - 1] == '\n')
{
size_t end = h.find("\r\n", pos);
if (end != std::string::npos)
{
h.erase(pos, end - pos + 2);
continue;
}
}
pos++;
}
};
removeHeader(headers, "Transfer-Encoding");
removeHeader(headers, "Content-Length");
headers.insert(headers.size() - 2,
"Content-Length: " + std::to_string(responseData.size()) + "\r\n");
std::string finalPacket = headers + responseData;
int sent = SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
serverBuffer.erase(0, totalProcessed);
isReceivingBody = false;
expectedLength = -1;
headersEnd = 0;
isChunked = false;
if (sent <= 0) break;
}
else
break; // wait for more streaming packets
}
}
}
}
if (isReceivingBody && expectedLength < 0 && !isChunked && serverBuffer.size() > headersEnd + 4)
{
std::string headers = serverBuffer.substr(0, headersEnd + 4);
std::string responseData = serverBuffer.substr(headersEnd + 4);
std::string finalUrl = "https://" + host;
if (!pendingUrls.empty())
{
finalUrl = pendingUrls.front();
pendingUrls.pop_front();
}
OnServerResponse.run(finalUrl, responseData);
auto removeHeader = [&](std::string& h, const std::string& key) {
size_t pos = 0;
while (true)
{
pos = h.find(key, pos);
if (pos == std::string::npos)
{
std::string lowerKey = key;
for (char& c : lowerKey)
c = (char)tolower(c);
pos = h.find(lowerKey, 0);
if (pos == std::string::npos) break;
}
if (pos == 0 || h[pos - 1] == '\n')
{
size_t end = h.find("\r\n", pos);
if (end != std::string::npos)
{
h.erase(pos, end - pos + 2);
continue;
}
}
pos++;
}
};
removeHeader(headers, "Transfer-Encoding");
removeHeader(headers, "Content-Length");
headers.insert(headers.size() - 2, "Content-Length: " + std::to_string(responseData.size()) + "\r\n");
std::string finalPacket = headers + responseData;
SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
}
SSL_shutdown(clientSSL);
SSL_free(clientSSL);
SSL_shutdown(remoteSSL);
SSL_free(remoteSSL);
} }
else else
{ {
send(remoteSocket, buffer, bytesReceived, 0); break;
fd_set readfds;
while (_running)
{
FD_ZERO(&readfds);
FD_SET(clientSocket, &readfds);
FD_SET(remoteSocket, &readfds);
struct timeval tv;
tv.tv_sec = 1;
tv.tv_usec = 0;
int ret = select(0, &readfds, NULL, NULL, &tv);
if (ret < 0) break;
if (ret == 0) continue;
if (FD_ISSET(clientSocket, &readfds))
{
int bytes = recv(clientSocket, buffer, sizeof(buffer), 0);
if (bytes <= 0) break;
int sent = send(remoteSocket, buffer, bytes, 0);
if (sent == SOCKET_ERROR) break;
}
if (FD_ISSET(remoteSocket, &readfds))
{
int bytes = recv(remoteSocket, buffer, sizeof(buffer), 0);
if (bytes <= 0) break;
int sent = send(clientSocket, buffer, bytes, 0);
if (sent == SOCKET_ERROR) break;
} }
} }
} }
if (connectionClosed) break;
closesocket(remoteSocket); }
closesocket(clientSocket); }
} }
+3 -2
View File
@@ -12,7 +12,8 @@
typedef unsigned __int64 SOCKET; typedef unsigned __int64 SOCKET;
class Proxy { class Proxy
{
public: public:
Proxy(); Proxy();
~Proxy(); ~Proxy();
@@ -32,7 +33,7 @@ private:
SOCKET _listenSocket = 0; SOCKET _listenSocket = 0;
std::thread _workerThread; std::thread _workerThread;
std::atomic<bool> _running; std::atomic<bool> _running = false;
CertManager _certManager; CertManager _certManager;
SSL_CTX* _clientCtx = nullptr; SSL_CTX* _clientCtx = nullptr;