From e7281f489947ec3b704eb25224b7b49d5e639a9d Mon Sep 17 00:00:00 2001 From: Artemy Kovalyov Date: Fri, 22 Nov 2024 16:24:40 +0200 Subject: [PATCH] BINDINGS/GO/PERF: Implement flag.Value for UcsMemoryType - 2 --- bindings/go/src/ucx/memory.go | 14 -------------- bindings/go/src/ucx/ucs_constants.go | 18 +++++++++++++++--- bindings/go/src/ucx/ucs_constants.h | 22 ++++++++++++++++++++++ 3 files changed, 37 insertions(+), 17 deletions(-) create mode 100644 bindings/go/src/ucx/ucs_constants.h diff --git a/bindings/go/src/ucx/memory.go b/bindings/go/src/ucx/memory.go index a67b2ace6fb1..5d3086db733f 100644 --- a/bindings/go/src/ucx/memory.go +++ b/bindings/go/src/ucx/memory.go @@ -9,11 +9,6 @@ package ucx import "C" import "unsafe" -import ( - "errors" - "strings" -) - // Memory handle is an opaque object representing a memory region allocated // through UCP library, which is optimized for remote memory access // operations (zero-copy operations). The memory could be registered @@ -72,12 +67,3 @@ func (m *UcpMemory) Close() error { return nil } - -func (mt *UcsMemoryType) Set (value string) error { - switch strings.ToLower(value) { - case "host": *mt = UCS_MEMORY_TYPE_HOST - case "cuda": *mt = UCS_MEMORY_TYPE_CUDA - default: return errors.New("memory type can be host or cuda") - } - return nil -} diff --git a/bindings/go/src/ucx/ucs_constants.go b/bindings/go/src/ucx/ucs_constants.go index eab44f336d60..205e1802eb23 100644 --- a/bindings/go/src/ucx/ucs_constants.go +++ b/bindings/go/src/ucx/ucs_constants.go @@ -5,11 +5,13 @@ package ucx +//#include "ucs_constants.h" //#include -//static inline const char* ucxgo_get_ucs_mem_type_name(ucs_memory_type_t idx) { -// return ucs_memory_type_names[idx]; -//} import "C" +import ( + "errors" + "unsafe" +) type UcsThreadMode int @@ -34,6 +36,16 @@ func (m UcsMemoryType) String() string { return C.GoString(C.ucxgo_get_ucs_mem_type_name(C.ucs_memory_type_t(m))) } +func (m *UcsMemoryType) Set (value string) error { + cValue := unsafe.Pointer(C.CString(value)) + res := C.ucxgo_parse_ucs_mem_type_name(cValue) + if res == -1 { + return errors.New("memory type can be either host or cuda") + } + *m = UcsMemoryType(res) + return nil +} + // Checks whether context's memory type mask // (received via UcpContext.MemoryTypesMask()) supports particular memory type. func IsMemTypeSupported(memType UcsMemoryType, mask uint64) bool { diff --git a/bindings/go/src/ucx/ucs_constants.h b/bindings/go/src/ucx/ucs_constants.h new file mode 100644 index 000000000000..6c2f1400aa2d --- /dev/null +++ b/bindings/go/src/ucx/ucs_constants.h @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ +#ifndef GO_UCS_CONSTANTS_H_ +#define GO_UCS_CONSTANTS_H_ + +#include +#include +#include + +static inline const char* ucxgo_get_ucs_mem_type_name(ucs_memory_type_t idx) { + return ucs_memory_type_names[idx]; +} + +static inline ssize_t ucxgo_parse_ucs_mem_type_name(void* value) { + ssize_t idx = ucs_string_find_in_list((const char *)value, ucs_memory_type_names, 0); + free(value); + return idx; +} + +#endif