diff --git a/src/spider/CMakeLists.txt b/src/spider/CMakeLists.txt index 6c08825..2bedefc 100644 --- a/src/spider/CMakeLists.txt +++ b/src/spider/CMakeLists.txt @@ -87,7 +87,7 @@ target_link_libraries( spider_task_executor PRIVATE spider_core - spider_client_lib + spider_client ) target_link_libraries( spider_task_executor @@ -131,7 +131,9 @@ set(SPIDER_CLIENT_SHARED_SOURCES set(SPIDER_CLIENT_SHARED_HEADERS client/Data.hpp client/Driver.hpp + client/Job.hpp client/task.hpp + client/spider.hpp client/TaskContext.hpp client/TaskGraph.hpp client/type_utils.hpp @@ -143,28 +145,15 @@ set(SPIDER_CLIENT_SHARED_HEADERS "spider client shared header files" ) -add_library(spider_client_lib) -target_sources(spider_client_lib PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES}) -target_sources(spider_client_lib PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS}) +add_library(spider_client) +target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES}) +target_sources(spider_client PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS}) target_link_libraries( - spider_client_lib + spider_client PUBLIC + spider_core Boost::boost absl::flat_hash_map ) -set(SPIDER_CLIENT_SOURCES CACHE INTERNAL "spider client source files") - -set(SPIDER_CLIENT_HEADERS - client/spider.hpp - client/Job.hpp - CACHE INTERNAL - "spider client header files" -) - -add_library(spider_client) -target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SOURCES}) -target_sources(spider_client PUBLIC ${SPIDER_CLIENT_HEADERS}) -target_link_libraries(spider_client PRIVATE spider_core) -target_link_libraries(spider_client PUBLIC spider_client_lib) add_library(spider::spider ALIAS spider_client) diff --git a/src/spider/client/Driver.hpp b/src/spider/client/Driver.hpp index 6e30a21..3e79393 100644 --- a/src/spider/client/Driver.hpp +++ b/src/spider/client/Driver.hpp @@ -6,14 +6,20 @@ #include #include #include +#include +#include #include +#include #include +#include +#include "../core/Error.hpp" #include "../core/TaskGraphImpl.hpp" #include "../io/Serializer.hpp" #include "../worker/FunctionManager.hpp" #include "Data.hpp" +#include "Exception.hpp" #include "Job.hpp" #include "task.hpp" #include "TaskGraph.hpp" @@ -39,6 +45,8 @@ namespace spider { namespace core { class MetadataStorage; class DataStorage; +class Task; +class TaskGraph; } // namespace core /** @@ -126,28 +134,98 @@ class Driver { * * @tparam ReturnType * @tparam Params + * @tparam Inputs * @param task * @param inputs * @return A job representing the running task. * @throw spider::ConnectionException */ - template + template auto - start(TaskFunction const& task, Params&&... inputs) -> Job; + start(TaskFunction const& task, Inputs&&... inputs) -> Job { + // Check input type + static_assert( + sizeof...(Inputs) == sizeof...(Params), + "Number of inputs must match number of parameters." + ); + for_n([&](auto i) { + using InputType = std::tuple_element_t>; + using ParamType = std::tuple_element_t>; + if constexpr (!std::is_same_v< + std::remove_cvref_t, + std::remove_cvref_t>) + { + throw std::invalid_argument("Input type does not match parameter type."); + } + }); + + std::optional optional_task = core::TaskGraphImpl::create_task(task); + if (!optional_task.has_value()) { + throw std::invalid_argument("Failed to create task."); + } + core::Task& new_task = optional_task.value(); + if (!core::TaskGraphImpl::task_add_input(new_task, std::forward(inputs)...)) { + throw std::invalid_argument("Failed to add inputs to task."); + } + boost::uuids::random_generator gen; + boost::uuids::uuid const job_id = gen(); + core::TaskGraph graph; + graph.add_task(new_task); + graph.add_input_task(new_task.get_id()); + graph.add_output_task(new_task.get_id()); + core::StorageErr err = m_metadata_storage->add_job(job_id, m_id, graph); + if (!err.success()) { + throw ConnectionException(fmt::format("Failed to start job: {}", err.description)); + } + + return Job{job_id, m_metadata_storage, m_data_storage}; + } /** * Starts running a task graph with the given inputs on Spider. * * @tparam ReturnType * @tparam Params + * @tparam Inputs * @param graph * @param inputs * @return A job representing the running task graph. * @throw spider::ConnectionException */ - template + template auto - start(TaskGraph const& graph, Params&&... inputs) -> Job; + start(TaskGraph const& graph, Inputs&&... inputs) -> Job { + // Check input type + static_assert( + sizeof...(Inputs) == sizeof...(Params), + "Number of inputs must match number of parameters." + ); + for_n([&](auto i) { + using InputType = std::tuple_element_t>; + using ParamType = std::tuple_element_t>; + if constexpr (!std::is_same_v< + std::remove_cvref_t, + std::remove_cvref_t>) + { + throw std::invalid_argument("Input type does not match parameter type."); + } + }); + + if (!graph.m_impl->add_inputs(std::forward(inputs)...)) { + throw std::invalid_argument("Failed to add inputs to task graph."); + } + // Reset ids in case the same graph is submitted before + graph.m_impl->reset_ids(); + boost::uuids::random_generator gen; + boost::uuids::uuid const job_id = gen(); + core::StorageErr const err + = m_metadata_storage->add_job(job_id, m_id, graph.m_impl->get_graph()); + if (!err.success()) { + throw ConnectionException(fmt::format("Failed to start job: {}", err.description)); + } + + return Job{job_id, m_metadata_storage, m_data_storage}; + } /** * Gets all scheduled and running jobs started by drivers with the current client's ID. @@ -157,7 +235,14 @@ class Driver { * @return IDs of the jobs. * @throw spider::ConnectionException */ - auto get_jobs() -> std::vector; + auto get_jobs() -> std::vector { + std::vector job_ids; + core::StorageErr const err = m_metadata_storage->get_jobs_by_client_id(m_id, &job_ids); + if (!err.success()) { + throw ConnectionException("Failed to get jobs."); + } + return job_ids; + } private: boost::uuids::uuid m_id; diff --git a/src/spider/client/Job.hpp b/src/spider/client/Job.hpp index 79e3244..fc6a242 100644 --- a/src/spider/client/Job.hpp +++ b/src/spider/client/Job.hpp @@ -1,13 +1,39 @@ #ifndef SPIDER_CLIENT_JOB_HPP #define SPIDER_CLIENT_JOB_HPP +#include +#include #include +#include +#include #include +#include +#include #include +#include +#include +#include + +#include "../core/DataImpl.hpp" +#include "../core/Error.hpp" +#include "../core/JobMetadata.hpp" +#include "../io/MsgPack.hpp" // IWYU pragma: keep +#include "../storage/MetadataStorage.hpp" +#include "Data.hpp" #include "task.hpp" +#include "type_utils.hpp" namespace spider { +namespace core { +class Data; +class DataStorage; +class MetadataStorage; +class Task; +class TaskOutput; +} // namespace core +class Driver; + // TODO: Use std::expected or Boost's outcome so that the user can get the result of the job in one // call rather than the current error-prone approach which requires that the user check the job's // status and then call the relevant method. @@ -32,7 +58,25 @@ class Job { * * @throw spider::ConnectionException */ - auto wait_complete(); + auto wait_complete() -> void { + bool complete = false; + core::StorageErr err = m_metadata_storage->get_job_complete(m_id, &complete); + if (!err.success()) { + throw ConnectionException{ + fmt::format("Failed to get job completion status: {}", err.description) + }; + } + while (!complete) { + constexpr int cSleepMs = 10; + std::this_thread::sleep_for(std::chrono::milliseconds(cSleepMs)); + err = m_metadata_storage->get_job_complete(m_id, &complete); + if (!err.success()) { + throw ConnectionException{ + fmt::format("Failed to get job completion status: {}", err.description) + }; + } + } + } /** * Cancels the job and waits for the job's tasks to be cancelled. @@ -45,8 +89,28 @@ class Job { * @return Status of the job. * @throw spider::ConnectionException */ - auto get_status() -> JobStatus; + auto get_status() -> JobStatus { + core::JobStatus status = core::JobStatus::Running; + core::StorageErr const err = m_metadata_storage->get_job_status(m_id, &status); + if (!err.success()) { + throw ConnectionException{fmt::format("Failed to get job status: {}", err.description)}; + } + switch (status) { + case core::JobStatus::Running: + return JobStatus::Running; + case core::JobStatus::Succeeded: + return JobStatus::Succeeded; + case core::JobStatus::Failed: + return JobStatus::Failed; + case core::JobStatus::Cancelled: + return JobStatus::Cancelled; + } + throw ConnectionException{ + fmt::format("Unknown job status: {}", static_cast(status)) + }; + } + // NOLINTBEGIN(readability-function-cognitive-complexity) /** * NOTE: It is undefined behavior to call this method for a job that is not in the `Succeeded` * state. @@ -54,7 +118,131 @@ class Job { * @return Result of the job. * @throw spider::ConnectionException */ - auto get_result() -> ReturnType; + auto get_result() -> ReturnType { + std::vector output_task_ids; + core::StorageErr err = m_metadata_storage->get_job_output_tasks(m_id, &output_task_ids); + if (!err.success()) { + throw ConnectionException{ + fmt::format("Failed to get job output tasks: {}", err.description) + }; + } + std::vector tasks; + for (auto const& id : output_task_ids) { + core::Task task{""}; + err = m_metadata_storage->get_task(id, &task); + if (!err.success()) { + throw ConnectionException{fmt::format("Failed to get task: {}", err.description)}; + } + tasks.push_back(task); + } + ReturnType result; + if constexpr (cIsSpecializationV) { + size_t task_index = 0; + size_t output_index = 0; + for_n>([&](auto i) { + using T = std::tuple_element_t; + if (task_index >= output_task_ids.size()) { + throw ConnectionException{fmt::format("Not enough output tasks for job result") + }; + } + core::Task const& task = tasks[task_index]; + if (output_index >= task.get_num_outputs()) { + throw ConnectionException{fmt::format("Not enough outputs for task")}; + } + core::TaskOutput const& output = task.get_output(output_index); + if (output.get_type() != typeid(T).name()) { + throw ConnectionException{fmt::format("Output type mismatch")}; + } + if constexpr (cIsSpecializationV) { + using DataType = ExtractTemplateParamT; + core::Data data; + std::optional const optional_data_id = output.get_data_id(); + if (!optional_data_id.has_value()) { + throw ConnectionException{fmt::format("Output data ID is missing")}; + } + err = m_data_storage->get_data(optional_data_id.value(), &data); + if (!err.success()) { + throw ConnectionException{ + fmt::format("Failed to get data: {}", err.description) + }; + } + std::get(result) = core::DataImpl::create_data( + std::make_unique(std::move(data)), + m_data_storage + ); + } else { + std::optional const optional_value = output.get_value(); + if (!optional_value.has_value()) { + throw ConnectionException{fmt::format("Output value is missing")}; + } + std::string const& value = optional_value.value(); + try { + msgpack::object_handle const handle + = msgpack::unpack(value.data(), value.size()); + msgpack::object const& obj = handle.get(); + std::get(result) = obj.as(); + } catch (msgpack::type_error const& e) { + throw ConnectionException{fmt::format("Failed to unpack data: {}", e.what()) + }; + } + } + output_index++; + if (output_index >= task.get_num_outputs()) { + task_index++; + output_index = 0; + } + }); + } else { + if (output_task_ids.size() != 1) { + throw ConnectionException{fmt::format("Expected one output task for job result")}; + } + core::Task task{""}; + err = m_metadata_storage->get_task(output_task_ids[0], &task); + if (!err.success()) { + throw ConnectionException{fmt::format("Failed to get task: {}", err.description)}; + } + if (task.get_num_outputs() != 1) { + throw ConnectionException{fmt::format("Expected one output for task")}; + } + core::TaskOutput const& output = task.get_output(0); + if (output.get_type() != typeid(ReturnType).name()) { + throw ConnectionException{fmt::format("Output type mismatch")}; + } + if constexpr (cIsSpecializationV) { + using DataType = ExtractTemplateParamT; + core::Data data; + std::optional const optional_data_id = output.get_data_id(); + if (!optional_data_id.has_value()) { + throw ConnectionException{fmt::format("Output data ID is missing")}; + } + err = m_data_storage->get_data(optional_data_id.value(), &data); + if (!err.success()) { + throw ConnectionException{fmt::format("Failed to get data: {}", err.description) + }; + } + return core::DataImpl::create_data( + std::make_unique(std::move(data)), + m_data_storage + ); + } else { + std::optional const optional_value = output.get_value(); + if (!optional_value.has_value()) { + throw ConnectionException{fmt::format("Output value is missing")}; + } + std::string const& value = optional_value.value(); + try { + msgpack::object_handle const handle + = msgpack::unpack(value.data(), value.size()); + msgpack::object const& obj = handle.get(); + return obj.as(); + } catch (msgpack::type_error const& e) { + throw ConnectionException{fmt::format("Failed to unpack data: {}", e.what())}; + } + } + } + } + + // NOLINTEND(readability-function-cognitive-complexity) /** * NOTE: It is undefined behavior to call this method for a job that is not in the `Failed` @@ -66,6 +254,20 @@ class Job { * @throw spider::ConnectionException */ auto get_error() -> std::pair; + +private: + Job(boost::uuids::uuid id, + std::shared_ptr metadata_storage, + std::shared_ptr data_storage) + : m_id{id}, + m_metadata_storage{std::move(metadata_storage)}, + m_data_storage{std::move(data_storage)} {} + + boost::uuids::uuid m_id; + std::shared_ptr m_metadata_storage; + std::shared_ptr m_data_storage; + + friend class Driver; }; } // namespace spider diff --git a/src/spider/client/type_utils.hpp b/src/spider/client/type_utils.hpp index 1a1a062..6c87cb1 100644 --- a/src/spider/client/type_utils.hpp +++ b/src/spider/client/type_utils.hpp @@ -1,7 +1,10 @@ #ifndef SPIDER_CLIENT_TYPE_UTILS_HPP #define SPIDER_CLIENT_TYPE_UTILS_HPP +#include +#include #include +#include namespace spider { // The template and partial specialization below check whether a given type is a specialization of @@ -27,5 +30,34 @@ struct IsSpecialization, type> : public std::true_type {}; template class template_type> inline constexpr bool cIsSpecializationV = IsSpecialization::value; + +template +struct Num { + static constexpr auto cValue = n; +}; + +template +void for_n(F func, std::index_sequence) { + (void)std::initializer_list{0, ((void)func(Num{}), 0)...}; +} + +template +void for_n(F func) { + for_n(func, std::make_index_sequence()); +} + +template +struct ExtractTemplateParam { + using Type = T; +}; + +template