Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add client driver implementation #42

Merged
merged 19 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/spider/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ set(SPIDER_CORE_SOURCES
set(SPIDER_CORE_HEADERS
core/Error.hpp
core/Data.hpp
core/DataImpl.hpp
core/Driver.hpp
core/KeyValueData.hpp
core/Task.hpp
core/TaskContextImpl.hpp
core/TaskGraph.hpp
core/JobMetadata.hpp
io/BoostAsio.hpp
Expand Down Expand Up @@ -124,7 +122,11 @@ target_link_libraries(
spdlog::spdlog
)

set(SPIDER_CLIENT_SHARED_SOURCES CACHE INTERNAL "spider client shared source files")
set(SPIDER_CLIENT_SHARED_SOURCES
client/Driver.cpp
CACHE INTERNAL
"spider client shared source files"
)

set(SPIDER_CLIENT_SHARED_HEADERS
client/Data.hpp
Expand All @@ -134,6 +136,9 @@ set(SPIDER_CLIENT_SHARED_HEADERS
client/TaskGraph.hpp
client/type_utils.hpp
client/Exception.hpp
core/DataImpl.hpp
core/TaskContextImpl.hpp
core/TaskGraphImpl.hpp
CACHE INTERNAL
"spider client shared header files"
)
Expand Down
12 changes: 7 additions & 5 deletions src/spider/client/Data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ namespace core {
class Data;
class DataStorage;
class DataImpl;
class TaskGraphImpl;
} // namespace core
class Driver;
class TaskContext;

/**
* A representation of data stored on external storage. This class allows the user to define:
Expand Down Expand Up @@ -120,7 +123,7 @@ class Data {
}
break;
}
return Data{data, m_data_store};
return Data{std::move(data), m_data_store};
}

private:
Expand All @@ -129,11 +132,9 @@ class Data {
TaskContext
};

explicit Builder(
std::shared_ptr<core::DataStorage> data_store,
Builder(std::shared_ptr<core::DataStorage> data_store,
boost::uuids::uuid const source_id,
DataSource const data_source
)
DataSource const data_source)
: m_data_store{std::move(data_store)},
m_source_id{source_id},
m_data_source{data_source} {}
Expand Down Expand Up @@ -163,6 +164,7 @@ class Data {
std::shared_ptr<core::DataStorage> m_data_store;

friend class core::DataImpl;
friend class core::TaskGraphImpl;
};
} // namespace spider

Expand Down
121 changes: 121 additions & 0 deletions src/spider/client/Driver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include "Driver.hpp"

#include <chrono>
#include <memory>
#include <optional>
#include <stop_token>
#include <string>
#include <thread>

#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid.hpp>

#include "../core/Driver.hpp"
#include "../core/Error.hpp"
#include "../core/KeyValueData.hpp"
#include "../io/BoostAsio.hpp" // IWYU pragma: keep
#include "../storage/MysqlStorage.hpp"
#include "Exception.hpp"

namespace spider {

Driver::Driver(std::string const& storage_url) {
boost::uuids::random_generator gen;
m_id = gen();

m_metadata_storage = std::make_shared<core::MySqlMetadataStorage>();
m_data_storage = std::make_shared<core::MySqlDataStorage>();
core::StorageErr err = m_metadata_storage->connect(storage_url);
if (!err.success()) {
throw ConnectionException(err.description);
}
err = m_data_storage->connect(storage_url);
if (!err.success()) {
throw ConnectionException(err.description);
}

std::optional<std::string> const optional_addr = core::get_address();
if (!optional_addr.has_value()) {
throw ConnectionException("Cannot get machine address");
}
std::string const& addr = optional_addr.value();
err = m_metadata_storage->add_driver(core::Driver{m_id, addr});
if (!err.success()) {
if (core::StorageErrType::DuplicateKeyErr == err.type) {
throw DriverIdInUseException(m_id);
}
throw ConnectionException(err.description);
}

// Start a thread to send heartbeats
// NOLINTNEXTLINE(performance-unnecessary-value-param)
m_heartbeat_thread = std::jthread([this](std::stop_token stoken) {
while (!stoken.stop_requested()) {
std::this_thread::sleep_for(std::chrono::seconds(1));
core::StorageErr const err = m_metadata_storage->update_heartbeat(m_id);
if (!err.success()) {
throw ConnectionException(err.description);
}
}
});
Comment on lines +50 to +60
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Heartbeat thread usage
Starting a continuous heartbeat in the constructor is helpful, but confirm that the destructor or move logic properly stops and joins the thread to prevent resource leaks.

}

Driver::Driver(std::string const& storage_url, boost::uuids::uuid const id) : m_id{id} {
m_metadata_storage = std::make_shared<core::MySqlMetadataStorage>();
m_data_storage = std::make_shared<core::MySqlDataStorage>();
core::StorageErr err = m_metadata_storage->connect(storage_url);
if (!err.success()) {
throw ConnectionException(err.description);
}
err = m_data_storage->connect(storage_url);
if (!err.success()) {
throw ConnectionException(err.description);
}

std::optional<std::string> const optional_addr = core::get_address();
if (!optional_addr.has_value()) {
throw ConnectionException("Cannot get machine address");
}
std::string const& addr = optional_addr.value();
err = m_metadata_storage->add_driver(core::Driver{m_id, addr});
if (!err.success()) {
if (core::StorageErrType::DuplicateKeyErr == err.type) {
throw DriverIdInUseException(m_id);
}
throw ConnectionException(err.description);
}

// Start a thread to send heartbeats
// NOLINTNEXTLINE(performance-unnecessary-value-param)
m_heartbeat_thread = std::jthread([this](std::stop_token stoken) {
while (!stoken.stop_requested()) {
std::this_thread::sleep_for(std::chrono::seconds(1));
core::StorageErr const err = m_metadata_storage->update_heartbeat(m_id);
if (!err.success()) {
throw ConnectionException(err.description);
}
}
});
}

auto Driver::kv_store_insert(std::string const& key, std::string const& value) -> void {
core::KeyValueData const kv_data{key, value, m_id};
core::StorageErr const err = m_data_storage->add_client_kv_data(kv_data);
if (!err.success()) {
throw ConnectionException(err.description);
}
}

auto Driver::kv_store_get(std::string const& key) -> std::optional<std::string> {
std::string value;
core::StorageErr const err = m_data_storage->get_client_kv_data(m_id, key, &value);
if (!err.success()) {
if (core::StorageErrType::KeyNotFoundErr == err.type) {
return std::nullopt;
}
throw ConnectionException(err.description);
}
return value;
}

} // namespace spider
41 changes: 32 additions & 9 deletions src/spider/client/Driver.hpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
#ifndef SPIDER_CLIENT_DRIVER_HPP
#define SPIDER_CLIENT_DRIVER_HPP

#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <thread>
#include <vector>

#include <boost/uuid/uuid.hpp>

#include "../core/TaskGraphImpl.hpp"
#include "../io/Serializer.hpp"
#include "../worker/FunctionManager.hpp"
#include "Data.hpp"
#include "Job.hpp"
#include "task.hpp"
#include "TaskGraph.hpp"
#include "utility"

/**
* Registers a Task function with Spider
Expand All @@ -31,6 +36,10 @@
#define SPIDER_REGISTER_TASK_TIMEOUT(func, timeout) SPIDER_WORKER_REGISTER_TASK(func)

namespace spider {
namespace core {
class MetadataStorage;
class DataStorage;
} // namespace core

/**
* An interface for a client to interact with Spider and create jobs, access the kv-store, etc.
Expand Down Expand Up @@ -58,7 +67,10 @@ class Driver {
* @return Data builder.
*/
template <Serializable T>
auto get_data_builder() -> Data<T>::Builder;
auto get_data_builder() -> Data<T>::Builder {
using DataBuilder = typename Data<T>::Builder;
return DataBuilder{m_data_storage, m_id, DataBuilder::DataSource::Driver};
}

/**
* Inserts the given key-value pair into the key-value store, overwriting any existing value.
Expand All @@ -67,7 +79,7 @@ class Driver {
* @param value
* @throw spider::ConnectionException
*/
auto kv_store_insert(std::string const& key, std::string const& value);
auto kv_store_insert(std::string const& key, std::string const& value) -> void;

/**
* Gets the value corresponding to the given key.
Expand All @@ -90,19 +102,24 @@ class Driver {
* @tparam ReturnType Return type for both the task and the resulting `TaskGraph`.
* @tparam TaskParams
* @tparam Inputs
* @tparam GraphParams
* @param task
* @param inputs Inputs to bind to `task`. If an input is a `Task` or `TaskGraph`, their
* outputs will be bound to the inputs of `task`.
* @return A `TaskGraph` of the inputs bound to `task`.
*/
template <
TaskIo ReturnType,
TaskIo... TaskParams,
RunnableOrTaskIo... Inputs,
TaskIo... GraphParams>
template <TaskIo ReturnType, TaskIo... TaskParams, RunnableOrTaskIo... Inputs>
auto bind(TaskFunction<ReturnType, TaskParams...> const& task, Inputs&&... inputs)
-> TaskGraph<ReturnType(GraphParams...)>;
-> TaskGraphType<ReturnType, Inputs...> {
std::optional<core::TaskGraphImpl> optional_graph
= core::TaskGraphImpl::bind(task, std::forward<Inputs>(inputs)...);
if (!optional_graph.has_value()) {
throw std::invalid_argument("Failed to bind inputs to task.");
}
std::unique_ptr<core::TaskGraphImpl> graph
= std::make_unique<core::TaskGraphImpl>(std::move(optional_graph.value()));

return TaskGraphType<ReturnType, Inputs...>{std::move(graph)};
}

/**
* Starts running a task with the given inputs on Spider.
Expand Down Expand Up @@ -141,6 +158,12 @@ class Driver {
* @throw spider::ConnectionException
*/
auto get_jobs() -> std::vector<boost::uuids::uuid>;

private:
boost::uuids::uuid m_id;
std::shared_ptr<core::MetadataStorage> m_metadata_storage;
std::shared_ptr<core::DataStorage> m_data_storage;
std::jthread m_heartbeat_thread;
};
} // namespace spider

Expand Down
23 changes: 22 additions & 1 deletion src/spider/client/TaskGraph.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
#ifndef SPIDER_CLIENT_TASKGRAPH_HPP
#define SPIDER_CLIENT_TASKGRAPH_HPP

#include <memory>
#include <utility>

#include "task.hpp"

namespace spider {
namespace core {
class TaskGraphImpl;
} // namespace core

class Driver;
class TaskContext;

/**
* A TaskGraph represents a directed acyclic graph (DAG) of tasks.
*
* @tparam ReturnType
* @tparam Params
*/
template <TaskIo ReturnType, TaskIo... Params>
class TaskGraph {};
class TaskGraph {
private:
explicit TaskGraph(std::unique_ptr<core::TaskGraphImpl> impl) : m_impl{std::move(impl)} {}

[[nodiscard]] auto get_impl() const -> core::TaskGraphImpl const& { return *m_impl; }

std::unique_ptr<core::TaskGraphImpl> m_impl;

friend class core::TaskGraphImpl;
friend class Driver;
friend class TaskContext;
};
} // namespace spider

#endif // SPIDER_CLIENT_TASKGRAPH_HPP
Loading
Loading