Skip to content

Commit

Permalink
Fix anyvalue marshalling for matrix and 64 bit types. (shader-slang#5827
Browse files Browse the repository at this point in the history
)

* Fix anyvalue marshalling for matrix types.

* Add support for 64bit types marshalling.

---------

Co-authored-by: Ellie Hermaszewska <[email protected]>
  • Loading branch information
csyonghe and expipiplus1 authored Dec 11, 2024
1 parent f687688 commit f573c15
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 14 deletions.
16 changes: 16 additions & 0 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3285,6 +3285,19 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return nullptr;
}

SpvInst* emitMakeUInt64(SpvInstParent* parent, IRInst* inst)
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto vec = emitOpCompositeConstruct(
parent,
nullptr,
builder.getVectorType(builder.getUIntType(), 2),
inst->getOperand(0),
inst->getOperand(1));
return emitOpBitcast(parent, inst, inst->getDataType(), vec);
}

// The instructions that appear inside the basic blocks of
// functions are what we will call "local" instructions.
//
Expand Down Expand Up @@ -3391,6 +3404,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_BitCast:
result = emitOpBitcast(parent, inst, inst->getDataType(), inst->getOperand(0));
break;
case kIROp_MakeUInt64:
result = emitMakeUInt64(parent, inst);
break;
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
Expand Down
110 changes: 96 additions & 14 deletions source/slang/slang-ir-any-value-marshalling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ struct AnyValueMarshallingContext
intraFieldOffset = 0;
}
}
void ensureOffsetAt8ByteBoundary()
{
ensureOffsetAt4ByteBoundary();
if ((fieldOffset & 1) != 0)
fieldOffset++;
}
void ensureOffsetAt2ByteBoundary()
{
if (intraFieldOffset == 0)
Expand Down Expand Up @@ -146,6 +152,7 @@ struct AnyValueMarshallingContext
case kIROp_BoolType:
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
case kIROp_PtrType:
context->marshalBasicType(builder, dataType, concreteTypedVar);
break;
case kIROp_VectorType:
Expand All @@ -166,17 +173,36 @@ struct AnyValueMarshallingContext
auto matrixType = static_cast<IRMatrixType*>(dataType);
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
for (IRIntegerValue i = 0; i < colCount; i++)
if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
{
auto col = builder->emitElementAddress(
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
for (IRIntegerValue j = 0; j < rowCount; j++)
for (IRIntegerValue i = 0; i < colCount; i++)
{
for (IRIntegerValue j = 0; j < rowCount; j++)
{
auto row = builder->emitElementAddress(
concreteTypedVar,
builder->getIntValue(builder->getIntType(), j));
auto element = builder->emitElementAddress(
row,
builder->getIntValue(builder->getIntType(), i));
emitMarshallingCode(builder, context, element);
}
}
}
else
{
for (IRIntegerValue i = 0; i < rowCount; i++)
{
auto element = builder->emitElementAddress(
col,
builder->getIntValue(builder->getIntType(), j));
emitMarshallingCode(builder, context, element);
auto row = builder->emitElementAddress(
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
for (IRIntegerValue j = 0; j < colCount; j++)
{
auto element = builder->emitElementAddress(
row,
builder->getIntValue(builder->getIntType(), j));
emitMarshallingCode(builder, context, element);
}
}
}
break;
Expand Down Expand Up @@ -348,11 +374,39 @@ struct AnyValueMarshallingContext
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
case kIROp_PtrType:
#if SLANG_PTR_IS_64
case kIROp_UIntPtrType:
case kIROp_IntPtrType:
#endif
SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
ensureOffsetAt8ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcVal = builder->emitLoad(concreteVar);
auto dstVal = builder->emitBitCast(builder->getUInt64Type(), srcVal);
auto lowBits = builder->emitCast(builder->getUIntType(), dstVal);
auto highBits = builder->emitShr(
builder->getUInt64Type(),
dstVal,
builder->getIntValue(builder->getIntType(), 32));
highBits = builder->emitCast(builder->getUIntType(), highBits);

auto dstAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
builder->emitStore(dstAddr, lowBits);
fieldOffset++;
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
dstAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
builder->emitStore(dstAddr, lowBits);
fieldOffset++;
}
}
break;
default:
SLANG_UNREACHABLE("unknown basic type");
Expand Down Expand Up @@ -545,7 +599,34 @@ struct AnyValueMarshallingContext
case kIROp_DoubleType:
case kIROp_Int8Type:
case kIROp_UInt8Type:
SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
case kIROp_PtrType:
#if SLANG_PTR_IS_64
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
#endif
ensureOffsetAt8ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
auto lowBits = builder->emitLoad(srcAddr);
fieldOffset++;
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcAddr1 = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
fieldOffset++;
auto highBits = builder->emitLoad(srcAddr1);
auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
if (dataType->getOp() != kIROp_UInt64Type)
combinedBits = builder->emitBitCast(dataType, combinedBits);
builder->emitStore(concreteVar, combinedBits);
}
}
break;
default:
SLANG_UNREACHABLE("unknown basic type");
Expand Down Expand Up @@ -735,7 +816,8 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
return -1;
case kIROp_PtrType:
return alignUp(offset, 8) + 8;
case kIROp_Int16Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
Expand All @@ -762,9 +844,9 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
auto elementType = matrixType->getElementType();
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
for (IRIntegerValue i = 0; i < colCount; i++)
for (IRIntegerValue i = 0; i < rowCount; i++)
{
for (IRIntegerValue j = 0; j < rowCount; j++)
for (IRIntegerValue j = 0; j < colCount; j++)
{
offset = _getAnyValueSizeRaw(elementType, offset);
if (offset < 0)
Expand Down
44 changes: 44 additions & 0 deletions tests/language-feature/anyvalue-layout.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -dx12 -use-dxil -profile cs_6_1 -output-using-type
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -output-using-type

interface IFoo
{
float getVal();
uint64_t getPtrVal();
}

struct Foo : IFoo
{
column_major float3x2 m;
int x;
uint64_t ptr;
float getVal()
{
return m[2][0];
}
uint64_t getPtrVal()
{
return ptr;
}
}

//TEST_INPUT: type_conformance Foo:IFoo = 0

//TEST_INPUT: set gFoo = ubuffer(data=[0 0 0 0 1.0 2.0 3.0 4.0 5.0 6.0 0 0 1 2], stride=4)
RWStructuredBuffer<IFoo> gFoo;

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<float> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
// CHECK: 3.0
outputBuffer[0] = gFoo[0].getVal();

// CHECK: 1.0
outputBuffer[1] = gFoo[0].getPtrVal()&0xFFFFFFFF;

// CHECK: 2.0
outputBuffer[2] = gFoo[0].getPtrVal()>>32;
}
30 changes: 30 additions & 0 deletions tests/language-feature/anyvalue-matrix-layout.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type

interface IFoo
{
float getVal();
}

struct Foo : IFoo
{
column_major float3x2 m;
float getVal()
{
return m[2][0];
}
}

//TEST_INPUT: type_conformance Foo:IFoo = 0

//TEST_INPUT: set gFoo = ubuffer(data=[0 0 0 0 1.0 2.0 3.0 4.0 5.0 6.0], stride=4)
RWStructuredBuffer<IFoo> gFoo;

//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<float> outputBuffer;

[numthreads(1,1,1)]
void computeMain()
{
// CHECK: 3.0
outputBuffer[0] = gFoo[0].getVal();
}

0 comments on commit f573c15

Please sign in to comment.