Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested forward/reverse mode #156

Merged
merged 6 commits into from
Apr 23, 2024
Merged

Support nested forward/reverse mode #156

merged 6 commits into from
Apr 23, 2024

Conversation

sritchie
Copy link
Member

@sritchie sritchie commented Dec 8, 2023

This PR:

  • Makes forward- and reverse-mode automatic differentiation compatible with
    each other, allowing for proper mixed-mode AD

  • Adds support for derivatives of literal functions in reverse-mode

  • Adds support for emmy.operator/Operator-shaped outputs in reverse-mode

In a future PR I want to add to / modify the IPerturbed protocol so that it's as supportive of reverse-mode as forward-mode. I'm leaving that to the future for now.

Copy link

codecov bot commented Dec 8, 2023

Codecov Report

Attention: Patch coverage is 94.91525% with 6 lines in your changes are missing coverage. Please review.

Project coverage is 87.67%. Comparing base (d354463) to head (2617c36).

❗ Current head 2617c36 differs from pull request most recent head 1e86dc5. Consider uploading reports for the commit 1e86dc5 to get more accurate results

Files Patch % Lines
src/emmy/abstract/function.cljc 71.42% 2 Missing and 2 partials ⚠️
src/emmy/tape.cljc 97.72% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #156      +/-   ##
==========================================
- Coverage   87.68%   87.67%   -0.02%     
==========================================
  Files          99       99              
  Lines       15862    15791      -71     
  Branches      853      848       -5     
==========================================
- Hits        13908    13844      -64     
+ Misses       1101     1099       -2     
+ Partials      853      848       -5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

src/emmy/tape.cljc Outdated Show resolved Hide resolved
Base automatically changed from sritchie/tape to main December 8, 2023 17:10
@sritchie sritchie force-pushed the sritchie/nested branch 2 times, most recently from 34dea5a to 405160e Compare December 8, 2023 17:19
@sritchie sritchie force-pushed the sritchie/nested branch 3 times, most recently from 1ba84d7 to e2d3f0e Compare April 9, 2024 00:27
@sritchie sritchie changed the base branch from main to sritchie/easy April 9, 2024 00:27
Base automatically changed from sritchie/easy to main April 13, 2024 01:15
@sritchie sritchie force-pushed the sritchie/nested branch 2 times, most recently from 99a0a08 to c74897e Compare April 18, 2024 18:37
Copy link
Member Author

@sritchie sritchie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments to help you on the journey, @littleredcomputer ...

src/emmy/abstract/function.cljc Outdated Show resolved Hide resolved
src/emmy/abstract/function.cljc Show resolved Hide resolved
@@ -262,15 +277,50 @@
(g/+ tangent (g/* (literal-apply partial primal-s)
dx))))))))

(defn- reverse-mode-fold
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

building on the cryptic fold-chain from a previous PR, now we can drop this in and support reverse-mode.

(rec (d/tangent x))
(d/tag x))

(tape/tape? x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels janky, and is the only place where we deliberately introspect these types.

@@ -557,36 +556,6 @@
(boolean
(some #{tag} *active-tags*)))

(defn inner-tag
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to emmy.tape

@@ -699,280 +668,19 @@

;; ## Chain Rule and Lifted Functions
;;
;; Finally, we come to the heart of it! [[lift-1]] and [[lift-2]] "lift", or
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this all moves to emmy.tape

d/*active-tags*)
(apply max tags)))

(defn tag+perturbation
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as we had in emmy.differential, just using tag-of internally vs emmy.differential/tag

@@ -548,6 +588,13 @@
(f/function? output)
(->partials-fn output tag)

(o/operator? output)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operator support, this should be handled with a protocol

(defunary g/zero? (by-primal g/zero?))
(defunary g/one? (by-primal g/one?))
(defunary g/identity? (by-primal g/identity?))
(defunary g/zero?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are careful to only respond true here if the tangent is also zero. should we be just as careful with tapes? by construction we really shouldn't be building duals with empty tangents etc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need g/zero? for tapes? If not, I'd leave it completely undefined: those things are kind of dangerous as you have discovered.

(if (and (tape/tape? entry) (= tag (tape/tape-tag entry)))
(let [partial (literal-partial f path)]
(conj partials [entry (literal-apply partial primal-s)]))
partials))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is odd that nothing triggers L300, so in every case, (and (tape/tape? entry) (= tag (tape/tape-tag entry))).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would happen in the case of nested gradient calls to a literal function, I just don't have those in the tests yet. In that case, the innermost tag wins and any tag with a different tape is treated as a scalar (i.e. partial derivative == 0, so we never add it to the map)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to share tests between forward mode and reverse mode... when I do that this should get hit.

(tape/tape-tag x)
(tape/tape-id x)
(rec (tape/tape-primal x))
(mapv (fn [[node partial]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the inner thing is mapv too. Does writing it that way feel less janky? [node partial] is almost an instance of Dual, right, except that the tag belongs to the containing tape. Armed with that, you could unify this part with the part above, but (I say sincerely) it would be too clever. Therefore I don't think this is janky: it feels like about what we want to see here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole function is too clever, and my cleverness bit me here: #168

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleverness in trying to make this compatible with D, that is...

(let [m (into {} (mapcat
(fn [dx]
(when-let [t (tag-of dx)]
{t dx})))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is cheaper to write [t dx]. However you choose to write it, it's cool (and I do think that {t dx} is a reasonable way to write the lone binding

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your way is faster!!

NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] instances. The
best way to accomplish this is by building `df:dx` out of already-lifted
functions, and declaring them by forward reference if you need to."
NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense; it's so sweet. keep the magic at the most primitive layer

(defunary g/zero? (by-primal g/zero?))
(defunary g/one? (by-primal g/one?))
(defunary g/identity? (by-primal g/identity?))
(defunary g/zero?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need g/zero? for tapes? If not, I'd leave it completely undefined: those things are kind of dangerous as you have discovered.

@@ -11,6 +11,7 @@
[emmy.generators :as sg]
[emmy.generic :as g]
[emmy.numerical.derivative :refer [D-numeric]]
[emmy.operator :as o]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just had this random test idea, which I don't think is particularly valuable or anything, but thought it was cool: "the Jacobian of a permutation is the permutation matrix of that permutation", as in :

(defn f [a b c d e] (up d e c b a))
(def M ((D f) 'a 'b 'c 'd 'e))
;; M = (down (up 0 0 0 0 1) (up 0 0 0 1 0) (up 0 0 1 0 0) (up 1 0 0 0 0) (up 0 1 0 0 0))
;; and so:
 (* M [1 2 3 4 5])
;; => (up 4 5 3 2 1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added as a test!

@sritchie sritchie merged commit bd47efd into main Apr 23, 2024
4 checks passed
@sritchie sritchie deleted the sritchie/nested branch April 23, 2024 22:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants