diff --git a/src/spider/CMakeLists.txt b/src/spider/CMakeLists.txt index da1e3f23..c03ddae6 100644 --- a/src/spider/CMakeLists.txt +++ b/src/spider/CMakeLists.txt @@ -4,6 +4,7 @@ set(SPIDER_CORE_SOURCES core/DriverCleaner.cpp core/JobCleaner.cpp core/Task.cpp + core/JobRecovery.cpp storage/mysql/MySqlConnection.cpp storage/mysql/MySqlStorageFactory.cpp storage/mysql/MySqlJobSubmissionBatch.cpp @@ -27,6 +28,7 @@ set(SPIDER_CORE_HEADERS core/Task.hpp core/TaskGraph.hpp core/JobMetadata.hpp + core/JobRecovery.hpp io/BoostAsio.hpp io/MsgPack.hpp io/msgpack_message.hpp diff --git a/src/spider/client/Data.hpp b/src/spider/client/Data.hpp index b329e439..b5150373 100644 --- a/src/spider/client/Data.hpp +++ b/src/spider/client/Data.hpp @@ -78,6 +78,26 @@ class Data { m_data_store->set_data_locality(*conn, *m_impl); } + /** + * Sets the data as checkpointed, indicating the data should not be cleaned up. + * + * @throw spider::ConnectionException + */ + void set_checkpointed() { + m_impl->set_persisted(true); + if (nullptr != m_connection) { + m_data_store->set_data_persisted(*m_connection, *m_impl); + return; + } + std::variant, core::StorageErr> conn_result + = m_storage_factory->provide_storage_connection(); + if (std::holds_alternative(conn_result)) { + throw ConnectionException(std::get(conn_result).description); + } + auto conn = std::move(std::get>(conn_result)); + m_data_store->set_data_persisted(*conn, *m_impl); + } + class Builder { public: /** @@ -106,6 +126,16 @@ class Data { return *this; } + /** + * Sets the data as checkpointed, indicating the data should not be cleaned up. + * + * @return self + */ + auto set_checkpointed() -> Builder& { + m_persisted = true; + return *this; + } + /** * Builds the data object. * @@ -119,6 +149,7 @@ class Data { auto data = std::make_unique(std::string{buffer.data(), buffer.size()}); data->set_locality(m_nodes); data->set_hard_locality(m_hard_locality); + data->set_persisted(m_persisted); std::shared_ptr conn = m_connection; if (nullptr == conn) { std::variant, core::StorageErr> conn_result @@ -166,6 +197,7 @@ class Data { std::vector m_nodes; bool m_hard_locality = false; std::function m_cleanup_func; + bool m_persisted = false; std::shared_ptr m_data_store; std::shared_ptr m_storage_factory; diff --git a/src/spider/core/Data.hpp b/src/spider/core/Data.hpp index b5a1d556..1903aa95 100644 --- a/src/spider/core/Data.hpp +++ b/src/spider/core/Data.hpp @@ -31,11 +31,16 @@ class Data { void set_hard_locality(bool const hard) { m_hard_locality = hard; } + void set_persisted(bool const persisted) { this->m_persisted = persisted; } + + [[nodiscard]] auto is_persisted() const -> bool { return m_persisted; } + private: boost::uuids::uuid m_id; std::string m_value; std::vector m_locality; bool m_hard_locality = false; + bool m_persisted = false; void init_id() { boost::uuids::random_generator gen; diff --git a/src/spider/core/JobRecovery.cpp b/src/spider/core/JobRecovery.cpp new file mode 100644 index 00000000..62da1de8 --- /dev/null +++ b/src/spider/core/JobRecovery.cpp @@ -0,0 +1,167 @@ +#include "JobRecovery.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace spider::core { +JobRecovery::JobRecovery( + boost::uuids::uuid const job_id, + std::shared_ptr storage_connection, + std::shared_ptr data_store, + std::shared_ptr metadata_store +) + : m_job_id{job_id}, + m_conn{std::move(storage_connection)}, + m_data_store{std::move(data_store)}, + m_metadata_store{std::move(metadata_store)} {} + +auto JobRecovery::compute_graph() -> StorageErr { + StorageErr err = m_metadata_store->get_task_graph(*m_conn, m_job_id, &m_task_graph); + if (false == err.success()) { + return err; + } + + for (auto const& [task_id, task] : m_task_graph.get_tasks()) { + if (TaskState::Failed == task.get_state()) { + m_task_set.insert(task_id); + m_task_queue.push_front(task_id); + } + } + + while (!m_task_queue.empty()) { + auto const task_id = m_task_queue.front(); + m_task_queue.pop_front(); + err = process_task(task_id); + if (false == err.success()) { + return err; + } + } + + return StorageErr{}; +} + +auto JobRecovery::get_data(boost::uuids::uuid data_id, Data& data) -> StorageErr { + auto it = m_data_map.find(data_id); + if (it != m_data_map.end()) { + data = it->second; + return StorageErr{}; + } + StorageErr const err = m_data_store->get_data(*m_conn, data_id, &data); + if (err.success()) { + m_data_map[data_id] = data; + } + return err; +} + +auto JobRecovery::check_task_input( + Task const& task, + absl::flat_hash_set& not_persisted +) -> StorageErr { + for (auto const& task_input : task.get_inputs()) { + std::optional optional_data_id = task_input.get_data_id(); + if (false == optional_data_id.has_value()) { + continue; + } + boost::uuids::uuid const data_id = optional_data_id.value(); + Data data; + StorageErr err = get_data(data_id, data); + if (false == err.success()) { + return err; + } + if (false == data.is_persisted()) { + std::optional> optional_parent + = task_input.get_task_output(); + if (false == optional_parent.has_value()) { + continue; + } + boost::uuids::uuid const parent_task_id = std::get<0>(optional_parent.value()); + not_persisted.insert(parent_task_id); + } + } + return StorageErr{}; +} + +auto JobRecovery::process_task(boost::uuids::uuid task_id) -> StorageErr { + std::optional const optional_task = m_task_graph.get_task(task_id); + if (false == optional_task.has_value()) { + return StorageErr{ + StorageErrType::KeyNotFoundErr, + fmt::format("No task with id {}", to_string(task_id)) + }; + } + + for (boost::uuids::uuid const& child_id : m_task_graph.get_child_tasks(task_id)) { + if (m_task_set.contains(child_id)) { + continue; + } + std::optional optional_child_task = m_task_graph.get_task(child_id); + if (false == optional_child_task.has_value()) { + return StorageErr{ + StorageErrType::KeyNotFoundErr, + fmt::format("No task with id {}", to_string(child_id)) + }; + } + Task const& child_task = *optional_child_task.value(); + if (TaskState::Pending != child_task.get_state()) { + m_task_queue.push_back(child_id); + m_task_set.insert(child_id); + } + } + + Task const& task = *optional_task.value(); + absl::flat_hash_set not_persisted; + StorageErr err = check_task_input(task, not_persisted); + if (false == err.success()) { + return err; + } + + if (not_persisted.empty()) { + m_ready_tasks.insert(task_id); + } else { + m_pending_tasks.insert(task_id); + for (auto const& parent_id : not_persisted) { + if (false == m_task_set.contains(parent_id)) { + m_task_queue.push_back(parent_id); + m_task_set.insert(parent_id); + } + } + } + + return StorageErr{}; +} + +auto JobRecovery::get_pending_tasks() const -> std::vector { + std::vector pending_tasks; + pending_tasks.reserve(m_pending_tasks.size()); + for (auto const& task_id : m_pending_tasks) { + pending_tasks.push_back(task_id); + } + return pending_tasks; +} + +auto JobRecovery::get_ready_tasks() const -> std::vector { + std::vector ready_tasks; + ready_tasks.reserve(m_ready_tasks.size()); + for (auto const& task_id : m_ready_tasks) { + ready_tasks.push_back(task_id); + } + return ready_tasks; +} +} // namespace spider::core diff --git a/src/spider/core/JobRecovery.hpp b/src/spider/core/JobRecovery.hpp new file mode 100644 index 00000000..bfe51b41 --- /dev/null +++ b/src/spider/core/JobRecovery.hpp @@ -0,0 +1,95 @@ +#ifndef SPIDER_CORE_JOBRECOVERY_HPP +#define SPIDER_CORE_JOBRECOVERY_HPP + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace spider::core { +class JobRecovery { +public: + JobRecovery( + boost::uuids::uuid job_id, + std::shared_ptr storage_connection, + std::shared_ptr data_store, + std::shared_ptr metadata_store + ); + + /** + * Recover the job by loading the task graph and data from the storage, + * compute the minimal subgraph that contains all the failed tasks and the + * data across edge are all persisted. + * The result is stored in m_ready_tasks and m_pending_tasks, where + * m_ready_tasks contains the tasks on the boundary of the subgraph, and + * m_pending_tasks contains the tasks that are not ready to run yet. + * @return StorageErr + */ + auto compute_graph() -> StorageErr; + + [[nodiscard]] auto get_ready_tasks() const -> std::vector; + + [[nodiscard]] auto get_pending_tasks() const -> std::vector; + +private: + /** + * Check if task has any parents with non-persisted Data that feed into the task. + * @param task + * @param not_persisted Returns parents with non-persisted Data that feed into the task. + * @return + */ + auto check_task_input(Task const& task, absl::flat_hash_set& not_persisted) + -> StorageErr; + + /** + * Get the data associated with the given data_id. If the data is cached in + * m_data_map, return it. Otherwise, fetch it from the data store and cache + * it. + * @param data_id + * @param data + * @return + */ + auto get_data(boost::uuids::uuid data_id, Data& data) -> StorageErr; + + /* + * Process the task from the task queue with the given task_id. + * 1. Add the non-pending children of the task to the working queue. + * 2. Check if its inputs contains non-persisted Data. + * 3. If the task has non-persisted Data input and has parents, add it to pending tasks and add + * its parents with non-persistent Data to the working queue. + * 4. Otherwise, add it to ready tasks. + * + * @param task_id + * @return StorageErr + */ + auto process_task(boost::uuids::uuid task_id) -> StorageErr; + + boost::uuids::uuid m_job_id; + + std::shared_ptr m_conn; + std::shared_ptr m_data_store; + std::shared_ptr m_metadata_store; + + absl::flat_hash_map m_data_map; + + TaskGraph m_task_graph; + + absl::flat_hash_set m_task_set; + std::deque m_task_queue; + absl::flat_hash_set m_ready_tasks; + absl::flat_hash_set m_pending_tasks; +}; +} // namespace spider::core + +#endif diff --git a/src/spider/scheduler/scheduler.cpp b/src/spider/scheduler/scheduler.cpp index 8629f3b2..8c8f4011 100644 --- a/src/spider/scheduler/scheduler.cpp +++ b/src/spider/scheduler/scheduler.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -17,11 +18,13 @@ #include #include #include +#include #include // IWYU pragma: keep #include #include #include +#include #include // IWYU pragma: keep #include #include @@ -39,9 +42,11 @@ constexpr int cStorageConnectionErr = 3; constexpr int cSchedulerAddrErr = 4; constexpr int cStorageErr = 5; -constexpr int cCleanupInterval = 1000; +constexpr int cCleanupInterval = 10; constexpr int cRetryCount = 5; +constexpr int cRecoveryInterval = 1; + namespace { /* * Signal handler for SIGTERM. Sets the stop flag to request a stop. @@ -145,6 +150,67 @@ auto cleanup_loop( } } +auto recovery_loop( + std::shared_ptr const& storage_factory, + std::shared_ptr const& metadata_store, + std::shared_ptr const& data_store +) -> void { + while (!spider::core::StopFlag::is_stop_requested()) { + std::this_thread::sleep_for(std::chrono::seconds(cRecoveryInterval)); + spdlog::debug("Starting recovery"); + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + if (std::holds_alternative(conn_result)) { + spdlog::error( + "Failed to connect to storage: {}", + std::get(conn_result).description + ); + continue; + } + + std::shared_ptr const conn = std::move( + std::get>(conn_result) + ); + + std::vector job_ids; + spider::core::StorageErr err = metadata_store->get_failed_jobs(*conn, &job_ids); + if (false == err.success()) { + spdlog::error("Failed to get failed jobs: {}", err.description); + continue; + } + if (job_ids.empty()) { + spdlog::debug("No failed jobs found"); + continue; + } + for (boost::uuids::uuid const& job_id : job_ids) { + spdlog::debug("Recovering job: {}", to_string(job_id)); + spider::core::JobRecovery recovery{job_id, conn, data_store, metadata_store}; + err = recovery.compute_graph(); + if (false == err.success()) { + spdlog::error( + "Failed to compute graph for job {}: {}", + to_string(job_id), + err.description + ); + continue; + } + err = metadata_store->reset_tasks( + *conn, + recovery.get_ready_tasks(), + recovery.get_pending_tasks() + ); + if (false == err.success()) { + spdlog::error( + "Failed to reset tasks for job {}: {}", + to_string(job_id), + err.description + ); + continue; + } + } + } +} + constexpr int cSignalExitBase = 128; } // namespace @@ -261,8 +327,16 @@ auto main(int argc, char** argv) -> int { // Start a thread that periodically starts cleanup std::thread cleanup_thread{cleanup_loop, std::cref(storage_factory), std::cref(data_store)}; + std::thread recovery_thread{ + recovery_loop, + std::cref(storage_factory), + std::cref(metadata_store), + std::cref(data_store) + }; + heartbeat_thread.join(); cleanup_thread.join(); + recovery_thread.join(); server.stop(); } catch (std::system_error& e) { spdlog::error("Failed to join thread: {}", e.what()); diff --git a/src/spider/storage/DataStorage.hpp b/src/spider/storage/DataStorage.hpp index e2be1984..553005be 100644 --- a/src/spider/storage/DataStorage.hpp +++ b/src/spider/storage/DataStorage.hpp @@ -65,6 +65,7 @@ class DataStorage { ) -> StorageErr = 0; virtual auto set_data_locality(StorageConnection& conn, Data const& data) -> StorageErr = 0; + virtual auto set_data_persisted(StorageConnection& conn, Data const& data) -> StorageErr = 0; virtual auto remove_data(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr = 0; virtual auto add_task_reference(StorageConnection& conn, boost::uuids::uuid id, boost::uuids::uuid task_id) diff --git a/src/spider/storage/MetadataStorage.hpp b/src/spider/storage/MetadataStorage.hpp index 39a3e437..bd7b632e 100644 --- a/src/spider/storage/MetadataStorage.hpp +++ b/src/spider/storage/MetadataStorage.hpp @@ -78,6 +78,28 @@ class MetadataStorage { virtual auto remove_job(StorageConnection& conn, boost::uuids::uuid id) noexcept -> StorageErr = 0; virtual auto reset_job(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr = 0; + /** + * Gets all the jobs that contains a failed task. + * @param conn The storage connection. + * @param job_ids Returns the job ids of the jobs that contains a failed task. + * @return The storage error code. + */ + virtual auto get_failed_jobs(StorageConnection& conn, std::vector* job_ids) + -> StorageErr + = 0; + /** + * Resets tasks in a job to a previous runnable states. + * @param conn The storage connection. + * @param ready_tasks The tasks to be set to ready. + * @param pending_tasks The tasks to be set to pending. + * @return The storage error code. + */ + virtual auto reset_tasks( + StorageConnection& conn, + std::vector const& ready_tasks, + std::vector const& pending_tasks + ) -> StorageErr + = 0; virtual auto add_child(StorageConnection& conn, boost::uuids::uuid parent_id, Task const& child) -> StorageErr = 0; diff --git a/src/spider/storage/mysql/MySqlStorage.cpp b/src/spider/storage/mysql/MySqlStorage.cpp index a7c321d9..7fe7049d 100644 --- a/src/spider/storage/mysql/MySqlStorage.cpp +++ b/src/spider/storage/mysql/MySqlStorage.cpp @@ -1166,6 +1166,142 @@ auto MySqlMetadataStorage::reset_job(StorageConnection& conn, boost::uuids::uuid return StorageErr{}; } +auto MySqlMetadataStorage::get_failed_jobs( + StorageConnection& conn, + std::vector* job_ids +) -> StorageErr { + try { + std::unique_ptr statement{ + static_cast(conn)->createStatement() + }; + // Every task should have at least one retry + std::unique_ptr result{statement->executeQuery( + "SELECT `job_id` FROM `tasks` WHERE `state` = 'failed' AND (`retry` = 0 OR " + "`retry` < `max_retry`)" + )}; + job_ids->reserve(result->rowsCount()); + while (result->next()) { + job_ids->emplace_back(read_id(result->getBinaryStream("job_id"))); + } + } catch (sql::SQLException& e) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + static_cast(conn)->commit(); + return StorageErr{}; +} + +auto MySqlMetadataStorage::reset_tasks( + StorageConnection& conn, + std::vector const& ready_tasks, + std::vector const& pending_tasks +) -> StorageErr { + try { + // Reset ready tasks and update retry count + std::unique_ptr ready_statement( + static_cast(conn)->prepareStatement( + "UPDATE `tasks` SET `state` = 'ready', `retry` = `retry` + 1 WHERE `id` " + "= ?" + ) + ); + for (boost::uuids::uuid const& id : ready_tasks) { + sql::bytes id_bytes = uuid_get_bytes(id); + ready_statement->setBytes(1, &id_bytes); + ready_statement->addBatch(); + } + ready_statement->executeBatch(); + // Reset pending tasks + std::unique_ptr pending_statement( + static_cast(conn)->prepareStatement( + "UPDATE `tasks` SET `state` = 'pending', `retry` = `retry` + 1 WHERE `id` " + "= ?" + ) + ); + for (boost::uuids::uuid const& id : pending_tasks) { + sql::bytes id_bytes = uuid_get_bytes(id); + pending_statement->setBytes(1, &id_bytes); + pending_statement->addBatch(); + } + pending_statement->executeBatch(); + // Clear all the task outputs + std::unique_ptr output_statement( + static_cast(conn)->prepareStatement( + "UPDATE `task_outputs` SET `value` = NULL, `data_id` = NULL WHERE " + "`task_id` = ?" + ) + ); + for (boost::uuids::uuid const& id : ready_tasks) { + sql::bytes id_bytes = uuid_get_bytes(id); + output_statement->setBytes(1, &id_bytes); + output_statement->addBatch(); + } + for (boost::uuids::uuid const& id : pending_tasks) { + sql::bytes id_bytes = uuid_get_bytes(id); + output_statement->setBytes(1, &id_bytes); + output_statement->addBatch(); + } + output_statement->executeBatch(); + // Clear the task inputs value or data for pending tasks + std::unique_ptr input_statement( + static_cast(conn)->prepareStatement( + "UPDATE `task_inputs` SET `value` = NULL, `data_id` = NULL WHERE `task_id` " + "= ?" + ) + ); + for (boost::uuids::uuid const& id : pending_tasks) { + sql::bytes id_bytes = uuid_get_bytes(id); + input_statement->setBytes(1, &id_bytes); + input_statement->addBatch(); + } + input_statement->executeBatch(); + // Set the data to be not persisted if it is only owned by ready and pending tasks. + // 1. Get the list of data that are persisted and referenced by a task. + // 2. Filter out the data that is reference by driver of other tasks. + // 3. Set the data to be not persisted. + std::unique_ptr get_data_statement( + static_cast(conn)->prepareStatement( + "SELECT `data`.`id`, `data_ref_task`.`task_id` FROM `data` JOIN " + "`data_ref_task` ON `data`.`id` = `data_ref_task`.`id` WHERE " + "`data`.`persisted` = 1" + ) + ); + std::unique_ptr const data_res(get_data_statement->executeQuery()); + absl::flat_hash_set data_ids; + absl::flat_hash_set remove_data_ids; + while (data_res->next()) { + boost::uuids::uuid const data_id = read_id(data_res->getBinaryStream("id")); + boost::uuids::uuid const task_id = read_id(data_res->getBinaryStream("task_id")); + data_ids.insert(data_id); + if (std::ranges::find(ready_tasks, task_id) == ready_tasks.end() + && std::ranges::find(pending_tasks, task_id) == pending_tasks.end()) + { + remove_data_ids.insert(data_id); + } + } + for (boost::uuids::uuid const& id : remove_data_ids) { + data_ids.erase(id); + } + if (!data_ids.empty()) { + std::unique_ptr set_data_statement( + static_cast(conn)->prepareStatement( + "UPDATE `data` SET `persisted` = 0 WHERE `id` = ?" + ) + ); + for (boost::uuids::uuid const& id : data_ids) { + sql::bytes id_bytes = uuid_get_bytes(id); + set_data_statement->setBytes(1, &id_bytes); + set_data_statement->addBatch(); + } + set_data_statement->executeBatch(); + } + } catch (sql::SQLException& e) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + static_cast(conn)->commit(); + return StorageErr{}; +} + auto MySqlMetadataStorage::add_child( StorageConnection& conn, boost::uuids::uuid parent_id, @@ -1639,10 +1775,10 @@ auto MySqlMetadataStorage::task_finish( std::unique_ptr ready_statement( static_cast(conn)->prepareStatement( "UPDATE `tasks` SET `state` = 'ready' WHERE `id` IN (SELECT `task_id` FROM " - "`task_inputs` WHERE `output_task_id` = ?) AND `state` = 'pending' AND NOT " - "EXISTS (SELECT `task_id` FROM `task_inputs` WHERE `task_id` IN (SELECT " - "`task_id` FROM `task_inputs` WHERE `output_task_id` = ?) AND `value` IS " - "NULL AND `data_id` IS NULL)" + "`task_inputs` WHERE `output_task_id` = ?) AND `state` = 'pending' AND " + "`id` NOT IN (SELECT `task_id` FROM `task_inputs` WHERE `task_id` IN " + "(SELECT `task_id` FROM `task_inputs` WHERE `output_task_id` = ?) AND " + "`value` IS NULL AND `data_id` IS NULL)" ) ); ready_statement->setBytes(1, &task_id_bytes); @@ -1710,15 +1846,33 @@ auto MySqlMetadataStorage::task_fail( ); task_statement->setBytes(1, &task_id_bytes); task_statement->executeUpdate(); - // Set the job fails - std::unique_ptr const job_statement( + // Check if we run out of retry + std::unique_ptr const retry_statement( static_cast(conn)->prepareStatement( - "UPDATE `jobs` SET `state` = 'fail' WHERE `id` = (SELECT `job_id` FROM " - "`tasks` WHERE `id` = ?)" + "SELECT `retry`, `max_retry` FROM `tasks` WHERE `id` = ?" ) ); - job_statement->setBytes(1, &task_id_bytes); - job_statement->executeUpdate(); + retry_statement->setBytes(1, &task_id_bytes); + std::unique_ptr const retry_res{retry_statement->executeQuery()}; + if (retry_res->rowsCount() == 0) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::KeyNotFoundErr, "Task not found"}; + } + retry_res->next(); + int32_t const retry = retry_res->getInt("retry"); + int32_t const max_retry = retry_res->getInt("max_retry"); + if (retry == 0 || retry >= max_retry) { + // Set the job fails + std::unique_ptr const job_statement( + static_cast(conn)->prepareStatement( + "UPDATE `jobs` SET `state` = 'fail' WHERE `id` = (SELECT `job_id` " + "FROM " + "`tasks` WHERE `id` = ?)" + ) + ); + job_statement->setBytes(1, &task_id_bytes); + job_statement->executeUpdate(); + } } } catch (sql::SQLException& e) { spdlog::error("Task fail error: {}", e.what()); @@ -2010,13 +2164,15 @@ auto MySqlDataStorage::add_driver_data( try { std::unique_ptr statement( static_cast(conn)->prepareStatement( - "INSERT INTO `data` (`id`, `value`, `hard_locality`) VALUES(?, ?, ?)" + "INSERT INTO `data` (`id`, `value`, `hard_locality`, `persisted`) " + "VALUES(?, ?, ?, ?)" ) ); sql::bytes id_bytes = uuid_get_bytes(data.get_id()); statement->setBytes(1, &id_bytes); statement->setString(2, data.get_value()); statement->setBoolean(3, data.is_hard_locality()); + statement->setBoolean(4, data.is_persisted()); statement->executeUpdate(); for (std::string const& addr : data.get_locality()) { @@ -2058,13 +2214,15 @@ auto MySqlDataStorage::add_task_data( try { std::unique_ptr statement( static_cast(conn)->prepareStatement( - "INSERT INTO `data` (`id`, `value`, `hard_locality`) VALUES(?, ?, ?)" + "INSERT INTO `data` (`id`, `value`, `hard_locality`, `persisted`) " + "VALUES(?, ?, ?, ?)" ) ); sql::bytes id_bytes = uuid_get_bytes(data.get_id()); statement->setBytes(1, &id_bytes); statement->setString(2, data.get_value()); statement->setBoolean(3, data.is_hard_locality()); + statement->setBoolean(4, data.is_persisted()); statement->executeUpdate(); for (std::string const& addr : data.get_locality()) { @@ -2104,7 +2262,7 @@ auto MySqlDataStorage::get_data_with_locality( ) -> StorageErr { std::unique_ptr statement( static_cast(conn)->prepareStatement( - "SELECT `id`, `value`, `hard_locality` FROM `data` WHERE `id` = ?" + "SELECT `id`, `value`, `hard_locality`, `persisted` FROM `data` WHERE `id` = ?" ) ); sql::bytes id_bytes = uuid_get_bytes(id); @@ -2120,6 +2278,7 @@ auto MySqlDataStorage::get_data_with_locality( res->next(); *data = Data{id, get_sql_string(res->getString(2))}; data->set_hard_locality(res->getBoolean(3)); + data->set_persisted(res->getBoolean(4)); std::unique_ptr locality_statement( static_cast(conn)->prepareStatement( @@ -2249,6 +2408,25 @@ auto MySqlDataStorage::set_data_locality(StorageConnection& conn, Data const& da return StorageErr{}; } +auto MySqlDataStorage::set_data_persisted(StorageConnection& conn, Data const& data) -> StorageErr { + try { + sql::bytes id_bytes = uuid_get_bytes(data.get_id()); + std::unique_ptr statement( + static_cast(conn)->prepareStatement( + "UPDATE `data` SET `persisted` = ? WHERE `id` = ?" + ) + ); + statement->setBoolean(1, data.is_persisted()); + statement->setBytes(2, &id_bytes); + statement->executeUpdate(); + } catch (sql::SQLException& e) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + static_cast(conn)->commit(); + return StorageErr{}; +} + auto MySqlDataStorage::remove_data(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr { try { std::unique_ptr statement( diff --git a/src/spider/storage/mysql/MySqlStorage.hpp b/src/spider/storage/mysql/MySqlStorage.hpp index 90808dac..e185c24d 100644 --- a/src/spider/storage/mysql/MySqlStorage.hpp +++ b/src/spider/storage/mysql/MySqlStorage.hpp @@ -75,6 +75,13 @@ class MySqlMetadataStorage : public MetadataStorage { ) -> StorageErr override; auto remove_job(StorageConnection& conn, boost::uuids::uuid id) noexcept -> StorageErr override; auto reset_job(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr override; + auto get_failed_jobs(StorageConnection& conn, std::vector* job_ids) + -> StorageErr override; + auto reset_tasks( + StorageConnection& conn, + std::vector const& ready_tasks, + std::vector const& pending_tasks + ) -> StorageErr override; auto add_child(StorageConnection& conn, boost::uuids::uuid parent_id, Task const& child) -> StorageErr override; auto get_task(StorageConnection& conn, boost::uuids::uuid id, Task* task) @@ -163,6 +170,7 @@ class MySqlDataStorage : public DataStorage { Data* data ) -> StorageErr override; auto set_data_locality(StorageConnection& conn, Data const& data) -> StorageErr override; + auto set_data_persisted(StorageConnection& conn, Data const& data) -> StorageErr override; auto remove_data(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr override; auto add_task_reference(StorageConnection& conn, boost::uuids::uuid id, boost::uuids::uuid task_id) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a04b1c58..d273c154 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,7 @@ set(SPIDER_TEST_SOURCES worker/test-TaskExecutor.cpp worker/test-Process.cpp io/test-MsgpackMessage.cpp + scheduler/test-JobRecovery.cpp scheduler/test-SchedulerPolicy.cpp scheduler/test-SchedulerServer.cpp client/test-Driver.cpp diff --git a/tests/scheduler/test-JobRecovery.cpp b/tests/scheduler/test-JobRecovery.cpp new file mode 100644 index 00000000..9d27bdab --- /dev/null +++ b/tests/scheduler/test-JobRecovery.cpp @@ -0,0 +1,517 @@ +// NOLINTBEGIN(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-analyzer-optin.core.EnumCastOutOfRange) + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { +TEMPLATE_LIST_TEST_CASE("Recovery single task", "[storage]", spider::test::StorageFactoryTypeList) { + std::shared_ptr const storage_factory + = spider::test::create_storage_factory(); + std::shared_ptr const metadata_store + = storage_factory->provide_metadata_storage(); + std::shared_ptr const data_store + = storage_factory->provide_data_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + std::shared_ptr const conn + = std::move(std::get>(conn_result)); + + boost::uuids::random_generator gen; + + boost::uuids::uuid const job_id = gen(); + boost::uuids::uuid const client_id = gen(); + // Submit task without data + spider::core::Task task{"task"}; + REQUIRE(metadata_store->add_driver(*conn, spider::core::Driver{client_id}).success()); + task.add_input(spider::core::TaskInput{"10", "int"}); + task.add_output(spider::core::TaskOutput{"int"}); + spider::core::TaskGraph graph; + graph.add_task(task); + graph.add_input_task(task.get_id()); + graph.add_output_task(task.get_id()); + REQUIRE(metadata_store->add_job(*conn, job_id, client_id, graph).success()); + + // Set task as failed + REQUIRE(metadata_store->set_task_state(*conn, task.get_id(), spider::core::TaskState::Failed) + .success()); + + // Recover the job + spider::core::JobRecovery recovery{job_id, conn, data_store, metadata_store}; + REQUIRE(recovery.compute_graph().success()); + auto const ready_tasks = recovery.get_ready_tasks(); + auto const pending_tasks = recovery.get_pending_tasks(); + REQUIRE(ready_tasks.size() == 1); + REQUIRE(pending_tasks.empty()); + REQUIRE(ready_tasks[0] == task.get_id()); + + REQUIRE(metadata_store->remove_job(*conn, job_id).success()); + REQUIRE(metadata_store->remove_driver(*conn, client_id).success()); +} + +TEMPLATE_LIST_TEST_CASE( + "Recovery single task with data", + "[storage]", + spider::test::StorageFactoryTypeList +) { + std::shared_ptr const storage_factory + = spider::test::create_storage_factory(); + std::shared_ptr const metadata_store + = storage_factory->provide_metadata_storage(); + std::shared_ptr const data_store + = storage_factory->provide_data_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + std::shared_ptr const conn + = std::move(std::get>(conn_result)); + + boost::uuids::random_generator gen; + + boost::uuids::uuid const job_id = gen(); + boost::uuids::uuid const client_id = gen(); + // Submit task without data + spider::core::Task task{"task"}; + spider::core::Data const data{"data"}; + REQUIRE(metadata_store->add_driver(*conn, spider::core::Driver{client_id}).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data).success()); + task.add_input(spider::core::TaskInput{data.get_id()}); + task.add_output(spider::core::TaskOutput{"int"}); + spider::core::TaskGraph graph; + graph.add_task(task); + graph.add_input_task(task.get_id()); + graph.add_output_task(task.get_id()); + REQUIRE(metadata_store->add_job(*conn, job_id, client_id, graph).success()); + + // Set task as failed + REQUIRE(metadata_store->set_task_state(*conn, task.get_id(), spider::core::TaskState::Failed) + .success()); + + // Recover the job + spider::core::JobRecovery recovery{job_id, conn, data_store, metadata_store}; + REQUIRE(recovery.compute_graph().success()); + auto const ready_tasks = recovery.get_ready_tasks(); + auto const pending_tasks = recovery.get_pending_tasks(); + REQUIRE(ready_tasks.size() == 1); + REQUIRE(pending_tasks.empty()); + REQUIRE(ready_tasks[0] == task.get_id()); + + REQUIRE(metadata_store->remove_job(*conn, job_id).success()); + REQUIRE(data_store->remove_data(*conn, data.get_id()).success()); + REQUIRE(metadata_store->remove_driver(*conn, client_id).success()); +} + +TEMPLATE_LIST_TEST_CASE( + "Recovery single task with persisted data", + "[storage]", + spider::test::StorageFactoryTypeList +) { + std::shared_ptr const storage_factory + = spider::test::create_storage_factory(); + std::shared_ptr const metadata_store + = storage_factory->provide_metadata_storage(); + std::shared_ptr const data_store + = storage_factory->provide_data_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + std::shared_ptr const conn + = std::move(std::get>(conn_result)); + + boost::uuids::random_generator gen; + + boost::uuids::uuid const job_id = gen(); + boost::uuids::uuid const client_id = gen(); + // Submit task without data + spider::core::Task task{"task"}; + spider::core::Data data{"data"}; + data.set_persisted(true); + REQUIRE(metadata_store->add_driver(*conn, spider::core::Driver{client_id}).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data).success()); + task.add_input(spider::core::TaskInput{data.get_id()}); + task.add_output(spider::core::TaskOutput{"int"}); + spider::core::TaskGraph graph; + graph.add_task(task); + graph.add_input_task(task.get_id()); + graph.add_output_task(task.get_id()); + REQUIRE(metadata_store->add_job(*conn, job_id, client_id, graph).success()); + + // Set task as failed + REQUIRE(metadata_store->set_task_state(*conn, task.get_id(), spider::core::TaskState::Failed) + .success()); + + // Recover the job + spider::core::JobRecovery recovery{job_id, conn, data_store, metadata_store}; + REQUIRE(recovery.compute_graph().success()); + auto const ready_tasks = recovery.get_ready_tasks(); + auto const pending_tasks = recovery.get_pending_tasks(); + REQUIRE(ready_tasks.size() == 1); + REQUIRE(pending_tasks.empty()); + REQUIRE(ready_tasks[0] == task.get_id()); + + REQUIRE(metadata_store->remove_job(*conn, job_id).success()); + REQUIRE(data_store->remove_data(*conn, data.get_id()).success()); + REQUIRE(metadata_store->remove_driver(*conn, client_id).success()); +} + +/** + * Recovers a job with multiple tasks. The task graph is: + * \dot + * digraph task_graph { + * node [shape="rect"]; + * 1 [color="green"]; + * 2 [color="green"]; + * 3 [color="green"]; + * 4 [color="green"]; + * 6 [color="blue"]; + * 7 [color="blue"]; + * 1 -> 3; + * 1 -> 4; + * 2 -> 4; + * 2 -> 5; + * 3 -> 6 [style="dashed"]; + * 4 -> 7 [style="dashed"]; + * subgraph cluster_recovery { + * style=filled; + * color=yellow; + * 5 [color="green"]; + * 8 [color="red"]; + * 5 -> 8 [style="dashed"]; + * } + * } + * \enddot + */ +TEMPLATE_LIST_TEST_CASE( + "Recovery multiple tasks", + "[storage]", + spider::test::StorageFactoryTypeList +) { + std::shared_ptr const storage_factory + = spider::test::create_storage_factory(); + std::shared_ptr const metadata_store + = storage_factory->provide_metadata_storage(); + std::shared_ptr const data_store + = storage_factory->provide_data_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + std::shared_ptr const conn + = std::move(std::get>(conn_result)); + + boost::uuids::random_generator gen; + + boost::uuids::uuid const job_id = gen(); + boost::uuids::uuid const client_id = gen(); + // Build task graph with multiple tasks + spider::core::Task task1{"task1"}; + task1.add_input(spider::core::TaskInput{"10", "int"}); + spider::core::Data data1{"data1"}; + data1.set_persisted(true); + task1.add_output(spider::core::TaskOutput{data1.get_id()}); + spider::core::Task task2{"task2"}; + task2.add_input(spider::core::TaskInput{"10", "int"}); + spider::core::Data data2{"data2"}; + data2.set_persisted(true); + task2.add_output(spider::core::TaskOutput{data2.get_id()}); + spider::core::Task task3{"task3"}; + task3.add_input(spider::core::TaskInput{task1.get_id(), 0, ""}); + spider::core::Data const data3{"data3"}; + task3.add_output(spider::core::TaskOutput{data3.get_id()}); + spider::core::Task task4{"task4"}; + task4.add_input(spider::core::TaskInput{task1.get_id(), 0, ""}); + task4.add_input(spider::core::TaskInput{task2.get_id(), 0, ""}); + spider::core::Data const data4{"data4"}; + task4.add_output(spider::core::TaskOutput{data4.get_id()}); + spider::core::Task task5{"task5"}; + task5.add_input(spider::core::TaskInput{task2.get_id(), 0, ""}); + spider::core::Data const data5{"data5"}; + task5.add_output(spider::core::TaskOutput{data5.get_id()}); + spider::core::Task task6{"task6"}; + task6.add_input(spider::core::TaskInput{task3.get_id(), 0, ""}); + task6.add_output(spider::core::TaskOutput{"int"}); + spider::core::Task task7{"task7"}; + task7.add_input(spider::core::TaskInput{task4.get_id(), 0, ""}); + task7.add_output(spider::core::TaskOutput{"int"}); + spider::core::Task task8{"task8"}; + task8.add_input(spider::core::TaskInput{task5.get_id(), 0, ""}); + task8.add_output(spider::core::TaskOutput{"int"}); + spider::core::TaskGraph graph; + graph.add_task(task1); + graph.add_task(task2); + graph.add_task(task3); + graph.add_task(task4); + graph.add_task(task5); + graph.add_task(task6); + graph.add_task(task7); + graph.add_task(task8); + graph.add_input_task(task1.get_id()); + graph.add_input_task(task2.get_id()); + graph.add_output_task(task6.get_id()); + graph.add_output_task(task7.get_id()); + graph.add_output_task(task8.get_id()); + graph.add_dependency(task1.get_id(), task3.get_id()); + graph.add_dependency(task1.get_id(), task4.get_id()); + graph.add_dependency(task2.get_id(), task4.get_id()); + graph.add_dependency(task2.get_id(), task5.get_id()); + graph.add_dependency(task3.get_id(), task6.get_id()); + graph.add_dependency(task4.get_id(), task7.get_id()); + graph.add_dependency(task5.get_id(), task8.get_id()); + REQUIRE(metadata_store->add_driver(*conn, spider::core::Driver{client_id}).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data1).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data2).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data3).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data4).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data5).success()); + REQUIRE(metadata_store->add_job(*conn, job_id, client_id, graph).success()); + REQUIRE(metadata_store->set_task_running(*conn, task1.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task1.get_id()}, + {spider::core::TaskOutput{data1.get_id()}} + ) + .success()); + REQUIRE(metadata_store->set_task_running(*conn, task2.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task2.get_id()}, + {spider::core::TaskOutput{data2.get_id()}} + ) + .success()); + REQUIRE(metadata_store->set_task_running(*conn, task3.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task3.get_id()}, + {spider::core::TaskOutput{data3.get_id()}} + ) + .success()); + REQUIRE(metadata_store->set_task_running(*conn, task4.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task4.get_id()}, + {spider::core::TaskOutput{data4.get_id()}} + ) + .success()); + REQUIRE(metadata_store->set_task_running(*conn, task5.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task5.get_id()}, + {spider::core::TaskOutput{data5.get_id()}} + ) + .success()); + + REQUIRE(metadata_store->set_task_state(*conn, task8.get_id(), spider::core::TaskState::Failed) + .success()); + + spider::core::JobRecovery recovery{job_id, conn, data_store, metadata_store}; + REQUIRE(recovery.compute_graph().success()); + auto const ready_tasks = recovery.get_ready_tasks(); + auto const pending_tasks = recovery.get_pending_tasks(); + REQUIRE(ready_tasks.size() == 1); + REQUIRE(ready_tasks[0] == task5.get_id()); + REQUIRE(pending_tasks.size() == 1); + REQUIRE(pending_tasks[0] == task8.get_id()); + + REQUIRE(metadata_store->remove_job(*conn, job_id).success()); + REQUIRE(data_store->remove_data(*conn, data1.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data2.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data3.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data4.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data5.get_id()).success()); + REQUIRE(metadata_store->remove_driver(*conn, client_id).success()); +} + +/** + * Recovers a job with multiple tasks. The task graph is: + * \dot + * digraph task_graph { + * node [shape="rect"]; + * 1 [color="green"]; + * 3 [color="green"]; + * 6 [color="blue"]; + * 7 [color="blue"]; + * 1 -> 3; + * 1 -> 4; + * 3 -> 6 [style="dashed"]; + * 4 -> 7 [style="dashed"]; + * subgraph cluster_recovery { + * style=filled; + * color=yellow; + * 2 [color="green"]; + * 4 [color="blue"] + * 5 [color="green"]; + * 8 [color="red"]; + * 2 -> 4 [style="dashed"]; + * 2 -> 5 [style="dashed"]; + * 5 -> 8 [style="dashed"]; + * } + * } + * \enddot + */ +TEMPLATE_LIST_TEST_CASE( + "Recovery multiple tasks with children", + "[storage]", + spider::test::StorageFactoryTypeList +) { + std::shared_ptr const storage_factory + = spider::test::create_storage_factory(); + std::shared_ptr const metadata_store + = storage_factory->provide_metadata_storage(); + std::shared_ptr const data_store + = storage_factory->provide_data_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + std::shared_ptr const conn + = std::move(std::get>(conn_result)); + + boost::uuids::random_generator gen; + + boost::uuids::uuid const job_id = gen(); + boost::uuids::uuid const client_id = gen(); + // Build task graph with multiple tasks + spider::core::Task task1{"task1"}; + task1.add_input(spider::core::TaskInput{"10", "int"}); + spider::core::Data data1{"data1"}; + data1.set_persisted(true); + task1.add_output(spider::core::TaskOutput{data1.get_id()}); + spider::core::Task task2{"task2"}; + task2.add_input(spider::core::TaskInput{"10", "int"}); + spider::core::Data const data2{"data2"}; + task2.add_output(spider::core::TaskOutput{data2.get_id()}); + spider::core::Task task3{"task3"}; + task3.add_input(spider::core::TaskInput{task1.get_id(), 0, ""}); + spider::core::Data const data3{"data3"}; + task3.add_output(spider::core::TaskOutput{data3.get_id()}); + spider::core::Task task4{"task4"}; + task4.add_input(spider::core::TaskInput{task1.get_id(), 0, ""}); + task4.add_input(spider::core::TaskInput{task2.get_id(), 0, ""}); + spider::core::Data const data4{"data4"}; + task4.add_output(spider::core::TaskOutput{data4.get_id()}); + spider::core::Task task5{"task5"}; + task5.add_input(spider::core::TaskInput{task2.get_id(), 0, ""}); + spider::core::Data const data5{"data5"}; + task5.add_output(spider::core::TaskOutput{data5.get_id()}); + spider::core::Task task6{"task6"}; + task6.add_input(spider::core::TaskInput{task3.get_id(), 0, ""}); + task6.add_output(spider::core::TaskOutput{"int"}); + spider::core::Task task7{"task7"}; + task7.add_input(spider::core::TaskInput{task4.get_id(), 0, ""}); + task7.add_output(spider::core::TaskOutput{"int"}); + spider::core::Task task8{"task8"}; + task8.add_input(spider::core::TaskInput{task5.get_id(), 0, ""}); + task8.add_output(spider::core::TaskOutput{"int"}); + spider::core::TaskGraph graph; + graph.add_task(task1); + graph.add_task(task2); + graph.add_task(task3); + graph.add_task(task4); + graph.add_task(task5); + graph.add_task(task6); + graph.add_task(task7); + graph.add_task(task8); + graph.add_input_task(task1.get_id()); + graph.add_input_task(task2.get_id()); + graph.add_output_task(task6.get_id()); + graph.add_output_task(task7.get_id()); + graph.add_output_task(task8.get_id()); + graph.add_dependency(task1.get_id(), task3.get_id()); + graph.add_dependency(task1.get_id(), task4.get_id()); + graph.add_dependency(task2.get_id(), task4.get_id()); + graph.add_dependency(task2.get_id(), task5.get_id()); + graph.add_dependency(task3.get_id(), task6.get_id()); + graph.add_dependency(task4.get_id(), task7.get_id()); + graph.add_dependency(task5.get_id(), task8.get_id()); + REQUIRE(metadata_store->add_driver(*conn, spider::core::Driver{client_id}).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data1).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data2).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data3).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data4).success()); + REQUIRE(data_store->add_driver_data(*conn, client_id, data5).success()); + REQUIRE(metadata_store->add_job(*conn, job_id, client_id, graph).success()); + REQUIRE(metadata_store->set_task_running(*conn, task1.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task1.get_id()}, + {spider::core::TaskOutput{data1.get_id()}} + ) + .success()); + REQUIRE(metadata_store->set_task_running(*conn, task2.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task2.get_id()}, + {spider::core::TaskOutput{data2.get_id()}} + ) + .success()); + REQUIRE(metadata_store->set_task_running(*conn, task3.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task3.get_id()}, + {spider::core::TaskOutput{data3.get_id()}} + ) + .success()); + REQUIRE(metadata_store->set_task_running(*conn, task5.get_id()).success()); + REQUIRE(metadata_store + ->task_finish( + *conn, + spider::core::TaskInstance{task5.get_id()}, + {spider::core::TaskOutput{data5.get_id()}} + ) + .success()); + + REQUIRE(metadata_store->set_task_state(*conn, task8.get_id(), spider::core::TaskState::Failed) + .success()); + + spider::core::JobRecovery recovery{job_id, conn, data_store, metadata_store}; + REQUIRE(recovery.compute_graph().success()); + auto const ready_tasks = recovery.get_ready_tasks(); + auto const pending_tasks = recovery.get_pending_tasks(); + REQUIRE(ready_tasks.size() == 1); + REQUIRE(ready_tasks[0] == task2.get_id()); + REQUIRE(pending_tasks.size() == 3); + REQUIRE(pending_tasks.end() != std::ranges::find(pending_tasks, task4.get_id())); + REQUIRE(pending_tasks.end() != std::ranges::find(pending_tasks, task5.get_id())); + REQUIRE(pending_tasks.end() != std::ranges::find(pending_tasks, task8.get_id())); + + REQUIRE(metadata_store->remove_job(*conn, job_id).success()); + REQUIRE(data_store->remove_data(*conn, data1.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data2.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data3.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data4.get_id()).success()); + REQUIRE(data_store->remove_data(*conn, data5.get_id()).success()); + REQUIRE(metadata_store->remove_driver(*conn, client_id).success()); +} +} // namespace + +// NOLINTEND(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-analyzer-optin.core.EnumCastOutOfRange) diff --git a/tests/storage/test-DataStorage.cpp b/tests/storage/test-DataStorage.cpp index e415ed68..a6b1ee38 100644 --- a/tests/storage/test-DataStorage.cpp +++ b/tests/storage/test-DataStorage.cpp @@ -40,7 +40,7 @@ TEMPLATE_LIST_TEST_CASE( auto conn = std::move(std::get>(conn_result)); // Add driver and data - spider::core::Data const data{"value"}; + spider::core::Data data{"value"}; boost::uuids::random_generator gen; boost::uuids::uuid const driver_id = gen(); REQUIRE(metadata_storage->add_driver(*conn, spider::core::Driver{driver_id}).success()); @@ -56,6 +56,13 @@ TEMPLATE_LIST_TEST_CASE( REQUIRE(data_storage->get_data(*conn, data.get_id(), &result).success()); REQUIRE(spider::test::data_equal(data, result)); + // Set data persisted should succeed + data.set_persisted(true); + REQUIRE(data_storage->set_data_persisted(*conn, data).success()); + // Get data should match + REQUIRE(data_storage->get_data(*conn, data.get_id(), &result).success()); + REQUIRE(spider::test::data_equal(data, result)); + // Remove data should succeed REQUIRE(data_storage->remove_data(*conn, data.get_id()).success()); diff --git a/tests/storage/test-MetadataStorage.cpp b/tests/storage/test-MetadataStorage.cpp index 40ce7472..1fb14e6d 100644 --- a/tests/storage/test-MetadataStorage.cpp +++ b/tests/storage/test-MetadataStorage.cpp @@ -546,6 +546,92 @@ TEMPLATE_LIST_TEST_CASE("Job reset", "[storage]", spider::test::StorageFactoryTy REQUIRE(storage->remove_job(*conn, job_id).success()); } +TEMPLATE_LIST_TEST_CASE("Job partial reset", "[storage]", spider::test::StorageFactoryTypeList) { + std::unique_ptr storage_factory + = spider::test::create_storage_factory(); + std::unique_ptr storage + = storage_factory->provide_metadata_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + auto conn = std::move(std::get>(conn_result)); + + boost::uuids::random_generator gen; + boost::uuids::uuid const job_id = gen(); + + // Create a complicated task graph + spider::core::Task child_task{"child"}; + spider::core::Task parent_1{"p1"}; + spider::core::Task parent_2{"p2"}; + parent_1.add_input(spider::core::TaskInput{"1", "float"}); + parent_1.add_input(spider::core::TaskInput{"2", "float"}); + parent_2.add_input(spider::core::TaskInput{"3", "int"}); + parent_2.add_input(spider::core::TaskInput{"4", "int"}); + parent_1.add_output(spider::core::TaskOutput{"float"}); + parent_2.add_output(spider::core::TaskOutput{"int"}); + child_task.add_input(spider::core::TaskInput{parent_1.get_id(), 0, "float"}); + child_task.add_input(spider::core::TaskInput{parent_2.get_id(), 0, "int"}); + child_task.add_output(spider::core::TaskOutput{"float"}); + parent_1.set_max_retries(1); + parent_2.set_max_retries(1); + child_task.set_max_retries(1); + spider::core::TaskGraph graph; + // Add task and dependencies to task graph in wrong order + graph.add_task(child_task); + graph.add_task(parent_1); + graph.add_task(parent_2); + graph.add_dependency(parent_2.get_id(), child_task.get_id()); + graph.add_dependency(parent_1.get_id(), child_task.get_id()); + graph.add_input_task(parent_1.get_id()); + graph.add_input_task(parent_2.get_id()); + graph.add_output_task(child_task.get_id()); + // Submit job should success + REQUIRE(storage->add_job(*conn, job_id, gen(), graph).success()); + + // Task finish for parent 1 should succeed + spider::core::TaskInstance const parent_1_instance{gen(), parent_1.get_id()}; + REQUIRE(storage->set_task_state(*conn, parent_1.get_id(), spider::core::TaskState::Running) + .success()); + REQUIRE(storage->task_finish( + *conn, + parent_1_instance, + {spider::core::TaskOutput{"1.1", "float"}} + ) + .success()); + + // Job partial reset + REQUIRE(storage->reset_tasks(*conn, {parent_2.get_id()}, {child_task.get_id()}).success()); + // Parent tasks states should be ready and child task state should be waiting + // Parent tasks inputs should be available and child task inputs should be empty + // All tasks output should be empty + spider::core::Task res_task{""}; + REQUIRE(storage->get_task(*conn, parent_1.get_id(), &res_task).success()); + REQUIRE(res_task.get_state() == spider::core::TaskState::Succeed); + REQUIRE(res_task.get_num_inputs() == 2); + REQUIRE(res_task.get_input(0).get_value() == "1"); + REQUIRE(res_task.get_input(1).get_value() == "2"); + REQUIRE(res_task.get_num_outputs() == 1); + REQUIRE(res_task.get_output(0).get_value().has_value()); + REQUIRE(storage->get_task(*conn, parent_2.get_id(), &res_task).success()); + REQUIRE(res_task.get_state() == spider::core::TaskState::Ready); + REQUIRE(res_task.get_num_inputs() == 2); + REQUIRE(res_task.get_input(0).get_value() == "3"); + REQUIRE(res_task.get_input(1).get_value() == "4"); + REQUIRE(res_task.get_num_outputs() == 1); + REQUIRE(!res_task.get_output(0).get_value().has_value()); + REQUIRE(storage->get_task(*conn, child_task.get_id(), &res_task).success()); + REQUIRE(res_task.get_state() == spider::core::TaskState::Pending); + REQUIRE(res_task.get_num_inputs() == 2); + REQUIRE(!res_task.get_input(0).get_value().has_value()); + REQUIRE(!res_task.get_input(1).get_value().has_value()); + REQUIRE(res_task.get_num_outputs() == 1); + REQUIRE(!res_task.get_output(0).get_value().has_value()); + + // Clean up + REQUIRE(storage->remove_job(*conn, job_id).success()); +} + TEMPLATE_LIST_TEST_CASE( "Scheduler lease timeout", "[storage]", diff --git a/tests/utils/CoreDataUtils.hpp b/tests/utils/CoreDataUtils.hpp index 3136e6df..50d848dc 100644 --- a/tests/utils/CoreDataUtils.hpp +++ b/tests/utils/CoreDataUtils.hpp @@ -20,6 +20,10 @@ inline auto data_equal(core::Data const& d1, core::Data const& d2) -> bool { return false; } + if (d1.is_persisted() != d2.is_persisted()) { + return false; + } + return true; } } // namespace spider::test