forked from AngusG/tensorflow-xnor-bnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_gemm_op.py
26 lines (21 loc) · 827 Bytes
/
tf_gemm_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
gemm_module = tf.load_op_library(
tf.resource_loader.get_path_to_datafile('./libs/gemm_op.so'))
xnor_gemm = gemm_module.gemm
@ops.RegisterGradient("Gemm")
def _xnor_gemm_grad(op, grad):
"""The gradients for `xnor_gemm`.
Args:
op: The `xnor_gemm` `Operation` that we are differentiating, which we can use
to find the inputs and outputs of the original op.
grad: Gradient with respect to the output of the `xnor_gemm` op.
Returns:
Gradients with respect to the input of `xnor_gemm`.
"""
a = op.inputs[0]
b = op.inputs[1]
grad_a = math_ops.matmul(grad, b, transpose_b=True)
grad_b = math_ops.matmul(a, grad, transpose_a=True)
return grad_a, grad_b