Skip to content

Commit

Permalink
add test on thread safety
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Aug 28, 2024
1 parent 4e90a70 commit cb5faa1
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,7 @@ def run(self):
test.join()

def test_multithread(self):

import threading

# Running evaluate() from multiple threads shouldn't crash
Expand All @@ -1218,6 +1219,77 @@ def work(n):
for t in threads:
t.join()

def test_thread_safety(self):
"""
Expected output
When not safe (before the pr this test is commited)
AssertionError: Thread-0 failed: result does not match expected
When safe (after the pr this test is commited)
Should pass without failure
"""
import threading
import time

barrier = threading.Barrier(4)

# Function that each thread will run with different expressions
def thread_function(a_value, b_value, expression, expected_result, results, index):
validate(expression, local_dict={"a": a_value, "b": b_value})
# Wait for all threads to reach this point
# such that they all set _numexpr_last
barrier.wait()

# Simulate some work or a context switch delay
time.sleep(0.1)

result = re_evaluate(local_dict={"a": a_value, "b": b_value})
results[index] = np.array_equal(result, expected_result)

def test_thread_safety_with_numexpr():
num_threads = 4
array_size = 1000000

expressions = [
"a + b",
"a - b",
"a * b",
"a / b"
]

a_value = [np.full(array_size, i + 1) for i in range(num_threads)]
b_value = [np.full(array_size, (i + 1) * 2) for i in range(num_threads)]

expected_results = [
a_value[i] + b_value[i] if expr == "a + b" else
a_value[i] - b_value[i] if expr == "a - b" else
a_value[i] * b_value[i] if expr == "a * b" else
a_value[i] / b_value[i] if expr == "a / b" else None
for i, expr in enumerate(expressions)
]

results = [None] * num_threads
threads = []

# Create and start threads with different expressions
for i in range(num_threads):
thread = threading.Thread(
target=thread_function,
args=(a_value[i], b_value[i], expressions[i], expected_results[i], results, i)
)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

for i in range(num_threads):
if not results[i]:
self.fail(f"Thread-{i} failed: result does not match expected")

test_thread_safety_with_numexpr()


# The worker function for the subprocess (needs to be here because Windows
# has problems pickling nested functions with the multiprocess module :-/)
Expand Down

0 comments on commit cb5faa1

Please sign in to comment.