Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Java behaviour w.r.t fmin/fmax/dmin/dmax on Z #20716

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions runtime/compiler/codegen/J9CodeGenerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,16 @@ void addMonClass(TR::Node* monNode, TR_OpaqueClassBlock* clazz);
*/
void setSupportsInlineVectorizedHashCode() { _j9Flags.set(SupportsInlineVectorizedHashCode); }

/** \brief
* Determines whether the code generator supports inlining of java_lang_Math_max/min_F/D
*/
bool getSupportsInlineMath_MaxMin_FD() { return _j9Flags.testAny(SupportsInlineMath_MaxMin_FD); }

/** \brief
* The code generator supports inlining of java_lang_Math_max/min_F/D
*/
void setSupportsInlineMath_MaxMin_FD() { _j9Flags.set(SupportsInlineMath_MaxMin_FD); }

/**
* \brief
* The number of nodes between a monext and the next monent before
Expand Down Expand Up @@ -699,6 +709,7 @@ void addMonClass(TR::Node* monNode, TR_OpaqueClassBlock* clazz);
SupportsInlineVectorizedHashCode = 0x00002000,
SupportsInlineStringCodingHasNegatives = 0x00004000,
SupportsInlineStringCodingCountPositives = 0x00008000,
SupportsInlineMath_MaxMin_FD = 0x00010000,
};

flags32_t _j9Flags;
Expand Down
4 changes: 4 additions & 0 deletions runtime/compiler/env/j9method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5068,6 +5068,10 @@ TR_ResolvedJ9Method::setRecognizedMethodInfo(TR::RecognizedMethod rm)
case TR::java_lang_Math_min_I:
case TR::java_lang_Math_max_L:
case TR::java_lang_Math_min_L:
case TR::java_lang_Math_max_F:
case TR::java_lang_Math_min_F:
case TR::java_lang_Math_max_D:
case TR::java_lang_Math_min_D:
case TR::java_lang_Math_abs_I:
case TR::java_lang_Math_abs_L:
case TR::java_lang_Math_abs_F:
Expand Down
17 changes: 17 additions & 0 deletions runtime/compiler/optimizer/J9RecognizedCallTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,11 @@ bool J9::RecognizedCallTransformer::isInlineable(TR::TreeTop* treetop)
case TR::java_lang_Math_max_L:
case TR::java_lang_Math_min_L:
return !comp()->getOption(TR_DisableMaxMinOptimization);
case TR::java_lang_Math_max_F:
case TR::java_lang_Math_min_F:
case TR::java_lang_Math_max_D:
case TR::java_lang_Math_min_D:
return !comp()->getOption(TR_DisableMaxMinOptimization) && cg()->getSupportsInlineMath_MaxMin_FD();
case TR::java_lang_Math_multiplyHigh:
return cg()->getSupportsLMulHigh();
case TR::java_lang_StringUTF16_toBytes:
Expand Down Expand Up @@ -1563,6 +1568,18 @@ void J9::RecognizedCallTransformer::transform(TR::TreeTop* treetop)
case TR::java_lang_Math_min_L:
processIntrinsicFunction(treetop, node, TR::lmin);
break;
case TR::java_lang_Math_max_F:
processIntrinsicFunction(treetop, node, TR::fmax);
break;
case TR::java_lang_Math_min_F:
processIntrinsicFunction(treetop, node, TR::fmin);
break;
case TR::java_lang_Math_max_D:
processIntrinsicFunction(treetop, node, TR::dmax);
break;
case TR::java_lang_Math_min_D:
processIntrinsicFunction(treetop, node, TR::dmin);
break;
case TR::java_lang_Math_multiplyHigh:
processIntrinsicFunction(treetop, node, TR::lmulh);
break;
Expand Down
24 changes: 17 additions & 7 deletions runtime/compiler/z/codegen/J9CodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ J9::Z::CodeGenerator::initialize()
cg->setSupportsInlineEncodeASCII();
}

static bool disableInlineMath_MaxMin_FD = feGetEnv("TR_disableInlineMaxMin") != NULL;
if (!disableInlineMath_MaxMin_FD)
{
cg->setSupportsInlineMath_MaxMin_FD();
}

static bool disableInlineVectorizedMismatch = feGetEnv("TR_disableInlineVectorizedMismatch") != NULL;
if (cg->getSupportsArrayCmpLen() &&
#if defined(J9VM_GC_SPARSE_HEAP_ALLOCATION)
Expand Down Expand Up @@ -4154,20 +4160,24 @@ J9::Z::CodeGenerator::inlineDirectCall(
}
}

if (!comp->getOption(TR_DisableSIMDDoubleMaxMin) && cg->getSupportsVectorRegisters())
{
switch (methodSymbol->getRecognizedMethod())
{
if (!self()->comp()->getOption(TR_DisableMaxMinOptimization) && cg->getSupportsInlineMath_MaxMin_FD()) {
switch (methodSymbol->getRecognizedMethod()) {
case TR::java_lang_Math_max_D:
resultReg = TR::TreeEvaluator::inlineDoubleMax(node, cg);
resultReg = J9::Z::TreeEvaluator::dmaxEvaluator(node, cg);
return true;
case TR::java_lang_Math_min_D:
resultReg = TR::TreeEvaluator::inlineDoubleMin(node, cg);
resultReg = J9::Z::TreeEvaluator::dminEvaluator(node, cg);
return true;
case TR::java_lang_Math_max_F:
resultReg = J9::Z::TreeEvaluator::fmaxEvaluator(node, cg);
return true;
case TR::java_lang_Math_min_F:
resultReg = J9::Z::TreeEvaluator::fminEvaluator(node, cg);
return true;
default:
break;
}
}
}

switch (methodSymbol->getRecognizedMethod())
{
Expand Down
96 changes: 18 additions & 78 deletions runtime/compiler/z/codegen/J9TreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1066,76 +1066,28 @@ allocateWriteBarrierInternalPointerRegister(TR::CodeGenerator * cg, TR::Node * s
}


extern TR::Register *
doubleMaxMinHelper(TR::Node *node, TR::CodeGenerator *cg, bool isMaxOp)
TR::Register *
J9::Z::TreeEvaluator::dmaxEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
TR_ASSERT(node->getNumChildren() >= 1 || node->getNumChildren() <= 2, "node has incorrect number of children");

/* ===================== Allocating Registers ===================== */

