diff --git a/tiledb/ctx.py b/tiledb/ctx.py index b4898f3767..cc1a6314d6 100644 --- a/tiledb/ctx.py +++ b/tiledb/ctx.py @@ -293,6 +293,20 @@ def save(self, uri: str): """ self.save_to_file(uri) + def __reduce__(self): + """ + Customize the pickling process by defining how to serialize + and reconstruct the Config object. + """ + state = self.dict() + return (self.__class__, (), state) + + def __setstate__(self, state): + """ + Customize how the Config object is restored from its serialized state. + """ + self.__init__(state) + class ConfigKeys: """ diff --git a/tiledb/tests/conftest.py b/tiledb/tests/conftest.py index 6565dd89e9..ecce429e50 100644 --- a/tiledb/tests/conftest.py +++ b/tiledb/tests/conftest.py @@ -52,54 +52,6 @@ def pytest_configure(config): # default must be set here rather than globally pytest.tiledb_vfs = "file" - vfs_config(config) - - -def vfs_config(pytestconfig): - vfs_config_override = {} - - vfs = pytestconfig.getoption("vfs") - if vfs == "s3": - pytest.tiledb_vfs = "s3" - - vfs_config_override.update( - { - "vfs.s3.endpoint_override": "localhost:9999", - "vfs.s3.aws_access_key_id": "minio", - "vfs.s3.aws_secret_access_key": "miniosecretkey", - "vfs.s3.scheme": "https", - "vfs.s3.verify_ssl": False, - "vfs.s3.use_virtual_addressing": False, - } - ) - - vfs_config_arg = pytestconfig.getoption("vfs-config", None) - if vfs_config_arg: - pass - - tiledb._orig_ctx = tiledb.Ctx - - def get_config(config): - final_config = {} - if isinstance(config, tiledb.Config): - final_config = config.dict() - elif config: - final_config = config - - final_config.update(vfs_config_override) - return final_config - - class PatchedCtx(tiledb.Ctx): - def __init__(self, config=None): - super().__init__(get_config(config)) - - class PatchedConfig(tiledb.Config): - def __init__(self, params=None): - super().__init__(get_config(params)) - - tiledb.Ctx = PatchedCtx - tiledb.Config = PatchedConfig - @pytest.fixture(scope="function", autouse=True) def isolate_os_fork(original_os_fork): diff --git a/tiledb/tests/test_context_and_config.py b/tiledb/tests/test_context_and_config.py index 054338991c..13ce7b72fa 100644 --- a/tiledb/tests/test_context_and_config.py +++ b/tiledb/tests/test_context_and_config.py @@ -1,4 +1,6 @@ +import io import os +import pickle import subprocess import sys import xml @@ -261,3 +263,20 @@ def test_config_repr_html(self): pytest.fail( f"Could not parse config._repr_html_(). Saw {config._repr_html_()}" ) + + def test_config_pickle(self): + # test that Config can be pickled and unpickled + config = tiledb.Config( + { + "rest.use_refactored_array_open": "false", + "rest.use_refactored_array_open_and_query_submit": "true", + "vfs.azure.storage_account_name": "myaccount", + } + ) + with io.BytesIO() as buf: + pickle.dump(config, buf) + buf.seek(0) + config2 = pickle.load(buf) + + self.assertIsInstance(config2, tiledb.Config) + self.assertEqual(config2.dict(), config.dict()) diff --git a/tiledb/tests/test_vfs.py b/tiledb/tests/test_vfs.py index 8219539a02..da71057741 100644 --- a/tiledb/tests/test_vfs.py +++ b/tiledb/tests/test_vfs.py @@ -1,6 +1,7 @@ import io import os import pathlib +import pickle import random import sys @@ -239,6 +240,21 @@ def test_io(self): txtio = io.TextIOWrapper(f2, encoding="utf-8") self.assertEqual(txtio.readlines(), lines) + def test_pickle(self): + # test that vfs can be pickled and unpickled with config options + config = tiledb.Config( + {"vfs.s3.region": "eu-west-1", "vfs.max_parallel_ops": "1"} + ) + vfs = tiledb.VFS(config) + with io.BytesIO() as buf: + pickle.dump(vfs, buf) + buf.seek(0) + vfs2 = pickle.load(buf) + + self.assertIsInstance(vfs2, tiledb.VFS) + self.assertEqual(vfs2.config()["vfs.s3.region"], "eu-west-1") + self.assertEqual(vfs2.config()["vfs.max_parallel_ops"], "1") + def test_sc42569_vfs_memoryview(self): # This test is to ensure that giving np.ndarray buffer to readinto works # when trying to write bytes that cannot be converted to float32 or int32 diff --git a/tiledb/vfs.py b/tiledb/vfs.py index 7a703fae06..5b27131915 100644 --- a/tiledb/vfs.py +++ b/tiledb/vfs.py @@ -25,7 +25,7 @@ class VFS(lt.VFS): """ def __init__(self, config: Union[Config, dict] = None, ctx: Optional[Ctx] = None): - ctx = ctx or default_ctx() + self.ctx = ctx or default_ctx() if config: from .libtiledb import Config @@ -39,12 +39,12 @@ def __init__(self, config: Union[Config, dict] = None, ctx: Optional[Ctx] = None raise ValueError("`config` argument must be of type Config or dict") # Convert all values to strings - config = {k: str(v) for k, v in config.items()} + self.config_dict = {k: str(v) for k, v in config.items()} - ccfg = tiledb.Config(config) - super().__init__(ctx, ccfg) + ccfg = tiledb.Config(self.config_dict) + super().__init__(self.ctx, ccfg) else: - super().__init__(ctx) + super().__init__(self.ctx) def ctx(self) -> Ctx: """ @@ -329,6 +329,21 @@ def touch(self, uri: _AnyPath): isfile = is_file size = file_size + # pickling support + def __getstate__(self): + # self.config_dict might not exist. In that case use the config from ctx. + if hasattr(self, "config_dict"): + config_dict = self.config_dict + else: + config_dict = self.config().dict() + return (config_dict,) + + def __setstate__(self, state): + config_dict = state[0] + config = Config(params=config_dict) + ctx = Ctx(config) + self.__init__(config=config, ctx=ctx) + class FileIO(io.RawIOBase): """TileDB FileIO class that encapsulates files opened by tiledb.VFS. The file