diff --git a/CMakeLists.txt b/CMakeLists.txt index 5c6b139..1c7293e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ option(ENABLE_STANDALONE "Build utils standalone" OFF) option(ENABLE_CXX17 "Enable C++ 17" OFF) option(INSTALL_DEP_FILES "Install a file with dependences." OFF) option(ENABLE_SSL "Enable SSL" OFF) +option(ENABLE_ZLIB "Use Zlib to uncompress http data." OFF) if (ENABLE_STANDALONE) project(utils) @@ -29,6 +30,12 @@ if (ENABLE_SSL) set(HAVE_OPENSSL 1) endif() +if (ENABLE_ZLIB) + find_package(ZLIB REQUIRED) + set(HAVE_ZLIB 1) +endif() + + if (Iconv_FOUND) set(HAVE_ICONV 1) endif() @@ -132,6 +139,9 @@ endif() if (ENABLE_SSL) target_link_libraries(utils OpenSSL::SSL OpenSSL::Crypto) endif() +if (ENABLE_ZLIB) + target_link_libraries(utils ZLIB::ZLIB) +endif() if (ENABLE_CXX17) target_compile_features(utils PRIVATE cxx_std_17) endif() diff --git a/http_client.cpp b/http_client.cpp index 5dad271..fbbc0dd 100644 --- a/http_client.cpp +++ b/http_client.cpp @@ -98,19 +98,15 @@ const char* SocketError::what() { return this->message.c_str(); } -Socket::Socket(std::string host, std::string protocol) { +Socket::Socket(std::string host, std::string port, bool https) { this->host = host; - this->protocol = protocol; -#if HAVE_OPENSSL - if (protocol == "https") { - return; - } -#endif + this->port = port; + this->https = https; addrinfo hints = { 0 }; hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; - int re = getaddrinfo(host.c_str(), protocol.c_str(), &hints, &this->addr); + int re = getaddrinfo(host.c_str(), port.c_str(), &hints, &this->addr); if (re) { throw AIException(re); } @@ -118,17 +114,17 @@ Socket::Socket(std::string host, std::string protocol) { Socket::~Socket() { if (this->moved) return; + this->close(); #if HAVE_OPENSSL + if (this->ssl) { + SSL_free(this->ssl); + this->ssl = nullptr; + } if (this->ssl_ctx) { SSL_CTX_free(this->ssl_ctx); this->ssl_ctx = nullptr; } - if (this->web) { - BIO_free_all(this->web); - this->web = nullptr; - } #endif - this->close(); if (this->addr) { freeaddrinfo(this->addr); this->addr = nullptr; @@ -136,45 +132,6 @@ Socket::~Socket() { } void Socket::connect() { -#if HAVE_OPENSSL - if (this->protocol == "https") { - const SSL_METHOD* method = SSLv23_client_method(); - if (!method) { - throw std::runtime_error("SSLv23_client_method failed"); - } - this->ssl_ctx = SSL_CTX_new(method); - if (!this->ssl_ctx) { - throw std::runtime_error("SSL_CTX_new failed"); - } - const long flags = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION; - SSL_CTX_set_options(this->ssl_ctx, flags); - this->web = BIO_new_ssl_connect(this->ssl_ctx); - if (!this->web) { - throw std::runtime_error("BIO_new_ssl_connect failed"); - } - int res =BIO_set_conn_hostname(this->web, this->host.c_str()); - if (res == -1) { - throw std::runtime_error("BIO_set_conn_hostname failed"); - } - BIO_get_ssl(this->web, &this->ssl); - if (!this->ssl) { - throw std::runtime_error("BIO_get_ssl failed"); - } - res = SSL_set_tlsext_host_name(this->ssl, this->host.c_str()); - if (res == -1) { - throw std::runtime_error("SSL_set_tlsext_host_name failed"); - } - res = BIO_do_connect(this->web); - if (res != 1) { - throw std::runtime_error("BIO_do_connect failed"); - } - res = BIO_do_handshake(this->web); - if (res != 1) { - throw std::runtime_error("BIO_do_handshake failed"); - } - return; - } -#endif if (this->socket != -1 || !this->addr) { return; } @@ -194,21 +151,39 @@ void Socket::connect() { #endif throw SocketError(); } +#if HAVE_OPENSSL + if (this->https) { + this->ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (!this->ssl_ctx) { + throw std::runtime_error("SSL_CTX_new failed"); + } + this->ssl = SSL_new(this->ssl_ctx); + if (!this->ssl) { + throw std::runtime_error("SSL_new failed"); + } + SSL_set_fd(this->ssl, this->socket); + SSL_set_tlsext_host_name(this->ssl, this->host.c_str()); + re = SSL_connect(this->ssl); + if (re != 1) { + throw std::runtime_error("SSL_connect failed"); + } + } +#endif } size_t Socket::send(const char* data, size_t len, int flags) { + if (this->socket == -1) { + throw std::runtime_error("Socket not connected"); + } #if HAVE_OPENSSL - if (this->protocol == "https") { - int sented = BIO_write(this->web, data, (int)len); + if (this->https) { + int sented = SSL_write(this->ssl, data, (int)len); if (sented <= 0) { throw std::runtime_error("BIO_write failed"); } return sented; } #endif - if (this->socket == -1) { - throw std::runtime_error("Socket not connected"); - } #if _WIN32 int sented = ::send(this->socket, data, (int)len, flags); if (sented == SOCKET_ERROR) { @@ -226,18 +201,18 @@ size_t Socket::send(std::string data, int flags) { } size_t Socket::recv(char* data, size_t len, int flags) { + if (this->socket == -1) { + throw std::runtime_error("Socket not connected"); + } #if HAVE_OPENSSL - if (this->protocol == "https") { - int recved = BIO_read(this->web, data, (int)len); - if (recved <= 0) { - throw std::runtime_error("BIO_read failed"); + if (this->https) { + int recved = SSL_read(this->ssl, data, (int)len); + if (recved < 0) { + throw std::runtime_error("SSL_read failed"); } return recved; } #endif - if (this->socket == -1) { - throw std::runtime_error("Socket not connected"); - } #if _WIN32 int recved = ::recv(this->socket, data, (int)len, flags); if (recved == SOCKET_ERROR) { @@ -265,6 +240,12 @@ std::string Socket::recv(size_t len, int flags) { void Socket::close() { if (this->socket != -1 && !closed) { +#if HAVE_OPENSSL + if (this->https) { + SSL_set_shutdown(this->ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN); + SSL_shutdown(this->ssl); + } +#endif #if _WIN32 closesocket(this->socket); #else @@ -274,8 +255,10 @@ void Socket::close() { } } -Request::Request(std::string host, std::string path, std::string method, HeaderMap headers, HttpClientOptions options) { +Request::Request(std::string host, std::string port, bool https, std::string path, std::string method, HeaderMap headers, HttpClientOptions options) { this->host = host; + this->port = port; + this->https = https; this->path = path; this->method = method; this->headers = headers; @@ -285,7 +268,7 @@ Request::Request(std::string host, std::string path, std::string method, HeaderM Response Request::send() { std::string data; data += this->method + " " + this->path + " HTTP/1.1\r\n"; - auto hasBody = !this->body || !this->body->isFinished(); + auto hasBody = this->body && !this->body->isFinished(); if (hasBody) { auto type = this->body->contentType(); if (!type.empty()) { @@ -300,7 +283,7 @@ Response Request::send() { } data += "\r\n"; #if HAVE_OPENSSL - Socket socket(this->host, this->options.https ? "https" : "http"); + Socket socket(this->host, this->port, this->https); #else Socket socket(this->host, "http"); #endif @@ -323,7 +306,29 @@ void Request::setBody(HttpBody* body) { } HttpClient::HttpClient(std::string host) { - this->host = host; + auto pos = host.find("://"); + bool is_number_port = false; + if (pos == std::string::npos) { + this->host = host; + } else { + this->host = host.substr(pos + 3); + auto protocol = host.substr(0, pos); + if (!cstr_stricmp(protocol.c_str(), "http")) { + this->https = false; + } else if (!cstr_stricmp(protocol.c_str(), "https")) { + this->https = true; + } else { + throw std::runtime_error("Unspported protocol"); + } + } + pos = this->host.find(":"); + if (pos != std::string::npos) { + this->port = this->host.substr(pos + 1); + this->host = this->host.substr(0, pos); + is_number_port = true; + } else { + this->port = this->https ? "https" : "http"; + } #if _WIN32 if (!inited) { WSAStartup(MAKEWORD(2, 2), &wsaData); @@ -340,13 +345,22 @@ HttpClient::HttpClient(std::string host) { ssl_inited = true; } #endif - this->headers["Host"] = host; + if (is_number_port) { + this->headers["Host"] = this->host + ":" + this->port; + } else { + this->headers["Host"] = this->host; + } this->headers["User-Agent"] = "simple-http-client"; this->headers["Accept"] = "*/*"; + std::string ae = ""; +#if HAVE_ZLIB + ae += "deflate, gzip"; +#endif + this->headers["Accept-Encoding"] = ae; } Request HttpClient::request(std::string path, std::string method) { - return Request(this->host, path, method, this->headers, this->options); + return Request(this->host, this->port, this->https, path, method, this->headers, this->options); } void HttpFullBody::pull() { @@ -435,6 +449,9 @@ Request::~Request() { } Response::Response(Socket socket): socket(socket) { +#if HAVE_ZLIB + memset(&this->zstream, 0, sizeof(z_stream)); +#endif parseStatus(); parseHeader(); } @@ -479,10 +496,35 @@ void Response::parseHeader() { if (!cstr_stricmp(it.c_str(), "chunked")) { this->chunked = true; break; +#if HAVE_ZLIB + } else if (!cstr_stricmp(it.c_str(), "gzip")) { + this->gzip = true; + inflateInit2(&this->zstream, 16 + MAX_WBITS); + break; + } else if (!cstr_stricmp(it.c_str(), "deflate")) { + this->deflate = true; + inflateInit2(&this->zstream, -MAX_WBITS); + break; +#endif } else { throw std::runtime_error("Unspported transfer-encoding"); } } + } else if (!cstr_stricmp(kv[0].c_str(), "content-encoding")) { +#if HAVE_ZLIB + if (kv[1] == "gzip") { + this->gzip = true; + inflateInit2(&this->zstream, 16 + MAX_WBITS); + + } else if (kv[1] == "deflate") { + this->deflate = true; + inflateInit2(&this->zstream, -MAX_WBITS); + } else { + throw std::runtime_error("Unspported content-encoding"); + } +#else + throw std::runtime_error("Unspported content-encoding"); +#endif } line = this->readLine(); } @@ -527,11 +569,21 @@ std::string Response::read() { throw std::runtime_error("Chunk size != data length"); } this->readLine(); +#if HAVE_ZLIB + if (this->gzip || this->deflate) { + return this->inflate(data); + } +#endif return data; } else { if (this->pullData()) return ""; auto data = this->buff; this->buff.clear(); +#if HAVE_ZLIB + if (this->gzip || this->deflate) { + return this->inflate(data); + } +#endif return data; } } @@ -545,3 +597,37 @@ std::string Response::readAll() { } return data; } + +Response::~Response() { +#if HAVE_ZLIB + if (this->gzip || this->deflate) { + inflateEnd(&this->zstream); + } +#endif +} + +#if HAVE_ZLIB +std::string Response::inflate(std::string data) { + this->zstream.next_in = (Bytef*)data.c_str(); + this->zstream.avail_in = data.length(); + std::string re; + while (true) { + char buf[10240]; + this->zstream.next_out = (Bytef*)buf; + this->zstream.avail_out = 10240; + int res = ::inflate(&this->zstream, Z_NO_FLUSH); + if (res == Z_STREAM_END) { + re += std::string(buf, 10240 - this->zstream.avail_out); + break; + } else if (res == Z_BUF_ERROR) { + re += std::string(buf, 10240 - this->zstream.avail_out); + } else if (res != Z_OK) { + throw std::runtime_error("inflate failed"); + } else { + re += std::string(buf, 10240 - this->zstream.avail_out); + break; + } + } + return re; +} +#endif diff --git a/http_client.h b/http_client.h index e02f3b9..d7c4710 100644 --- a/http_client.h +++ b/http_client.h @@ -23,6 +23,10 @@ #include "openssl/ssl.h" #endif +#if HAVE_ZLIB +#include "zlib.h" +#endif + struct HeaderNameCompare { bool operator()(const std::string& a, const std::string& b) const { return cstr_stricmp(a.c_str(), b.c_str()) < 0; @@ -35,7 +39,6 @@ std::string decodeURIComponent(std::string str); class HttpClientOptions { public: - bool https = false; }; class HttpBody { @@ -95,7 +98,7 @@ private: class Socket { public: - Socket(std::string host, std::string protocol); + Socket(std::string host, std::string port, bool https = false); Socket(Socket& socket) { *this = socket; socket.moved = true; @@ -118,12 +121,12 @@ public: private: #if HAVE_OPENSSL SSL_CTX* ssl_ctx = nullptr; - BIO* web = nullptr; SSL* ssl = nullptr; #endif bool moved = false; std::string host; - std::string protocol; + std::string port; + bool https = false; int socket = -1; bool closed = false; addrinfo* addr = nullptr; @@ -133,13 +136,15 @@ typedef struct Response Response; class Request { public: - Request(std::string host, std::string path, std::string method, HeaderMap headers, HttpClientOptions options); + Request(std::string host, std::string port, bool https, std::string path, std::string method, HeaderMap headers, HttpClientOptions options); ~Request(); Response send(); HeaderMap headers; HttpClientOptions options; void setBody(HttpBody* body); std::string host; + std::string port; + bool https = false; std::string path; std::string method; private: @@ -150,6 +155,7 @@ class Response { public: Response() = delete; explicit Response(Socket socket); + ~Response(); HeaderMap headers; uint16_t code = 0; std::string reason; @@ -162,6 +168,12 @@ private: bool pullData(); bool headerParsed = false; bool chunked = false; +#if HAVE_ZLIB + bool gzip = false; + bool deflate = false; + z_stream zstream; + std::string inflate(std::string data); +#endif Socket socket; std::string buff; }; @@ -174,6 +186,8 @@ public: HeaderMap headers; private: std::string host; + std::string port; + bool https = false; }; #endif diff --git a/utils_config.h.in b/utils_config.h.in index 5181966..2b90c52 100644 --- a/utils_config.h.in +++ b/utils_config.h.in @@ -18,3 +18,4 @@ #cmakedefine HAVE_STRNCASECMP @HAVE_STRNCASECMP@ #cmakedefine HAVE_FCLOSEALL @HAVE_FCLOSEALL@ #cmakedefine HAVE_OPENSSL @HAVE_OPENSSL@ +#cmakedefine HAVE_ZLIB @HAVE_ZLIB@