726 lines
23 KiB
C++
726 lines
23 KiB
C++
#include "proxy.h"
|
|
|
|
#include <winsock2.h>
|
|
#include <ws2tcpip.h>
|
|
|
|
#include <nerutils/log.h>
|
|
|
|
#include <algorithm>
|
|
#include <deque>
|
|
|
|
/*
|
|
memory helpers
|
|
*/
|
|
template <typename T, void (*f)(T*)> struct Deleter
|
|
{
|
|
void operator()(T* p) const
|
|
{
|
|
if (p) f(p);
|
|
}
|
|
};
|
|
using SSL_ptr = std::unique_ptr<SSL, Deleter<SSL, SSL_free>>;
|
|
|
|
struct AutoSocket
|
|
{
|
|
SOCKET s;
|
|
AutoSocket(SOCKET val = INVALID_SOCKET) : s(val) {}
|
|
~AutoSocket()
|
|
{
|
|
if (s != INVALID_SOCKET) closesocket(s);
|
|
}
|
|
operator SOCKET() const { return s; }
|
|
};
|
|
|
|
/*
|
|
helper functions
|
|
*/
|
|
int stoiSafe(const std::string& s, int default_val = 0, int base = 10)
|
|
{
|
|
if (s.empty()) return default_val;
|
|
try
|
|
{
|
|
return std::stoi(s, nullptr, base);
|
|
}
|
|
catch (...)
|
|
{
|
|
return default_val;
|
|
}
|
|
}
|
|
|
|
void removeHeader(std::string& headers, const std::string& key)
|
|
{
|
|
if (!headers.empty() && headers.back() != '\n') headers += "\r\n";
|
|
|
|
std::string result;
|
|
std::string keyLower = key;
|
|
std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower);
|
|
|
|
size_t start = 0;
|
|
size_t end;
|
|
|
|
while ((end = headers.find('\n', start)) != std::string::npos)
|
|
{
|
|
std::string line = headers.substr(start, end - start + 1);
|
|
std::string lineLower = line;
|
|
std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower);
|
|
|
|
bool match = false;
|
|
if (lineLower.compare(0, keyLower.length(), keyLower) == 0)
|
|
{
|
|
size_t pos = keyLower.length();
|
|
while (pos < lineLower.length() && (lineLower[pos] == ' ' || lineLower[pos] == '\t'))
|
|
pos++;
|
|
|
|
if (pos < lineLower.length() && lineLower[pos] == ':') match = true;
|
|
}
|
|
|
|
if (!match) result += line;
|
|
start = end + 1;
|
|
}
|
|
headers = std::move(result);
|
|
}
|
|
|
|
std::string getHeaderValue(const std::string& headers, const std::string& key)
|
|
{
|
|
size_t start = 0;
|
|
size_t end;
|
|
|
|
std::string keyLower = key;
|
|
std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower);
|
|
if (keyLower.empty()) return "";
|
|
|
|
while (start < headers.length())
|
|
{
|
|
end = headers.find('\n', start);
|
|
|
|
std::string line = (end == std::string::npos) ? headers.substr(start) : headers.substr(start, end - start);
|
|
std::string lineLower = line;
|
|
std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower);
|
|
|
|
if (lineLower.compare(0, keyLower.length(), keyLower) == 0)
|
|
{
|
|
size_t pos = keyLower.length();
|
|
while (pos < lineLower.length() && (lineLower[pos] == ' ' || lineLower[pos] == '\t'))
|
|
pos++;
|
|
if (pos < lineLower.length() && lineLower[pos] == ':')
|
|
{
|
|
size_t valueStart = pos + 1;
|
|
while (valueStart < line.length() && (line[valueStart] == ' ' || line[valueStart] == '\t'))
|
|
valueStart++;
|
|
|
|
size_t valueEnd = line.length();
|
|
if (valueEnd > valueStart && line[valueEnd - 1] == '\r') valueEnd--;
|
|
|
|
return line.substr(valueStart, valueEnd - valueStart);
|
|
}
|
|
}
|
|
|
|
if (end == std::string::npos) break;
|
|
start = end + 1;
|
|
}
|
|
|
|
return "";
|
|
}
|
|
|
|
/*
|
|
http stream helper class
|
|
*/
|
|
struct HttpStream
|
|
{
|
|
std::string buffer;
|
|
bool isReceivingBody = false;
|
|
bool isChunked = false;
|
|
int contentLength = -1;
|
|
size_t headersEnd = std::string::npos;
|
|
int statusCode = 0;
|
|
|
|
size_t currentChunkIdx = 0;
|
|
std::string payload;
|
|
|
|
void reset()
|
|
{
|
|
isReceivingBody = false;
|
|
isChunked = false;
|
|
contentLength = -1;
|
|
headersEnd = std::string::npos;
|
|
statusCode = 0;
|
|
currentChunkIdx = 0;
|
|
payload.clear();
|
|
}
|
|
|
|
bool parseHeaders()
|
|
{
|
|
headersEnd = buffer.find("\r\n\r\n");
|
|
if (headersEnd == std::string::npos) return false;
|
|
|
|
std::string headers = buffer.substr(0, headersEnd + 4);
|
|
std::string te = getHeaderValue(headers, "Transfer-Encoding");
|
|
std::transform(te.begin(), te.end(), te.begin(), ::tolower);
|
|
isChunked = (te.find("chunked") != std::string::npos);
|
|
|
|
if (!isChunked)
|
|
{
|
|
std::string cl = getHeaderValue(headers, "Content-Length");
|
|
contentLength = stoiSafe(cl, -1);
|
|
}
|
|
|
|
if (headers.compare(0, 5, "HTTP/") == 0)
|
|
{
|
|
size_t space = headers.find(' ');
|
|
if (space != std::string::npos) statusCode = stoiSafe(headers.substr(space + 1, 3));
|
|
}
|
|
|
|
isReceivingBody = true;
|
|
return true;
|
|
}
|
|
};
|
|
|
|
/*
|
|
proxy impl
|
|
*/
|
|
Proxy::Proxy() {}
|
|
|
|
Proxy::~Proxy()
|
|
{
|
|
shutdown();
|
|
}
|
|
|
|
void Proxy::addWhitelistDomain(const std::string& domain)
|
|
{
|
|
_whitelistDomains.push_back(domain);
|
|
}
|
|
|
|
bool Proxy::init()
|
|
{
|
|
if (!_certManager.init()) return false;
|
|
|
|
initSSL();
|
|
|
|
WSADATA wsaData;
|
|
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false;
|
|
|
|
_listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
|
sockaddr_in serverAddr = {};
|
|
serverAddr.sin_family = AF_INET;
|
|
serverAddr.sin_port = htons(PROXY_PORT);
|
|
inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr);
|
|
|
|
if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false;
|
|
listen(_listenSocket, SOMAXCONN);
|
|
|
|
_running = true;
|
|
for (int i = 0; i < PROXY_THREAD_COUNT; ++i)
|
|
{
|
|
_poolThreads.emplace_back([this]() {
|
|
while (_running)
|
|
{
|
|
SOCKET client;
|
|
{
|
|
std::unique_lock<std::mutex> lock(_queueMutex);
|
|
_queueCond.wait(lock, [this]() { return !_clientQueue.empty() || !_running; });
|
|
if (!_running && _clientQueue.empty()) return;
|
|
client = _clientQueue.front();
|
|
_clientQueue.pop();
|
|
}
|
|
this->handleClient(client);
|
|
}
|
|
});
|
|
}
|
|
|
|
_workerThread = std::thread(&Proxy::loop, this);
|
|
Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT);
|
|
|
|
return true;
|
|
}
|
|
|
|
void Proxy::shutdown()
|
|
{
|
|
if (!_running) return;
|
|
|
|
_running = false;
|
|
if (_listenSocket != INVALID_SOCKET)
|
|
{
|
|
closesocket(_listenSocket);
|
|
_listenSocket = INVALID_SOCKET;
|
|
}
|
|
if (_workerThread.joinable()) _workerThread.join();
|
|
|
|
_queueCond.notify_all();
|
|
for (auto& t : _poolThreads)
|
|
{
|
|
if (t.joinable()) t.join();
|
|
}
|
|
_poolThreads.clear();
|
|
|
|
WSACleanup();
|
|
cleanupSSL();
|
|
}
|
|
|
|
void Proxy::loop()
|
|
{
|
|
while (_running)
|
|
{
|
|
SOCKET client = accept(_listenSocket, NULL, NULL);
|
|
if (!_running)
|
|
{
|
|
if (client != INVALID_SOCKET) closesocket(client);
|
|
break;
|
|
}
|
|
if (client == INVALID_SOCKET) continue;
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lock(_queueMutex);
|
|
_clientQueue.push(client);
|
|
}
|
|
_queueCond.notify_one();
|
|
}
|
|
}
|
|
|
|
void Proxy::handleClient(SOCKET clientSocket)
|
|
{
|
|
AutoSocket clientGuard(clientSocket);
|
|
|
|
char buffer[32768];
|
|
fd_set initialFds;
|
|
FD_ZERO(&initialFds);
|
|
FD_SET(clientGuard, &initialFds);
|
|
struct timeval initialTv = {5, 0};
|
|
if (select(0, &initialFds, NULL, NULL, &initialTv) <= 0) return;
|
|
|
|
int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0);
|
|
if (bytesRead <= 0) return;
|
|
buffer[bytesRead] = '\0';
|
|
|
|
std::string initialReq(buffer);
|
|
if (initialReq.find("CONNECT ") != 0) return;
|
|
|
|
/*
|
|
host info
|
|
*/
|
|
size_t hostStart = 8;
|
|
size_t hostEnd = initialReq.find(' ', hostStart);
|
|
if (hostEnd == std::string::npos) return;
|
|
|
|
std::string fullHost = initialReq.substr(hostStart, hostEnd - hostStart);
|
|
std::string host = fullHost;
|
|
int port = 443;
|
|
size_t colon = fullHost.find(':');
|
|
if (colon != std::string::npos)
|
|
{
|
|
host = fullHost.substr(0, colon);
|
|
port = stoiSafe(fullHost.substr(colon + 1), 443);
|
|
}
|
|
|
|
struct addrinfo hints = {}, *res = nullptr;
|
|
hints.ai_family = AF_INET;
|
|
hints.ai_socktype = SOCK_STREAM;
|
|
if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return;
|
|
|
|
/*
|
|
establish connection to host
|
|
*/
|
|
AutoSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
|
if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0) return freeaddrinfo(res);
|
|
freeaddrinfo(res);
|
|
send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0);
|
|
|
|
/*
|
|
whitelist check
|
|
*/
|
|
bool isWhitelisted = _whitelistDomains.empty();
|
|
for (const auto& d : _whitelistDomains)
|
|
{
|
|
if (host.find(d) != std::string::npos)
|
|
{
|
|
isWhitelisted = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!isWhitelisted)
|
|
{
|
|
int tunnelIdleTimeouts = 0;
|
|
char buf[32768];
|
|
while (_running)
|
|
{
|
|
fd_set readfds;
|
|
FD_ZERO(&readfds);
|
|
FD_SET(clientGuard, &readfds);
|
|
FD_SET(remoteGuard, &readfds);
|
|
struct timeval tv = {1, 0};
|
|
|
|
int sel = select(0, &readfds, NULL, NULL, &tv);
|
|
if (sel < 0) break;
|
|
if (sel == 0)
|
|
{
|
|
tunnelIdleTimeouts++;
|
|
if (tunnelIdleTimeouts > 30) break;
|
|
continue;
|
|
}
|
|
tunnelIdleTimeouts = 0;
|
|
|
|
if (FD_ISSET(clientGuard, &readfds))
|
|
{
|
|
int n = recv(clientGuard, buf, sizeof(buf), 0);
|
|
if (n <= 0) break;
|
|
send(remoteGuard, buf, n, 0);
|
|
}
|
|
if (FD_ISSET(remoteGuard, &readfds))
|
|
{
|
|
int n = recv(remoteGuard, buf, sizeof(buf), 0);
|
|
if (n <= 0) break;
|
|
send(clientGuard, buf, n, 0);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
/*
|
|
SSL
|
|
*/
|
|
SSL_CTX* hostCtx = _certManager.createHostContext(host);
|
|
if (!hostCtx) return;
|
|
|
|
auto sslHandshake = [](SSL* ssl, bool isAccept, SOCKET s) -> bool {
|
|
while (true)
|
|
{
|
|
int ret = isAccept ? SSL_accept(ssl) : SSL_connect(ssl);
|
|
if (ret > 0) return true;
|
|
int err = SSL_get_error(ssl, ret);
|
|
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
|
|
{
|
|
fd_set fds;
|
|
FD_ZERO(&fds);
|
|
FD_SET(s, &fds);
|
|
struct timeval tv = {1, 0};
|
|
if (err == SSL_ERROR_WANT_READ)
|
|
select(0, &fds, NULL, NULL, &tv);
|
|
else
|
|
select(0, NULL, &fds, NULL, &tv);
|
|
continue;
|
|
}
|
|
return false;
|
|
}
|
|
};
|
|
|
|
SSL_ptr clientSSL(SSL_new(hostCtx));
|
|
BIO* clientBio = BIO_new_socket((int)clientGuard, BIO_NOCLOSE);
|
|
SSL_set_bio(clientSSL.get(), clientBio, clientBio);
|
|
if (!sslHandshake(clientSSL.get(), true, clientGuard)) return;
|
|
|
|
SSL_ptr remoteSSL(SSL_new(_clientCtx));
|
|
BIO* remoteBio = BIO_new_socket((int)remoteGuard, BIO_NOCLOSE);
|
|
SSL_set_bio(remoteSSL.get(), remoteBio, remoteBio);
|
|
SSL_set_tlsext_host_name(remoteSSL.get(), host.c_str());
|
|
if (!sslHandshake(remoteSSL.get(), false, remoteGuard)) return;
|
|
|
|
/*
|
|
traffic handler
|
|
*/
|
|
HttpStream clientStream, serverStream;
|
|
std::deque<std::string> pendingUrls;
|
|
bool tunnelMode = false;
|
|
|
|
int idleTimeouts = 0;
|
|
while (_running)
|
|
{
|
|
fd_set readfds;
|
|
FD_ZERO(&readfds);
|
|
FD_SET(clientGuard, &readfds);
|
|
FD_SET(remoteGuard, &readfds);
|
|
|
|
struct timeval tv = {0, 50000};
|
|
|
|
bool hasBuffered = (SSL_pending(clientSSL.get()) > 0 || SSL_pending(remoteSSL.get()) > 0);
|
|
|
|
if (!hasBuffered)
|
|
{
|
|
int sel = select(0, &readfds, NULL, NULL, &tv);
|
|
if (sel < 0) break;
|
|
if (sel == 0)
|
|
{
|
|
idleTimeouts++;
|
|
if (idleTimeouts > 600) break;
|
|
continue;
|
|
}
|
|
}
|
|
else
|
|
FD_ZERO(&readfds);
|
|
idleTimeouts = 0;
|
|
|
|
/*
|
|
client -> server
|
|
*/
|
|
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL.get()) > 0)
|
|
{
|
|
int n = SSL_read(clientSSL.get(), buffer, sizeof(buffer));
|
|
if (n <= 0) break;
|
|
|
|
if (tunnelMode)
|
|
SSL_write(remoteSSL.get(), buffer, n);
|
|
else
|
|
{
|
|
clientStream.buffer.append(buffer, n);
|
|
while (true)
|
|
{
|
|
if (!clientStream.isReceivingBody)
|
|
{
|
|
if (!clientStream.parseHeaders()) break;
|
|
|
|
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
|
|
std::string url = "https://" + host;
|
|
size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1);
|
|
if (s1 != std::string::npos && s2 != std::string::npos)
|
|
url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1);
|
|
pendingUrls.push_back(url);
|
|
}
|
|
|
|
bool complete = false;
|
|
std::string fullBody;
|
|
size_t totalRequestSize = 0;
|
|
|
|
if (clientStream.isChunked)
|
|
{
|
|
if (clientStream.currentChunkIdx == 0)
|
|
clientStream.currentChunkIdx = clientStream.headersEnd + 4;
|
|
|
|
while (clientStream.currentChunkIdx < clientStream.buffer.size())
|
|
{
|
|
size_t idx = clientStream.currentChunkIdx;
|
|
size_t le = clientStream.buffer.find("\r\n", idx);
|
|
if (le == std::string::npos) break;
|
|
int cs = stoiSafe(clientStream.buffer.substr(idx, le - idx), -1, 16);
|
|
if (cs < 0) return;
|
|
|
|
if (idx + (le - idx) + 2 + cs + 2 > clientStream.buffer.size()) break;
|
|
if (cs > 0) clientStream.payload.append(clientStream.buffer, le + 2, cs);
|
|
clientStream.currentChunkIdx = le + 2 + cs + 2;
|
|
if (cs == 0)
|
|
{
|
|
fullBody = std::move(clientStream.payload);
|
|
complete = true;
|
|
totalRequestSize = clientStream.currentChunkIdx;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
int cl = clientStream.contentLength;
|
|
if (cl < 0) cl = 0;
|
|
if (clientStream.buffer.size() >= (clientStream.headersEnd + 4 + cl))
|
|
{
|
|
fullBody = clientStream.buffer.substr(clientStream.headersEnd + 4, cl);
|
|
complete = true;
|
|
totalRequestSize = clientStream.headersEnd + 4 + cl;
|
|
}
|
|
}
|
|
|
|
if (complete)
|
|
{
|
|
std::string url = pendingUrls.back();
|
|
std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4);
|
|
|
|
removeHeader(headers, "Accept-Encoding");
|
|
headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n");
|
|
|
|
OnClientRequest.run(url, fullBody, headers);
|
|
|
|
SSL_write(remoteSSL.get(), headers.data(), (int)headers.size());
|
|
SSL_write(remoteSSL.get(), fullBody.data(), (int)fullBody.size());
|
|
|
|
clientStream.buffer.erase(0, totalRequestSize);
|
|
clientStream.reset();
|
|
}
|
|
else
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
server -> client
|
|
*/
|
|
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL.get()) > 0)
|
|
{
|
|
int n = SSL_read(remoteSSL.get(), buffer, sizeof(buffer));
|
|
bool closed = (n <= 0);
|
|
if (!closed) serverStream.buffer.append(buffer, n);
|
|
|
|
while (true)
|
|
{
|
|
if (!serverStream.isReceivingBody)
|
|
if (!serverStream.parseHeaders()) break;
|
|
|
|
bool complete = false;
|
|
std::string fullBody;
|
|
size_t totalResponseSize = 0;
|
|
|
|
if (serverStream.statusCode == 204 || serverStream.statusCode == 304 ||
|
|
(serverStream.statusCode >= 100 && serverStream.statusCode < 200))
|
|
{
|
|
fullBody = "";
|
|
complete = true;
|
|
totalResponseSize = serverStream.headersEnd + 4;
|
|
}
|
|
else if (serverStream.isChunked)
|
|
{
|
|
if (serverStream.currentChunkIdx == 0) serverStream.currentChunkIdx = serverStream.headersEnd + 4;
|
|
|
|
while (serverStream.currentChunkIdx < serverStream.buffer.size())
|
|
{
|
|
size_t idx = serverStream.currentChunkIdx;
|
|
size_t le = serverStream.buffer.find("\r\n", idx);
|
|
if (le == std::string::npos) break;
|
|
int cs = stoiSafe(serverStream.buffer.substr(idx, le - idx), -1, 16);
|
|
if (cs < 0) return;
|
|
|
|
if (idx + (le - idx) + 2 + cs + 2 > serverStream.buffer.size()) break;
|
|
if (cs > 0) serverStream.payload.append(serverStream.buffer, le + 2, cs);
|
|
serverStream.currentChunkIdx = le + 2 + cs + 2;
|
|
if (cs == 0)
|
|
{
|
|
fullBody = std::move(serverStream.payload);
|
|
complete = true;
|
|
totalResponseSize = serverStream.currentChunkIdx;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
else if (serverStream.contentLength >= 0)
|
|
{
|
|
if (serverStream.buffer.size() >= (serverStream.headersEnd + 4 + serverStream.contentLength))
|
|
{
|
|
fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4, serverStream.contentLength);
|
|
complete = true;
|
|
totalResponseSize = serverStream.headersEnd + 4 + serverStream.contentLength;
|
|
}
|
|
}
|
|
else if (closed)
|
|
{
|
|
fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4);
|
|
complete = true;
|
|
totalResponseSize = serverStream.buffer.size();
|
|
}
|
|
|
|
if (complete)
|
|
{
|
|
std::string url = pendingUrls.empty() ? "https://" + host : pendingUrls.front();
|
|
if (!pendingUrls.empty()) pendingUrls.pop_front();
|
|
|
|
std::string respHeaders = serverStream.buffer.substr(0, serverStream.headersEnd + 4);
|
|
|
|
int statusCode = 200;
|
|
size_t space = respHeaders.find(' ');
|
|
if (space != std::string::npos) statusCode = stoiSafe(respHeaders.substr(space + 1, 3));
|
|
|
|
if (statusCode == 101) tunnelMode = true;
|
|
|
|
OnServerResponse.run(url, fullBody, respHeaders);
|
|
|
|
std::string packet;
|
|
if (tunnelMode || statusCode == 204 || statusCode == 304 || (statusCode >= 100 && statusCode < 200))
|
|
packet = respHeaders + fullBody;
|
|
else
|
|
{
|
|
removeHeader(respHeaders, "Transfer-Encoding");
|
|
removeHeader(respHeaders, "Content-Length");
|
|
respHeaders.insert(respHeaders.size() - 2,
|
|
"Content-Length: " + std::to_string(fullBody.size()) + "\r\n");
|
|
packet = respHeaders + fullBody;
|
|
}
|
|
|
|
SSL_write(clientSSL.get(), packet.data(), (int)packet.size());
|
|
serverStream.buffer.erase(0, totalResponseSize);
|
|
|
|
if (tunnelMode)
|
|
{
|
|
if (serverStream.buffer.size() > 0)
|
|
{
|
|
SSL_write(clientSSL.get(), serverStream.buffer.data(), (int)serverStream.buffer.size());
|
|
serverStream.buffer.clear();
|
|
}
|
|
if (clientStream.buffer.size() > 0)
|
|
{
|
|
SSL_write(remoteSSL.get(), clientStream.buffer.data(), (int)clientStream.buffer.size());
|
|
clientStream.buffer.clear();
|
|
}
|
|
}
|
|
|
|
serverStream.reset();
|
|
if (tunnelMode) break;
|
|
}
|
|
else
|
|
break;
|
|
}
|
|
if (closed || tunnelMode) break;
|
|
}
|
|
}
|
|
if (tunnelMode && _running)
|
|
{
|
|
int tunnelIdleTimeouts = 0;
|
|
while (_running)
|
|
{
|
|
fd_set readfds;
|
|
FD_ZERO(&readfds);
|
|
FD_SET(clientGuard, &readfds);
|
|
FD_SET(remoteGuard, &readfds);
|
|
struct timeval tv = {1, 0};
|
|
|
|
bool hasBuffered = (SSL_pending(clientSSL.get()) > 0 || SSL_pending(remoteSSL.get()) > 0);
|
|
|
|
if (!hasBuffered)
|
|
{
|
|
int sel = select(0, &readfds, NULL, NULL, &tv);
|
|
if (sel < 0) break;
|
|
if (sel == 0)
|
|
{
|
|
tunnelIdleTimeouts++;
|
|
if (tunnelIdleTimeouts > 30) break;
|
|
continue;
|
|
}
|
|
}
|
|
else
|
|
FD_ZERO(&readfds);
|
|
tunnelIdleTimeouts = 0;
|
|
|
|
/*
|
|
client -> server
|
|
*/
|
|
if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL.get()) > 0)
|
|
{
|
|
int n = SSL_read(clientSSL.get(), buffer, sizeof(buffer));
|
|
if (n <= 0) break;
|
|
SSL_write(remoteSSL.get(), buffer, n);
|
|
}
|
|
|
|
/*
|
|
server -> client
|
|
*/
|
|
if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL.get()) > 0)
|
|
{
|
|
int n = SSL_read(remoteSSL.get(), buffer, sizeof(buffer));
|
|
if (n <= 0) break;
|
|
SSL_write(clientSSL.get(), buffer, n);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
bool Proxy::initSSL()
|
|
{
|
|
_clientCtx = SSL_CTX_new(TLS_client_method());
|
|
if (!_clientCtx) return false;
|
|
SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr);
|
|
const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
|
|
SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos));
|
|
return true;
|
|
}
|
|
|
|
void Proxy::cleanupSSL()
|
|
{
|
|
if (!_clientCtx) return;
|
|
SSL_CTX_free(_clientCtx);
|
|
_clientCtx = nullptr;
|
|
}
|