From f78f6d4b19dc828adc25603c80abeb677b6cabcb Mon Sep 17 00:00:00 2001 From: aleinin <95333017+abeleinin@users.noreply.github.com> Date: Wed, 2 Oct 2024 23:32:49 -0500 Subject: [PATCH 1/4] linalg solve backend --- docs/src/python/linalg.rst | 5 +- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/solve.cpp | 131 ++++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 6 + mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/linalg.cpp | 39 +++++++ mlx/linalg.h | 2 + mlx/primitives.cpp | 11 ++ mlx/primitives.h | 17 ++- python/src/linalg.cpp | 20 +++- python/tests/test_linalg.py | 54 +++++++++ tests/linalg_tests.cpp | 50 +++++++++ 15 files changed, 336 insertions(+), 4 deletions(-) create mode 100644 mlx/backend/common/solve.cpp diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index f6c51ed0b..853ad393b 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -5,8 +5,8 @@ Linear Algebra .. currentmodule:: mlx.core.linalg -.. autosummary:: - :toctree: _autosummary +.. autosummary:: + :toctree: _autosummary inv tri_inv @@ -18,3 +18,4 @@ Linear Algebra svd eigvalsh eigh + solve diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 1f80224ad..69f7eadca 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -82,6 +82,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT_MULTI(Eigh) +DEFAULT_MULTI(Solve) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 4fca2274e..7a77f82e6 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -52,6 +52,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/solve.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 547d8e25d..26c745537 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -111,6 +111,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT_MULTI(Eigh) +DEFAULT_MULTI(Solve) namespace { diff --git a/mlx/backend/common/solve.cpp b/mlx/backend/common/solve.cpp new file mode 100644 index 000000000..be0f190d8 --- /dev/null +++ b/mlx/backend/common/solve.cpp @@ -0,0 +1,131 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack_helper.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +#include + +namespace mlx::core { + +namespace { + +// Wrapper to account for differences in +// LAPACK implementations (basically how to pass the 'trans' string to fortran). +int sgetrs_wrapper(char trans, int N, int NRHS, int* ipiv, float* a, float* b) { + int info; + +#ifdef LAPACK_FORTRAN_STRLEN_END + sgetrs_( + /* trans */ &trans, + /* n */ &N, + /* nrhs */ &NRHS, + /* a */ a, + /* lda */ &N, + /* ipiv */ ipiv, + /* b */ b, + /* ldb */ &N, + /* info */ &info, + /* trans_len = */ static_cast(1)); +#else + sgetrs_( + /* trans */ &trans, + /* n */ &N, + /* nrhs */ &NRHS, + /* a */ a, + /* lda */ &N, + /* ipiv */ ipiv, + /* b */ b, + /* ldb */ &N, + /* info */ &info); +#endif + + return info; +} + +} // namespace + +void solve_impl(const array& a, const array& b, array& out) { + int N = a.shape(-2); + int NRHS = out.shape(-1); + std::vector ipiv(N); + + // copy b into out and make it col-contiguous + auto flags = out.flags(); + flags.col_contiguous = true; + flags.row_contiguous = false; + std::vector strides(a.ndim(), 0); + std::copy(out.strides().begin(), out.strides().end(), strides.begin()); + strides[a.ndim() - 2] = 1; + strides[a.ndim() - 1] = N; + + out.set_data( + allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags); + copy_inplace(b, out, CopyType::GeneralGeneral); + + // lapack clobbers the input, so we have to make a copy. the copy doesn't need + // to be col-contiguous because sgetrs has a transpose parameter (trans='T'). + array a_cpy(a.shape(), float32, nullptr, {}); + copy( + a, + a_cpy, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + float* a_ptr = a_cpy.data(); + float* out_ptr = out.data(); + int* ipiv_ptr = ipiv.data(); + + int info; + size_t num_matrices = a.size() / (N * N); + for (size_t i = 0; i < num_matrices; i++) { + // Compute LU factorization of A + MLX_LAPACK_FUNC(sgetrf) + (/* m */ &N, + /* n */ &N, + /* a */ a_ptr, + /* lda */ &N, + /* ipiv */ ipiv_ptr, + /* info */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "solve_impl: sgetrf_ failed with code " << info + << ((info > 0) ? " because matrix is singular" + : " becuase argument had an illegal value"); + throw std::runtime_error(ss.str()); + } + + static constexpr char trans = 'T'; + // Solve the system using the LU factors from sgetrf + info = sgetrs_wrapper(trans, N, NRHS, ipiv_ptr, a_ptr, out_ptr); + + if (info != 0) { + std::stringstream ss; + ss << "solve_impl: sgetrs_ failed with code " << info; + throw std::runtime_error(ss.str()); + } + + // Advance pointers to the next matrix + a_ptr += N * N; + out_ptr += N * NRHS; + } +} + +void Solve::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 2); + if (inputs[0].dtype() != float32 || inputs[1].dtype() != float32) { + throw std::runtime_error("[Solve::eval] only supports float32."); + } + solve_impl(inputs[0], inputs[1], outputs[0]); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index e5a7d885b..35586a257 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -438,4 +438,10 @@ void View::eval_gpu(const std::vector& inputs, array& out) { } } +void Solve::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Solve::eval_gpu] Metal Solve NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index c87fcc8bb..db9f40013 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -111,5 +111,6 @@ NO_CPU(Tanh) NO_CPU(Transpose) NO_CPU(Inverse) NO_CPU(View) +NO_CPU_MULTI(Solve) } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index aaee51d83..f8556a76d 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -114,6 +114,7 @@ NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) NO_GPU(View) +NO_GPU_MULTI(Solve) namespace fast { NO_GPU_MULTI(LayerNorm) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index daf5573fc..74c052366 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -500,4 +500,43 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } +array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (a.dtype() != float32 && b.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::solve] Input array must have type float32. Received arrays " + << "with type " << a.dtype() << " and " << b.dtype() << "."; + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::solve] First input must have >= 2 dimensions. " + << "Received array with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (b.ndim() < 1) { + std::ostringstream msg; + msg << "[linalg::solve] Second input must have >= 1 dimensions. " + << "Received array with " << b.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + std::ostringstream msg; + msg << "[linalg::solve] First input must be a square matrix. " + << "Received array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != b.shape(b.ndim() - 2)) { + std::ostringstream msg; + msg << "[linalg::solve] Last dimension of first input with shape " + << a.shape() << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + return array( + b.shape(), out_type, std::make_shared(to_stream(s)), {a, b}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 4ea81bef0..bca1d56af 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -74,6 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); +array solve(const array& a, const array& b, StreamOrDevice s = {}); + /** * Compute the cross product of two arrays along the given axis. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c9f839d4b..9ea5cc513 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4201,6 +4201,17 @@ std::pair, std::vector> SVD::vmap( return {{linalg::svd(a, stream())}, {ax, ax, ax}}; } +std::pair, std::vector> Solve::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto maybe_move_ax = [this](auto& arr, auto ax) { + return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr; + }; + auto a = maybe_move_ax(inputs[0], axes[0]); + auto b = maybe_move_ax(inputs[1], axes[1]); + return {{linalg::solve(a, b, stream())}, {0}}; +} + std::pair, std::vector> Inverse::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index f2b5bab7c..f15d66f40 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2202,7 +2202,6 @@ class Eigh : public Primitive { : Primitive(stream), uplo_(std::move(uplo)), compute_eigenvectors_(compute_eigenvectors) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -2236,4 +2235,20 @@ class Eigh : public Primitive { bool compute_eigenvectors_; }; +class Solve : public Primitive { + public: + explicit Solve(Stream stream) : Primitive(stream) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_PRINT(Solve) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index e2c3aea23..0510f0f69 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -443,7 +443,6 @@ void init_linalg(nb::module_& parent_module) { m.def( "eigh", [](const array& a, const std::string UPLO, StreamOrDevice s) { - // TODO avoid cast? auto result = eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, @@ -486,4 +485,23 @@ void init_linalg(nb::module_& parent_module) { array([[ 0.707107, -0.707107], [ 0.707107, 0.707107]], dtype=float32) )pbdoc"); + m.def( + "solve", + &solve, + "a"_a, + "b"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the solution to a system of linear equations ax = b. + + Args: + a (array): Input array. + b (array): Input array. + + Returns: + array: The unique solution to the system ax = b. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 695d7704f..f81186818 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -319,6 +319,60 @@ def check_eigs_and_vecs(A_np, kwargs={}): mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix + def test_solve(self): + mx.random.seed(7) + + # Test 3x3 matrix with 1D rhs + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + b = mx.array([11.0, 35.0, 28.0]) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test symmetric positive-definite matrix + N = 5 + a = mx.random.uniform(shape=(N, N)) + a = mx.matmul(a, a.T) + N * mx.eye(N) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test batch dimension + a = mx.random.uniform(shape=(5, 5, 4, 4)) + b = mx.random.uniform(shape=(5, 5, 4, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test large matrix + N = 1000 + a = mx.random.uniform(shape=(N, N)) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-2)) + + # Test multi-column rhs + a = mx.random.uniform(shape=(5, 5)) + b = mx.random.uniform(shape=(5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test batched multi-column rhs + a = mx.concat([a, a, a, a, a, a]).reshape((3, 2, 5, 5)) + b = mx.concat([b, b, b, b, b, b]).reshape((3, 2, 5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index f0b34cc01..a8b03ee14 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -473,3 +473,53 @@ TEST_CASE("test matrix eigh") { // Verify eigendecomposition CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item()); } + +TEST_CASE("test solve") { + // 0D and 1D throw + CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu)); + CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu)); + + // Unsupported types throw + CHECK_THROWS( + linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu)); + + // Non-square throws + array a = reshape(arange(6), {3, 2}); + array b = reshape(arange(3), {3, 1}); + CHECK_THROWS(linalg::solve(a, b, Device::cpu)); + + // Test 2x2 matrix with 1D rhs + a = array({2., 1., 1., 3.}, {2, 2}); + b = array({8., 13.}, {2}); + + array result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test 3x3 matrix + a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3}); + b = array({6., 15., 25.}, {3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch dimension + a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); + b = reshape(concatenate({b, b, b, b, b}), {5, 3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test multi-column rhs + a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3}); + b = array({4., 2., 5., 3., 6., 1.}, {3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch multi-column rhs + a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); + b = reshape(concatenate({b, b, b, b, b}), {5, 3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); +} From ff4fb42a8f23aa6329d03f520cbcabf86a86e532 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 25 Oct 2024 12:21:35 -0700 Subject: [PATCH 2/4] nits --- mlx/backend/common/solve.cpp | 82 +++++++++++------------------------- mlx/linalg.cpp | 24 ++++++----- 2 files changed, 38 insertions(+), 68 deletions(-) diff --git a/mlx/backend/common/solve.cpp b/mlx/backend/common/solve.cpp index be0f190d8..74e28aad2 100644 --- a/mlx/backend/common/solve.cpp +++ b/mlx/backend/common/solve.cpp @@ -1,57 +1,14 @@ // Copyright © 2024 Apple Inc. +#include + #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack_helper.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - -#include - namespace mlx::core { -namespace { - -// Wrapper to account for differences in -// LAPACK implementations (basically how to pass the 'trans' string to fortran). -int sgetrs_wrapper(char trans, int N, int NRHS, int* ipiv, float* a, float* b) { - int info; - -#ifdef LAPACK_FORTRAN_STRLEN_END - sgetrs_( - /* trans */ &trans, - /* n */ &N, - /* nrhs */ &NRHS, - /* a */ a, - /* lda */ &N, - /* ipiv */ ipiv, - /* b */ b, - /* ldb */ &N, - /* info */ &info, - /* trans_len = */ static_cast(1)); -#else - sgetrs_( - /* trans */ &trans, - /* n */ &N, - /* nrhs */ &NRHS, - /* a */ a, - /* lda */ &N, - /* ipiv */ ipiv, - /* b */ b, - /* ldb */ &N, - /* info */ &info); -#endif - - return info; -} - -} // namespace - void solve_impl(const array& a, const array& b, array& out) { int N = a.shape(-2); int NRHS = out.shape(-1); @@ -59,12 +16,14 @@ void solve_impl(const array& a, const array& b, array& out) { // copy b into out and make it col-contiguous auto flags = out.flags(); - flags.col_contiguous = true; + auto ndim = b.ndim(); + flags.col_contiguous = ndim <= 2; flags.row_contiguous = false; - std::vector strides(a.ndim(), 0); - std::copy(out.strides().begin(), out.strides().end(), strides.begin()); - strides[a.ndim() - 2] = 1; - strides[a.ndim() - 1] = N; + flags.contiguous = true; + auto strides = out.strides(); + if (ndim >= 2) { + std::swap(strides[ndim - 1], strides[ndim - 2]); + } out.set_data( allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags); @@ -96,19 +55,29 @@ void solve_impl(const array& a, const array& b, array& out) { if (info != 0) { std::stringstream ss; - ss << "solve_impl: sgetrf_ failed with code " << info + ss << "[Solve::eval_cpu] sgetrf_ failed with code " << info << ((info > 0) ? " because matrix is singular" : " becuase argument had an illegal value"); throw std::runtime_error(ss.str()); } - static constexpr char trans = 'T'; // Solve the system using the LU factors from sgetrf - info = sgetrs_wrapper(trans, N, NRHS, ipiv_ptr, a_ptr, out_ptr); + static constexpr char trans = 'T'; + MLX_LAPACK_FUNC(sgetrs) + ( + /* trans */ &trans, + /* n */ &N, + /* nrhs */ &NRHS, + /* a */ a_ptr, + /* lda */ &N, + /* ipiv */ ipiv_ptr, + /* b */ out_ptr, + /* ldb */ &N, + /* info */ &info); if (info != 0) { std::stringstream ss; - ss << "solve_impl: sgetrs_ failed with code " << info; + ss << "[Solve::eval_cpu] sgetrs_ failed with code " << info; throw std::runtime_error(ss.str()); } @@ -122,9 +91,6 @@ void Solve::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 2); - if (inputs[0].dtype() != float32 || inputs[1].dtype() != float32) { - throw std::runtime_error("[Solve::eval] only supports float32."); - } solve_impl(inputs[0], inputs[1], outputs[0]); } diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 74c052366..92f70e927 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -501,30 +501,24 @@ std::pair eigh( } array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { - if (a.dtype() != float32 && b.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::solve] Input array must have type float32. Received arrays " - << "with type " << a.dtype() << " and " << b.dtype() << "."; - } - if (a.ndim() < 2) { std::ostringstream msg; msg << "[linalg::solve] First input must have >= 2 dimensions. " - << "Received array with " << a.ndim() << " dimensions."; + << "Received array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } if (b.ndim() < 1) { std::ostringstream msg; msg << "[linalg::solve] Second input must have >= 1 dimensions. " - << "Received array with " << b.ndim() << " dimensions."; + << "Received array with " << b.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } if (a.shape(-1) != a.shape(-2)) { std::ostringstream msg; msg << "[linalg::solve] First input must be a square matrix. " - << "Received array with shape " << a.shape() << "."; + << "Received array with shape " << a.shape() << "."; throw std::invalid_argument(msg.str()); } @@ -535,8 +529,18 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { << " second input with shape " << b.shape() << "."; throw std::invalid_argument(msg.str()); } + + auto out_type = promote_types(a.dtype(), b.dtype()); + if (out_type != float32) { + std::ostringstream msg; + msg << "[linalg::solve] Input array must have type float32. Received arrays " + << "with type " << a.dtype() << " and " << b.dtype() << "."; + } return array( - b.shape(), out_type, std::make_shared(to_stream(s)), {a, b}); + b.shape(), + out_type, + std::make_shared(to_stream(s)), + {astype(a, out_type, s), astype(b, out_type, s)}); } } // namespace mlx::core::linalg From 233258ebbb76224e6f9f5c274dd21f0edac19570 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 25 Oct 2024 20:03:36 -0700 Subject: [PATCH 3/4] more nits + fix --- mlx/backend/common/solve.cpp | 22 +++++++++--------- mlx/linalg.cpp | 44 ++++++++++++++++++++++++++++++------ python/tests/test_linalg.py | 4 ++-- tests/linalg_tests.cpp | 4 ++-- 4 files changed, 52 insertions(+), 22 deletions(-) diff --git a/mlx/backend/common/solve.cpp b/mlx/backend/common/solve.cpp index 74e28aad2..8d81558a7 100644 --- a/mlx/backend/common/solve.cpp +++ b/mlx/backend/common/solve.cpp @@ -12,26 +12,24 @@ namespace mlx::core { void solve_impl(const array& a, const array& b, array& out) { int N = a.shape(-2); int NRHS = out.shape(-1); - std::vector ipiv(N); - // copy b into out and make it col-contiguous + // Copy b into out and make it col contiguous + auto ndim = out.ndim(); auto flags = out.flags(); - auto ndim = b.ndim(); - flags.col_contiguous = ndim <= 2; + flags.col_contiguous = ndim == 2; flags.row_contiguous = false; flags.contiguous = true; auto strides = out.strides(); - if (ndim >= 2) { - std::swap(strides[ndim - 1], strides[ndim - 2]); - } - + strides[ndim - 1] = N; + strides[ndim - 2] = 1; out.set_data( allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags); - copy_inplace(b, out, CopyType::GeneralGeneral); + copy_inplace( + b, out, b.shape(), b.strides(), strides, 0, 0, CopyType::GeneralGeneral); // lapack clobbers the input, so we have to make a copy. the copy doesn't need // to be col-contiguous because sgetrs has a transpose parameter (trans='T'). - array a_cpy(a.shape(), float32, nullptr, {}); + array a_cpy(a.shape(), a.dtype(), nullptr, {}); copy( a, a_cpy, @@ -39,7 +37,9 @@ void solve_impl(const array& a, const array& b, array& out) { float* a_ptr = a_cpy.data(); float* out_ptr = out.data(); - int* ipiv_ptr = ipiv.data(); + + auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; + int* ipiv_ptr = static_cast(ipiv.buffer.raw_ptr()); int info; size_t num_matrices = a.size() / (N * N); diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 92f70e927..bcef5b28c 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -500,21 +500,28 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } -array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { - if (a.ndim() < 2) { +array solve(const array& in_a, const array& in_b, StreamOrDevice s /* = {} */) { + if (in_a.ndim() < 2) { std::ostringstream msg; msg << "[linalg::solve] First input must have >= 2 dimensions. " - << "Received array with " << a.ndim() << " dimensions."; + << "Received array with " << in_a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - if (b.ndim() < 1) { + if (in_b.ndim() < 1) { std::ostringstream msg; msg << "[linalg::solve] Second input must have >= 1 dimensions. " - << "Received array with " << b.ndim() << " dimensions."; + << "Received array with " << in_b.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } + auto a = in_a; + auto b = in_b; + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = expand_dims(b, -1, s); + } + if (a.shape(-1) != a.shape(-2)) { std::ostringstream msg; msg << "[linalg::solve] First input must be a square matrix. " @@ -522,7 +529,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { throw std::invalid_argument(msg.str()); } - if (a.shape(-1) != b.shape(b.ndim() - 2)) { + if (a.shape(-1) != b.shape(-2)) { std::ostringstream msg; msg << "[linalg::solve] Last dimension of first input with shape " << a.shape() << " must match second to last dimension of" @@ -535,12 +542,35 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { std::ostringstream msg; msg << "[linalg::solve] Input array must have type float32. Received arrays " << "with type " << a.dtype() << " and " << b.dtype() << "."; + throw std::invalid_argument(msg.str()); } - return array( + + // Broadcast leading dimensions + if (a.ndim() > 2 || b.ndim() > 2) { + std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); + std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + auto inner_shape = broadcast_shapes(bsx_a, bsx_b); + + // Broadcast a + inner_shape.push_back(a.shape(-2)); + inner_shape.push_back(a.shape(-1)); + a = broadcast_to(a, inner_shape, s); + + // Broadcast b + *(inner_shape.end() - 2) = b.shape(-2); + *(inner_shape.end() - 1) = b.shape(-1); + b = broadcast_to(b, inner_shape, s); + } + + auto out = array( b.shape(), out_type, std::make_shared(to_stream(s)), {astype(a, out_type, s), astype(b, out_type, s)}); + if (in_b.ndim() == 1) { + return squeeze(out, -1, s); + } + return out; } } // namespace mlx::core::linalg diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index f81186818..6a98720a2 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -366,8 +366,8 @@ def test_solve(self): self.assertTrue(np.allclose(result, expected, atol=1e-5)) # Test batched multi-column rhs - a = mx.concat([a, a, a, a, a, a]).reshape((3, 2, 5, 5)) - b = mx.concat([b, b, b, b, b, b]).reshape((3, 2, 5, 8)) + a = mx.broadcast_to(a, (3, 2, 5, 5)) + b = mx.broadcast_to(b, (3, 1, 5, 8)) result = mx.linalg.solve(a, b, stream=mx.cpu) expected = np.linalg.solve(a, b) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index a8b03ee14..4dfe998c8 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -517,8 +517,8 @@ TEST_CASE("test solve") { CHECK(allclose(matmul(a, result), b).item()); // Test batch multi-column rhs - a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); - b = reshape(concatenate({b, b, b, b, b}), {5, 3, 2}); + a = broadcast_to(a, {5, 3, 3}); + b = broadcast_to(b, {5, 3, 2}); result = linalg::solve(a, b, Device::cpu); CHECK(allclose(matmul(a, result), b).item()); From 25580fd544aa2def2cc2ff317b1a4d257724c686 Mon Sep 17 00:00:00 2001 From: aleinin <95333017+abeleinin@users.noreply.github.com> Date: Thu, 31 Oct 2024 17:43:16 -0500 Subject: [PATCH 4/4] luf primitive and lu, solve, and solve_triangular backends --- docs/src/python/linalg.rst | 3 + mlx/backend/accelerate/primitives.cpp | 2 +- mlx/backend/common/CMakeLists.txt | 2 +- mlx/backend/common/default_primitives.cpp | 2 +- mlx/backend/common/luf.cpp | 66 ++++++++++ mlx/backend/common/solve.cpp | 97 -------------- mlx/backend/metal/primitives.cpp | 12 +- mlx/backend/no_cpu/primitives.cpp | 2 +- mlx/backend/no_metal/primitives.cpp | 2 +- mlx/linalg.cpp | 152 ++++++++++++++++------ mlx/linalg.h | 10 ++ mlx/primitives.cpp | 11 -- mlx/primitives.h | 9 +- python/src/linalg.cpp | 75 ++++++++++- python/tests/test_linalg.py | 72 +++++++++- tests/linalg_tests.cpp | 43 +++++- 16 files changed, 384 insertions(+), 176 deletions(-) create mode 100644 mlx/backend/common/luf.cpp delete mode 100644 mlx/backend/common/solve.cpp diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 853ad393b..769f4bbb1 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -18,4 +18,7 @@ Linear Algebra svd eigvalsh eigh + lu + lu_factor solve + solve_triangular diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 69f7eadca..6d267accd 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -82,7 +82,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT_MULTI(Eigh) -DEFAULT_MULTI(Solve) +DEFAULT_MULTI(LUF) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 7a77f82e6..a990ff119 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -52,7 +52,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/solve.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 26c745537..d4d7f7cac 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -111,7 +111,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT_MULTI(Eigh) -DEFAULT_MULTI(Solve) +DEFAULT_MULTI(LUF) namespace { diff --git a/mlx/backend/common/luf.cpp b/mlx/backend/common/luf.cpp new file mode 100644 index 000000000..82fed15ea --- /dev/null +++ b/mlx/backend/common/luf.cpp @@ -0,0 +1,66 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void lu_factor_impl(const array& a, array& lu, array& pivots) { + int M = a.shape(-2); + int N = a.shape(-1); + + // Copy a into lu and make it col contiguous + auto ndim = lu.ndim(); + auto flags = lu.flags(); + flags.col_contiguous = ndim == 2; + flags.row_contiguous = false; + flags.contiguous = true; + auto strides = lu.strides(); + strides[ndim - 1] = M; + strides[ndim - 2] = 1; + lu.set_data( + allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags); + copy_inplace( + a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral); + + float* a_ptr = lu.data(); + + pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); + int* pivots_ptr = pivots.data(); + + int info; + size_t num_matrices = a.size() / (M * N); + for (size_t i = 0; i < num_matrices; ++i) { + // Compute LU factorization of A + MLX_LAPACK_FUNC(sgetrf) + (/* m */ &M, + /* n */ &N, + /* a */ a_ptr, + /* lda */ &M, + /* ipiv */ pivots_ptr, + /* info */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info + << ((info > 0) ? " because matrix is singular" + : " because argument had an illegal value"); + throw std::runtime_error(ss.str()); + } + + // Advance pointers to the next matrix + a_ptr += M * N; + pivots_ptr += pivots.shape(-1); + } +} + +void LUF::eval(const std::vector& inputs, std::vector& outputs) { + assert(inputs.size() == 1); + lu_factor_impl(inputs[0], outputs[0], outputs[1]); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/solve.cpp b/mlx/backend/common/solve.cpp deleted file mode 100644 index 8d81558a7..000000000 --- a/mlx/backend/common/solve.cpp +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include - -#include "mlx/allocator.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" -#include "mlx/primitives.h" - -namespace mlx::core { - -void solve_impl(const array& a, const array& b, array& out) { - int N = a.shape(-2); - int NRHS = out.shape(-1); - - // Copy b into out and make it col contiguous - auto ndim = out.ndim(); - auto flags = out.flags(); - flags.col_contiguous = ndim == 2; - flags.row_contiguous = false; - flags.contiguous = true; - auto strides = out.strides(); - strides[ndim - 1] = N; - strides[ndim - 2] = 1; - out.set_data( - allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags); - copy_inplace( - b, out, b.shape(), b.strides(), strides, 0, 0, CopyType::GeneralGeneral); - - // lapack clobbers the input, so we have to make a copy. the copy doesn't need - // to be col-contiguous because sgetrs has a transpose parameter (trans='T'). - array a_cpy(a.shape(), a.dtype(), nullptr, {}); - copy( - a, - a_cpy, - a.flags().row_contiguous ? CopyType::Vector : CopyType::General); - - float* a_ptr = a_cpy.data(); - float* out_ptr = out.data(); - - auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; - int* ipiv_ptr = static_cast(ipiv.buffer.raw_ptr()); - - int info; - size_t num_matrices = a.size() / (N * N); - for (size_t i = 0; i < num_matrices; i++) { - // Compute LU factorization of A - MLX_LAPACK_FUNC(sgetrf) - (/* m */ &N, - /* n */ &N, - /* a */ a_ptr, - /* lda */ &N, - /* ipiv */ ipiv_ptr, - /* info */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "[Solve::eval_cpu] sgetrf_ failed with code " << info - << ((info > 0) ? " because matrix is singular" - : " becuase argument had an illegal value"); - throw std::runtime_error(ss.str()); - } - - // Solve the system using the LU factors from sgetrf - static constexpr char trans = 'T'; - MLX_LAPACK_FUNC(sgetrs) - ( - /* trans */ &trans, - /* n */ &N, - /* nrhs */ &NRHS, - /* a */ a_ptr, - /* lda */ &N, - /* ipiv */ ipiv_ptr, - /* b */ out_ptr, - /* ldb */ &N, - /* info */ &info); - - if (info != 0) { - std::stringstream ss; - ss << "[Solve::eval_cpu] sgetrs_ failed with code " << info; - throw std::runtime_error(ss.str()); - } - - // Advance pointers to the next matrix - a_ptr += N * N; - out_ptr += N * NRHS; - } -} - -void Solve::eval( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 2); - solve_impl(inputs[0], inputs[1], outputs[0]); -} - -} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 35586a257..1e6d7df16 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -407,6 +407,12 @@ void Eigh::eval_gpu( throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); } +void LUF::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); +} + void View::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); @@ -438,10 +444,4 @@ void View::eval_gpu(const std::vector& inputs, array& out) { } } -void Solve::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - throw std::runtime_error("[Solve::eval_gpu] Metal Solve NYI."); -} - } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index db9f40013..53226b472 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -73,6 +73,7 @@ NO_CPU(LogicalNot) NO_CPU(LogicalAnd) NO_CPU(LogicalOr) NO_CPU(LogAddExp) +NO_CPU_MULTI(LUF) NO_CPU(Matmul) NO_CPU(Maximum) NO_CPU(Minimum) @@ -111,6 +112,5 @@ NO_CPU(Tanh) NO_CPU(Transpose) NO_CPU(Inverse) NO_CPU(View) -NO_CPU_MULTI(Solve) } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index f8556a76d..dc0388493 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -74,6 +74,7 @@ NO_GPU(LogicalNot) NO_GPU(LogicalAnd) NO_GPU(LogicalOr) NO_GPU(LogAddExp) +NO_GPU_MULTI(LUF) NO_GPU(Matmul) NO_GPU(Maximum) NO_GPU(Minimum) @@ -114,7 +115,6 @@ NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) NO_GPU(View) -NO_GPU_MULTI(Solve) namespace fast { NO_GPU_MULTI(LayerNorm) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index bcef5b28c..7ada566b1 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -270,7 +270,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) { array tri_inv( const array& a, - bool upper /* = true */, + bool upper /* = false */, StreamOrDevice s /* = {} */) { return inv_impl(a, /*tri=*/true, upper, s); } @@ -454,7 +454,7 @@ array cross( return concatenate(outputs, axis, s); } -void validate_eigh(const array& a, const std::string fname) { +void validate_eigh(const array& a, const std::string& fname) { if (a.dtype() != float32) { std::ostringstream msg; msg << fname << " Arrays must have type float32. Received array " @@ -500,39 +500,108 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } -array solve(const array& in_a, const array& in_b, StreamOrDevice s /* = {} */) { - if (in_a.ndim() < 2) { +void validate_lu(const array& a, const std::string& fname) { + if (a.dtype() != float32) { std::ostringstream msg; - msg << "[linalg::solve] First input must have >= 2 dimensions. " - << "Received array with " << in_a.ndim() << " dimensions."; + msg << fname << " Arrays must type float32. Received array " + << "with type " << a.dtype() << "."; throw std::invalid_argument(msg.str()); } - if (in_b.ndim() < 1) { + if (a.ndim() < 2) { std::ostringstream msg; - msg << "[linalg::solve] Second input must have >= 1 dimensions. " - << "Received array with " << in_b.ndim() << " dimensions."; + msg << fname + << " Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - auto a = in_a; - auto b = in_b; - if (b.ndim() == 1) { - // Insert a singleton dim at the end - b = expand_dims(b, -1, s); + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument(fname + " Only defined for square matrices."); + } +} + +std::vector lu(const array& a, StreamOrDevice s /* = {} */) { + validate_lu(a, "[linalg::lu]"); + + auto [LU, pivots] = lu_factor(a, s); + int N = a.shape(-1); + + pivots.eval(); + int* pivots_ptr = pivots.data(); + + size_t num_matrices = a.size() / (a.shape(-2) * N); + std::vector P_matrices; + P_matrices.reserve(num_matrices); + for (size_t m = 0; m < num_matrices; ++m) { + array P = eye(N, s); + for (int i = 0; i < N; ++i) { + // Convert pivots to 0-based indexing + int j = pivots_ptr[i] - 1; + if (i != j) { + array row_i = slice(P, {i, 0}, {i + 1, N}, s); + array row_j = slice(P, {j, 0}, {j + 1, N}, s); + + P = slice_update(P, row_j, {i, 0}, {i + 1, N}, s); + P = slice_update(P, row_i, {j, 0}, {j + 1, N}, s); + } + } + P_matrices.push_back(transpose(P, s)); + pivots_ptr += pivots.shape(-1); + } + + array P = reshape(stack(P_matrices, /* axis = */ 0, s), a.shape(), s); + array L = add(tril(LU, /* k = */ -1, s), eye(N, s), s); + array U = triu(LU, /* k = */ 0, s); + + return {P, L, U}; +} + +std::pair lu_factor(const array& a, StreamOrDevice s /* = {} */) { + validate_lu(a, "[linalg::lu_factor]"); + + int m = a.shape()[a.shape().size() - 2]; + int n = a.shape()[a.shape().size() - 1]; + + std::vector pivots_shape(a.shape().begin(), a.shape().end() - 2); + pivots_shape.push_back(std::min(m, n)); + + auto out = array::make_arrays( + {a.shape(), pivots_shape}, + {a.dtype(), int32}, + std::make_shared(to_stream(s)), + {astype(a, a.dtype(), s)}); + return std::make_pair(out[0], out[1]); +} + +void validate_solve(const array& a, const array& b, const std::string& fname) { + if (a.ndim() < 2) { + std::ostringstream msg; + msg << fname << " First input must have >= 2 dimensions. " + << "Received array with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (b.ndim() < 1) { + std::ostringstream msg; + msg << fname << " Second input must have >= 1 dimensions. " + << "Received array with " << b.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); } if (a.shape(-1) != a.shape(-2)) { std::ostringstream msg; - msg << "[linalg::solve] First input must be a square matrix. " + msg << fname << " First input must be a square matrix. " << "Received array with shape " << a.shape() << "."; throw std::invalid_argument(msg.str()); } - if (a.shape(-1) != b.shape(-2)) { + int lastDim = b.ndim() > 1 ? -2 : -1; + if (a.shape(-1) != b.shape(lastDim)) { std::ostringstream msg; - msg << "[linalg::solve] Last dimension of first input with shape " - << a.shape() << " must match second to last dimension of" + msg << fname << " Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" << " second input with shape " << b.shape() << "."; throw std::invalid_argument(msg.str()); } @@ -540,37 +609,36 @@ array solve(const array& in_a, const array& in_b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); if (out_type != float32) { std::ostringstream msg; - msg << "[linalg::solve] Input array must have type float32. Received arrays " + msg << fname << " Input array must have type float32. Received arrays " << "with type " << a.dtype() << " and " << b.dtype() << "."; throw std::invalid_argument(msg.str()); } +} - // Broadcast leading dimensions - if (a.ndim() > 2 || b.ndim() > 2) { - std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); - std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); - auto inner_shape = broadcast_shapes(bsx_a, bsx_b); +array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { + validate_solve(a, b, "[linalg::solve]"); - // Broadcast a - inner_shape.push_back(a.shape(-2)); - inner_shape.push_back(a.shape(-1)); - a = broadcast_to(a, inner_shape, s); + // P, L, U matrices + const auto luf = lu(a, s); - // Broadcast b - *(inner_shape.end() - 2) = b.shape(-2); - *(inner_shape.end() - 1) = b.shape(-1); - b = broadcast_to(b, inner_shape, s); - } + std::vector order(a.ndim()); + std::iota(order.begin(), order.end(), 0); + std::swap(order[order.size() - 1], order[order.size() - 2]); - auto out = array( - b.shape(), - out_type, - std::make_shared(to_stream(s)), - {astype(a, out_type, s), astype(b, out_type, s)}); - if (in_b.ndim() == 1) { - return squeeze(out, -1, s); - } - return out; + array P = transpose(luf[0], order, s); + array Pb = matmul(P, b, s); + array y = solve_triangular(luf[1], Pb, /* upper = */ false, s); + return solve_triangular(luf[2], y, /* upper = */ true, s); +} + +array solve_triangular( + const array& a, + const array& b, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + validate_solve(a, b, "[linalg::solve_triangular]"); + array a_inv = tri_inv(a, upper, s); + return matmul(a_inv, b); } } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index bca1d56af..9fe4dbf60 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -74,8 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); +std::vector lu(const array& a, StreamOrDevice s = {}); + +std::pair lu_factor(const array& a, StreamOrDevice s = {}); + array solve(const array& a, const array& b, StreamOrDevice s = {}); +array solve_triangular( + const array& a, + const array& b, + bool upper = false, + StreamOrDevice s = {}); + /** * Compute the cross product of two arrays along the given axis. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9ea5cc513..c9f839d4b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4201,17 +4201,6 @@ std::pair, std::vector> SVD::vmap( return {{linalg::svd(a, stream())}, {ax, ax, ax}}; } -std::pair, std::vector> Solve::vmap( - const std::vector& inputs, - const std::vector& axes) { - auto maybe_move_ax = [this](auto& arr, auto ax) { - return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr; - }; - auto a = maybe_move_ax(inputs[0], axes[0]); - auto b = maybe_move_ax(inputs[1], axes[1]); - return {{linalg::solve(a, b, stream())}, {0}}; -} - std::pair, std::vector> Inverse::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index f15d66f40..7c34d6736 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2235,17 +2235,16 @@ class Eigh : public Primitive { bool compute_eigenvectors_; }; -class Solve : public Primitive { +/* LU Factorization primitive. */ +class LUF : public Primitive { public: - explicit Solve(Stream stream) : Primitive(stream) {} + explicit LUF(Stream stream) : Primitive(stream) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_VMAP() - DEFINE_PRINT(Solve) - DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_PRINT(LUF) private: void eval(const std::vector& inputs, std::vector& outputs); diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 0510f0f69..951c831e1 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -21,6 +21,10 @@ nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) { const auto result = svd(a, s); return nb::make_tuple(result.at(0), result.at(1), result.at(2)); } +nb::tuple lu_helper(const array& a, StreamOrDevice s /* = {} */) { + const auto result = lu(a, s); + return nb::make_tuple(result.at(0), result.at(1), result.at(2)); +} } // namespace void init_linalg(nb::module_& parent_module) { @@ -264,7 +268,7 @@ void init_linalg(nb::module_& parent_module) { "tri_inv", &tri_inv, "a"_a, - "upper"_a, + "upper"_a = false, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -278,7 +282,7 @@ void init_linalg(nb::module_& parent_module) { Args: a (array): Input array. - upper (array): Whether the array is upper or lower triangular. Defaults to ``False``. + upper (bool, optional): Whether the array is upper or lower triangular. Defaults to ``False``. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. @@ -485,6 +489,44 @@ void init_linalg(nb::module_& parent_module) { array([[ 0.707107, -0.707107], [ 0.707107, 0.707107]], dtype=float32) )pbdoc"); + m.def( + "lu", + &lu_helper, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def lu(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), + R"pbdoc( + Compute the LU factorization of the given matrix ``A``. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + tuple(array, array, array): The ``P``, ``L``, and ``U`` matrices, such that ``A = P @ L @ U`` + )pbdoc"); + m.def( + "lu_factor", + &lu_factor, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def lu_factor(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), + R"pbdoc( + Computes a compact representation of the LU factorization. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + tuple(array, array): The ``LU`` matrix and ``pivots`` array. + )pbdoc"); m.def( "solve", &solve, @@ -495,13 +537,38 @@ void init_linalg(nb::module_& parent_module) { nb::sig( "def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - Compute the solution to a system of linear equations ax = b. + Compute the solution to a system of linear equations ``AX = B``. Args: a (array): Input array. b (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The unique solution to the system ``AX = B``. + )pbdoc"); + m.def( + "solve_triangular", + &solve_triangular, + "a"_a, + "b"_a, + "upper"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def solve_triangular(a: array, b: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Computes the solution of a triangular system of linear equations ``AX = B``. + + Args: + a (array): Input array. + b (array): Input array. + upper (bool, optional): Whether the array is upper or lower triangular. Default ``False``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. Returns: - array: The unique solution to the system ax = b. + array: The unique solution to the system ``AX = B``. )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6a98720a2..6d20735a3 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -319,6 +319,45 @@ def check_eigs_and_vecs(A_np, kwargs={}): mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix + def test_lu(self): + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array(0.0), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu) + + # Test 3x3 matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(P @ L @ U, a)) + + # Test batch dimension + a = mx.broadcast_to(a, (5, 5, 3, 3)) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(P @ L @ U, a)) + + def test_lu_factor(self): + mx.random.seed(7) + + # Test 3x3 matrix + a = mx.random.uniform(shape=(5, 5)) + LU, pivots = mx.linalg.lu_factor(a, stream=mx.cpu) + n = a.shape[-1] + + P = mx.eye(n) + for i in range(len(pivots)): + j = pivots[i] - 1 + if i != j: + P[[i, j]] = P[[j, i]] + + P = mx.transpose(P) + L = mx.add(mx.tril(LU, k=-1), mx.eye(n)) + U = mx.triu(LU) + self.assertTrue(mx.allclose(P @ L @ U, a)) + def test_solve(self): mx.random.seed(7) @@ -338,7 +377,7 @@ def test_solve(self): result = mx.linalg.solve(a, b, stream=mx.cpu) expected = np.linalg.solve(a, b) - self.assertTrue(np.allclose(result, expected, atol=1e-5)) + self.assertTrue(np.allclose(result, expected)) # Test batch dimension a = mx.random.uniform(shape=(5, 5, 4, 4)) @@ -355,7 +394,7 @@ def test_solve(self): result = mx.linalg.solve(a, b, stream=mx.cpu) expected = np.linalg.solve(a, b) - self.assertTrue(np.allclose(result, expected, atol=1e-2)) + self.assertTrue(np.allclose(result, expected, atol=1e-3)) # Test multi-column rhs a = mx.random.uniform(shape=(5, 5)) @@ -363,7 +402,7 @@ def test_solve(self): result = mx.linalg.solve(a, b, stream=mx.cpu) expected = np.linalg.solve(a, b) - self.assertTrue(np.allclose(result, expected, atol=1e-5)) + self.assertTrue(np.allclose(result, expected)) # Test batched multi-column rhs a = mx.broadcast_to(a, (3, 2, 5, 5)) @@ -371,7 +410,32 @@ def test_solve(self): result = mx.linalg.solve(a, b, stream=mx.cpu) expected = np.linalg.solve(a, b) - self.assertTrue(np.allclose(result, expected, atol=1e-5)) + self.assertTrue(np.allclose(result, expected)) + + def test_solve_triangular(self): + # Test lower triangular matrix + a = mx.array([[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]) + b = mx.array([8.0, 14.0, 3.0]) + + result = mx.linalg.solve_triangular(a, b, upper=False, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test upper triangular matrix + a = mx.array([[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]]) + b = mx.array([13.0, 33.0, 18.0]) + + result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test batch multi-column rhs + a = mx.broadcast_to(a, (3, 4, 3, 3)) + b = mx.broadcast_to(mx.expand_dims(b, -1), (3, 4, 3, 8)) + + result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) if __name__ == "__main__": diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 4dfe998c8..efc8ba9d5 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -474,6 +474,26 @@ TEST_CASE("test matrix eigh") { CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item()); } +TEST_CASE("test lu") { + // Test 2x2 matrix + array a = array({1., 2., 3., 4.}, {2, 2}); + auto out = linalg::lu(a, Device::cpu); + array expected = matmul(matmul(out[0], out[1]), out[2]); + CHECK(allclose(a, expected).item()); + + // Test 3x3 matrix + a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3}); + out = linalg::lu(a, Device::cpu); + expected = matmul(matmul(out[0], out[1]), out[2]); + CHECK(allclose(a, expected).item()); + + // Test batch dimension + a = broadcast_to(a, {3, 3, 3}); + out = linalg::lu(a, Device::cpu); + expected = matmul(matmul(out[0], out[1]), out[2]); + CHECK(allclose(a, expected).item()); +} + TEST_CASE("test solve") { // 0D and 1D throw CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu)); @@ -503,8 +523,8 @@ TEST_CASE("test solve") { CHECK(allclose(matmul(a, result), b).item()); // Test batch dimension - a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); - b = reshape(concatenate({b, b, b, b, b}), {5, 3, 1}); + a = broadcast_to(a, {5, 3, 3}); + b = broadcast_to(b, {5, 3, 1}); result = linalg::solve(a, b, Device::cpu); CHECK(allclose(matmul(a, result), b).item()); @@ -523,3 +543,22 @@ TEST_CASE("test solve") { result = linalg::solve(a, b, Device::cpu); CHECK(allclose(matmul(a, result), b).item()); } + +TEST_CASE("test solve_triangluar") { + // Test lower triangular matrix + array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3}); + array b = array({2., 5., 0.}); + + array result = + linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu); + array expected = array({1., 2., 1.}); + CHECK(allclose(expected, result).item()); + + // Test upper triangular matrix + a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3}); + b = array({5., 14., 3.}); + + result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu); + expected = array({-3., 2., 3.}); + CHECK(allclose(expected, result).item()); +}