TR::Register * v16 = cg->allocateRegister(TR_VRF);
TR::Register * v17 = cg->allocateRegister(TR_VRF);
TR::Register * v18 = cg->allocateRegister(TR_VRF);

/* ===================== Generating instructions ===================== */

/* ====== LD FPR0,16(GPR5) Load a ====== */
TR::Register * v0 = cg->fprClobberEvaluate(node->getFirstChild());

/* ====== LD FPR2, 0(GPR5) Load b ====== */
TR::Register * v2 = cg->evaluate(node->getSecondChild());

/* ====== WFTCIDB V16,V0,X'F' a == NaN ====== */
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v16, v0, 0xF, 8, 3);

/* ====== For Max: WFCHE V17,V0,V2 Compare a >= b ====== */
if(isMaxOp)
{
generateVRRcInstruction(cg, TR::InstOpCode::VFCH, node, v17, v0, v2, 0, 8, 3);
}
/* ====== For Min: WFCHE V17,V0,V2 Compare a <= b ====== */
else
{
generateVRRcInstruction(cg, TR::InstOpCode::VFCH, node, v17, v2, v0, 0, 8, 3);
}

/* ====== VO V16,V16,V17 (a >= b) || (a == NaN) ====== */
generateVRRcInstruction(cg, TR::InstOpCode::VO, node, v16, v16, v17, 0, 0, 0);

/* ====== For Max: WFTCIDB V17,V0,X'800' a == +0 ====== */
if(isMaxOp)
{
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v17, v0, 0x800, 8, 3);
}
/* ====== For Min: WFTCIDB V17,V0,X'400' a == -0 ====== */
else
{
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v17, v0, 0x400, 8, 3);
}
/* ====== WFTCIDB V18,V2,X'C00' b == 0 ====== */
generateVRIeInstruction(cg, TR::InstOpCode::VFTCI, node, v18, v2, 0xC00, 8, 3);

/* ====== VN V17,V17,V18 (a == -0) && (b == 0) ====== */
generateVRRcInstruction(cg, TR::InstOpCode::VN, node, v17, v17, v18, 0, 0, 0);

/* ====== VO V16,V16,V17 (a >= b) || (a == NaN) || ((a == -0) && (b == 0)) ====== */
generateVRRcInstruction(cg, TR::InstOpCode::VO, node, v16, v16, v17, 0, 0, 0);

/* ====== VSEL V0,V0,V2,V16 ====== */
generateVRReInstruction(cg, TR::InstOpCode::VSEL, node, v0, v0, v2, v16);

/* ===================== Deallocating Registers ===================== */
cg->stopUsingRegister(v2);
cg->stopUsingRegister(v16);
cg->stopUsingRegister(v17);
cg->stopUsingRegister(v18);
return OMR::Z::TreeEvaluator::dmaxHelper(node, cg);
}

node->setRegister(v0);
TR::Register *
J9::Z::TreeEvaluator::dminEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
return OMR::Z::TreeEvaluator::dminHelper(node, cg);
}

cg->decReferenceCount(node->getFirstChild());
cg->decReferenceCount(node->getSecondChild());
TR::Register *
J9::Z::TreeEvaluator::fmaxEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
return OMR::Z::TreeEvaluator::fmaxHelper(node, cg);
}

return node->getRegister();
TR::Register *
J9::Z::TreeEvaluator::fminEvaluator(TR::Node * node, TR::CodeGenerator * cg)
{
return OMR::Z::TreeEvaluator::fminHelper(node, cg);
}

TR::Register*
Expand Down Expand Up @@ -2750,19 +2702,7 @@ J9::Z::TreeEvaluator::toLowerIntrinsic(TR::Node *node, TR::CodeGenerator *cg, bo
return caseConversionHelper(node, cg, false, isCompressedString);
}

