From 0a4bbb30cf673c3ab6d5a46dbb780782fc242c59 Mon Sep 17 00:00:00 2001 From: Ali Ince Date: Thu, 25 Oct 2018 22:25:39 +0100 Subject: [PATCH] Apply changes related to seabolt's encapsulation --- cmd/go-bolt/main.go | 60 ++++--------------------------- connection.go | 12 ++++--- connector.go | 88 +++++++++++++++++++++------------------------ error.go | 12 ++++--- logging.go | 20 +++-------- resolver.go | 8 ++--- value.go | 34 +++++++++--------- value_handler.go | 6 ++-- 8 files changed, 91 insertions(+), 149 deletions(-) diff --git a/cmd/go-bolt/main.go b/cmd/go-bolt/main.go index 1007911..be43bfc 100644 --- a/cmd/go-bolt/main.go +++ b/cmd/go-bolt/main.go @@ -23,11 +23,9 @@ import ( "errors" "flag" "fmt" - "os" - "time" - "github.com/neo4j-drivers/gobolt" "net/url" + "os" "strings" ) @@ -39,7 +37,6 @@ var ( query string mode string debug bool - stats bool ) func executeQuery() { @@ -50,7 +47,6 @@ func executeQuery() { logger := simpleLogger(logLevelDebug, os.Stderr) - start := time.Now() connector, err := gobolt.NewConnector(parsedURI, map[string]interface{}{ "scheme": "basic", "principal": username, @@ -60,68 +56,38 @@ func executeQuery() { panic(err) } defer connector.Close() - elapsed := time.Since(start) - if stats { - logger.Infof("NewConnector took %s", elapsed) - } accessMode := gobolt.AccessModeWrite if strings.ToLower(mode) == "read" { accessMode = gobolt.AccessModeRead } - start = time.Now() conn, err := connector.Acquire(accessMode) if err != nil { panic(err) } defer conn.Close() - elapsed = time.Since(start) - if stats { - logger.Infof("Acquire took %s", elapsed) - } - start = time.Now() runMsg, err := conn.Run(query, nil, nil, 0, nil) if err != nil { panic(err) } - elapsed = time.Since(start) - if stats { - logger.Infof("Run took %s", elapsed) - } - start = time.Now() pullAllMsg, err := conn.PullAll() if err != nil { panic(err) } - elapsed = time.Since(start) - if stats { - logger.Infof("PullAll took %s", elapsed) - } - start = time.Now() err = conn.Flush() if err != nil { panic(err) } - elapsed = time.Since(start) - if stats { - logger.Infof("Flush took %s", elapsed) - } - start = time.Now() records, err := conn.FetchSummary(runMsg) if records != 0 { panic(errors.New("unexpected summary fetch return")) } - elapsed = time.Since(start) - if stats { - logger.Infof("FetchSummary took %s", elapsed) - } - start = time.Now() fields, err := conn.Fields() if err != nil { panic(err) @@ -135,12 +101,7 @@ func executeQuery() { fmt.Print(fields[i]) } fmt.Println() - elapsed = time.Since(start) - if stats { - logger.Infof("Summary processing took %s", elapsed) - } - start = time.Now() for { fetch, err := conn.Fetch(pullAllMsg) if err != nil { @@ -165,25 +126,19 @@ func executeQuery() { fmt.Println() } - elapsed = time.Since(start) - if stats { - logger.Infof("Result processing took %s", elapsed) - } } func main() { flag.Parse() executeQuery() - if stats { - current, peak, events := gobolt.GetAllocationStats() + current, peak, events := gobolt.GetAllocationStats() - fmt.Fprintf(os.Stderr, "=====================================\n") - fmt.Fprintf(os.Stderr, "current allocation : %d bytes\n", current) - fmt.Fprintf(os.Stderr, "peak allocation : %d bytes\n", peak) - fmt.Fprintf(os.Stderr, "allocation events : %d\n", events) - fmt.Fprintf(os.Stderr, "=====================================\n") - } + fmt.Fprintf(os.Stderr, "=====================================\n") + fmt.Fprintf(os.Stderr, "current allocation : %d bytes\n", current) + fmt.Fprintf(os.Stderr, "peak allocation : %d bytes\n", peak) + fmt.Fprintf(os.Stderr, "allocation events : %d\n", events) + fmt.Fprintf(os.Stderr, "=====================================\n") } func init() { @@ -195,5 +150,4 @@ func init() { flag.StringVar(&query, "query", "UNWIND RANGE(1,1000) AS N RETURN N", "cypher query to run") flag.StringVar(&mode, "mode", "write", "access mode for routing mode (read or write)") flag.BoolVar(&debug, "debug", true, "whether to use debug logging") - flag.BoolVar(&stats, "stats", true, "whether to dump allocation stats on exit") } diff --git a/connection.go b/connection.go index 0b48f8e..706da5b 100644 --- a/connection.go +++ b/connection.go @@ -69,12 +69,12 @@ func (connection *neo4jConnection) Id() string { } func (connection *neo4jConnection) RemoteAddress() string { - connectedAddress := connection.cInstance.resolved_address + connectedAddress := C.BoltConnection_remote_endpoint(connection.cInstance) if connectedAddress == nil { return "UNKNOWN" } - return fmt.Sprintf("%s:%s", C.GoString(connectedAddress.host), C.GoString(connectedAddress.port)) + return fmt.Sprintf("%s:%s", C.GoString(C.BoltAddress_host(connectedAddress)), C.GoString(C.BoltAddress_port(connectedAddress))) } func (connection *neo4jConnection) Server() string { @@ -228,7 +228,9 @@ func (connection *neo4jConnection) DiscardAll() (RequestHandle, error) { } func (connection *neo4jConnection) assertReadyState() error { - if connection.cInstance.status != C.BOLT_READY { + cStatus := C.BoltConnection_status(connection.cInstance) + + if C.BoltStatus_get_state(cStatus) != C.BOLT_CONNECTION_STATE_READY { return newError(connection, "unexpected connection state") } @@ -245,7 +247,7 @@ func (connection *neo4jConnection) Flush() error { } func (connection *neo4jConnection) Fetch(request RequestHandle) (FetchType, error) { - res := C.BoltConnection_fetch(connection.cInstance, C.bolt_request(request)) + res := C.BoltConnection_fetch(connection.cInstance, C.BoltRequest(request)) if err := connection.assertReadyState(); err != nil { return FetchTypeError, err @@ -255,7 +257,7 @@ func (connection *neo4jConnection) Fetch(request RequestHandle) (FetchType, erro } func (connection *neo4jConnection) FetchSummary(request RequestHandle) (int, error) { - res := C.BoltConnection_fetch_summary(connection.cInstance, C.bolt_request(request)) + res := C.BoltConnection_fetch_summary(connection.cInstance, C.BoltRequest(request)) if res < 0 { return -1, newError(connection, "unable to fetch summary") } diff --git a/connector.go b/connector.go index a4e8592..4903b41 100644 --- a/connector.go +++ b/connector.go @@ -80,8 +80,8 @@ type neo4jConnector struct { authToken map[string]interface{} config Config - cAddress *C.struct_BoltAddress - cInstance *C.struct_BoltConnector + cAddress *C.BoltAddress + cInstance *C.BoltConnector cLogger *C.struct_BoltLog cResolver *C.struct_BoltAddressResolver @@ -122,7 +122,7 @@ func (conn *neo4jConnector) Acquire(mode AccessMode) (Connection, error) { cMode = C.BOLT_ACCESS_MODE_READ } - cResult := C.BoltConnector_acquire(conn.cInstance, cMode) + cResult := C.BoltConnector_acquire(conn.cInstance, C.BoltAccessMode(cMode)) if cResult.connection == nil { codeText := C.GoString(C.BoltError_get_string(cResult.connection_error)) context := C.GoString(cResult.connection_error_ctx) @@ -140,9 +140,9 @@ func (conn *neo4jConnector) release(connection *neo4jConnection) error { // GetAllocationStats returns statistics about seabolt (C) allocations func GetAllocationStats() (int64, int64, int64) { - current := C.BoltMem_current_allocation() - peak := C.BoltMem_peak_allocation() - events := C.BoltMem_allocation_events() + current := C.BoltStat_memory_allocation_current() + peak := C.BoltStat_memory_allocation_peak() + events := C.BoltStat_memory_allocation_events() return int64(current), int64(peak), int64(events) } @@ -160,11 +160,10 @@ func NewConnector(uri *url.URL, authToken map[string]interface{}, config *Config } } - cTrust := (*C.struct_BoltTrust)(C.malloc(C.sizeof_struct_BoltTrust)) - cTrust.certs = nil - cTrust.certs_len = 0 - cTrust.skip_verify = 0 - cTrust.skip_verify_hostname = 0 + cTrust := C.BoltTrust_create() + C.BoltTrust_set_certs(cTrust, nil, 0) + C.BoltTrust_set_skip_verify(cTrust, 0) + C.BoltTrust_set_skip_verify_hostname(cTrust, 0) certsBuf, err := pemEncodeCerts(config.TLSCertificates) if err != nil { @@ -173,38 +172,34 @@ func NewConnector(uri *url.URL, authToken map[string]interface{}, config *Config if certsBuf != nil { certsBytes := certsBuf.String() - cTrust.certs = C.CString(certsBytes) - cTrust.certs_len = C.int32_t(certsBuf.Len()) + C.BoltTrust_set_certs(cTrust, C.CString(certsBytes), C.int(certsBuf.Len())) } if config.TLSSkipVerify { - cTrust.skip_verify = 1 + C.BoltTrust_set_skip_verify(cTrust, 1) } if config.TLSSkipVerifyHostname { - cTrust.skip_verify_hostname = 1 + C.BoltTrust_set_skip_verify_hostname(cTrust, 1) } - cSocketOpts := (*C.struct_BoltSocketOptions)(C.malloc(C.sizeof_struct_BoltSocketOptions)) - cSocketOpts.connect_timeout = C.int(config.SockConnectTimeout / time.Millisecond) - cSocketOpts.recv_timeout = C.int(config.SockRecvTimeout / time.Millisecond) - cSocketOpts.send_timeout = C.int(config.SockSendTimeout / time.Millisecond) - cSocketOpts.keepalive = 0 - - if config.SockKeepalive { - cSocketOpts.keepalive = 1 + cSocketOpts := C.BoltSocketOptions_create() + C.BoltSocketOptions_set_connect_timeout(cSocketOpts, C.int(config.SockConnectTimeout/time.Millisecond)) + C.BoltSocketOptions_set_keep_alive(cSocketOpts, 1) + if !config.SockKeepalive { + C.BoltSocketOptions_set_keep_alive(cSocketOpts, 0) } valueSystem := createValueSystem(config) - var mode uint32 = C.BOLT_DIRECT + var mode uint32 = C.BOLT_MODE_DIRECT if uri.Scheme == "bolt+routing" { - mode = C.BOLT_ROUTING + mode = C.BOLT_MODE_ROUTING } - var transport uint32 = C.BOLT_SOCKET + var transport uint32 = C.BOLT_TRANSPORT_PLAINTEXT if config.Encryption { - transport = C.BOLT_SECURE_SOCKET + transport = C.BOLT_TRANSPORT_ENCRYPTED } userAgentStr := C.CString("Go Driver/1.7") @@ -217,21 +212,21 @@ func NewConnector(uri *url.URL, authToken map[string]interface{}, config *Config cLogger := registerLogging(key, config.Log) cResolver := registerResolver(key, config.AddressResolver) - cConfig := C.struct_BoltConfig{ - mode: mode, - transport: transport, - trust: cTrust, - user_agent: userAgentStr, - routing_context: routingContextValue, - address_resolver: cResolver, - log: cLogger, - max_pool_size: C.int(config.MaxPoolSize), - max_connection_lifetime: C.int(config.MaxConnLifetime / time.Millisecond), - max_connection_acquire_time: C.int(config.ConnAcquisitionTimeout / time.Millisecond), - sock_opts: cSocketOpts, - } - cInstance := C.BoltConnector_create(address, authTokenValue, &cConfig) + cConfig := C.BoltConfig_create() + C.BoltConfig_set_mode(cConfig, C.BoltMode(mode)) + C.BoltConfig_set_transport(cConfig, C.BoltTransport(transport)) + C.BoltConfig_set_trust(cConfig, cTrust) + C.BoltConfig_set_user_agent(cConfig, userAgentStr) + C.BoltConfig_set_routing_context(cConfig, routingContextValue) + C.BoltConfig_set_address_resolver(cConfig, cResolver) + C.BoltConfig_set_log(cConfig, cLogger) + C.BoltConfig_set_max_pool_size(cConfig, C.int(config.MaxPoolSize)) + C.BoltConfig_set_max_connection_life_time(cConfig, C.int(config.MaxConnLifetime/time.Millisecond)) + C.BoltConfig_set_max_connection_acquisition_time(cConfig, C.int(config.ConnAcquisitionTimeout/time.Millisecond)) + C.BoltConfig_set_socket_options(cConfig, cSocketOpts) + + cInstance := C.BoltConnector_create(address, authTokenValue, cConfig) conn := &neo4jConnector{ key: key, uri: uri, @@ -249,18 +244,15 @@ func NewConnector(uri *url.URL, authToken map[string]interface{}, config *Config C.free(unsafe.Pointer(portStr)) C.BoltValue_destroy(routingContextValue) C.BoltValue_destroy(authTokenValue) - - if cTrust.certs != nil { - C.free(unsafe.Pointer(cTrust.certs)) - } - C.free(unsafe.Pointer(cTrust)) - C.free(unsafe.Pointer(cSocketOpts)) + C.BoltTrust_destroy(cTrust) + C.BoltSocketOptions_destroy(cSocketOpts) + C.BoltConfig_destroy(cConfig) return conn, nil } func createValueSystem(config *Config) *boltValueSystem { - valueHandlersBySignature := make(map[int8]ValueHandler, len(config.ValueHandlers)) + valueHandlersBySignature := make(map[int16]ValueHandler, len(config.ValueHandlers)) valueHandlersByType := make(map[reflect.Type]ValueHandler, len(config.ValueHandlers)) for _, handler := range config.ValueHandlers { for _, readSignature := range handler.ReadableStructs() { diff --git a/error.go b/error.go index a67f0a5..6a5f420 100644 --- a/error.go +++ b/error.go @@ -146,7 +146,10 @@ func (failure *defaultGenericError) Error() string { } func newError(connection *neo4jConnection, description string) error { - if connection.cInstance.error == C.BOLT_SERVER_FAILURE { + cStatus := C.BoltConnection_status(connection.cInstance) + errorCode := C.BoltStatus_get_error(cStatus) + + if errorCode == C.BOLT_SERVER_FAILURE { failure, err := connection.valueSystem.valueAsDictionary(C.BoltConnection_failure(connection.cInstance)) if err != nil { return connection.valueSystem.genericErrorFactory("unable to construct database error: %s", err.Error()) @@ -178,10 +181,11 @@ func newError(connection *neo4jConnection, description string) error { return connection.valueSystem.databaseErrorFactory(classification, code, message) } - codeText := C.GoString(C.BoltError_get_string(connection.cInstance.error)) - context := C.GoString(connection.cInstance.error_ctx) + state := C.BoltStatus_get_state(cStatus) + errorText := C.GoString(C.BoltError_get_string(errorCode)) + context := C.GoString(C.BoltStatus_get_error_context(cStatus)) - return connection.valueSystem.connectorErrorFactory(int(connection.cInstance.status), int(connection.cInstance.error), codeText, context, description) + return connection.valueSystem.connectorErrorFactory(int(state), int(errorCode), errorText, context, description) } func newGenericError(format string, args ...interface{}) GenericError { diff --git a/logging.go b/logging.go index 39dc424..5ef82bc 100644 --- a/logging.go +++ b/logging.go @@ -85,32 +85,22 @@ func registerLogging(key int, logging Logging) *C.struct_BoltLog { mapLogging.Store(key, logging) - boltLog := C.BoltLog_create() - boltLog.state = C.int(key) - - boltLog.error_enabled = 0 + boltLog := C.BoltLog_create(C.int(key)) if logging != nil && logging.ErrorEnabled() { - boltLog.error_enabled = 1 + C.BoltLog_set_error_func(boltLog, C.log_func(C.go_seabolt_log_error_cb)) } - boltLog.error_logger = C.log_func(C.go_seabolt_log_error_cb) - boltLog.warning_enabled = 0 if logging != nil && logging.WarningEnabled() { - boltLog.warning_enabled = 1 + C.BoltLog_set_warning_func(boltLog, C.log_func(C.go_seabolt_log_warning_cb)) } - boltLog.warning_logger = C.log_func(C.go_seabolt_log_warning_cb) - boltLog.info_enabled = 0 if logging != nil && logging.InfoEnabled() { - boltLog.info_enabled = 1 + C.BoltLog_set_info_func(boltLog, C.log_func(C.go_seabolt_log_info_cb)) } - boltLog.info_logger = C.log_func(C.go_seabolt_log_info_cb) - boltLog.debug_enabled = 0 if logging != nil && logging.DebugEnabled() { - boltLog.debug_enabled = 1 + C.BoltLog_set_debug_func(boltLog, C.log_func(C.go_seabolt_log_debug_cb)) } - boltLog.debug_logger = C.log_func(C.go_seabolt_log_debug_cb) return boltLog } diff --git a/resolver.go b/resolver.go index a92c77f..ad8fd70 100644 --- a/resolver.go +++ b/resolver.go @@ -41,14 +41,14 @@ type URLAddressResolver func(address *url.URL) []*url.URL func go_seabolt_server_address_resolver_cb(state C.int, address *C.struct_BoltAddress, resolved *C.struct_BoltAddressSet) { resolver := lookupResolver(state) if resolver != nil { - resolvedAddresses := resolver(&url.URL{Host: fmt.Sprintf("%s:%s", C.GoString(address.host), C.GoString(address.port))}) + resolvedAddresses := resolver(&url.URL{Host: fmt.Sprintf("%s:%s", C.GoString(C.BoltAddress_host(address)), C.GoString(C.BoltAddress_port(address)))}) for _, addr := range resolvedAddresses { cHost := C.CString(addr.Hostname()) cPort := C.CString(addr.Port()) cAddress := C.BoltAddress_create(cHost, cPort) - C.BoltAddressSet_add(resolved, *cAddress) + C.BoltAddressSet_add(resolved, cAddress) C.BoltAddress_destroy(cAddress) C.free(unsafe.Pointer(cHost)) @@ -66,9 +66,7 @@ func registerResolver(key int, resolver URLAddressResolver) *C.struct_BoltAddres mapResolver.Store(key, resolver) - boltResolver := C.BoltAddressResolver_create() - boltResolver.state = C.int(key) - boltResolver.resolver = C.address_resolver_func(C.go_seabolt_server_address_resolver_cb) + boltResolver := C.BoltAddressResolver_create(C.int(key), C.address_resolver_func(C.go_seabolt_server_address_resolver_cb)) return boltResolver } diff --git a/value.go b/value.go index 6099673..114df96 100644 --- a/value.go +++ b/value.go @@ -32,7 +32,7 @@ import ( type boltValueSystem struct { valueHandlers []ValueHandler - valueHandlersBySignature map[int8]ValueHandler + valueHandlersBySignature map[int16]ValueHandler valueHandlersByType map[reflect.Type]ValueHandler connectorErrorFactory func(state, code int, codeText, context, description string) ConnectorError databaseErrorFactory func(classification, code, message string) DatabaseError @@ -40,25 +40,27 @@ type boltValueSystem struct { } func (valueSystem *boltValueSystem) valueAsGo(value *C.struct_BoltValue) (interface{}, error) { + valueType := C.BoltValue_type(value) + switch { - case value._type == C.BOLT_NULL: + case valueType == C.BOLT_NULL: return nil, nil - case value._type == C.BOLT_BOOLEAN: + case valueType == C.BOLT_BOOLEAN: return valueSystem.valueAsBoolean(value), nil - case value._type == C.BOLT_INTEGER: + case valueType == C.BOLT_INTEGER: return valueSystem.valueAsInt(value), nil - case value._type == C.BOLT_FLOAT: + case valueType == C.BOLT_FLOAT: return valueSystem.valueAsFloat(value), nil - case value._type == C.BOLT_STRING: + case valueType == C.BOLT_STRING: return valueSystem.valueAsString(value), nil - case value._type == C.BOLT_DICTIONARY: + case valueType == C.BOLT_DICTIONARY: return valueSystem.valueAsDictionary(value) - case value._type == C.BOLT_LIST: + case valueType == C.BOLT_LIST: return valueSystem.valueAsList(value) - case value._type == C.BOLT_BYTES: + case valueType == C.BOLT_BYTES: return valueSystem.valueAsBytes(value), nil - case value._type == C.BOLT_STRUCTURE: - signature := int8(value.subtype) + case valueType == C.BOLT_STRUCTURE: + signature := int16(C.BoltStructure_code(value)) if handler, ok := valueSystem.valueHandlersBySignature[signature]; ok { listValue, err := valueSystem.structAsList(value) @@ -92,11 +94,11 @@ func (valueSystem *boltValueSystem) valueAsFloat(value *C.struct_BoltValue) floa func (valueSystem *boltValueSystem) valueAsString(value *C.struct_BoltValue) string { val := C.BoltString_get(value) - return C.GoStringN(val, C.int(value.size)) + return C.GoStringN(val, C.BoltValue_size(value)) } func (valueSystem *boltValueSystem) valueAsDictionary(value *C.struct_BoltValue) (map[string]interface{}, error) { - size := int(value.size) + size := int(C.BoltValue_size(value)) dict := make(map[string]interface{}, size) for i := 0; i < size; i++ { index := C.int32_t(i) @@ -112,7 +114,7 @@ func (valueSystem *boltValueSystem) valueAsDictionary(value *C.struct_BoltValue) } func (valueSystem *boltValueSystem) valueAsList(value *C.struct_BoltValue) ([]interface{}, error) { - size := int(value.size) + size := int(C.BoltValue_size(value)) list := make([]interface{}, size) for i := 0; i < size; i++ { index := C.int32_t(i) @@ -127,7 +129,7 @@ func (valueSystem *boltValueSystem) valueAsList(value *C.struct_BoltValue) ([]in } func (valueSystem *boltValueSystem) structAsList(value *C.struct_BoltValue) ([]interface{}, error) { - size := int(value.size) + size := int(C.BoltValue_size(value)) list := make([]interface{}, size) for i := 0; i < size; i++ { index := C.int32_t(i) @@ -143,7 +145,7 @@ func (valueSystem *boltValueSystem) structAsList(value *C.struct_BoltValue) ([]i func (valueSystem *boltValueSystem) valueAsBytes(value *C.struct_BoltValue) []byte { val := C.BoltBytes_get_all(value) - return C.GoBytes(unsafe.Pointer(val), C.int(value.size)) + return C.GoBytes(unsafe.Pointer(val), C.BoltValue_size(value)) } func (valueSystem *boltValueSystem) valueToConnector(value interface{}) *C.struct_BoltValue { diff --git a/value_handler.go b/value_handler.go index 34bd280..4cacb9d 100644 --- a/value_handler.go +++ b/value_handler.go @@ -27,10 +27,10 @@ import ( // ValueHandler is the interface that custom value handlers should implement to // support reading/writing struct types into custom types type ValueHandler interface { - ReadableStructs() []int8 + ReadableStructs() []int16 WritableTypes() []reflect.Type - Read(signature int8, values []interface{}) (interface{}, error) - Write(value interface{}) (int8, []interface{}, error) + Read(signature int16, values []interface{}) (interface{}, error) + Write(value interface{}) (int16, []interface{}, error) } // ValueHandlerError is the special error that ValueHandlers should return in