Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/compressed tree w protobuf #126

Draft
wants to merge 46 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
45b9a11
create tensor index tree object and incorporate into hullslicer
mathleur Feb 15, 2024
fe35193
fix some tests
mathleur Feb 16, 2024
4c868b2
fix get for xarray backend
mathleur Feb 16, 2024
6bd0b1c
fix xarray testts
mathleur Feb 16, 2024
2efbacf
fix mroe tests
mathleur Feb 16, 2024
07d48f3
fix type change transformation
mathleur Feb 16, 2024
3cbed71
fix all tests with xarray backend
mathleur Feb 16, 2024
7eb8755
make first version fo tensor index tree work with fdb
mathleur Feb 19, 2024
ba382fa
more tests
mathleur Feb 19, 2024
e415a2d
start adding compression for grids
mathleur Feb 19, 2024
ba65c47
clean up
mathleur Feb 19, 2024
d60bb0a
make right grid axes be compressed in tree
mathleur Feb 19, 2024
4110f93
clean up
mathleur Feb 29, 2024
ed43774
remove duplicate print for index trees
mathleur Mar 4, 2024
cb94a1e
start moving compressed tree logic of adding children to tensor tree …
mathleur Mar 4, 2024
ef80bb5
merge develop and move compression logic into tensor index tree
mathleur Mar 5, 2024
bce6fb4
fix xarray get
mathleur Mar 6, 2024
9b5e71d
start to fix more tests
mathleur Mar 7, 2024
582f299
fix almost all problems except data size attached to leaves
mathleur Mar 7, 2024
fab2252
fix almost all tests, but data extracted from fdb is only the first i…
mathleur Mar 7, 2024
1ec5145
fix small bug on slicing for flat polytopes
mathleur Mar 11, 2024
3b16685
fix everything except returned values from pygribjump
mathleur Mar 11, 2024
8c9f273
flake8
mathleur Mar 11, 2024
c4ff333
flake8
mathleur Mar 11, 2024
171efae
flake8
mathleur Mar 11, 2024
9acfcf4
make associating data to right tree nodes work with compressed trees
mathleur Mar 14, 2024
670fb08
clean up
mathleur Mar 14, 2024
4394117
add a index tree protobuf
mathleur Mar 15, 2024
62f4363
fix test
mathleur Mar 15, 2024
83a6aec
add protobuf index tree encoder
mathleur Mar 15, 2024
8bdc5ba
add protobuf index tree decoder
mathleur Mar 15, 2024
b1b81a6
add encoder/decoder support for timestamps/timedeltas and test
mathleur Mar 18, 2024
74f0f75
add performance test for protobuf encoder/decoder
mathleur Mar 18, 2024
48da664
Merge branch 'feature/index_tree_protobuf' of github.com:ecmwf/polyto…
mathleur Mar 18, 2024
93b02b4
modify protobuf tree encoding to support compressed tree tuple values
mathleur Mar 19, 2024
61f3438
add result_size to protobuf compressed tree
mathleur Apr 18, 2024
9363e69
remove large file
mathleur Apr 18, 2024
dca546e
update gitignore
mathleur Apr 18, 2024
d5444fe
move to proto3
mathleur Apr 18, 2024
fa8e7fa
start to add cap proto index tree
mathleur Apr 23, 2024
5ce2c2a
finish capnp protobuffer of compressed tree
mathleur Apr 23, 2024
91a42a3
add and fix capnp encoder
mathleur Apr 29, 2024
2ed7e10
make encoder and decoder work for capnp
mathleur Apr 29, 2024
5038c63
make protobuf encoding right
mathleur Apr 29, 2024
b90e92a
fix tests
mathleur Apr 30, 2024
a07970e
do encoding and decoding to a byte array instead of a file
mathleur Apr 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ site
example_eo
example_mri
.mypy_cache
*.req
*.req
serializedTree
18 changes: 18 additions & 0 deletions indexTree.capnp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@0xae1f1be0650fec43;

struct Value {
# NEED TO DO THIS STILL
value :union {
strVal @0 :Text;
intVal @1 :Int64;
doubleVal @2 :Float64;
}
}

