Skip to content

Commit

Permalink
Add COM API for querying metadata. (shader-slang#5168)
Browse files Browse the repository at this point in the history
* Add COM API for querying metadata.

* Fix tests.

* fix test.
  • Loading branch information
csyonghe authored Sep 30, 2024
1 parent bc11579 commit 15d1c6c
Show file tree
Hide file tree
Showing 17 changed files with 360 additions and 37 deletions.
27 changes: 27 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -5280,6 +5280,22 @@ namespace slang

#define SLANG_UUID_ISession ISession::getTypeGuid()

struct IMetadata : public ISlangCastable
{
SLANG_COM_INTERFACE(0x8044a8a3, 0xddc0, 0x4b7f, { 0xaf, 0x8e, 0x2, 0x6e, 0x90, 0x5d, 0x73, 0x32 })

/*
Returns whether a resource parameter at the specifieid binding location is actually being used
in the compiled shader.
*/
virtual SlangResult isParameterLocationUsed(
SlangParameterCategory category, // is this a `t` register? `s` register?
SlangUInt spaceIndex, // `space` for D3D12, `set` for Vulkan
SlangUInt registerIndex, // `register` for D3D12, `binding` for Vulkan
bool& outUsed) = 0;
};
#define SLANG_UUID_IMetadata IMetadata::getTypeGuid()

/** A component type is a unit of shader code layout, reflection, and linking.
A component type is a unit of shader code that can be included into
Expand Down Expand Up @@ -5492,6 +5508,17 @@ namespace slang
SlangInt targetIndex,
IBlob** outCode,
IBlob** outDiagnostics = nullptr) = 0;

virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
IMetadata** outMetadata,
IBlob** outDiagnostics = nullptr) = 0;

virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
IMetadata** outMetadata,
IBlob** outDiagnostics = nullptr) = 0;
};
#define SLANG_UUID_IComponentType IComponentType::getTypeGuid()

Expand Down
20 changes: 20 additions & 0 deletions source/compiler-core/slang-artifact-associated-impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,24 @@ Slice<String> ArtifactPostEmitMetadata::getExportedFunctionMangledNames()
return Slice<String>(m_exportedFunctionMangledNames.getBuffer(), m_exportedFunctionMangledNames.getCount());
}

SlangResult ArtifactPostEmitMetadata::isParameterLocationUsed(
SlangParameterCategory category,
SlangUInt spaceIndex,
SlangUInt registerIndex,
bool& outUsed)
{
for (const auto& range : getUsedBindingRanges())
{
if (range.containsBinding((slang::ParameterCategory)category, spaceIndex, registerIndex))
{
outUsed = true;
return SLANG_OK;
}
}

outUsed = false;
return SLANG_OK;
}


} // namespace Slang
8 changes: 8 additions & 0 deletions source/compiler-core/slang-artifact-associated-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ struct ShaderBindingRange
case slang::ShaderResource:
case slang::UnorderedAccess:
case slang::SamplerState:
case slang::DescriptorTableSlot:
return true;
default:
return false;
Expand All @@ -157,6 +158,13 @@ class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitM
SLANG_NO_THROW virtual Slice<ShaderBindingRange> SLANG_MCALL getUsedBindingRanges() SLANG_OVERRIDE;
SLANG_NO_THROW virtual Slice<String> SLANG_MCALL getExportedFunctionMangledNames() SLANG_OVERRIDE;

// IMetadata
SLANG_NO_THROW virtual SlangResult SLANG_MCALL isParameterLocationUsed(
SlangParameterCategory category, // is this a `t` register? `s` register?
SlangUInt spaceIndex, // `space` for D3D12, `set` for Vulkan
SlangUInt registerIndex, // `register` for D3D12, `binding` for Vulkan
bool& outUsed) SLANG_OVERRIDE;

void* getInterface(const Guid& uuid);
void* getObject(const Guid& uuid);

Expand Down
2 changes: 1 addition & 1 deletion source/compiler-core/slang-artifact-associated.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class IArtifactDiagnostics : public IClonable

struct ShaderBindingRange;

class IArtifactPostEmitMetadata : public ICastable
class IArtifactPostEmitMetadata : public slang::IMetadata
{
public:
SLANG_COM_INTERFACE(0x5d03bce9, 0xafb1, 0x4fc8, { 0xa4, 0x6f, 0x3c, 0xe0, 0x7b, 0x6, 0x1b, 0x1b });
Expand Down
19 changes: 19 additions & 0 deletions source/slang-record-replay/record/slang-component-type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,25 @@ namespace SlangRecord
return res;
}

SLANG_NO_THROW SlangResult SLANG_MCALL IComponentTypeRecorder::getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics)
{
// No need to record this call.
return m_actualComponentType->getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL IComponentTypeRecorder::getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics)
{
// No need to record this call.
return m_actualComponentType->getTargetMetadata(targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult IComponentTypeRecorder::getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
9 changes: 9 additions & 0 deletions source/slang-record-replay/record/slang-component-type.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ namespace SlangRecord
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) override;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics = nullptr) override;
protected:
virtual ApiClassId getClassId() = 0;
virtual SessionRecorder* getSessionRecorder() = 0;
Expand Down
17 changes: 17 additions & 0 deletions source/slang-record-replay/record/slang-entrypoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,23 @@ namespace SlangRecord
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics);
}

virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
17 changes: 17 additions & 0 deletions source/slang-record-replay/record/slang-module.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ namespace SlangRecord
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics);
}

virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
17 changes: 17 additions & 0 deletions source/slang-record-replay/record/slang-type-conformance.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ namespace SlangRecord
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics);
}

virtual SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down
65 changes: 65 additions & 0 deletions source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,22 @@ namespace Slang
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE;

IArtifact* getTargetArtifact(SlangInt targetIndex, slang::IBlob** outDiagnostics);

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetCode(
SlangInt targetIndex,
slang::IBlob** outCode,
slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE;
SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE;
SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics = nullptr) SLANG_OVERRIDE;

SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
Expand Down Expand Up @@ -580,6 +592,8 @@ namespace Slang

Scope* m_lookupScope = nullptr;
std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal;

Dictionary<Int, ComPtr<IArtifact>> m_targetArtifacts;
};

/// A component type built up from other component types.
Expand Down Expand Up @@ -914,6 +928,23 @@ namespace Slang
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down Expand Up @@ -1159,6 +1190,23 @@ namespace Slang
return Super::getTargetCode(targetIndex, outCode, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getResultAsFileSystem(
SlangInt entryPointIndex,
SlangInt targetIndex,
Expand Down Expand Up @@ -1460,6 +1508,23 @@ namespace Slang
return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getEntryPointMetadata(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getEntryPointMetadata(entryPointIndex, targetIndex, outMetadata, outDiagnostics);
}

SLANG_NO_THROW SlangResult SLANG_MCALL getTargetMetadata(
SlangInt targetIndex,
slang::IMetadata** outMetadata,
slang::IBlob** outDiagnostics) SLANG_OVERRIDE
{
return Super::getTargetMetadata(targetIndex, outMetadata, outDiagnostics);
}

/// Get a serialized representation of the checked module.
virtual SLANG_NO_THROW SlangResult SLANG_MCALL serialize(ISlangBlob** outSerializedBlob) override;

Expand Down
11 changes: 6 additions & 5 deletions source/slang/slang-parameter-binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1225,18 +1225,14 @@ static void addExplicitParameterBindings_GLSL(
}
}

// We use the HLSL binding directly (even though this notionally for GLSL/Vulkan)
// We'll do the shifting at later later point in _maybeApplyHLSLToVulkanShifts
info[kResInfo].resInfo = typeLayout->findOrAddResourceInfo(hlslInfo.kind);

if (warnedMissingVulkanLayoutModifier)
{
// If we warn due to invalid bindings and user did not set how to interpret 'hlsl style bindings', we should map
// `register` 1:1 with equivlent vulkan bindings.
if(!hlslToVulkanLayoutOptions
|| hlslToVulkanLayoutOptions->getKindShiftEnabledFlags() == HLSLToVulkanLayoutOptions::KindFlag::None)
{
info[kResInfo].resInfo->kind = LayoutResourceKind::DescriptorTableSlot;
info[kResInfo].resInfo = typeLayout->findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot);
info[kResInfo].resInfo->count = 1;
}
else
Expand All @@ -1245,6 +1241,11 @@ static void addExplicitParameterBindings_GLSL(
}
}

// We use the HLSL binding directly (even though this notionally for GLSL/Vulkan)
// We'll do the shifting at later later point in _maybeApplyHLSLToVulkanShifts
if (!info[kResInfo].resInfo)
info[kResInfo].resInfo = typeLayout->findOrAddResourceInfo(hlslInfo.kind);

info[kResInfo].semanticInfo.kind = info[kResInfo].resInfo->kind;
info[kResInfo].semanticInfo.index = UInt(hlslInfo.index);
info[kResInfo].semanticInfo.space = UInt(hlslInfo.space);
Expand Down
Loading

0 comments on commit 15d1c6c

Please sign in to comment.