Skip to content

Commit

Permalink
[HLSL] Shore up floating point conversions (#90222)
Browse files Browse the repository at this point in the history
This PR fixes bugs in HLSL floating conversions. HLSL always has `half`,
`float` and `double` types, which promote in the order:

`half`->`float`->`double`

and convert in the order:

`double`->`float`->`half`

As with other conversions in C++, promotions are preferred over
conversions.

We do have floating conversions documented in the draft language
specification (https://microsoft.github.io/hlsl-specs/specs/hlsl.pdf
[Conv.rank.float]) although the exact language is still in flux
(microsoft/hlsl-specs#206).

Resolves #81047
  • Loading branch information
llvm-beanz authored May 2, 2024
1 parent e06d6ed commit aa5ff68
Show file tree
Hide file tree
Showing 4 changed files with 507 additions and 19 deletions.
43 changes: 42 additions & 1 deletion clang/lib/Sema/SemaOverload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2587,7 +2587,8 @@ bool Sema::IsIntegralPromotion(Expr *From, QualType FromType, QualType ToType) {

// In HLSL an rvalue of integral type can be promoted to an rvalue of a larger
// integral type.
if (Context.getLangOpts().HLSL)
if (Context.getLangOpts().HLSL && FromType->isIntegerType() &&
ToType->isIntegerType())
return Context.getTypeSize(FromType) < Context.getTypeSize(ToType);

return false;
Expand Down Expand Up @@ -2616,6 +2617,13 @@ bool Sema::IsFloatingPointPromotion(QualType FromType, QualType ToType) {
ToBuiltin->getKind() == BuiltinType::Ibm128))
return true;

// In HLSL, `half` promotes to `float` or `double`, regardless of whether
// or not native half types are enabled.
if (getLangOpts().HLSL && FromBuiltin->getKind() == BuiltinType::Half &&
(ToBuiltin->getKind() == BuiltinType::Float ||
ToBuiltin->getKind() == BuiltinType::Double))
return true;

// Half can be promoted to float.
if (!getLangOpts().NativeHalfType &&
FromBuiltin->getKind() == BuiltinType::Half &&
Expand Down Expand Up @@ -4393,6 +4401,24 @@ getFixedEnumPromtion(Sema &S, const StandardConversionSequence &SCS) {
return FixedEnumPromotion::ToPromotedUnderlyingType;
}

static ImplicitConversionSequence::CompareKind
HLSLCompareFloatingRank(QualType LHS, QualType RHS) {
assert(LHS->isVectorType() == RHS->isVectorType() &&
"Either both elements should be vectors or neither should.");
if (const auto *VT = LHS->getAs<VectorType>())
LHS = VT->getElementType();

if (const auto *VT = RHS->getAs<VectorType>())
RHS = VT->getElementType();

const auto L = LHS->getAs<BuiltinType>()->getKind();
const auto R = RHS->getAs<BuiltinType>()->getKind();
if (L == R)
return ImplicitConversionSequence::Indistinguishable;
return L < R ? ImplicitConversionSequence::Better
: ImplicitConversionSequence::Worse;
}

/// CompareStandardConversionSequences - Compare two standard
/// conversion sequences to determine whether one is better than the
/// other or if they are indistinguishable (C++ 13.3.3.2p3).
Expand Down Expand Up @@ -4634,6 +4660,21 @@ CompareStandardConversionSequences(Sema &S, SourceLocation Loc,
: ImplicitConversionSequence::Worse;
}

if (S.getLangOpts().HLSL) {
// On a promotion we prefer the lower rank to disambiguate.
if ((SCS1.Second == ICK_Floating_Promotion &&
SCS2.Second == ICK_Floating_Promotion) ||
(SCS1.Element == ICK_Floating_Promotion &&
SCS2.Element == ICK_Floating_Promotion))
return HLSLCompareFloatingRank(SCS1.getToType(2), SCS2.getToType(2));
// On a conversion we prefer the higher rank to disambiguate.
if ((SCS1.Second == ICK_Floating_Conversion &&
SCS2.Second == ICK_Floating_Conversion) ||
(SCS1.Element == ICK_Floating_Conversion &&
SCS2.Element == ICK_Floating_Conversion))
return HLSLCompareFloatingRank(SCS2.getToType(2), SCS1.getToType(2));
}

return ImplicitConversionSequence::Indistinguishable;
}

Expand Down
26 changes: 8 additions & 18 deletions clang/test/SemaHLSL/OverloadResolutionBugs.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@
// https://github.com/llvm/llvm-project/issues/81047

// expected-no-diagnostics
void Fn3(double2 D);
void Fn3(float2 F);

void Call3(half2 H) { Fn3(H); }

void Fn5(double2 D);

void Call5(half2 H) { Fn5(H); }

void Fn4(int64_t2 L);
void Fn4(int2 I);

Expand Down Expand Up @@ -61,13 +52,12 @@ float test_frac_int(int p0) { return frac(p0); }

float test_frac_bool(bool p0) { return frac(p0); }

// https://github.com/llvm/llvm-project/issues/81049
// This resolves the wrong overload. In clang this converts down to an int, in
// DXC it extends the scalar to a vector.
void Fn(int) {}
void Fn(vector<int64_t,2>) {}

// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.2-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefix=NO_HALF

half sqrt_h(half x) { return sqrt(x); }

// NO_HALF: define noundef float @"?sqrt_h@@YA$halff@$halff@@Z"(
// NO_HALF: call float @llvm.sqrt.f32(float %0)
void Call() {
int64_t V;
Fn(V);
}
229 changes: 229 additions & 0 deletions clang/test/SemaHLSL/ScalarOverloadResolution.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -Wconversion -verify -o - %s
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -ast-dump %s | FileCheck %s

// This test verifies floating point type implicit conversion ranks for overload
// resolution. In HLSL the built-in type ranks are half < float < double. This
// applies to both scalar and vector types.

// HLSL allows implicit truncation fo types, so it differentiates between
// promotions (converting to larger types) and conversions (converting to
// smaller types). Promotions are preferred over conversions. Promotions prefer
// promoting to the next lowest type in the ranking order. Conversions prefer
// converting to the next highest type in the ranking order.

void HalfFloatDouble(double D);
void HalfFloatDouble(float F);
void HalfFloatDouble(half H);

// CHECK: FunctionDecl {{.*}} used HalfFloatDouble 'void (double)'
// CHECK: FunctionDecl {{.*}} used HalfFloatDouble 'void (float)'
// CHECK: FunctionDecl {{.*}} used HalfFloatDouble 'void (half)'

void FloatDouble(double D);
void FloatDouble(float F);

// CHECK: FunctionDecl {{.*}} used FloatDouble 'void (double)'
// CHECK: FunctionDecl {{.*}} used FloatDouble 'void (float)'

void HalfDouble(double D);
void HalfDouble(half H);

// CHECK: FunctionDecl {{.*}} used HalfDouble 'void (double)'
// CHECK: FunctionDecl {{.*}} used HalfDouble 'void (half)'

void HalfFloat(float F);
void HalfFloat(half H);

// CHECK: FunctionDecl {{.*}} used HalfFloat 'void (float)'
// CHECK: FunctionDecl {{.*}} used HalfFloat 'void (half)'

void Double(double D);
void Float(float F);
void Half(half H);

// CHECK: FunctionDecl {{.*}} used Double 'void (double)'
// CHECK: FunctionDecl {{.*}} used Float 'void (float)'
// CHECK: FunctionDecl {{.*}} used Half 'void (half)'


// Case 1: A function declared with overloads for half float and double types.
// (a) When called with half, it will resolve to half because half is an exact
// match.
// (b) When called with float it will resolve to float because float is an
// exact match.
// (c) When called with double it will resolve to double because it is an
// exact match.

// CHECK-LABEL: FunctionDecl {{.*}} Case1 'void (half, float, double)'
void Case1(half H, float F, double D) {
// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(half)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (half)' lvalue Function {{.*}} 'HalfFloatDouble' 'void (half)'
HalfFloatDouble(H);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'HalfFloatDouble' 'void (float)'
HalfFloatDouble(F);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(double)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (double)' lvalue Function {{.*}} 'HalfFloatDouble' 'void (double)'
HalfFloatDouble(D);
}

// Case 2: A function declared with double and float overlaods.
// (a) When called with half, it will resolve to float because float is lower
// ranked than double.
// (b) When called with float it will resolve to float because float is an
// exact match.
// (c) When called with double it will resolve to double because it is an
// exact match.

// CHECK-LABEL: FunctionDecl {{.*}} Case2 'void (half, float, double)'
void Case2(half H, float F, double D) {
// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'FloatDouble' 'void (float)'
FloatDouble(H);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'FloatDouble' 'void (float)'
FloatDouble(F);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(double)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (double)' lvalue Function {{.*}} 'FloatDouble' 'void (double)'
FloatDouble(D);
}

// Case 3: A function declared with half and double overloads
// (a) When called with half, it will resolve to half because it is an exact
// match.
// (b) When called with flaot, it will resolve to double because double is a
// valid promotion.
// (c) When called with double, it will resolve to double because it is an
// exact match.

// CHECK-LABEL: FunctionDecl {{.*}} Case3 'void (half, float, double)'
void Case3(half H, float F, double D) {
// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(half)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (half)' lvalue Function {{.*}} 'HalfDouble' 'void (half)'
HalfDouble(H);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(double)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (double)' lvalue Function {{.*}} 'HalfDouble' 'void (double)'
HalfDouble(F);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(double)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (double)' lvalue Function {{.*}} 'HalfDouble' 'void (double)'
HalfDouble(D);
}

// Case 4: A function declared with half and float overloads.
// (a) When called with half, it will resolve to half because half is an exact
// match.
// (b) When called with float it will resolve to float because float is an
// exact match.
// (c) When called with double it will resolve to float because it is the
// float is higher rank than half.

// CHECK-LABEL: FunctionDecl {{.*}} Case4 'void (half, float, double)'
void Case4(half H, float F, double D) {
// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(half)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (half)' lvalue Function {{.*}} 'HalfFloat' 'void (half)'
HalfFloat(H);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'HalfFloat' 'void (float)'
HalfFloat(F);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'HalfFloat' 'void (float)'
HalfFloat(D); // expected-warning{{implicit conversion loses floating-point precision: 'double' to 'float'}}
}

// Case 5: A function declared with only a double overload.
// (a) When called with half, it will resolve to double because double is a
// valid promotion.
// (b) When called with float it will resolve to double because double is a
// valid promotion.
// (c) When called with double it will resolve to double because it is an
// exact match.

// CHECK-LABEL: FunctionDecl {{.*}} Case5 'void (half, float, double)'
void Case5(half H, float F, double D) {
// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(double)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (double)' lvalue Function {{.*}} 'Double' 'void (double)'
Double(H);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(double)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (double)' lvalue Function {{.*}} 'Double' 'void (double)'
Double(F);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(double)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (double)' lvalue Function {{.*}} 'Double' 'void (double)'
Double(D);
}

// Case 6: A function declared with only a float overload.
// (a) When called with half, it will resolve to float because float is a
// valid promotion.
// (b) When called with float it will resolve to float because float is an
// exact match.
// (c) When called with double it will resolve to float because it is a
// valid conversion.

// CHECK-LABEL: FunctionDecl {{.*}} Case6 'void (half, float, double)'
void Case6(half H, float F, double D) {
// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'Float' 'void (float)'
Float(H);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'Float' 'void (float)'
Float(F);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float)' lvalue Function {{.*}} 'Float' 'void (float)'
Float(D); // expected-warning{{implicit conversion loses floating-point precision: 'double' to 'float'}}
}

// Case 7: A function declared with only a half overload.
// (a) When called with half, it will resolve to half because half is an
// exact match
// (b) When called with float it will resolve to half because half is a
// valid conversion.
// (c) When called with double it will resolve to float because it is a
// valid conversion.

// CHECK-LABEL: FunctionDecl {{.*}} Case7 'void (half, float, double)'
void Case7(half H, float F, double D) {
// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(half)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (half)' lvalue Function {{.*}} 'Half' 'void (half)'
Half(H);

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(half)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (half)' lvalue Function {{.*}} 'Half' 'void (half)'
Half(F); // expected-warning{{implicit conversion loses floating-point precision: 'float' to 'half'}}

// CHECK: CallExpr {{.*}} 'void'
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(half)' <FunctionToPointerDecay>
// CHECK-NEXT: DeclRefExpr {{.*}} 'void (half)' lvalue Function {{.*}} 'Half' 'void (half)'
Half(D); // expected-warning{{implicit conversion loses floating-point precision: 'double' to 'half'}}
}
Loading

0 comments on commit aa5ff68

Please sign in to comment.