struct Node {
axis @0 :Text;
value @1 :List(Value);
result @2 :List(Float64);
resultSize @3 :List(Int64);
children @4 :List(Node);
}
5 changes: 5 additions & 0 deletions polytope/datacube/backends/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...utility.combinatorics import validate_axes
from ..datacube_axis import DatacubeAxis
from ..index_tree import DatacubePath, IndexTree
from ..transformations.datacube_mappers.datacube_mappers import DatacubeMapper
from ..transformations.datacube_transformations import (
DatacubeAxisTransformation,
has_transform,
Expand All @@ -31,6 +32,7 @@ def __init__(self, axis_options=None, datacube_options=None):
self.nearest_search = {}
self._axes = None
self.transformed_axes = []
self.compressed_grid_axes = []

@abstractmethod
def get(self, requests: IndexTree) -> Any:
Expand All @@ -54,6 +56,9 @@ def _create_axes(self, name, values, transformation_type_key, transformation_opt
)
for blocked_axis in transformation.blocked_axes():
self.blocked_axes.append(blocked_axis)
if isinstance(transformation, DatacubeMapper):
for compressed_grid_axis in transformation.compressed_grid_axes:
self.compressed_grid_axes.append(compressed_grid_axis)
if len(final_axis_names) > 1:
self.coupled_axes.append(final_axis_names)
for axis_name in final_axis_names:
Expand Down
85 changes: 67 additions & 18 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from copy import deepcopy
from itertools import product

import numpy as np
import pygribjump as pygj

from ...utility.geometry import nearest_pt
Expand Down Expand Up @@ -47,8 +49,46 @@ def get(self, requests: IndexTree):
fdb_requests = []
fdb_requests_decoding_info = []
self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info)
output_values = self.gj.extract(fdb_requests)
self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info)

# TODO: note that this doesn't exactly work as intended, it's just going to retrieve value from gribjump that
# corresponds to first value in the compressed tuples

complete_branch_combi_sizes = []
output_values = []
for request in fdb_requests:
interm_branch_tuple_values = []
for key in request[0].keys():
# remove the tuple of the request when we ask the fdb
interm_branch_tuple_values.append(request[0][key])
request[0][key] = request[0][key][0]
branch_tuple_combi = product(*interm_branch_tuple_values)
# TODO: now build the relevant requests from this and ask gj for them
# TODO: then group the output values together to fit back with the original compressed request and continue
new_requests = []
for combi in branch_tuple_combi:
new_request = {}
for i, key in enumerate(request[0].keys()):
new_request[key] = combi[i]
new_requests.append((new_request, request[1]))
branch_output_values = self.gj.extract(new_requests)
branch_combi_sizes = [len(t) for t in interm_branch_tuple_values]

all_remapped_output_values = []
for k, req in enumerate(new_requests):
output = branch_output_values[k][0]
output_dict = {}
for i, o in enumerate(output):
output_dict[i] = o[0]

all_remapped_output_values.append(output_dict)

output_data_branch = []
output_data_branch = np.array(all_remapped_output_values)
output_data_branch = np.reshape(output_data_branch, tuple(branch_combi_sizes))
output_values.append([output_data_branch])
complete_branch_combi_sizes.append([list(range(b)) for b in branch_combi_sizes])

self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info, complete_branch_combi_sizes)

def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path=None):
if leaf_path is None:
Expand All @@ -62,7 +102,7 @@ def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_de
self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info)
# If request node has no children, we have a leaf so need to assign fdb values to it
else:
key_value_path = {requests.axis.name: requests.value}
key_value_path = {requests.axis.name: requests.values}
ax = requests.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
Expand Down Expand Up @@ -112,7 +152,7 @@ def get_2nd_last_values(self, requests, leaf_path=None):
found_latlon_pts = []
for lat_child in requests.children:
for lon_child in lat_child.children:
found_latlon_pts.append([lat_child.value, lon_child.value])
found_latlon_pts.append([lat_child.values, lon_child.values])

# now find the nearest lat lon to the points requested
nearest_latlons = []
Expand All @@ -121,20 +161,21 @@ def get_2nd_last_values(self, requests, leaf_path=None):
nearest_latlons.append(nearest_latlon)

