forked from sayakpaul/SimCLR-in-TensorFlow-2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
33 lines (24 loc) · 891 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import tensorflow as tf
cosine_sim_1d = tf.keras.losses.CosineSimilarity(axis=1, reduction=tf.keras.losses.Reduction.NONE)
cosine_sim_2d = tf.keras.losses.CosineSimilarity(axis=2, reduction=tf.keras.losses.Reduction.NONE)
def _cosine_simililarity_dim1(x, y):
v = cosine_sim_1d(x, y)
return v
def _cosine_simililarity_dim2(x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = cosine_sim_2d(tf.expand_dims(x, 1), tf.expand_dims(y, 0))
return v
def _dot_simililarity_dim1(x, y):
# x shape: (N, 1, C)
# y shape: (N, C, 1)
# v shape: (N, 1, 1)
v = tf.matmul(tf.expand_dims(x, 1), tf.expand_dims(y, 2))
return v
def _dot_simililarity_dim2(x, y):
v = tf.tensordot(tf.expand_dims(x, 1), tf.expand_dims(tf.transpose(y), 0), axes=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v