-
Notifications
You must be signed in to change notification settings - Fork 5
/
join.go
90 lines (76 loc) · 1.76 KB
/
join.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package golgi
import (
"github.com/pkg/errors"
G "gorgonia.org/gorgonia"
)
var (
_ Layer = (*Join)(nil)
)
type joinOp int
const (
composeOp joinOp = iota
addOp
elMulOp
)
// Join joins are generalized compositions.
type Join struct {
Composition
op joinOp
}
// Add adds the results of two layers/terms.
func Add(a, b Term) *Join {
return &Join{
Composition: Composition{
a: a,
b: b,
},
op: addOp,
}
}
// HadamardProd performs a elementwise multiplicatoin on the results of two layers/terms.
func HadamardProd(a, b Term) *Join {
return &Join{
Composition: Composition{
a: a,
b: b,
},
op: elMulOp,
}
}
// Fwd runs the equation forwards.
func (l *Join) Fwd(a G.Input) (output G.Result) {
if l.op == composeOp {
return l.Composition.Fwd(a)
}
if err := G.CheckOne(a); err != nil {
return G.Err(errors.Wrapf(err, "Forward of a Join %v", l.Name()))
}
if l.retVal != nil {
return l.retVal
}
input := a.Node()
x, err := Apply(l.a, input)
if err != nil {
return G.Err(errors.Wrapf(err, "Forward of Join %v - Applying %v to %v failed", l.Name(), l.a, input.Name()))
}
xn, ok := x.(*G.Node)
if !ok {
return G.Err(errors.Errorf("Expected the result of applying %v to %v to return a *Node. Got %v of %T instead", l.a, input.Name(), x, x))
}
y, err := Apply(l.b, input)
if err != nil {
return G.Err(errors.Wrapf(err, "Forward of Join %v - Applying %v to %v failed", l.Name(), l.b, input.Name()))
}
yn, ok := y.(*G.Node)
if !ok {
return G.Err(errors.Errorf("Expected the result of applying %v to %v to return a *Node. Got %v of %T instead", l.a, input.Name(), y, y))
}
// perform the op
switch l.op {
case addOp:
return G.LiftResult(G.Add(xn, yn))
case elMulOp:
return G.LiftResult(G.HadamardProd(xn, yn))
}
panic("Unreachable")
}