Fix BIO_do_connect freeze

This commit is contained in:
2024-01-10 17:33:01 +08:00
parent db452fce0f
commit c6105db80d
4 changed files with 186 additions and 75 deletions

View File

@@ -5,6 +5,7 @@ option(ENABLE_STANDALONE "Build utils standalone" OFF)
option(ENABLE_CXX17 "Enable C++ 17" OFF)
option(INSTALL_DEP_FILES "Install a file with dependences." OFF)
option(ENABLE_SSL "Enable SSL" OFF)
option(ENABLE_ZLIB "Use Zlib to uncompress http data." OFF)
if (ENABLE_STANDALONE)
project(utils)
@@ -29,6 +30,12 @@ if (ENABLE_SSL)
set(HAVE_OPENSSL 1)
endif()
if (ENABLE_ZLIB)
find_package(ZLIB REQUIRED)
set(HAVE_ZLIB 1)
endif()
if (Iconv_FOUND)
set(HAVE_ICONV 1)
endif()
@@ -132,6 +139,9 @@ endif()
if (ENABLE_SSL)
target_link_libraries(utils OpenSSL::SSL OpenSSL::Crypto)
endif()
if (ENABLE_ZLIB)
target_link_libraries(utils ZLIB::ZLIB)
endif()
if (ENABLE_CXX17)
target_compile_features(utils PRIVATE cxx_std_17)
endif()

View File

