fix: refactor everything, fix everything
This commit is contained in:
+1
-1
@@ -86,7 +86,7 @@ add_custom_command(TARGET dbd-unlocker POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
${JSON_RES_FILES}
|
||||
"$<TARGET_FILE_DIR:dbd-unlocker>/"
|
||||
COMMENT "Copying JSON resources to executable directory"
|
||||
COMMENT "copying json sources"
|
||||
)
|
||||
|
||||
# ------------------------------
|
||||
|
||||
+404
-437
@@ -5,16 +5,156 @@
|
||||
|
||||
#include <nerutils/log.h>
|
||||
#include <deque>
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
|
||||
namespace
|
||||
{
|
||||
class ScopedSSL
|
||||
{
|
||||
SSL* s;
|
||||
|
||||
public:
|
||||
ScopedSSL(SSL* val = nullptr) : s(val) {}
|
||||
~ScopedSSL()
|
||||
{
|
||||
if (s) SSL_free(s);
|
||||
}
|
||||
void reset(SSL* val)
|
||||
{
|
||||
if (s) SSL_free(s);
|
||||
s = val;
|
||||
}
|
||||
operator SSL*() const { return s; }
|
||||
};
|
||||
|
||||
class ScopedSocket
|
||||
{
|
||||
SOCKET s;
|
||||
|
||||
public:
|
||||
ScopedSocket(SOCKET val = INVALID_SOCKET) : s(val) {}
|
||||
~ScopedSocket()
|
||||
{
|
||||
if (s != INVALID_SOCKET) closesocket(s);
|
||||
}
|
||||
void reset(SOCKET val)
|
||||
{
|
||||
if (s != INVALID_SOCKET) closesocket(s);
|
||||
s = val;
|
||||
}
|
||||
operator SOCKET() const { return s; }
|
||||
SOCKET get() const { return s; }
|
||||
};
|
||||
|
||||
int safe_stoi(const std::string& s, int default_val = 0, int base = 10)
|
||||
{
|
||||
if (s.empty()) return default_val;
|
||||
try
|
||||
{
|
||||
return std::stoi(s, nullptr, base);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return default_val;
|
||||
}
|
||||
}
|
||||
|
||||
void removeHeader(std::string& h, const std::string& key)
|
||||
{
|
||||
std::string h_lower = h;
|
||||
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||
std::string key_lower = key;
|
||||
std::transform(key_lower.begin(), key_lower.end(), key_lower.begin(), ::tolower);
|
||||
|
||||
size_t pos = 0;
|
||||
while ((pos = h_lower.find(key_lower, pos)) != std::string::npos)
|
||||
{
|
||||
if (pos == 0 || h_lower[pos - 1] == '\n')
|
||||
{
|
||||
size_t end = h_lower.find("\r\n", pos);
|
||||
if (end != std::string::npos)
|
||||
{
|
||||
h.erase(pos, end - pos + 2);
|
||||
h_lower.erase(pos, end - pos + 2);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
|
||||
std::string getHeaderValue(const std::string& headers, const std::string& key)
|
||||
{
|
||||
std::string h_lower = headers;
|
||||
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||
std::string search = "\n" + key;
|
||||
std::transform(search.begin(), search.end(), search.begin(), ::tolower);
|
||||
search += ":";
|
||||
|
||||
size_t pos = h_lower.find(search);
|
||||
if (pos == std::string::npos) return ""; // Not found
|
||||
|
||||
size_t vStart = pos + search.length();
|
||||
while (vStart < headers.size() && (headers[vStart] == ' ' || headers[vStart] == '\t'))
|
||||
vStart++;
|
||||
size_t vEnd = headers.find("\r\n", vStart);
|
||||
if (vEnd == std::string::npos) return "";
|
||||
return headers.substr(vStart, vEnd - vStart);
|
||||
}
|
||||
|
||||
struct HttpStream
|
||||
{
|
||||
std::string buffer;
|
||||
bool isReceivingBody = false;
|
||||
bool isChunked = false;
|
||||
int contentLength = -1;
|
||||
size_t headersEnd = 0;
|
||||
|
||||
void reset()
|
||||
{
|
||||
isReceivingBody = false;
|
||||
isChunked = false;
|
||||
contentLength = -1;
|
||||
headersEnd = 0;
|
||||
}
|
||||
|
||||
bool parseHeaders()
|
||||
{
|
||||
if (isReceivingBody) return true;
|
||||
headersEnd = buffer.find("\r\n\r\n");
|
||||
if (headersEnd == std::string::npos) return false;
|
||||
|
||||
std::string headers = buffer.substr(0, headersEnd + 4);
|
||||
|
||||
std::string te = getHeaderValue(headers, "Transfer-Encoding");
|
||||
std::transform(te.begin(), te.end(), te.begin(), ::tolower);
|
||||
isChunked = (te.find("chunked") != std::string::npos);
|
||||
|
||||
if (!isChunked)
|
||||
{
|
||||
std::string cl = getHeaderValue(headers, "Content-Length");
|
||||
contentLength = safe_stoi(cl, -1);
|
||||
}
|
||||
else
|
||||
{
|
||||
contentLength = -1;
|
||||
}
|
||||
|
||||
isReceivingBody = true;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool Proxy::initSSL()
|
||||
{
|
||||
_clientCtx = SSL_CTX_new(TLS_client_method());
|
||||
if (!_clientCtx)
|
||||
{
|
||||
Log::error("Failed to create client SSL context");
|
||||
return false;
|
||||
}
|
||||
if (!_clientCtx) return false;
|
||||
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
|
||||
const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
|
||||
SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos));
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -28,71 +168,43 @@ void Proxy::cleanupSSL()
|
||||
}
|
||||
|
||||
Proxy::Proxy() {}
|
||||
|
||||
Proxy::~Proxy() {}
|
||||
Proxy::~Proxy()
|
||||
{
|
||||
Shutdown();
|
||||
}
|
||||
|
||||
bool Proxy::Init()
|
||||
{
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
|
||||
{
|
||||
Log::error("WSAStartup failed");
|
||||
return false;
|
||||
}
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false;
|
||||
if (!_certManager.Init()) return false;
|
||||
if (!initSSL()) return false;
|
||||
|
||||
if (!_certManager.Init())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!initSSL())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
_listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
||||
if (_listenSocket == INVALID_SOCKET)
|
||||
{
|
||||
Log::error("Error creating listen socket: {0:x}", WSAGetLastError());
|
||||
Shutdown();
|
||||
return false;
|
||||
}
|
||||
|
||||
sockaddr_in serverAddr;
|
||||
sockaddr_in serverAddr = {};
|
||||
serverAddr.sin_family = AF_INET;
|
||||
serverAddr.sin_port = htons(PROXY_PORT);
|
||||
inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr);
|
||||
|
||||
if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR)
|
||||
{
|
||||
Log::error("Error binding listen socket: {0:x}", WSAGetLastError());
|
||||
Shutdown();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false;
|
||||
listen(_listenSocket, SOMAXCONN);
|
||||
|
||||
Log::verbose("Listening on 127.0.0.1:{}", PROXY_PORT);
|
||||
Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT);
|
||||
|
||||
_running = true;
|
||||
_workerThread = std::thread(&Proxy::loop, this);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Proxy::Shutdown()
|
||||
{
|
||||
_running = false;
|
||||
|
||||
if (_listenSocket != INVALID_SOCKET && _listenSocket != 0)
|
||||
if (_listenSocket != INVALID_SOCKET)
|
||||
{
|
||||
closesocket(_listenSocket);
|
||||
_listenSocket = 0;
|
||||
_listenSocket = INVALID_SOCKET;
|
||||
}
|
||||
|
||||
if (_workerThread.joinable()) _workerThread.join();
|
||||
|
||||
WSACleanup();
|
||||
|
||||
cleanupSSL();
|
||||
}
|
||||
|
||||
@@ -100,500 +212,355 @@ void Proxy::loop()
|
||||
{
|
||||
while (_running)
|
||||
{
|
||||
SOCKET clientSocket = accept(_listenSocket, NULL, NULL);
|
||||
|
||||
SOCKET hClient = accept(_listenSocket, NULL, NULL);
|
||||
if (!_running)
|
||||
{
|
||||
if (clientSocket != INVALID_SOCKET) closesocket(clientSocket);
|
||||
if (hClient != INVALID_SOCKET) closesocket(hClient);
|
||||
break;
|
||||
}
|
||||
|
||||
if (clientSocket == INVALID_SOCKET) continue;
|
||||
|
||||
std::thread([this, clientSocket]() { this->handleClient(clientSocket); }).detach();
|
||||
if (hClient == INVALID_SOCKET) continue;
|
||||
std::thread([this, hClient]() { this->handleClient(hClient); }).detach();
|
||||
}
|
||||
}
|
||||
|
||||
void Proxy::handleClient(SOCKET clientSocket)
|
||||
void Proxy::handleClient(SOCKET hClientSocket)
|
||||
{
|
||||
char buffer[8192];
|
||||
int bytesReceived = recv(clientSocket, buffer, sizeof(buffer) - 1, 0);
|
||||
ScopedSocket clientGuard(hClientSocket);
|
||||
char buffer[32768];
|
||||
|
||||
if (bytesReceived <= 0)
|
||||
{
|
||||
closesocket(clientSocket);
|
||||
return;
|
||||
}
|
||||
int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0);
|
||||
if (bytesRead <= 0) return;
|
||||
buffer[bytesRead] = '\0';
|
||||
|
||||
buffer[bytesReceived] = '\0';
|
||||
std::string request(buffer, bytesReceived);
|
||||
std::string initialReq(buffer);
|
||||
if (initialReq.find("CONNECT ") != 0) return;
|
||||
|
||||
std::string method, url;
|
||||
size_t space1 = request.find(' ');
|
||||
if (space1 != std::string::npos)
|
||||
{
|
||||
method = request.substr(0, space1);
|
||||
size_t space2 = request.find(' ', space1 + 1);
|
||||
if (space2 != std::string::npos)
|
||||
{
|
||||
url = request.substr(space1 + 1, space2 - space1 - 1);
|
||||
}
|
||||
}
|
||||
size_t hostStart = 8;
|
||||
size_t hostEnd = initialReq.find(' ', hostStart);
|
||||
if (hostEnd == std::string::npos) return;
|
||||
|
||||
if (method.empty() || url.empty())
|
||||
{
|
||||
closesocket(clientSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
std::string host;
|
||||
std::string port = "80";
|
||||
bool isConnect = (method == "CONNECT");
|
||||
|
||||
if (isConnect)
|
||||
{
|
||||
size_t colon = url.find(':');
|
||||
std::string fullHost = initialReq.substr(hostStart, hostEnd - hostStart);
|
||||
std::string host = fullHost;
|
||||
int port = 443;
|
||||
size_t colon = fullHost.find(':');
|
||||
if (colon != std::string::npos)
|
||||
{
|
||||
host = url.substr(0, colon);
|
||||
port = url.substr(colon + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
host = url;
|
||||
port = "443";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t hostPos = request.find("Host: ");
|
||||
if (hostPos != std::string::npos)
|
||||
{
|
||||
size_t endPos = request.find("\r\n", hostPos);
|
||||
if (endPos != std::string::npos)
|
||||
{
|
||||
std::string hostHeader = request.substr(hostPos + 6, endPos - (hostPos + 6));
|
||||
size_t colon = hostHeader.find(':');
|
||||
if (colon != std::string::npos)
|
||||
{
|
||||
host = hostHeader.substr(0, colon);
|
||||
port = hostHeader.substr(colon + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
host = hostHeader;
|
||||
}
|
||||
}
|
||||
}
|
||||
host = fullHost.substr(0, colon);
|
||||
port = safe_stoi(fullHost.substr(colon + 1), 443);
|
||||
}
|
||||
|
||||
if (host.empty())
|
||||
{
|
||||
closesocket(clientSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
struct addrinfo hints = {}, *res;
|
||||
struct addrinfo hints = {}, *res = nullptr;
|
||||
hints.ai_family = AF_INET;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return;
|
||||
|
||||
if (getaddrinfo(host.c_str(), port.c_str(), &hints, &res) != 0)
|
||||
ScopedSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
||||
if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0)
|
||||
{
|
||||
Log::error("Could not resolve host: {}", host);
|
||||
closesocket(clientSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
SOCKET remoteSocket = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
|
||||
if (connect(remoteSocket, res->ai_addr, (int)res->ai_addrlen) == SOCKET_ERROR)
|
||||
{
|
||||
Log::error("Connection to {}:{} failed", host, port);
|
||||
freeaddrinfo(res);
|
||||
closesocket(clientSocket);
|
||||
return;
|
||||
}
|
||||
freeaddrinfo(res);
|
||||
|
||||
if (isConnect)
|
||||
{
|
||||
const char* reply = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
send(clientSocket, reply, static_cast<int>(strlen(reply)), 0);
|
||||
send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0);
|
||||
|
||||
SSL_CTX* serverCtx = _certManager.CreateHostContext(host);
|
||||
if (!serverCtx)
|
||||
{
|
||||
Log::error("Failed to generate dynamic cert for {}", host);
|
||||
closesocket(clientSocket);
|
||||
closesocket(remoteSocket);
|
||||
return;
|
||||
}
|
||||
SSL_CTX* hostCtx = _certManager.CreateHostContext(host);
|
||||
if (!hostCtx) return;
|
||||
|
||||
SSL* clientSSL = SSL_new(serverCtx);
|
||||
SSL_set_fd(clientSSL, static_cast<int>(clientSocket));
|
||||
ScopedSSL clientSSL(SSL_new(hostCtx));
|
||||
SSL_set_fd(clientSSL, (int)clientGuard.get());
|
||||
if (SSL_accept(clientSSL) <= 0) return;
|
||||
|
||||
if (SSL_accept(clientSSL) <= 0)
|
||||
{
|
||||
Log::error("SSL_accept failed on client");
|
||||
SSL_free(clientSSL);
|
||||
closesocket(clientSocket);
|
||||
closesocket(remoteSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
SSL* remoteSSL = SSL_new(_clientCtx);
|
||||
SSL_set_fd(remoteSSL, static_cast<int>(remoteSocket));
|
||||
ScopedSSL remoteSSL(SSL_new(_clientCtx));
|
||||
SSL_set_fd(remoteSSL, (int)remoteGuard.get());
|
||||
SSL_set_tlsext_host_name(remoteSSL, host.c_str());
|
||||
if (SSL_connect(remoteSSL) <= 0) return;
|
||||
|
||||
if (SSL_connect(remoteSSL) <= 0)
|
||||
{
|
||||
Log::error("SSL_connect failed on remote server");
|
||||
SSL_free(remoteSSL);
|
||||
SSL_free(clientSSL);
|
||||
closesocket(clientSocket);
|
||||
closesocket(remoteSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
HttpStream clientStream, serverStream;
|
||||
std::deque<std::string> pendingUrls;
|
||||
std::string serverBuffer;
|
||||
bool isReceivingBody = false;
|
||||
int expectedLength = -1;
|
||||
bool isChunked = false;
|
||||
size_t headersEnd = 0;
|
||||
|
||||
bool tunnelMode = false;
|
||||
fd_set readfds;
|
||||
|
||||
while (_running)
|
||||
{
|
||||
FD_ZERO(&readfds);
|
||||
FD_SET(clientSocket, &readfds);
|
||||
FD_SET(remoteSocket, &readfds);
|
||||
FD_SET(clientGuard, &readfds);
|
||||
FD_SET(remoteGuard, &readfds);
|
||||
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 0;
|
||||
tv.tv_usec = 100000;
|
||||
struct timeval tv = {0, 50000};
|
||||
if (select(0, &readfds, NULL, NULL, &tv) < 0) break;
|
||||
|
||||
int ret = select(0, &readfds, NULL, NULL, &tv);
|
||||
if (ret < 0) break;
|
||||
|
||||
if (FD_ISSET(clientSocket, &readfds) || SSL_pending(clientSSL) > 0)
|
||||
if (tunnelMode)
|
||||
{
|
||||
int bytes = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||
if (bytes <= 0) break;
|
||||
|
||||
std::string data(buffer, bytes);
|
||||
|
||||
size_t reqStart = 0;
|
||||
while (reqStart < data.size())
|
||||
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
|
||||
{
|
||||
size_t nextReq = std::string::npos;
|
||||
const char* methods[] = {"GET ", "POST ", "PUT ", "DELETE ", "PATCH ", "OPTIONS ", "HEAD "};
|
||||
for (const char* m : methods)
|
||||
int n = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||
if (n <= 0) break;
|
||||
SSL_write(remoteSSL, buffer, n);
|
||||
}
|
||||
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||
{
|
||||
size_t found = data.find(m, reqStart + 1);
|
||||
if (found != std::string::npos && (nextReq == std::string::npos || found < nextReq))
|
||||
nextReq = found;
|
||||
int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
|
||||
if (n <= 0) break;
|
||||
SSL_write(clientSSL, buffer, n);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string singleReq = (nextReq == std::string::npos) ? data.substr(reqStart)
|
||||
: data.substr(reqStart, nextReq - reqStart);
|
||||
|
||||
std::string url;
|
||||
size_t pathSpace1 = singleReq.find(' ');
|
||||
size_t pathSpace2 = singleReq.find(' ', pathSpace1 + 1);
|
||||
if (pathSpace1 != std::string::npos && pathSpace2 != std::string::npos)
|
||||
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
|
||||
{
|
||||
url = "https://" + host + singleReq.substr(pathSpace1 + 1, pathSpace2 - pathSpace1 - 1);
|
||||
int n = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||
if (n <= 0) break;
|
||||
clientStream.buffer.append(buffer, n);
|
||||
|
||||
while (!clientStream.buffer.empty())
|
||||
{
|
||||
if (!clientStream.isReceivingBody)
|
||||
{
|
||||
if (!clientStream.parseHeaders()) break;
|
||||
|
||||
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
|
||||
std::string url = "https://" + host;
|
||||
size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1);
|
||||
if (s1 != std::string::npos && s2 != std::string::npos)
|
||||
url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1);
|
||||
|
||||
pendingUrls.push_back(url);
|
||||
|
||||
size_t aePos = singleReq.find("Accept-Encoding:");
|
||||
if (aePos == std::string::npos) aePos = singleReq.find("accept-encoding:");
|
||||
if (aePos != std::string::npos)
|
||||
removeHeader(headers, "Accept-Encoding");
|
||||
removeHeader(headers, "Expect");
|
||||
headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n");
|
||||
|
||||
OnClientRequest.run(url, headers);
|
||||
SSL_write(remoteSSL, headers.data(), (int)headers.size());
|
||||
|
||||
clientStream.buffer.erase(0, clientStream.headersEnd + 4);
|
||||
}
|
||||
|
||||
if (clientStream.isReceivingBody)
|
||||
{
|
||||
size_t aeEndPos = singleReq.find("\r\n", aePos);
|
||||
if (aeEndPos != std::string::npos)
|
||||
singleReq.replace(aePos, aeEndPos - aePos, "Accept-Encoding: identity");
|
||||
if (clientStream.isChunked)
|
||||
{
|
||||
size_t idx = 0;
|
||||
while (idx < clientStream.buffer.size())
|
||||
{
|
||||
size_t le = clientStream.buffer.find("\r\n", idx);
|
||||
if (le == std::string::npos) break;
|
||||
|
||||
int chunkSz = safe_stoi(clientStream.buffer.substr(idx, le - idx), 0, 16);
|
||||
size_t totalChunkSz = (le - idx) + 2 + chunkSz + 2;
|
||||
if (idx + totalChunkSz > clientStream.buffer.size()) break;
|
||||
|
||||
SSL_write(remoteSSL, clientStream.buffer.data() + idx, (int)totalChunkSz);
|
||||
idx += totalChunkSz;
|
||||
|
||||
if (chunkSz == 0)
|
||||
{
|
||||
clientStream.reset();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx > 0) clientStream.buffer.erase(0, idx);
|
||||
if (!clientStream.isReceivingBody) continue;
|
||||
break;
|
||||
}
|
||||
else if (clientStream.contentLength >= 0)
|
||||
{
|
||||
size_t ts = (std::min)((size_t)clientStream.contentLength, clientStream.buffer.size());
|
||||
if (ts > 0)
|
||||
{
|
||||
SSL_write(remoteSSL, clientStream.buffer.data(), (int)ts);
|
||||
clientStream.buffer.erase(0, ts);
|
||||
clientStream.contentLength -= (int)ts;
|
||||
}
|
||||
if (clientStream.contentLength <= 0)
|
||||
clientStream.reset();
|
||||
else
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
clientStream.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OnClientRequest.run(url.empty() ? ("https://" + host) : url, singleReq);
|
||||
|
||||
if (nextReq == std::string::npos) break;
|
||||
reqStart = nextReq;
|
||||
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||
{
|
||||
int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
|
||||
bool connectionClosed = (n <= 0);
|
||||
if (!connectionClosed)
|
||||
{
|
||||
serverStream.buffer.append(buffer, n);
|
||||
}
|
||||
|
||||
int sent = SSL_write(remoteSSL, data.data(), static_cast<int>(data.size()));
|
||||
if (sent <= 0) break;
|
||||
while (!serverStream.buffer.empty() || connectionClosed)
|
||||
{
|
||||
if (!serverStream.isReceivingBody)
|
||||
{
|
||||
if (connectionClosed) break;
|
||||
|
||||
serverStream.headersEnd = serverStream.buffer.find("\r\n\r\n");
|
||||
if (serverStream.headersEnd == std::string::npos) break;
|
||||
|
||||
std::string headers = serverStream.buffer.substr(0, serverStream.headersEnd + 4);
|
||||
size_t s1 = headers.find(' ');
|
||||
int sCode = (s1 != std::string::npos) ? safe_stoi(headers.substr(s1 + 1, 3)) : 0;
|
||||
|
||||
if (sCode == 101)
|
||||
{
|
||||
SSL_write(clientSSL, serverStream.buffer.data(), (int)serverStream.buffer.size());
|
||||
serverStream.buffer.clear();
|
||||
clientStream.buffer.clear();
|
||||
tunnelMode = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (FD_ISSET(remoteSocket, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||
if (sCode >= 100 && sCode < 200)
|
||||
{
|
||||
int bytes = SSL_read(remoteSSL, buffer, sizeof(buffer) - 1);
|
||||
if (bytes <= 0) break;
|
||||
SSL_write(clientSSL, headers.data(), (int)headers.size());
|
||||
serverStream.buffer.erase(0, serverStream.headersEnd + 4);
|
||||
serverStream.isReceivingBody = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
serverBuffer.append(buffer, bytes);
|
||||
serverStream.isReceivingBody = true;
|
||||
std::string h_lower = headers;
|
||||
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||
serverStream.isChunked = (h_lower.find("transfer-encoding: chunked") != std::string::npos);
|
||||
|
||||
while (!serverBuffer.empty())
|
||||
{
|
||||
if (!isReceivingBody)
|
||||
{
|
||||
headersEnd = serverBuffer.find("\r\n\r\n");
|
||||
if (headersEnd != std::string::npos)
|
||||
{
|
||||
isReceivingBody = true;
|
||||
std::string headers = serverBuffer.substr(0, headersEnd + 4);
|
||||
|
||||
size_t clPos = headers.find("Content-Length: ");
|
||||
if (clPos == std::string::npos) clPos = headers.find("content-length: ");
|
||||
size_t clPos = h_lower.find("content-length:");
|
||||
if (clPos != std::string::npos)
|
||||
{
|
||||
size_t clEnd = headers.find("\r\n", clPos);
|
||||
expectedLength = std::stoi(headers.substr(clPos + 16, clEnd - clPos - 16));
|
||||
size_t vStart = clPos + 15;
|
||||
while (vStart < h_lower.size() && (h_lower[vStart] == ' ' || h_lower[vStart] == '\t'))
|
||||
vStart++;
|
||||
serverStream.contentLength =
|
||||
safe_stoi(h_lower.substr(vStart, h_lower.find("\r\n", vStart) - vStart), -1);
|
||||
}
|
||||
else
|
||||
expectedLength = -1;
|
||||
|
||||
isChunked = (headers.find("chunked") != std::string::npos);
|
||||
}
|
||||
else
|
||||
else if (sCode == 204 || sCode == 304 || sCode == 205)
|
||||
{
|
||||
break; // need more data
|
||||
serverStream.contentLength = 0;
|
||||
}
|
||||
else if (!serverStream.isChunked)
|
||||
{
|
||||
serverStream.contentLength = -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (isReceivingBody)
|
||||
if (serverStream.isReceivingBody)
|
||||
{
|
||||
bool complete = false;
|
||||
std::string fullBody;
|
||||
size_t bodyStart = headersEnd + 4;
|
||||
size_t totalProcessed = bodyStart;
|
||||
std::string body;
|
||||
size_t bStart = serverStream.headersEnd + 4;
|
||||
size_t processed = bStart;
|
||||
|
||||
if (isChunked)
|
||||
if (serverStream.isChunked)
|
||||
{
|
||||
size_t idx = bodyStart;
|
||||
bool parseOk = true;
|
||||
while (idx < serverBuffer.size())
|
||||
size_t idx = bStart;
|
||||
bool ok = true;
|
||||
while (idx < serverStream.buffer.size())
|
||||
{
|
||||
size_t lineEnd = serverBuffer.find("\r\n", idx);
|
||||
if (lineEnd == std::string::npos)
|
||||
size_t le = serverStream.buffer.find("\r\n", idx);
|
||||
if (le == std::string::npos)
|
||||
{
|
||||
parseOk = false;
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
std::string hexStr = serverBuffer.substr(idx, lineEnd - idx);
|
||||
int chunkSize = 0;
|
||||
try
|
||||
int cs = safe_stoi(serverStream.buffer.substr(idx, le - idx), 0, 16);
|
||||
idx = le + 2;
|
||||
if (cs == 0)
|
||||
{
|
||||
chunkSize = std::stoi(hexStr, nullptr, 16);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
parseOk = false;
|
||||
break;
|
||||
}
|
||||
idx = lineEnd + 2;
|
||||
if (chunkSize == 0)
|
||||
{
|
||||
idx += 2; // skip terminal \r\n
|
||||
idx += 2;
|
||||
complete = true;
|
||||
totalProcessed = idx;
|
||||
processed = idx;
|
||||
break;
|
||||
}
|
||||
if (idx + (size_t)chunkSize + 2 > serverBuffer.size())
|
||||
if (idx + cs + 2 > serverStream.buffer.size())
|
||||
{
|
||||
parseOk = false;
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
fullBody.append(serverBuffer, idx, chunkSize);
|
||||
idx += chunkSize + 2;
|
||||
body.append(serverStream.buffer, idx, cs);
|
||||
idx += cs + 2;
|
||||
}
|
||||
if (!parseOk) complete = false;
|
||||
}
|
||||
else if (expectedLength >= 0)
|
||||
if (!ok)
|
||||
{
|
||||
if (serverBuffer.size() >= bodyStart + expectedLength)
|
||||
if (connectionClosed)
|
||||
{
|
||||
complete = true;
|
||||
totalProcessed = bodyStart + expectedLength;
|
||||
fullBody = serverBuffer.substr(bodyStart, expectedLength);
|
||||
processed = serverStream.buffer.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
complete = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (serverStream.contentLength >= 0)
|
||||
{
|
||||
if (serverStream.buffer.size() >= bStart + serverStream.contentLength)
|
||||
{
|
||||
complete = true;
|
||||
processed = bStart + serverStream.contentLength;
|
||||
body = serverStream.buffer.substr(bStart, serverStream.contentLength);
|
||||
}
|
||||
else if (connectionClosed)
|
||||
{
|
||||
complete = true;
|
||||
processed = serverStream.buffer.size();
|
||||
body = serverStream.buffer.substr(bStart);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string peekBuffer = serverBuffer.substr(0, bodyStart);
|
||||
bool isCloseConn = peekBuffer.find("Connection: close") != std::string::npos ||
|
||||
peekBuffer.find("connection: close") != std::string::npos;
|
||||
|
||||
bool isNoBodyStatus = peekBuffer.find("HTTP/1.1 204") != std::string::npos ||
|
||||
peekBuffer.find("HTTP/1.1 304") != std::string::npos ||
|
||||
peekBuffer.find("HTTP/1.0 204") != std::string::npos ||
|
||||
peekBuffer.find("HTTP/1.0 304") != std::string::npos;
|
||||
|
||||
if (isCloseConn)
|
||||
break;
|
||||
else
|
||||
if (connectionClosed)
|
||||
{
|
||||
complete = true;
|
||||
fullBody = "";
|
||||
totalProcessed = bodyStart;
|
||||
processed = serverStream.buffer.size();
|
||||
body = serverStream.buffer.substr(bStart);
|
||||
}
|
||||
else
|
||||
{
|
||||
complete = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (complete)
|
||||
{
|
||||
std::string headers = serverBuffer.substr(0, bodyStart);
|
||||
std::string responseData = fullBody;
|
||||
std::string url = pendingUrls.empty() ? ("https://" + host) : pendingUrls.front();
|
||||
if (!pendingUrls.empty()) pendingUrls.pop_front();
|
||||
|
||||
std::string currentUrl = "https://" + host;
|
||||
if (!pendingUrls.empty())
|
||||
std::string respHeaders = serverStream.buffer.substr(0, bStart);
|
||||
size_t firstSpace = respHeaders.find(' ');
|
||||
int sc =
|
||||
(firstSpace != std::string::npos) ? safe_stoi(respHeaders.substr(firstSpace + 1, 3)) : 0;
|
||||
|
||||
OnServerResponse.run(url, body);
|
||||
|
||||
removeHeader(respHeaders, "Transfer-Encoding");
|
||||
removeHeader(respHeaders, "Content-Length");
|
||||
if (sc != 204 && sc != 304 && sc != 205)
|
||||
{
|
||||
currentUrl = pendingUrls.front();
|
||||
pendingUrls.pop_front();
|
||||
respHeaders.insert(respHeaders.size() - 2,
|
||||
"Content-Length: " + std::to_string(body.size()) + "\r\n");
|
||||
}
|
||||
|
||||
OnServerResponse.run(currentUrl, responseData);
|
||||
std::string packet = respHeaders + body;
|
||||
SSL_write(clientSSL, packet.data(), (int)packet.size());
|
||||
|
||||
auto removeHeader = [&](std::string& h, const std::string& key) {
|
||||
size_t pos = 0;
|
||||
while (true)
|
||||
{
|
||||
pos = h.find(key, pos);
|
||||
if (pos == std::string::npos)
|
||||
{
|
||||
std::string lowerKey = key;
|
||||
for (char& c : lowerKey)
|
||||
c = (char)tolower(c);
|
||||
pos = h.find(lowerKey, 0);
|
||||
if (pos == std::string::npos) break;
|
||||
}
|
||||
if (pos == 0 || h[pos - 1] == '\n')
|
||||
{
|
||||
size_t end = h.find("\r\n", pos);
|
||||
if (end != std::string::npos)
|
||||
{
|
||||
h.erase(pos, end - pos + 2);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
};
|
||||
|
||||
removeHeader(headers, "Transfer-Encoding");
|
||||
removeHeader(headers, "Content-Length");
|
||||
|
||||
headers.insert(headers.size() - 2,
|
||||
"Content-Length: " + std::to_string(responseData.size()) + "\r\n");
|
||||
|
||||
std::string finalPacket = headers + responseData;
|
||||
int sent = SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
|
||||
|
||||
serverBuffer.erase(0, totalProcessed);
|
||||
|
||||
isReceivingBody = false;
|
||||
expectedLength = -1;
|
||||
headersEnd = 0;
|
||||
isChunked = false;
|
||||
|
||||
if (sent <= 0) break;
|
||||
}
|
||||
else
|
||||
break; // wait for more streaming packets
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isReceivingBody && expectedLength < 0 && !isChunked && serverBuffer.size() > headersEnd + 4)
|
||||
{
|
||||
std::string headers = serverBuffer.substr(0, headersEnd + 4);
|
||||
std::string responseData = serverBuffer.substr(headersEnd + 4);
|
||||
|
||||
std::string finalUrl = "https://" + host;
|
||||
if (!pendingUrls.empty())
|
||||
{
|
||||
finalUrl = pendingUrls.front();
|
||||
pendingUrls.pop_front();
|
||||
}
|
||||
|
||||
OnServerResponse.run(finalUrl, responseData);
|
||||
|
||||
auto removeHeader = [&](std::string& h, const std::string& key) {
|
||||
size_t pos = 0;
|
||||
while (true)
|
||||
{
|
||||
pos = h.find(key, pos);
|
||||
if (pos == std::string::npos)
|
||||
{
|
||||
std::string lowerKey = key;
|
||||
for (char& c : lowerKey)
|
||||
c = (char)tolower(c);
|
||||
pos = h.find(lowerKey, 0);
|
||||
if (pos == std::string::npos) break;
|
||||
}
|
||||
if (pos == 0 || h[pos - 1] == '\n')
|
||||
{
|
||||
size_t end = h.find("\r\n", pos);
|
||||
if (end != std::string::npos)
|
||||
{
|
||||
h.erase(pos, end - pos + 2);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
};
|
||||
|
||||
removeHeader(headers, "Transfer-Encoding");
|
||||
removeHeader(headers, "Content-Length");
|
||||
headers.insert(headers.size() - 2, "Content-Length: " + std::to_string(responseData.size()) + "\r\n");
|
||||
|
||||
std::string finalPacket = headers + responseData;
|
||||
SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
|
||||
}
|
||||
|
||||
SSL_shutdown(clientSSL);
|
||||
SSL_free(clientSSL);
|
||||
|
||||
SSL_shutdown(remoteSSL);
|
||||
SSL_free(remoteSSL);
|
||||
serverStream.buffer.erase(0, processed);
|
||||
serverStream.reset();
|
||||
clientStream.reset();
|
||||
}
|
||||
else
|
||||
{
|
||||
send(remoteSocket, buffer, bytesReceived, 0);
|
||||
|
||||
fd_set readfds;
|
||||
while (_running)
|
||||
{
|
||||
FD_ZERO(&readfds);
|
||||
FD_SET(clientSocket, &readfds);
|
||||
FD_SET(remoteSocket, &readfds);
|
||||
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 1;
|
||||
tv.tv_usec = 0;
|
||||
|
||||
int ret = select(0, &readfds, NULL, NULL, &tv);
|
||||
if (ret < 0) break;
|
||||
if (ret == 0) continue;
|
||||
|
||||
if (FD_ISSET(clientSocket, &readfds))
|
||||
{
|
||||
int bytes = recv(clientSocket, buffer, sizeof(buffer), 0);
|
||||
if (bytes <= 0) break;
|
||||
int sent = send(remoteSocket, buffer, bytes, 0);
|
||||
if (sent == SOCKET_ERROR) break;
|
||||
}
|
||||
|
||||
if (FD_ISSET(remoteSocket, &readfds))
|
||||
{
|
||||
int bytes = recv(remoteSocket, buffer, sizeof(buffer), 0);
|
||||
if (bytes <= 0) break;
|
||||
int sent = send(clientSocket, buffer, bytes, 0);
|
||||
if (sent == SOCKET_ERROR) break;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
closesocket(remoteSocket);
|
||||
closesocket(clientSocket);
|
||||
if (connectionClosed) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,8 @@
|
||||
|
||||
typedef unsigned __int64 SOCKET;
|
||||
|
||||
class Proxy {
|
||||
class Proxy
|
||||
{
|
||||
public:
|
||||
Proxy();
|
||||
~Proxy();
|
||||
@@ -32,7 +33,7 @@ private:
|
||||
|
||||
SOCKET _listenSocket = 0;
|
||||
std::thread _workerThread;
|
||||
std::atomic<bool> _running;
|
||||
std::atomic<bool> _running = false;
|
||||
|
||||
CertManager _certManager;
|
||||
SSL_CTX* _clientCtx = nullptr;
|
||||
|
||||
Reference in New Issue
Block a user