Skip to content

Commit

Permalink
[SYCL] adding binary caching support to kernel_compiler extension (#1…
Browse files Browse the repository at this point in the history
…5537)

Trying to reuse as much of the `PersistentDeviceCodeCache` as possible.
We use the same top cache directory as the regular binary caches, but
with a `ext_kernel_compiler` subdirectory and a slightly different
system for assigning paths.

Rather than write new tests, am just adding a "test the cache" passes to
the existing tests.
  • Loading branch information
cperkinsintel authored Nov 1, 2024
1 parent b46b900 commit 5dae72c
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 85 deletions.
170 changes: 120 additions & 50 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <detail/kernel_compiler/kernel_compiler_opencl.hpp>
#include <detail/kernel_compiler/kernel_compiler_sycl.hpp>
#include <detail/kernel_impl.hpp>
#include <detail/persistent_device_code_cache.hpp>
#include <detail/program_manager/program_manager.hpp>
#include <sycl/backend_types.hpp>
#include <sycl/context.hpp>
Expand Down Expand Up @@ -396,6 +397,53 @@ class kernel_bundle_impl {
return SS.str();
}

bool
extKernelCompilerFetchFromCache(const std::vector<device> Devices,
const std::vector<std::string> &BuildOptions,
const std::string &SourceStr,
ur_program_handle_t &UrProgram) {
using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
ContextImplPtr ContextImpl = getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl->getAdapter();

std::string UserArgs = syclex::detail::userArgsAsString(BuildOptions);

std::vector<ur_device_handle_t> DeviceHandles;
std::transform(
Devices.begin(), Devices.end(), std::back_inserter(DeviceHandles),
[](const device &Dev) { return getSyclObjImpl(Dev)->getHandleRef(); });

std::vector<const uint8_t *> Binaries;
std::vector<size_t> Lengths;
std::vector<std::vector<std::vector<char>>> PersistentBinaries;
for (size_t i = 0; i < Devices.size(); i++) {
std::vector<std::vector<char>> BinProg =
PersistentDeviceCodeCache::getCompiledKernelFromDisc(
Devices[i], UserArgs, SourceStr);

// exit if any device binary is missing
if (BinProg.empty()) {
return false;
}
PersistentBinaries.push_back(BinProg);

Binaries.push_back((uint8_t *)(PersistentBinaries[i][0].data()));
Lengths.push_back(PersistentBinaries[i][0].size());
}

ur_program_properties_t Properties = {};
Properties.stype = UR_STRUCTURE_TYPE_PROGRAM_PROPERTIES;
Properties.pNext = nullptr;
Properties.count = 0;
Properties.pMetadatas = nullptr;

Adapter->call<UrApiKind::urProgramCreateWithBinary>(
ContextImpl->getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
Lengths.data(), Binaries.data(), &Properties, &UrProgram);

return true;
}

std::shared_ptr<kernel_bundle_impl>
build_from_source(const std::vector<device> Devices,
const std::vector<std::string> &BuildOptions,
Expand All @@ -415,57 +463,68 @@ class kernel_bundle_impl {
DeviceVec.push_back(Dev);
}

const auto spirv = [&]() -> std::vector<uint8_t> {
if (Language == syclex::source_language::opencl) {
// if successful, the log is empty. if failed, throws an error with the
// compilation log.
const auto &SourceStr = std::get<std::string>(this->Source);
std::vector<uint32_t> IPVersionVec(Devices.size());
std::transform(DeviceVec.begin(), DeviceVec.end(), IPVersionVec.begin(),
[&](ur_device_handle_t d) {
uint32_t ipVersion = 0;
Adapter->call<UrApiKind::urDeviceGetInfo>(
d, UR_DEVICE_INFO_IP_VERSION, sizeof(uint32_t),
&ipVersion, nullptr);
return ipVersion;
});
return syclex::detail::OpenCLC_to_SPIRV(SourceStr, IPVersionVec,
BuildOptions, LogPtr);
}
if (Language == syclex::source_language::spirv) {
const auto &SourceBytes =
std::get<std::vector<std::byte>>(this->Source);
std::vector<uint8_t> Result(SourceBytes.size());
std::transform(SourceBytes.cbegin(), SourceBytes.cend(), Result.begin(),
[](std::byte B) { return static_cast<uint8_t>(B); });
return Result;
}
if (Language == syclex::source_language::sycl) {
const auto &SourceStr = std::get<std::string>(this->Source);
return syclex::detail::SYCL_to_SPIRV(SourceStr, IncludePairs,
BuildOptions, LogPtr,
RegisteredKernelNames);
}
if (Language == syclex::source_language::sycl_jit) {
const auto &SourceStr = std::get<std::string>(this->Source);
return syclex::detail::SYCL_JIT_to_SPIRV(SourceStr, IncludePairs,
BuildOptions, LogPtr,
RegisteredKernelNames);
}
throw sycl::exception(
make_error_code(errc::invalid),
"OpenCL C and SPIR-V are the only supported languages at this time");
}();

ur_program_handle_t UrProgram = nullptr;
Adapter->call<UrApiKind::urProgramCreateWithIL>(ContextImpl->getHandleRef(),
spirv.data(), spirv.size(),
nullptr, &UrProgram);
// program created by urProgramCreateWithIL is implicitly retained.
if (UrProgram == nullptr)
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"urProgramCreateWithIL resulted in a null program handle.");
// SourceStrPtr will be null when source is Spir-V bytes.
const std::string *SourceStrPtr = std::get_if<std::string>(&this->Source);
bool FetchedFromCache = false;
if (PersistentDeviceCodeCache::isEnabled() && SourceStrPtr) {
FetchedFromCache = extKernelCompilerFetchFromCache(
Devices, BuildOptions, *SourceStrPtr, UrProgram);
}

