From 405160ef7afa2a06a392f7aac794259028a30e2e Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Fri, 8 Dec 2023 10:19:04 -0700 Subject: [PATCH] support nesting --- src/emmy/abstract/function.cljc | 2 +- src/emmy/tape.cljc | 47 +++++++++++++++++++++------------ test/emmy/tape_test.cljc | 3 --- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/emmy/abstract/function.cljc b/src/emmy/abstract/function.cljc index caa0cb6d..fe18962c 100644 --- a/src/emmy/abstract/function.cljc +++ b/src/emmy/abstract/function.cljc @@ -267,7 +267,7 @@ the exemplar expected." [f provided expected indexes] (cond (number? expected) - (when-not (v/numerical? provided) + (when-not (v/scalar? provided) (u/illegal (str "expected numerical quantity in argument " indexes " of function call " f " but got " provided))) diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index 44625477..123b2d2e 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -245,15 +245,11 @@ (defn tag-of "More permissive version of [[tape-tag]] that returns ##-Inf, the 'least - possible tag', when passed a non-[[TapeCell]] instance. - - TODO this will need to be extended to - handle [[emmy.differential/Differential]] instances when these namespaces - merge." + possible tag', when passed a non-[[TapeCell]] instance." [x] - (if (tape? x) - (tape-tag x) - ##-Inf)) + (cond (tape? x) (tape-tag x) + (d/differential? x) (d/max-order-tag x) + :else ##-Inf)) (defn tape-primal "Given a [[TapeCell]], returns the `primal` field of the supplied [[TapeCell]] @@ -301,9 +297,9 @@ Given a non-[[TapeCell]], acts as identity." [v] - (if (tape? v) - (recur (tape-primal v)) - v)) + (cond (tape? v) (recur (tape-primal v)) + (d/differential? v) (recur (d/primal-part v)) + :else v)) ;; ### Comparison, Control Flow ;; @@ -669,7 +665,18 @@ "No df:dx, df:dy supplied for `f` or registered generically.")))) ([f df:dx df:dy] (fn call [x y] - (letfn [(operate [tag] + (letfn [(operate-forward [tag] + (let [[xe dx] (d/primal-tangent-pair x tag) + [ye dy] (d/primal-tangent-pair y tag) + a (call xe ye) + b (if (g/numeric-zero? dx) + a + (d/d:+* a (df:dx xe ye) dx))] + (if (g/numeric-zero? dy) + b + (d/d:+* b (df:dy xe ye) dy)))) + + (operate-reverse [tag] (let [primal-x (tape-primal x tag) primal-y (tape-primal y tag) partial-x (if (and (tape? x) (= tag (tape-tag x))) @@ -683,11 +690,15 @@ (into partial-x partial-y))))] (let [tag-x (tag-of x) tag-y (tag-of y)] - (cond (and (tape? x) (>= tag-x tag-y)) - (operate tag-x) + (cond (>= tag-x tag-y) + (cond (tape? x) (operate-reverse tag-x) + (d/differential? x) (operate-forward tag-x) + :else (f x y)) - (and (tape? y) (< tag-x tag-y)) - (operate tag-y) + (< tag-x tag-y) + (cond (tape? y) (operate-reverse tag-y) + (d/differential? y) (operate-forward tag-y) + :else (f x y)) :else (f x y))))))) @@ -750,7 +761,9 @@ ([generic-op differential-op] (doseq [signature [[::tape ::tape] [::v/scalar ::tape] - [::tape ::v/scalar]]] + [::tape ::v/scalar] + [::tape ::d/differential] + [::d/differential ::tape]]] (defmethod generic-op signature [a b] (differential-op a b))))) (defn ^:no-doc by-primal diff --git a/test/emmy/tape_test.cljc b/test/emmy/tape_test.cljc index fae2836e..95685028 100644 --- a/test/emmy/tape_test.cljc +++ b/test/emmy/tape_test.cljc @@ -486,9 +486,6 @@ ((D (t/gradient f)) 'a 'b 'c 'd 'e 'f))) "forward-over-reverse") - ;; TODO enable this when we add support for tape and gradient comms in - ;; lift-2. - #_ (is (= expected (g/simplify ((t/gradient (D f)) 'a 'b 'c 'd 'e 'f)))