Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sort! rules #1000

Merged
merged 8 commits into from
Dec 17, 2023
Merged

sort! rules #1000

merged 8 commits into from
Dec 17, 2023

Conversation

jgreener64
Copy link
Contributor

I made Enzyme rules for sort! as suggested in #880. These are my first Enzyme rules so will need some checking.

I put them in src/internal_rules.jl for now, let me know if they should go somewhere else. I have tests too which I can add later.

@codecov-commenter
Copy link

codecov-commenter commented Aug 14, 2023

Codecov Report

Attention: 36 lines in your changes are missing coverage. Please review.

Comparison is base (10d380b) 76.07% compared to head (9015204) 75.79%.

Files Patch % Lines
src/internal_rules.jl 0.00% 36 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1000      +/-   ##
==========================================
- Coverage   76.07%   75.79%   -0.28%     
==========================================
  Files          35       35              
  Lines        9926     9962      +36     
==========================================
  Hits         7551     7551              
- Misses       2375     2411      +36     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

kwargs...
)
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like sortperm itself just makes a list of 1:N, then sorts that in place using the original data as the order: https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1756

Could we change this to first sort the derivative array in place using the primal array as the order, then do the actual sort on the primal? That way we don't have that temporary

@jgreener64
Copy link
Contributor Author

Added tests and changed the augmented primal. I will address the other points later.

@jgreener64
Copy link
Contributor Author

Added a rule for batched mode.

I had a look at sorting without allocating the index array but couldn't get it to work. The Perm order at https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1756 is specifically for ordering indices, rather than generic ordering of another array.

@jgreener64
Copy link
Contributor Author

Is this ready for merge? Test failures look unrelated.

@matinraayai
Copy link

matinraayai commented Nov 27, 2023

@jgreener64 will this rule also cover out-of-place sort across different dims?
Also thank you for making this, I realized my code needs to diff a code with sort.

@jgreener64
Copy link
Contributor Author

In general sort will call this rule yes since it calls sort! internally: https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1495-L1500.

sort with dims though maybe not as that goes down a different path: https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1817. Give it a go and see what happens?

@matinraayai
Copy link

In general sort will call this rule yes since it calls sort! internally: https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1495-L1500.

sort with dims though maybe not as that goes down a different path: https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1817. Give it a go and see what happens?

This is what I'm getting on Enzyme Master right now:

Illegal updateAnalysis prev:{[-1]:Pointer, [-1,-1]:Float@double} new: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}
val:   %47 = bitcast double addrspace(13)* %46 to i64 addrspace(13)*, !dbg !291 origin=  store i64 %42, i64 addrspace(13)* %47, align 8, !dbg !291, !tbaa !207, !alias.scope !210, !noalias !211

Caused by:
Stacktrace:
 [1] setindex!
   @ ./array.jl:1019
 [2] _setindex_ra!
   @ ./reinterpretarray.jl:537
 [3] setindex!
   @ ./reinterpretarray.jl:505
 [4] _sort!
   @ ./sort.jl:906

I was hoping this PR would fix this issue, but if not maybe I can help add something new to the rule to cover the dim keyword?

@jgreener64
Copy link
Contributor Author

Does that error occur with this PR as well?

@matinraayai
Copy link

It seems like it. I just swapped Enzyme#master with your sort-rule branch and they both seem to give the same error.

@jgreener64
Copy link
Contributor Author

It looks like a specific rule needs to be written for this path: https://github.com/JuliaLang/julia/blob/750df9fb5bede16f321f5d5405943d12aec7b83e/base/sort.jl#L1817.

Though as a short-term fix you can copy the array and call sort!, that should hit the above rules and be fast.

@matinraayai
Copy link

@jgreener64 I will do that (I don't care about the unsorted array). I thought the rules would apply to the ! variant too however?

@jgreener64
Copy link
Contributor Author

The rules in this PR will always apply to sort! and will apply to sort in cases where dims is not set. So the workaround for sort with dims is to copy the array and use sort!, which always hits these rules.

@wsmoses wsmoses merged commit 770b064 into EnzymeAD:main Dec 17, 2023
29 of 42 checks passed
@jgreener64 jgreener64 deleted the sort-rule branch December 17, 2023 23:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants