Skip to content

Commit

Permalink
fix race
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaraditya303 committed Dec 18, 2024
1 parent f802c8b commit 587f9d6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
22 changes: 20 additions & 2 deletions Lib/test/test_capi/test_unicode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import sys
from test import support
from test.support import import_helper
from test.support import threading_helper

try:
import _testcapi
Expand Down Expand Up @@ -1005,6 +1005,24 @@ def test_asutf8(self):
self.assertRaises(TypeError, unicode_asutf8, [], 0)
# CRASHES unicode_asutf8(NULL, 0)

@unittest.skipIf(_testcapi is None, 'need _testcapi module')
@threading_helper.requires_working_threading()
def test_asutf8_race(self):
"""Test that there's no race condition in PyUnicode_AsUTF8()"""
unicode_asutf8 = _testcapi.unicode_asutf8
from threading import Thread

data = "😊"

def worker():
for _ in range(1000):
self.assertEqual(unicode_asutf8(data, 5), b'\xf0\x9f\x98\x8a\0')

threads = [Thread(target=worker) for _ in range(10)]
with threading_helper.start_threads(threads):
pass


@support.cpython_only
@unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module')
def test_asutf8andsize(self):
Expand Down Expand Up @@ -1938,4 +1956,4 @@ def copy(text):


if __name__ == "__main__":
unittest.main()
unittest.main()
46 changes: 30 additions & 16 deletions Objects/unicodeobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ NOTE: In the interpreter's initialization phase, some globals are currently

static inline char* _PyUnicode_UTF8(PyObject *op)
{
return (_PyCompactUnicodeObject_CAST(op)->utf8);
return FT_ATOMIC_LOAD_PTR_ACQUIRE(_PyCompactUnicodeObject_CAST(op)->utf8);
}

static inline char* PyUnicode_UTF8(PyObject *op)
Expand All @@ -130,7 +130,7 @@ static inline char* PyUnicode_UTF8(PyObject *op)

static inline void PyUnicode_SET_UTF8(PyObject *op, char *utf8)
{
_PyCompactUnicodeObject_CAST(op)->utf8 = utf8;
FT_ATOMIC_STORE_PTR_RELEASE(_PyCompactUnicodeObject_CAST(op)->utf8, utf8);
}

static inline Py_ssize_t PyUnicode_UTF8_LENGTH(PyObject *op)
Expand Down Expand Up @@ -700,15 +700,15 @@ _PyUnicode_CheckConsistency(PyObject *op, int check_content)
CHECK(ascii->state.compact == 0);
CHECK(data != NULL);
if (ascii->state.ascii) {
CHECK(compact->utf8 == data);
CHECK(_PyUnicode_UTF8(op) == data);
CHECK(compact->utf8_length == ascii->length);
}
else {
CHECK(compact->utf8 != data);
CHECK(_PyUnicode_UTF8(op) != data);
}
}

if (compact->utf8 == NULL)
if (_PyUnicode_UTF8(op) == NULL)
CHECK(compact->utf8_length == 0);
}

Expand Down Expand Up @@ -1156,8 +1156,8 @@ resize_compact(PyObject *unicode, Py_ssize_t length)

if (_PyUnicode_HAS_UTF8_MEMORY(unicode)) {
PyMem_Free(_PyUnicode_UTF8(unicode));
PyUnicode_SET_UTF8(unicode, NULL);
PyUnicode_SET_UTF8_LENGTH(unicode, 0);
PyUnicode_SET_UTF8(unicode, NULL);
}
#ifdef Py_TRACE_REFS
_Py_ForgetReference(unicode);
Expand Down Expand Up @@ -1210,8 +1210,8 @@ resize_inplace(PyObject *unicode, Py_ssize_t length)
if (!share_utf8 && _PyUnicode_HAS_UTF8_MEMORY(unicode))
{
PyMem_Free(_PyUnicode_UTF8(unicode));
PyUnicode_SET_UTF8(unicode, NULL);
PyUnicode_SET_UTF8_LENGTH(unicode, 0);
PyUnicode_SET_UTF8(unicode, NULL);
}

data = (PyObject *)PyObject_Realloc(data, new_size);
Expand All @@ -1221,8 +1221,8 @@ resize_inplace(PyObject *unicode, Py_ssize_t length)
}
_PyUnicode_DATA_ANY(unicode) = data;
if (share_utf8) {
PyUnicode_SET_UTF8(unicode, data);
PyUnicode_SET_UTF8_LENGTH(unicode, length);
PyUnicode_SET_UTF8(unicode, data);
}
_PyUnicode_LENGTH(unicode) = length;
PyUnicode_WRITE(PyUnicode_KIND(unicode), data, length, 0);
Expand Down Expand Up @@ -4216,6 +4216,21 @@ PyUnicode_FSDecoder(PyObject* arg, void* addr)

static int unicode_fill_utf8(PyObject *unicode);


static int
unicode_ensure_utf8(PyObject *unicode)
{
int err = 0;
if (PyUnicode_UTF8(unicode) == NULL) {
Py_BEGIN_CRITICAL_SECTION(unicode);
if (PyUnicode_UTF8(unicode) == NULL) {
err = unicode_fill_utf8(unicode);
}
Py_END_CRITICAL_SECTION();
}
return err;
}

const char *
PyUnicode_AsUTF8AndSize(PyObject *unicode, Py_ssize_t *psize)
{
Expand All @@ -4227,13 +4242,11 @@ PyUnicode_AsUTF8AndSize(PyObject *unicode, Py_ssize_t *psize)
return NULL;
}

if (PyUnicode_UTF8(unicode) == NULL) {
if (unicode_fill_utf8(unicode) == -1) {
if (psize) {
*psize = -1;
}
return NULL;
if (unicode_ensure_utf8(unicode) == -1) {
if (psize) {
*psize = -1;
}
return NULL;
}

if (psize) {
Expand Down Expand Up @@ -5854,6 +5867,7 @@ unicode_encode_utf8(PyObject *unicode, _Py_error_handler error_handler,
static int
unicode_fill_utf8(PyObject *unicode)
{
_Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(unicode);
/* the string cannot be ASCII, or PyUnicode_UTF8() would be set */
assert(!PyUnicode_IS_ASCII(unicode));

Expand Down Expand Up @@ -5895,10 +5909,10 @@ unicode_fill_utf8(PyObject *unicode)
PyErr_NoMemory();
return -1;
}
PyUnicode_SET_UTF8(unicode, cache);
PyUnicode_SET_UTF8_LENGTH(unicode, len);
memcpy(cache, start, len);
cache[len] = '\0';
PyUnicode_SET_UTF8_LENGTH(unicode, len);
PyUnicode_SET_UTF8(unicode, cache);
_PyBytesWriter_Dealloc(&writer);
return 0;
}
Expand Down

0 comments on commit 587f9d6

Please sign in to comment.