Skip to content

Commit

Permalink
feat: dot/2 supports receiving 1-D tensors (vectors) (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy authored Nov 2, 2023
1 parent ba8fa30 commit e4d9621
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 15 deletions.
19 changes: 19 additions & 0 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,25 @@ defmodule Candlex.Backend do
end

@impl true
def dot(
%T{type: _out_type} = out,
%T{shape: left_shape, type: _left_type} = left,
[left_axis] = _left_axes,
[] = _left_batched_axes,
%T{shape: right_shape, type: _right_type} = right,
[0] = _right_axes,
[] = _right_batched_axes
)
when tuple_size(left_shape) >= 1 and tuple_size(right_shape) == 1 and
left_axis == tuple_size(left_shape) - 1 do
{left, right} = maybe_upcast(left, right)

from_nx(left)
|> Native.dot(from_nx(right))
|> unwrap!()
|> to_nx(out)
end

def dot(
%T{type: _out_type} = out,
%T{shape: left_shape, type: _left_type} = left,
Expand Down
1 change: 1 addition & 0 deletions lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ defmodule Candlex.Native do
:bitwise_or,
:bitwise_xor,
:divide,
:dot,
:equal,
:greater,
:greater_equal,
Expand Down
1 change: 1 addition & 0 deletions native/candlex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ rustler::init! {
tensors::permute,
tensors::slice_scatter,
tensors::pad_with_zeros,
tensors::dot,
tensors::matmul,
tensors::abs,
tensors::acos,
Expand Down
8 changes: 8 additions & 0 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ pub fn divide(left: ExTensor, right: ExTensor) -> Result<ExTensor, CandlexError>
))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn dot(left: ExTensor, right: ExTensor) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(
left.mul(&right.broadcast_as(left.shape())?)?
.sum(left.rank() - 1)?,
))
}

macro_rules! unary_nif {
($nif_name:ident, $native_fn_name:ident) => {
#[rustler::nif(schedule = "DirtyCpu")]
Expand Down
33 changes: 18 additions & 15 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -511,14 +511,13 @@ defmodule CandlexTest do

# Dot product of vectors

# TODO:
# t([1, 2, 3])
# |> Nx.dot(t([4, 5, 6]))
# |> assert_equal(t(32))
t([1, 2, 3])
|> Nx.dot(t([4, 5, 6]))
|> assert_equal(t(32))

# t([1.0, 2.0, 3.0])
# |> Nx.dot(t([1, 2, 3]))
# |> assert_equal(t(14.0))
t([1.0, 2, 3])
|> Nx.dot(t([1, 2, 3]))
|> assert_equal(t(14.0))

# Dot product of matrices (2-D tensors)

Expand All @@ -543,14 +542,18 @@ defmodule CandlexTest do

# Dot product of vector and n-D tensor

# t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]], names: [:i, :j, :k])
# |> Nx.dot(t([5.0, 10], names: [:x]))
# |> assert_equal(t(
# [
# [25, 55],
# [85, 115]
# ]
# ))
t([[0.0]])
|> Nx.dot(t([55.0]))
|> assert_equal(t([0.0]))

t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]])
|> Nx.dot(t([5, 10]))
|> assert_equal(
t([
[25.0, 55],
[85, 115]
])
)

# t([5.0, 10], names: [:x])
# |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j]))
Expand Down

0 comments on commit e4d9621

Please sign in to comment.