# need to remove the branches that do not fit
lat_children_values = [child.value for child in requests.children]
lat_children_values = [child.values for child in requests.children]
for i in range(len(lat_children_values)):
lat_child_val = lat_children_values[i]
lat_child = [child for child in requests.children if child.value == lat_child_val][0]
if lat_child.value not in [latlon[0] for latlon in nearest_latlons]:
lat_child = [child for child in requests.children if child.values == lat_child_val][0]
if lat_child.values not in [(latlon[0],) for latlon in nearest_latlons]:
lat_child.remove_branch()
else:
possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value]
lon_children_values = [child.value for child in lat_child.children]
possible_lons = [latlon[1] for latlon in nearest_latlons if (latlon[0],) == lat_child.values]
lon_children_values = [child.values for child in lat_child.children]
for j in range(len(lon_children_values)):
lon_child_val = lon_children_values[j]
lon_child = [child for child in lat_child.children if child.value == lon_child_val][0]
if lon_child.value not in possible_lons:
lon_child.remove_branch()
lon_child = [child for child in lat_child.children if child.values == lon_child_val][0]
for value in lon_child.values:
if value not in possible_lons:
lon_child.remove_compressed_branch(value)

lat_length = len(requests.children)
range_lengths = [False] * lat_length
Expand All @@ -149,7 +190,7 @@ def get_2nd_last_values(self, requests, leaf_path=None):
range_length = deepcopy(range_lengths[i])
current_start_idx = deepcopy(current_start_idxs[i])
fdb_range_nodes = deepcopy(fdb_node_ranges[i])
key_value_path = {lat_child.axis.name: lat_child.value}
key_value_path = {lat_child.axis.name: lat_child.values}
ax = lat_child.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
Expand All @@ -160,14 +201,14 @@ def get_2nd_last_values(self, requests, leaf_path=None):
)

leaf_path_copy = deepcopy(leaf_path)
leaf_path_copy.pop("values")
leaf_path_copy.pop("values", None)
return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length)

def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, fdb_range_n):
i = 0
for c in requests.children:
# now c are the leaves of the initial tree
key_value_path = {c.axis.name: c.value}
key_value_path = {c.axis.name: c.values}
ax = c.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
Expand All @@ -182,7 +223,7 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx,
range_l[i] += 1
fdb_range_n[i][range_l[i] - 1] = c
else:
key_value_path = {c.axis.name: c.value}
key_value_path = {c.axis.name: c.values}
ax = c.axis
(key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key(
key_value_path, leaf_path, self.unwanted_path
Expand All @@ -193,8 +234,10 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx,
current_idx[i] = current_start_idx
return (range_l, current_idx, fdb_range_n)

def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info):
def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info, complete_branch_combi_sizes):
for k in range(len(output_values)):
combi_sizes = complete_branch_combi_sizes[k]
combi_sizes_combis = list(product(*combi_sizes))
request_output_values = output_values[k]
(
original_indices,
Expand All @@ -215,7 +258,13 @@ def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info):
for i in range(len(sorted_fdb_range_nodes)):
for j in range(sorted_range_lengths[i]):
n = sorted_fdb_range_nodes[i][j]
n.result = request_output_values[0][i][0][j]
for size_combi in list(combi_sizes_combis):
interm_output_values = request_output_values[0]
# TODO: the result associated to nodes is still only a simple float, not an array and is not
# the right one...
for val in size_combi:
interm_output_values = interm_output_values[val]
n.result = interm_output_values[i][j]

def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length):
interm_request_ranges = []
Expand Down
2 changes: 1 addition & 1 deletion polytope/datacube/backends/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get(self, requests: IndexTree):
for r in requests.leaves:
path = r.flatten()
if len(path.items()) == len(self.dimensions.items()):
result = 0
result = (0,)
for k, v in path.items():
result += v * self.stride[k]

