Skip to content

Commit

Permalink
feat: Add client job implementation and driver start job implementati…
Browse files Browse the repository at this point in the history
…on (#43)
  • Loading branch information
sitaowang1998 authored Jan 3, 2025
1 parent c7da559 commit c255a2a
Show file tree
Hide file tree
Showing 16 changed files with 869 additions and 65 deletions.
27 changes: 8 additions & 19 deletions src/spider/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ target_link_libraries(
spider_task_executor
PRIVATE
spider_core
spider_client_lib
spider_client
)
target_link_libraries(
spider_task_executor
Expand Down Expand Up @@ -131,7 +131,9 @@ set(SPIDER_CLIENT_SHARED_SOURCES
set(SPIDER_CLIENT_SHARED_HEADERS
client/Data.hpp
client/Driver.hpp
client/Job.hpp
client/task.hpp
client/spider.hpp
client/TaskContext.hpp
client/TaskGraph.hpp
client/type_utils.hpp
Expand All @@ -143,28 +145,15 @@ set(SPIDER_CLIENT_SHARED_HEADERS
"spider client shared header files"
)

add_library(spider_client_lib)
target_sources(spider_client_lib PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES})
target_sources(spider_client_lib PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS})
add_library(spider_client)
target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES})
target_sources(spider_client PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS})
target_link_libraries(
spider_client_lib
spider_client
PUBLIC
spider_core
Boost::boost
absl::flat_hash_map
)

set(SPIDER_CLIENT_SOURCES CACHE INTERNAL "spider client source files")

set(SPIDER_CLIENT_HEADERS
client/spider.hpp
client/Job.hpp
CACHE INTERNAL
"spider client header files"
)

add_library(spider_client)
target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SOURCES})
target_sources(spider_client PUBLIC ${SPIDER_CLIENT_HEADERS})
target_link_libraries(spider_client PRIVATE spider_core)
target_link_libraries(spider_client PUBLIC spider_client_lib)
add_library(spider::spider ALIAS spider_client)
95 changes: 90 additions & 5 deletions src/spider/client/Driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
#include <stdexcept>
#include <string>
#include <thread>
#include <tuple>
#include <type_traits>
#include <vector>

#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid.hpp>
#include <fmt/format.h>

#include "../core/Error.hpp"
#include "../core/TaskGraphImpl.hpp"
#include "../io/Serializer.hpp"
#include "../worker/FunctionManager.hpp"
#include "Data.hpp"
#include "Exception.hpp"
#include "Job.hpp"
#include "task.hpp"
#include "TaskGraph.hpp"
Expand All @@ -39,6 +45,8 @@ namespace spider {
namespace core {
class MetadataStorage;
class DataStorage;
class Task;
class TaskGraph;
} // namespace core

/**
Expand Down Expand Up @@ -126,28 +134,98 @@ class Driver {
*
* @tparam ReturnType
* @tparam Params
* @tparam Inputs
* @param task
* @param inputs
* @return A job representing the running task.
* @throw spider::ConnectionException
*/
template <TaskIo ReturnType, TaskIo... Params>
template <TaskIo ReturnType, TaskIo... Params, TaskIo... Inputs>
auto
start(TaskFunction<ReturnType, Params...> const& task, Params&&... inputs) -> Job<ReturnType>;
start(TaskFunction<ReturnType, Params...> const& task, Inputs&&... inputs) -> Job<ReturnType> {
// Check input type
static_assert(
sizeof...(Inputs) == sizeof...(Params),
"Number of inputs must match number of parameters."
);
for_n<sizeof...(Inputs)>([&](auto i) {
using InputType = std::tuple_element_t<i.cValue, std::tuple<Inputs...>>;
using ParamType = std::tuple_element_t<i.cValue, std::tuple<Params...>>;
if constexpr (!std::is_same_v<
std::remove_cvref_t<InputType>,
std::remove_cvref_t<ParamType>>)
{
throw std::invalid_argument("Input type does not match parameter type.");
}
});

std::optional<core::Task> optional_task = core::TaskGraphImpl::create_task(task);
if (!optional_task.has_value()) {
throw std::invalid_argument("Failed to create task.");
}
core::Task& new_task = optional_task.value();
if (!core::TaskGraphImpl::task_add_input(new_task, std::forward<Inputs>(inputs)...)) {
throw std::invalid_argument("Failed to add inputs to task.");
}
boost::uuids::random_generator gen;
boost::uuids::uuid const job_id = gen();
core::TaskGraph graph;
graph.add_task(new_task);
graph.add_input_task(new_task.get_id());
graph.add_output_task(new_task.get_id());
core::StorageErr err = m_metadata_storage->add_job(job_id, m_id, graph);
if (!err.success()) {
throw ConnectionException(fmt::format("Failed to start job: {}", err.description));
}

return Job<ReturnType>{job_id, m_metadata_storage, m_data_storage};
}

/**
* Starts running a task graph with the given inputs on Spider.
*
* @tparam ReturnType
* @tparam Params
* @tparam Inputs
* @param graph
* @param inputs
* @return A job representing the running task graph.
* @throw spider::ConnectionException
*/
template <TaskIo ReturnType, TaskIo... Params>
template <TaskIo ReturnType, TaskIo... Params, TaskIo... Inputs>
auto
start(TaskGraph<ReturnType(Params...)> const& graph, Params&&... inputs) -> Job<ReturnType>;
start(TaskGraph<ReturnType, Params...> const& graph, Inputs&&... inputs) -> Job<ReturnType> {
// Check input type
static_assert(
sizeof...(Inputs) == sizeof...(Params),
"Number of inputs must match number of parameters."
);
for_n<sizeof...(Inputs)>([&](auto i) {
using InputType = std::tuple_element_t<i.cValue, std::tuple<Inputs...>>;
using ParamType = std::tuple_element_t<i.cValue, std::tuple<Params...>>;
if constexpr (!std::is_same_v<
std::remove_cvref_t<InputType>,
std::remove_cvref_t<ParamType>>)
{
throw std::invalid_argument("Input type does not match parameter type.");
}
});

if (!graph.m_impl->add_inputs(std::forward<Inputs>(inputs)...)) {
throw std::invalid_argument("Failed to add inputs to task graph.");
}
// Reset ids in case the same graph is submitted before
graph.m_impl->reset_ids();
boost::uuids::random_generator gen;
boost::uuids::uuid const job_id = gen();
core::StorageErr const err
= m_metadata_storage->add_job(job_id, m_id, graph.m_impl->get_graph());
if (!err.success()) {
throw ConnectionException(fmt::format("Failed to start job: {}", err.description));
}

return Job<ReturnType>{job_id, m_metadata_storage, m_data_storage};
}

/**
* Gets all scheduled and running jobs started by drivers with the current client's ID.
Expand All @@ -157,7 +235,14 @@ class Driver {
* @return IDs of the jobs.
* @throw spider::ConnectionException
*/
auto get_jobs() -> std::vector<boost::uuids::uuid>;
auto get_jobs() -> std::vector<boost::uuids::uuid> {
std::vector<boost::uuids::uuid> job_ids;
core::StorageErr const err = m_metadata_storage->get_jobs_by_client_id(m_id, &job_ids);
if (!err.success()) {
throw ConnectionException("Failed to get jobs.");
}
return job_ids;
}

private:
boost::uuids::uuid m_id;
Expand Down
Loading

0 comments on commit c255a2a

Please sign in to comment.