diff --git a/src/unlocker/proxy.cpp b/src/unlocker/proxy.cpp index 4bf4a47..b3bb9e9 100644 --- a/src/unlocker/proxy.cpp +++ b/src/unlocker/proxy.cpp @@ -262,6 +262,12 @@ void Proxy::handleClient(SOCKET clientSocket) AutoSocket clientGuard(clientSocket); char buffer[32768]; + fd_set initialFds; + FD_ZERO(&initialFds); + FD_SET(clientGuard, &initialFds); + struct timeval initialTv = {5, 0}; + if (select(0, &initialFds, NULL, NULL, &initialTv) <= 0) return; + int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0); if (bytesRead <= 0) return; buffer[bytesRead] = '\0'; @@ -328,11 +334,13 @@ void Proxy::handleClient(SOCKET clientSocket) }; SSL_ptr clientSSL(SSL_new(hostCtx)); - SSL_set_fd(clientSSL.get(), (int)clientGuard); + BIO* clientBio = BIO_new_socket((int)clientGuard, BIO_NOCLOSE); + SSL_set_bio(clientSSL.get(), clientBio, clientBio); if (!sslHandshake(clientSSL.get(), true, clientGuard)) return; SSL_ptr remoteSSL(SSL_new(_clientCtx)); - SSL_set_fd(remoteSSL.get(), (int)remoteGuard); + BIO* remoteBio = BIO_new_socket((int)remoteGuard, BIO_NOCLOSE); + SSL_set_bio(remoteSSL.get(), remoteBio, remoteBio); SSL_set_tlsext_host_name(remoteSSL.get(), host.c_str()); if (!sslHandshake(remoteSSL.get(), false, remoteGuard)) return; @@ -343,6 +351,7 @@ void Proxy::handleClient(SOCKET clientSocket) std::deque pendingUrls; bool tunnelMode = false; + int idleTimeouts = 0; while (_running) { fd_set readfds; @@ -351,7 +360,21 @@ void Proxy::handleClient(SOCKET clientSocket) FD_SET(remoteGuard, &readfds); struct timeval tv = {0, 50000}; - if (select(0, &readfds, NULL, NULL, &tv) < 0) break; + + bool hasBuffered = (SSL_pending(clientSSL.get()) > 0 || SSL_pending(remoteSSL.get()) > 0); + + if (!hasBuffered) + { + int sel = select(0, &readfds, NULL, NULL, &tv); + if (sel < 0) break; + if (sel == 0) + { + idleTimeouts++; + if (idleTimeouts > 600) break; + continue; + } + } + idleTimeouts = 0; /* client -> server @@ -501,13 +524,15 @@ void Proxy::handleClient(SOCKET clientSocket) std::string respHeaders = serverStream.buffer.substr(0, serverStream.headersEnd + 4); - OnServerResponse.run(url, fullBody, respHeaders); - - std::string packet; int statusCode = 200; size_t space = respHeaders.find(' '); if (space != std::string::npos) statusCode = stoiSafe(respHeaders.substr(space + 1, 3)); + if (statusCode == 101) tunnelMode = true; + + OnServerResponse.run(url, fullBody, respHeaders); + + std::string packet; if (tunnelMode || statusCode == 204 || statusCode == 304 || (statusCode >= 100 && statusCode < 200)) packet = respHeaders + fullBody; else @@ -520,7 +545,6 @@ void Proxy::handleClient(SOCKET clientSocket) } SSL_write(clientSSL.get(), packet.data(), (int)packet.size()); - serverStream.buffer.erase(0, totalResponseSize); if (tunnelMode) @@ -546,6 +570,53 @@ void Proxy::handleClient(SOCKET clientSocket) if (closed) break; } } + if (tunnelMode && _running) + { + int tunnelIdleTimeouts = 0; + while (_running) + { + fd_set readfds; + FD_ZERO(&readfds); + FD_SET(clientGuard, &readfds); + FD_SET(remoteGuard, &readfds); + struct timeval tv = {1, 0}; + + bool hasBuffered = (SSL_pending(clientSSL.get()) > 0 || SSL_pending(remoteSSL.get()) > 0); + + if (!hasBuffered) + { + int sel = select(0, &readfds, NULL, NULL, &tv); + if (sel < 0) break; + if (sel == 0) + { + tunnelIdleTimeouts++; + if (tunnelIdleTimeouts > 30) break; + continue; + } + } + tunnelIdleTimeouts = 0; + + /* + client -> server + */ + if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL.get()) > 0) + { + int n = SSL_read(clientSSL.get(), buffer, sizeof(buffer)); + if (n <= 0) break; + SSL_write(remoteSSL.get(), buffer, n); + } + + /* + server -> client + */ + if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL.get()) > 0) + { + int n = SSL_read(remoteSSL.get(), buffer, sizeof(buffer)); + if (n <= 0) break; + SSL_write(clientSSL.get(), buffer, n); + } + } + } } bool Proxy::initSSL()