Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the C++ interface easier to use #185

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/website/pages/docs/loading.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ int main()
// Create an input tensor
uint64_t shape[]{1};
auto tensor = carton::Tensor(carton::DataType::kString, shape);
tensor.set_string(0, "Today is a good [MASK].");
tensor.at<std::string_view>(0) = "Today is a good [MASK].";

// Create a map of inputs
std::unordered_map<std::string, carton::Tensor> inputs;
Expand All @@ -322,7 +322,7 @@ int main()

const auto scores_data = static_cast<const float *>(scores.data());

std::cout << "Got output token: " << tokens.get_string(0) << std::endl;
std::cout << "Got output token: " << tokens.at<std::string_view>(0) << std::endl;
std::cout << "Got output scores: " << scores_data[0] << std::endl;
}
```
Expand Down
4 changes: 2 additions & 2 deletions docs/website/pages/quickstart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ int main()
// Create an input tensor
uint64_t shape[]{1};
auto tensor = carton::Tensor(carton::DataType::kString, shape);
tensor.set_string(0, "Today is a good [MASK].");
tensor.at<std::string_view>(0) = "Today is a good [MASK].";

// Create a map of inputs
std::unordered_map<std::string, carton::Tensor> inputs;
Expand All @@ -78,7 +78,7 @@ int main()

const auto scores_data = static_cast<const float *>(scores.data());

std::cout << "Got output token: " << tokens.get_string(0) << std::endl;
std::cout << "Got output token: " << tokens.at<std::string_view>(0) << std::endl;
std::cout << "Got output scores: " << scores_data[0] << std::endl;
}
```
Expand Down
20 changes: 16 additions & 4 deletions source/carton-bindings-c/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ impl CartonTensor {
}
}

