Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group support for conv3d #1262

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ MLX was developed with contributions from the following individuals:
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Max-Heinrich Laves: Added groups in 3D convolutions.

<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
Expand Down
214 changes: 123 additions & 91 deletions mlx/backend/common/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,19 @@ void slow_conv_3D(
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
const int C = in.shape(4); // In channels
const int oD = out.shape(1); // Output spatial dim
const int oH = out.shape(2); // Output spatial dim
const int oW = out.shape(3); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(4); // In channels
const int wD = wt.shape(1); // Weight spatial dim
const int wH = wt.shape(2); // Weight spatial dim
const int wW = wt.shape(3); // Weight spatial dim

const int groups = C / wt.shape(4);
const int C_per_group = wt.shape(4);
const int O_per_group = O / groups;

const size_t in_stride_N = in.strides()[0];
const size_t in_stride_D = in.strides()[1];
const size_t in_stride_H = in.strides()[2];
Expand Down Expand Up @@ -377,39 +381,40 @@ void slow_conv_3D(
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];

for (int o = 0; o < O; ++o) {
float r = 0.;

for (int wd = 0; wd < wD; ++wd) {
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];

const T* wt_ptr_pt =
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;

for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c

} // ww
} // wh
} // wd

out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;

for (int wd = 0; wd < wD; ++wd) {
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];

const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + id * in_stride_D +
ih * in_stride_H + iw * in_stride_W;

for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c

} // ww
} // wh
} // wd

out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};

int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
Expand Down Expand Up @@ -484,47 +489,48 @@ void slow_conv_3D(
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];

for (int o = 0; o < O; ++o) {
float r = 0.;

for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];

if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
iw < iW) {
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
wh * wt_stride_H + ww * wt_stride_W;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;

int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];

const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
ih_dil * in_stride_H + iw_dil * in_stride_W;
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
iw < iW) {
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
wh * wt_stride_H + ww * wt_stride_W;

for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;

} // iD, ih, iw check
} // ww
} // wh
} // wd
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
ih_dil * in_stride_H + iw_dil * in_stride_W;

out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c

} // iD, ih, iw check
} // ww
} // wh
} // wd

out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};

int oD_border_0 = 0;
Expand Down Expand Up @@ -916,11 +922,15 @@ void explicit_gemm_conv_ND_cpu(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const int C = in.shape(-1); // Input channels
const int O = wt.shape(0); // Output channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim

const int groups = C / wt.shape(-1);
const int C_per_group = wt.shape(-1);
const int O_per_group = O / groups;

auto conv_dtype = float32;

// Pad input
Expand Down Expand Up @@ -973,6 +983,15 @@ void explicit_gemm_conv_ND_cpu(

auto flags = in_padded.flags();

if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(
*(strided_shape.end() - (wDim.size() + 1)), *(strided_shape.end() - 1));
std::swap(
*(strided_strides.end() - (in_padded.strides().size() - 1)),
*(strided_strides.end() - 1));
}

array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
Expand All @@ -993,7 +1012,18 @@ void explicit_gemm_conv_ND_cpu(
auto gemm_wt = wt;
auto gemm_out = out;

if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
auto wt_shape = wt.shape();
std::swap(*(wt_shape.begin() + 1), *(wt_shape.end() - 1));
auto wt_strides = wt.strides();
std::swap(*(wt_strides.begin() + 1), *(wt_strides.end() - 1));

array wt_transpose(wt_shape, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(wt, wt_strides, wt.flags(), wt.size(), 0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
Expand All @@ -1005,24 +1035,26 @@ void explicit_gemm_conv_ND_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}

// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);

for (int g = 0; g < groups; ++g) {
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O_per_group, // N
strided_reshape[1] / C * C_per_group, // K
1.0f, // alpha
in_strided.data<float>() +
g * strided_reshape[1] / C * C_per_group, // A
strided_reshape[1], // lda
gemm_wt.data<float>() +
g * O_per_group * strided_reshape[1] / C * C_per_group, // B
strided_reshape[1] / C * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>() + g * O_per_group, // C
O // ldc
);
}
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
Expand Down
8 changes: 7 additions & 1 deletion mlx/backend/metal/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ void conv_3D_gpu(
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
const int groups,
bool flip,
std::vector<array>& copies) {
// Make conv params
Expand Down Expand Up @@ -832,9 +833,13 @@ void conv_3D_gpu(
out.strides()[2],
out.strides()[3],
out.strides()[4]},
/* const int groups = */ 1,
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}

return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}

Expand Down Expand Up @@ -874,6 +879,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
copies);
}
Expand Down
4 changes: 2 additions & 2 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3180,9 +3180,9 @@ array conv_general(
bool flip /* = false */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) {
if (groups != 1 && in.ndim() != 3 && in.ndim() != 4 && in.ndim() != 5) {
throw std::invalid_argument(
"[conv] Can only handle groups != 1 in 1D or 2D convolutions.");
"[conv] Can only handle groups != 1 in 1D, 2D and 3D convolutions.");
}

int spatial_dims = in.ndim() - 2;
Expand Down
16 changes: 15 additions & 1 deletion python/tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,9 @@ def run_conv3D(
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
wt_np = np.random.normal(
0.0, 1.0, (O, kD, kH, kW, int(C / groups))
).astype(np_dtype)

in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
Expand Down Expand Up @@ -540,6 +542,18 @@ def run_conv3D(
):
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)

# Groups tests
N, C, O = (4, 16, 32)
for idim, kdim, stride, padding in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)),
):
for group in (1, 2, 4, 8, 16):
run_conv3D(
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
)

@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_3D_grad(self):
def run_conv3D_grad(
Expand Down