Skip to content

Commit

Permalink
SNOW-1418543 make local testing thread safe (#2185)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Sep 11, 2024
1 parent 195226e commit 7d54d20
Show file tree
Hide file tree
Showing 9 changed files with 1,091 additions and 640 deletions.
311 changes: 175 additions & 136 deletions src/snowflake/snowpark/mock/_connection.py

Large diffs are not rendered by default.

23 changes: 15 additions & 8 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import operator
import re
import string
import threading
from decimal import Decimal
from functools import partial, reduce
from numbers import Real
Expand Down Expand Up @@ -130,14 +131,17 @@ def __call__(self, *args, input_data=None, row_number=None, **kwargs):

class MockedFunctionRegistry:
_instance = None
_lock_init = threading.Lock()

def __init__(self) -> None:
self._registry = dict()
self._lock = threading.RLock()

@classmethod
def get_or_create(cls) -> "MockedFunctionRegistry":
if cls._instance is None:
cls._instance = MockedFunctionRegistry()
with cls._lock_init:
if cls._instance is None:
cls._instance = MockedFunctionRegistry()
return cls._instance

def get_function(
Expand All @@ -151,10 +155,11 @@ def get_function(
distinct = func.is_distinct
func_name = func_name.lower()

if func_name not in self._registry:
return None
with self._lock:
if func_name not in self._registry:
return None

function = self._registry[func_name]
function = self._registry[func_name]

return function.distinct if distinct else function

Expand All @@ -169,7 +174,8 @@ def register(
snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__
)
mocked_function = MockedFunction(name, func_implementation, *args, **kwargs)
self._registry[name] = mocked_function
with self._lock:
self._registry[name] = mocked_function
return mocked_function

def unregister(
Expand All @@ -180,8 +186,9 @@ def unregister(
snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__
)

if name in self._registry:
del self._registry[name]
with self._lock:
if name in self._registry:
del self._registry[name]


class LocalTimezone:
Expand Down
Loading

0 comments on commit 7d54d20

Please sign in to comment.