-
Notifications
You must be signed in to change notification settings - Fork 1
/
spatial_transformer.py
94 lines (73 loc) · 3.35 KB
/
spatial_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import tensorflow as tf
def generate_normalized_homo_meshgrids(inputs):
# for x, y in grid, -1 <=x,y<=1
batch_size = tf.shape(inputs)[0]
_, H, W, _ = inputs.shape
x_range = tf.range(W)
y_range = tf.range(H)
x_mesh, y_mesh = tf.meshgrid(x_range, y_range)
x_mesh = (x_mesh / W - 0.5) * 2
y_mesh = (y_mesh / H - 0.5) * 2
y_mesh = tf.reshape(y_mesh, (*y_mesh.shape, 1))
x_mesh = tf.reshape(x_mesh, (*x_mesh.shape, 1))
ones_mesh = tf.ones_like(x_mesh)
homogeneous_grid = tf.concat([x_mesh, y_mesh, ones_mesh], -1)
homogeneous_grid = tf.reshape(homogeneous_grid, (-1, 3, 1))
homogeneous_grid = tf.dtypes.cast(homogeneous_grid, tf.float32)
homogeneous_grid = tf.expand_dims(homogeneous_grid, 0)
return tf.tile(homogeneous_grid, [batch_size, 1, 1, 1])
def transform_grids(transformations, grids, inputs):
trans_matrices = tf.reshape(transformations, (-1, 2, 3))
batch_size = tf.shape(trans_matrices)[0]
gs = tf.squeeze(grids, -1)
reprojected_grids = tf.matmul(trans_matrices, gs, transpose_b=True)
# transform grid range from [-1,1) to the range of [0,1)
reprojected_grids = (tf.linalg.matrix_transpose(reprojected_grids) + 1) * 0.5
_, H, W, _ = inputs.shape
reprojected_grids = tf.math.multiply(reprojected_grids, [W, H])
return reprojected_grids
def generate_four_neighbors_from_reprojection(inputs, reprojected_grids):
_, H, W, _ = inputs.shape
x, y = tf.split(reprojected_grids, 2, axis=-1)
x1 = tf.floor(x)
x1 = tf.dtypes.cast(x1, tf.int32)
x2 = x1 + tf.constant(1)
y1 = tf.floor(y)
y1 = tf.dtypes.cast(y1, tf.int32)
y2 = y1 + tf.constant(1)
y_max = tf.constant(H - 1, dtype=tf.int32)
x_max = tf.constant(W - 1, dtype=tf.int32)
zero = tf.zeros([1], dtype=tf.int32)
x1_safe = tf.clip_by_value(x1, zero, x_max)
y1_safe = tf.clip_by_value(y1, zero, y_max)
x2_safe = tf.clip_by_value(x2, zero, x_max)
y2_safe = tf.clip_by_value(y2, zero, y_max)
return x1_safe, y1_safe, x2_safe, y2_safe
def bilinear_sample(inputs, reprojected_grids):
x1, y1, x2, y2 = generate_four_neighbors_from_reprojection(inputs, reprojected_grids)
x1y1 = tf.concat([y1, x1], -1)
x1y2 = tf.concat([y2, x1], -1)
x2y1 = tf.concat([y1, x2], -1)
x2y2 = tf.concat([y2, x2], -1)
pixel_x1y1 = tf.gather_nd(inputs, x1y1, batch_dims=1)
pixel_x1y2 = tf.gather_nd(inputs, x1y2, batch_dims=1)
pixel_x2y1 = tf.gather_nd(inputs, x2y1, batch_dims=1)
pixel_x2y2 = tf.gather_nd(inputs, x2y2, batch_dims=1)
x, y = tf.split(reprojected_grids, 2, axis=-1)
wx = tf.concat([tf.dtypes.cast(x2, tf.float32) - x, x - tf.dtypes.cast(x1, tf.float32)], -1)
wx = tf.expand_dims(wx, -2)
wy = tf.concat([tf.dtypes.cast(y2, tf.float32) - y, y - tf.dtypes.cast(y1, tf.float32)], -1)
wy = tf.expand_dims(wy, -1)
Q = tf.concat([pixel_x1y1, pixel_x1y2, pixel_x2y1, pixel_x2y2], -1)
Q_shape = tf.shape(Q)
Q = tf.reshape(Q, (Q_shape[0], Q_shape[1], 2, 2))
Q = tf.cast(Q, tf.float32)
r = wx @ Q @ wy
_, H, W, channels = inputs.shape
r = tf.reshape(r, (-1, H, W, 1))
return r
def spatial_transform_input(inputs, transformations):
grids = generate_normalized_homo_meshgrids(inputs)
reprojected_grids = transform_grids(transformations, grids, inputs)
result = bilinear_sample(inputs, reprojected_grids)
return result