Skip to content

Commit

Permalink
Implement serialization functionality for Config and VFS objects (#…
Browse files Browse the repository at this point in the history
…2110)

* Implement serialization functionality for VFS objects
* Implement serialization for Config
* Remove PatchedConfig and PatchedCtx

---------

Co-authored-by: Theodore Tsirpanis <[email protected]>
  • Loading branch information
kounelisagis and teo-tsirpanis authored Nov 25, 2024
1 parent e9d05cd commit 806d1ae
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 53 deletions.
14 changes: 14 additions & 0 deletions tiledb/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,20 @@ def save(self, uri: str):
"""
self.save_to_file(uri)

def __reduce__(self):
"""
Customize the pickling process by defining how to serialize
and reconstruct the Config object.
"""
state = self.dict()
return (self.__class__, (), state)

def __setstate__(self, state):
"""
Customize how the Config object is restored from its serialized state.
"""
self.__init__(state)


class ConfigKeys:
"""
Expand Down
48 changes: 0 additions & 48 deletions tiledb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,54 +52,6 @@ def pytest_configure(config):
# default must be set here rather than globally
pytest.tiledb_vfs = "file"

vfs_config(config)


def vfs_config(pytestconfig):
vfs_config_override = {}

vfs = pytestconfig.getoption("vfs")
if vfs == "s3":
pytest.tiledb_vfs = "s3"

vfs_config_override.update(
{
"vfs.s3.endpoint_override": "localhost:9999",
"vfs.s3.aws_access_key_id": "minio",
"vfs.s3.aws_secret_access_key": "miniosecretkey",
"vfs.s3.scheme": "https",
"vfs.s3.verify_ssl": False,
"vfs.s3.use_virtual_addressing": False,
}
)

vfs_config_arg = pytestconfig.getoption("vfs-config", None)
if vfs_config_arg:
pass

tiledb._orig_ctx = tiledb.Ctx

def get_config(config):
final_config = {}
if isinstance(config, tiledb.Config):
final_config = config.dict()
elif config:
final_config = config

final_config.update(vfs_config_override)
return final_config

class PatchedCtx(tiledb.Ctx):
def __init__(self, config=None):
super().__init__(get_config(config))

class PatchedConfig(tiledb.Config):
def __init__(self, params=None):
super().__init__(get_config(params))

tiledb.Ctx = PatchedCtx
tiledb.Config = PatchedConfig


@pytest.fixture(scope="function", autouse=True)
def isolate_os_fork(original_os_fork):
Expand Down
19 changes: 19 additions & 0 deletions tiledb/tests/test_context_and_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import os
import pickle
import subprocess
import sys
import xml
Expand Down Expand Up @@ -261,3 +263,20 @@ def test_config_repr_html(self):
pytest.fail(
f"Could not parse config._repr_html_(). Saw {config._repr_html_()}"
)

def test_config_pickle(self):
# test that Config can be pickled and unpickled
config = tiledb.Config(
{
"rest.use_refactored_array_open": "false",
"rest.use_refactored_array_open_and_query_submit": "true",
"vfs.azure.storage_account_name": "myaccount",
}
)
with io.BytesIO() as buf:
pickle.dump(config, buf)
buf.seek(0)
config2 = pickle.load(buf)

self.assertIsInstance(config2, tiledb.Config)
self.assertEqual(config2.dict(), config.dict())
16 changes: 16 additions & 0 deletions tiledb/tests/test_vfs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import os
import pathlib
import pickle
import random
import sys

Expand Down Expand Up @@ -239,6 +240,21 @@ def test_io(self):
txtio = io.TextIOWrapper(f2, encoding="utf-8")
self.assertEqual(txtio.readlines(), lines)

def test_pickle(self):
# test that vfs can be pickled and unpickled with config options
config = tiledb.Config(
{"vfs.s3.region": "eu-west-1", "vfs.max_parallel_ops": "1"}
)
vfs = tiledb.VFS(config)
with io.BytesIO() as buf:
pickle.dump(vfs, buf)
buf.seek(0)
vfs2 = pickle.load(buf)

self.assertIsInstance(vfs2, tiledb.VFS)
self.assertEqual(vfs2.config()["vfs.s3.region"], "eu-west-1")
self.assertEqual(vfs2.config()["vfs.max_parallel_ops"], "1")

def test_sc42569_vfs_memoryview(self):
# This test is to ensure that giving np.ndarray buffer to readinto works
# when trying to write bytes that cannot be converted to float32 or int32
Expand Down
25 changes: 20 additions & 5 deletions tiledb/vfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class VFS(lt.VFS):
"""

def __init__(self, config: Union[Config, dict] = None, ctx: Optional[Ctx] = None):
ctx = ctx or default_ctx()
self.ctx = ctx or default_ctx()

if config:
from .libtiledb import Config
Expand All @@ -39,12 +39,12 @@ def __init__(self, config: Union[Config, dict] = None, ctx: Optional[Ctx] = None
raise ValueError("`config` argument must be of type Config or dict")

# Convert all values to strings
config = {k: str(v) for k, v in config.items()}
self.config_dict = {k: str(v) for k, v in config.items()}

ccfg = tiledb.Config(config)
super().__init__(ctx, ccfg)
ccfg = tiledb.Config(self.config_dict)
super().__init__(self.ctx, ccfg)
else:
super().__init__(ctx)
super().__init__(self.ctx)

def ctx(self) -> Ctx:
"""
Expand Down Expand Up @@ -329,6 +329,21 @@ def touch(self, uri: _AnyPath):
isfile = is_file
size = file_size

# pickling support
def __getstate__(self):
# self.config_dict might not exist. In that case use the config from ctx.
if hasattr(self, "config_dict"):
config_dict = self.config_dict
else:
config_dict = self.config().dict()
return (config_dict,)

def __setstate__(self, state):
config_dict = state[0]
config = Config(params=config_dict)
ctx = Ctx(config)
self.__init__(config=config, ctx=ctx)


class FileIO(io.RawIOBase):
"""TileDB FileIO class that encapsulates files opened by tiledb.VFS. The file
Expand Down

0 comments on commit 806d1ae

Please sign in to comment.