-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support nested forward/reverse mode #156
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #156 +/- ##
==========================================
- Coverage 87.68% 87.67% -0.02%
==========================================
Files 99 99
Lines 15862 15791 -71
Branches 853 848 -5
==========================================
- Hits 13908 13844 -64
+ Misses 1101 1099 -2
+ Partials 853 848 -5 ☔ View full report in Codecov by Sentry. |
34dea5a
to
405160e
Compare
1ba84d7
to
e2d3f0e
Compare
99a0a08
to
c74897e
Compare
c74897e
to
69ee50a
Compare
227276b
to
3ece359
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments to help you on the journey, @littleredcomputer ...
@@ -262,15 +277,50 @@ | |||
(g/+ tangent (g/* (literal-apply partial primal-s) | |||
dx)))))))) | |||
|
|||
(defn- reverse-mode-fold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
building on the cryptic fold-chain
from a previous PR, now we can drop this in and support reverse-mode.
(rec (d/tangent x)) | ||
(d/tag x)) | ||
|
||
(tape/tape? x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels janky, and is the only place where we deliberately introspect these types.
@@ -557,36 +556,6 @@ | |||
(boolean | |||
(some #{tag} *active-tags*))) | |||
|
|||
(defn inner-tag |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved to emmy.tape
@@ -699,280 +668,19 @@ | |||
|
|||
;; ## Chain Rule and Lifted Functions | |||
;; | |||
;; Finally, we come to the heart of it! [[lift-1]] and [[lift-2]] "lift", or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this all moves to emmy.tape
d/*active-tags*) | ||
(apply max tags))) | ||
|
||
(defn tag+perturbation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as we had in emmy.differential
, just using tag-of
internally vs emmy.differential/tag
@@ -548,6 +588,13 @@ | |||
(f/function? output) | |||
(->partials-fn output tag) | |||
|
|||
(o/operator? output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operator support, this should be handled with a protocol
(defunary g/zero? (by-primal g/zero?)) | ||
(defunary g/one? (by-primal g/one?)) | ||
(defunary g/identity? (by-primal g/identity?)) | ||
(defunary g/zero? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are careful to only respond true here if the tangent is also zero. should we be just as careful with tapes? by construction we really shouldn't be building duals with empty tangents etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need g/zero?
for tapes? If not, I'd leave it completely undefined: those things are kind of dangerous as you have discovered.
(if (and (tape/tape? entry) (= tag (tape/tape-tag entry))) | ||
(let [partial (literal-partial f path)] | ||
(conj partials [entry (literal-apply partial primal-s)])) | ||
partials)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is odd that nothing triggers L300, so in every case, (and (tape/tape? entry) (= tag (tape/tape-tag entry)))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would happen in the case of nested gradient
calls to a literal function, I just don't have those in the tests yet. In that case, the innermost tag wins and any tag with a different tape is treated as a scalar (i.e. partial derivative == 0, so we never add it to the map)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to share tests between forward mode and reverse mode... when I do that this should get hit.
(tape/tape-tag x) | ||
(tape/tape-id x) | ||
(rec (tape/tape-primal x)) | ||
(mapv (fn [[node partial]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like the inner thing is mapv
too. Does writing it that way feel less janky? [node partial]
is almost an instance of Dual, right, except that the tag belongs to the containing tape. Armed with that, you could unify this part with the part above, but (I say sincerely) it would be too clever. Therefore I don't think this is janky: it feels like about what we want to see here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this whole function is too clever, and my cleverness bit me here: #168
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleverness in trying to make this compatible with D
, that is...
src/emmy/tape.cljc
Outdated
(let [m (into {} (mapcat | ||
(fn [dx] | ||
(when-let [t (tag-of dx)] | ||
{t dx}))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it is cheaper to write [t dx]
. However you choose to write it, it's cool (and I do think that {t dx}
is a reasonable way to write the lone binding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
your way is faster!!
NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] instances. The | ||
best way to accomplish this is by building `df:dx` out of already-lifted | ||
functions, and declaring them by forward reference if you need to." | ||
NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes sense; it's so sweet. keep the magic at the most primitive layer
(defunary g/zero? (by-primal g/zero?)) | ||
(defunary g/one? (by-primal g/one?)) | ||
(defunary g/identity? (by-primal g/identity?)) | ||
(defunary g/zero? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need g/zero?
for tapes? If not, I'd leave it completely undefined: those things are kind of dangerous as you have discovered.
@@ -11,6 +11,7 @@ | |||
[emmy.generators :as sg] | |||
[emmy.generic :as g] | |||
[emmy.numerical.derivative :refer [D-numeric]] | |||
[emmy.operator :as o] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just had this random test idea, which I don't think is particularly valuable or anything, but thought it was cool: "the Jacobian of a permutation is the permutation matrix of that permutation", as in :
(defn f [a b c d e] (up d e c b a))
(def M ((D f) 'a 'b 'c 'd 'e))
;; M = (down (up 0 0 0 0 1) (up 0 0 0 1 0) (up 0 0 1 0 0) (up 1 0 0 0 0) (up 0 1 0 0 0))
;; and so:
(* M [1 2 3 4 5])
;; => (up 4 5 3 2 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added as a test!
This PR:
Makes forward- and reverse-mode automatic differentiation compatible with
each other, allowing for proper mixed-mode AD
Adds support for derivatives of literal functions in reverse-mode
Adds support for
emmy.operator/Operator
-shaped outputs in reverse-modeIn a future PR I want to add to / modify the
IPerturbed
protocol so that it's as supportive of reverse-mode as forward-mode. I'm leaving that to the future for now.