Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add job implementation and driver start job implementation #43

Merged
merged 11 commits into from
Jan 3, 2025
27 changes: 8 additions & 19 deletions src/spider/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ target_link_libraries(
spider_task_executor
PRIVATE
spider_core
spider_client_lib
spider_client
)
target_link_libraries(
spider_task_executor
Expand Down Expand Up @@ -131,7 +131,9 @@ set(SPIDER_CLIENT_SHARED_SOURCES
set(SPIDER_CLIENT_SHARED_HEADERS
client/Data.hpp
client/Driver.hpp
client/Job.hpp
client/task.hpp
client/spider.hpp
client/TaskContext.hpp
client/TaskGraph.hpp
client/type_utils.hpp
Expand All @@ -143,28 +145,15 @@ set(SPIDER_CLIENT_SHARED_HEADERS
"spider client shared header files"
)

add_library(spider_client_lib)
target_sources(spider_client_lib PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES})
target_sources(spider_client_lib PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS})
add_library(spider_client)
target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES})
target_sources(spider_client PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS})
target_link_libraries(
spider_client_lib
spider_client
PUBLIC
spider_core
Comment on lines +148 to +154
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add the missing library alias

According to the summary, an alias spider::spider should be defined for the spider_client library. This alias is missing from the implementation.

Add the following after the target_link_libraries block:

+add_library(spider::spider ALIAS spider_client)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
add_library(spider_client)
target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES})
target_sources(spider_client PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS})
target_link_libraries(
spider_client_lib
spider_client
PUBLIC
spider_core
add_library(spider_client)
target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SHARED_SOURCES})
target_sources(spider_client PUBLIC ${SPIDER_CLIENT_SHARED_HEADERS})
target_link_libraries(
spider_client
PUBLIC
spider_core
add_library(spider::spider ALIAS spider_client)

Boost::boost
absl::flat_hash_map
)

set(SPIDER_CLIENT_SOURCES CACHE INTERNAL "spider client source files")

set(SPIDER_CLIENT_HEADERS
client/spider.hpp
client/Job.hpp
CACHE INTERNAL
"spider client header files"
)

add_library(spider_client)
target_sources(spider_client PRIVATE ${SPIDER_CLIENT_SOURCES})
target_sources(spider_client PUBLIC ${SPIDER_CLIENT_HEADERS})
target_link_libraries(spider_client PRIVATE spider_core)
target_link_libraries(spider_client PUBLIC spider_client_lib)
add_library(spider::spider ALIAS spider_client)
95 changes: 90 additions & 5 deletions src/spider/client/Driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
#include <stdexcept>
#include <string>
#include <thread>
#include <tuple>
#include <type_traits>
#include <vector>

#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid.hpp>
#include <fmt/format.h>

#include "../core/Error.hpp"
#include "../core/TaskGraphImpl.hpp"
#include "../io/Serializer.hpp"
#include "../worker/FunctionManager.hpp"
#include "Data.hpp"
#include "Exception.hpp"
#include "Job.hpp"
#include "task.hpp"
#include "TaskGraph.hpp"
Expand All @@ -39,6 +45,8 @@ namespace spider {
namespace core {
class MetadataStorage;
class DataStorage;
class Task;
class TaskGraph;
} // namespace core

/**
Expand Down Expand Up @@ -126,28 +134,98 @@ class Driver {
*
* @tparam ReturnType
* @tparam Params
* @tparam Inputs
* @param task
* @param inputs
* @return A job representing the running task.
* @throw spider::ConnectionException
*/
template <TaskIo ReturnType, TaskIo... Params>
template <TaskIo ReturnType, TaskIo... Params, TaskIo... Inputs>
auto
start(TaskFunction<ReturnType, Params...> const& task, Params&&... inputs) -> Job<ReturnType>;
start(TaskFunction<ReturnType, Params...> const& task, Inputs&&... inputs) -> Job<ReturnType> {
// Check input type
static_assert(
sizeof...(Inputs) == sizeof...(Params),
"Number of inputs must match number of parameters."
);
for_n<sizeof...(Inputs)>([&](auto i) {
using InputType = std::tuple_element_t<i.cValue, std::tuple<Inputs...>>;
using ParamType = std::tuple_element_t<i.cValue, std::tuple<Params...>>;
if constexpr (!std::is_same_v<
std::remove_cvref_t<InputType>,
std::remove_cvref_t<ParamType>>)
{
throw std::invalid_argument("Input type does not match parameter type.");
}
});

std::optional<core::Task> optional_task = core::TaskGraphImpl::create_task(task);
if (!optional_task.has_value()) {
throw std::invalid_argument("Failed to create task.");
}
core::Task& new_task = optional_task.value();
if (!core::TaskGraphImpl::task_add_input(new_task, std::forward<Inputs>(inputs)...)) {
throw std::invalid_argument("Failed to add inputs to task.");
}
boost::uuids::random_generator gen;
boost::uuids::uuid const job_id = gen();
core::TaskGraph graph;
graph.add_task(new_task);
graph.add_input_task(new_task.get_id());
graph.add_output_task(new_task.get_id());
core::StorageErr err = m_metadata_storage->add_job(job_id, m_id, graph);
if (!err.success()) {
throw ConnectionException(fmt::format("Failed to start job: {}", err.description));
}

return Job<ReturnType>{job_id, m_metadata_storage, m_data_storage};
}

/**
* Starts running a task graph with the given inputs on Spider.
*
* @tparam ReturnType
* @tparam Params
* @tparam Inputs
* @param graph
* @param inputs
* @return A job representing the running task graph.
* @throw spider::ConnectionException
*/
template <TaskIo ReturnType, TaskIo... Params>
template <TaskIo ReturnType, TaskIo... Params, TaskIo... Inputs>
auto
start(TaskGraph<ReturnType(Params...)> const& graph, Params&&... inputs) -> Job<ReturnType>;
start(TaskGraph<ReturnType, Params...> const& graph, Inputs&&... inputs) -> Job<ReturnType> {
// Check input type
static_assert(
sizeof...(Inputs) == sizeof...(Params),
"Number of inputs must match number of parameters."
);
for_n<sizeof...(Inputs)>([&](auto i) {
using InputType = std::tuple_element_t<i.cValue, std::tuple<Inputs...>>;
using ParamType = std::tuple_element_t<i.cValue, std::tuple<Params...>>;
if constexpr (!std::is_same_v<
std::remove_cvref_t<InputType>,
std::remove_cvref_t<ParamType>>)
{
throw std::invalid_argument("Input type does not match parameter type.");
}
});

if (!graph.m_impl->add_inputs(std::forward<Inputs>(inputs)...)) {
throw std::invalid_argument("Failed to add inputs to task graph.");
}
// Reset ids in case the same graph is submitted before
graph.m_impl->reset_ids();
boost::uuids::random_generator gen;
boost::uuids::uuid const job_id = gen();
core::StorageErr const err
= m_metadata_storage->add_job(job_id, m_id, graph.m_impl->get_graph());
if (!err.success()) {
throw ConnectionException(fmt::format("Failed to start job: {}", err.description));
}

return Job<ReturnType>{job_id, m_metadata_storage, m_data_storage};
}

/**
* Gets all scheduled and running jobs started by drivers with the current client's ID.
Expand All @@ -157,7 +235,14 @@ class Driver {
* @return IDs of the jobs.
* @throw spider::ConnectionException
*/
auto get_jobs() -> std::vector<boost::uuids::uuid>;
auto get_jobs() -> std::vector<boost::uuids::uuid> {
std::vector<boost::uuids::uuid> job_ids;
core::StorageErr const err = m_metadata_storage->get_jobs_by_client_id(m_id, &job_ids);
if (!err.success()) {
throw ConnectionException("Failed to get jobs.");
}
return job_ids;
}

private:
boost::uuids::uuid m_id;
Expand Down
Loading
Loading