fix: refactor everything, fix everything
This commit is contained in:
+1
-1
@@ -86,7 +86,7 @@ add_custom_command(TARGET dbd-unlocker POST_BUILD
|
|||||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||||
${JSON_RES_FILES}
|
${JSON_RES_FILES}
|
||||||
"$<TARGET_FILE_DIR:dbd-unlocker>/"
|
"$<TARGET_FILE_DIR:dbd-unlocker>/"
|
||||||
COMMENT "Copying JSON resources to executable directory"
|
COMMENT "copying json sources"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
|
|||||||
+404
-437
@@ -5,16 +5,156 @@
|
|||||||
|
|
||||||
#include <nerutils/log.h>
|
#include <nerutils/log.h>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <atomic>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
class ScopedSSL
|
||||||
|
{
|
||||||
|
SSL* s;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ScopedSSL(SSL* val = nullptr) : s(val) {}
|
||||||
|
~ScopedSSL()
|
||||||
|
{
|
||||||
|
if (s) SSL_free(s);
|
||||||
|
}
|
||||||
|
void reset(SSL* val)
|
||||||
|
{
|
||||||
|
if (s) SSL_free(s);
|
||||||
|
s = val;
|
||||||
|
}
|
||||||
|
operator SSL*() const { return s; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScopedSocket
|
||||||
|
{
|
||||||
|
SOCKET s;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ScopedSocket(SOCKET val = INVALID_SOCKET) : s(val) {}
|
||||||
|
~ScopedSocket()
|
||||||
|
{
|
||||||
|
if (s != INVALID_SOCKET) closesocket(s);
|
||||||
|
}
|
||||||
|
void reset(SOCKET val)
|
||||||
|
{
|
||||||
|
if (s != INVALID_SOCKET) closesocket(s);
|
||||||
|
s = val;
|
||||||
|
}
|
||||||
|
operator SOCKET() const { return s; }
|
||||||
|
SOCKET get() const { return s; }
|
||||||
|
};
|
||||||
|
|
||||||
|
int safe_stoi(const std::string& s, int default_val = 0, int base = 10)
|
||||||
|
{
|
||||||
|
if (s.empty()) return default_val;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
return std::stoi(s, nullptr, base);
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
return default_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void removeHeader(std::string& h, const std::string& key)
|
||||||
|
{
|
||||||
|
std::string h_lower = h;
|
||||||
|
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||||
|
std::string key_lower = key;
|
||||||
|
std::transform(key_lower.begin(), key_lower.end(), key_lower.begin(), ::tolower);
|
||||||
|
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = h_lower.find(key_lower, pos)) != std::string::npos)
|
||||||
|
{
|
||||||
|
if (pos == 0 || h_lower[pos - 1] == '\n')
|
||||||
|
{
|
||||||
|
size_t end = h_lower.find("\r\n", pos);
|
||||||
|
if (end != std::string::npos)
|
||||||
|
{
|
||||||
|
h.erase(pos, end - pos + 2);
|
||||||
|
h_lower.erase(pos, end - pos + 2);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getHeaderValue(const std::string& headers, const std::string& key)
|
||||||
|
{
|
||||||
|
std::string h_lower = headers;
|
||||||
|
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||||
|
std::string search = "\n" + key;
|
||||||
|
std::transform(search.begin(), search.end(), search.begin(), ::tolower);
|
||||||
|
search += ":";
|
||||||
|
|
||||||
|
size_t pos = h_lower.find(search);
|
||||||
|
if (pos == std::string::npos) return ""; // Not found
|
||||||
|
|
||||||
|
size_t vStart = pos + search.length();
|
||||||
|
while (vStart < headers.size() && (headers[vStart] == ' ' || headers[vStart] == '\t'))
|
||||||
|
vStart++;
|
||||||
|
size_t vEnd = headers.find("\r\n", vStart);
|
||||||
|
if (vEnd == std::string::npos) return "";
|
||||||
|
return headers.substr(vStart, vEnd - vStart);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HttpStream
|
||||||
|
{
|
||||||
|
std::string buffer;
|
||||||
|
bool isReceivingBody = false;
|
||||||
|
bool isChunked = false;
|
||||||
|
int contentLength = -1;
|
||||||
|
size_t headersEnd = 0;
|
||||||
|
|
||||||
|
void reset()
|
||||||
|
{
|
||||||
|
isReceivingBody = false;
|
||||||
|
isChunked = false;
|
||||||
|
contentLength = -1;
|
||||||
|
headersEnd = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool parseHeaders()
|
||||||
|
{
|
||||||
|
if (isReceivingBody) return true;
|
||||||
|
headersEnd = buffer.find("\r\n\r\n");
|
||||||
|
if (headersEnd == std::string::npos) return false;
|
||||||
|
|
||||||
|
std::string headers = buffer.substr(0, headersEnd + 4);
|
||||||
|
|
||||||
|
std::string te = getHeaderValue(headers, "Transfer-Encoding");
|
||||||
|
std::transform(te.begin(), te.end(), te.begin(), ::tolower);
|
||||||
|
isChunked = (te.find("chunked") != std::string::npos);
|
||||||
|
|
||||||
|
if (!isChunked)
|
||||||
|
{
|
||||||
|
std::string cl = getHeaderValue(headers, "Content-Length");
|
||||||
|
contentLength = safe_stoi(cl, -1);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
contentLength = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
isReceivingBody = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
bool Proxy::initSSL()
|
bool Proxy::initSSL()
|
||||||
{
|
{
|
||||||
_clientCtx = SSL_CTX_new(TLS_client_method());
|
_clientCtx = SSL_CTX_new(TLS_client_method());
|
||||||
if (!_clientCtx)
|
if (!_clientCtx) return false;
|
||||||
{
|
|
||||||
Log::error("Failed to create client SSL context");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
|
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
|
||||||
|
const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
|
||||||
|
SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,71 +168,43 @@ void Proxy::cleanupSSL()
|
|||||||
}
|
}
|
||||||
|
|
||||||
Proxy::Proxy() {}
|
Proxy::Proxy() {}
|
||||||
|
Proxy::~Proxy()
|
||||||
Proxy::~Proxy() {}
|
{
|
||||||
|
Shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
bool Proxy::Init()
|
bool Proxy::Init()
|
||||||
{
|
{
|
||||||
WSADATA wsaData;
|
WSADATA wsaData;
|
||||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
|
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false;
|
||||||
{
|
if (!_certManager.Init()) return false;
|
||||||
Log::error("WSAStartup failed");
|
if (!initSSL()) return false;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!_certManager.Init())
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!initSSL())
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
_listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
_listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
||||||
if (_listenSocket == INVALID_SOCKET)
|
sockaddr_in serverAddr = {};
|
||||||
{
|
|
||||||
Log::error("Error creating listen socket: {0:x}", WSAGetLastError());
|
|
||||||
Shutdown();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
sockaddr_in serverAddr;
|
|
||||||
serverAddr.sin_family = AF_INET;
|
serverAddr.sin_family = AF_INET;
|
||||||
serverAddr.sin_port = htons(PROXY_PORT);
|
serverAddr.sin_port = htons(PROXY_PORT);
|
||||||
inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr);
|
inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr);
|
||||||
|
|
||||||
if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR)
|
if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false;
|
||||||
{
|
|
||||||
Log::error("Error binding listen socket: {0:x}", WSAGetLastError());
|
|
||||||
Shutdown();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
listen(_listenSocket, SOMAXCONN);
|
listen(_listenSocket, SOMAXCONN);
|
||||||
|
Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT);
|
||||||
Log::verbose("Listening on 127.0.0.1:{}", PROXY_PORT);
|
|
||||||
|
|
||||||
_running = true;
|
_running = true;
|
||||||
_workerThread = std::thread(&Proxy::loop, this);
|
_workerThread = std::thread(&Proxy::loop, this);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Proxy::Shutdown()
|
void Proxy::Shutdown()
|
||||||
{
|
{
|
||||||
_running = false;
|
_running = false;
|
||||||
|
if (_listenSocket != INVALID_SOCKET)
|
||||||
if (_listenSocket != INVALID_SOCKET && _listenSocket != 0)
|
|
||||||
{
|
{
|
||||||
closesocket(_listenSocket);
|
closesocket(_listenSocket);
|
||||||
_listenSocket = 0;
|
_listenSocket = INVALID_SOCKET;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_workerThread.joinable()) _workerThread.join();
|
if (_workerThread.joinable()) _workerThread.join();
|
||||||
|
|
||||||
WSACleanup();
|
WSACleanup();
|
||||||
|
|
||||||
cleanupSSL();
|
cleanupSSL();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,500 +212,355 @@ void Proxy::loop()
|
|||||||
{
|
{
|
||||||
while (_running)
|
while (_running)
|
||||||
{
|
{
|
||||||
SOCKET clientSocket = accept(_listenSocket, NULL, NULL);
|
SOCKET hClient = accept(_listenSocket, NULL, NULL);
|
||||||
|
|
||||||
if (!_running)
|
if (!_running)
|
||||||
{
|
{
|
||||||
if (clientSocket != INVALID_SOCKET) closesocket(clientSocket);
|
if (hClient != INVALID_SOCKET) closesocket(hClient);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (hClient == INVALID_SOCKET) continue;
|
||||||
if (clientSocket == INVALID_SOCKET) continue;
|
std::thread([this, hClient]() { this->handleClient(hClient); }).detach();
|
||||||
|
|
||||||
std::thread([this, clientSocket]() { this->handleClient(clientSocket); }).detach();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Proxy::handleClient(SOCKET clientSocket)
|
void Proxy::handleClient(SOCKET hClientSocket)
|
||||||
{
|
{
|
||||||
char buffer[8192];
|
ScopedSocket clientGuard(hClientSocket);
|
||||||
int bytesReceived = recv(clientSocket, buffer, sizeof(buffer) - 1, 0);
|
char buffer[32768];
|
||||||
|
|
||||||
if (bytesReceived <= 0)
|
int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0);
|
||||||
{
|
if (bytesRead <= 0) return;
|
||||||
closesocket(clientSocket);
|
buffer[bytesRead] = '\0';
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer[bytesReceived] = '\0';
|
std::string initialReq(buffer);
|
||||||
std::string request(buffer, bytesReceived);
|
if (initialReq.find("CONNECT ") != 0) return;
|
||||||
|
|
||||||
std::string method, url;
|
size_t hostStart = 8;
|
||||||
size_t space1 = request.find(' ');
|
size_t hostEnd = initialReq.find(' ', hostStart);
|
||||||
if (space1 != std::string::npos)
|
if (hostEnd == std::string::npos) return;
|
||||||
{
|
|
||||||
method = request.substr(0, space1);
|
|
||||||
size_t space2 = request.find(' ', space1 + 1);
|
|
||||||
if (space2 != std::string::npos)
|
|
||||||
{
|
|
||||||
url = request.substr(space1 + 1, space2 - space1 - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (method.empty() || url.empty())
|
std::string fullHost = initialReq.substr(hostStart, hostEnd - hostStart);
|
||||||
{
|
std::string host = fullHost;
|
||||||
closesocket(clientSocket);
|
int port = 443;
|
||||||
return;
|
size_t colon = fullHost.find(':');
|
||||||
}
|
|
||||||
|
|
||||||
std::string host;
|
|
||||||
std::string port = "80";
|
|
||||||
bool isConnect = (method == "CONNECT");
|
|
||||||
|
|
||||||
if (isConnect)
|
|
||||||
{
|
|
||||||
size_t colon = url.find(':');
|
|
||||||
if (colon != std::string::npos)
|
if (colon != std::string::npos)
|
||||||
{
|
{
|
||||||
host = url.substr(0, colon);
|
host = fullHost.substr(0, colon);
|
||||||
port = url.substr(colon + 1);
|
port = safe_stoi(fullHost.substr(colon + 1), 443);
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
host = url;
|
|
||||||
port = "443";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
size_t hostPos = request.find("Host: ");
|
|
||||||
if (hostPos != std::string::npos)
|
|
||||||
{
|
|
||||||
size_t endPos = request.find("\r\n", hostPos);
|
|
||||||
if (endPos != std::string::npos)
|
|
||||||
{
|
|
||||||
std::string hostHeader = request.substr(hostPos + 6, endPos - (hostPos + 6));
|
|
||||||
size_t colon = hostHeader.find(':');
|
|
||||||
if (colon != std::string::npos)
|
|
||||||
{
|
|
||||||
host = hostHeader.substr(0, colon);
|
|
||||||
port = hostHeader.substr(colon + 1);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
host = hostHeader;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (host.empty())
|
struct addrinfo hints = {}, *res = nullptr;
|
||||||
{
|
|
||||||
closesocket(clientSocket);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct addrinfo hints = {}, *res;
|
|
||||||
hints.ai_family = AF_INET;
|
hints.ai_family = AF_INET;
|
||||||
hints.ai_socktype = SOCK_STREAM;
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return;
|
||||||
|
|
||||||
if (getaddrinfo(host.c_str(), port.c_str(), &hints, &res) != 0)
|
ScopedSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
||||||
|
if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0)
|
||||||
{
|
{
|
||||||
Log::error("Could not resolve host: {}", host);
|
|
||||||
closesocket(clientSocket);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
SOCKET remoteSocket = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
|
|
||||||
if (connect(remoteSocket, res->ai_addr, (int)res->ai_addrlen) == SOCKET_ERROR)
|
|
||||||
{
|
|
||||||
Log::error("Connection to {}:{} failed", host, port);
|
|
||||||
freeaddrinfo(res);
|
freeaddrinfo(res);
|
||||||
closesocket(clientSocket);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
freeaddrinfo(res);
|
freeaddrinfo(res);
|
||||||
|
|
||||||
if (isConnect)
|
send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0);
|
||||||
{
|
|
||||||
const char* reply = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
|
||||||
send(clientSocket, reply, static_cast<int>(strlen(reply)), 0);
|
|
||||||
|
|
||||||
SSL_CTX* serverCtx = _certManager.CreateHostContext(host);
|
SSL_CTX* hostCtx = _certManager.CreateHostContext(host);
|
||||||
if (!serverCtx)
|
if (!hostCtx) return;
|
||||||
{
|
|
||||||
Log::error("Failed to generate dynamic cert for {}", host);
|
|
||||||
closesocket(clientSocket);
|
|
||||||
closesocket(remoteSocket);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
SSL* clientSSL = SSL_new(serverCtx);
|
ScopedSSL clientSSL(SSL_new(hostCtx));
|
||||||
SSL_set_fd(clientSSL, static_cast<int>(clientSocket));
|
SSL_set_fd(clientSSL, (int)clientGuard.get());
|
||||||
|
if (SSL_accept(clientSSL) <= 0) return;
|
||||||
|
|
||||||
if (SSL_accept(clientSSL) <= 0)
|
ScopedSSL remoteSSL(SSL_new(_clientCtx));
|
||||||
{
|
SSL_set_fd(remoteSSL, (int)remoteGuard.get());
|
||||||
Log::error("SSL_accept failed on client");
|
|
||||||
SSL_free(clientSSL);
|
|
||||||
closesocket(clientSocket);
|
|
||||||
closesocket(remoteSocket);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
SSL* remoteSSL = SSL_new(_clientCtx);
|
|
||||||
SSL_set_fd(remoteSSL, static_cast<int>(remoteSocket));
|
|
||||||
SSL_set_tlsext_host_name(remoteSSL, host.c_str());
|
SSL_set_tlsext_host_name(remoteSSL, host.c_str());
|
||||||
|
if (SSL_connect(remoteSSL) <= 0) return;
|
||||||
|
|
||||||
if (SSL_connect(remoteSSL) <= 0)
|
HttpStream clientStream, serverStream;
|
||||||
{
|
|
||||||
Log::error("SSL_connect failed on remote server");
|
|
||||||
SSL_free(remoteSSL);
|
|
||||||
SSL_free(clientSSL);
|
|
||||||
closesocket(clientSocket);
|
|
||||||
closesocket(remoteSocket);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::deque<std::string> pendingUrls;
|
std::deque<std::string> pendingUrls;
|
||||||
std::string serverBuffer;
|
bool tunnelMode = false;
|
||||||
bool isReceivingBody = false;
|
|
||||||
int expectedLength = -1;
|
|
||||||
bool isChunked = false;
|
|
||||||
size_t headersEnd = 0;
|
|
||||||
|
|
||||||
fd_set readfds;
|
fd_set readfds;
|
||||||
|
|
||||||
while (_running)
|
while (_running)
|
||||||
{
|
{
|
||||||
FD_ZERO(&readfds);
|
FD_ZERO(&readfds);
|
||||||
FD_SET(clientSocket, &readfds);
|
FD_SET(clientGuard, &readfds);
|
||||||
FD_SET(remoteSocket, &readfds);
|
FD_SET(remoteGuard, &readfds);
|
||||||
|
|
||||||
struct timeval tv;
|
struct timeval tv = {0, 50000};
|
||||||
tv.tv_sec = 0;
|
if (select(0, &readfds, NULL, NULL, &tv) < 0) break;
|
||||||
tv.tv_usec = 100000;
|
|
||||||
|
|
||||||
int ret = select(0, &readfds, NULL, NULL, &tv);
|
if (tunnelMode)
|
||||||
if (ret < 0) break;
|
|
||||||
|
|
||||||
if (FD_ISSET(clientSocket, &readfds) || SSL_pending(clientSSL) > 0)
|
|
||||||
{
|
{
|
||||||
int bytes = SSL_read(clientSSL, buffer, sizeof(buffer));
|
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
|
||||||
if (bytes <= 0) break;
|
|
||||||
|
|
||||||
std::string data(buffer, bytes);
|
|
||||||
|
|
||||||
size_t reqStart = 0;
|
|
||||||
while (reqStart < data.size())
|
|
||||||
{
|
{
|
||||||
size_t nextReq = std::string::npos;
|
int n = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||||
const char* methods[] = {"GET ", "POST ", "PUT ", "DELETE ", "PATCH ", "OPTIONS ", "HEAD "};
|
if (n <= 0) break;
|
||||||
for (const char* m : methods)
|
SSL_write(remoteSSL, buffer, n);
|
||||||
|
}
|
||||||
|
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||||
{
|
{
|
||||||
size_t found = data.find(m, reqStart + 1);
|
int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
|
||||||
if (found != std::string::npos && (nextReq == std::string::npos || found < nextReq))
|
if (n <= 0) break;
|
||||||
nextReq = found;
|
SSL_write(clientSSL, buffer, n);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string singleReq = (nextReq == std::string::npos) ? data.substr(reqStart)
|
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
|
||||||
: data.substr(reqStart, nextReq - reqStart);
|
|
||||||
|
|
||||||
std::string url;
|
|
||||||
size_t pathSpace1 = singleReq.find(' ');
|
|
||||||
size_t pathSpace2 = singleReq.find(' ', pathSpace1 + 1);
|
|
||||||
if (pathSpace1 != std::string::npos && pathSpace2 != std::string::npos)
|
|
||||||
{
|
{
|
||||||
url = "https://" + host + singleReq.substr(pathSpace1 + 1, pathSpace2 - pathSpace1 - 1);
|
int n = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||||
|
if (n <= 0) break;
|
||||||
|
clientStream.buffer.append(buffer, n);
|
||||||
|
|
||||||
|
while (!clientStream.buffer.empty())
|
||||||
|
{
|
||||||
|
if (!clientStream.isReceivingBody)
|
||||||
|
{
|
||||||
|
if (!clientStream.parseHeaders()) break;
|
||||||
|
|
||||||
|
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
|
||||||
|
std::string url = "https://" + host;
|
||||||
|
size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1);
|
||||||
|
if (s1 != std::string::npos && s2 != std::string::npos)
|
||||||
|
url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1);
|
||||||
|
|
||||||
pendingUrls.push_back(url);
|
pendingUrls.push_back(url);
|
||||||
|
|
||||||
size_t aePos = singleReq.find("Accept-Encoding:");
|
removeHeader(headers, "Accept-Encoding");
|
||||||
if (aePos == std::string::npos) aePos = singleReq.find("accept-encoding:");
|
removeHeader(headers, "Expect");
|
||||||
if (aePos != std::string::npos)
|
headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n");
|
||||||
|
|
||||||
|
OnClientRequest.run(url, headers);
|
||||||
|
SSL_write(remoteSSL, headers.data(), (int)headers.size());
|
||||||
|
|
||||||
|
clientStream.buffer.erase(0, clientStream.headersEnd + 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (clientStream.isReceivingBody)
|
||||||
{
|
{
|
||||||
size_t aeEndPos = singleReq.find("\r\n", aePos);
|
if (clientStream.isChunked)
|
||||||
if (aeEndPos != std::string::npos)
|
{
|
||||||
singleReq.replace(aePos, aeEndPos - aePos, "Accept-Encoding: identity");
|
size_t idx = 0;
|
||||||
|
while (idx < clientStream.buffer.size())
|
||||||
|
{
|
||||||
|
size_t le = clientStream.buffer.find("\r\n", idx);
|
||||||
|
if (le == std::string::npos) break;
|
||||||
|
|
||||||
|
int chunkSz = safe_stoi(clientStream.buffer.substr(idx, le - idx), 0, 16);
|
||||||
|
size_t totalChunkSz = (le - idx) + 2 + chunkSz + 2;
|
||||||
|
if (idx + totalChunkSz > clientStream.buffer.size()) break;
|
||||||
|
|
||||||
|
SSL_write(remoteSSL, clientStream.buffer.data() + idx, (int)totalChunkSz);
|
||||||
|
idx += totalChunkSz;
|
||||||
|
|
||||||
|
if (chunkSz == 0)
|
||||||
|
{
|
||||||
|
clientStream.reset();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (idx > 0) clientStream.buffer.erase(0, idx);
|
||||||
|
if (!clientStream.isReceivingBody) continue;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else if (clientStream.contentLength >= 0)
|
||||||
|
{
|
||||||
|
size_t ts = (std::min)((size_t)clientStream.contentLength, clientStream.buffer.size());
|
||||||
|
if (ts > 0)
|
||||||
|
{
|
||||||
|
SSL_write(remoteSSL, clientStream.buffer.data(), (int)ts);
|
||||||
|
clientStream.buffer.erase(0, ts);
|
||||||
|
clientStream.contentLength -= (int)ts;
|
||||||
|
}
|
||||||
|
if (clientStream.contentLength <= 0)
|
||||||
|
clientStream.reset();
|
||||||
|
else
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
clientStream.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OnClientRequest.run(url.empty() ? ("https://" + host) : url, singleReq);
|
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||||
|
{
|
||||||
if (nextReq == std::string::npos) break;
|
int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
|
||||||
reqStart = nextReq;
|
bool connectionClosed = (n <= 0);
|
||||||
|
if (!connectionClosed)
|
||||||
|
{
|
||||||
|
serverStream.buffer.append(buffer, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
int sent = SSL_write(remoteSSL, data.data(), static_cast<int>(data.size()));
|
while (!serverStream.buffer.empty() || connectionClosed)
|
||||||
if (sent <= 0) break;
|
{
|
||||||
|
if (!serverStream.isReceivingBody)
|
||||||
|
{
|
||||||
|
if (connectionClosed) break;
|
||||||
|
|
||||||
|
serverStream.headersEnd = serverStream.buffer.find("\r\n\r\n");
|
||||||
|
if (serverStream.headersEnd == std::string::npos) break;
|
||||||
|
|
||||||
|
std::string headers = serverStream.buffer.substr(0, serverStream.headersEnd + 4);
|
||||||
|
size_t s1 = headers.find(' ');
|
||||||
|
int sCode = (s1 != std::string::npos) ? safe_stoi(headers.substr(s1 + 1, 3)) : 0;
|
||||||
|
|
||||||
|
if (sCode == 101)
|
||||||
|
{
|
||||||
|
SSL_write(clientSSL, serverStream.buffer.data(), (int)serverStream.buffer.size());
|
||||||
|
serverStream.buffer.clear();
|
||||||
|
clientStream.buffer.clear();
|
||||||
|
tunnelMode = true;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (FD_ISSET(remoteSocket, &readfds) || SSL_pending(remoteSSL) > 0)
|
if (sCode >= 100 && sCode < 200)
|
||||||
{
|
{
|
||||||
int bytes = SSL_read(remoteSSL, buffer, sizeof(buffer) - 1);
|
SSL_write(clientSSL, headers.data(), (int)headers.size());
|
||||||
if (bytes <= 0) break;
|
serverStream.buffer.erase(0, serverStream.headersEnd + 4);
|
||||||
|
serverStream.isReceivingBody = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
serverBuffer.append(buffer, bytes);
|
serverStream.isReceivingBody = true;
|
||||||
|
std::string h_lower = headers;
|
||||||
|
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||||
|
serverStream.isChunked = (h_lower.find("transfer-encoding: chunked") != std::string::npos);
|
||||||
|
|
||||||
while (!serverBuffer.empty())
|
size_t clPos = h_lower.find("content-length:");
|
||||||
{
|
|
||||||
if (!isReceivingBody)
|
|
||||||
{
|
|
||||||
headersEnd = serverBuffer.find("\r\n\r\n");
|
|
||||||
if (headersEnd != std::string::npos)
|
|
||||||
{
|
|
||||||
isReceivingBody = true;
|
|
||||||
std::string headers = serverBuffer.substr(0, headersEnd + 4);
|
|
||||||
|
|
||||||
size_t clPos = headers.find("Content-Length: ");
|
|
||||||
if (clPos == std::string::npos) clPos = headers.find("content-length: ");
|
|
||||||
if (clPos != std::string::npos)
|
if (clPos != std::string::npos)
|
||||||
{
|
{
|
||||||
size_t clEnd = headers.find("\r\n", clPos);
|
size_t vStart = clPos + 15;
|
||||||
expectedLength = std::stoi(headers.substr(clPos + 16, clEnd - clPos - 16));
|
while (vStart < h_lower.size() && (h_lower[vStart] == ' ' || h_lower[vStart] == '\t'))
|
||||||
|
vStart++;
|
||||||
|
serverStream.contentLength =
|
||||||
|
safe_stoi(h_lower.substr(vStart, h_lower.find("\r\n", vStart) - vStart), -1);
|
||||||
}
|
}
|
||||||
else
|
else if (sCode == 204 || sCode == 304 || sCode == 205)
|
||||||
expectedLength = -1;
|
|
||||||
|
|
||||||
isChunked = (headers.find("chunked") != std::string::npos);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
{
|
||||||
break; // need more data
|
serverStream.contentLength = 0;
|
||||||
|
}
|
||||||
|
else if (!serverStream.isChunked)
|
||||||
|
{
|
||||||
|
serverStream.contentLength = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isReceivingBody)
|
if (serverStream.isReceivingBody)
|
||||||
{
|
{
|
||||||
bool complete = false;
|
bool complete = false;
|
||||||
std::string fullBody;
|
std::string body;
|
||||||
size_t bodyStart = headersEnd + 4;
|
size_t bStart = serverStream.headersEnd + 4;
|
||||||
size_t totalProcessed = bodyStart;
|
size_t processed = bStart;
|
||||||
|
|
||||||
if (isChunked)
|
if (serverStream.isChunked)
|
||||||
{
|
{
|
||||||
size_t idx = bodyStart;
|
size_t idx = bStart;
|
||||||
bool parseOk = true;
|
bool ok = true;
|
||||||
while (idx < serverBuffer.size())
|
while (idx < serverStream.buffer.size())
|
||||||
{
|
{
|
||||||
size_t lineEnd = serverBuffer.find("\r\n", idx);
|
size_t le = serverStream.buffer.find("\r\n", idx);
|
||||||
if (lineEnd == std::string::npos)
|
if (le == std::string::npos)
|
||||||
{
|
{
|
||||||
parseOk = false;
|
ok = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
std::string hexStr = serverBuffer.substr(idx, lineEnd - idx);
|
int cs = safe_stoi(serverStream.buffer.substr(idx, le - idx), 0, 16);
|
||||||
int chunkSize = 0;
|
idx = le + 2;
|
||||||
try
|
if (cs == 0)
|
||||||
{
|
{
|
||||||
chunkSize = std::stoi(hexStr, nullptr, 16);
|
idx += 2;
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
parseOk = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
idx = lineEnd + 2;
|
|
||||||
if (chunkSize == 0)
|
|
||||||
{
|
|
||||||
idx += 2; // skip terminal \r\n
|
|
||||||
complete = true;
|
complete = true;
|
||||||
totalProcessed = idx;
|
processed = idx;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (idx + (size_t)chunkSize + 2 > serverBuffer.size())
|
if (idx + cs + 2 > serverStream.buffer.size())
|
||||||
{
|
{
|
||||||
parseOk = false;
|
ok = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
fullBody.append(serverBuffer, idx, chunkSize);
|
body.append(serverStream.buffer, idx, cs);
|
||||||
idx += chunkSize + 2;
|
idx += cs + 2;
|
||||||
}
|
}
|
||||||
if (!parseOk) complete = false;
|
if (!ok)
|
||||||
}
|
|
||||||
else if (expectedLength >= 0)
|
|
||||||
{
|
{
|
||||||
if (serverBuffer.size() >= bodyStart + expectedLength)
|
if (connectionClosed)
|
||||||
{
|
{
|
||||||
complete = true;
|
complete = true;
|
||||||
totalProcessed = bodyStart + expectedLength;
|
processed = serverStream.buffer.size();
|
||||||
fullBody = serverBuffer.substr(bodyStart, expectedLength);
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
complete = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (serverStream.contentLength >= 0)
|
||||||
|
{
|
||||||
|
if (serverStream.buffer.size() >= bStart + serverStream.contentLength)
|
||||||
|
{
|
||||||
|
complete = true;
|
||||||
|
processed = bStart + serverStream.contentLength;
|
||||||
|
body = serverStream.buffer.substr(bStart, serverStream.contentLength);
|
||||||
|
}
|
||||||
|
else if (connectionClosed)
|
||||||
|
{
|
||||||
|
complete = true;
|
||||||
|
processed = serverStream.buffer.size();
|
||||||
|
body = serverStream.buffer.substr(bStart);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
std::string peekBuffer = serverBuffer.substr(0, bodyStart);
|
if (connectionClosed)
|
||||||
bool isCloseConn = peekBuffer.find("Connection: close") != std::string::npos ||
|
|
||||||
peekBuffer.find("connection: close") != std::string::npos;
|
|
||||||
|
|
||||||
bool isNoBodyStatus = peekBuffer.find("HTTP/1.1 204") != std::string::npos ||
|
|
||||||
peekBuffer.find("HTTP/1.1 304") != std::string::npos ||
|
|
||||||
peekBuffer.find("HTTP/1.0 204") != std::string::npos ||
|
|
||||||
peekBuffer.find("HTTP/1.0 304") != std::string::npos;
|
|
||||||
|
|
||||||
if (isCloseConn)
|
|
||||||
break;
|
|
||||||
else
|
|
||||||
{
|
{
|
||||||
complete = true;
|
complete = true;
|
||||||
fullBody = "";
|
processed = serverStream.buffer.size();
|
||||||
totalProcessed = bodyStart;
|
body = serverStream.buffer.substr(bStart);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
complete = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (complete)
|
if (complete)
|
||||||
{
|
{
|
||||||
std::string headers = serverBuffer.substr(0, bodyStart);
|
std::string url = pendingUrls.empty() ? ("https://" + host) : pendingUrls.front();
|
||||||
std::string responseData = fullBody;
|
if (!pendingUrls.empty()) pendingUrls.pop_front();
|
||||||
|
|
||||||
std::string currentUrl = "https://" + host;
|
std::string respHeaders = serverStream.buffer.substr(0, bStart);
|
||||||
if (!pendingUrls.empty())
|
size_t firstSpace = respHeaders.find(' ');
|
||||||
|
int sc =
|
||||||
|
(firstSpace != std::string::npos) ? safe_stoi(respHeaders.substr(firstSpace + 1, 3)) : 0;
|
||||||
|
|
||||||
|
OnServerResponse.run(url, body);
|
||||||
|
|
||||||
|
removeHeader(respHeaders, "Transfer-Encoding");
|
||||||
|
removeHeader(respHeaders, "Content-Length");
|
||||||
|
if (sc != 204 && sc != 304 && sc != 205)
|
||||||
{
|
{
|
||||||
currentUrl = pendingUrls.front();
|
respHeaders.insert(respHeaders.size() - 2,
|
||||||
pendingUrls.pop_front();
|
"Content-Length: " + std::to_string(body.size()) + "\r\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
OnServerResponse.run(currentUrl, responseData);
|
std::string packet = respHeaders + body;
|
||||||
|
SSL_write(clientSSL, packet.data(), (int)packet.size());
|
||||||
|
|
||||||
auto removeHeader = [&](std::string& h, const std::string& key) {
|
serverStream.buffer.erase(0, processed);
|
||||||
size_t pos = 0;
|
serverStream.reset();
|
||||||
while (true)
|
clientStream.reset();
|
||||||
{
|
|
||||||
pos = h.find(key, pos);
|
|
||||||
if (pos == std::string::npos)
|
|
||||||
{
|
|
||||||
std::string lowerKey = key;
|
|
||||||
for (char& c : lowerKey)
|
|
||||||
c = (char)tolower(c);
|
|
||||||
pos = h.find(lowerKey, 0);
|
|
||||||
if (pos == std::string::npos) break;
|
|
||||||
}
|
|
||||||
if (pos == 0 || h[pos - 1] == '\n')
|
|
||||||
{
|
|
||||||
size_t end = h.find("\r\n", pos);
|
|
||||||
if (end != std::string::npos)
|
|
||||||
{
|
|
||||||
h.erase(pos, end - pos + 2);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
removeHeader(headers, "Transfer-Encoding");
|
|
||||||
removeHeader(headers, "Content-Length");
|
|
||||||
|
|
||||||
headers.insert(headers.size() - 2,
|
|
||||||
"Content-Length: " + std::to_string(responseData.size()) + "\r\n");
|
|
||||||
|
|
||||||
std::string finalPacket = headers + responseData;
|
|
||||||
int sent = SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
|
|
||||||
|
|
||||||
serverBuffer.erase(0, totalProcessed);
|
|
||||||
|
|
||||||
isReceivingBody = false;
|
|
||||||
expectedLength = -1;
|
|
||||||
headersEnd = 0;
|
|
||||||
isChunked = false;
|
|
||||||
|
|
||||||
if (sent <= 0) break;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
break; // wait for more streaming packets
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isReceivingBody && expectedLength < 0 && !isChunked && serverBuffer.size() > headersEnd + 4)
|
|
||||||
{
|
|
||||||
std::string headers = serverBuffer.substr(0, headersEnd + 4);
|
|
||||||
std::string responseData = serverBuffer.substr(headersEnd + 4);
|
|
||||||
|
|
||||||
std::string finalUrl = "https://" + host;
|
|
||||||
if (!pendingUrls.empty())
|
|
||||||
{
|
|
||||||
finalUrl = pendingUrls.front();
|
|
||||||
pendingUrls.pop_front();
|
|
||||||
}
|
|
||||||
|
|
||||||
OnServerResponse.run(finalUrl, responseData);
|
|
||||||
|
|
||||||
auto removeHeader = [&](std::string& h, const std::string& key) {
|
|
||||||
size_t pos = 0;
|
|
||||||
while (true)
|
|
||||||
{
|
|
||||||
pos = h.find(key, pos);
|
|
||||||
if (pos == std::string::npos)
|
|
||||||
{
|
|
||||||
std::string lowerKey = key;
|
|
||||||
for (char& c : lowerKey)
|
|
||||||
c = (char)tolower(c);
|
|
||||||
pos = h.find(lowerKey, 0);
|
|
||||||
if (pos == std::string::npos) break;
|
|
||||||
}
|
|
||||||
if (pos == 0 || h[pos - 1] == '\n')
|
|
||||||
{
|
|
||||||
size_t end = h.find("\r\n", pos);
|
|
||||||
if (end != std::string::npos)
|
|
||||||
{
|
|
||||||
h.erase(pos, end - pos + 2);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
removeHeader(headers, "Transfer-Encoding");
|
|
||||||
removeHeader(headers, "Content-Length");
|
|
||||||
headers.insert(headers.size() - 2, "Content-Length: " + std::to_string(responseData.size()) + "\r\n");
|
|
||||||
|
|
||||||
std::string finalPacket = headers + responseData;
|
|
||||||
SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
SSL_shutdown(clientSSL);
|
|
||||||
SSL_free(clientSSL);
|
|
||||||
|
|
||||||
SSL_shutdown(remoteSSL);
|
|
||||||
SSL_free(remoteSSL);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
send(remoteSocket, buffer, bytesReceived, 0);
|
break;
|
||||||
|
|
||||||
fd_set readfds;
|
|
||||||
while (_running)
|
|
||||||
{
|
|
||||||
FD_ZERO(&readfds);
|
|
||||||
FD_SET(clientSocket, &readfds);
|
|
||||||
FD_SET(remoteSocket, &readfds);
|
|
||||||
|
|
||||||
struct timeval tv;
|
|
||||||
tv.tv_sec = 1;
|
|
||||||
tv.tv_usec = 0;
|
|
||||||
|
|
||||||
int ret = select(0, &readfds, NULL, NULL, &tv);
|
|
||||||
if (ret < 0) break;
|
|
||||||
if (ret == 0) continue;
|
|
||||||
|
|
||||||
if (FD_ISSET(clientSocket, &readfds))
|
|
||||||
{
|
|
||||||
int bytes = recv(clientSocket, buffer, sizeof(buffer), 0);
|
|
||||||
if (bytes <= 0) break;
|
|
||||||
int sent = send(remoteSocket, buffer, bytes, 0);
|
|
||||||
if (sent == SOCKET_ERROR) break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (FD_ISSET(remoteSocket, &readfds))
|
|
||||||
{
|
|
||||||
int bytes = recv(remoteSocket, buffer, sizeof(buffer), 0);
|
|
||||||
if (bytes <= 0) break;
|
|
||||||
int sent = send(clientSocket, buffer, bytes, 0);
|
|
||||||
if (sent == SOCKET_ERROR) break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (connectionClosed) break;
|
||||||
closesocket(remoteSocket);
|
}
|
||||||
closesocket(clientSocket);
|
}
|
||||||
}
|
}
|
||||||
@@ -12,7 +12,8 @@
|
|||||||
|
|
||||||
typedef unsigned __int64 SOCKET;
|
typedef unsigned __int64 SOCKET;
|
||||||
|
|
||||||
class Proxy {
|
class Proxy
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
Proxy();
|
Proxy();
|
||||||
~Proxy();
|
~Proxy();
|
||||||
@@ -32,7 +33,7 @@ private:
|
|||||||
|
|
||||||
SOCKET _listenSocket = 0;
|
SOCKET _listenSocket = 0;
|
||||||
std::thread _workerThread;
|
std::thread _workerThread;
|
||||||
std::atomic<bool> _running;
|
std::atomic<bool> _running = false;
|
||||||
|
|
||||||
CertManager _certManager;
|
CertManager _certManager;
|
||||||
SSL_CTX* _clientCtx = nullptr;
|
SSL_CTX* _clientCtx = nullptr;
|
||||||
|
|||||||
Reference in New Issue
Block a user