if (!FetchedFromCache) {
const auto spirv = [&]() -> std::vector<uint8_t> {
if (Language == syclex::source_language::opencl) {
// if successful, the log is empty. if failed, throws an error with
// the compilation log.
std::vector<uint32_t> IPVersionVec(Devices.size());
std::transform(DeviceVec.begin(), DeviceVec.end(),
IPVersionVec.begin(), [&](ur_device_handle_t d) {
uint32_t ipVersion = 0;
Adapter->call<UrApiKind::urDeviceGetInfo>(
d, UR_DEVICE_INFO_IP_VERSION, sizeof(uint32_t),
&ipVersion, nullptr);
return ipVersion;
});
return syclex::detail::OpenCLC_to_SPIRV(*SourceStrPtr, IPVersionVec,
BuildOptions, LogPtr);
}
if (Language == syclex::source_language::spirv) {
const auto &SourceBytes =
std::get<std::vector<std::byte>>(this->Source);
std::vector<uint8_t> Result(SourceBytes.size());
std::transform(SourceBytes.cbegin(), SourceBytes.cend(),
Result.begin(),
[](std::byte B) { return static_cast<uint8_t>(B); });
return Result;
}
if (Language == syclex::source_language::sycl) {
return syclex::detail::SYCL_to_SPIRV(*SourceStrPtr, IncludePairs,
BuildOptions, LogPtr,
RegisteredKernelNames);
}
if (Language == syclex::source_language::sycl_jit) {
const auto &SourceStr = std::get<std::string>(this->Source);
return syclex::detail::SYCL_JIT_to_SPIRV(SourceStr, IncludePairs,
BuildOptions, LogPtr,
RegisteredKernelNames);
}
throw sycl::exception(
make_error_code(errc::invalid),
"SYCL C++, OpenCL C and SPIR-V are the only supported "
"languages at this time");
}();

Adapter->call<UrApiKind::urProgramCreateWithIL>(
ContextImpl->getHandleRef(), spirv.data(), spirv.size(), nullptr,
&UrProgram);
// program created by urProgramCreateWithIL is implicitly retained.
if (UrProgram == nullptr)
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"urProgramCreateWithIL resulted in a null program handle.");

} // if(!FetchedFromCache)

std::string XsFlags = extractXsFlags(BuildOptions);
auto Res = Adapter->call_nocheck<UrApiKind::urProgramBuildExp>(
Expand Down Expand Up @@ -501,6 +560,17 @@ class kernel_bundle_impl {
nullptr, MContext, MDevices, bundle_state::executable, KernelIDs,
UrProgram);
device_image_plain DevImg{DevImgImpl};

// If caching enabled and kernel not fetched from cache, cache.
if (PersistentDeviceCodeCache::isEnabled() && !FetchedFromCache &&
SourceStrPtr) {
for (const auto &Device : Devices) {
PersistentDeviceCodeCache::putCompiledKernelToDisc(
Device, syclex::detail::userArgsAsString(BuildOptions),
*SourceStrPtr, UrProgram);
}
}

return std::make_shared<kernel_bundle_impl>(MContext, MDevices, DevImg,
KernelNames, Language);
}
Expand Down
9 changes: 9 additions & 0 deletions sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ SYCL_to_SPIRV(const std::string &SYCLSource, include_pairs_t IncludePairs,
throw sycl::exception(sycl::errc::build,
"kernel_compiler does not support GCC<8");
}

std::string userArgsAsString(const std::vector<std::string> &UserArguments) {
return std::accumulate(UserArguments.begin(), UserArguments.end(),
std::string(""),
[](const std::string &A, const std::string &B) {
return A.empty() ? B : A + " " + B;
});
}

} // namespace detail
} // namespace ext::oneapi::experimental
} // namespace _V1
Expand Down
2 changes: 2 additions & 0 deletions sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ SYCL_to_SPIRV(const std::string &Source, include_pairs_t IncludePairs,

bool SYCL_Compilation_Available();

std::string userArgsAsString(const std::vector<std::string> &UserArguments);

spirv_vec_t
SYCL_JIT_to_SPIRV(const std::string &Source, include_pairs_t IncludePairs,
const std::vector<std::string> &UserArgs, std::string *LogPtr,
Expand Down
Loading

0 comments on commit 5dae72c

Please sign in to comment.