diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 598161f78..e0d31d145 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -17,6 +17,7 @@ MLX was developed with contributions from the following individuals: - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. +- Max-Heinrich Laves: Added groups in 3D convolutions. diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 79bc3c4a1..fd9da04e6 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -336,15 +336,19 @@ void slow_conv_3D( const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim + const int C = in.shape(4); // In channels const int oD = out.shape(1); // Output spatial dim const int oH = out.shape(2); // Output spatial dim const int oW = out.shape(3); // Output spatial dim const int O = wt.shape(0); // Out channels - const int C = wt.shape(4); // In channels const int wD = wt.shape(1); // Weight spatial dim const int wH = wt.shape(2); // Weight spatial dim const int wW = wt.shape(3); // Weight spatial dim + const int groups = C / wt.shape(4); + const int C_per_group = wt.shape(4); + const int O_per_group = O / groups; + const size_t in_stride_N = in.strides()[0]; const size_t in_stride_D = in.strides()[1]; const size_t in_stride_H = in.strides()[2]; @@ -377,39 +381,40 @@ void slow_conv_3D( int ih_base = oh * wt_strides[1] - padding[1]; int iw_base = ow * wt_strides[2] - padding[2]; - for (int o = 0; o < O; ++o) { - float r = 0.; - - for (int wd = 0; wd < wD; ++wd) { - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wd_flip = flip ? wD - wd - 1 : wd; - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int id = id_base + wd_flip * wt_dilation[0]; - int ih = ih_base + wh_flip * wt_dilation[1]; - int iw = iw_base + ww_flip * wt_dilation[2]; - - const T* wt_ptr_pt = - wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = - in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W; - - for (int c = 0; c < C; ++c) { - r += static_cast(in_ptr_pt[0]) * - static_cast(wt_ptr_pt[0]); - in_ptr_pt += in_stride_C; - wt_ptr_pt += wt_stride_C; - } // c - - } // ww - } // wh - } // wd - - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; + + for (int wd = 0; wd < wD; ++wd) { + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wd_flip = flip ? wD - wd - 1 : wd; + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int id = id_base + wd_flip * wt_dilation[0]; + int ih = ih_base + wh_flip * wt_dilation[1]; + int iw = iw_base + ww_flip * wt_dilation[2]; + + const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = in_ptr + id * in_stride_D + + ih * in_stride_H + iw * in_stride_W; + + for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + + } // ww + } // wh + } // wd + + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g }; int jump_d = flip ? -wt_dilation[0] : wt_dilation[0]; @@ -484,47 +489,48 @@ void slow_conv_3D( int wh_base = base_h[oh % f_out_jump_h]; int ww_base = base_w[ow % f_out_jump_w]; - for (int o = 0; o < O; ++o) { - float r = 0.; - - for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) { - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wd_flip = flip ? wD - wd - 1 : wd; - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int id = id_base + wd_flip * wt_dilation[0]; - int ih = ih_base + wh_flip * wt_dilation[1]; - int iw = iw_base + ww_flip * wt_dilation[2]; - - if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 && - iw < iW) { - const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + - wh * wt_stride_H + ww * wt_stride_W; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - int id_dil = !is_idil_one ? (id / in_dilation[0]) : id; - int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw; + for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) { + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wd_flip = flip ? wD - wd - 1 : wd; + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int id = id_base + wd_flip * wt_dilation[0]; + int ih = ih_base + wh_flip * wt_dilation[1]; + int iw = iw_base + ww_flip * wt_dilation[2]; - const T* in_ptr_pt = in_ptr + id_dil * in_stride_D + - ih_dil * in_stride_H + iw_dil * in_stride_W; + if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 && + iw < iW) { + const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + + wh * wt_stride_H + ww * wt_stride_W; - for (int c = 0; c < C; ++c) { - r += static_cast(in_ptr_pt[0]) * - static_cast(wt_ptr_pt[0]); - in_ptr_pt += in_stride_C; - wt_ptr_pt += wt_stride_C; - } // c + int id_dil = !is_idil_one ? (id / in_dilation[0]) : id; + int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw; - } // iD, ih, iw check - } // ww - } // wh - } // wd + const T* in_ptr_pt = in_ptr + id_dil * in_stride_D + + ih_dil * in_stride_H + iw_dil * in_stride_W; - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o + for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + + } // iD, ih, iw check + } // ww + } // wh + } // wd + + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g }; int oD_border_0 = 0; @@ -916,11 +922,15 @@ void explicit_gemm_conv_ND_cpu( in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim const auto oDim = std::vector( out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim - const int O = wt.shape(0); // Out channels - const int C = wt.shape(-1); // In channels + const int C = in.shape(-1); // Input channels + const int O = wt.shape(0); // Output channels const auto wDim = std::vector( wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim + const int groups = C / wt.shape(-1); + const int C_per_group = wt.shape(-1); + const int O_per_group = O / groups; + auto conv_dtype = float32; // Pad input @@ -973,6 +983,15 @@ void explicit_gemm_conv_ND_cpu( auto flags = in_padded.flags(); + if (groups > 1) { + // Transpose the last two dimensions for grouped convolutions + std::swap( + *(strided_shape.end() - (wDim.size() + 1)), *(strided_shape.end() - 1)); + std::swap( + *(strided_strides.end() - (in_padded.strides().size() - 1)), + *(strided_strides.end() - 1)); + } + array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); in_strided_view.copy_shared_buffer( in_padded, strided_strides, flags, in_strided_view.size(), 0); @@ -993,7 +1012,18 @@ void explicit_gemm_conv_ND_cpu( auto gemm_wt = wt; auto gemm_out = out; - if (wt.dtype() != float32 || !wt.flags().row_contiguous) { + if (groups > 1) { + // Transpose the last two dimensions for grouped convolutions + auto wt_shape = wt.shape(); + std::swap(*(wt_shape.begin() + 1), *(wt_shape.end() - 1)); + auto wt_strides = wt.strides(); + std::swap(*(wt_strides.begin() + 1), *(wt_strides.end() - 1)); + + array wt_transpose(wt_shape, wt.dtype(), nullptr, {}); + wt_transpose.copy_shared_buffer(wt, wt_strides, wt.flags(), wt.size(), 0); + gemm_wt = array(wt_transpose.shape(), float32, nullptr, {}); + copy(wt_transpose, gemm_wt, CopyType::General); + } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); @@ -1005,24 +1035,26 @@ void explicit_gemm_conv_ND_cpu( gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); } - // Perform gemm - cblas_sgemm( - CblasRowMajor, - CblasNoTrans, // no trans A - CblasTrans, // transB - strided_reshape[0], // M - O, // N - strided_reshape[1], // K - 1.0f, // alpha - in_strided.data(), - strided_reshape[1], // lda - gemm_wt.data(), - strided_reshape[1], // ldb - 0.0f, // beta - gemm_out.data(), - O // ldc - ); - + for (int g = 0; g < groups; ++g) { + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape[0], // M + O_per_group, // N + strided_reshape[1] / C * C_per_group, // K + 1.0f, // alpha + in_strided.data() + + g * strided_reshape[1] / C * C_per_group, // A + strided_reshape[1], // lda + gemm_wt.data() + + g * O_per_group * strided_reshape[1] / C * C_per_group, // B + strided_reshape[1] / C * C_per_group, // ldb + 0.0f, // beta + gemm_out.data() + g * O_per_group, // C + O // ldc + ); + } // Copy results if needed if (out.dtype() != float32) { copy(gemm_out, out, CopyType::Vector); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 395488ead..9be7785c5 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -798,6 +798,7 @@ void conv_3D_gpu( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, + const int groups, bool flip, std::vector& copies) { // Make conv params @@ -832,9 +833,13 @@ void conv_3D_gpu( out.strides()[2], out.strides()[3], out.strides()[4]}, - /* const int groups = */ 1, + /* const int groups = */ groups, /* const bool flip = */ flip, }; + if (groups > 1) { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); } @@ -874,6 +879,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { kernel_strides_, kernel_dilation_, input_dilation_, + groups_, flip_, copies); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5d019edf0..2114f7cc2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3180,9 +3180,9 @@ array conv_general( bool flip /* = false */, StreamOrDevice s /* = {} */) { // Run checks - if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) { + if (groups != 1 && in.ndim() != 3 && in.ndim() != 4 && in.ndim() != 5) { throw std::invalid_argument( - "[conv] Can only handle groups != 1 in 1D or 2D convolutions."); + "[conv] Can only handle groups != 1 in 1D, 2D and 3D convolutions."); } int spatial_dims = in.ndim() - 2; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index b54619671..3f707f8c0 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -497,7 +497,9 @@ def run_conv3D( in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype( np_dtype ) - wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype) + wt_np = np.random.normal( + 0.0, 1.0, (O, kD, kH, kW, int(C / groups)) + ).astype(np_dtype) in_mx, wt_mx = map(mx.array, (in_np, wt_np)) in_pt, wt_pt = map( @@ -540,6 +542,18 @@ def run_conv3D( ): run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype) + # Groups tests + N, C, O = (4, 16, 32) + for idim, kdim, stride, padding in ( + ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)), + ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)), + ((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)), + ): + for group in (1, 2, 4, 8, 16): + run_conv3D( + N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype + ) + @unittest.skipIf(not has_torch, "requires Torch") def test_torch_conv_3D_grad(self): def run_conv3D_grad(