Skip to content

Commit

Permalink
[AL-1132] Added the ability to do inplace transforms (#1354)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhinavTuli authored Nov 28, 2021
1 parent 92cd5a9 commit fcc922f
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 56 deletions.
5 changes: 1 addition & 4 deletions hub/core/chunk_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def last_chunk_key(self) -> str:
def last_chunk_name(self) -> str:
return self.chunk_id_encoder.get_name_for_chunk(-1)

@property
def last_chunk(self) -> Optional[BaseChunk]:
if self.num_chunks == 0:
return None
Expand Down Expand Up @@ -307,9 +306,7 @@ def extend(self, samples):
if tensor_meta.dtype is None:
tensor_meta.set_dtype(get_dtype(samples))

current_chunk = (
self.last_chunk if self.last_chunk is not None else self._create_new_chunk()
)
current_chunk = self.last_chunk() or self._create_new_chunk()
updated_chunks = {current_chunk}

enc = self.chunk_id_encoder
Expand Down
3 changes: 3 additions & 0 deletions hub/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from hub.htype import DEFAULT_HTYPE, HTYPE_CONFIGURATIONS, UNSPECIFIED
from hub.integrations import dataset_to_tensorflow
from hub.util.bugout_reporter import hub_reporter
from hub.util.dataset import try_flushing
from hub.util.exceptions import (
CouldNotCreateNewDatasetException,
InvalidKeyTypeError,
Expand Down Expand Up @@ -392,6 +393,7 @@ def commit(self, message: Optional[str] = None) -> None:
str: the commit id of the stored commit that can be used to access the snapshot.
"""
commit_id = self.version_state["commit_id"]
try_flushing(self)
commit(self.version_state, self.storage, message)

# do not store commit message
Expand All @@ -413,6 +415,7 @@ def checkout(self, address: str, create: bool = False) -> str:
Returns:
str: The commit_id of the dataset after checkout.
"""
try_flushing(self)
checkout(self.version_state, self.storage, address, create)

# do not store address
Expand Down
139 changes: 139 additions & 0 deletions hub/core/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ def filter_tr(sample_in, sample_out):
sample_out.image.append(sample_in * np.ones((100, 100)))


@hub.compute
def inplace_transform(sample_in, samples_out):
samples_out.img.append(2 * sample_in.img.numpy())
samples_out.img.append(3 * sample_in.img.numpy())
samples_out.label.append(2 * sample_in.label.numpy())
samples_out.label.append(3 * sample_in.label.numpy())


def check_target_array(ds, index, target):
np.testing.assert_array_equal(
ds.img[index].numpy(), target * np.ones((500, 500, 3))
)
np.testing.assert_array_equal(
ds.label[index].numpy(), target * np.ones((100, 100, 3))
)


@all_schedulers
@enabled_non_gcs_datasets
def test_single_transform_hub_dataset(ds, scheduler):
Expand Down Expand Up @@ -449,3 +466,125 @@ def test_ds_out():
test_ds_out()

data_in.delete()


def test_inplace_transform(local_ds_generator):
ds = local_ds_generator()

with ds:
ds.create_tensor("img")
ds.create_tensor("label")
for _ in range(100):
ds.img.append(np.ones((500, 500, 3)))
ds.label.append(np.ones((100, 100, 3)))
a = ds.commit()
assert len(ds) == 100
for i in range(100):
check_target_array(ds, i, 1)

inplace_transform().eval(ds, num_workers=TRANSFORM_TEST_NUM_WORKERS)
assert ds.img.chunk_engine.num_samples == len(ds) == 200

for i in range(200):
target = 2 if i % 2 == 0 else 3
check_target_array(ds, i, target)

ds.checkout(a)
assert len(ds) == 100
for i in range(100):
check_target_array(ds, i, 1)

ds = local_ds_generator()
assert len(ds) == 200
for i in range(200):
target = 2 if i % 2 == 0 else 3
check_target_array(ds, i, target)

ds.checkout(a)
assert len(ds) == 100
for i in range(100):
check_target_array(ds, i, 1)


def test_inplace_transform_without_commit(local_ds_generator):
ds = local_ds_generator()

with ds:
ds.create_tensor("img")
ds.create_tensor("label")
for _ in range(100):
ds.img.append(np.ones((500, 500, 3)))
ds.label.append(np.ones((100, 100, 3)))
assert len(ds) == 100
for i in range(100):
check_target_array(ds, i, 1)

inplace_transform().eval(ds, num_workers=TRANSFORM_TEST_NUM_WORKERS)
assert ds.img.chunk_engine.num_samples == len(ds) == 200

for i in range(200):
target = 2 if i % 2 == 0 else 3
check_target_array(ds, i, target)

ds = local_ds_generator()
assert len(ds) == 200
for i in range(200):
target = 2 if i % 2 == 0 else 3
check_target_array(ds, i, target)


def test_inplace_transform_non_head(local_ds_generator):
ds = local_ds_generator()
with ds:
ds.create_tensor("img")
ds.create_tensor("label")
for _ in range(100):
ds.img.append(np.ones((500, 500, 3)))
ds.label.append(np.ones((100, 100, 3)))
assert len(ds) == 100
for i in range(100):
check_target_array(ds, i, 1)
a = ds.commit()
for _ in range(50):
ds.img.append(np.ones((500, 500, 3)))
ds.label.append(np.ones((100, 100, 3)))
assert len(ds) == 150
for i in range(150):
check_target_array(ds, i, 1)

ds.checkout(a)

# transforming non-head node
inplace_transform().eval(ds, num_workers=4)
b = ds.commit_id

assert len(ds) == 200
for i in range(200):
target = 2 if i % 2 == 0 else 3
check_target_array(ds, i, target)

ds.checkout(a)
assert len(ds) == 100
for i in range(100):
check_target_array(ds, i, 1)

ds.checkout("main")
assert len(ds) == 150
for i in range(150):
check_target_array(ds, i, 1)

ds = local_ds_generator()
assert len(ds) == 150
for i in range(150):
check_target_array(ds, i, 1)

ds.checkout(a)
assert len(ds) == 100
for i in range(100):
check_target_array(ds, i, 1)

ds.checkout(b)
assert len(ds) == 200
for i in range(200):
target = 2 if i % 2 == 0 else 3
check_target_array(ds, i, target)
86 changes: 56 additions & 30 deletions hub/core/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import math
from typing import List, Callable, Optional
from itertools import repeat
from hub.constants import FIRST_COMMIT_ID
from hub.core.compute.provider import ComputeProvider
from hub.core.compute.thread import ThreadProvider
from hub.core.compute.process import ProcessProvider
from hub.core.compute.serial import SerialProvider
from hub.core.ipc import Server
from hub.util.bugout_reporter import hub_reporter
from hub.util.chunk_paths import get_chunk_paths
from hub.util.compute import get_compute_provider
from hub.util.remove_cache import get_base_storage, get_dataset_with_zero_size_cache
from hub.util.transform import (
Expand All @@ -16,7 +15,11 @@
get_pbar_description,
store_data_slice,
)
from hub.util.encoder import merge_all_chunk_id_encoders, merge_all_tensor_metas
from hub.util.encoder import (
merge_all_chunk_id_encoders,
merge_all_commit_chunk_sets,
merge_all_tensor_metas,
)
from hub.util.exceptions import (
HubComposeEmptyListError,
HubComposeIncompatibleFunction,
Expand All @@ -28,6 +31,8 @@
import threading
import sys

from hub.util.version_control import auto_checkout, load_meta


class TransformFunction:
def __init__(self, func, args, kwargs):
Expand All @@ -39,7 +44,7 @@ def __init__(self, func, args, kwargs):
def eval(
self,
data_in,
ds_out: hub.Dataset,
ds_out: Optional[hub.Dataset] = None,
num_workers: int = 0,
scheduler: str = "threaded",
progressbar: bool = True,
Expand Down Expand Up @@ -79,7 +84,7 @@ def __len__(self):
def eval(
self,
data_in,
ds_out: hub.Dataset,
ds_out: Optional[hub.Dataset] = None,
num_workers: int = 0,
scheduler: str = "threaded",
progressbar: bool = True,
Expand All @@ -103,10 +108,12 @@ def eval(
UnsupportedSchedulerError: If the scheduler passed is not recognized. Supported values include: "serial", 'threaded', 'processed' and 'ray'.
TransformError: All other exceptions raised if there are problems while running the pipeline.
"""
num_workers = max(num_workers, 0)
if num_workers == 0:
if num_workers <= 0:
scheduler = "serial"
num_workers = max(num_workers, 1)
compute_provider = get_compute_provider(scheduler, num_workers)

original_data_in = data_in
if isinstance(data_in, hub.Dataset):
data_in = get_dataset_with_zero_size_cache(data_in)

Expand All @@ -116,26 +123,33 @@ def eval(
)

check_transform_data_in(data_in, scheduler)
check_transform_ds_out(ds_out, scheduler)
target_ds = data_in if ds_out is None else ds_out
check_transform_ds_out(target_ds, scheduler)
target_ds.flush()
# if not the head node, checkout to an auto branch that is newly created
auto_checkout(target_ds.version_state, target_ds.storage)

ds_out.flush()
initial_autoflush = ds_out.storage.autoflush
ds_out.storage.autoflush = False
initial_autoflush = target_ds.storage.autoflush
target_ds.storage.autoflush = False

tensors = list(ds_out.tensors)

compute_provider = get_compute_provider(scheduler, num_workers)
overwrite = ds_out is None
if overwrite:
original_data_in.clear_cache()

try:
self.run(
data_in, ds_out, tensors, compute_provider, num_workers, progressbar
data_in,
target_ds,
compute_provider,
num_workers,
progressbar,
overwrite,
)
except Exception as e:
raise TransformError(e)
finally:
compute_provider.close()

ds_out.storage.autoflush = initial_autoflush
target_ds.storage.autoflush = initial_autoflush

def _run_with_progbar(
self, func: Callable, ret: dict, total: int, desc: Optional[str] = ""
Expand Down Expand Up @@ -181,23 +195,23 @@ def callback(data):
def run(
self,
data_in,
ds_out: hub.Dataset,
tensors: List[str],
target_ds: hub.Dataset,
compute: ComputeProvider,
num_workers: int,
progressbar: bool = True,
overwrite: bool = False,
):
"""Runs the pipeline on the input data to produce output samples and stores in the dataset.
This receives arguments processed and sanitized by the Pipeline.eval method.
"""
is_serial = isinstance(compute, SerialProvider)
num_workers = max(num_workers, 1)
size = math.ceil(len(data_in) / num_workers)
slices = [data_in[i * size : (i + 1) * size] for i in range(num_workers)]
storage = get_base_storage(target_ds.storage)
group_index = target_ds.group_index # type: ignore
version_state = target_ds.version_state

output_base_storage = get_base_storage(ds_out.storage)
version_state = ds_out.version_state
tensors = [ds_out.tensors[t].key for t in tensors]
tensors = list(target_ds.tensors)
tensors = [target_ds.tensors[t].key for t in tensors]

ret = {}

Expand All @@ -206,7 +220,7 @@ def _run(progress_port=None):
store_data_slice,
zip(
slices,
repeat((output_base_storage, ds_out.group_index)), # type: ignore
repeat((storage, group_index)), # type: ignore
repeat(tensors),
repeat(self),
repeat(version_state),
Expand All @@ -221,11 +235,23 @@ def _run(progress_port=None):
else:
_run()

metas_and_encoders = ret["metas_and_encoders"]
if overwrite:
chunk_paths = get_chunk_paths(target_ds, tensors)
# TODO:
# delete_chunks(chunk_paths, storage, compute)

all_tensor_metas, all_chunk_id_encoders = zip(*metas_and_encoders)
merge_all_tensor_metas(all_tensor_metas, ds_out)
merge_all_chunk_id_encoders(all_chunk_id_encoders, ds_out)
metas_and_encoders = ret["metas_and_encoders"]
all_tensor_metas, all_chunk_id_encoders, all_chunk_commit_sets = zip(
*metas_and_encoders
)
merge_all_tensor_metas(all_tensor_metas, target_ds, storage, overwrite)
merge_all_chunk_id_encoders(
all_chunk_id_encoders, target_ds, storage, overwrite
)
if target_ds.commit_id != FIRST_COMMIT_ID:
merge_all_commit_chunk_sets(
all_chunk_commit_sets, target_ds, storage, overwrite
)


def compose(functions: List[TransformFunction]):
Expand Down
Loading

0 comments on commit fcc922f

Please sign in to comment.