Skip to content

Commit

Permalink
add udf/sp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam committed Sep 3, 2024
1 parent 318a589 commit dc49a2b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 20 deletions.
3 changes: 2 additions & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3041,7 +3041,8 @@ def _use_object(self, object_name: str, object_type: str) -> None:
# we do not validate here
object_type = match.group(1)
object_name = match.group(2)
setattr(self._conn, f"_active_{object_type}", object_name)
with self._conn._lock:
setattr(self._conn, f"_active_{object_type}", object_name)
else:
self._run_query(query)
else:
Expand Down
84 changes: 65 additions & 19 deletions tests/mock/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,62 @@ def test_table_update_delete_insert():
pass


def test_udf():
pass
def test_udf_register_and_invoke(session):
df = session.create_dataframe([[1], [2]], schema=["num"])
num_threads = 10

def register_udf(x: int):
def echo(x: int) -> int:
return x

def test_sp():
pass
return session.udf.register(echo, name="echo", replace=True)

def invoke_udf():
result = df.select(session.udf.call_udf("echo", df.num)).collect()
assert result[0][0] == 1
assert result[1][0] == 2

threads = []
for i in range(num_threads):
thread_register = Thread(target=register_udf, args=(i,))
threads.append(thread_register)
thread_register.start()

thread_invoke = Thread(target=invoke_udf)
threads.append(thread_invoke)
thread_invoke.start()

for thread in threads:
thread.join()


def test_sp_register_and_invoke(session):
num_threads = 10

def increment_by_one_fn(session_: Session, x: int) -> int:
return x + 1

def register_sproc():
session.sproc.register(
increment_by_one_fn, name="increment_by_one", replace=True
)

def invoke_sproc():
result = session.call("increment_by_one", 1)
assert result == 2

threads = []
for i in range(num_threads):
thread_register = Thread(target=register_sproc, args=(i,))
threads.append(thread_register)
thread_register.start()

thread_invoke = Thread(target=invoke_sproc)
threads.append(thread_invoke)
thread_invoke.start()

for thread in threads:
thread.join()


def test_mocked_function_registry_created_once():
Expand Down Expand Up @@ -122,19 +172,9 @@ def put_and_get_file():
thread.join()


def test_stage_entity_registry_upload_and_read():
# upload a json to stage_i and read it back
test_parameter = {
"account": "test_account",
"user": "test_user",
"schema": "test_schema",
"database": "test_database",
"warehouse": "test_warehouse",
"role": "test_role",
"local_testing": True,
}
session = Session.builder.configs(options=test_parameter).create()
def test_stage_entity_registry_upload_and_read(session):
stage_registry = StageEntityRegistry(MockServerConnection())
num_threads = 10

def upload_and_read_json(thread_id: int):
json_string = json.dumps({"thread_id": thread_id})
Expand All @@ -152,13 +192,19 @@ def upload_and_read_json(thread_id: int):
session._analyzer,
{"INFER_SCHEMA": "True"},
)
# TODO: read table emulator and compare results
assert df == thread_id

assert df['"thread_id"'].iloc[0] == thread_id

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(upload_and_read_json, i) for i in range(num_threads)]

for future in as_completed(futures):
future.result()


def test_stage_entity_registry_create_or_replace():
stage_registry = StageEntityRegistry(MockServerConnection())
num_threads = 100
num_threads = 10

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [
Expand Down

0 comments on commit dc49a2b

Please sign in to comment.