Skip to content

Commit

Permalink
Specializing slice_scatter. (FlagOpen#270)
Browse files Browse the repository at this point in the history
* specializing slice_scatter. WIP.

* polish and refine 2d_inner cases.

* fix slice_scatter error on 1d inputs.

* test slice_scatter fallback
  • Loading branch information
tongxin authored Nov 5, 2024
1 parent af0a3a6 commit 189f645
Show file tree
Hide file tree
Showing 5 changed files with 599 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def enable(lib=aten_lib):
lib.impl("fill.Scalar", fill_scalar, "CUDA")
lib.impl("fill.Tensor", fill_tensor, "CUDA")
lib.impl("flip", flip, "CUDA")
lib.impl("slice_scatter", slice_scatter, "CUDA")
lib.impl("slice_scatter", slice_scatter_v2, "CUDA")
lib.impl("select_scatter", select_scatter, "CUDA")
lib.impl("index_select", index_select, "CUDA")
lib.impl("tile", tile, "CUDA")
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from .sigmoid import sigmoid
from .silu import silu
from .sin import sin
from .slice_scatter import slice_scatter
from .slice_scatter import slice_scatter, slice_scatter_v2
from .softmax import softmax
from .stack import stack
from .sub import sub
Expand Down Expand Up @@ -234,6 +234,7 @@
"where_scalar_other",
"select_scatter",
"slice_scatter",
"slice_scatter_v2",
"masked_fill",
"_unique2",
"_upsample_bicubic2d_aa",
Expand Down
Loading

0 comments on commit 189f645

Please sign in to comment.