Skip to content

Commit

Permalink
BINDINGS/GO/PERF: Implement flag.Value for UcsMemoryType - 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Artemy-Mellanox committed Nov 22, 2024
1 parent 381317d commit e7281f4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
14 changes: 0 additions & 14 deletions bindings/go/src/ucx/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ package ucx
import "C"
import "unsafe"

import (
"errors"
"strings"
)

// Memory handle is an opaque object representing a memory region allocated
// through UCP library, which is optimized for remote memory access
// operations (zero-copy operations). The memory could be registered
Expand Down Expand Up @@ -72,12 +67,3 @@ func (m *UcpMemory) Close() error {

return nil
}

func (mt *UcsMemoryType) Set (value string) error {
switch strings.ToLower(value) {
case "host": *mt = UCS_MEMORY_TYPE_HOST
case "cuda": *mt = UCS_MEMORY_TYPE_CUDA
default: return errors.New("memory type can be host or cuda")
}
return nil
}
18 changes: 15 additions & 3 deletions bindings/go/src/ucx/ucs_constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

package ucx

//#include "ucs_constants.h"
//#include <ucp/api/ucp.h>
//static inline const char* ucxgo_get_ucs_mem_type_name(ucs_memory_type_t idx) {
// return ucs_memory_type_names[idx];
//}
import "C"
import (
"errors"
"unsafe"
)

type UcsThreadMode int

Expand All @@ -34,6 +36,16 @@ func (m UcsMemoryType) String() string {
return C.GoString(C.ucxgo_get_ucs_mem_type_name(C.ucs_memory_type_t(m)))
}

func (m *UcsMemoryType) Set (value string) error {
cValue := unsafe.Pointer(C.CString(value))
res := C.ucxgo_parse_ucs_mem_type_name(cValue)
if res == -1 {
return errors.New("memory type can be either host or cuda")
}
*m = UcsMemoryType(res)
return nil
}

// Checks whether context's memory type mask
// (received via UcpContext.MemoryTypesMask()) supports particular memory type.
func IsMemTypeSupported(memType UcsMemoryType, mask uint64) bool {
Expand Down
22 changes: 22 additions & 0 deletions bindings/go/src/ucx/ucs_constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright (C) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
#ifndef GO_UCS_CONSTANTS_H_
#define GO_UCS_CONSTANTS_H_

#include <ucs/memory/memory_type.h>
#include <ucs/sys/string.h>
#include <stdlib.h>

static inline const char* ucxgo_get_ucs_mem_type_name(ucs_memory_type_t idx) {
return ucs_memory_type_names[idx];
}

static inline ssize_t ucxgo_parse_ucs_mem_type_name(void* value) {
ssize_t idx = ucs_string_find_in_list((const char *)value, ucs_memory_type_names, 0);
free(value);
return idx;
}

#endif

0 comments on commit e7281f4

Please sign in to comment.