Skip to content

Commit

Permalink
spatial sharding WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollman committed Feb 1, 2024
1 parent b058f78 commit 331b523
Showing 1 changed file with 154 additions and 38 deletions.
192 changes: 154 additions & 38 deletions python/neuroglancer/write_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import struct
from collections.abc import Sequence
from typing import Literal, NamedTuple, Optional, Union, cast
from logging import warning
import tensorstore as ts
import numpy as np
import math

from . import coordinate_space, viewer_state

Expand Down Expand Up @@ -102,6 +102,55 @@ def choose_output_spec(total_count, total_bytes,
return options


def compressed_morton_code(gridpt, grid_size):
# from cloudvolume
if hasattr(gridpt, "__len__") and len(gridpt) == 0: # generators don't have len
return np.zeros((0,), dtype=np.uint32)

gridpt = np.asarray(gridpt, dtype=np.uint32)
single_input = False
if gridpt.ndim == 1:
gridpt = np.atleast_2d(gridpt)
single_input = True

code = np.zeros((gridpt.shape[0],), dtype=np.uint64)
num_bits = [ math.ceil(math.log2(size)) for size in grid_size ]
j = np.uint64(0)
one = np.uint64(1)

if sum(num_bits) > 64:
raise ValueError(f"Unable to represent grids that require more than 64 bits. Grid size {grid_size} requires {num_bits} bits.")

max_coords = np.max(gridpt, axis=0)
if np.any(max_coords >= grid_size):
raise ValueError(f"Unable to represent grid points larger than the grid. Grid size: {grid_size} Grid points: {gridpt}")

for i in range(max(num_bits)):
for dim in range(3):
if 2 ** i < grid_size[dim]:
bit = (((np.uint64(gridpt[:, dim]) >> np.uint64(i)) & one) << j)
code |= bit
j += one
print(gridpt, grid_size, code)
if single_input:
return code[0]
return code

# def compressed_morton_code(position, shape):
# output_bit = 0
# rank = len(position)
# output_num = 0
# for bit in range(32):
# for dim in range(rank-1, -1, -1):
# if (shape[dim] - 1) >> bit:
# output_num |= ((position[dim] >> bit) & 1) << output_bit
# output_bit += 1
# if output_bit == 64:
# # In Python, we don't have the 32-bit limitation, so we don't need to split into high and low.
# # But you can add code here to handle or signal overflow if needed.
# pass
# return output_num

def _get_dtype_for_geometry(annotation_type: AnnotationType, rank: int):
geometry_size = rank if annotation_type == "point" else 2 * rank
return [("geometry", "<f4", geometry_size)]
Expand Down Expand Up @@ -302,19 +351,67 @@ def _serialize_annotations_sharded(self, path, annotations, shard_spec):
txn.commit_async().result()

def _serialize_annotations(self, f, annotations: list[Annotation]):
f.write(struct.pack("<Q", len(annotations)))
for annotation in annotations:
f.write(annotation.encoded)
for annotation in annotations:
f.write(struct.pack("<Q", annotation.id))

f.write(self._encode_multiple_annotations(annotations))

def _serialize_annotation(self, f, annotation: Annotation):
f.write(annotation.encoded)
for related_ids in annotation.relationships:
f.write(struct.pack("<I", len(related_ids)))
for related_id in related_ids:
f.write(struct.pack("<Q", related_id))

def _encode_multiple_annotations(self, annotations: list[Annotation]):
"""
This function creates a binary string from a list of annotations.
Parameters:
annotations (list): List of annotation objects. Each object should have 'encoded' and 'id' attributes.
Returns:
bytes: Binary string of all components together.
"""
binary_components = []
binary_components.append(struct.pack("<Q", len(annotations)))
for annotation in annotations:
binary_components.append(annotation.encoded)
for annotation in annotations:
binary_components.append(struct.pack("<Q", annotation.id))
return b"".join(binary_components)

def _serialize_annotations_by_related_id(self, path, related_id_dict, shard_spec):
spec = {
'driver': 'neuroglancer_uint64_sharded',
'metadata': shard_spec,
"base": f"file://{path}"
}
dataset = ts.KvStore.open(spec).result()
txn = ts.Transaction()
for related_id, annotations in related_id_dict.items():
# convert the ann.id to a binary representation of a uint64
key = related_id.to_bytes(8, 'little')
value = self._encode_multiple_annotations(annotations)
dataset.with_transaction(txn)[key]=value
txn.commit_async().result()

def _serialize_annotation_chunk_sharded(self, path, annotations_by_chunk, shard_spec, max_sizes):
spec = {
'driver': 'neuroglancer_uint64_sharded',
'metadata': shard_spec,
"base": f"file://{path}"
}
dataset = ts.KvStore.open(spec).result()
txn = ts.Transaction()
for chunk_index, annotations in annotations_by_chunk.items():
# calculate the compressed morton code for the chunk index
key = compressed_morton_code(chunk_index, max_sizes)
print(key, type(key))
key = key.astype('<u8').tobytes()
print(key, type(key))
value = self._encode_multiple_annotations(annotations)
dataset.with_transaction(txn)[key] = value

txn.commit_async().result()

def write(self, path: Union[str, pathlib.Path]):
metadata = {
"@type": "neuroglancer_annotations_v1",
Expand All @@ -323,62 +420,81 @@ def write(self, path: Union[str, pathlib.Path]):
"upper_bound": [float(x) for x in self.upper_bound],
"annotation_type": self.annotation_type,
"properties": [p.to_json() for p in self.properties],
"relationships": [
{"id": relationship, "key": f"rel_{relationship}"}
for relationship in self.relationships
],
"relationships": [],
"by_id": {
"key": "by_id"
}
}
total_ann_bytes = sum(len(a.encoded) for a in self.annotations)
sharding_spec = choose_output_spec(len(self.annotations),
sum(len(a.encoded) for a in self.annotations))
if sharding_spec is not None:
metadata["by_id"]["sharding"] = sharding_spec

total_ann_bytes)