@@ -98,19 +98,15 @@ const char* SocketError::what() {
return this->message.c_str();
}
Socket::Socket(std::string host, std::string protocol) {
Socket::Socket(std::string host, std::string port, bool https) {
this->host = host;
this->protocol = protocol;
#if HAVE_OPENSSL
if (protocol == "https") {
return;
}
#endif
this->port = port;
this->https = https;
addrinfo hints = { 0 };
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
int re = getaddrinfo(host.c_str(), protocol.c_str(), &hints, &this->addr);
int re = getaddrinfo(host.c_str(), port.c_str(), &hints, &this->addr);
if (re) {
throw AIException(re);
}
@@ -118,17 +114,17 @@ Socket::Socket(std::string host, std::string protocol) {
Socket::~Socket() {
if (this->moved) return;
this->close();
#if HAVE_OPENSSL
if (this->ssl) {
SSL_free(this->ssl);
this->ssl = nullptr;
}
if (this->ssl_ctx) {
SSL_CTX_free(this->ssl_ctx);
this->ssl_ctx = nullptr;
}
if (this->web) {
BIO_free_all(this->web);
this->web = nullptr;
}
#endif
this->close();
if (this->addr) {
freeaddrinfo(this->addr);
this->addr = nullptr;
@@ -136,45 +132,6 @@ Socket::~Socket() {
}
void Socket::connect() {
#if HAVE_OPENSSL
if (this->protocol == "https") {
const SSL_METHOD* method = SSLv23_client_method();
if (!method) {
throw std::runtime_error("SSLv23_client_method failed");
}
this->ssl_ctx = SSL_CTX_new(method);
if (!this->ssl_ctx) {
throw std::runtime_error("SSL_CTX_new failed");
}
const long flags = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION;
SSL_CTX_set_options(this->ssl_ctx, flags);
this->web = BIO_new_ssl_connect(this->ssl_ctx);
if (!this->web) {
throw std::runtime_error("BIO_new_ssl_connect failed");
}
int res =BIO_set_conn_hostname(this->web, this->host.c_str());
if (res == -1) {
throw std::runtime_error("BIO_set_conn_hostname failed");
}
BIO_get_ssl(this->web, &this->ssl);
if (!this->ssl) {
throw std::runtime_error("BIO_get_ssl failed");
}
res = SSL_set_tlsext_host_name(this->ssl, this->host.c_str());
if (res == -1) {
throw std::runtime_error("SSL_set_tlsext_host_name failed");
}
res = BIO_do_connect(this->web);
if (res != 1) {
throw std::runtime_error("BIO_do_connect failed");
}
res = BIO_do_handshake(this->web);
if (res != 1) {
throw std::runtime_error("BIO_do_handshake failed");
}
return;
}
#endif
if (this->socket != -1 || !this->addr) {
return;
}
@@ -194,21 +151,39 @@ void Socket::connect() {
#endif
throw SocketError();
}
#if HAVE_OPENSSL
if (this->https) {
this->ssl_ctx = SSL_CTX_new(TLS_client_method());
if (!this->ssl_ctx) {
throw std::runtime_error("SSL_CTX_new failed");
}
this->ssl = SSL_new(this->ssl_ctx);
if (!this->ssl) {
throw std::runtime_error("SSL_new failed");
}
SSL_set_fd(this->ssl, this->socket);
SSL_set_tlsext_host_name(this->ssl, this->host.c_str());
re = SSL_connect(this->ssl);
if (re != 1) {
throw std::runtime_error("SSL_connect failed");
}
}
#endif
}
size_t Socket::send(const char* data, size_t len, int flags) {
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if HAVE_OPENSSL
if (this->protocol == "https") {
int sented = BIO_write(this->web, data, (int)len);
if (this->https) {
int sented = SSL_write(this->ssl, data, (int)len);
if (sented <= 0) {
throw std::runtime_error("BIO_write failed");
}
return sented;
}
#endif
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if _WIN32
int sented = ::send(this->socket, data, (int)len, flags);
if (sented == SOCKET_ERROR) {
@@ -226,18 +201,18 @@ size_t Socket::send(std::string data, int flags) {
}
size_t Socket::recv(char* data, size_t len, int flags) {
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if HAVE_OPENSSL
if (this->protocol == "https") {
int recved = BIO_read(this->web, data, (int)len);
if (recved <= 0) {
throw std::runtime_error("BIO_read failed");
if (this->https) {
int recved = SSL_read(this->ssl, data, (int)len);
if (recved < 0) {
throw std::runtime_error("SSL_read failed");
}
return recved;
}
#endif
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if _WIN32
int recved = ::recv(this->socket, data, (int)len, flags);
if (recved == SOCKET_ERROR) {
@@ -265,6 +240,12 @@ std::string Socket::recv(size_t len, int flags) {
void Socket::close() {
if (this->socket != -1 && !closed) {
#if HAVE_OPENSSL
if (this->https) {
SSL_set_shutdown(this->ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
SSL_shutdown(this->ssl);
}
#endif
#if _WIN32
closesocket(this->socket);
#else
@@ -274,8 +255,10 @@ void Socket::close() {
}
}
Request::Request(std::string host, std::string path, std::string method, HeaderMap headers, HttpClientOptions options) {
Request::Request(std::string host, std::string port, bool https, std::string path, std::string method, HeaderMap headers, HttpClientOptions options) {
this->host = host;
this->port = port;
this->https = https;
this->path = path;
this->method = method;
this->headers = headers;
@@ -285,7 +268,7 @@ Request::Request(std::string host, std::string path, std::string method, HeaderM
Response Request::send() {
std::string data;
data += this->method + " " + this->path + " HTTP/1.1\r\n";
auto hasBody = !this->body || !this->body->isFinished();
auto hasBody = this->body && !this->body->isFinished();
if (hasBody) {
auto type = this->body->contentType();
if (!type.empty()) {
@@ -300,7 +283,7 @@ Response Request::send() {
}
data += "\r\n";
#if HAVE_OPENSSL
Socket socket(this->host, this->options.https ? "https" : "http");
Socket socket(this->host, this->port, this->https);
#else
Socket socket(this->host, "http");
#endif
@@ -323,7 +306,29 @@ void Request::setBody(HttpBody* body) {
}
HttpClient::HttpClient(std::string host) {
this->host = host;
auto pos = host.find("://");
bool is_number_port = false;
if (pos == std::string::npos) {
this->host = host;
} else {
this->host = host.substr(pos + 3);
auto protocol = host.substr(0, pos);
if (!cstr_stricmp(protocol.c_str(), "http")) {
this->https = false;
} else if (!cstr_stricmp(protocol.c_str(), "https")) {
this->https = true;
} else {
throw std::runtime_error("Unspported protocol");
}
}
pos = this->host.find(":");
if (pos != std::string::npos) {
this->port = this->host.substr(pos + 1);
this->host = this->host.substr(0, pos);
is_number_port = true;
} else {
this->port = this->https ? "https" : "http";
}
#if _WIN32
if (!inited) {
WSAStartup(MAKEWORD(2, 2), &wsaData);
@@ -340,13 +345,22 @@ HttpClient::HttpClient(std::string host) {
ssl_inited = true;
}
#endif
this->headers["Host"] = host;
if (is_number_port) {
this->headers["Host"] = this->host + ":" + this->port;
} else {
this->headers["Host"] = this->host;
}
this->headers["User-Agent"] = "simple-http-client";
this->headers["Accept"] = "*/*";
std::string ae = "";
#if HAVE_ZLIB
ae += "deflate, gzip";
#endif
this->headers["Accept-Encoding"] = ae;
}
Request HttpClient::request(std::string path, std::string method) {
return Request(this->host, path, method, this->headers, this->options);
return Request(this->host, this->port, this->https, path, method, this->headers, this->options);
}
void HttpFullBody::pull() {
@@ -435,6 +449,9 @@ Request::~Request() {
}
Response::Response(Socket socket): socket(socket) {
#if HAVE_ZLIB
memset(&this->zstream, 0, sizeof(z_stream));
#endif
parseStatus();
parseHeader();
}
@@ -479,10 +496,35 @@ void Response::parseHeader() {
if (!cstr_stricmp(it.c_str(), "chunked")) {
this->chunked = true;
break;
#if HAVE_ZLIB
} else if (!cstr_stricmp(it.c_str(), "gzip")) {
this->gzip = true;
inflateInit2(&this->zstream, 16 + MAX_WBITS);
break;
} else if (!cstr_stricmp(it.c_str(), "deflate")) {
this->deflate = true;
inflateInit2(&this->zstream, -MAX_WBITS);
break;
#endif
} else {
throw std::runtime_error("Unspported transfer-encoding");
}
}
} else if (!cstr_stricmp(kv[0].c_str(), "content-encoding")) {
#if HAVE_ZLIB
if (kv[1] == "gzip") {
this->gzip = true;
inflateInit2(&this->zstream, 16 + MAX_WBITS);
} else if (kv[1] == "deflate") {
this->deflate = true;
inflateInit2(&this->zstream, -MAX_WBITS);
} else {
throw std::runtime_error("Unspported content-encoding");
}
#else
throw std::runtime_error("Unspported content-encoding");
#endif
}
line = this->readLine();
}
@@ -527,11 +569,21 @@ std::string Response::read() {
throw std::runtime_error("Chunk size != data length");
}
this->readLine();
#if HAVE_ZLIB
if (this->gzip || this->deflate) {
return this->inflate(data);
}
#endif
return data;
} else {
if (this->pullData()) return "";
auto data = this->buff;
this->buff.clear();
#if HAVE_ZLIB
if (this->gzip || this->deflate) {
return this->inflate(data);
}
#endif
return data;
}
}
@@ -545,3 +597,37 @@ std::string Response::readAll() {
}
return data;
}
Response::~Response() {
#if HAVE_ZLIB
if (this->gzip || this->deflate) {
inflateEnd(&this->zstream);
}
#endif
}
#if HAVE_ZLIB
std::string Response::inflate(std::string data) {
this->zstream.next_in = (Bytef*)data.c_str();
this->zstream.avail_in = data.length();
std::string re;
while (true) {
char buf[10240];
this->zstream.next_out = (Bytef*)buf;
this->zstream.avail_out = 10240;
int res = ::inflate(&this->zstream, Z_NO_FLUSH);
if (res == Z_STREAM_END) {
re += std::string(buf, 10240 - this->zstream.avail_out);
break;
} else if (res == Z_BUF_ERROR) {
re += std::string(buf, 10240 - this->zstream.avail_out);
} else if (res != Z_OK) {
throw std::runtime_error("inflate failed");
} else {
re += std::string(buf, 10240 - this->zstream.avail_out);
break;
}
}
return re;
}
#endif

View File

@@ -23,6 +23,10 @@
#include "openssl/ssl.h"
#endif
#if HAVE_ZLIB
#include "zlib.h"
#endif
struct HeaderNameCompare {
bool operator()(const std::string& a, const std::string& b) const {
return cstr_stricmp(a.c_str(), b.c_str()) < 0;
@@ -35,7 +39,6 @@ std::string decodeURIComponent(std::string str);
class HttpClientOptions {
public:
bool https = false;
};
class HttpBody {
@@ -95,7 +98,7 @@ private:
class Socket {
public:
Socket(std::string host, std::string protocol);
Socket(std::string host, std::string port, bool https = false);
Socket(Socket& socket) {
*this = socket;
socket.moved = true;
@@ -118,12 +121,12 @@ public:
private:
#if HAVE_OPENSSL
SSL_CTX* ssl_ctx = nullptr;
BIO* web = nullptr;
SSL* ssl = nullptr;
#endif
bool moved = false;
std::string host;
std::string protocol;
std::string port;
bool https = false;
int socket = -1;
bool closed = false;
addrinfo* addr = nullptr;
@@ -133,13 +136,15 @@ typedef struct Response Response;
class Request {
public:
Request(std::string host, std::string path, std::string method, HeaderMap headers, HttpClientOptions options);
Request(std::string host, std::string port, bool https, std::string path, std::string method, HeaderMap headers, HttpClientOptions options);
~Request();
Response send();
HeaderMap headers;
HttpClientOptions options;
void setBody(HttpBody* body);
std::string host;
std::string port;
bool https = false;
std::string path;
std::string method;
private:
@@ -150,6 +155,7 @@ class Response {
public:
Response() = delete;
explicit Response(Socket socket);
~Response();
HeaderMap headers;
uint16_t code = 0;
std::string reason;
@@ -162,6 +168,12 @@ private:
bool pullData();
bool headerParsed = false;
bool chunked = false;
#if HAVE_ZLIB
bool gzip = false;
bool deflate = false;
z_stream zstream;
std::string inflate(std::string data);
#endif
Socket socket;
std::string buff;
};
@@ -174,6 +186,8 @@ public:
HeaderMap headers;
private:
std::string host;
std::string port;
bool https = false;
};
#endif

View File

@@ -18,3 +18,4 @@
#cmakedefine HAVE_STRNCASECMP @HAVE_STRNCASECMP@
#cmakedefine HAVE_FCLOSEALL @HAVE_FCLOSEALL@
#cmakedefine HAVE_OPENSSL @HAVE_OPENSSL@
#cmakedefine HAVE_ZLIB @HAVE_ZLIB@