diff --git a/binary_search_tree.h b/binary_search_tree.h index 2179f72..400d48f 100644 --- a/binary_search_tree.h +++ b/binary_search_tree.h @@ -1,5 +1,6 @@ #ifndef _UTIL_BINARY_SEARCH_TREE_H #define _UTIL_BINARY_SEARCH_TREE_H +#include #include "binary_tree.h" template @@ -28,6 +29,116 @@ inline void binary_search_tree_clear(BinarySearchTree*& root, F free_key, binary_search_tree_clear(root, std::function(free_key), std::function(free_value)); } +template +bool binary_search_tree_delete(BinarySearchTree*& root, K key, V* deleted_value, std::function comp, std::function free_key = std::function(), std::function free_value = std::function()) { + if (!root) return false; + BinarySearchTree* cur = root, *prev = nullptr; + while (cur) { + int re = comp(key, cur->data.key); + if (re == 0) { + break; + } + prev = cur; + cur = re < 0 ? cur->left : cur->right; + } + if (!cur) return false; + if (binary_tree_is_leaf(cur)) { + if (prev) prev->left == cur ? prev->left = nullptr : prev->right = nullptr; + else root = nullptr; + } else if (!cur->right) { + if (prev) prev->left == cur ? prev->left = cur->left : prev->right = cur->left; + else root = cur->left; + } else if (!cur->left) { + if (prev) prev->left == cur ? prev->left = cur->right : prev->right = cur->right; + else root = cur->right; + } else { + prev = cur; + BinarySearchTree* tmp = cur->left; + while (tmp->right) { + prev = tmp; + tmp = tmp->right; + } + K tk = tmp->data.key; + V tv = tmp->data.value; + tmp->data.key = cur->data.key; + tmp->data.value = cur->data.value; + cur->data.key = tk; + cur->data.value = tv; + cur = tmp; + prev->left == cur ? prev->left = nullptr : prev->right = nullptr; + } + if (free_key) free_key(cur->data.key); + if (deleted_value) *deleted_value = cur->data.value; + if (!deleted_value && free_value) free_value(cur->data.value); + free(cur); + return true; +} + +template +inline bool binary_search_tree_delete(BinarySearchTree*& root, K key, V* deleted_value, F comp, G free_key, H free_value) { + return binary_search_tree_delete(root, key, deleted_value, std::function(comp), std::function(free_key), std::function(free_value)); +} + +template +bool binary_search_tree_delete(BinarySearchTree*& root, K key, V* deleted_value = nullptr) { + return binary_search_tree_delete(root, key, deleted_value, [](K k1, K k2) { + return k1 == k2 ? 0 : k1 < k2 ? -1 : 1; + }, nullptr, nullptr); +} + +template +bool binary_search_tree_get(BinarySearchTree* root, K key, V& value, std::function comp) { + if (!root) return false; + BinarySearchTree* cur = root; + while (cur) { + int re = comp(key, cur->data.key); + if (re == 0) { + value = cur->data.value; + return true; + } + cur = re < 0 ? cur->left : cur->right; + } + return false; +} + +template +inline bool binary_search_tree_get(BinarySearchTree* root, K key, V& value, F comp) { + return binary_search_tree_get(root, key, value, std::function(comp)); +} + +template +bool binary_search_tree_get(BinarySearchTree* root, K key, V& value) { + return binary_search_tree_get(root, key, value, [](K k1, K k2) { + return k1 == k2 ? 0 : k1 < k2 ? -1 : 1; + }); +} + +template +BinarySearchTree* binary_search_tree_get_node(BinarySearchTree* root, K key, std::function comp) { + if (!root) return nullptr; + BinarySearchTree* cur = root; + while (cur) { + int re = comp(key, cur->data.key); + if (re == 0) { + return cur; + } + cur = re < 0 ? cur->left : cur->right; + } + return nullptr; +} + +template +inline BinarySearchTree* binary_search_tree_get_node(BinarySearchTree* root, K key, F comp) { + return binary_search_tree_get_node(root, key, std::function(comp)); +} + +template +BinarySearchTree* binary_search_tree_get_node(BinarySearchTree* root, K key) { + return binary_search_tree_get_node(root, key, [](K k1, K k2) { + return k1 == k2 ? 0 : k1 < k2 ? -1 : 1; + }); +} + template void binary_search_tree_iter(BinarySearchTree* root, std::function callback, Args... args) { binary_tree_lnr(root, [&callback](BinarySearchTree* e, Args... args) { @@ -41,14 +152,14 @@ inline void binary_search_tree_iter(BinarySearchTree* root, F callback, Ar } template -bool binary_search_tree_insert(BinarySearchTree*& root, K key, V value, std::function comp, std::function free_key = std::function(), std::function free_value = std::function()) { +BinarySearchTree* binary_search_tree_insert(BinarySearchTree*& root, K key, V value, std::function comp, std::function free_key = std::function(), std::function free_value = std::function()) { if (!root) { root = binary_tree_new>({ key, value }); if (!root) { if (free_key) free_key(key); if (free_value) free_value(value); } - return root != nullptr; + return root; } BinarySearchTree* cur = root; int re = comp(key, cur->data.key); @@ -65,32 +176,86 @@ bool binary_search_tree_insert(BinarySearchTree*& root, K key, V value, st if (free_value) free_value(cur->data.value); if (free_key) free_key(key); cur->data.value = value; - return true; + return cur; } BinarySearchTree* node = binary_tree_new>({key, value}); if (!node) { if (free_key) free_key(key); if (free_value) free_value(value); - return false; + return nullptr; } if (re == -1) { cur->left = node; } else { cur->right = node; } - return true; + return node; } template -inline bool binary_search_tree_insert(BinarySearchTree*& root, K key, V value, F comp) { +inline BinarySearchTree* binary_search_tree_insert(BinarySearchTree*& root, K key, V value, F comp) { return binary_search_tree_insert(root, key, value, std::function(comp)); } template -bool binary_search_tree_insert(BinarySearchTree*& root, K key, V value) { +BinarySearchTree* binary_search_tree_insert(BinarySearchTree*& root, K key, V value) { return binary_search_tree_insert(root, key, value, [](K k1, K k2) { return k1 == k2 ? 0 : k1 < k2 ? -1 : 1; }); } +template > +class BinarySearchMap { +private: + BinarySearchTree* tree = nullptr; + std::function comp_func; + std::function free_key_func; + std::function free_value_func; +public: + BinarySearchMap(std::function free_key = std::function(), std::function free_value = std::function(), Compare cmp = Compare()) { + comp_func = std::function([&cmp] (K k1, K k2) { + return cmp(k1, k2) ? -1 : cmp(k2, k1) ? 1 : 0; + }); + free_key_func = free_key; + free_value_func = free_value; + } + template + BinarySearchMap(F free_key, G free_value = std::function(), Compare cmp = Compare()) : BinarySearchMap(std::function(free_key), std::function(free_value), cmp) {} + ~BinarySearchMap() { + clear(); + } + void inline clear() { + binary_search_tree_clear(tree, free_key_func, free_value_func); + } + bool inline del(K key, V* deleted_value = nullptr) { + return binary_search_tree_delete(tree, key, deleted_value, comp_func, free_key_func, free_value_func); + } + bool inline get(K key, V& value) { + return binary_search_tree_get(tree, key, value, comp_func); + } + bool inline insert(K key, V value) { + return binary_search_tree_insert(tree, key, value, comp_func, free_key_func, free_value_func) != nullptr; + } + template + void inline iter(std::function callback, Args... args) { + return binary_search_tree_iter(tree, callback, args...); + } + template + void inline iter(F callback, Args... args) { + return binary_search_tree_iter(tree, std::function(callback), args...); + } + V& operator[](K key) { + BinarySearchTree* node = binary_search_tree_get_node(tree, key, comp_func); + if (!node) { + V tmp; + BinarySearchTree* node = binary_search_tree_insert(tree, key, tmp, comp_func, free_key_func, free_value_func); + if (!node) { + throw std::runtime_error("Failed to insert new node"); + } + return node->data.value; + } + return node->data.value; + } +}; + #endif diff --git a/test/binary_tree_test.cpp b/test/binary_tree_test.cpp index 71295a8..46f0efc 100644 --- a/test/binary_tree_test.cpp +++ b/test/binary_tree_test.cpp @@ -68,5 +68,96 @@ TEST(BinaryTreeTest, BinarySearchTree1) { }); GTEST_ASSERT_EQ(keys, "20,33,40,77,100,120"); GTEST_ASSERT_EQ(values, "45,23,13,33,20,222"); + int v; + GTEST_ASSERT_EQ(binary_search_tree_get(tree, 100, v), true); + GTEST_ASSERT_EQ(v, 20); + GTEST_ASSERT_EQ(binary_search_tree_get(tree, 33, v), true); + GTEST_ASSERT_EQ(v, 23); + GTEST_ASSERT_EQ(binary_search_tree_get(tree, 119, v), false); + GTEST_ASSERT_EQ(v, 23); + GTEST_ASSERT_EQ(binary_search_tree_delete(tree, 77, &v), true); + GTEST_ASSERT_EQ(v, 33); + binary_search_tree_insert(tree, 27, 311); + keys = ""; + values = ""; + binary_search_tree_iter(tree, [&keys, &values](int key, int value) { + if (!keys.empty()) { + keys += ","; + values += ","; + } + keys += std::to_string(key); + values += std::to_string(value); + }); + GTEST_ASSERT_EQ(keys, "20,27,33,40,100,120"); + GTEST_ASSERT_EQ(values, "45,311,23,13,20,222"); + binary_search_tree_insert(tree, 80, 122); + GTEST_ASSERT_EQ(binary_search_tree_delete(tree, 33, &v), true); + GTEST_ASSERT_EQ(v, 23); + GTEST_ASSERT_EQ(binary_search_tree_delete(tree, 40, &v), true); + GTEST_ASSERT_EQ(v, 13); + keys = ""; + values = ""; + binary_search_tree_iter(tree, [&keys, &values](int key, int value) { + if (!keys.empty()) { + keys += ","; + values += ","; + } + keys += std::to_string(key); + values += std::to_string(value); + }); + GTEST_ASSERT_EQ(keys, "20,27,80,100,120"); + GTEST_ASSERT_EQ(values, "45,311,122,20,222"); + GTEST_ASSERT_EQ(binary_search_tree_delete(tree, 120), true); + GTEST_ASSERT_EQ(binary_search_tree_delete(tree, 120, &v), false); + GTEST_ASSERT_EQ(v, 13); + auto j = binary_search_tree_get_node(tree, 20); + GTEST_ASSERT_EQ(j->data.key, 20); binary_search_tree_clear(tree); } + +TEST(BinaryTreeTest, BinarySearchTree2) { + BinarySearchTree* tree = nullptr; + binary_search_tree_insert(tree, 100, 3); + int v; + GTEST_ASSERT_EQ(binary_search_tree_delete(tree, 100, &v), true); + GTEST_ASSERT_EQ(v, 3); + GTEST_ASSERT_EQ(tree, nullptr); + binary_search_tree_insert(tree, 2, 3); + binary_search_tree_insert(tree, 1, 2); + binary_search_tree_insert(tree, 3, 4); + GTEST_ASSERT_EQ(binary_search_tree_delete(tree, 2, &v), true); + GTEST_ASSERT_EQ(v, 3); + GTEST_ASSERT_EQ(tree->data.key, 1); + GTEST_ASSERT_EQ(tree->data.value, 2); + binary_search_tree_insert(tree, -100, 123); + binary_search_tree_insert(tree, -200, 200); + binary_search_tree_clear(tree); +} + +TEST(BinaryTreeTest, BinarySearchTree3) { + BinarySearchMap map; + map.insert(20, 30); + map.insert(30, 40); + map.insert(10, 20); + map[20] = 55; + map[33] = 44; + GTEST_ASSERT_EQ(map[20], 55); + GTEST_ASSERT_EQ(map[33], 44); + int v; + GTEST_ASSERT_EQ(map.get(33, v), true); + GTEST_ASSERT_EQ(v, 44); + GTEST_ASSERT_EQ(map.del(20, &v), true); + GTEST_ASSERT_EQ(v, 55); + std::string keys; + std::string values; + map.iter([&keys, &values](int key, int value) { + if (!keys.empty()) { + keys += ","; + values += ","; + } + keys += std::to_string(key); + values += std::to_string(value); + }); + GTEST_ASSERT_EQ(keys, "10,30,33"); + GTEST_ASSERT_EQ(values, "20,40,44"); +}