diff --git a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py index b0f5022446de15..b8c27666d7f538 100644 --- a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py +++ b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py @@ -1,6 +1,7 @@ import collections import gzip import logging +import os import pathlib import pickle import shutil @@ -33,6 +34,14 @@ logger: logging.Logger = logging.getLogger(__name__) +OVERRIDE_SQLITE_VERSION_REQUIREMENT_STR = ( + os.environ.get("OVERRIDE_SQLITE_VERSION_REQ") or "" +) +OVERRIDE_SQLITE_VERSION_REQUIREMENT = ( + OVERRIDE_SQLITE_VERSION_REQUIREMENT_STR + and OVERRIDE_SQLITE_VERSION_REQUIREMENT_STR.lower() != "false" +) + _DEFAULT_FILE_NAME = "sqlite.db" _DEFAULT_TABLE_NAME = "data" @@ -212,6 +221,7 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]): _active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field( init=False, repr=False ) + _use_sqlite_on_conflict: bool = field(repr=False, default=True) def __post_init__(self) -> None: assert ( @@ -232,7 +242,10 @@ def __post_init__(self) -> None: # We use the ON CONFLICT clause to implement UPSERTs with sqlite. # This was added in 3.24.0 from 2018-06-04. # See https://www.sqlite.org/lang_conflict.html - raise RuntimeError("SQLite version 3.24.0 or later is required") + if OVERRIDE_SQLITE_VERSION_REQUIREMENT: + self.use_sqlite_on_conflict = False + else: + raise RuntimeError("SQLite version 3.24.0 or later is required") # We keep a small cache in memory to avoid having to serialize/deserialize # data from the database too often. We use an OrderedDict to build @@ -295,7 +308,7 @@ def _prune_cache(self, num_items_to_prune: int) -> None: values.append(column_serializer(value)) items_to_write.append(tuple(values)) - if items_to_write: + if items_to_write and self._use_sqlite_on_conflict: # Tricky: By using a INSERT INTO ... ON CONFLICT (key) structure, we can # ensure that the rowid remains the same if a value is updated but is # autoincremented when rows are inserted. @@ -312,6 +325,26 @@ def _prune_cache(self, num_items_to_prune: int) -> None: """, items_to_write, ) + else: + for item in items_to_write: + try: + self._conn.execute( + f"""INSERT INTO {self.tablename} ( + key, + value + {''.join(f', {column_name}' for column_name in self.extra_columns.keys())} + ) + VALUES ({', '.join(['?'] *(2 + len(self.extra_columns)))})""", + item, + ) + except sqlite3.IntegrityError: + self._conn.execute( + f"""UPDATE {self.tablename} SET + value = ? + {''.join(f', {column_name} = ?' for column_name in self.extra_columns.keys())} + WHERE key = ?""", + (*item[1:], item[0]), + ) def flush(self) -> None: self._prune_cache(len(self._active_object_cache)) diff --git a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py index f4062f9a911453..6230c2e37edc6a 100644 --- a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py +++ b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py @@ -15,11 +15,13 @@ ) -def test_file_dict() -> None: +@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False]) +def test_file_dict(use_sqlite_on_conflict: bool) -> None: cache = FileBackedDict[int]( tablename="cache", cache_max_size=10, cache_eviction_batch_size=10, + _use_sqlite_on_conflict=use_sqlite_on_conflict, ) for i in range(100): @@ -92,7 +94,8 @@ def test_file_dict() -> None: cache["a"] = 1 -def test_custom_serde() -> None: +@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False]) +def test_custom_serde(use_sqlite_on_conflict: bool) -> None: @dataclass(frozen=True) class Label: a: str @@ -139,6 +142,7 @@ def deserialize(s: str) -> Main: deserializer=deserialize, # Disable the in-memory cache to force all reads/writes to the DB. cache_max_size=0, + _use_sqlite_on_conflict=use_sqlite_on_conflict, ) first = Main(3, {Label("one", 1): 0.1, Label("two", 2): 0.2}) second = Main(-100, {Label("z", 26): 0.26}) @@ -186,7 +190,8 @@ def test_file_dict_stores_counter() -> None: assert in_memory_counters[i].most_common(2) == cache[str(i)].most_common(2) -def test_file_dict_ordering() -> None: +@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False]) +def test_file_dict_ordering(use_sqlite_on_conflict: bool) -> None: """ We require that FileBackedDict maintains insertion order, similar to Python's built-in dict. This test makes one of each and validates that they behave the same. @@ -196,6 +201,7 @@ def test_file_dict_ordering() -> None: serializer=str, deserializer=int, cache_max_size=1, + _use_sqlite_on_conflict=use_sqlite_on_conflict, ) data = {} @@ -229,12 +235,14 @@ class Pair: @pytest.mark.parametrize("cache_max_size", [0, 1, 10]) -def test_custom_column(cache_max_size: int) -> None: +@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False]) +def test_custom_column(cache_max_size: int, use_sqlite_on_conflict: bool) -> None: cache = FileBackedDict[Pair]( extra_columns={ "x": lambda m: m.x, }, cache_max_size=cache_max_size, + _use_sqlite_on_conflict=use_sqlite_on_conflict, ) cache["first"] = Pair(3, "a") @@ -275,7 +283,8 @@ def test_custom_column(cache_max_size: int) -> None: ] -def test_shared_connection() -> None: +@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False]) +def test_shared_connection(use_sqlite_on_conflict: bool) -> None: with ConnectionWrapper() as connection: cache1 = FileBackedDict[int]( shared_connection=connection, @@ -283,6 +292,7 @@ def test_shared_connection() -> None: extra_columns={ "v": lambda v: v, }, + _use_sqlite_on_conflict=use_sqlite_on_conflict, ) cache2 = FileBackedDict[Pair]( shared_connection=connection, @@ -291,6 +301,7 @@ def test_shared_connection() -> None: "x": lambda m: m.x, "y": lambda m: m.y, }, + _use_sqlite_on_conflict=use_sqlite_on_conflict, ) cache1["a"] = 3