diff --git a/src/unlocker/proxy.cpp b/src/unlocker/proxy.cpp index eedd2a2..06ac8c0 100644 --- a/src/unlocker/proxy.cpp +++ b/src/unlocker/proxy.cpp @@ -4,181 +4,174 @@ #include #include -#include + #include -#include -#include +#include -namespace +/* + memory helpers +*/ +template struct Deleter { - class ScopedSSL - { - SSL* s; - - public: - ScopedSSL(SSL* val = nullptr) : s(val) {} - ~ScopedSSL() - { - if (s) SSL_free(s); - } - void reset(SSL* val) - { - if (s) SSL_free(s); - s = val; - } - operator SSL*() const { return s; } - }; - - class ScopedSocket - { - SOCKET s; - - public: - ScopedSocket(SOCKET val = INVALID_SOCKET) : s(val) {} - ~ScopedSocket() - { - if (s != INVALID_SOCKET) closesocket(s); - } - void reset(SOCKET val) - { - if (s != INVALID_SOCKET) closesocket(s); - s = val; - } - operator SOCKET() const { return s; } - SOCKET get() const { return s; } - }; - - int safe_stoi(const std::string& s, int default_val = 0, int base = 10) - { - if (s.empty()) return default_val; - try + void operator()(T* p) const { - return std::stoi(s, nullptr, base); + if (p) f(p); } - catch (...) - { - return default_val; - } - } +}; +using SSL_ptr = std::unique_ptr>; - void removeHeader(std::string& h, const std::string& key) - { - std::string h_lower = h; - std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower); - std::string key_lower = key; - std::transform(key_lower.begin(), key_lower.end(), key_lower.begin(), ::tolower); - - size_t pos = 0; - while ((pos = h_lower.find(key_lower, pos)) != std::string::npos) - { - if (pos == 0 || h_lower[pos - 1] == '\n') - { - size_t end = h_lower.find("\r\n", pos); - if (end != std::string::npos) - { - h.erase(pos, end - pos + 2); - h_lower.erase(pos, end - pos + 2); - continue; - } - } - pos++; - } - } - - std::string getHeaderValue(const std::string& headers, const std::string& key) - { - std::string h_lower = headers; - std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower); - std::string search = "\n" + key; - std::transform(search.begin(), search.end(), search.begin(), ::tolower); - search += ":"; - - size_t pos = h_lower.find(search); - if (pos == std::string::npos) return ""; - - size_t vStart = pos + search.length(); - while (vStart < headers.size() && (headers[vStart] == ' ' || headers[vStart] == '\t')) - vStart++; - size_t vEnd = headers.find("\r\n", vStart); - if (vEnd == std::string::npos) return ""; - return headers.substr(vStart, vEnd - vStart); - } - - struct HttpStream - { - std::string buffer; - bool isReceivingBody = false; - bool isChunked = false; - int contentLength = -1; - size_t headersEnd = 0; - - void reset() - { - isReceivingBody = false; - isChunked = false; - contentLength = -1; - headersEnd = 0; - } - - bool parseHeaders() - { - if (isReceivingBody) return true; - headersEnd = buffer.find("\r\n\r\n"); - if (headersEnd == std::string::npos) return false; - - std::string headers = buffer.substr(0, headersEnd + 4); - - std::string te = getHeaderValue(headers, "Transfer-Encoding"); - std::transform(te.begin(), te.end(), te.begin(), ::tolower); - isChunked = (te.find("chunked") != std::string::npos); - - if (!isChunked) - { - std::string cl = getHeaderValue(headers, "Content-Length"); - contentLength = safe_stoi(cl, -1); - } - else - { - contentLength = -1; - } - - isReceivingBody = true; - return true; - } - }; -} // namespace - -bool Proxy::initSSL() +struct AutoSocket { - _clientCtx = SSL_CTX_new(TLS_client_method()); - if (!_clientCtx) return false; - SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr); - const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'}; - SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos)); - return true; -} + SOCKET s; + AutoSocket(SOCKET val = INVALID_SOCKET) : s(val) {} + ~AutoSocket() + { + if (s != INVALID_SOCKET) closesocket(s); + } + operator SOCKET() const { return s; } +}; -void Proxy::cleanupSSL() +/* + helper functions +*/ +int stoiSafe(const std::string& s, int default_val = 0, int base = 10) { - if (_clientCtx) + if (s.empty()) return default_val; + try { - SSL_CTX_free(_clientCtx); - _clientCtx = nullptr; + return std::stoi(s, nullptr, base); + } + catch (...) + { + return default_val; } } +void removeHeader(std::string& headers, const std::string& key) +{ + std::string result; + size_t start = 0; + size_t end; + + std::string keyLower = key; + std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower); + if (keyLower.back() != ':') keyLower += ':'; + + while ((end = headers.find('\n', start)) != std::string::npos) + { + std::string line = headers.substr(start, end - start + 1); + std::string lineLower = line; + std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower); + + if (lineLower.compare(0, keyLower.length(), keyLower) != 0) result += line; + + start = end + 1; + } + + if (start < headers.length()) + { + std::string line = headers.substr(start); + std::string lineLower = line; + if (lineLower.compare(0, keyLower.length(), keyLower) != 0) result += line; + } + + headers = std::move(result); +} + +std::string getHeaderValue(const std::string& headers, const std::string& key) +{ + size_t start = 0; + size_t end; + + std::string keyLower = key; + std::transform(keyLower.begin(), keyLower.end(), keyLower.begin(), ::tolower); + if (keyLower.empty()) return ""; + if (keyLower.back() != ':') keyLower += ':'; + + while (start < headers.length()) + { + end = headers.find('\n', start); + + std::string line = (end == std::string::npos) ? headers.substr(start) : headers.substr(start, end - start); + std::string lineLower = line; + std::transform(lineLower.begin(), lineLower.end(), lineLower.begin(), ::tolower); + + if (lineLower.compare(0, keyLower.length(), keyLower) == 0) + { + size_t valueStart = keyLower.length(); + while (valueStart < line.length() && (line[valueStart] == ' ' || line[valueStart] == '\t')) + valueStart++; + + size_t valueEnd = line.length(); + if (valueEnd > valueStart && line[valueEnd - 1] == '\r') valueEnd--; + + return line.substr(valueStart, valueEnd - valueStart); + } + + if (end == std::string::npos) break; + start = end + 1; + } + + return ""; +} + +/* + http stream helper class +*/ +struct HttpStream +{ + std::string buffer; + bool isReceivingBody = false; + bool isChunked = false; + int contentLength = -1; + size_t headersEnd = std::string::npos; + + void reset() + { + isReceivingBody = false; + isChunked = false; + contentLength = -1; + headersEnd = std::string::npos; + } + + bool parseHeaders() + { + headersEnd = buffer.find("\r\n\r\n"); + if (headersEnd == std::string::npos) return false; + + std::string headers = buffer.substr(0, headersEnd + 4); + std::string te = getHeaderValue(headers, "Transfer-Encoding"); + std::transform(te.begin(), te.end(), te.begin(), ::tolower); + isChunked = (te.find("chunked") != std::string::npos); + + if (!isChunked) + { + std::string cl = getHeaderValue(headers, "Content-Length"); + contentLength = stoiSafe(cl, -1); + } + isReceivingBody = true; + return true; + } +}; + +/* + proxy impl +*/ Proxy::Proxy() {} + Proxy::~Proxy() { - Shutdown(); + shutdown(); } -bool Proxy::Init() +bool Proxy::init() { + if (!_certManager.init()) return false; + + initSSL(); + WSADATA wsaData; if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) return false; - if (!_certManager.Init()) return false; - if (!initSSL()) return false; _listenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); sockaddr_in serverAddr = {}; @@ -188,15 +181,18 @@ bool Proxy::Init() if (bind(_listenSocket, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) return false; listen(_listenSocket, SOMAXCONN); - Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT); _running = true; _workerThread = std::thread(&Proxy::loop, this); + Log::verbose("Proxy active on 127.0.0.1:{}", PROXY_PORT); + return true; } -void Proxy::Shutdown() +void Proxy::shutdown() { + if (!_running) return; + _running = false; if (_listenSocket != INVALID_SOCKET) { @@ -212,22 +208,22 @@ void Proxy::loop() { while (_running) { - SOCKET hClient = accept(_listenSocket, NULL, NULL); + SOCKET client = accept(_listenSocket, NULL, NULL); if (!_running) { - if (hClient != INVALID_SOCKET) closesocket(hClient); + if (client != INVALID_SOCKET) closesocket(client); break; } - if (hClient == INVALID_SOCKET) continue; - std::thread([this, hClient]() { this->handleClient(hClient); }).detach(); + if (client == INVALID_SOCKET) continue; + std::thread([this, client]() { this->handleClient(client); }).detach(); } } -void Proxy::handleClient(SOCKET hClientSocket) +void Proxy::handleClient(SOCKET clientSocket) { - ScopedSocket clientGuard(hClientSocket); - char buffer[32768]; + AutoSocket clientGuard(clientSocket); + char buffer[32768]; int bytesRead = recv(clientGuard, buffer, sizeof(buffer) - 1, 0); if (bytesRead <= 0) return; buffer[bytesRead] = '\0'; @@ -235,6 +231,9 @@ void Proxy::handleClient(SOCKET hClientSocket) std::string initialReq(buffer); if (initialReq.find("CONNECT ") != 0) return; + /* + host info + */ size_t hostStart = 8; size_t hostEnd = initialReq.find(' ', hostStart); if (hostEnd == std::string::npos) return; @@ -246,7 +245,7 @@ void Proxy::handleClient(SOCKET hClientSocket) if (colon != std::string::npos) { host = fullHost.substr(0, colon); - port = safe_stoi(fullHost.substr(colon + 1), 443); + port = stoiSafe(fullHost.substr(colon + 1), 443); } struct addrinfo hints = {}, *res = nullptr; @@ -254,35 +253,39 @@ void Proxy::handleClient(SOCKET hClientSocket) hints.ai_socktype = SOCK_STREAM; if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res) != 0) return; - ScopedSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); - if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0) - { - freeaddrinfo(res); - return; - } + /* + establish connection to host + */ + AutoSocket remoteGuard(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + if (connect(remoteGuard, res->ai_addr, (int)res->ai_addrlen) != 0) return freeaddrinfo(res); freeaddrinfo(res); - send(clientGuard, "HTTP/1.1 200 Connection Established\r\n\r\n", 39, 0); - SSL_CTX* hostCtx = _certManager.CreateHostContext(host); + /* + SSL + */ + SSL_CTX* hostCtx = _certManager.createHostContext(host); if (!hostCtx) return; - ScopedSSL clientSSL(SSL_new(hostCtx)); - SSL_set_fd(clientSSL, (int)clientGuard.get()); - if (SSL_accept(clientSSL) <= 0) return; + SSL_ptr clientSSL(SSL_new(hostCtx)); + SSL_set_fd(clientSSL.get(), (int)clientGuard); + if (SSL_accept(clientSSL.get()) <= 0) return; - ScopedSSL remoteSSL(SSL_new(_clientCtx)); - SSL_set_fd(remoteSSL, (int)remoteGuard.get()); - SSL_set_tlsext_host_name(remoteSSL, host.c_str()); - if (SSL_connect(remoteSSL) <= 0) return; + SSL_ptr remoteSSL(SSL_new(_clientCtx)); + SSL_set_fd(remoteSSL.get(), (int)remoteGuard); + SSL_set_tlsext_host_name(remoteSSL.get(), host.c_str()); + if (SSL_connect(remoteSSL.get()) <= 0) return; + /* + traffic handler + */ HttpStream clientStream, serverStream; std::deque pendingUrls; bool tunnelMode = false; - fd_set readfds; while (_running) { + fd_set readfds; FD_ZERO(&readfds); FD_SET(clientGuard, &readfds); FD_SET(remoteGuard, &readfds); @@ -290,108 +293,82 @@ void Proxy::handleClient(SOCKET hClientSocket) struct timeval tv = {0, 50000}; if (select(0, &readfds, NULL, NULL, &tv) < 0) break; - if (tunnelMode) + /* + client -> server + */ + if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL.get()) > 0) { - if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0) - { - int n = SSL_read(clientSSL, buffer, sizeof(buffer)); - if (n <= 0) break; - SSL_write(remoteSSL, buffer, n); - } - if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0) - { - int n = SSL_read(remoteSSL, buffer, sizeof(buffer)); - if (n <= 0) break; - SSL_write(clientSSL, buffer, n); - } - continue; - } - - if (FD_ISSET(clientGuard, &readfds) || SSL_pending(clientSSL) > 0) - { - int n = SSL_read(clientSSL, buffer, sizeof(buffer)); + int n = SSL_read(clientSSL.get(), buffer, sizeof(buffer)); if (n <= 0) break; - clientStream.buffer.append(buffer, n); - while (!clientStream.buffer.empty()) + if (tunnelMode) + SSL_write(remoteSSL.get(), buffer, n); + else { - if (!clientStream.isReceivingBody) + clientStream.buffer.append(buffer, n); + while (true) { - if (!clientStream.parseHeaders()) break; - - std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); - std::string url = "https://" + host; - size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1); - if (s1 != std::string::npos && s2 != std::string::npos) - url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1); - - pendingUrls.push_back(url); - - removeHeader(headers, "Accept-Encoding"); - removeHeader(headers, "Expect"); - headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n"); - - if (clientStream.contentLength == 0 || (clientStream.contentLength < 0 && !clientStream.isChunked)) + if (!clientStream.isReceivingBody) { - std::string emptyBody = ""; - OnClientRequest.run(url, emptyBody, headers); + if (!clientStream.parseHeaders()) break; - if (!pendingUrls.empty()) pendingUrls.back() = url; - - SSL_write(remoteSSL, headers.data(), (int)headers.size()); - clientStream.buffer.erase(0, clientStream.headersEnd + 4); - clientStream.reset(); + std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); + std::string url = "https://" + host; + size_t s1 = headers.find(' '), s2 = headers.find(' ', s1 + 1); + if (s1 != std::string::npos && s2 != std::string::npos) + url = "https://" + host + headers.substr(s1 + 1, s2 - s1 - 1); + pendingUrls.push_back(url); } - } - - if (clientStream.isReceivingBody) - { - size_t bodyStart = clientStream.headersEnd + 4; - std::string url = pendingUrls.back(); - std::string headers = clientStream.buffer.substr(0, bodyStart); - removeHeader(headers, "Accept-Encoding"); - removeHeader(headers, "Expect"); - headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n"); bool complete = false; - std::string body; + std::string fullBody; + size_t totalRequestSize = 0; if (clientStream.isChunked) { - size_t idx = bodyStart; + size_t idx = clientStream.headersEnd + 4; while (idx < clientStream.buffer.size()) { size_t le = clientStream.buffer.find("\r\n", idx); if (le == std::string::npos) break; - int cs = safe_stoi(clientStream.buffer.substr(idx, le - idx), 0, 16); + int cs = stoiSafe(clientStream.buffer.substr(idx, le - idx), 0, 16); if (idx + (le - idx) + 2 + cs + 2 > clientStream.buffer.size()) break; - body.append(clientStream.buffer, le + 2, cs); + if (cs > 0) fullBody.append(clientStream.buffer, le + 2, cs); idx = le + 2 + cs + 2; if (cs == 0) { complete = true; + totalRequestSize = idx; break; } } } - else if (clientStream.contentLength >= 0) + else { - if (clientStream.buffer.size() >= bodyStart + clientStream.contentLength) + int cl = clientStream.contentLength; + if (cl < 0) cl = 0; + if (clientStream.buffer.size() >= (clientStream.headersEnd + 4 + cl)) { - body = clientStream.buffer.substr(bodyStart, clientStream.contentLength); + fullBody = clientStream.buffer.substr(clientStream.headersEnd + 4, cl); complete = true; + totalRequestSize = clientStream.headersEnd + 4 + cl; } } if (complete) { - OnClientRequest.run(url, body, headers); - if (!pendingUrls.empty() && pendingUrls.back() != url) pendingUrls.back() = url; + std::string url = pendingUrls.back(); + std::string headers = clientStream.buffer.substr(0, clientStream.headersEnd + 4); - SSL_write(remoteSSL, headers.data(), (int)headers.size()); - SSL_write(remoteSSL, clientStream.buffer.data() + bodyStart, - (int)(clientStream.buffer.size() - bodyStart)); - clientStream.buffer.clear(); + removeHeader(headers, "Accept-Encoding"); + headers.insert(headers.size() - 2, "Accept-Encoding: identity\r\n"); + + OnClientRequest.run(url, fullBody, headers); + + SSL_write(remoteSSL.get(), headers.data(), (int)headers.size()); + SSL_write(remoteSSL.get(), fullBody.data(), (int)fullBody.size()); + + clientStream.buffer.erase(0, totalRequestSize); clientStream.reset(); } else @@ -400,169 +377,126 @@ void Proxy::handleClient(SOCKET hClientSocket) } } - if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL) > 0) + /* + server -> client + */ + if (FD_ISSET(remoteGuard, &readfds) || SSL_pending(remoteSSL.get()) > 0) { - int n = SSL_read(remoteSSL, buffer, sizeof(buffer)); - bool connectionClosed = (n <= 0); - if (!connectionClosed) - { - serverStream.buffer.append(buffer, n); - } + int n = SSL_read(remoteSSL.get(), buffer, sizeof(buffer)); + bool closed = (n <= 0); + if (!closed) serverStream.buffer.append(buffer, n); - while (!serverStream.buffer.empty() || connectionClosed) + while (true) { if (!serverStream.isReceivingBody) + if (!serverStream.parseHeaders()) break; + + bool complete = false; + std::string fullBody; + size_t totalResponseSize = 0; + + if (serverStream.isChunked) { - if (connectionClosed) break; - - serverStream.headersEnd = serverStream.buffer.find("\r\n\r\n"); - if (serverStream.headersEnd == std::string::npos) break; - - std::string headers = serverStream.buffer.substr(0, serverStream.headersEnd + 4); - size_t s1 = headers.find(' '); - int sCode = (s1 != std::string::npos) ? safe_stoi(headers.substr(s1 + 1, 3)) : 0; - - if (sCode == 101) + size_t idx = serverStream.headersEnd + 4; + while (idx < serverStream.buffer.size()) { - SSL_write(clientSSL, serverStream.buffer.data(), (int)serverStream.buffer.size()); - serverStream.buffer.clear(); - clientStream.buffer.clear(); - tunnelMode = true; - break; + size_t le = serverStream.buffer.find("\r\n", idx); + if (le == std::string::npos) break; + int cs = stoiSafe(serverStream.buffer.substr(idx, le - idx), 0, 16); + if (idx + (le - idx) + 2 + cs + 2 > serverStream.buffer.size()) break; + if (cs > 0) fullBody.append(serverStream.buffer, le + 2, cs); + idx = le + 2 + cs + 2; + if (cs == 0) + { + complete = true; + totalResponseSize = idx; + break; + } } - - if (sCode >= 100 && sCode < 200) + } + else if (serverStream.contentLength >= 0) + { + if (serverStream.buffer.size() >= (serverStream.headersEnd + 4 + serverStream.contentLength)) { - SSL_write(clientSSL, headers.data(), (int)headers.size()); - serverStream.buffer.erase(0, serverStream.headersEnd + 4); - serverStream.isReceivingBody = false; - continue; + fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4, serverStream.contentLength); + complete = true; + totalResponseSize = serverStream.headersEnd + 4 + serverStream.contentLength; } - - serverStream.isReceivingBody = true; - std::string h_lower = headers; - std::transform(h_lower.begin(), h_lower.end(), h_lower.begin(), ::tolower); - serverStream.isChunked = (h_lower.find("transfer-encoding: chunked") != std::string::npos); - - size_t clPos = h_lower.find("content-length:"); - if (clPos != std::string::npos) - { - size_t vStart = clPos + 15; - while (vStart < h_lower.size() && (h_lower[vStart] == ' ' || h_lower[vStart] == '\t')) - vStart++; - serverStream.contentLength = - safe_stoi(h_lower.substr(vStart, h_lower.find("\r\n", vStart) - vStart), -1); - } - else if (sCode == 204 || sCode == 304 || sCode == 205) - serverStream.contentLength = 0; - else if (!serverStream.isChunked) - serverStream.contentLength = -1; + } + else if (closed) + { + fullBody = serverStream.buffer.substr(serverStream.headersEnd + 4); + complete = true; + totalResponseSize = serverStream.buffer.size(); } - if (serverStream.isReceivingBody) + if (complete) { - bool complete = false; - std::string body; - size_t bStart = serverStream.headersEnd + 4; - size_t processed = bStart; + std::string url = pendingUrls.empty() ? "https://" + host : pendingUrls.front(); + if (!pendingUrls.empty()) pendingUrls.pop_front(); - if (serverStream.isChunked) - { - size_t idx = bStart; - bool ok = true; - while (idx < serverStream.buffer.size()) - { - size_t le = serverStream.buffer.find("\r\n", idx); - if (le == std::string::npos) - { - ok = false; - break; - } - int cs = safe_stoi(serverStream.buffer.substr(idx, le - idx), 0, 16); - idx = le + 2; - if (cs == 0) - { - idx += 2; - complete = true; - processed = idx; - break; - } - if (idx + cs + 2 > serverStream.buffer.size()) - { - ok = false; - break; - } - body.append(serverStream.buffer, idx, cs); - idx += cs + 2; - } - if (!ok) - { - if (connectionClosed) - { - complete = true; - processed = serverStream.buffer.size(); - } - else - complete = false; - } - } - else if (serverStream.contentLength >= 0) - { - if (serverStream.buffer.size() >= bStart + serverStream.contentLength) - { - complete = true; - processed = bStart + serverStream.contentLength; - body = serverStream.buffer.substr(bStart, serverStream.contentLength); - } - else if (connectionClosed) - { - complete = true; - processed = serverStream.buffer.size(); - body = serverStream.buffer.substr(bStart); - } - } + std::string respHeaders = serverStream.buffer.substr(0, serverStream.headersEnd + 4); + + OnServerResponse.run(url, fullBody, respHeaders); + + std::string packet; + int statusCode = 200; + size_t space = respHeaders.find(' '); + if (space != std::string::npos) statusCode = stoiSafe(respHeaders.substr(space + 1, 3)); + + if (tunnelMode || statusCode == 204 || statusCode == 304 || (statusCode >= 100 && statusCode < 200)) + packet = respHeaders + fullBody; else { - if (connectionClosed) - { - complete = true; - processed = serverStream.buffer.size(); - body = serverStream.buffer.substr(bStart); - } - else - complete = false; - } - - if (complete) - { - std::string url = pendingUrls.empty() ? ("https://" + host) : pendingUrls.front(); - if (!pendingUrls.empty()) pendingUrls.pop_front(); - - std::string respHeaders = serverStream.buffer.substr(0, bStart); - OnServerResponse.run(url, body, respHeaders); - removeHeader(respHeaders, "Transfer-Encoding"); removeHeader(respHeaders, "Content-Length"); - - size_t fs = respHeaders.find(' '); - int scFinal = (fs != std::string::npos) ? safe_stoi(respHeaders.substr(fs + 1, 3)) : 0; - - if (scFinal != 204 && scFinal != 304 && scFinal != 205) - respHeaders.insert(respHeaders.size() - 2, - "Content-Length: " + std::to_string(body.size()) + "\r\n"); - - std::string packet = respHeaders + body; - SSL_write(clientSSL, packet.data(), (int)packet.size()); - - serverStream.buffer.erase(0, processed); - serverStream.reset(); - clientStream.reset(); + respHeaders.insert(respHeaders.size() - 2, + "Content-Length: " + std::to_string(fullBody.size()) + "\r\n"); + packet = respHeaders + fullBody; } - else - break; + + SSL_write(clientSSL.get(), packet.data(), (int)packet.size()); + + serverStream.buffer.erase(0, totalResponseSize); + + if (tunnelMode) + { + if (serverStream.buffer.size() > 0) + { + SSL_write(clientSSL.get(), serverStream.buffer.data(), (int)serverStream.buffer.size()); + serverStream.buffer.clear(); + } + if (clientStream.buffer.size() > 0) + { + SSL_write(remoteSSL.get(), clientStream.buffer.data(), (int)clientStream.buffer.size()); + clientStream.buffer.clear(); + } + } + + serverStream.reset(); + if (tunnelMode) break; } + else + break; } - if (connectionClosed) break; + if (closed) break; } } } + +bool Proxy::initSSL() +{ + _clientCtx = SSL_CTX_new(TLS_client_method()); + if (!_clientCtx) return false; + SSL_CTX_set_verify(_clientCtx, SSL_VERIFY_NONE, nullptr); + const unsigned char alpn_protos[] = {8, 'h', 't', 't', 'p', '/', '1', '.', '1'}; + SSL_CTX_set_alpn_protos(_clientCtx, alpn_protos, sizeof(alpn_protos)); + return true; +} + +void Proxy::cleanupSSL() +{ + if (!_clientCtx) return; + SSL_CTX_free(_clientCtx); + _clientCtx = nullptr; +} diff --git a/src/unlocker/proxy.h b/src/unlocker/proxy.h index 7e0441b..6332c83 100644 --- a/src/unlocker/proxy.h +++ b/src/unlocker/proxy.h @@ -5,7 +5,7 @@ #include #include #include -#include "cert_manager.h" +#include "ssl.h" #include /* @@ -22,8 +22,8 @@ class Proxy Proxy(); ~Proxy(); - bool Init(); - void Shutdown(); + bool init(); + void shutdown(); CallbackEvent OnClientRequest; CallbackEvent OnServerResponse;