# calculate the number of chunks in each dimension
num_chunks = np.ceil((self.upper_bound - self.lower_bound) / self.chunk_size).astype(int)

# find the maximum number of annotations in any chunk
max_annotations = min(len(annotations) for annotations in self.annotations_by_chunk.values())
max_annotations = max(len(annotations) for annotations in self.annotations_by_chunk.values())

# make directories
os.makedirs(path, exist_ok=True)
for relationship in self.relationships:
os.makedirs(os.path.join(path, f"rel_{relationship}"), exist_ok=True)
os.makedirs(os.path.join(path, "by_id"), exist_ok=True)
os.makedirs(os.path.join(path, "spatial0"), exist_ok=True)

total_chunks = len(self.annotations_by_chunk)
spatial_sharding_spec = choose_output_spec(total_chunks,
total_ann_bytes + 8*len(self.annotations)+8*total_chunks)
# initialize metadata for spatial index
metadata['spatial'] = [
{
"key": f"spatial0",
"key": "spatial0",
"grid_shape": num_chunks.tolist(),
"chunk_size": [int(x) for x in self.chunk_size],
"limit": max_annotations
}
]
# write annotations by spatial chunk
if spatial_sharding_spec is not None:
self._serialize_annotation_chunk_sharded(os.path.join(path, "spatial0"),
self.annotations_by_chunk,
spatial_sharding_spec,
num_chunks.tolist())
metadata['spatial'][0]['sharding'] = spatial_sharding_spec
else:
for chunk_index, annotations in self.annotations_by_chunk.items():
chunk_name = "_".join([str(c) for c in chunk_index])
filepath = os.path.join(path, "spatial0", chunk_name)
with open(filepath, 'wb') as f:
self._serialize_annotations(f, annotations)

os.makedirs(path, exist_ok=True)
for relationship in self.relationships:
os.makedirs(os.path.join(path, f"rel_{relationship}"), exist_ok=True)
os.makedirs(os.path.join(path, "by_id"), exist_ok=True)
os.makedirs(os.path.join(path, "spatial0"), exist_ok=True)

with open(os.path.join(path, "info"), "w") as f:

f.write(json.dumps(metadata, cls=NumpyEncoder))

for chunk_index, annotations in self.annotations_by_chunk.items():
with open(
os.path.join(path, "spatial0", "_".join([str(c) for c in chunk_index])),
"wb",
) as f:
self._serialize_annotations(f, annotations)

# write annotations by id
if sharding_spec is not None:
self._serialize_annotations_sharded(os.path.join(path, "by_id"), self.annotations, sharding_spec)
metadata["by_id"]["sharding"] = sharding_spec
else:
for annotation in self.annotations:
with open(os.path.join(path, "by_id", str(annotation.id)), "wb") as f:
self._serialize_annotation(f, annotation)

# write relationships
for i, relationship in enumerate(self.relationships):
rel_index = self.related_annotations[i]
for segment_id, annotations in rel_index.items():
filepath = os.path.join(path, f"rel_{relationship}", str(segment_id))
with open(filepath, "wb") as f:
self._serialize_annotations(f, annotations)
relationship_sharding_spec = choose_output_spec(len(rel_index),
total_ann_bytes + 8*len(self.annotations)+8*total_chunks)
rel_md = {"id": relationship,
"key": f"rel_{relationship}"}
if relationship_sharding_spec is not None:
rel_md["sharding"] = relationship_sharding_spec
self._serialize_annotations_by_related_id(os.path.join(path, f"rel_{relationship}"), rel_index, relationship_sharding_spec)
else:
for segment_id, annotations in rel_index.items():
filepath = os.path.join(path, f"rel_{relationship}", str(segment_id))
with open(filepath, "wb") as f:
self._serialize_annotations(f, annotations)

metadata["relationships"].append(rel_md)

# write metadata info file
with open(os.path.join(path, "info"), "w") as f:
f.write(json.dumps(metadata, cls=NumpyEncoder))

0 comments on commit 331b523

Please sign in to comment.