Skip to content

Commit

Permalink
Merge pull request #20362 from babsingh/main9
Browse files Browse the repository at this point in the history
Fix overflow issues
  • Loading branch information
keithc-ca authored Oct 16, 2024
2 parents d51176e + 619c7da commit 85753ad
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 77 deletions.
32 changes: 24 additions & 8 deletions runtime/j9vm/java11vmi.c
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,12 @@ JVM_DefineModule(JNIEnv * env, jobject module, jboolean isOpen, jstring version,
j9array_t array = (j9array_t)J9_JNI_UNWRAP_REFERENCE(packageArray);
j9object_t stringObject = J9JAVAARRAYOFOBJECT_LOAD(currentThread, array, pkgIndex);
if (NULL != stringObject) {
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject) + 1;
char *packageName = (char*)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject);
char *packageName = NULL;
if (utfLength < UDATA_MAX) {
utfLength += 1;
packageName = (char *)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
}
if (NULL == packageName) {
oom = TRUE;
break;
Expand Down Expand Up @@ -992,8 +996,12 @@ JVM_AddModuleExports(JNIEnv * env, jobject fromModule, const char *package, jobj
#if JAVA_SPEC_VERSION >= 15
if (NULL != packageObj) {
j9object_t stringObject = J9_JNI_UNWRAP_REFERENCE(packageObj);
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject) + 1;
char* packageName = (char *)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject);
char *packageName = NULL;
if (utfLength < UDATA_MAX) {
utfLength += 1;
packageName = (char *)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
}
if (NULL == packageName) {
vmFuncs->setNativeOutOfMemoryError(currentThread, 0, 0);
goto done;
Expand Down Expand Up @@ -1066,8 +1074,12 @@ JVM_AddModuleExportsToAll(JNIEnv * env, jobject fromModule, const char *package)
#if JAVA_SPEC_VERSION >= 15
if (NULL != packageObj) {
j9object_t stringObject = J9_JNI_UNWRAP_REFERENCE(packageObj);
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject) + 1;
char* packageName = (char *)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject);
char *packageName = NULL;
if (utfLength < UDATA_MAX) {
utfLength += 1;
packageName = (char *)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
}
if (NULL == packageName) {
vmFuncs->setNativeOutOfMemoryError(currentThread, 0, 0);
goto done;
Expand Down Expand Up @@ -1306,8 +1318,12 @@ JVM_AddModuleExportsToAllUnnamed(JNIEnv * env, jobject fromModule, const char *p
#if JAVA_SPEC_VERSION >= 15
if (NULL != packageObj) {
j9object_t stringObject = J9_JNI_UNWRAP_REFERENCE(packageObj);
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject) + 1;
char* packageName = (char *)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
UDATA utfLength = vmFuncs->getStringUTF8Length(currentThread, stringObject);
char *packageName = NULL;
if (utfLength < UDATA_MAX) {
utfLength += 1;
packageName = (char *)j9mem_allocate_memory(utfLength, OMRMEM_CATEGORY_VM);
}
if (NULL == packageName) {
vmFuncs->setNativeOutOfMemoryError(currentThread, 0, 0);
goto done;
Expand Down
80 changes: 55 additions & 25 deletions runtime/jcl/common/java_lang_invoke_MethodHandleNatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,10 @@ getClassSignatureLength(J9VMThread *currentThread, J9Class *clazz)
j9object_t sigString = J9VMJAVALANGCLASS_CLASSNAMESTRING(currentThread, J9VM_J9CLASS_TO_HEAPCLASS(clazz));
if (NULL != sigString) {
/* +2 so that we can fit 'L' and ';' around the class name. */
signatureLength = vm->internalVMFunctions->getStringUTF8Length(currentThread, sigString) + 2;
signatureLength = vm->internalVMFunctions->getStringUTF8Length(currentThread, sigString);
if (signatureLength <= (UDATA_MAX - 2)) {
signatureLength += 2;
}
} else {
J9Class *myClass = clazz;
UDATA numDims = 0;
Expand Down Expand Up @@ -470,24 +473,27 @@ getClassSignatureInout(J9VMThread *currentThread, J9Class *clazz, LocalJ9UTF8Buf
j9object_t sigString = J9VMJAVALANGCLASS_CLASSNAMESTRING(currentThread, J9VM_J9CLASS_TO_HEAPCLASS(clazz));
if (NULL != sigString) {
/* +3 so that we can fit 'L' and ';' around the class name and add null-terminator. */
UDATA utfLength = vm->internalVMFunctions->getStringUTF8Length(currentThread, sigString) + 3;
if (utfLength <= stringBuffer->remaining()) {
if (J9ROMCLASS_IS_ARRAY(clazz->romClass)) {
vm->internalVMFunctions->copyStringToUTF8Helper(
currentThread, sigString, J9_STR_XLAT, 0, J9VMJAVALANGSTRING_LENGTH(currentThread, sigString),
stringBuffer->cursor, utfLength - 3);
/* Adjust cursor to account for the call to copyStringToUTF8Helper. */
stringBuffer->advanceN(utfLength - 3);
} else {
stringBuffer->putCharAtCursor('L');
vm->internalVMFunctions->copyStringToUTF8Helper(
currentThread, sigString, J9_STR_XLAT, 0, J9VMJAVALANGSTRING_LENGTH(currentThread, sigString),
stringBuffer->cursor, utfLength - 3);
/* Adjust cursor to account for the call to copyStringToUTF8Helper. */
stringBuffer->advanceN(utfLength - 3);
stringBuffer->putCharAtCursor(';');
UDATA utfLength = vm->internalVMFunctions->getStringUTF8Length(currentThread, sigString);
if (utfLength <= (UDATA_MAX - 3)) {
utfLength += 3;
if (utfLength <= stringBuffer->remaining()) {
if (J9ROMCLASS_IS_ARRAY(clazz->romClass)) {
vm->internalVMFunctions->copyStringToUTF8Helper(
currentThread, sigString, J9_STR_XLAT, 0, J9VMJAVALANGSTRING_LENGTH(currentThread, sigString),
stringBuffer->cursor, utfLength - 3);
/* Adjust cursor to account for the call to copyStringToUTF8Helper. */
stringBuffer->advanceN(utfLength - 3);
} else {
stringBuffer->putCharAtCursor('L');
vm->internalVMFunctions->copyStringToUTF8Helper(
currentThread, sigString, J9_STR_XLAT, 0, J9VMJAVALANGSTRING_LENGTH(currentThread, sigString),
stringBuffer->cursor, utfLength - 3);
/* Adjust cursor to account for the call to copyStringToUTF8Helper. */
stringBuffer->advanceN(utfLength - 3);
stringBuffer->putCharAtCursor(';');
}
result = true;
}
result = true;
}
} else {
J9Class *myClass = clazz;
Expand Down Expand Up @@ -557,20 +563,34 @@ getJ9UTF8SignatureFromMethodTypeWithMemAlloc(J9VMThread *currentThread, j9object
j9object_t ptypes = J9VMJAVALANGINVOKEMETHODTYPE_PTYPES(currentThread, typeObject);
U_32 numArgs = J9INDEXABLEOBJECT_SIZE(currentThread, ptypes);
UDATA signatureLength = 2; /* space for '(', ')' */
UDATA tempSignatureLength = 0;
UDATA signatureUtf8Size = 0;
J9UTF8 *result = NULL;
j9object_t rtype = NULL;
J9Class *rclass = NULL;
PORT_ACCESS_FROM_JAVAVM(vm);

/* Calculate total signature length, including all ptypes and rtype. */
for (U_32 i = 0; i < numArgs; i++) {
j9object_t pObject = J9JAVAARRAYOFOBJECT_LOAD(currentThread, ptypes, i);
J9Class *pclass = J9VM_J9CLASS_FROM_HEAPCLASS(currentThread, pObject);
signatureLength += getClassSignatureLength(currentThread, pclass);
tempSignatureLength = getClassSignatureLength(currentThread, pclass);
if (signatureLength > (J9UTF8_MAX_LENGTH - tempSignatureLength)) {
goto done;
}
signatureLength += tempSignatureLength;
}
j9object_t rtype = J9VMJAVALANGINVOKEMETHODTYPE_RTYPE(currentThread, typeObject);
J9Class *rclass = J9VM_J9CLASS_FROM_HEAPCLASS(currentThread, rtype);
signatureLength += getClassSignatureLength(currentThread, rclass);
rtype = J9VMJAVALANGINVOKEMETHODTYPE_RTYPE(currentThread, typeObject);
rclass = J9VM_J9CLASS_FROM_HEAPCLASS(currentThread, rtype);
tempSignatureLength = getClassSignatureLength(currentThread, rclass);
if (signatureLength > (J9UTF8_MAX_LENGTH - tempSignatureLength)) {
goto done;
}
signatureLength += tempSignatureLength;

signatureUtf8Size = signatureLength + sizeof(J9UTF8) + 1; /* +1 for a null-terminator */
result = reinterpret_cast<J9UTF8 *>(j9mem_allocate_memory(signatureUtf8Size, OMRMEM_CATEGORY_VM));

UDATA signatureUtf8Size = signatureLength + sizeof(J9UTF8) + 1; /* +1 for a null-terminator */
J9UTF8 *result = reinterpret_cast<J9UTF8 *>(j9mem_allocate_memory(signatureUtf8Size, OMRMEM_CATEGORY_VM));
if (NULL != result) {
LocalJ9UTF8Buffer stringBuffer(result, signatureUtf8Size);

Expand All @@ -588,6 +608,7 @@ getJ9UTF8SignatureFromMethodTypeWithMemAlloc(J9VMThread *currentThread, j9object
stringBuffer.commitLength();
}

done:
return result;
}

Expand Down Expand Up @@ -1036,12 +1057,21 @@ Java_java_lang_invoke_MethodHandleNatives_resolve(
} else {
LocalJ9UTF8Buffer stringBuffer(reinterpret_cast<J9UTF8 *>(signatureBuffer), sizeof(signatureBuffer));
signature = getJ9UTF8SignatureFromMethodType(currentThread, typeObject, &stringBuffer);
if (NULL == signature) {
vmFuncs->setCurrentExceptionUTF(currentThread, J9VMCONSTANTPOOL_JAVALANGINTERNALERROR, NULL);
goto done;
}
}
} else if (J9VMJAVALANGSTRING_OR_NULL(vm) == typeClass) {
signature = vmFuncs->copyStringToJ9UTF8WithMemAlloc(currentThread, typeObject, J9_STR_XLAT, "", 0, signatureBuffer, sizeof(signatureBuffer));
} else if (J9VMJAVALANGCLASS(vm) == typeClass) {
J9Class *rclass = J9VM_J9CLASS_FROM_HEAPCLASS(currentThread, typeObject);
UDATA signatureLength = getClassSignatureLength(currentThread, rclass) + sizeof(J9UTF8) + 1 /* null-terminator */;
UDATA signatureLength = getClassSignatureLength(currentThread, rclass);
if (signatureLength > J9UTF8_MAX_LENGTH) {
vmFuncs->setCurrentExceptionUTF(currentThread, J9VMCONSTANTPOOL_JAVALANGINTERNALERROR, NULL);
goto done;
}
signatureLength += sizeof(J9UTF8) + 1 /* null-terminator */;
LocalJ9UTF8Buffer stringBuffer;
if (signatureLength <= sizeof(signatureBuffer)) {
stringBuffer = LocalJ9UTF8Buffer(reinterpret_cast<J9UTF8 *>(signatureBuffer), sizeof(signatureBuffer));
Expand Down
5 changes: 4 additions & 1 deletion runtime/oti/j9nonbuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -3599,6 +3599,8 @@ typedef struct J9UTF8 {
#pragma warning(pop)
#endif /* defined(_MSC_VER) */

#define J9UTF8_MAX_LENGTH U_16_MAX

typedef struct J9ROMClass {
U_32 romSize;
U_32 singleScalarStaticCount;
Expand Down Expand Up @@ -4815,7 +4817,8 @@ typedef struct J9InternalVMFunctions {
struct J9Class* ( *internalFindKnownClass)(struct J9VMThread *currentThread, UDATA index, UDATA flags) ;
struct J9Class* ( *resolveKnownClass)(struct J9JavaVM * vm, UDATA index) ;
UDATA ( *computeHashForUTF8)(const U_8 * string, UDATA size) ;
IDATA ( *getStringUTF8Length)(struct J9VMThread *vmThread, j9object_t string) ;
UDATA ( *getStringUTF8Length)(struct J9VMThread *vmThread, j9object_t string) ;
U_64 ( *getStringUTF8LengthTruncated)(struct J9VMThread *vmThread, j9object_t string, U_64 maxLength) ;
void ( *acquireExclusiveVMAccess)(struct J9VMThread * vmThread) ;
void ( *releaseExclusiveVMAccess)(struct J9VMThread * vmThread) ;
void ( *internalReleaseVMAccess)(struct J9VMThread * currentThread) ;
Expand Down
31 changes: 24 additions & 7 deletions runtime/oti/vm_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3583,13 +3583,30 @@ copyStringToUTF8Helper(J9VMThread *vmThread, j9object_t string, UDATA stringFlag


/**
* @brief
* @param *vm
* @param *string
* @return IDATA
*/
IDATA
getStringUTF8Length(J9VMThread *vmThread,j9object_t string);
* @brief Find the length of the string object when it is converted to UTF-8.
*
* Note: On 32-bit platforms, the length may be truncated.
*
* @param vm a pointer to J9JavaVM
* @param string a string object
*
* @return the length of the string in UTF-8
*/
UDATA
getStringUTF8Length(J9VMThread *vmThread, j9object_t string);

/**
* @brief Find the length of the string object when it is converted to UTF-8, but truncate it
* using maxLength as the upper bound.
*
* @param vm a pointer to J9JavaVM
* @param string a string object
* @param maxLength the upper bound of the length used for truncation
*
* @return the length of the string in UTF-8
*/
U_64
getStringUTF8LengthTruncated(J9VMThread *vmThread, j9object_t string, U_64 maxLength);


/**
Expand Down
1 change: 1 addition & 0 deletions runtime/vm/intfunc.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ J9InternalVMFunctions J9InternalFunctions = {
resolveKnownClass,
computeHashForUTF8,
getStringUTF8Length,
getStringUTF8LengthTruncated,
acquireExclusiveVMAccess,
releaseExclusiveVMAccess,
internalReleaseVMAccess,
Expand Down
17 changes: 10 additions & 7 deletions runtime/vm/jnimisc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ getStringUTFLength(JNIEnv *env, jstring string)
VM_VMAccess::inlineEnterVMFromJNI(currentThread);
j9object_t stringObject = J9_JNI_UNWRAP_REFERENCE(string);

UDATA utfLength = getStringUTF8Length(currentThread, stringObject);
U_64 utfLength = getStringUTF8LengthTruncated(currentThread, stringObject, INT32_MAX);
VM_VMAccess::inlineExitVMToJNI(currentThread);
return (jsize)utfLength;
}
Expand All @@ -840,14 +840,17 @@ getStringUTFCharsImpl(JNIEnv *env, jstring string, jboolean *isCopy, jboolean en
J9VMThread *currentThread = (J9VMThread*)env;
VM_VMAccess::inlineEnterVMFromJNI(currentThread);
j9object_t stringObject = J9_JNI_UNWRAP_REFERENCE(string);
/* Add 1 for null terminator */
UDATA utfLength = getStringUTF8Length(currentThread, stringObject) + 1;

UDATA utfLength = getStringUTF8Length(currentThread, stringObject);
U_8 *utfChars = NULL;
if (ensureMem32) {
utfChars = (U_8*)jniArrayAllocateMemory32FromThread(currentThread, utfLength);
} else {
utfChars = (U_8*)jniArrayAllocateMemoryFromThread(currentThread, utfLength);
if (utfLength < UDATA_MAX) {
/* Add 1 for a null terminator. */
utfLength += 1;
if (ensureMem32) {
utfChars = (U_8 *)jniArrayAllocateMemory32FromThread(currentThread, utfLength);
} else {
utfChars = (U_8 *)jniArrayAllocateMemoryFromThread(currentThread, utfLength);
}
}

if (NULL == utfChars) {
Expand Down
Loading

0 comments on commit 85753ad

Please sign in to comment.