Skip to content

Commit

Permalink
WIP add test script
Browse files Browse the repository at this point in the history
  • Loading branch information
jmuhlich committed Jul 11, 2024
1 parent cae27cf commit 35e4f8c
Showing 1 changed file with 220 additions and 0 deletions.
220 changes: 220 additions & 0 deletions test_rotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import pathlib

import numpy as np
import matplotlib.pyplot as plt
import skimage.io
from ashlar import reg, utils, thumbnail


class TestMetadata(reg.Metadata):
def __init__(
self,
path,
tile_size,
overlap,
pixel_size,
channel=0,
zarr=None,
img=None,
series=None,
):
self.path = pathlib.Path(path)
self._tile_size = np.array(tile_size)
self.overlap = overlap
self._pixel_size = pixel_size
self.channel = channel
self.zarr = zarr
self.img = img
self.series = series
self.deconstruct_mosaic()

def deconstruct_mosaic(self):
if self.zarr is not None:
self.mosaic = self.zarr

if self.img is not None:
self.mosaic = self.img

if self.zarr is None and self.img is None:
self.mosaic = skimage.io.imread(self.path, key=self.channel)

m_shape = self.mosaic.shape

step_shape = (1 - self.overlap) * self._tile_size
# round position to integer since no subpixel needed for already stitched image
step_shape = np.around(step_shape).astype("int")
overlap_shape = np.around(self.overlap * self._tile_size).astype(int)
m_limit = m_shape - overlap_shape

self._slice_positions = (
np.mgrid[: m_limit[0] : step_shape[0], : m_limit[1] : step_shape[1]]
.reshape(2, -1)
.T
)

self._positions = self._slice_positions.astype(float)

if self.series is not None:
self._slice_positions = self._slice_positions[self.series]
self._positions = self._positions[self.series]

@property
def _num_images(self):
return len(self._positions)

@property
def num_channels(self):
return 1

@property
def pixel_size(self):
return self._pixel_size

@property
def pixel_dtype(self):
return self.zarr.dtype

@property
def mosaic_shape(self):
return self.zarr.shape

def tile_size(self, i):
return self._tile_size


class TestReader(reg.Reader):
def __init__(
self,
path=None,
tile_size=(1000, 1000),
overlap=0.1,
pixel_size=1,
channel=0,
zarr=None,
img=None,
series=None,
flip_x=False,
flip_y=False,
angle=0,
center_crop_shape=None,
noise=0,
):
path = "" if path is None else path
self.metadata = TestMetadata(
path, tile_size, overlap, pixel_size, channel, zarr, img, series
)
self.path = pathlib.Path(path)
self.mosaic = self.metadata.mosaic
self.flip_x = flip_x
self.flip_y = flip_y
self.angle = angle
self.noise = noise

def read(self, series, c):
position = self.metadata._slice_positions[series]
assert np.issubdtype(position.dtype, np.integer)
r, c = position
h, w = self.metadata._tile_size
img = self.mosaic[r : r + h, c : c + w]
if self.noise:
r = np.random.RandomState(seed=series)
noise_img = r.randint(0, self.noise + 1, size=img.shape)
img = np.clip(img + noise_img, img.min(), img.max()).astype(img.dtype)
if not np.all(img.shape == (h, w)):
img_h, img_w = img.shape
pad_h, pad_w = np.clip([h - img_h, w - img_w], 0, None)
img = np.pad(img, [(0, pad_h), (0, pad_w)])
if self.flip_x:
img = np.fliplr(img)
if self.flip_y:
img = np.flipud(img)
if self.angle != 0:
img = skimage.transform.rotate(img, self.angle, center=(0, 0), resize=True)
return img


def align_cycles(reader1, reader2, scale=0.05):
import skimage.transform

if not hasattr(reader1, "thumbnail"):
raise ValueError("reader1 does not have a thumbnail")
if not hasattr(reader2, "thumbnail"):
raise ValueError("reader2 does not have a thumbnail")
img1 = reader1.thumbnail
img2 = reader2.thumbnail
padded_shape = np.array((img1.shape, img2.shape)).max(axis=0)
img1 = skimage.transform.warp(img1, np.eye(3), output_shape=padded_shape)
img2 = skimage.transform.warp(img2, np.eye(3), output_shape=padded_shape)
angle = utils.register_angle(img1, img2, sigma=1)
if angle != 0:
print(f"\r estimated cycle rotation = {angle:.2f} degrees")
rotation_center = 0.5 * np.array(padded_shape[::-1]) - 0.5
img2 = skimage.transform.rotate(
img2, angle, resize=False, center=rotation_center
)
shifts = thumbnail.calculate_image_offset(img1, img2, int(1 / scale))
print(f"\r estimated shift {shifts / scale}")
tform_steps = [
("translation", -reader2.metadata.origin[::-1]),
("scale", scale),
("translation", -rotation_center),
("rotation", np.deg2rad(-angle)),
("translation", rotation_center),
("translation", shifts[::-1]),
("scale", 1 / scale),
("translation", reader1.metadata.origin[::-1]),
]
tform = skimage.transform.AffineTransform()
for step in tform_steps:
tform += skimage.transform.AffineTransform(**{step[0]: step[1]})

return tform


import numpy as np
import skimage.data
import skimage.transform
from ashlar import thumbnail

TILE_SIZE = (108, 128)

img = skimage.data.astronaut()[..., 1]
c1r = TestReader(img=img, tile_size=TILE_SIZE, overlap=0.25, noise=1)

affine = skimage.transform.AffineTransform
#tform = affine(
# translation=200 * (np.random.random(2) - 0.5),
# rotation=np.deg2rad(-10 * (np.random.random(1) - 0.5)[0]),
#)
tform = (
affine(translation=(-250, -280))
+ affine(rotation=np.deg2rad(88))
+ affine(translation=(250, 280))
)

# apply known transform to image
img2 = skimage.transform.warp(img, tform.inverse, preserve_range=True).astype(img.dtype)
c2r = TestReader(img=img2, tile_size=TILE_SIZE, overlap=0.25, noise=1)

# set random stage origin
c1r.metadata._positions += 2000 * (np.random.random(2) - 0.5)
c2r.metadata._positions += 2000 * (np.random.random(2) - 0.5)

# randomly perturb stage positions
c1r.metadata._positions += np.random.random_sample(c1r.metadata._positions.shape) * 5
c2r.metadata._positions += np.random.random_sample(c2r.metadata._positions.shape) * 5

a1 = reg.EdgeAligner(c1r, verbose=True)
a1.run()
print()
a2 = reg.LayerAligner(c2r, a1, verbose=True)
a2.run()

fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(a1.reader.thumbnail, cmap='gray', vmax=2e5)
ax2.imshow(a2.reader.thumbnail, cmap='gray', vmax=2e5)
for i, (y, x) in enumerate(a1.metadata.centers - a1.metadata.origin):
ax1.annotate(str(i), (x,y), ha='center', va='center', color='yellow')
for i, (y, x) in enumerate(a2.metadata.centers - a2.metadata.origin):
ax2.annotate(str(i), (x,y), ha='center', va='center', color='magenta')
plt.show()

0 comments on commit 35e4f8c

Please sign in to comment.