/// For a string tensor, get a string at a particular (flattened) index into the tensor.
/// For a string tensor, get a string at a particular flattened index into the tensor.
/// Note: any returned pointers are only valid until the tensor is modified.
/// Note: `index` should take strides into account.
#[no_mangle]
pub extern "C" fn carton_tensor_get_string(
&self,
Expand All @@ -196,7 +197,10 @@ impl CartonTensor {
) {
if let carton_core::types::Tensor::String(v) = &self.inner {
let view = v.view();
let item = view.iter().nth(index as _).unwrap();
let ptr = view.as_ptr();

// TODO: assert that the index is in bounds
let item = unsafe { &*ptr.add(index as _) };
unsafe {
*string_out = item.as_ptr() as *const _;
*strlen_out = item.len() as _;
Expand All @@ -206,13 +210,18 @@ impl CartonTensor {
}
}

/// For a string tensor, set a string at a particular (flattened) index.
/// For a string tensor, set a string at a particular flattened index.
/// Copies the null-terminated string `string` into the tensor at the specified index.
/// Note: `index` should take strides into account.
#[no_mangle]
pub extern "C" fn carton_tensor_set_string(&mut self, index: u64, string: *const c_char) {
let new = unsafe { CStr::from_ptr(string).to_str().unwrap().to_owned() };
self.carton_tensor_set_string_inner(index, new);
}

/// For a string tensor, set a string at a particular flattened index.
/// Copies `strlen` bytes of `string` into the tensor at the specified index.
/// Note: `index` should take strides into account.
#[no_mangle]
pub extern "C" fn carton_tensor_set_string_with_strlen(
&mut self,
Expand All @@ -232,7 +241,10 @@ impl CartonTensor {
fn carton_tensor_set_string_inner(&mut self, index: u64, string: String) {
if let carton_core::types::Tensor::String(v) = &mut self.inner {
let mut view = v.view_mut();
let item = view.iter_mut().nth(index as _).unwrap();
let ptr = view.as_mut_ptr();

// TODO: assert that the index is in bounds
let item = unsafe { &mut *ptr.add(index as _) };
*item = string;
} else {
panic!("Tried to call `set_string` on a non-string tensor")
Expand Down
71 changes: 63 additions & 8 deletions source/carton-bindings-cpp/src/carton.hh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ namespace carton

friend class TensorMap;

template <typename T>
friend class TensorStringValue;

// For a string tensor, set a string at a particular flattened index
// This will copy data from the provided string_view.
// Note: `index` should take strides into account.
void set_string(uint64_t index, std::string_view string);

// For a string tensor, get a string at a particular flattened index
// Note: the returned view is only valid until the tensor is modified.
// Note: `index` should take strides into account.
std::string_view get_string(uint64_t index) const;

public:
// Create a tensor with dtype `dtype` and shape `shape`
Tensor(DataType dtype, std::span<uint64_t> shape);
Expand Down Expand Up @@ -131,15 +144,57 @@ namespace carton
// Note: the returned span is only valid while this Tensor is in scope
std::span<const int64_t> strides() const;

// For a string tensor, set a string at a particular (flattened) index
// This will copy data from the provided string_view.
// TODO: do some template magic to make this easy to use
void set_string(uint64_t index, std::string_view string);
// Using the accessor methods can be faster than `at` when accessing many elements
// because they avoid making function calls on each element access.
// See `TensorAccessor` below for usage.
template <typename T, size_t NumDims>
auto accessor();

// Using the accessor methods can be faster than `at` when accessing many elements
// because they avoid making function calls on each element access.
// See `TensorAccessor` below for usage.
template <typename T, size_t NumDims>
auto accessor() const;

// Get an element at an index
// This is a convenience wrapper that creates an `accessor` and uses it
// Consider explicitly creating an accessor if you need to access many elements
template <typename T, typename... Index>
auto at(Index... index) const;

// Get an element at an index
// This is a convenience wrapper that creates an `accessor` and uses it
// Consider explicitly creating an accessor if you need to access many elements
template <typename T, typename... Index>
auto at(Index... index);
};

// For a string tensor, get a string at a particular (flattened) index
// Note: the returned view is only valid until the tensor is modified.
// TODO: do some template magic to make this easy to use
std::string_view get_string(uint64_t index) const;
// The return type of the `accessor` methods of `Tensor`
template <typename T, size_t NumDims, typename DataContainer>
class TensorAccessor
{
private:
DataContainer data_;

// The strides of the tensor
const std::span<const int64_t> strides_;

friend class Tensor;
TensorAccessor(DataContainer data, std::span<const int64_t> strides) : data_(data), strides_(strides) {}

public:
// Return the element at `index`
// One value of `index` must be provided for each dimension.
//
// ```
// auto acc = t.accessor<float, 3>();
// auto val = acc[1, 2, 3];
// ```
//
// Note: For string values, the returned view is only valid until the tensor is modified. Users
// should make a copy if they need to both persist the value and modify the tensor.
template <typename... Index>
auto operator[](Index... index) const;
};

template <typename T>
Expand Down
115 changes: 115 additions & 0 deletions source/carton-bindings-cpp/src/carton_impl.hh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,121 @@ namespace carton
}
}

// Utility to let us read and write strings more easily
// This is returned by `TensorAccessor` when indexing string tensors
template <typename T>
class TensorStringValue
{
private:
T &tensor_;

uint64_t index_;

public:
TensorStringValue(T &tensor, uint64_t index) : tensor_(tensor), index_(index) {}

// Assignment of a string type
void operator=(std::string_view val)
{
tensor_.set_string(index_, val);
}

// Reading of a string type
operator std::string_view() const
{
return tensor_.get_string(index_);
}
};

template <typename T>
std::ostream &operator<<(std::ostream &os, const TensorStringValue<T> &v)
{
os << std::string_view(v);
return os;
}

// Impl for Tensor
// Using the accessor methods can be faster when accessing many elements because
// they avoid making function calls on each element access
template <typename T, size_t NumDims>
auto Tensor::accessor()
{
// TODO: assert N == ndims
// TODO: assert data type
if constexpr (std::is_same_v<T, std::string_view> || std::is_same_v<T, std::string>)
{
return TensorAccessor<std::string_view, NumDims, Tensor &>(*this, strides());
}
else
{
return TensorAccessor<T, NumDims, void *>(data(), strides());
}
}

// Using the accessor methods can be faster when accessing many elements because
// they avoid making function calls on each element access
template <typename T, size_t NumDims>
auto Tensor::accessor() const
{
// TODO: assert N == ndims
// TODO: assert data type
if constexpr (std::is_same_v<T, std::string_view> || std::is_same_v<T, std::string>)
{
return TensorAccessor<std::string_view, NumDims, const Tensor &>(*this, strides());
}
else
{
return TensorAccessor<T, NumDims, const void *>(data(), strides());
}
}

template <typename T, typename... Index>
auto Tensor::at(Index... index) const
{
constexpr auto N = sizeof...(Index);
auto acc = accessor<T, N>();
return acc.operator[](std::forward<Index>(index)...);
}

template <typename T, typename... Index>
auto Tensor::at(Index... index)
{
constexpr auto N = sizeof...(Index);
auto acc = accessor<T, N>();
return acc.operator[](std::forward<Index>(index)...);
}

// Impl for TensorAccessor
template <typename T, size_t NumDims, typename DataContainer>
template <typename... Index>
auto TensorAccessor<T, NumDims, DataContainer>::operator[](Index... index) const
{
constexpr auto num_indices = sizeof...(Index);
static_assert(NumDims == num_indices, "Incorrect number of indices");

// Compute the index. This all gets flattened out at compile time
int i = 0;

// Basically sets up a dot product of `index` and `strides`
auto offset = ([&]
{ return index * strides_[i++]; }() +
...);

if constexpr (std::is_same_v<T, std::string_view> || std::is_same_v<T, std::string>)
{
// Handle string tensors separately
// For convenience, we allow T to be std::string, but we always use `std::string_view`
// to avoid unnecessary copies.
return TensorStringValue(data_, offset);
}
else
{
// Numeric tensors
static_assert(std::is_arithmetic_v<T>, "accessor() only supports string and numeric tensors");
return static_cast<const T *>(data_)[offset * sizeof(T)];
}
}

// Impl for AsyncNotifier
template <typename T>
AsyncNotifier<T>::AsyncNotifier() : AsyncNotifierBase() {}
Expand Down
10 changes: 6 additions & 4 deletions source/carton-bindings-cpp/tests/callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ void infer_callback(Result<TensorMap> infer_result, void *arg)
const auto tokens = out.get_and_remove("tokens");
const auto scores = out.get_and_remove("scores");

const auto scores_data = static_cast<const float *>(scores.data());
// Can use a template arg of `std::string` or `std::string_view`
std::cout << "Got output token: " << tokens.at<std::string>(0) << std::endl;
std::cout << "Got output scores: " << scores.at<float>(0) << std::endl;

std::cout << "Got output token: " << tokens.get_string(0) << std::endl;
std::cout << "Got output scores: " << scores_data[0] << std::endl;
assert(tokens.at<std::string_view>(0) == std::string_view("day"));
assert(std::abs(scores.at<float>(0) - 14.5513) < 0.0001);

exit(0);
}
Expand All @@ -47,7 +49,7 @@ void load_callback(Result<Carton> model_result, void *arg)

uint64_t shape[]{1};
auto tensor = Tensor(DataType::kString, shape);
tensor.set_string(0, "Today is a good [MASK].");
tensor.at<std::string_view>(0) = "Today is a good [MASK].";

std::unordered_map<std::string, Tensor> inputs;
inputs.insert(std::make_pair("input", std::move(tensor)));
Expand Down
10 changes: 8 additions & 2 deletions source/carton-bindings-cpp/tests/future.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cassert>
#include <iostream>

#include "../src/carton.hh"
Expand All @@ -28,7 +29,9 @@ int main()

uint64_t shape[]{1};
auto tensor = Tensor(DataType::kString, shape);
tensor.set_string(0, "Today is a good [MASK].");

// Can use a template arg of `std::string` or `std::string_view`
tensor.at<std::string>(0) = "Today is a good [MASK].";

std::unordered_map<std::string, Tensor> inputs;
inputs.insert(std::make_pair("input", std::move(tensor)));
Expand All @@ -41,6 +44,9 @@ int main()

const auto scores_data = static_cast<const float *>(scores.data());

std::cout << "Got output token: " << tokens.get_string(0) << std::endl;
std::cout << "Got output token: " << tokens.at<std::string_view>(0) << std::endl;
std::cout << "Got output scores: " << scores_data[0] << std::endl;

assert(tokens.at<std::string_view>(0) == std::string_view("day"));
assert(std::abs(scores_data[0] - 14.5513) < 0.0001);
}
11 changes: 9 additions & 2 deletions source/carton-bindings-cpp/tests/notifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ int main()

uint64_t shape[]{1};
auto tensor = Tensor(DataType::kString, shape);
tensor.set_string(0, "Today is a good [MASK].");
tensor.at<std::string_view>(0) = "Today is a good [MASK].";

std::unordered_map<std::string, Tensor> inputs;
inputs.insert(std::make_pair("input", std::move(tensor)));
Expand All @@ -56,6 +56,13 @@ int main()

const auto scores_data = static_cast<const float *>(scores.data());

std::cout << "Got output token: " << tokens.get_string(0) << std::endl;
// If you're accessing a few elements, you can just use `.at`, but we'll use
// an accessor here for testing
const auto token_accessor = tokens.accessor<std::string_view, 1>();

std::cout << "Got output token: " << token_accessor[0] << std::endl;
std::cout << "Got output scores: " << scores_data[0] << std::endl;

assert(token_accessor[0] == std::string_view("day"));
assert(std::abs(scores_data[0] - 14.5513) < 0.0001);
}