diff --git a/include/mscclpp/proxy_channel_device.hpp b/include/mscclpp/proxy_channel_device.hpp index c4cbc6ec2..bebc9f567 100644 --- a/include/mscclpp/proxy_channel_device.hpp +++ b/include/mscclpp/proxy_channel_device.hpp @@ -64,12 +64,21 @@ union ChannelTrigger { /// @param semaphoreId The ID of the semaphore. MSCCLPP_DEVICE_INLINE ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t bytes, int semaphoreId) { - value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + bytes); - value.snd = ((((((((semaphoreId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst) + constexpr uint64_t maskSize = (1ULL << MSCCLPP_BITS_SIZE) - 1; + constexpr uint64_t maskSrcOffset = (1ULL << MSCCLPP_BITS_OFFSET) - 1; + constexpr uint64_t maskDstOffset = (1ULL << MSCCLPP_BITS_OFFSET) - 1; + constexpr uint64_t maskSrcMemoryId = (1ULL << MSCCLPP_BITS_REGMEM_HANDLE) - 1; + constexpr uint64_t maskDstMemoryId = (1ULL << MSCCLPP_BITS_REGMEM_HANDLE) - 1; + constexpr uint64_t maskType = (1ULL << MSCCLPP_BITS_TYPE) - 1; + constexpr uint64_t maskChanId = (1ULL << MSCCLPP_BITS_CONNID) - 1; + value.fst = (((srcOffset & maskSrcOffset) << MSCCLPP_BITS_SIZE) + (bytes & maskSize)); + value.snd = (((((((((semaphoreId & maskChanId) << MSCCLPP_BITS_TYPE) + ((uint64_t)type & maskType)) + << MSCCLPP_BITS_REGMEM_HANDLE) + + (dst & maskDstMemoryId)) << MSCCLPP_BITS_REGMEM_HANDLE) + - src) + (src & maskSrcMemoryId)) << MSCCLPP_BITS_OFFSET) + - dstOffset); + (dstOffset & maskDstOffset)); } #endif // defined(MSCCLPP_DEVICE_COMPILE) };