Fix BIO_do_connect freeze

This commit is contained in:
2024-01-10 17:33:01 +08:00
parent db452fce0f
commit c6105db80d
4 changed files with 186 additions and 75 deletions

View File

@@ -98,19 +98,15 @@ const char* SocketError::what() {
return this->message.c_str();
}
Socket::Socket(std::string host, std::string protocol) {
Socket::Socket(std::string host, std::string port, bool https) {
this->host = host;
this->protocol = protocol;
#if HAVE_OPENSSL
if (protocol == "https") {
return;
}
#endif
this->port = port;
this->https = https;
addrinfo hints = { 0 };
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
int re = getaddrinfo(host.c_str(), protocol.c_str(), &hints, &this->addr);
int re = getaddrinfo(host.c_str(), port.c_str(), &hints, &this->addr);
if (re) {
throw AIException(re);
}
@@ -118,17 +114,17 @@ Socket::Socket(std::string host, std::string protocol) {
Socket::~Socket() {
if (this->moved) return;
this->close();
#if HAVE_OPENSSL
if (this->ssl) {
SSL_free(this->ssl);
this->ssl = nullptr;
}
if (this->ssl_ctx) {
SSL_CTX_free(this->ssl_ctx);
this->ssl_ctx = nullptr;
}
if (this->web) {
BIO_free_all(this->web);
this->web = nullptr;
}
#endif
this->close();
if (this->addr) {
freeaddrinfo(this->addr);
this->addr = nullptr;
@@ -136,45 +132,6 @@ Socket::~Socket() {
}
void Socket::connect() {
#if HAVE_OPENSSL
if (this->protocol == "https") {
const SSL_METHOD* method = SSLv23_client_method();
if (!method) {
throw std::runtime_error("SSLv23_client_method failed");
}
this->ssl_ctx = SSL_CTX_new(method);
if (!this->ssl_ctx) {
throw std::runtime_error("SSL_CTX_new failed");
}
const long flags = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION;
SSL_CTX_set_options(this->ssl_ctx, flags);
this->web = BIO_new_ssl_connect(this->ssl_ctx);
if (!this->web) {
throw std::runtime_error("BIO_new_ssl_connect failed");
}
int res =BIO_set_conn_hostname(this->web, this->host.c_str());
if (res == -1) {
throw std::runtime_error("BIO_set_conn_hostname failed");
}
BIO_get_ssl(this->web, &this->ssl);
if (!this->ssl) {
throw std::runtime_error("BIO_get_ssl failed");
}
res = SSL_set_tlsext_host_name(this->ssl, this->host.c_str());
if (res == -1) {
throw std::runtime_error("SSL_set_tlsext_host_name failed");
}
res = BIO_do_connect(this->web);
if (res != 1) {
throw std::runtime_error("BIO_do_connect failed");
}
res = BIO_do_handshake(this->web);
if (res != 1) {
throw std::runtime_error("BIO_do_handshake failed");
}
return;
}
#endif
if (this->socket != -1 || !this->addr) {
return;
}
@@ -194,21 +151,39 @@ void Socket::connect() {
#endif
throw SocketError();
}
#if HAVE_OPENSSL
if (this->https) {
this->ssl_ctx = SSL_CTX_new(TLS_client_method());
if (!this->ssl_ctx) {
throw std::runtime_error("SSL_CTX_new failed");
}
this->ssl = SSL_new(this->ssl_ctx);
if (!this->ssl) {
throw std::runtime_error("SSL_new failed");
}
SSL_set_fd(this->ssl, this->socket);
SSL_set_tlsext_host_name(this->ssl, this->host.c_str());
re = SSL_connect(this->ssl);
if (re != 1) {
throw std::runtime_error("SSL_connect failed");
}
}
#endif
}
size_t Socket::send(const char* data, size_t len, int flags) {
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if HAVE_OPENSSL
if (this->protocol == "https") {
int sented = BIO_write(this->web, data, (int)len);
if (this->https) {
int sented = SSL_write(this->ssl, data, (int)len);
if (sented <= 0) {
throw std::runtime_error("BIO_write failed");
}
return sented;
}
#endif
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if _WIN32
int sented = ::send(this->socket, data, (int)len, flags);
if (sented == SOCKET_ERROR) {
@@ -226,18 +201,18 @@ size_t Socket::send(std::string data, int flags) {
}
size_t Socket::recv(char* data, size_t len, int flags) {
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if HAVE_OPENSSL
if (this->protocol == "https") {
int recved = BIO_read(this->web, data, (int)len);
if (recved <= 0) {
throw std::runtime_error("BIO_read failed");
if (this->https) {
int recved = SSL_read(this->ssl, data, (int)len);
if (recved < 0) {
throw std::runtime_error("SSL_read failed");
}
return recved;
}
#endif
if (this->socket == -1) {
throw std::runtime_error("Socket not connected");
}
#if _WIN32
int recved = ::recv(this->socket, data, (int)len, flags);
if (recved == SOCKET_ERROR) {
@@ -265,6 +240,12 @@ std::string Socket::recv(size_t len, int flags) {
void Socket::close() {
if (this->socket != -1 && !closed) {
#if HAVE_OPENSSL
if (this->https) {
SSL_set_shutdown(this->ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
SSL_shutdown(this->ssl);
}
#endif
#if _WIN32
closesocket(this->socket);
#else
@@ -274,8 +255,10 @@ void Socket::close() {
}
}
Request::Request(std::string host, std::string path, std::string method, HeaderMap headers, HttpClientOptions options) {
Request::Request(std::string host, std::string port, bool https, std::string path, std::string method, HeaderMap headers, HttpClientOptions options) {
this->host = host;
this->port = port;
this->https = https;
this->path = path;
this->method = method;
this->headers = headers;
@@ -285,7 +268,7 @@ Request::Request(std::string host, std::string path, std::string method, HeaderM
Response Request::send() {
std::string data;
data += this->method + " " + this->path + " HTTP/1.1\r\n";
auto hasBody = !this->body || !this->body->isFinished();
auto hasBody = this->body && !this->body->isFinished();
if (hasBody) {
auto type = this->body->contentType();
if (!type.empty()) {
@@ -300,7 +283,7 @@ Response Request::send() {
}
data += "\r\n";
#if HAVE_OPENSSL
Socket socket(this->host, this->options.https ? "https" : "http");
Socket socket(this->host, this->port, this->https);
#else
Socket socket(this->host, "http");
#endif
@@ -323,7 +306,29 @@ void Request::setBody(HttpBody* body) {
}
HttpClient::HttpClient(std::string host) {
this->host = host;
auto pos = host.find("://");
bool is_number_port = false;
if (pos == std::string::npos) {
this->host = host;
} else {
this->host = host.substr(pos + 3);
auto protocol = host.substr(0, pos);
if (!cstr_stricmp(protocol.c_str(), "http")) {
this->https = false;
} else if (!cstr_stricmp(protocol.c_str(), "https")) {
this->https = true;
} else {
throw std::runtime_error("Unspported protocol");
}
}
pos = this->host.find(":");
if (pos != std::string::npos) {
this->port = this->host.substr(pos + 1);
this->host = this->host.substr(0, pos);
is_number_port = true;
} else {
this->port = this->https ? "https" : "http";
}
#if _WIN32
if (!inited) {
WSAStartup(MAKEWORD(2, 2), &wsaData);
@@ -340,13 +345,22 @@ HttpClient::HttpClient(std::string host) {
ssl_inited = true;
}
#endif
this->headers["Host"] = host;
if (is_number_port) {
this->headers["Host"] = this->host + ":" + this->port;
} else {
this->headers["Host"] = this->host;
}
this->headers["User-Agent"] = "simple-http-client";
this->headers["Accept"] = "*/*";
std::string ae = "";
#if HAVE_ZLIB
ae += "deflate, gzip";
#endif
this->headers["Accept-Encoding"] = ae;
}
Request HttpClient::request(std::string path, std::string method) {
return Request(this->host, path, method, this->headers, this->options);
return Request(this->host, this->port, this->https, path, method, this->headers, this->options);
}
void HttpFullBody::pull() {
@@ -435,6 +449,9 @@ Request::~Request() {
}
Response::Response(Socket socket): socket(socket) {
#if HAVE_ZLIB
memset(&this->zstream, 0, sizeof(z_stream));
#endif
parseStatus();
parseHeader();
}
@@ -479,10 +496,35 @@ void Response::parseHeader() {
if (!cstr_stricmp(it.c_str(), "chunked")) {
this->chunked = true;
break;
#if HAVE_ZLIB
} else if (!cstr_stricmp(it.c_str(), "gzip")) {
this->gzip = true;
inflateInit2(&this->zstream, 16 + MAX_WBITS);
break;
} else if (!cstr_stricmp(it.c_str(), "deflate")) {
this->deflate = true;
inflateInit2(&this->zstream, -MAX_WBITS);
break;
#endif
} else {
throw std::runtime_error("Unspported transfer-encoding");
}
}
} else if (!cstr_stricmp(kv[0].c_str(), "content-encoding")) {
#if HAVE_ZLIB
if (kv[1] == "gzip") {
this->gzip = true;
inflateInit2(&this->zstream, 16 + MAX_WBITS);
} else if (kv[1] == "deflate") {
this->deflate = true;
inflateInit2(&this->zstream, -MAX_WBITS);
} else {
throw std::runtime_error("Unspported content-encoding");
}
#else
throw std::runtime_error("Unspported content-encoding");
#endif
}
line = this->readLine();
}
@@ -527,11 +569,21 @@ std::string Response::read() {
throw std::runtime_error("Chunk size != data length");
}
this->readLine();
#if HAVE_ZLIB
if (this->gzip || this->deflate) {
return this->inflate(data);
}
#endif
return data;
} else {
if (this->pullData()) return "";
auto data = this->buff;
this->buff.clear();
#if HAVE_ZLIB
if (this->gzip || this->deflate) {
return this->inflate(data);
}
#endif
return data;
}
}
@@ -545,3 +597,37 @@ std::string Response::readAll() {
}
return data;
}
Response::~Response() {
#if HAVE_ZLIB
if (this->gzip || this->deflate) {
inflateEnd(&this->zstream);
}
#endif
}
#if HAVE_ZLIB
std::string Response::inflate(std::string data) {
this->zstream.next_in = (Bytef*)data.c_str();
this->zstream.avail_in = data.length();
std::string re;
while (true) {
char buf[10240];
this->zstream.next_out = (Bytef*)buf;
this->zstream.avail_out = 10240;
int res = ::inflate(&this->zstream, Z_NO_FLUSH);
if (res == Z_STREAM_END) {
re += std::string(buf, 10240 - this->zstream.avail_out);
break;
} else if (res == Z_BUF_ERROR) {
re += std::string(buf, 10240 - this->zstream.avail_out);
} else if (res != Z_OK) {
throw std::runtime_error("inflate failed");
} else {
re += std::string(buf, 10240 - this->zstream.avail_out);
break;
}
}
return re;
}
#endif