Skip to content

Commit

Permalink
tape works
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 21, 2024
1 parent f343c88 commit 69ee50a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(d/dual? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))

Check warning on line 284 in src/emmy/abstract/function.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/abstract/function.cljc#L284

Added line #L284 was not covered by tests
primal-s (s/mapr (fn [x] (d/primal x tag)) s)]
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
(s/fold-chain (fold-fn f primal-s tag) s)))

(defn- check-argument-type
Expand Down
18 changes: 11 additions & 7 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -543,12 +543,16 @@
(letfn [(process-term [term]
(g/simplify
(s/mapr (fn rec [x]
(if (d/dual? x)
(d/bundle-element
(rec (d/primal x))
(rec (d/tangent x))
(d/tag x))
(-> (g/simplify x)
(x/substitute replace-m))))
(cond (d/dual? x)
(d/bundle-element
(rec (d/primal x))
(rec (d/tangent x))
(d/tag x))

(tape/tape? x)
(u/illegal "TODO implement this using fmap style.")

Check warning on line 553 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L553

Added line #L553 was not covered by tests

:else (-> (g/simplify x)
(x/substitute replace-m))))
term)))]
(series/fmap process-term series)))))
61 changes: 56 additions & 5 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
[emmy.function :as f]
[emmy.generic :as g]
[emmy.matrix :as matrix]
[emmy.operator :as o]
[emmy.structure :as s]
[emmy.util :as u]
[emmy.value :as v]))
Expand Down Expand Up @@ -329,10 +330,21 @@
d/*active-tags*)
(apply max tags)))

;; TODO we could change `perturbed?` into something like
;; `possible-perturbations`, to get collection types to return sequence of
;; inputs for this. Then we could handle map-shaped inputs etc into literal
;; functions, if we had the proper descriptor language for it.

(defn tag+perturbation
"A COPY of the same function in `differential`. I'm adding this here to avoid
import nonsense, and I'll delete one of the copies on the next PR, when I add
support for mixing forward and reverse modes together."
"Given any number of `dxs`, returns a pair of the form
[<tag> <tape-or-dual-number>]
containing the tag and instance of [[emmy.differential/Dual]] or [[TapeCell]]
associated with the inner-most call to [[with-active-tag]] in the current call
stack.
If none of `dxs` has an active tag, returns `nil`."
([& dxs]
(let [m (into {} (mapcat
(fn [dx]
Expand Down Expand Up @@ -479,6 +491,20 @@
;;
(defrecord Completed [v->partial]
d/IPerturbed
;; TODO note that this can happen because these can pop out from inside of
;; ->partial-fn. And that is currently where the tag-rewriting has to occur.
;;
;; But that is going to be inefficient for lots of intermediate values...
;; ideally we could call this AFTER we select out the IDs. That implies that
;; we want to shove that inside of extract.
;;
;; TODO TODO TODO definitely do this, we definitely want that to happen, don't
;; have those stacked levels, otherwise super inefficient to walk multiple
;; times.
;;
;; TODO AND THEN if that's true then we can delete this implementation, since
;; we'll already be pulled OUT of the completed map.

;; NOTE that it's a problem that `replace-tag` is called on [[Completed]]
;; instances now. In a future refactor I want `get` calls out of
;; a [[Completed]] map to occur before tag replacement needs to happen.
Expand Down Expand Up @@ -524,6 +550,10 @@
- the partial derivative of the output with respect to that value."
[root]
(let [nodes (topological-sort root)

;; TODO this is the spot where we want to wire in many sensitivities. So
;; how would it work, if we set all of the sensitivities for the outputs
;; at once? What would the ordering be as we walked backwards?
sensitivities {(tape-id root) 1}]
(->Completed
(reduce process sensitivities nodes))))
Expand All @@ -540,6 +570,10 @@

(declare ->partials)

;; TODO fix the docstring, and think of how we can combine this into the
;; narrative of what we find in derivative. Maybe this should be the main
;; version?

(defn- ->partials-fn
"Returns a new function that composes a 'tag extraction' step with `f`. The
returned fn will
Expand Down Expand Up @@ -582,12 +616,24 @@
(vector? output)
(mapv #(->partials % tag) output)

;; Here is an example of the subtlety. We MAY want to go one at a
;; time... or we may want to insert some sensitivity entry into the
;; entire structure and roll the entire structure back. We don't do that
;; YET so I bet we can get away with ignoring it for this first PR. But
;; we are close to needing that.
(s/structure? output)
(s/mapr #(->partials % tag) output)

(f/function? output)
(->partials-fn output tag)

(o/operator? output)
(o/->Operator (->partials-fn (o/procedure output) tag)
(o/arity output)
(o/name output)
(o/context output)
(meta output))

Check warning on line 635 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L631-L635

Added lines #L631 - L635 were not covered by tests

(v/scalar? output)
(->Completed {})

Expand Down Expand Up @@ -620,10 +666,15 @@
(s/mapr #(extract % id) output)

(f/function? output)
;; TODO this needs to handle perturbation confusion with tag
;; replacement. Make something similar to extract-tangent-fn.
(comp #(extract % id) output)

(o/operator? output)
(o/->Operator (extract (o/procedure output) id)
(o/arity output)
(o/name output)
(o/context output)
(meta output))

Check warning on line 676 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L671-L676

Added lines #L671 - L676 were not covered by tests

:else 0))

;; TODO note that [[interpret]] and [[tapify]] both need to become generic on
Expand Down

0 comments on commit 69ee50a

Please sign in to comment.