diff --git a/nvtx_plugins/cc/nvtx_kernels.cc b/nvtx_plugins/cc/nvtx_kernels.cc index 0d75479..cff7fe5 100644 --- a/nvtx_plugins/cc/nvtx_kernels.cc +++ b/nvtx_plugins/cc/nvtx_kernels.cc @@ -205,3 +205,24 @@ class NvtxEndOp : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL + +#define REGISTER_CPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("NvtxStart") \ + .Device(DEVICE_CPU) \ + .HostMemory("message") \ + .HostMemory("domain_name") \ + .HostMemory("marker_id") \ + .HostMemory("domain_handle") \ + .TypeConstraint("T"), \ + NvtxStartOp); \ + REGISTER_KERNEL_BUILDER(Name("NvtxEnd") \ + .Device(DEVICE_CPU) \ + .HostMemory("marker_id") \ + .HostMemory("domain_handle") \ + .HostMemory("grad_message") \ + .HostMemory("grad_domain_name") \ + .TypeConstraint("T"), \ + NvtxEndOp); + +TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNEL); +#undef REGISTER_CPU_KERNEL