Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alignment invariant map to map distance #103

Draft
wants to merge 29 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5dd4a8a
Merge pull request #98 from flatironinstitute/dev
DSilva27 Sep 18, 2024
bae6271
Merge pull request #99 from flatironinstitute/dev
geoffwoollard Sep 18, 2024
63470a2
point cloud -> gw distance
geoffwoollard Sep 19, 2024
bef535e
average point clouds or distances, but does not make sense since orde…
geoffwoollard Sep 19, 2024
6a967b6
clean points
geoffwoollard Sep 26, 2024
07865b7
Merge branch 'dev' into 96-alignment-invariant-map-to-map-distance
geoffwoollard Sep 26, 2024
bd304af
notebook
geoffwoollard Nov 14, 2024
e2abda1
script with dask downsampling on one submission with itself
geoffwoollard Nov 21, 2024
f8ac02f
script with dask downsampling on one submission with itself
geoffwoollard Nov 21, 2024
1b8bfcf
script with dask downsampling on one submission with itself
geoffwoollard Nov 21, 2024
59a1ba8
update
geoffwoollard Nov 21, 2024
2f5b9ad
Merge branch 'dev' into 96-alignment-invariant-map-to-map-distance
geoffwoollard Dec 19, 2024
ac6f705
results
geoffwoollard Dec 20, 2024
f6d1f5a
results in notebook
geoffwoollard Dec 20, 2024
1dba48c
float128 and other types. generalize to non-symmetric with two sets o…
geoffwoollard Dec 22, 2024
874b02d
exponent
geoffwoollard Dec 22, 2024
d43bac3
fix symmetric bug overwriting results with zero
geoffwoollard Dec 28, 2024
8d51f10
mockup example
geoffwoollard Dec 28, 2024
c420650
gw
geoffwoollard Dec 28, 2024
d2c02af
scheduler args
geoffwoollard Dec 29, 2024
0eaaaa6
slurm mockup working
geoffwoollard Dec 29, 2024
d46aa85
npix3
geoffwoollard Dec 29, 2024
45e793b
single volume_i,j for low memory
geoffwoollard Dec 29, 2024
c427319
refactor out prepare marginals
geoffwoollard Dec 29, 2024
c18fd87
refactor name of volume to marginal
geoffwoollard Dec 29, 2024
2bf32eb
running via sbatch
geoffwoollard Dec 29, 2024
0c915ec
refactor args, main
geoffwoollard Dec 29, 2024
00ce528
exponent
geoffwoollard Dec 30, 2024
419310b
gw distance class. return empty
geoffwoollard Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# gw
src/cryo_challenge/_map_to_map/gromov_wasserstein/*mp4
src/cryo_challenge/_map_to_map/gromov_wasserstein/*map

# downloaded data
data/dataset_2_submissions
data/dataset_1_submissions
Expand Down
Empty file.
90 changes: 90 additions & 0 deletions src/cryo_challenge/_map_to_map/gromov_wasserstein/coords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np


def coords_n_by_d(coords_1d=None, N=None, d=3):
if N is None:
assert coords_1d is not None
elif coords_1d is None:
assert N is not None
coords_1d = np.arange(-N // 2, N // 2)

if d == 2:
X = np.meshgrid(coords_1d, coords_1d)
elif d == 3:
X = np.meshgrid(coords_1d, coords_1d, coords_1d)
coords = np.zeros((X[0].size, d))
for di in range(d):
coords[:, di] = X[di].flatten()
# make compatible with flatten
if d == 3:
coords[:, [0, 1]] = coords[:, [1, 0]]
elif d == 2:
coords[:, [0, 1]] = coords[:, [1, 0]]

return coords


def EA_to_R3(phi, theta, psi=None):
"""
Makes a rotation matrix from Z-Y-Z Euler angles.
maps image coordinates (x,y,0) view coordinates
See Z_1 Y_2 Z_3 entry in the table "Proper Euler angles" at https://en.wikipedia.org/wiki/Euler_angles#Rotation_matrix
http://www.gregslabaugh.net/publications/euler.pdf
"""
R_z = np.array(
[[np.cos(phi), -np.sin(phi), 0], [np.sin(phi), np.cos(phi), 0], [0, 0, 1]]
)
R_y = np.array(
[
[np.cos(theta), 0, np.sin(theta)],
[0, 1, 0],
[-np.sin(theta), 0, np.cos(theta)],
]
)
R = np.dot(R_z, R_y)
if psi is not None and psi != 0:
R_in = np.array(
[[np.cos(psi), -np.sin(psi), 0], [np.sin(psi), np.cos(psi), 0], [0, 0, 1]]
)

R = np.dot(R, R_in)

return R


def deg_to_rad(deg):
return deg * np.pi / 180


def get_random_quat(num_pts, method="sphere"):
"""
Get num_pts of unit quaternions with a uniform random distribution.
:param num_pts: The number of quaternions to return
: param method:
hemisphere: uniform on the 4 hemisphere, with x in [0,1], y,z in [-1,1]
sphere: uniform on the sphere, with x,y,z in [-1,1]
:return: Quaternion list of shape [number of quaternion, 4]
"""
u = np.random.rand(3, num_pts)
u1, u2, u3 = [u[x] for x in range(3)]

quat = np.zeros((4, num_pts))

quat[0] = np.sqrt(1 - u1) * np.sin(np.pi * u2 / 2)
quat[1] = np.sqrt(1 - u1) * np.cos(np.pi * u2 / 2)
quat[2] = np.sqrt(u1) * np.sin(np.pi * u3 / 2)
quat[3] = np.sqrt(u1) * np.cos(np.pi * u3 / 2)

return np.transpose(quat)


def quaternion_to_R(q):
a, b, c, d = q[0], q[1], q[2], q[3]
R = np.array(
[
[a**2 + b**2 - c**2 - d**2, 2 * b * c - 2 * a * d, 2 * b * d + 2 * a * c],
[2 * b * c + 2 * a * d, a**2 - b**2 + c**2 - d**2, 2 * c * d - 2 * a * b],
[2 * b * d - 2 * a * c, 2 * c * d + 2 * a * b, a**2 - b**2 - c**2 + d**2],
]
)
return R
251 changes: 251 additions & 0 deletions src/cryo_challenge/_map_to_map/gromov_wasserstein/cvt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import numpy as np
import scipy as sp
import scipy.spatial
from scipy.spatial import Delaunay

from coords import coords_n_by_d


# eps = sys.float_info.epsilon
eps = 0.000001


def in_box(robots, bounding_box):
return np.logical_and(
np.logical_and(
bounding_box[0] <= robots[:, 0], robots[:, 0] <= bounding_box[1]
),
np.logical_and(
bounding_box[2] <= robots[:, 1], robots[:, 1] <= bounding_box[3]
),
np.logical_and(
bounding_box[4] <= robots[:, 2], robots[:, 2] <= bounding_box[5]
),
)


def voronoi(robots, bounding_box):
i = in_box(robots, bounding_box)
points_center = robots[i, :]
points_left = np.copy(points_center)
points_left[:, 0] = bounding_box[0] - (points_left[:, 0] - bounding_box[0])
points_right = np.copy(points_center)
points_right[:, 0] = bounding_box[1] + (bounding_box[1] - points_right[:, 0])
points_down = np.copy(points_center)
points_down[:, 1] = bounding_box[2] - (points_down[:, 1] - bounding_box[2])
points_up = np.copy(points_center)
points_up[:, 1] = bounding_box[3] + (bounding_box[3] - points_up[:, 1])
points_back = np.copy(points_center)
points_back[:, 2] = bounding_box[4] - (points_back[:, 2] - bounding_box[4])
points_forth = np.copy(points_center)
points_forth[:, 2] = bounding_box[5] + (bounding_box[5] - points_forth[:, 2])
points = np.append(
points_center,
np.append(
np.append(points_left, points_right, axis=0),
np.append(
np.append(points_down, points_up, axis=0),
np.append(points_back, points_forth, axis=0),
axis=0,
),
axis=0,
),
axis=0,
)

vor = sp.spatial.Voronoi(points)
# Filter regions and select corresponding points
regions = []
points_to_filter = [] # we'll need to gather points too
ind = np.arange(points.shape[0])
ind = np.expand_dims(ind, axis=1)

for i, region in enumerate(vor.regions): # enumerate the regions
if not region: # nicer to skip the empty region altogether
continue

flag = True
tot = 0
tot_fail = 0
for index in region:
tot += 1
if index == -1:
flag = False
tot_fail += 1
break
else:
x = vor.vertices[index, 0]
y = vor.vertices[index, 1]
z = vor.vertices[index, 2]
if not (
bounding_box[0] - eps <= x
and x <= bounding_box[1] + eps
and bounding_box[2] - eps <= y
and y <= bounding_box[3] + eps
and bounding_box[4] - eps <= z
and z <= bounding_box[5] + eps
):
flag = False
tot_fail += 1
break
if flag:
regions.append(region)

# find the point which lies inside
points_to_filter.append(vor.points[vor.point_region == i][0, :])

vor.filtered_points = np.array(points_to_filter)
vor.filtered_regions = regions
return vor


def centroid_region(vertices, map_2d, memory, threshold=0):
min_x = map_2d.shape[0] + 1
min_y = map_2d.shape[1] + 1
min_z = map_2d.shape[2] + 1
max_x = -1
max_y = -1
max_z = -1
for i in range(len(vertices)):
min_x = min(min_x, vertices[i, 0])
min_y = min(min_y, vertices[i, 1])
min_z = min(min_z, vertices[i, 2])
max_x = max(max_x, vertices[i, 0])
max_y = max(max_y, vertices[i, 1])
max_z = max(max_z, vertices[i, 2])
min_x = int(min_x * map_2d.shape[0])
max_x = int(max_x * map_2d.shape[0]) + 1
min_y = int(min_y * map_2d.shape[1])
max_y = int(max_y * map_2d.shape[1]) + 1
min_z = int(min_z * map_2d.shape[2])
max_z = int(max_z * map_2d.shape[2]) + 1
max_x = min(max_x, map_2d.shape[0])
max_y = min(max_y, map_2d.shape[1])
max_z = min(max_z, map_2d.shape[2])
min_x = max(min_x, 0)
min_y = max(min_y, 0)
min_z = max(min_z, 0)
A = 0
C_x = 0
C_y = 0
C_z = 0

temp = np.zeros(vertices[0].shape)
for i in range(len(vertices)):
temp += vertices[i] / len(vertices)

N_x = max_x - min_x
N_y = max_y - min_y
N_z = max_z - min_z
grid = np.indices((N_x, N_y, N_z))
g0 = (grid[0].ravel() + min_x + 0.5) / map_2d.shape[0]
g1 = (grid[1].ravel() + min_y + 0.5) / map_2d.shape[1]
g2 = (grid[2].ravel() + min_z + 0.5) / map_2d.shape[2]
points = np.asarray([g0, g1, g2]).transpose()

if len(points) > 0:
in_array = Delaunay(vertices).find_simplex(points) >= 0

map_array = map_2d[min_x:max_x, min_y:max_y, min_z:max_z].ravel()
A += (in_array * map_array).sum()
C_x += (in_array * map_array * g0).sum()
C_y += (in_array * map_array * g1).sum()
C_z += (in_array * map_array * g2).sum()

if A == 0:
return np.array([[temp[0], temp[1], temp[2]]]), A
C_x /= A
C_y /= A
C_z /= A
return np.array([[C_x, C_y, C_z]]), A


def plot(r, map_2d, bounding_box):
vor = voronoi(r, bounding_box)

centroids = []
weights = []
memory = np.zeros(map_2d.shape)
for region in vor.filtered_regions:
vertices = vor.vertices[region + [region[0]], :]
centroid, w = centroid_region(vertices, map_2d, memory)
centroids.append(list(centroid[0, :]))
weights.append(w)
centroids = np.asarray(centroids)

return centroids, weights


def update(rob, centroids):
interim_x = np.asarray(centroids[:, 0] - rob[:, 0])
interim_y = np.asarray(centroids[:, 1] - rob[:, 1])
interim_z = np.asarray(centroids[:, 2] - rob[:, 2])
# magn = [np.linalg.norm(centroids[i, :] - rob[i, :]) for i in range(rob.shape[0])]
# x = np.copy(interim_x)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if I need this...

# x = np.asarray([interim_x[i] / magn[i] for i in range(interim_x.shape[0])])
# y = np.copy(interim_y)
# y = np.asarray([interim_y[i] / magn[i] for i in range(interim_y.shape[0])])
# z = np.copy(interim_z)
# z = np.asarray([interim_z[i] / magn[i] for i in range(interim_z.shape[0])])
temp = np.copy(rob)
temp[:, 0] = [rob[i, 0] + 1 * interim_x[i] for i in range(rob.shape[0])]
temp[:, 1] = [rob[i, 1] + 1 * interim_y[i] for i in range(rob.shape[0])]
temp[:, 2] = [rob[i, 2] + 1 * interim_z[i] for i in range(rob.shape[0])]
return np.asarray(temp)


def get_init(map_3d, M, random_seed=None):
cube_length = max(map_3d.shape)
cube_length += cube_length % 2
map_3d = np.pad(
map_3d,
(
(0, cube_length - map_3d.shape[0]),
(0, cube_length - map_3d.shape[1]),
(0, cube_length - map_3d.shape[2]),
),
"minimum",
)
assert (
np.unique(map_3d.shape).size == 1
), "map must be cube, not non-cubic rectangular parallelepiped"
N = map_3d.shape[0]
assert N % 2 == 0, "N must be even"

map_3d /= map_3d.sum() # 3d map to probability density
map_3d_flat = map_3d.flatten()
map_3d_idx = np.arange(map_3d_flat.shape[0])
if random_seed is not None:
np.random.seed(seed=random_seed)

# this scales with M (the number of chosen items), not map_3d_idx (the possibilities to choose from)
samples_idx = np.random.choice(
map_3d_idx, size=M, replace=True, p=map_3d_flat
) # chosen voxel indeces
coords_1d = np.arange(-N // 2, N // 2)
xyz = coords_n_by_d(coords_1d, d=3)
rm0 = xyz[
samples_idx
] # pick out coordinates that where chosen. note that this assumes map_3d_idx matches with rows of xyz

robs = []
for i in range(len(rm0)):
rob = [0, 0, 0]
rob[0] = rm0[i][0] / N + 1 / 2 + np.random.normal(0, 0.1) / N
rob[1] = rm0[i][1] / N + 1 / 2 + np.random.normal(0, 0.1) / N
rob[2] = rm0[i][2] / N + 1 / 2 + np.random.normal(0, 0.1) / N
robs.append(rob)
robs = np.asarray(robs)

return robs, map_3d


def iterate(map_3d, r0, max_iter=5):
robots = r0
bounding_box = np.array([0.0, 1.0, 0.0, 1.0, 0.0, 1.0])

for i in range(max_iter):
centroids, weights = plot(robots, map_3d, bounding_box)
robots = update(robots, centroids)

return robots
Loading
Loading