diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c177466..a79d9ad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,8 +2,6 @@ name: CI on: pull_request: push: - branches: - - main jobs: main: @@ -33,6 +31,6 @@ jobs: - run: mix deps.unlock --check-unused if: ${{ matrix.lint }} - run: mix deps.compile - # - run: mix compile --warnings-as-errors - # if: ${{ matrix.lint }} + - run: mix compile --warnings-as-errors + if: ${{ matrix.lint }} - run: mix test diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index a998f25..4e68eba 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -118,9 +118,31 @@ defmodule Candlex.Backend do # Aggregates @impl true - def all(%T{} = out, %T{} = tensor, _opts) do - from_nx(tensor) - |> Native.all() + def all(%T{} = out, %T{} = tensor, opts) do + case opts[:axes] do + nil -> + from_nx(tensor) + |> Native.all() + + axes -> + from_nx(tensor) + |> Native.all_within_dims(axes, opts[:keep_axes]) + end + |> unwrap!() + |> to_nx(out) + end + + @impl true + def any(%T{} = out, %T{} = tensor, opts) do + case opts[:axes] do + nil -> + from_nx(tensor) + |> Native.any() + + axes -> + from_nx(tensor) + |> Native.any_within_dims(axes, opts[:keep_axes]) + end |> unwrap!() |> to_nx(out) end @@ -506,6 +528,25 @@ defmodule Candlex.Backend do end @impl true + def dot( + %T{type: _out_type} = out, + %T{shape: left_shape, type: _left_type} = left, + [left_axis] = _left_axes, + [] = _left_batched_axes, + %T{shape: right_shape, type: _right_type} = right, + [0] = _right_axes, + [] = _right_batched_axes + ) + when tuple_size(left_shape) >= 1 and tuple_size(right_shape) == 1 and + left_axis == tuple_size(left_shape) - 1 do + {left, right} = maybe_upcast(left, right) + + from_nx(left) + |> Native.dot(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + def dot( %T{type: _out_type} = out, %T{shape: left_shape, type: _left_type} = left, @@ -516,6 +557,8 @@ defmodule Candlex.Backend do [] = _right_batched_axes ) when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do + {left, right} = maybe_upcast(left, right) + Native.matmul( from_nx(left), from_nx(right) @@ -827,7 +870,6 @@ defmodule Candlex.Backend do end for op <- [ - :any, :argsort, :eigh, :fft, diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index d5c2326..fc045cc 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -28,6 +28,9 @@ defmodule Candlex.Native do def from_binary(_binary, _dtype, _shape, _device), do: error() def to_binary(_tensor), do: error() def all(_tensor), do: error() + def all_within_dims(_tensor, _dims, _keep_dims), do: error() + def any(_tensor), do: error() + def any_within_dims(_tensor, _dims, _keep_dims), do: error() def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def gather(_tensor, _indexes, _dim), do: error() @@ -92,6 +95,7 @@ defmodule Candlex.Native do :bitwise_or, :bitwise_xor, :divide, + :dot, :equal, :greater, :greater_equal, diff --git a/mix.exs b/mix.exs index d758ee5..89774a2 100644 --- a/mix.exs +++ b/mix.exs @@ -32,11 +32,12 @@ defmodule Candlex.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:nx, "~> 0.6.2"}, + # {:nx, "~> 0.6.2"}, + {:nx, git: "https://github.com/elixir-nx/nx", sparse: "nx"}, {:rustler_precompiled, "~> 0.7.0"}, # Optional - {:rustler, "~> 0.30.0", optional: true}, + {:rustler, "~> 0.29", optional: true}, # Dev {:ex_doc, "~> 0.30.9", only: :dev, runtime: false} diff --git a/mix.lock b/mix.lock index aff8a25..3747805 100644 --- a/mix.lock +++ b/mix.lock @@ -8,7 +8,7 @@ "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:hex, :nx, "0.6.2", "f1d137f477b1a6f84f8db638f7a6d5a0f8266caea63c9918aa4583db38ebe1d6", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ac913b68d53f25f6eb39bddcf2d2cd6ea2e9bcb6f25cf86a79e35d0411ba96ad"}, + "nx": {:git, "https://github.com/elixir-nx/nx", "7706e8601e40916c02f8773df7802b3bfab43054", [sparse: "nx"]}, "rustler": {:hex, :rustler, "0.30.0", "cefc49922132b072853fa9b0ca4dc2ffcb452f68fb73b779042b02d545e097fb", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "9ef1abb6a7dda35c47cfc649e6a5a61663af6cf842a55814a554a84607dee389"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs index 8d2ae71..ef72f43 100644 --- a/native/candlex/src/lib.rs +++ b/native/candlex/src/lib.rs @@ -42,6 +42,9 @@ rustler::init! { tensors::less, tensors::less_equal, tensors::all, + tensors::all_within_dims, + tensors::any, + tensors::any_within_dims, tensors::sum, tensors::dtype, tensors::t_shape, @@ -68,6 +71,7 @@ rustler::init! { tensors::permute, tensors::slice_scatter, tensors::pad_with_zeros, + tensors::dot, tensors::matmul, tensors::abs, tensors::acos, diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 03b8a71..78c94b4 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -154,22 +154,38 @@ pub fn arange( #[rustler::nif(schedule = "DirtyCpu")] pub fn all(ex_tensor: ExTensor) -> Result { - let device = ex_tensor.device(); - let t = ex_tensor.flatten_all()?; - let dims = t.shape().dims(); - let on_true = Tensor::ones(dims, DType::U8, device)?; - let on_false = Tensor::zeros(dims, DType::U8, device)?; - - let bool_scalar = match t - .where_cond(&on_true, &on_false)? - .min(0)? - .to_scalar::()? - { - 0 => 0u8, - _ => 1u8, - }; + Ok(ExTensor::new(_all( + &ex_tensor.flatten_all()?, + vec![0], + false, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn all_within_dims( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + Ok(ExTensor::new(_all(ex_tensor.deref(), dims, keep_dims)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn any(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(_any( + &ex_tensor.flatten_all()?, + vec![0], + false, + )?)) +} - Ok(ExTensor::new(Tensor::new(bool_scalar, device)?)) +#[rustler::nif(schedule = "DirtyCpu")] +pub fn any_within_dims( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + Ok(ExTensor::new(_any(ex_tensor.deref(), dims, keep_dims)?)) } #[rustler::nif(schedule = "DirtyCpu")] @@ -346,6 +362,14 @@ pub fn divide(left: ExTensor, right: ExTensor) -> Result )) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn dot(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new( + left.mul(&right.broadcast_as(left.shape())?)? + .sum(left.rank() - 1)?, + )) +} + macro_rules! unary_nif { ($nif_name:ident, $native_fn_name:ident) => { #[rustler::nif(schedule = "DirtyCpu")] @@ -446,6 +470,38 @@ custom_binary_nif!(pow, Pow); custom_binary_nif!(right_shift, Shr); custom_binary_nif!(remainder, Remainder); +fn _any(tensor: &Tensor, dims: Vec, keep_dims: bool) -> Result { + let comparison = tensor.ne(&tensor.zeros_like()?)?; + + let result = if keep_dims { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.max_keepdim(*dim).unwrap()) + } else { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.max(*dim).unwrap()) + }; + + Ok(result) +} + +fn _all(tensor: &Tensor, dims: Vec, keep_dims: bool) -> Result { + let comparison = tensor.ne(&tensor.zeros_like()?)?; + + let result = if keep_dims { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.min_keepdim(*dim).unwrap()) + } else { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.min(*dim).unwrap()) + }; + + Ok(result) +} + fn tuple_to_vec(term: Term) -> Result, rustler::Error> { rustler::types::tuple::get_tuple(term)? .iter() diff --git a/test/candlex_test.exs b/test/candlex_test.exs index 111e86f..5899143 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -511,14 +511,13 @@ defmodule CandlexTest do # Dot product of vectors - # TODO: - # t([1, 2, 3]) - # |> Nx.dot(t([4, 5, 6])) - # |> assert_equal(t(32)) + t([1, 2, 3]) + |> Nx.dot(t([4, 5, 6])) + |> assert_equal(t(32)) - # t([1.0, 2.0, 3.0]) - # |> Nx.dot(t([1, 2, 3])) - # |> assert_equal(t(14.0)) + t([1.0, 2, 3]) + |> Nx.dot(t([1, 2, 3])) + |> assert_equal(t(14.0)) # Dot product of matrices (2-D tensors) @@ -533,7 +532,7 @@ defmodule CandlexTest do # )) t([[1.0, 2, 3], [4, 5, 6]]) - |> Nx.dot(t([[7.0, 8], [9, 10], [11, 12]])) + |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) |> assert_equal( t([ [58.0, 64], @@ -543,14 +542,18 @@ defmodule CandlexTest do # Dot product of vector and n-D tensor - # t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]], names: [:i, :j, :k]) - # |> Nx.dot(t([5.0, 10], names: [:x])) - # |> assert_equal(t( - # [ - # [25, 55], - # [85, 115] - # ] - # )) + t([[0.0]]) + |> Nx.dot(t([55.0])) + |> assert_equal(t([0.0])) + + t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]]) + |> Nx.dot(t([5, 10])) + |> assert_equal( + t([ + [25.0, 55], + [85, 115] + ]) + ) # t([5.0, 10], names: [:x]) # |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j])) @@ -2143,6 +2146,109 @@ defmodule CandlexTest do ) end + test "all" do + t(0) + |> Nx.all() + |> assert_equal(t(0)) + + t(10) + |> Nx.all() + |> assert_equal(t(1)) + + t([0, 1, 2]) + |> Nx.all() + |> assert_equal(t(0)) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:x]) + |> assert_equal(t([1, 0, 1])) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:y]) + |> assert_equal(t([0, 1])) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:y], keep_axes: true) + |> assert_equal( + t([ + [0], + [1] + ]) + ) + + tensor = Nx.tensor([[[1, 2], [0, 4]], [[5, 6], [7, 8]]], names: [:x, :y, :z]) + + tensor + |> Nx.all(axes: [:x, :y]) + |> assert_equal(t([0, 1])) + + tensor + |> Nx.all(axes: [:y, :z]) + |> assert_equal(t([0, 1])) + + tensor + |> Nx.all(axes: [:x, :z]) + |> assert_equal(t([1, 0])) + + tensor + |> Nx.all(axes: [:x, :y], keep_axes: true) + |> assert_equal( + t([ + [ + [0, 1] + ] + ]) + ) + + tensor + |> Nx.all(axes: [:y, :z], keep_axes: true) + |> assert_equal( + t([ + [ + [0] + ], + [ + [1] + ] + ]) + ) + + tensor + |> Nx.all(axes: [:x, :z], keep_axes: true) + |> assert_equal( + t([ + [ + [1], + [0] + ] + ]) + ) + end + + test "any" do + t([0, 1, 2]) + |> Nx.any() + |> assert_equal(t(1)) + + t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + |> Nx.any(axes: [:x]) + |> assert_equal(t([0, 1, 1])) + + t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + |> Nx.any(axes: [:y]) + |> assert_equal(t([1, 1])) + + tensor = t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + + tensor + |> Nx.any(axes: [:x], keep_axes: true) + |> assert_equal(t([[0, 1, 1]])) + + tensor + |> Nx.any(axes: [:y], keep_axes: true) + |> assert_equal(t([[1], [1]])) + end + if Candlex.Backend.cuda_available?() do test "different devices" do t([1, 2, 3], backend: {Candlex.Backend, device: :cpu})