TR::Register*
J9::Z::TreeEvaluator::inlineDoubleMax(TR::Node *node, TR::CodeGenerator *cg)
{
cg->generateDebugCounter("z13/simd/doubleMax", 1, TR::DebugCounter::Free);
return doubleMaxMinHelper(node, cg, true);
}

TR::Register*
J9::Z::TreeEvaluator::inlineDoubleMin(TR::Node *node, TR::CodeGenerator *cg)
{
cg->generateDebugCounter("z13/simd/doubleMin", 1, TR::DebugCounter::Free);
return doubleMaxMinHelper(node, cg, false);
}

TR::Register *
J9::Z::TreeEvaluator::inlineMathFma(TR::Node *node, TR::CodeGenerator *cg)
Expand Down
6 changes: 4 additions & 2 deletions runtime/compiler/z/codegen/J9TreeEvaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ class OMR_EXTENSIBLE TreeEvaluator: public J9::TreeEvaluator
*/
static TR::Register *inlineVectorizedStringIndexOf(TR::Node *node, TR::CodeGenerator *cg, bool isCompressed);
static TR::Register *inlineIntrinsicIndexOf(TR::Node *node, TR::CodeGenerator *cg, bool isLatin1);
static TR::Register *inlineDoubleMax(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *inlineDoubleMin(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *fminEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *dminEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *fmaxEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *dmaxEvaluator(TR::Node *node, TR::CodeGenerator *cg);
static TR::Register *inlineMathFma(TR::Node *node, TR::CodeGenerator *cg);

/* This Evaluator generates the SIMD routine for methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
package jit.test.recognizedMethod;
import org.testng.AssertJUnit;
import org.testng.annotations.Test;
import java.util.Random;
import org.testng.asserts.SoftAssert;
import static jit.test.recognizedMethod.TestMathUtils.*;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Listeners;
import org.testng.AssertJUnit;



@Test(singleThreaded=true)
public class TestJavaLangMath {

/**
Expand All @@ -39,10 +48,10 @@ public class TestJavaLangMath {
*
* Subsequent tree simplification passes will attempt to reduce this constant
* operation to a <code>dsqrt</code> IL by performing the square root at compile
* time. The transformation will be performed when the function get executed
* time. The transformation will be performed when the function get executed
* twice, therefore, the "invocationCount=2" is needed. However we must ensure the
* result of the square root done by the compiler at compile time will be exactly
* the same as the result had it been done by the Java runtime at runtime. This
* result of the square root done by the compiler at compile time will be exactly
* the same as the result had it been done by the Java runtime at runtime. This
* test validates the results are the same.
*/
@Test(groups = {"level.sanity"}, invocationCount=2)
Expand All @@ -55,4 +64,82 @@ public void test_java_lang_Math_sqrt() {
AssertJUnit.assertEquals(Double.POSITIVE_INFINITY, Math.sqrt(Double.POSITIVE_INFINITY));
AssertJUnit.assertTrue(Double.isNaN(Math.sqrt(Double.NaN)));
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="zeroProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_min_zeros_FD(Number a, Number b, boolean isFirstArg) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
assertEquals(Math.min(f1, f2), isFirstArg ? f1 : f2);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
assertEquals(Math.min(f1, f2), isFirstArg ? f1 : f2);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="zeroProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_max_zeros_FD(Number a, Number b, boolean isFirstArg) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
assertEquals(Math.max(f1, f2), isFirstArg ? f2 : f1);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
assertEquals(Math.max(f1, f2), isFirstArg ? f2 : f1);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="nanProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_min_nan_FD(Number a, Number b, Number expected) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
AssertJUnit.assertTrue(Float.isNaN(Math.min(f1, f2)));
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
AssertJUnit.assertTrue(Double.isNaN(Math.min(f1, f2)));
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="nanProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_max_nan_FD(Number a, Number b, Number expected) {
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
AssertJUnit.assertTrue(Float.isNaN(Math.max(f1, f2)));
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
AssertJUnit.assertTrue(Double.isNaN(Math.max(f1, f2)));
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="normalNumberProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_min_normal_FD(Number a, Number b){
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
AssertJUnit.assertEquals(Math.min(f1, f2), f1 <= f2 ? f1 : f2);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
AssertJUnit.assertEquals(Math.min(f1, f2), f1 <= f2 ? f1 : f2);
}
}

@Test(groups = {"level.sanity"}, invocationCount=2, dataProvider="normalNumberProviderFD", dataProviderClass=TestMathUtils.class)
public void test_java_lang_Math_max_normal_FD(Number a, Number b){
if (a instanceof Float) {
float f1 = a.floatValue();
float f2 = b.floatValue();
AssertJUnit.assertEquals(Math.max(f1, f2), f1 >= f2 ? f1 : f2);
} else {
double f1 = a.doubleValue();
double f2 = b.doubleValue();
AssertJUnit.assertEquals(Math.max(f1, f2), f1 >= f2 ? f1 : f2);
}
}
}
Loading