feat: add full HTTP/HTTPs support
This commit is contained in:
+469
-25
@@ -4,6 +4,28 @@
|
||||
#include <ws2tcpip.h>
|
||||
|
||||
#include <nerutils/log.h>
|
||||
#include <deque>
|
||||
|
||||
bool Proxy::initSSL()
|
||||
{
|
||||
_clientCtx = SSL_CTX_new(TLS_client_method());
|
||||
if (!_clientCtx)
|
||||
{
|
||||
Log::error("Failed to create client SSL context");
|
||||
return false;
|
||||
}
|
||||
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Proxy::cleanupSSL()
|
||||
{
|
||||
if (_clientCtx)
|
||||
{
|
||||
SSL_CTX_free(_clientCtx);
|
||||
_clientCtx = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Proxy::Proxy() {}
|
||||
|
||||
@@ -18,6 +40,15 @@ bool Proxy::Init()
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!_certManager.Init())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!initSSL())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
_listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
||||
if (_listenSocket == INVALID_SOCKET)
|
||||
{
|
||||
@@ -61,6 +92,8 @@ void Proxy::Shutdown()
|
||||
if (_workerThread.joinable()) _workerThread.join();
|
||||
|
||||
WSACleanup();
|
||||
|
||||
cleanupSSL();
|
||||
}
|
||||
|
||||
void Proxy::loop()
|
||||
@@ -92,17 +125,66 @@ void Proxy::handleClient(SOCKET clientSocket)
|
||||
return;
|
||||
}
|
||||
|
||||
buffer[bytesReceived] = '\0';
|
||||
std::string request(buffer, bytesReceived);
|
||||
|
||||
/*
|
||||
get host
|
||||
*/
|
||||
std::string host;
|
||||
size_t hostPos = request.find("Host: ");
|
||||
if (hostPos != std::string::npos)
|
||||
std::string method, url;
|
||||
size_t space1 = request.find(' ');
|
||||
if (space1 != std::string::npos)
|
||||
{
|
||||
size_t endPos = request.find("\r\n", hostPos);
|
||||
host = request.substr(hostPos + 6, endPos - (hostPos + 6));
|
||||
method = request.substr(0, space1);
|
||||
size_t space2 = request.find(' ', space1 + 1);
|
||||
if (space2 != std::string::npos)
|
||||
{
|
||||
url = request.substr(space1 + 1, space2 - space1 - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (method.empty() || url.empty())
|
||||
{
|
||||
closesocket(clientSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
std::string host;
|
||||
std::string port = "80";
|
||||
bool isConnect = (method == "CONNECT");
|
||||
|
||||
if (isConnect)
|
||||
{
|
||||
size_t colon = url.find(':');
|
||||
if (colon != std::string::npos)
|
||||
{
|
||||
host = url.substr(0, colon);
|
||||
port = url.substr(colon + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
host = url;
|
||||
port = "443";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t hostPos = request.find("Host: ");
|
||||
if (hostPos != std::string::npos)
|
||||
{
|
||||
size_t endPos = request.find("\r\n", hostPos);
|
||||
if (endPos != std::string::npos)
|
||||
{
|
||||
std::string hostHeader = request.substr(hostPos + 6, endPos - (hostPos + 6));
|
||||
size_t colon = hostHeader.find(':');
|
||||
if (colon != std::string::npos)
|
||||
{
|
||||
host = hostHeader.substr(0, colon);
|
||||
port = hostHeader.substr(colon + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
host = hostHeader;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (host.empty())
|
||||
@@ -111,14 +193,11 @@ void Proxy::handleClient(SOCKET clientSocket)
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
handle remote
|
||||
*/
|
||||
struct addrinfo hints = {}, *res;
|
||||
hints.ai_family = AF_INET;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
if (getaddrinfo(host.c_str(), "80", &hints, &res) != 0)
|
||||
if (getaddrinfo(host.c_str(), port.c_str(), &hints, &res) != 0)
|
||||
{
|
||||
Log::error("Could not resolve host: {}", host);
|
||||
closesocket(clientSocket);
|
||||
@@ -128,26 +207,391 @@ void Proxy::handleClient(SOCKET clientSocket)
|
||||
SOCKET remoteSocket = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
|
||||
if (connect(remoteSocket, res->ai_addr, (int)res->ai_addrlen) == SOCKET_ERROR)
|
||||
{
|
||||
Log::error("Connection to {} failed", host);
|
||||
Log::error("Connection to {}:{} failed", host, port);
|
||||
freeaddrinfo(res);
|
||||
closesocket(clientSocket);
|
||||
return;
|
||||
}
|
||||
freeaddrinfo(res);
|
||||
|
||||
/*
|
||||
fwd
|
||||
*/
|
||||
send(remoteSocket, buffer, bytesReceived, 0);
|
||||
|
||||
/*
|
||||
recv
|
||||
*/
|
||||
int remoteBytes;
|
||||
while ((remoteBytes = recv(remoteSocket, buffer, sizeof(buffer), 0)) > 0)
|
||||
if (isConnect)
|
||||
{
|
||||
// Log::verbose("Forwarding {} bytes from server back to client", remoteBytes);
|
||||
send(clientSocket, buffer, remoteBytes, 0);
|
||||
const char* reply = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
send(clientSocket, reply, static_cast<int>(strlen(reply)), 0);
|
||||
|
||||
SSL_CTX* serverCtx = _certManager.CreateHostContext(host);
|
||||
if (!serverCtx)
|
||||
{
|
||||
Log::error("Failed to generate dynamic cert for {}", host);
|
||||
closesocket(clientSocket);
|
||||
closesocket(remoteSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
SSL* clientSSL = SSL_new(serverCtx);
|
||||
SSL_set_fd(clientSSL, static_cast<int>(clientSocket));
|
||||
|
||||
if (SSL_accept(clientSSL) <= 0)
|
||||
{
|
||||
Log::error("SSL_accept failed on client");
|
||||
SSL_free(clientSSL);
|
||||
closesocket(clientSocket);
|
||||
closesocket(remoteSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
SSL* remoteSSL = SSL_new(_clientCtx);
|
||||
SSL_set_fd(remoteSSL, static_cast<int>(remoteSocket));
|
||||
SSL_set_tlsext_host_name(remoteSSL, host.c_str());
|
||||
|
||||
if (SSL_connect(remoteSSL) <= 0)
|
||||
{
|
||||
Log::error("SSL_connect failed on remote server");
|
||||
SSL_free(remoteSSL);
|
||||
SSL_free(clientSSL);
|
||||
closesocket(clientSocket);
|
||||
closesocket(remoteSocket);
|
||||
return;
|
||||
}
|
||||
|
||||
std::deque<std::string> pendingUrls;
|
||||
std::string serverBuffer;
|
||||
bool isReceivingBody = false;
|
||||
int expectedLength = -1;
|
||||
bool isChunked = false;
|
||||
size_t headersEnd = 0;
|
||||
|
||||
fd_set readfds;
|
||||
while (_running)
|
||||
{
|
||||
FD_ZERO(&readfds);
|
||||
FD_SET(clientSocket, &readfds);
|
||||
FD_SET(remoteSocket, &readfds);
|
||||
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 0;
|
||||
tv.tv_usec = 100000;
|
||||
|
||||
int ret = select(0, &readfds, NULL, NULL, &tv);
|
||||
if (ret < 0) break;
|
||||
|
||||
if (FD_ISSET(clientSocket, &readfds) || SSL_pending(clientSSL) > 0)
|
||||
{
|
||||
int bytes = SSL_read(clientSSL, buffer, sizeof(buffer));
|
||||
if (bytes <= 0) break;
|
||||
|
||||
std::string data(buffer, bytes);
|
||||
|
||||
size_t reqStart = 0;
|
||||
while (reqStart < data.size())
|
||||
{
|
||||
size_t nextReq = std::string::npos;
|
||||
const char* methods[] = {"GET ", "POST ", "PUT ", "DELETE ", "PATCH ", "OPTIONS ", "HEAD "};
|
||||
for (const char* m : methods)
|
||||
{
|
||||
size_t found = data.find(m, reqStart + 1);
|
||||
if (found != std::string::npos && (nextReq == std::string::npos || found < nextReq))
|
||||
nextReq = found;
|
||||
}
|
||||
|
||||
std::string singleReq = (nextReq == std::string::npos) ? data.substr(reqStart)
|
||||
: data.substr(reqStart, nextReq - reqStart);
|
||||
|
||||
std::string url;
|
||||
size_t pathSpace1 = singleReq.find(' ');
|
||||
size_t pathSpace2 = singleReq.find(' ', pathSpace1 + 1);
|
||||
if (pathSpace1 != std::string::npos && pathSpace2 != std::string::npos)
|
||||
{
|
||||
url = "https://" + host + singleReq.substr(pathSpace1 + 1, pathSpace2 - pathSpace1 - 1);
|
||||
pendingUrls.push_back(url);
|
||||
|
||||
size_t aePos = singleReq.find("Accept-Encoding:");
|
||||
if (aePos == std::string::npos) aePos = singleReq.find("accept-encoding:");
|
||||
if (aePos != std::string::npos)
|
||||
{
|
||||
size_t aeEndPos = singleReq.find("\r\n", aePos);
|
||||
if (aeEndPos != std::string::npos)
|
||||
singleReq.replace(aePos, aeEndPos - aePos, "Accept-Encoding: identity");
|
||||
}
|
||||
}
|
||||
|
||||
OnClientRequest.run(url.empty() ? ("https://" + host) : url, singleReq);
|
||||
|
||||
if (nextReq == std::string::npos) break;
|
||||
reqStart = nextReq;
|
||||
}
|
||||
|
||||
int sent = SSL_write(remoteSSL, data.data(), static_cast<int>(data.size()));
|
||||
if (sent <= 0) break;
|
||||
}
|
||||
|
||||
if (FD_ISSET(remoteSocket, &readfds) || SSL_pending(remoteSSL) > 0)
|
||||
{
|
||||
int bytes = SSL_read(remoteSSL, buffer, sizeof(buffer) - 1);
|
||||
if (bytes <= 0) break;
|
||||
|
||||
serverBuffer.append(buffer, bytes);
|
||||
|
||||
while (!serverBuffer.empty())
|
||||
{
|
||||
if (!isReceivingBody)
|
||||
{
|
||||
headersEnd = serverBuffer.find("\r\n\r\n");
|
||||
if (headersEnd != std::string::npos)
|
||||
{
|
||||
isReceivingBody = true;
|
||||
std::string headers = serverBuffer.substr(0, headersEnd + 4);
|
||||
|
||||
size_t clPos = headers.find("Content-Length: ");
|
||||
if (clPos == std::string::npos) clPos = headers.find("content-length: ");
|
||||
if (clPos != std::string::npos)
|
||||
{
|
||||
size_t clEnd = headers.find("\r\n", clPos);
|
||||
expectedLength = std::stoi(headers.substr(clPos + 16, clEnd - clPos - 16));
|
||||
}
|
||||
else
|
||||
expectedLength = -1;
|
||||
|
||||
isChunked = (headers.find("chunked") != std::string::npos);
|
||||
}
|
||||
else
|
||||
{
|
||||
break; // need more data
|
||||
}
|
||||
}
|
||||
|
||||
if (isReceivingBody)
|
||||
{
|
||||
bool complete = false;
|
||||
std::string fullBody;
|
||||
size_t bodyStart = headersEnd + 4;
|
||||
size_t totalProcessed = bodyStart;
|
||||
|
||||
if (isChunked)
|
||||
{
|
||||
size_t idx = bodyStart;
|
||||
bool parseOk = true;
|
||||
while (idx < serverBuffer.size())
|
||||
{
|
||||
size_t lineEnd = serverBuffer.find("\r\n", idx);
|
||||
if (lineEnd == std::string::npos)
|
||||
{
|
||||
parseOk = false;
|
||||
break;
|
||||
}
|
||||
std::string hexStr = serverBuffer.substr(idx, lineEnd - idx);
|
||||
int chunkSize = 0;
|
||||
try
|
||||
{
|
||||
chunkSize = std::stoi(hexStr, nullptr, 16);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
parseOk = false;
|
||||
break;
|
||||
}
|
||||
idx = lineEnd + 2;
|
||||
if (chunkSize == 0)
|
||||
{
|
||||
idx += 2; // skip terminal \r\n
|
||||
complete = true;
|
||||
totalProcessed = idx;
|
||||
break;
|
||||
}
|
||||
if (idx + (size_t)chunkSize + 2 > serverBuffer.size())
|
||||
{
|
||||
parseOk = false;
|
||||
break;
|
||||
}
|
||||
fullBody.append(serverBuffer, idx, chunkSize);
|
||||
idx += chunkSize + 2;
|
||||
}
|
||||
if (!parseOk) complete = false;
|
||||
}
|
||||
else if (expectedLength >= 0)
|
||||
{
|
||||
if (serverBuffer.size() >= bodyStart + expectedLength)
|
||||
{
|
||||
complete = true;
|
||||
totalProcessed = bodyStart + expectedLength;
|
||||
fullBody = serverBuffer.substr(bodyStart, expectedLength);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string peekBuffer = serverBuffer.substr(0, bodyStart);
|
||||
bool isCloseConn = peekBuffer.find("Connection: close") != std::string::npos ||
|
||||
peekBuffer.find("connection: close") != std::string::npos;
|
||||
|
||||
bool isNoBodyStatus = peekBuffer.find("HTTP/1.1 204") != std::string::npos ||
|
||||
peekBuffer.find("HTTP/1.1 304") != std::string::npos ||
|
||||
peekBuffer.find("HTTP/1.0 204") != std::string::npos ||
|
||||
peekBuffer.find("HTTP/1.0 304") != std::string::npos;
|
||||
|
||||
if (isCloseConn)
|
||||
break;
|
||||
else
|
||||
{
|
||||
complete = true;
|
||||
fullBody = "";
|
||||
totalProcessed = bodyStart;
|
||||
}
|
||||
}
|
||||
|
||||
if (complete)
|
||||
{
|
||||
std::string headers = serverBuffer.substr(0, bodyStart);
|
||||
std::string responseData = fullBody;
|
||||
|
||||
std::string currentUrl = "https://" + host;
|
||||
if (!pendingUrls.empty())
|
||||
{
|
||||
currentUrl = pendingUrls.front();
|
||||
pendingUrls.pop_front();
|
||||
}
|
||||
|
||||
OnServerResponse.run(currentUrl, responseData);
|
||||
|
||||
auto removeHeader = [&](std::string& h, const std::string& key) {
|
||||
size_t pos = 0;
|
||||
while (true)
|
||||
{
|
||||
pos = h.find(key, pos);
|
||||
if (pos == std::string::npos)
|
||||
{
|
||||
std::string lowerKey = key;
|
||||
for (char& c : lowerKey)
|
||||
c = (char)tolower(c);
|
||||
pos = h.find(lowerKey, 0);
|
||||
if (pos == std::string::npos) break;
|
||||
}
|
||||
if (pos == 0 || h[pos - 1] == '\n')
|
||||
{
|
||||
size_t end = h.find("\r\n", pos);
|
||||
if (end != std::string::npos)
|
||||
{
|
||||
h.erase(pos, end - pos + 2);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
};
|
||||
|
||||
removeHeader(headers, "Transfer-Encoding");
|
||||
removeHeader(headers, "Content-Length");
|
||||
|
||||
headers.insert(headers.size() - 2,
|
||||
"Content-Length: " + std::to_string(responseData.size()) + "\r\n");
|
||||
|
||||
std::string finalPacket = headers + responseData;
|
||||
int sent = SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
|
||||
|
||||
serverBuffer.erase(0, totalProcessed);
|
||||
|
||||
isReceivingBody = false;
|
||||
expectedLength = -1;
|
||||
headersEnd = 0;
|
||||
isChunked = false;
|
||||
|
||||
if (sent <= 0) break;
|
||||
}
|
||||
else
|
||||
break; // wait for more streaming packets
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isReceivingBody && expectedLength < 0 && !isChunked && serverBuffer.size() > headersEnd + 4)
|
||||
{
|
||||
std::string headers = serverBuffer.substr(0, headersEnd + 4);
|
||||
std::string responseData = serverBuffer.substr(headersEnd + 4);
|
||||
|
||||
std::string finalUrl = "https://" + host;
|
||||
if (!pendingUrls.empty())
|
||||
{
|
||||
finalUrl = pendingUrls.front();
|
||||
pendingUrls.pop_front();
|
||||
}
|
||||
|
||||
OnServerResponse.run(finalUrl, responseData);
|
||||
|
||||
auto removeHeader = [&](std::string& h, const std::string& key) {
|
||||
size_t pos = 0;
|
||||
while (true)
|
||||
{
|
||||
pos = h.find(key, pos);
|
||||
if (pos == std::string::npos)
|
||||
{
|
||||
std::string lowerKey = key;
|
||||
for (char& c : lowerKey)
|
||||
c = (char)tolower(c);
|
||||
pos = h.find(lowerKey, 0);
|
||||
if (pos == std::string::npos) break;
|
||||
}
|
||||
if (pos == 0 || h[pos - 1] == '\n')
|
||||
{
|
||||
size_t end = h.find("\r\n", pos);
|
||||
if (end != std::string::npos)
|
||||
{
|
||||
h.erase(pos, end - pos + 2);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
};
|
||||
|
||||
removeHeader(headers, "Transfer-Encoding");
|
||||
removeHeader(headers, "Content-Length");
|
||||
headers.insert(headers.size() - 2, "Content-Length: " + std::to_string(responseData.size()) + "\r\n");
|
||||
|
||||
std::string finalPacket = headers + responseData;
|
||||
SSL_write(clientSSL, finalPacket.data(), static_cast<int>(finalPacket.size()));
|
||||
}
|
||||
|
||||
SSL_shutdown(clientSSL);
|
||||
SSL_free(clientSSL);
|
||||
|
||||
SSL_shutdown(remoteSSL);
|
||||
SSL_free(remoteSSL);
|
||||
}
|
||||
else
|
||||
{
|
||||
send(remoteSocket, buffer, bytesReceived, 0);
|
||||
|
||||
fd_set readfds;
|
||||
while (_running)
|
||||
{
|
||||
FD_ZERO(&readfds);
|
||||
FD_SET(clientSocket, &readfds);
|
||||
FD_SET(remoteSocket, &readfds);
|
||||
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 1;
|
||||
tv.tv_usec = 0;
|
||||
|
||||
int ret = select(0, &readfds, NULL, NULL, &tv);
|
||||
if (ret < 0) break;
|
||||
if (ret == 0) continue;
|
||||
|
||||
if (FD_ISSET(clientSocket, &readfds))
|
||||
{
|
||||
int bytes = recv(clientSocket, buffer, sizeof(buffer), 0);
|
||||
if (bytes <= 0) break;
|
||||
int sent = send(remoteSocket, buffer, bytes, 0);
|
||||
if (sent == SOCKET_ERROR) break;
|
||||
}
|
||||
|
||||
if (FD_ISSET(remoteSocket, &readfds))
|
||||
{
|
||||
int bytes = recv(remoteSocket, buffer, sizeof(buffer), 0);
|
||||
if (bytes <= 0) break;
|
||||
int sent = send(clientSocket, buffer, bytes, 0);
|
||||
if (sent == SOCKET_ERROR) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
closesocket(remoteSocket);
|
||||
|
||||
Reference in New Issue
Block a user