diff --git a/CMakeLists.txt b/CMakeLists.txt index f3eacc8..5c6b139 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ option(ENABLE_ICONV "Use libiconv to convert encoding" ON) option(ENABLE_STANDALONE "Build utils standalone" OFF) option(ENABLE_CXX17 "Enable C++ 17" OFF) option(INSTALL_DEP_FILES "Install a file with dependences." OFF) +option(ENABLE_SSL "Enable SSL" OFF) if (ENABLE_STANDALONE) project(utils) @@ -23,6 +24,11 @@ if (ENABLE_ICONV) find_package(Iconv) endif() +if (ENABLE_SSL) + find_package(OpenSSL 3.0.0 REQUIRED) + set(HAVE_OPENSSL 1) +endif() + if (Iconv_FOUND) set(HAVE_ICONV 1) endif() @@ -82,6 +88,7 @@ set(SOURCE_FILE c_linked_list.cpp file_reader.c urlparse.cpp + http_client.cpp ) set(SOURCE_FILE_HEADERS cfileop.h @@ -101,6 +108,7 @@ set(SOURCE_FILE_HEADERS c_linked_list.h file_reader.h urlparse.h + http_client.h ) add_library(utils STATIC ${SOURCE_FILE} ${SOURCE_FILE_HEADERS}) @@ -118,6 +126,12 @@ if (INSTALL_DEP_FILES) target_link_libraries(utils shell32) endif() endif() +if (WIN32) + target_link_libraries(utils Ws2_32) +endif() +if (ENABLE_SSL) + target_link_libraries(utils OpenSSL::SSL OpenSSL::Crypto) +endif() if (ENABLE_CXX17) target_compile_features(utils PRIVATE cxx_std_17) endif() diff --git a/cstr_util.c b/cstr_util.c index c007b49..7f2afb6 100644 --- a/cstr_util.c +++ b/cstr_util.c @@ -92,6 +92,21 @@ int cstr_tolowercase(const char* str, size_t input_len, char** output) { return 1; } +int cstr_touppercase(const char* str, size_t input_len, char** output) { + if (!str || !output) return 0; + if (!input_len) input_len = strlen(str); + if (input_len == (size_t)-1) return 0; + char* tmp = malloc(input_len + 1); + if (!tmp) return 0; + size_t i = 0; + for (; i < input_len; i++) { + tmp[i] = toupper(str[i]); + } + tmp[input_len] = 0; + *output = tmp; + return 1; +} + uint32_t cstr_read_uint32(const uint8_t* bytes, int big) { if (!bytes) return 0; return big ? (bytes[0] << 24) + (bytes[1] << 16) + (bytes[2] << 8) + bytes[3] : bytes[0] + (bytes[1] << 8) + (bytes[2] << 16) + (bytes[3] << 24); diff --git a/cstr_util.h b/cstr_util.h index 4dfc52f..1b03e61 100644 --- a/cstr_util.h +++ b/cstr_util.h @@ -27,6 +27,14 @@ int cstr_is_integer(const char* str, int allow_sign); * @return 1 if successed otherwise 0. */ int cstr_tolowercase(const char* str, size_t input_len, char** output); +/** + * @brief Convert string to uppercase + * @param str Origin string + * @param input_len The length of origin string. If is 0, strlen will be called to calculate length. + * @param output Output string. Need free memory by calling free. + * @return 1 if successed otherwise 0. +*/ +int cstr_touppercase(const char* str, size_t input_len, char** output); /** * @brief Convert bytes to uint32 * @param bytes Bytes (at least 4 bytes) diff --git a/http_client.cpp b/http_client.cpp new file mode 100644 index 0000000..70ba1ee --- /dev/null +++ b/http_client.cpp @@ -0,0 +1,423 @@ +#include "http_client.h" + +#ifndef _WIN32 +#include +#endif + +#include "err.h" +#include "str_util.h" + +#include +#include + +#if _WIN32 +static bool inited = false; +static WSADATA wsaData = { 0 }; +#endif + +#if HAVE_OPENSSL +static bool ssl_inited = false; +#endif + +AIException::AIException(int code) { + this->code = code; +} + +const char* AIException::what() { +#if _WIN32 + return gai_strerrorA(this->code); +#else + return gai_strerror(this->code); +#endif +} + +std::string decodeURIComponent(std::string str) { + std::string re; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '%') { + if (i + 2 >= str.length()) { + throw std::runtime_error("Invalid Percent-encoding"); + } + char buf[3] = { str[i + 1], str[i + 2], 0 }; + int c = strtol(buf, nullptr, 16); + re += (char)c; + i += 2; + } else { + re += str[i]; + } + } + return re; +} + +std::string encodeURIComponent(std::string str) { + std::string re; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == ' ') { + re += '+'; + } else if (str[i] == '-' || str[i] == '_' || str[i] == '.' || str[i] == '~' || + (str[i] >= 'a' && str[i] <= 'z') || + (str[i] >= 'A' && str[i] <= 'Z') || + (str[i] >= '0' && str[i] <= '9')) { + re += str[i]; + } else { + char buf[4] = { '%', 0 }; + snprintf(buf + 1, 3, "%02X", (unsigned char)str[i]); + re += buf; + } + } + return re; +} + +SocketError::SocketError() { +#if _WIN32 + this->code = WSAGetLastError(); +#else + this->code = errno; +#endif +#if _WIN32 + if (!err::get_winerror(this->message, this->code)) { +#else + if (!err::get_errno_message(this->message, this->code)) { +#endif + char buf[32]; + snprintf(buf, 32, "Unknown error %d", this->code); + this->message = std::string(buf); + } +} + +const char* SocketError::what() { + return this->message.c_str(); +} + +Socket::Socket(std::string host, std::string protocol) { + this->host = host; + this->protocol = protocol; +#if HAVE_OPENSSL + if (protocol == "https") { + return; + } +#endif + addrinfo hints = { 0 }; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + int re = getaddrinfo(host.c_str(), protocol.c_str(), &hints, &this->addr); + if (re) { + throw AIException(re); + } +} + +Socket::~Socket() { +#if HAVE_OPENSSL + if (this->ssl_ctx) { + SSL_CTX_free(this->ssl_ctx); + this->ssl_ctx = nullptr; + } + if (this->web) { + BIO_free_all(this->web); + this->web = nullptr; + } +#endif + this->close(); + if (this->addr) { + freeaddrinfo(this->addr); + this->addr = nullptr; + } +} + +void Socket::connect() { +#if HAVE_OPENSSL + if (this->protocol == "https") { + const SSL_METHOD* method = SSLv23_client_method(); + if (!method) { + throw std::runtime_error("SSLv23_client_method failed"); + } + this->ssl_ctx = SSL_CTX_new(method); + if (!this->ssl_ctx) { + throw std::runtime_error("SSL_CTX_new failed"); + } + const long flags = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION; + SSL_CTX_set_options(this->ssl_ctx, flags); + this->web = BIO_new_ssl_connect(this->ssl_ctx); + if (!this->web) { + throw std::runtime_error("BIO_new_ssl_connect failed"); + } + int res =BIO_set_conn_hostname(this->web, this->host.c_str()); + if (res == -1) { + throw std::runtime_error("BIO_set_conn_hostname failed"); + } + BIO_get_ssl(this->web, &this->ssl); + if (!this->ssl) { + throw std::runtime_error("BIO_get_ssl failed"); + } + res = SSL_set_tlsext_host_name(this->ssl, this->host.c_str()); + if (res == -1) { + throw std::runtime_error("SSL_set_tlsext_host_name failed"); + } + res = BIO_do_connect(this->web); + if (res != 1) { + throw std::runtime_error("BIO_do_connect failed"); + } + res = BIO_do_handshake(this->web); + if (res != 1) { + throw std::runtime_error("BIO_do_handshake failed"); + } + return; + } +#endif + if (this->socket != -1 || !this->addr) { + return; + } + this->socket = ::socket(this->addr->ai_family, this->addr->ai_socktype, this->addr->ai_protocol); +#if _WIN32 + if (this->socket == INVALID_SOCKET) { +#else + if (this->socket == -1) { +#endif + throw SocketError(); + } + int re = ::connect(this->socket, this->addr->ai_addr, this->addr->ai_addrlen); +#if _WIN32 + if (re == SOCKET_ERROR) { +#else + if (re == -1) { +#endif + throw SocketError(); + } +} + +size_t Socket::send(const char* data, size_t len, int flags) { +#if HAVE_OPENSSL + if (this->protocol == "https") { + int sented = BIO_write(this->web, data, (int)len); + if (sented <= 0) { + throw std::runtime_error("BIO_write failed"); + } + return sented; + } +#endif + if (this->socket == -1) { + throw std::runtime_error("Socket not connected"); + } +#if _WIN32 + int sented = ::send(this->socket, data, (int)len, flags); + if (sented == SOCKET_ERROR) { +#else + ssize_t sented = ::send(this->socket, data, len, flags); + if (sented == -1) { +#endif + throw SocketError(); + } + return sented; +} + +size_t Socket::send(std::string data, int flags) { + return this->send(data.c_str(), data.length(), flags); +} + +size_t Socket::recv(char* data, size_t len, int flags) { +#if HAVE_OPENSSL + if (this->protocol == "https") { + int recved = BIO_read(this->web, data, (int)len); + if (recved <= 0) { + throw std::runtime_error("BIO_read failed"); + } + return recved; + } +#endif + if (this->socket == -1) { + throw std::runtime_error("Socket not connected"); + } +#if _WIN32 + int recved = ::recv(this->socket, data, (int)len, flags); + if (recved == SOCKET_ERROR) { +#else + ssize_t recved = ::recv(this->socket, data, len, flags); + if (recved == -1) { +#endif + throw SocketError(); + } + return recved; +} + +std::string Socket::recv(size_t len, int flags) { + char* buf = new char[len]; + try { + size_t recved = this->recv(buf, len, flags); + std::string re(buf, recved); + delete[] buf; + return re; + } catch (std::exception& e) { + delete[] buf; + throw e; + } +} + +void Socket::close() { + if (this->socket != -1 && !closed) { +#if _WIN32 + closesocket(this->socket); +#else + ::close(this->socket); +#endif + closed = true; + } +} + +Request::Request(std::string host, std::string path, std::string method, HeaderMap headers, HttpClientOptions options) { + this->host = host; + this->path = path; + this->method = method; + this->headers = headers; + this->options = options; +} + +void Request::send() { + std::string data; + data += this->method + " " + this->path + " HTTP/1.1\r\n"; + auto hasBody = !this->body || !this->body->isFinished(); + if (hasBody) { + auto type = this->body->contentType(); + if (!type.empty()) { + this->headers["Content-Type"] = type; + } + if (this->body->hasLength()) { + this->headers["Content-Length"] = std::to_string(this->body->length()); + } + } + for (auto& header : this->headers) { + data += header.first + ": " + header.second + "\r\n"; + } + data += "\r\n"; + printf("%s", data.c_str()); +#if HAVE_OPENSSL + Socket socket(this->host, this->options.https ? "https" : "http"); +#else + Socket socket(this->host, "http"); +#endif + socket.connect(); + socket.send(data); + if (hasBody) { + while (!this->body->isFinished()) { + char buf[1024]; + size_t len = this->body->pullData(buf, 1024); + socket.send(buf, len, 0); + printf("%s", std::string(buf, len).c_str()); + } + socket.send("\r\n"); + } + auto s = socket.recv(10240); + printf("%s\n", s.c_str()); +} + +HttpClient::HttpClient(std::string host) { + this->host = host; +#if _WIN32 + if (!inited) { + WSAStartup(MAKEWORD(2, 2), &wsaData); + inited = true; + } +#endif +#if HAVE_OPENSSL + if (!ssl_inited) { +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSL_library_init(); +#else + SSL_load_error_strings(); +#endif + ssl_inited = true; + } +#endif + this->headers["Host"] = host; + this->headers["User-Agent"] = "simple-http-client"; + this->headers["Accept"] = "*/*"; +} + +Request HttpClient::request(std::string path, std::string method) { + return Request(this->host, path, method, this->headers, this->options); +} + +void HttpFullBody::pull() { + if (!this->pulled) { + this->buff = this->body(); + this->len = this->buff.length(); + this->pulled = true; + } +} + +size_t HttpFullBody::pullData(char* buf, size_t len) { + this->pull(); + size_t rlen = std::min(this->buff.length(), len); + memcpy(buf, this->buff.c_str(), rlen); + this->buff = this->buff.substr(rlen); + return rlen; +} + +bool HttpFullBody::isFinished() { + this->pull(); + return this->buff.length() == 0; +} + +bool HttpFullBody::hasLength() { + return true; +} + +std::string HttpFullBody::contentType() { + return ""; +} + +size_t HttpFullBody::length() { + this->pull(); + return this->len; +} + +QueryData::QueryData() {} + +QueryData::QueryData(std::string data) { + auto list = str_util::str_split(data, "&"); + for (auto& item : list) { + auto kv = str_util::str_splitv(item, "=", 2); + auto k = decodeURIComponent(kv[0]); + auto v = decodeURIComponent(kv[1]); + this->append(k, v); + } +} + +std::string QueryData::contentType() { + return "application/x-www-form-urlencoded"; +} + +std::string QueryData::toQuery() { + std::string re; + for (auto& item : this->data) { + for (auto& v : item.second) { + if (re.length() > 0) { + re += "&"; + } + re += encodeURIComponent(item.first) + "=" + encodeURIComponent(v); + } + } + return re; +} + +void QueryData::append(std::string key, std::string value) { + if (this->data.find(key) != this->data.end()) { + this->data[key].push_back(value); + } else { + this->data[key] = std::vector({ value }); + } +} + +void QueryData::set(std::string key, std::string value) { + this->data[key] = std::vector({ value }); +} + +std::string QueryData::body() { + return this->toQuery(); +} + +Request::~Request() { + if (this->body) { + delete this->body; + } +} diff --git a/http_client.h b/http_client.h new file mode 100644 index 0000000..781232e --- /dev/null +++ b/http_client.h @@ -0,0 +1,162 @@ +#ifndef _UTIL_HTTP_CLIENT_H +#define _UTIL_HTTP_CLIENT_H + +#include +#include +#include +#include +#include +#include "cstr_util.h" +#include "utils_config.h" + +#if _WIN32 +#include +#include +#else +#include +#include +#include +#endif + +#if HAVE_OPENSSL +#include "openssl/ssl.h" +#endif + +struct HeaderNameCompare { + bool operator()(const std::string& a, const std::string& b) const { + return cstr_stricmp(a.c_str(), b.c_str()) < 0; + } +}; + +typedef std::map HeaderMap; + +std::string decodeURIComponent(std::string str); + +class HttpClientOptions { +public: + bool https = false; +}; + +class HttpBody { +public: + virtual size_t pullData(char* buf, size_t len) = 0; + virtual bool isFinished() = 0; + virtual std::string contentType() = 0; + virtual bool hasLength() = 0; + virtual size_t length() = 0; +}; + +class HttpFullBody: public HttpBody { +public: + virtual size_t pullData(char* buf, size_t len); + virtual bool isFinished(); + virtual bool hasLength(); + virtual std::string contentType(); + virtual size_t length(); +protected: + virtual std::string body() = 0; +private: + void pull(); + std::string buff; + bool pulled; + size_t len; +}; + +class QueryData: public HttpFullBody { +public: + QueryData(); + QueryData(std::string data); + std::string contentType(); + std::map> data; + std::string toQuery(); + void append(std::string key, std::string value); + void set(std::string key, std::string value); +protected: + std::string body(); +}; + +class AIException: std::exception { +public: + AIException(int code); + const char* what(); +private: + int code; +}; + +class SocketError: std::exception { +public: + SocketError(); + const char* what(); +private: + int code; + std::string message; +}; + +class Socket { +public: + Socket(std::string host, std::string protocol); + ~Socket(); + void connect(); + size_t send(const char* data, size_t len, int flags = 0); + template + size_t send(const char(&data)[T], int flags = 0) { + return this->send(data, T, flags); + } + size_t send(std::string data, int flags = 0); + size_t recv(char* data, size_t len, int flags = 0); + template + size_t recv(char(&data)[T], int flags = 0) { + return this->recv(data, T, flags); + } + std::string recv(size_t len, int flags = 0); + void close(); +private: +#if HAVE_OPENSSL + SSL_CTX* ssl_ctx = nullptr; + BIO* web = nullptr; + SSL* ssl = nullptr; +#endif + std::string host; + std::string protocol; + int socket = -1; + bool closed = false; + addrinfo* addr = nullptr; +}; + +class Request { +public: + Request(std::string host, std::string path, std::string method, HeaderMap headers, HttpClientOptions options); + ~Request(); + void send(); + HeaderMap headers; + HttpClientOptions options; + void setBody(HttpBody* body) { + if (this->body != nullptr) delete this->body; + this->body = body; + } +private: + HttpBody* body = nullptr; + std::string host; + std::string path; + std::string method; +}; + +class Response { +public: + Response(Request req, Socket socket); +private: + Request req; + Socket socket; +}; + +class HttpClient { +public: + HttpClient(std::string host); + Request request(std::string path, std::string method); + HttpClientOptions options; + HeaderMap headers; +private: + std::string host; +}; + +#endif diff --git a/str_util.cpp b/str_util.cpp index daf79de..20edb4c 100644 --- a/str_util.cpp +++ b/str_util.cpp @@ -18,6 +18,18 @@ bool str_util::tolowercase(std::string ori, std::string& result) { } } +bool str_util::touppercase(std::string ori, std::string& result) { + char* tmp = nullptr; + auto re = cstr_touppercase(ori.c_str(), ori.length(), &tmp); + if (re) { + result = std::string(tmp, ori.length()); + free(tmp); + return true; + } else { + return false; + } +} + std::string str_util::str_replace(std::string input, std::string pattern, std::string new_content) { auto loc = input.find(pattern, 0); auto len = pattern.length(); diff --git a/str_util.h b/str_util.h index 48529e2..156fa64 100644 --- a/str_util.h +++ b/str_util.h @@ -12,6 +12,13 @@ namespace str_util { * @return true if successed. */ bool tolowercase(std::string ori, std::string& result); + /** + * @brief Convert string to uppercase + * @param ori Origin string + * @param result Output string. + * @return true if successed. + */ + bool touppercase(std::string ori, std::string& result); /** * @brief Replace all pattern to new_content * @param input Input string diff --git a/utils_config.h.in b/utils_config.h.in index e9d3027..5181966 100644 --- a/utils_config.h.in +++ b/utils_config.h.in @@ -17,3 +17,4 @@ #cmakedefine HAVE__STRNICMP @HAVE__STRNICMP@ #cmakedefine HAVE_STRNCASECMP @HAVE_STRNCASECMP@ #cmakedefine HAVE_FCLOSEALL @HAVE_FCLOSEALL@ +#cmakedefine HAVE_OPENSSL @HAVE_OPENSSL@