diff --git a/callbacks.go b/callbacks.go index 6d6b23b..4e622b4 100644 --- a/callbacks.go +++ b/callbacks.go @@ -36,7 +36,7 @@ func errHandler(dbprocAddr C.long, severity, dberr, oserr C.int, dberrstr, oserr conn := getConnection(int64(dbprocAddr)) if conn != nil { - conn.addError(err) + conn.addError(err, int(dberr)) } //fmt.Printf("err: %s", err) diff --git a/conn.go b/conn.go index e8cc296..631594a 100644 --- a/conn.go +++ b/conn.go @@ -94,6 +94,9 @@ type Conn struct { messageNums map[int]int messageMutex sync.RWMutex + errors map[int]string + errorsMutex sync.RWMutex + currentResult *Result expiresFromPool time.Time belongsToPool *ConnPool @@ -120,7 +123,12 @@ func (conn *Conn) addMessage(msg string, msgno int) { conn.messageNums[msgno] = i + 1 } -func (conn *Conn) addError(err string) { +func (conn *Conn) addError(err string, errno int) { + conn.errorsMutex.Lock() + defer conn.errorsMutex.Unlock() + + conn.errors[errno] = err + if len(conn.Error) > 0 { conn.Error += "\n" } @@ -141,6 +149,7 @@ func connectWithCredentials(crd *credentials) (*Conn, error) { spParamsCache: NewParamsCache(), credentials: *crd, messageNums: make(map[int]int), + errors: make(map[int]string), } err := conn.reconnect() if err != nil { @@ -188,6 +197,14 @@ func (conn *Conn) close() { } } +// Remove a pooled connection from it's pool. +func (conn *Conn) RemoveFromPool() *Conn { + if conn.belongsToPool != nil { + conn.belongsToPool.Remove(conn) + } + return conn +} + //ensure only one getDbProc at a time var getDbProcMutex = &sync.Mutex{} @@ -260,11 +277,14 @@ func (conn *Conn) DbUse() error { func (conn *Conn) clearMessages() { conn.messageMutex.Lock() - defer conn.messageMutex.Unlock() - conn.Error = "" + conn.errors = make(map[int]string) + conn.messageMutex.Unlock() + + conn.errorsMutex.Lock() conn.Message = "" conn.messageNums = make(map[int]int) + conn.errorsMutex.Unlock() } //Returns the number of occurances of a supplied FreeTDS message number. @@ -276,6 +296,15 @@ func (conn *Conn) HasMessageNumber(msgno int) int { return count } +//Returns the error string for a supplied FreeTDS error number. +//if the error has not occurred then an empty string and false is returned. +func (conn *Conn) HasErrorNumber(errno int) (string, bool) { + conn.errorsMutex.RLock() + err, found := conn.errors[errno] + conn.errorsMutex.RUnlock() + return err, found +} + //Execute sql query. func (conn *Conn) Exec(sql string) ([]*Result, error) { results, err := conn.exec(sql) diff --git a/conn_pool.go b/conn_pool.go index a48bd76..cfd8064 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -167,6 +167,17 @@ func (p *ConnPool) addToPool(conn *Conn) { } } +// Remove a pooled connection from the pool, +// forcing a new connection to take it's place. +func (p *ConnPool) Remove(conn *Conn) { + if conn.belongsToPool != p { + return + } + p.connCount-- + <-p.poolGuard //remove reservation + conn.belongsToPool = nil +} + //Release connection to the pool. func (p *ConnPool) Release(conn *Conn) { if conn.belongsToPool != p { diff --git a/conn_pool_test.go b/conn_pool_test.go index eeda49b..2025255 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -3,6 +3,7 @@ package freetds import ( "fmt" "github.com/stretchr/testify/assert" + "sync" "testing" "time" ) @@ -168,3 +169,106 @@ func TestConnPoolDo(t *testing.T) { assert.Equal(t, 1, len(p.pool)) assert.Equal(t, 1, p.connCount) } + +func TestPoolRemove_TwoSizedPool(t *testing.T) { + p, _ := NewConnPool(testDbConnStr(2)) + assert.Equal(t, p.connCount, 1) + c1, _ := p.Get() + c2, _ := p.Get() + // The pool has used 2 connections (c1, c2) + assert.Equal(t, p.connCount, 2) + // ...and has no unused connections in it. + assert.Equal(t, len(p.pool), 0) + + // Remove c1 from the pool + p.Remove(c1) + assert.Nil(t, c1.belongsToPool) + // The pool has 1 used connection (c2) + assert.Equal(t, p.connCount, 1) + // ...and still has no unused connections in it. + assert.Equal(t, len(p.pool), 0) + + // Trying to release the removed conn is a safe noop + p.Release(c1) + c1.Close() + + // Get another connection from the pool. + // The pool will need to create this connection + // as it has no unused connections. + c3, _ := p.Get() + + // The pool has used 2 connections (c2, c3) + assert.Equal(t, p.connCount, 2) + // ...and has no unused connections in it. + assert.Equal(t, len(p.pool), 0) + + p.Release(c2) + p.Release(c3) +} + +func TestPoolRemove_OneSizedPool(t *testing.T) { + p, _ := NewConnPool(testDbConnStr(1)) + assert.Equal(t, p.connCount, 1) + c1, _ := p.Get() + // The pool has used 1 connection (c1) + assert.Equal(t, p.connCount, 1) + // ...and has no unused connections in it. + assert.Equal(t, len(p.pool), 0) + + var wg sync.WaitGroup + chGotConn := make(chan *Conn) + + // This goroutine attemps to Get another connection from the pool. + wg.Add(1) + go func() { + // The call to p.Get() will block until c1 is + // removed from the pool by the other goroutine. + c2, _ := p.Get() + assert.NotNil(t, c2) + chGotConn <- c2 + wg.Done() + }() + + wg.Add(1) + go func() { + // Sleep the goroutine for 2 seconds, + // keeping c1 in the pool. + time.Sleep(2 * time.Second) + // Now remove c1 from the pool. + // This will allow the call to p.Get() + // in the other goroutine to unblock. + // Note the use of the RemoveFromPool() func on the conn itself. + c1.RemoveFromPool().Close() + wg.Done() + }() + + // Wait to receive c2, or timeout. + select { + case c2 := <-chGotConn: + // The pool has used 1 connection (c2) + assert.Equal(t, p.connCount, 1) + // ...and has no unused connections in it. + assert.Equal(t, len(p.pool), 0) + p.Release(c2) + + case <-time.After(5 * time.Second): + assert.Fail(t, "timed out waiting for a pooled connection") + } + + wg.Wait() + close(chGotConn) +} + +func TestPoolRemove_OnConn(t *testing.T) { + p, _ := NewConnPool(testDbConnStr(2)) + assert.Equal(t, p.connCount, 1) + c1, _ := p.Get() + + c1.RemoveFromPool() + assert.Nil(t, c1.belongsToPool) + // Calling RemoveFromPool again is a safe noop + c1.RemoveFromPool() + // Trying to remove the already removed conn is a safe noop + p.Remove(c1) + c1.Close() +} diff --git a/conn_sp.go b/conn_sp.go index d58882e..5a621f4 100644 --- a/conn_sp.go +++ b/conn_sp.go @@ -49,7 +49,7 @@ func (conn *Conn) ExecSp(spName string, params ...interface{}) (*SpResult, error return nil, err } for i, spParam := range spParams { - //get datavalue for the suplied stored procedure parametar + //get datavalue for the suplied stored procedure parameter var datavalue *C.BYTE datalen := 0 if i < len(params) { @@ -57,7 +57,7 @@ func (conn *Conn) ExecSp(spName string, params ...interface{}) (*SpResult, error if param != nil { data, sqlDatalen, err := typeToSqlBuf(int(spParam.UserTypeId), param, conn.freetdsVersionGte095) if err != nil { - conn.Close() //close the connection + conn.close() //hard close the connection, if pooled don't return it. return nil, err } if len(data) > 0 { diff --git a/conn_test.go b/conn_test.go index eb3f46a..0757c68 100644 --- a/conn_test.go +++ b/conn_test.go @@ -516,3 +516,41 @@ func TestMessageNumbers(t *testing.T) { assert.Equal(t, c.HasMessageNumber(msgnumOne), 2) assert.Equal(t, c.HasMessageNumber(msgnumTwo), 1) } + +// Also run with "go test --race" for race condition checking. +func TestErrors(t *testing.T) { + const errnumOne = 11111 + const errnumOneMessage = "errnum-1, goroutine-" + const errnumTwo = 22222 + const errnumTwoMessage = "errnum-2, goroutine-" + + c := &Conn{ + errors: make(map[int]string), + } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + c.addError((errnumOneMessage + "alpha"), errnumOne) + wg.Done() + }() + + wg.Add(1) + go func() { + c.addError((errnumOneMessage + "beta"), errnumOne) + c.addError((errnumTwoMessage + "beta"), errnumTwo) + wg.Done() + }() + + wg.Wait() + // The most recent error using errnumOne will overwrite the previous error. + str, found := c.HasErrorNumber(errnumOne) + assert.True(t, found) + assert.Contains(t, str, errnumOneMessage) + + // Only one error using errnumTwo was raised. + str, found = c.HasErrorNumber(errnumTwo) + assert.True(t, found) + assert.Contains(t, str, errnumTwoMessage) +} diff --git a/result.go b/result.go index fa86fd6..48a0ece 100644 --- a/result.go +++ b/result.go @@ -98,7 +98,7 @@ func (r *Result) MustScan(cnt int, dest ...interface{}) error { return err } if cnt != r.scanCount { - return errors.New(fmt.Sprintf("Worng scan count, expected %d, actual %d.", cnt, r.scanCount)) + return errors.New(fmt.Sprintf("Wrong scan count, expected %d, actual %d.", cnt, r.scanCount)) } return nil } @@ -132,7 +132,7 @@ func (r *Result) ScanColumn(name string, dest interface{}) error { err = convertAssign(dest, r.Rows[r.currentRow][i]) if err != nil { - return err + return fmt.Errorf("%s, column '%s'", err.Error(), name) } return nil @@ -146,7 +146,7 @@ func (r *Result) scanStruct(s *reflect.Value) error { if f.IsValid() { if f.CanSet() { if err := convertAssign(f.Addr().Interface(), r.Rows[r.currentRow][i]); err != nil { - return err + return fmt.Errorf("%s, column '%s'", err.Error(), col.Name) } r.scanCount++ } diff --git a/sp_result.go b/sp_result.go index 64c6c8f..8bfba8a 100644 --- a/sp_result.go +++ b/sp_result.go @@ -89,7 +89,7 @@ func (r *SpResult) Next() bool { return rst.Next() } -//Sacaning output parameters of stored procedure +//Scanning output parameters of stored procedure func (r *SpResult) ParamScan(values ...interface{}) error { outputValues := make([]interface{}, len(r.outputParams)) for i := 0; i < len(r.outputParams); i++ {