Expand Down
17 changes: 13 additions & 4 deletions polytope/datacube/backends/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,26 @@ def get(self, requests: IndexTree):
for r in requests.leaves:
path = r.flatten()
if len(path.items()) == self.axis_counter:
# first, find the grid mapper transform
# TODO: need to undo the tuples in the path into actual paths with a single value that xarray can read
unmapped_path = {}
path_copy = deepcopy(path)
for key in path_copy:
axis = self._axes[key]
key_value_path = {key: path_copy[key]}
# (path, unmapped_path) = axis.unmap_to_datacube(path, unmapped_path)
(key_value_path, path, unmapped_path) = axis.unmap_path_key(key_value_path, path, unmapped_path)
path.update(key_value_path)
path.update(unmapped_path)

unmapped_path = {}
self.refit_path(path, unmapped_path, path)
for key in path:
path[key] = list(path[key])
for key in unmapped_path:
if isinstance(unmapped_path[key], tuple):
unmapped_path[key] = list(unmapped_path[key])

subxarray = self.dataarray.sel(path, method="nearest")
subxarray = subxarray.sel(unmapped_path)
value = subxarray.item()
value = subxarray.values
key = subxarray.name
r.result = (key, value)
else:
Expand Down Expand Up @@ -93,6 +96,12 @@ def refit_path(self, path_copy, unmapped_path, path):
path_copy.pop(key, None)

def select(self, path, unmapped_path):
for key in path:
key_value = path[key][0]
path[key] = key_value
for key in unmapped_path:
key_value = unmapped_path[key][0]
unmapped_path[key] = key_value
path_copy = deepcopy(path)
self.refit_path(path_copy, unmapped_path, path)
subarray = self.dataarray.sel(path_copy, method="nearest")
Expand Down
5 changes: 5 additions & 0 deletions polytope/datacube/datacube_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(self):
# TODO: Maybe here, store transformations as a dico instead
self.transformations = []
self.type = 0
self.can_round = True

def parse(self, value: Any) -> Any:
return float(value)
Expand All @@ -194,6 +195,7 @@ def __init__(self):
self.range = None
self.transformations = []
self.type = 0.0
self.can_round = True

def parse(self, value: Any) -> Any:
return float(value)
Expand All @@ -215,6 +217,7 @@ def __init__(self):
self.range = None
self.transformations = []
self.type = pd.Timestamp("2000-01-01T00:00:00")
self.can_round = False

def parse(self, value: Any) -> Any:
if isinstance(value, np.str_):
Expand Down Expand Up @@ -244,6 +247,7 @@ def __init__(self):
self.range = None
self.transformations = []
self.type = np.timedelta64(0, "s")
self.can_round = False

def parse(self, value: Any) -> Any:
if isinstance(value, np.str_):
Expand Down Expand Up @@ -272,6 +276,7 @@ def __init__(self):
self.tol = float("NaN")
self.range = None
self.transformations = []
self.can_round = False

def parse(self, value: Any) -> Any:
return value
Expand Down
18 changes: 18 additions & 0 deletions polytope/datacube/index_tree.capnp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@


struct Value {
# NEED TO DO THIS STILL
value :union {
str_val @0 :Text;
int_val @1 :Int64;
double_val @2 :Float64;
}
}

struct Node {
axis @0 :Text;
value @1 :List(Value);
result @2 :List(Float64);
result_size @3 :List(Int64);
children @4 :List(Node);
}
20 changes: 20 additions & 0 deletions polytope/datacube/index_tree.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
syntax = "proto3";

package index_tree;

message Value {
oneof value {
string str_val = 1;
int64 int_val = 2;
double double_val = 3;
}
}

message Node {
string axis = 1;
repeated Value value = 2;
repeated double result = 3;
repeated int64 result_size = 4;
repeated Node children = 5;
}

7 changes: 0 additions & 7 deletions polytope/datacube/index_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,6 @@ def copy_children_from_other(self, other):
c.copy_children_from_other(o)
return

def pprint_2(self, level=0):
if self.axis.name == "root":
print("\n")
print("\t" * level + "\u21b3" + str(self))
for child in self.children:
child.pprint_2(level + 1)

def _collect_leaf_nodes_old(self, leaves):
if len(self.children) == 0:
leaves.append(self)
Expand Down
28 changes: 28 additions & 0 deletions polytope/datacube/index_tree_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading