Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU LU factorization and linear solvers #1451

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/src/python/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ Linear Algebra

.. currentmodule:: mlx.core.linalg

.. autosummary::
:toctree: _autosummary
.. autosummary::
:toctree: _autosummary

inv
tri_inv
Expand All @@ -18,3 +18,7 @@ Linear Algebra
svd
eigvalsh
eigh
lu
lu_factor
solve
solve_triangular
1 change: 1 addition & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
DEFAULT_MULTI(LUF)

void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)

Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
DEFAULT_MULTI(LUF)

namespace {

Expand Down
66 changes: 66 additions & 0 deletions mlx/backend/common/luf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright © 2024 Apple Inc.

#include <cassert>

#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"

namespace mlx::core {

void lu_factor_impl(const array& a, array& lu, array& pivots) {
int M = a.shape(-2);
int N = a.shape(-1);

// Copy a into lu and make it col contiguous
auto ndim = lu.ndim();
auto flags = lu.flags();
flags.col_contiguous = ndim == 2;
flags.row_contiguous = false;
flags.contiguous = true;
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);

float* a_ptr = lu.data<float>();

pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
int* pivots_ptr = pivots.data<int>();

int info;
size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ pivots_ptr,
/* info */ &info);

if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}

// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += pivots.shape(-1);
}
}

void LUF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
assert(inputs.size() == 1);
lu_factor_impl(inputs[0], outputs[0], outputs[1]);
}

} // namespace mlx::core
6 changes: 6 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,12 @@ void Eigh::eval_gpu(
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
}

void LUF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
}

void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd)
NO_CPU(LogicalOr)
NO_CPU(LogAddExp)
NO_CPU_MULTI(LUF)
NO_CPU(Matmul)
NO_CPU(Maximum)
NO_CPU(Minimum)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
Expand Down
145 changes: 143 additions & 2 deletions mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {

array tri_inv(
const array& a,
bool upper /* = true */,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
return inv_impl(a, /*tri=*/true, upper, s);
}
Expand Down Expand Up @@ -454,7 +454,7 @@ array cross(
return concatenate(outputs, axis, s);
}

void validate_eigh(const array& a, const std::string fname) {
void validate_eigh(const array& a, const std::string& fname) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must have type float32. Received array "
Expand Down Expand Up @@ -500,4 +500,145 @@ std::pair<array, array> eigh(
return std::make_pair(out[0], out[1]);
}

void validate_lu(const array& a, const std::string& fname) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname
<< " Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(fname + " Only defined for square matrices.");
}
}

std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, "[linalg::lu]");

auto [LU, pivots] = lu_factor(a, s);
int N = a.shape(-1);

pivots.eval();
int* pivots_ptr = pivots.data<int>();

size_t num_matrices = a.size() / (a.shape(-2) * N);
std::vector<array> P_matrices;
P_matrices.reserve(num_matrices);
for (size_t m = 0; m < num_matrices; ++m) {
array P = eye(N, s);
for (int i = 0; i < N; ++i) {
// Convert pivots to 0-based indexing
int j = pivots_ptr[i] - 1;
if (i != j) {
array row_i = slice(P, {i, 0}, {i + 1, N}, s);
array row_j = slice(P, {j, 0}, {j + 1, N}, s);

P = slice_update(P, row_j, {i, 0}, {i + 1, N}, s);
P = slice_update(P, row_i, {j, 0}, {j + 1, N}, s);
}
}
P_matrices.push_back(transpose(P, s));
pivots_ptr += pivots.shape(-1);
}

array P = reshape(stack(P_matrices, /* axis = */ 0, s), a.shape(), s);
array L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
array U = triu(LU, /* k = */ 0, s);

return {P, L, U};
}

std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, "[linalg::lu_factor]");

int m = a.shape()[a.shape().size() - 2];
int n = a.shape()[a.shape().size() - 1];

std::vector<int> pivots_shape(a.shape().begin(), a.shape().end() - 2);
pivots_shape.push_back(std::min(m, n));

auto out = array::make_arrays(
{a.shape(), pivots_shape},
{a.dtype(), int32},
std::make_shared<LUF>(to_stream(s)),
{astype(a, a.dtype(), s)});
return std::make_pair(out[0], out[1]);
}

void validate_solve(const array& a, const array& b, const std::string& fname) {
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname << " First input must have >= 2 dimensions. "
<< "Received array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (b.ndim() < 1) {
std::ostringstream msg;
msg << fname << " Second input must have >= 1 dimensions. "
<< "Received array with " << b.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
std::ostringstream msg;
msg << fname << " First input must be a square matrix. "
<< "Received array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}

int lastDim = b.ndim() > 1 ? -2 : -1;
if (a.shape(-1) != b.shape(lastDim)) {
std::ostringstream msg;
msg << fname << " Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}

auto out_type = promote_types(a.dtype(), b.dtype());
if (out_type != float32) {
std::ostringstream msg;
msg << fname << " Input array must have type float32. Received arrays "
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
throw std::invalid_argument(msg.str());
}
}

array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
validate_solve(a, b, "[linalg::solve]");

// P, L, U matrices
const auto luf = lu(a, s);

std::vector<int> order(a.ndim());
std::iota(order.begin(), order.end(), 0);
std::swap(order[order.size() - 1], order[order.size() - 2]);

array P = transpose(luf[0], order, s);
array Pb = matmul(P, b, s);
array y = solve_triangular(luf[1], Pb, /* upper = */ false, s);
return solve_triangular(luf[2], y, /* upper = */ true, s);
}

array solve_triangular(
const array& a,
const array& b,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
validate_solve(a, b, "[linalg::solve_triangular]");
array a_inv = tri_inv(a, upper, s);
return matmul(a_inv, b);
}

} // namespace mlx::core::linalg
12 changes: 12 additions & 0 deletions mlx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {});

array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});

std::vector<array> lu(const array& a, StreamOrDevice s = {});

std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});

array solve(const array& a, const array& b, StreamOrDevice s = {});

array solve_triangular(
const array& a,
const array& b,
bool upper = false,
StreamOrDevice s = {});

/**
* Compute the cross product of two arrays along the given axis.
*/
Expand Down
16 changes: 15 additions & 1 deletion mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2202,7 +2202,6 @@ class Eigh : public Primitive {
: Primitive(stream),
uplo_(std::move(uplo)),
compute_eigenvectors_(compute_eigenvectors) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
Expand Down Expand Up @@ -2236,4 +2235,19 @@ class Eigh : public Primitive {
bool compute_eigenvectors_;
};

/* LU Factorization primitive. */
class LUF : public Primitive {
public:
explicit LUF(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

DEFINE_PRINT(LUF)

private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};

} // namespace mlx::core
Loading