Skip to content

Commit

Permalink
withall for tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed May 31, 2024
1 parent 2bc1bb1 commit c26c36d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/WithAlloc.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
module WithAlloc

using Bumper
export whatalloc, @withalloc1
export whatalloc, @withalloc1, @withalloc

function whatalloc end

function _bumper_alloc(allocinfo::Tuple{Type, Vararg{Int, N}}) where {N}
(Bumper.alloc!(Bumper.default_buffer(), allocinfo... ), )
end

function _bumper_alloc(allocinfo::NTuple{N, <: Tuple}) where {N}
ntuple(i -> Bumper.alloc!(Bumper.default_buffer(), allocinfo[i]...), N)
end


macro withalloc1(ex)
fncall = esc(ex.args[1])
Expand All @@ -18,6 +26,18 @@ macro withalloc1(ex)
end
end

macro withalloc(ex)
fncall = esc(ex.args[1])
args = esc.(ex.args[2:end])
quote
let
allocinfo = whatalloc($fncall, $(args...), )
storobj = _bumper_alloc(allocinfo)
$(fncall)(storobj..., $(args...), )
end
end
end


# Teemu's original draft, slightly edited
# macro withalloc(ex)
Expand Down
36 changes: 36 additions & 0 deletions test/test1.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using WithAlloc, Bumper, Test, LinearAlgebra

##

function mymul!(A, B, C)
mul!(A, B, C)
Expand Down Expand Up @@ -32,5 +33,40 @@ end
end
@test s3 s1

@no_escape begin
A4 = WithAlloc.@withalloc mymul!(B, C)

@show A4 A1
s4 = sum(A4)
end
@test s4 s1

##

B = randn(5,10)
C = randn(10, 3)
D = randn(10, 5)
A1 = B * C
A2 = B * D

function mymul2!(A1, A2, B, C, D)
mul!(A1, B, C)
mul!(A2, B, D)
return A1, A2
end

function WithAlloc.whatalloc(::typeof(mymul2!), B, C, D)
T1 = promote_type(eltype(B), eltype(C))
T2 = promote_type(eltype(B), eltype(D))
return ( (T1, size(B, 1), size(C, 2)),
(T2, size(B, 1), size(D, 2)) )
end


@no_escape begin
A1b, A2b = WithAlloc.@withalloc mymul2!(B, C, D)

@show A1 A1b
@show A2 A2b
end

0 comments on commit c26c36d

Please sign in to comment.