feat: refactor proxy
This commit is contained in:
+324
-390
@@ -4,181 +4,174 @@
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#include <nerutils/log.h>
|
||||
#include <deque>
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
#include <deque>
|
||||
|
||||
namespace
|
||||
/*
|
||||
memory helpers
|
||||
*/
|
||||
template <typename T, void (*f)(T*)> struct Deleter
|
||||
{
|
||||
class ScopedSSL
|
||||
{
|
||||
SSL* s;
|
||||
|
||||
public:
|
||||
ScopedSSL(SSL* val = nullptr) : s(val) {}
|
||||
~ScopedSSL()
|
||||
{
|
||||
if (s) SSL_free(s);
|
||||
}
|
||||
void reset(SSL* val)
|
||||
{
|
||||
if (s) SSL_free(s);
|
||||
s = val;
|
||||
}
|
||||
operator SSL*() const { return s; }
|
||||
};
|
||||
|
||||
class ScopedSocket
|
||||
{
|
||||
SOCKET s;
|
||||
|
||||
public:
|
||||
ScopedSocket(SOCKET val = INVALID_SOCKET) : s(val) {}
|
||||
~ScopedSocket()
|
||||
{
|
||||
if (s != INVALID_SOCKET) closesocket(s);
|
||||
}
|
||||
void reset(SOCKET val)
|
||||
{
|
||||
if (s != INVALID_SOCKET) closesocket(s);
|
||||
s = val;
|
||||
}
|
||||
operator SOCKET() const { return s; }
|
||||
SOCKET get() const { return s; }
|
||||
};
|
||||
|
||||
int safe_stoi(const std::string& s, int default_val = 0, int base = 10)
|
||||
{
|
||||
if (s.empty()) return default_val;
|
||||
try
|
||||
void operator()(T* p) const
|
||||
{
|
||||
return std::stoi(s, nullptr, base);
|
||||
if (p) f(p);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return default_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
using SSL_ptr = std::unique_ptr<SSL, Deleter<SSL, SSL_free>>;
|
||||
|
||||
void removeHeader(std::string& h, const std::string& key)
|
||||
{
|
||||
std::string h_lower = h;
|
||||
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||
std::string key_lower = key;
|
||||
std::transform(key_lower.begin(), key_lower.end(), key_lower.begin(), ::tolower);
|
||||
|
||||
size_t pos = 0;
|
||||
while ((pos = h_lower.find(key_lower, pos)) != std::string::npos)
|
||||
{
|
||||
if (pos == 0 || h_lower[pos - 1] == '\n')
|
||||
{
|
||||
size_t end = h_lower.find("\r\n", pos);
|
||||
if (end != std::string::npos)
|
||||
{
|
||||
h.erase(pos, end - pos + 2);
|
||||
h_lower.erase(pos, end - pos + 2);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
|
||||
std::string getHeaderValue(const std::string& headers, const std::string& key)
|
||||
{
|
||||
std::string h_lower = headers;
|
||||
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||
std::string search = "\n" + key;
|
||||
std::transform(search.begin(), search.end(), search.begin(), ::tolower);
|
||||
search += ":";
|
||||
|
||||
size_t pos = h_lower.find(search);
|
||||
if (pos == std::string::npos) return "";
|
||||
|
||||
size_t vStart = pos + search.length();
|
||||
while (vStart < headers.size() && (headers[vStart] == ' ' || headers[vStart] == '\t'))
|
||||
vStart++;
|
||||
size_t vEnd = headers.find("\r\n", vStart);
|
||||
if (vEnd == std::string::npos) return "";
|
||||
return headers.substr(vStart, vEnd - vStart);
|
||||
}
|
||||
|
||||
struct HttpStream
|
||||
{
|
||||
std::string buffer;
|
||||
bool isReceivingBody = false;
|
||||
bool isChunked = false;
|
||||
int contentLength = -1;
|
||||
size_t headersEnd = 0;
|
||||
|
||||
void reset()
|
||||
{
|
||||
isReceivingBody = false;
|
||||
isChunked = false;
|
||||
contentLength = -1;
|
||||
headersEnd = 0;
|
||||
}
|
||||
|
||||
bool parseHeaders()
|
||||
{
|
||||
if (isReceivingBody) return true;
|
||||
headersEnd = buffer.find("\r\n\r\n");
|
||||
if (headersEnd == std::string::npos) return false;
|
||||
|
||||
std::string headers = buffer.substr(0, headersEnd + 4);
|
||||
|
||||
std::string te = getHeaderValue(headers, "Transfer-Encoding");
|
||||
std::transform(te.begin(), te.end(), te.begin(), ::tolower);
|
||||
isChunked = (te.find("chunked") != std::string::npos);
|
||||
|
||||
if (!isChunked)
|
||||
{
|
||||
std::string cl = getHeaderValue(headers, "Content-Length");
|
||||
contentLength = safe_stoi(cl, -1);
|
||||
}
|
||||
else
|
||||
{
|
||||
contentLength = -1;
|
||||
}
|
||||
|
||||
isReceivingBody = true;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool Proxy::initSSL()
|
||||
struct AutoSocket
|
||||
{
|
||||
_clientCtx = SSL_CTX_new(TLS_client_method());
|
||||
if (!_clientCtx) return false;
|
||||
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
|
||||
const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
|
||||
SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos));
|
||||
return true;
|
||||
}
|
||||
SOCKET s;
|
||||
AutoSocket(SOCKET val = INVALID_SOCKET) : s(val) {}
|
||||
~AutoSocket()
|
||||
{
|
||||
if (s != INVALID_SOCKET) closesocket(s);
|
||||
}
|
||||
operator SOCKET() const { return s; }
|
||||
};
|
||||
|
||||
void Proxy::cleanupSSL()
|
||||
/*
|
||||
helper functions
|
||||
*/
|
||||
int stoiSafe(const std::string& s, int default_val = 0, int base = 10)
|
||||
{
|
||||
if (_clientCtx)
|
||||
if (s.empty()) return default_val;
|
||||
try
|
||||
{
|
||||
SSL_CTX_free(_clientCtx);
|
||||
_clientCtx = nullptr;
|
||||
return std::stoi(s, nullptr, base);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return default_val;
|
||||
}
|
||||
}
|
||||
|
||||
void removeHeader(std::string& headers, const std::string& key)
|
||||
{
|
||||
std::string result;
|
||||
size_t start = 0;
|
||||
size_t end;
|
||||
|
||||
std::string keyLower = key;
|
||||
std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower);
|
||||
if (keyLower.back() != ':') keyLower += ':';
|
||||
|
||||
while ((end = headers.find('\n', start)) != std::string::npos)
|
||||
{
|
||||
std::string line = headers.substr(start, end - start + 1);
|
||||
std::string lineLower = line;
|
||||
std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower);
|
||||
|
||||
if (lineLower.compare(0, keyLower.length(), keyLower) != 0) result += line;
|
||||
|
||||
start = end + 1;
|
||||
}
|
||||
|
||||
if (start < headers.length())
|
||||
{
|
||||
std::string line = headers.substr(start);
|
||||
std::string lineLower = line;
|
||||
if (lineLower.compare(0, keyLower.length(), keyLower) != 0) result += line;
|
||||
}
|
||||
|
||||
headers = std::move(result);
|
||||
}
|
||||
|
||||
std::string getHeaderValue(const std::string& headers, const std::string& key)
|
||||
{
|
||||
size_t start = 0;
|
||||
size_t end;
|
||||
|
||||
std::string keyLower = key;
|
||||
std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower);
|
||||
if (keyLower.empty()) return "";
|
||||
if (keyLower.back() != ':') keyLower += ':';
|
||||
|
||||
while (start < headers.length())
|
||||
{
|
||||
end = headers.find('\n', start);
|
||||
|
||||
std::string line = (end == std::string::npos) ? headers.substr(start) : headers.substr(start, end - start);
|
||||
std::string lineLower = line;
|
||||
std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower);
|
||||
|
||||
if (lineLower.compare(0, keyLower.length(), keyLower) == 0)
|
||||
{
|
||||
size_t valueStart = keyLower.length();
|
||||
while (valueStart < line.length() && (line[valueStart] == ' ' || line[valueStart] == '\t'))
|
||||
valueStart++;
|
||||
|
||||
size_t valueEnd = line.length();
|
||||
if (valueEnd > valueStart && line[valueEnd - 1] == '\r') valueEnd--;
|
||||
|
||||
return line.substr(valueStart, valueEnd - valueStart);
|
||||
}
|
||||
|
||||
if (end == std::string::npos) break;
|
||||
start = end + 1;
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
/*
|
||||
http stream helper class
|
||||
*/
|
||||
struct HttpStream
|
||||
{
|
||||
std::string buffer;
|
||||
bool isReceivingBody = false;
|
||||
bool isChunked = false;
|
||||
int contentLength = -1;
|
||||
size_t headersEnd = std::string::npos;
|
||||
|
||||
void reset()
|
||||
{
|
||||
isReceivingBody = false;
|
||||
isChunked = false;
|
||||
contentLength = -1;
|
||||
headersEnd = std::string::npos;
|
||||
}
|
||||
|
||||
bool parseHeaders()
|
||||
{
|
||||
headersEnd = buffer.find("\r\n\r\n");
|
||||
if (headersEnd == std::string::npos) return false;
|
||||
|
||||
std::string headers = buffer.substr(0, headersEnd + 4);
|
||||
std::string te = getHeaderValue(headers, "Transfer-Encoding");
|
||||
std::transform(te.begin(), te.end(), te.begin(), ::tolower);
|
||||
isChunked = (te.find("chunked") != std::string::npos);
|
||||
|
||||
if (!isChunked)
|
||||
{
|
||||
std::string cl = getHeaderValue(headers, "Content-Length");
|
||||
contentLength = stoiSafe(cl, -1);
|
||||
}
|
||||
isReceivingBody = true;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
proxy impl
|
||||
*/
|
||||
Proxy::Proxy() {}
|
||||
|
||||
Proxy::~Proxy()
|
||||
{
|
||||
Shutdown();
|
||||
shutdown();
|
||||
}
|
||||
|
||||
bool Proxy::Init()
|
||||
bool Proxy::init()
|
||||
{
|
||||
if (!_certManager.init()) return false;
|
||||
|
||||
initSSL();
|
||||
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false;
|
||||
if (!_certManager.Init()) return false;
|
||||
if (!initSSL()) return false;
|
||||
|
||||
_listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
||||
sockaddr_in serverAddr = {};
|
||||
@@ -188,15 +181,18 @@ bool Proxy::Init()
|
||||
|
||||
if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false;
|
||||
listen(_listenSocket, SOMAXCONN);
|
||||
Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT);
|
||||
|
||||
_running = true;
|
||||
_workerThread = std::thread(&Proxy::loop, this);
|
||||
Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Proxy::Shutdown()
|
||||
void Proxy::shutdown()
|
||||
{
|
||||
if (!_running) return;
|
||||
|
||||
_running = false;
|
||||
if (_listenSocket != INVALID_SOCKET)
|
||||
{
|
||||
@@ -212,22 +208,22 @@ void Proxy::loop()
|
||||
{
|
||||
while (_running)
|
||||
{
|
||||
SOCKET hClient = accept(_listenSocket, NULL, NULL);
|
||||
SOCKET client = accept(_listenSocket, NULL, NULL);
|
||||
if (!_running)
|
||||
{
|
||||
if (hClient != INVALID_SOCKET) closesocket(hClient);
|
||||
if (client != INVALID_SOCKET) closesocket(client);
|
||||
break;
|
||||
}
|
||||
if (hClient == INVALID_SOCKET) continue;
|
||||
std::thread([this, hClient]() { this->handleClient(hClient); }).detach();
|
||||
if (client == INVALID_SOCKET) continue;
|
||||
std::thread([this, client]() { this->handleClient(client); }).detach();
|
||||
}
|
||||
}
|
||||
|
||||
void Proxy::handleClient(SOCKET hClientSocket)
|
||||
void Proxy::handleClient(SOCKET clientSocket)
|
||||
{
|
||||
ScopedSocket clientGuard(hClientSocket);
|
||||
char buffer[32768];
|
||||
AutoSocket clientGuard(clientSocket);
|
||||
|
||||
char buffer[32768];
|
||||
int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0);
|
||||
if (bytesRead <= 0) return;
|
||||
buffer[bytesRead] = '\0';
|
||||
@@ -235,6 +231,9 @@ void Proxy::handleClient(SOCKET hClientSocket)
|
||||
std::string initialReq(buffer);
|
||||
if (initialReq.find("CONNECT ") != 0) return;
|
||||
|
||||
/*
|
||||
host info
|
||||
*/
|
||||
size_t hostStart = 8;
|
||||
size_t hostEnd = initialReq.find(' ', hostStart);
|
||||
if (hostEnd == std::string::npos) return;
|
||||
@@ -246,7 +245,7 @@ void Proxy::handleClient(SOCKET hClientSocket)
|
||||
if (colon != std::string::npos)
|
||||
{
|
||||
host = fullHost.substr(0, colon);
|
||||
port = safe_stoi(fullHost.substr(colon + 1), 443);
|
||||
port = stoiSafe(fullHost.substr(colon + 1), 443);
|
||||
}
|
||||
|
||||
struct addrinfo hints = {}, *res = nullptr;
|
||||
@@ -254,35 +253,39 @@ void Proxy::handleClient(SOCKET hClientSocket)
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return;
|
||||
|
||||
ScopedSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
||||
if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0)
|
||||
{
|
||||
freeaddrinfo(res);
|
||||
return;
|
||||
}
|
||||
/*
|
||||
establish connection to host
|
||||
*/
|
||||
AutoSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
||||
if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0) return freeaddrinfo(res);
|
||||
freeaddrinfo(res);
|
||||
|
||||
send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0);
|
||||
|
||||
SSL_CTX* hostCtx = _certManager.CreateHostContext(host);
|
||||
/*
|
||||
SSL
|
||||
*/
|
||||
SSL_CTX* hostCtx = _certManager.createHostContext(host);
|
||||
if (!hostCtx) return;
|
||||
|
||||
ScopedSSL clientSSL(SSL_new(hostCtx));
|
||||
SSL_set_fd(clientSSL, (int)clientGuard.get());
|
||||
if (SSL_accept(clientSSL) <= 0) return;
|
||||
SSL_ptr clientSSL(SSL_new(hostCtx));
|
||||
SSL_set_fd(clientSSL.get(), (int)clientGuard);
|
||||
if (SSL_accept(clientSSL.get()) <= 0) return;
|
||||
|
||||
ScopedSSL remoteSSL(SSL_new(_clientCtx));
|
||||
SSL_set_fd(remoteSSL, (int)remoteGuard.get());
|
||||
SSL_set_tlsext_host_name(remoteSSL, host.c_str());
|
||||
if (SSL_connect(remoteSSL) <= 0) return;
|
||||
SSL_ptr remoteSSL(SSL_new(_clientCtx));
|
||||
SSL_set_fd(remoteSSL.get(), (int)remoteGuard);
|
||||
SSL_set_tlsext_host_name(remoteSSL.get(), host.c_str());
|
||||
if (SSL_connect(remoteSSL.get()) <= 0) return;
|
||||
|
||||
/*
|
||||
traffic handler
|
||||
*/
|
||||
HttpStream clientStream, serverStream;
|
||||
std::deque<std::string> pendingUrls;
|
||||
bool tunnelMode = false;
|
||||
fd_set readfds;
|
||||
|
||||
while (_running)
|
||||
{
|
||||
fd_set readfds;
|
||||
FD_ZERO(&readfds);
|
||||
FD_SET(clientGuard, &readfds);
|
||||
FD_SET(remoteGuard, &readfds);
|
||||
@@ -290,108 +293,82 @@ void Proxy::handleClient(SOCKET hClientSocket)
|
||||
struct timeval tv = {0, 50000};
|
||||
if (select(0, &readfds, NULL, NULL, &tv) < 0) break;
|
||||
|
||||
if (tunnelMode)
|
||||
/*
|
||||
client -> server
|
||||
*/
|
||||
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL.get()) > 0)
|
||||
{
|
||||
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
|
||||
{
|
||||
int n = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||
if (n <= 0) break;
|
||||
SSL_write(remoteSSL, buffer, n);
|
||||
}
|
||||
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||
{
|
||||
int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
|
||||
if (n <= 0) break;
|
||||
SSL_write(clientSSL, buffer, n);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0)
|
||||
{
|
||||
int n = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||
int n = SSL_read(clientSSL.get(), buffer, sizeof(buffer));
|
||||
if (n <= 0) break;
|
||||
clientStream.buffer.append(buffer, n);
|
||||
|
||||
while (!clientStream.buffer.empty())
|
||||
if (tunnelMode)
|
||||
SSL_write(remoteSSL.get(), buffer, n);
|
||||
else
|
||||
{
|
||||
if (!clientStream.isReceivingBody)
|
||||
clientStream.buffer.append(buffer, n);
|
||||
while (true)
|
||||
{
|
||||
if (!clientStream.parseHeaders()) break;
|
||||
|
||||
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
|
||||
std::string url = "https://" + host;
|
||||
size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1);
|
||||
if (s1 != std::string::npos && s2 != std::string::npos)
|
||||
url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1);
|
||||
|
||||
pendingUrls.push_back(url);
|
||||
|
||||
removeHeader(headers, "Accept-Encoding");
|
||||
removeHeader(headers, "Expect");
|
||||
headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n");
|
||||
|
||||
if (clientStream.contentLength == 0 || (clientStream.contentLength < 0 && !clientStream.isChunked))
|
||||
if (!clientStream.isReceivingBody)
|
||||
{
|
||||
std::string emptyBody = "";
|
||||
OnClientRequest.run(url, emptyBody, headers);
|
||||
if (!clientStream.parseHeaders()) break;
|
||||
|
||||
if (!pendingUrls.empty()) pendingUrls.back() = url;
|
||||
|
||||
SSL_write(remoteSSL, headers.data(), (int)headers.size());
|
||||
clientStream.buffer.erase(0, clientStream.headersEnd + 4);
|
||||
clientStream.reset();
|
||||
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
|
||||
std::string url = "https://" + host;
|
||||
size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1);
|
||||
if (s1 != std::string::npos && s2 != std::string::npos)
|
||||
url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1);
|
||||
pendingUrls.push_back(url);
|
||||
}
|
||||
}
|
||||
|
||||
if (clientStream.isReceivingBody)
|
||||
{
|
||||
size_t bodyStart = clientStream.headersEnd + 4;
|
||||
std::string url = pendingUrls.back();
|
||||
std::string headers = clientStream.buffer.substr(0, bodyStart);
|
||||
removeHeader(headers, "Accept-Encoding");
|
||||
removeHeader(headers, "Expect");
|
||||
headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n");
|
||||
|
||||
bool complete = false;
|
||||
std::string body;
|
||||
std::string fullBody;
|
||||
size_t totalRequestSize = 0;
|
||||
|
||||
if (clientStream.isChunked)
|
||||
{
|
||||
size_t idx = bodyStart;
|
||||
size_t idx = clientStream.headersEnd + 4;
|
||||
while (idx < clientStream.buffer.size())
|
||||
{
|
||||
size_t le = clientStream.buffer.find("\r\n", idx);
|
||||
if (le == std::string::npos) break;
|
||||
int cs = safe_stoi(clientStream.buffer.substr(idx, le - idx), 0, 16);
|
||||
int cs = stoiSafe(clientStream.buffer.substr(idx, le - idx), 0, 16);
|
||||
if (idx + (le - idx) + 2 + cs + 2 > clientStream.buffer.size()) break;
|
||||
body.append(clientStream.buffer, le + 2, cs);
|
||||
if (cs > 0) fullBody.append(clientStream.buffer, le + 2, cs);
|
||||
idx = le + 2 + cs + 2;
|
||||
if (cs == 0)
|
||||
{
|
||||
complete = true;
|
||||
totalRequestSize = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (clientStream.contentLength >= 0)
|
||||
else
|
||||
{
|
||||
if (clientStream.buffer.size() >= bodyStart + clientStream.contentLength)
|
||||
int cl = clientStream.contentLength;
|
||||
if (cl < 0) cl = 0;
|
||||
if (clientStream.buffer.size() >= (clientStream.headersEnd + 4 + cl))
|
||||
{
|
||||
body = clientStream.buffer.substr(bodyStart, clientStream.contentLength);
|
||||
fullBody = clientStream.buffer.substr(clientStream.headersEnd + 4, cl);
|
||||
complete = true;
|
||||
totalRequestSize = clientStream.headersEnd + 4 + cl;
|
||||
}
|
||||
}
|
||||
|
||||
if (complete)
|
||||
{
|
||||
OnClientRequest.run(url, body, headers);
|
||||
if (!pendingUrls.empty() && pendingUrls.back() != url) pendingUrls.back() = url;
|
||||
std::string url = pendingUrls.back();
|
||||
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
|
||||
|
||||
SSL_write(remoteSSL, headers.data(), (int)headers.size());
|
||||
SSL_write(remoteSSL, clientStream.buffer.data() + bodyStart,
|
||||
(int)(clientStream.buffer.size() - bodyStart));
|
||||
clientStream.buffer.clear();
|
||||
removeHeader(headers, "Accept-Encoding");
|
||||
headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n");
|
||||
|
||||
OnClientRequest.run(url, fullBody, headers);
|
||||
|
||||
SSL_write(remoteSSL.get(), headers.data(), (int)headers.size());
|
||||
SSL_write(remoteSSL.get(), fullBody.data(), (int)fullBody.size());
|
||||
|
||||
clientStream.buffer.erase(0, totalRequestSize);
|
||||
clientStream.reset();
|
||||
}
|
||||
else
|
||||
@@ -400,169 +377,126 @@ void Proxy::handleClient(SOCKET hClientSocket)
|
||||
}
|
||||
}
|
||||
|
||||
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||
/*
|
||||
server -> client
|
||||
*/
|
||||
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL.get()) > 0)
|
||||
{
|
||||
int n = SSL_read(remoteSSL, buffer, sizeof(buffer));
|
||||
bool connectionClosed = (n <= 0);
|
||||
if (!connectionClosed)
|
||||
{
|
||||
serverStream.buffer.append(buffer, n);
|
||||
}
|
||||
int n = SSL_read(remoteSSL.get(), buffer, sizeof(buffer));
|
||||
bool closed = (n <= 0);
|
||||
if (!closed) serverStream.buffer.append(buffer, n);
|
||||
|
||||
while (!serverStream.buffer.empty() || connectionClosed)
|
||||
while (true)
|
||||
{
|
||||
if (!serverStream.isReceivingBody)
|
||||
if (!serverStream.parseHeaders()) break;
|
||||
|
||||
bool complete = false;
|
||||
std::string fullBody;
|
||||
size_t totalResponseSize = 0;
|
||||
|
||||
if (serverStream.isChunked)
|
||||
{
|
||||
if (connectionClosed) break;
|
||||
|
||||
serverStream.headersEnd = serverStream.buffer.find("\r\n\r\n");
|
||||
if (serverStream.headersEnd == std::string::npos) break;
|
||||
|
||||
std::string headers = serverStream.buffer.substr(0, serverStream.headersEnd + 4);
|
||||
size_t s1 = headers.find(' ');
|
||||
int sCode = (s1 != std::string::npos) ? safe_stoi(headers.substr(s1 + 1, 3)) : 0;
|
||||
|
||||
if (sCode == 101)
|
||||
size_t idx = serverStream.headersEnd + 4;
|
||||
while (idx < serverStream.buffer.size())
|
||||
{
|
||||
SSL_write(clientSSL, serverStream.buffer.data(), (int)serverStream.buffer.size());
|
||||
serverStream.buffer.clear();
|
||||
clientStream.buffer.clear();
|
||||
tunnelMode = true;
|
||||
break;
|
||||
size_t le = serverStream.buffer.find("\r\n", idx);
|
||||
if (le == std::string::npos) break;
|
||||
int cs = stoiSafe(serverStream.buffer.substr(idx, le - idx), 0, 16);
|
||||
if (idx + (le - idx) + 2 + cs + 2 > serverStream.buffer.size()) break;
|
||||
if (cs > 0) fullBody.append(serverStream.buffer, le + 2, cs);
|
||||
idx = le + 2 + cs + 2;
|
||||
if (cs == 0)
|
||||
{
|
||||
complete = true;
|
||||
totalResponseSize = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (sCode >= 100 && sCode < 200)
|
||||
}
|
||||
else if (serverStream.contentLength >= 0)
|
||||
{
|
||||
if (serverStream.buffer.size() >= (serverStream.headersEnd + 4 + serverStream.contentLength))
|
||||
{
|
||||
SSL_write(clientSSL, headers.data(), (int)headers.size());
|
||||
serverStream.buffer.erase(0, serverStream.headersEnd + 4);
|
||||
serverStream.isReceivingBody = false;
|
||||
continue;
|
||||
fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4, serverStream.contentLength);
|
||||
complete = true;
|
||||
totalResponseSize = serverStream.headersEnd + 4 + serverStream.contentLength;
|
||||
}
|
||||
|
||||
serverStream.isReceivingBody = true;
|
||||
std::string h_lower = headers;
|
||||
std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower);
|
||||
serverStream.isChunked = (h_lower.find("transfer-encoding: chunked") != std::string::npos);
|
||||
|
||||
size_t clPos = h_lower.find("content-length:");
|
||||
if (clPos != std::string::npos)
|
||||
{
|
||||
size_t vStart = clPos + 15;
|
||||
while (vStart < h_lower.size() && (h_lower[vStart] == ' ' || h_lower[vStart] == '\t'))
|
||||
vStart++;
|
||||
serverStream.contentLength =
|
||||
safe_stoi(h_lower.substr(vStart, h_lower.find("\r\n", vStart) - vStart), -1);
|
||||
}
|
||||
else if (sCode == 204 || sCode == 304 || sCode == 205)
|
||||
serverStream.contentLength = 0;
|
||||
else if (!serverStream.isChunked)
|
||||
serverStream.contentLength = -1;
|
||||
}
|
||||
else if (closed)
|
||||
{
|
||||
fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4);
|
||||
complete = true;
|
||||
totalResponseSize = serverStream.buffer.size();
|
||||
}
|
||||
|
||||
if (serverStream.isReceivingBody)
|
||||
if (complete)
|
||||
{
|
||||
bool complete = false;
|
||||
std::string body;
|
||||
size_t bStart = serverStream.headersEnd + 4;
|
||||
size_t processed = bStart;
|
||||
std::string url = pendingUrls.empty() ? "https://" + host : pendingUrls.front();
|
||||
if (!pendingUrls.empty()) pendingUrls.pop_front();
|
||||
|
||||
if (serverStream.isChunked)
|
||||
{
|
||||
size_t idx = bStart;
|
||||
bool ok = true;
|
||||
while (idx < serverStream.buffer.size())
|
||||
{
|
||||
size_t le = serverStream.buffer.find("\r\n", idx);
|
||||
if (le == std::string::npos)
|
||||
{
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
int cs = safe_stoi(serverStream.buffer.substr(idx, le - idx), 0, 16);
|
||||
idx = le + 2;
|
||||
if (cs == 0)
|
||||
{
|
||||
idx += 2;
|
||||
complete = true;
|
||||
processed = idx;
|
||||
break;
|
||||
}
|
||||
if (idx + cs + 2 > serverStream.buffer.size())
|
||||
{
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
body.append(serverStream.buffer, idx, cs);
|
||||
idx += cs + 2;
|
||||
}
|
||||
if (!ok)
|
||||
{
|
||||
if (connectionClosed)
|
||||
{
|
||||
complete = true;
|
||||
processed = serverStream.buffer.size();
|
||||
}
|
||||
else
|
||||
complete = false;
|
||||
}
|
||||
}
|
||||
else if (serverStream.contentLength >= 0)
|
||||
{
|
||||
if (serverStream.buffer.size() >= bStart + serverStream.contentLength)
|
||||
{
|
||||
complete = true;
|
||||
processed = bStart + serverStream.contentLength;
|
||||
body = serverStream.buffer.substr(bStart, serverStream.contentLength);
|
||||
}
|
||||
else if (connectionClosed)
|
||||
{
|
||||
complete = true;
|
||||
processed = serverStream.buffer.size();
|
||||
body = serverStream.buffer.substr(bStart);
|
||||
}
|
||||
}
|
||||
std::string respHeaders = serverStream.buffer.substr(0, serverStream.headersEnd + 4);
|
||||
|
||||
OnServerResponse.run(url, fullBody, respHeaders);
|
||||
|
||||
std::string packet;
|
||||
int statusCode = 200;
|
||||
size_t space = respHeaders.find(' ');
|
||||
if (space != std::string::npos) statusCode = stoiSafe(respHeaders.substr(space + 1, 3));
|
||||
|
||||
if (tunnelMode || statusCode == 204 || statusCode == 304 || (statusCode >= 100 && statusCode < 200))
|
||||
packet = respHeaders + fullBody;
|
||||
else
|
||||
{
|
||||
if (connectionClosed)
|
||||
{
|
||||
complete = true;
|
||||
processed = serverStream.buffer.size();
|
||||
body = serverStream.buffer.substr(bStart);
|
||||
}
|
||||
else
|
||||
complete = false;
|
||||
}
|
||||
|
||||
if (complete)
|
||||
{
|
||||
std::string url = pendingUrls.empty() ? ("https://" + host) : pendingUrls.front();
|
||||
if (!pendingUrls.empty()) pendingUrls.pop_front();
|
||||
|
||||
std::string respHeaders = serverStream.buffer.substr(0, bStart);
|
||||
OnServerResponse.run(url, body, respHeaders);
|
||||
|
||||
removeHeader(respHeaders, "Transfer-Encoding");
|
||||
removeHeader(respHeaders, "Content-Length");
|
||||
|
||||
size_t fs = respHeaders.find(' ');
|
||||
int scFinal = (fs != std::string::npos) ? safe_stoi(respHeaders.substr(fs + 1, 3)) : 0;
|
||||
|
||||
if (scFinal != 204 && scFinal != 304 && scFinal != 205)
|
||||
respHeaders.insert(respHeaders.size() - 2,
|
||||
"Content-Length: " + std::to_string(body.size()) + "\r\n");
|
||||
|
||||
std::string packet = respHeaders + body;
|
||||
SSL_write(clientSSL, packet.data(), (int)packet.size());
|
||||
|
||||
serverStream.buffer.erase(0, processed);
|
||||
serverStream.reset();
|
||||
clientStream.reset();
|
||||
respHeaders.insert(respHeaders.size() - 2,
|
||||
"Content-Length: " + std::to_string(fullBody.size()) + "\r\n");
|
||||
packet = respHeaders + fullBody;
|
||||
}
|
||||
else
|
||||
break;
|
||||
|
||||
SSL_write(clientSSL.get(), packet.data(), (int)packet.size());
|
||||
|
||||
serverStream.buffer.erase(0, totalResponseSize);
|
||||
|
||||
if (tunnelMode)
|
||||
{
|
||||
if (serverStream.buffer.size() > 0)
|
||||
{
|
||||
SSL_write(clientSSL.get(), serverStream.buffer.data(), (int)serverStream.buffer.size());
|
||||
serverStream.buffer.clear();
|
||||
}
|
||||
if (clientStream.buffer.size() > 0)
|
||||
{
|
||||
SSL_write(remoteSSL.get(), clientStream.buffer.data(), (int)clientStream.buffer.size());
|
||||
clientStream.buffer.clear();
|
||||
}
|
||||
}
|
||||
|
||||
serverStream.reset();
|
||||
if (tunnelMode) break;
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
if (connectionClosed) break;
|
||||
if (closed) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Proxy::initSSL()
|
||||
{
|
||||
_clientCtx = SSL_CTX_new(TLS_client_method());
|
||||
if (!_clientCtx) return false;
|
||||
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
|
||||
const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
|
||||
SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos));
|
||||
return true;
|
||||
}
|
||||
|
||||
void Proxy::cleanupSSL()
|
||||
{
|
||||
if (!_clientCtx) return;
|
||||
SSL_CTX_free(_clientCtx);
|
||||
_clientCtx = nullptr;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <string>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/err.h>
|
||||
#include "cert_manager.h"
|
||||
#include "ssl.h"
|
||||
#include <nerutils/callback.h>
|
||||
|
||||
/*
|
||||
@@ -22,8 +22,8 @@ class Proxy
|
||||
Proxy();
|
||||
~Proxy();
|
||||
|
||||
bool Init();
|
||||
void Shutdown();
|
||||
bool init();
|
||||
void shutdown();
|
||||
|
||||
CallbackEvent<std::string&, const std::string&, std::string&> OnClientRequest;
|
||||
CallbackEvent<const std::string&, std::string&, std::string&> OnServerResponse;
|
||||
|
||||
Reference in New Issue
Block a user