Skip to content

Commit

Permalink
support nesting
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Dec 8, 2023
1 parent 89f75a7 commit 405160e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@
the exemplar expected."
[f provided expected indexes]
(cond (number? expected)
(when-not (v/numerical? provided)
(when-not (v/scalar? provided)
(u/illegal (str "expected numerical quantity in argument " indexes
" of function call " f
" but got " provided)))
Expand Down
47 changes: 30 additions & 17 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,11 @@

(defn tag-of
"More permissive version of [[tape-tag]] that returns ##-Inf, the 'least
possible tag', when passed a non-[[TapeCell]] instance.
TODO this will need to be extended to
handle [[emmy.differential/Differential]] instances when these namespaces
merge."
possible tag', when passed a non-[[TapeCell]] instance."
[x]
(if (tape? x)
(tape-tag x)
##-Inf))
(cond (tape? x) (tape-tag x)
(d/differential? x) (d/max-order-tag x)
:else ##-Inf))

(defn tape-primal
"Given a [[TapeCell]], returns the `primal` field of the supplied [[TapeCell]]
Expand Down Expand Up @@ -301,9 +297,9 @@
Given a non-[[TapeCell]], acts as identity."
[v]
(if (tape? v)
(recur (tape-primal v))
v))
(cond (tape? v) (recur (tape-primal v))
(d/differential? v) (recur (d/primal-part v))
:else v))

;; ### Comparison, Control Flow
;;
Expand Down Expand Up @@ -669,7 +665,18 @@
"No df:dx, df:dy supplied for `f` or registered generically."))))
([f df:dx df:dy]
(fn call [x y]
(letfn [(operate [tag]
(letfn [(operate-forward [tag]
(let [[xe dx] (d/primal-tangent-pair x tag)
[ye dy] (d/primal-tangent-pair y tag)
a (call xe ye)
b (if (g/numeric-zero? dx)
a
(d/d:+* a (df:dx xe ye) dx))]
(if (g/numeric-zero? dy)
b
(d/d:+* b (df:dy xe ye) dy))))

(operate-reverse [tag]
(let [primal-x (tape-primal x tag)
primal-y (tape-primal y tag)
partial-x (if (and (tape? x) (= tag (tape-tag x)))
Expand All @@ -683,11 +690,15 @@
(into partial-x partial-y))))]
(let [tag-x (tag-of x)
tag-y (tag-of y)]
(cond (and (tape? x) (>= tag-x tag-y))
(operate tag-x)
(cond (>= tag-x tag-y)
(cond (tape? x) (operate-reverse tag-x)
(d/differential? x) (operate-forward tag-x)
:else (f x y))

(and (tape? y) (< tag-x tag-y))
(operate tag-y)
(< tag-x tag-y)
(cond (tape? y) (operate-reverse tag-y)
(d/differential? y) (operate-forward tag-y)
:else (f x y))

Check warning on line 701 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L701

Added line #L701 was not covered by tests

:else (f x y)))))))

Expand Down Expand Up @@ -750,7 +761,9 @@
([generic-op differential-op]
(doseq [signature [[::tape ::tape]
[::v/scalar ::tape]
[::tape ::v/scalar]]]
[::tape ::v/scalar]
[::tape ::d/differential]
[::d/differential ::tape]]]
(defmethod generic-op signature [a b] (differential-op a b)))))

(defn ^:no-doc by-primal
Expand Down
3 changes: 0 additions & 3 deletions test/emmy/tape_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,6 @@
((D (t/gradient f)) 'a 'b 'c 'd 'e 'f)))
"forward-over-reverse")

;; TODO enable this when we add support for tape and gradient comms in
;; lift-2.
#_
(is (= expected
(g/simplify
((t/gradient (D f)) 'a 'b 'c 'd 'e 'f)))
Expand Down

0 comments on commit 405160e

Please sign in to comment.