Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix go2SqlDataType was returning too large of a varbinary #74

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ type Conn struct {

credentials
freetdsVersionGte095 bool
mssqlVersion2005 bool
}

func (conn *Conn) addMessage(msg string, msgno int) {
Expand Down Expand Up @@ -165,6 +166,12 @@ func (conn *Conn) connect() (*Conn, error) {
conn.close()
return nil, err
}

var sql_runtime_err = conn.readMSSQLVersion()
if sql_runtime_err != nil {
return nil, err
}

//log.Printf("freetds connected to %s@%s.%s", conn.user, conn.host, conn.database)
return conn, nil
}
Expand Down Expand Up @@ -235,6 +242,7 @@ func (conn *Conn) getDbProc() (*C.DBPROCESS, error) {
return nil, dbProcError("dbopen error")
}
conn.readFreeTdsVersion()

return dbproc, nil
}

Expand All @@ -244,6 +252,29 @@ func (conn *Conn) readFreeTdsVersion() {
conn.setFreetdsVersionGte095(freeTdsVersion)
}

func (conn *Conn) readMSSQLVersion() error {
var results, err = conn.ExecuteSql("SELECT @@VERSION")
if err != nil {
return err
}

var version_string string

if len(results) > 1 {
var result = results[0]
if result.Next() {
var err = result.Scan(&version_string)

if err != nil {
return err
}
}
}

conn.setMSSQLVersion(version_string)
return nil
}

func dbProcError(msg string) error {
return fmt.Errorf("%s\n%s\n%s", msg, lastError, lastMessage)
}
Expand Down Expand Up @@ -467,6 +498,14 @@ func (conn *Conn) setFreetdsVersionGte095(freeTdsVersion []int) {
}
}

func (conn *Conn) setMSSQLVersion(version string) {
conn.mssqlVersion2005 = false
version = strings.ToLower(version)
if version[0:25] == "microsoft sql server 2005" {
conn.mssqlVersion2005 = true
}
}

func parseFreeTdsVersion(dbVersion string) []int {
rxFreeTdsVersion := regexp.MustCompile(`v(\d+).(\d+).(\d+)`)
//log.Println("FreeTDS Version: ", dbVersion)
Expand Down
19 changes: 14 additions & 5 deletions executesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (conn *Conn) ExecuteSql(query string, params ...driver.Value) ([]*Result, e
if numParams != len(params) {
return nil, fmt.Errorf("Incorrect number of params, expecting %d got %d", numParams, len(params))
}
paramDef, paramVal, err := parseParams(params...)
paramDef, paramVal, err := parseParams(conn.mssqlVersion2005, params...)
if err != nil {
return nil, err
}
Expand All @@ -55,7 +55,7 @@ func (conn *Conn) executeSqlSybase125(query string, params ...driver.Value) ([]*
matches := re.FindAllSubmatchIndex([]byte(sql), -1)

for i, _ := range matches {
_, escapedValue, _ := go2SqlDataType(params[i])
_, escapedValue, _ := go2SqlDataType(conn.mssqlVersion2005, params[i])
sql = fmt.Sprintf("%s", strings.Replace(sql, "$bindkey", escapedValue, 1))
}

Expand All @@ -81,15 +81,15 @@ func query2Statement(query string) (string, int) {
return quote(statement), numParams
}

func parseParams(params ...driver.Value) (string, string, error) {
func parseParams(sql_2005 bool, params ...driver.Value) (string, string, error) {
paramDef := ""
paramVal := ""
for i, param := range params {
if i > 0 {
paramVal += ", "
paramDef += ", "
}
sqlType, sqlValue, err := go2SqlDataType(param)
sqlType, sqlValue, err := go2SqlDataType(sql_2005, param)
if err != nil {
return "", "", err
}
Expand All @@ -104,7 +104,7 @@ func quote(in string) string {
return strings.Replace(in, "'", "''", -1)
}

func go2SqlDataType(value interface{}) (string, string, error) {
func go2SqlDataType(sql_2005 bool, value interface{}) (string, string, error) {
max := func(a int, b int) int {
if a > b {
return a
Expand Down Expand Up @@ -136,11 +136,20 @@ func go2SqlDataType(value interface{}) (string, string, error) {
case time.Time:
{
strValue = t.Format(time.RFC3339Nano)
if sql_2005 {
return "datetime", fmt.Sprintf("'%s'", quote(strValue)), nil
}

return "datetimeoffset", fmt.Sprintf("'%s'", quote(strValue)), nil
}
case []byte:
{
b, _ := value.([]byte)
if len(b) > 8000 {
return "varbinary (MAX)",
fmt.Sprintf("0x%x", b), nil
}

return fmt.Sprintf("varbinary (%d)", max(1, len(b))),
fmt.Sprintf("0x%x", b), nil
}
Expand Down
58 changes: 31 additions & 27 deletions executesql_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package freetds

import (
"strings"
"testing"
"time"

Expand All @@ -12,29 +13,30 @@ const (
)

func TestGoTo2SqlDataType2(t *testing.T) {
var checker = func(value interface{}, sqlType string, sqlFormatedValue string) {
actualSqlType, actualSqlFormatedValue, err := go2SqlDataType(value)
var checker = func(sql_2005 bool, value interface{}, sqlType string, sqlFormatedValue string) {
actualSqlType, actualSqlFormatedValue, err := go2SqlDataType(sql_2005, value)
assert.Nil(t, err)
assert.Equal(t, actualSqlType, sqlType)
assert.Equal(t, actualSqlFormatedValue, sqlFormatedValue)
}

checker(123, "int", "123")
checker(int64(123), "bigint", "123")
checker(int16(123), "smallint", "123")
checker(int8(123), "tinyint", "123")
checker(123.23, "real", "123.23")
checker(float64(123.23), "real", "123.23")
checker(false, 123, "int", "123")
checker(false, int64(123), "bigint", "123")
checker(false, int16(123), "smallint", "123")
checker(false, int8(123), "tinyint", "123")
checker(false, 123.23, "real", "123.23")
checker(false, float64(123.23), "real", "123.23")

checker("iso medo", "nvarchar (8)", "'iso medo'")
checker("iso medo isn't", "nvarchar (14)", "'iso medo isn''t'")
checker(false, "iso medo", "nvarchar (8)", "'iso medo'")
checker(false, "iso medo isn't", "nvarchar (14)", "'iso medo isn''t'")

tm := time.Unix(1136239445, 0)
paris, _ := time.LoadLocation("Europe/Paris")

checker(tm.In(paris), "datetimeoffset", "'"+tm.In(paris).Format(sqlDateTimeOffSet)+"'")
checker(false, tm.In(paris), "datetimeoffset", "'"+tm.In(paris).Format(sqlDateTimeOffSet)+"'")
checker(true, tm.In(paris), "datetime", "'"+tm.In(paris).Format(sqlDateTimeOffSet)+"'")

checker([]byte{1, 2, 3, 4, 5, 6, 7, 8}, "varbinary (8)", "0x0102030405060708")
checker(false, []byte{1, 2, 3, 4, 5, 6, 7, 8}, "varbinary (8)", "0x0102030405060708")

//go2SqlDataType(t)
}
Expand All @@ -58,32 +60,34 @@ func TestQuery2Statement(t *testing.T) {
}

func TestGoTo2SqlDataType(t *testing.T) {
var checker = func(value interface{}, sqlType string, sqlFormatedValue string) {
actualSqlType, actualSqlFormatedValue, err := go2SqlDataType(value)
var checker = func(sql_2005 bool, value interface{}, sqlType string, sqlFormatedValue string) {
actualSqlType, actualSqlFormatedValue, err := go2SqlDataType(sql_2005, value)
assert.Nil(t, err)
assert.Equal(t, actualSqlType, sqlType)
assert.Equal(t, actualSqlFormatedValue, sqlFormatedValue)
}

checker(123, "int", "123")
checker(int64(123), "bigint", "123")
checker(int8(123), "tinyint", "123")
checker(123.23, "real", "123.23")
checker(float64(123.23), "real", "123.23")
checker(false, 123, "int", "123")
checker(false, int64(123), "bigint", "123")
checker(false, int8(123), "tinyint", "123")
checker(false, 123.23, "real", "123.23")
checker(false, float64(123.23), "real", "123.23")

checker("iso medo", "nvarchar (8)", "'iso medo'")
checker("iso medo isn't", "nvarchar (14)", "'iso medo isn''t'")
checker(false, "iso medo", "nvarchar (8)", "'iso medo'")
checker(false, "iso medo isn't", "nvarchar (14)", "'iso medo isn''t'")

tm := time.Unix(1136239445, 0)
paris, _ := time.LoadLocation("Europe/Paris")

checker(tm.In(paris), "datetimeoffset", "'"+tm.In(paris).Format(sqlDateTimeOffSet)+"'")
checker(false, tm.In(paris), "datetimeoffset", "'"+tm.In(paris).Format(sqlDateTimeOffSet)+"'")
checker(true, tm.In(paris), "datetime", "'"+tm.In(paris).Format(sqlDateTimeOffSet)+"'")

checker([]byte{1, 2, 3, 4, 5, 6, 7, 8}, "varbinary (8)", "0x0102030405060708")
checker(false, []byte{1, 2, 3, 4, 5, 6, 7, 8}, "varbinary (8)", "0x0102030405060708")
checker(false, make([]byte, 8001), "varbinary (MAX)", "0x" + strings.Repeat("00", 8001))

checker("", "nvarchar (1)", "''")
checker(true, "bit", "1")
checker(false, "bit", "0")
checker(false, "", "nvarchar (1)", "''")
checker(false, true, "bit", "1")
checker(false, false, "bit", "0")
}

func TestExecuteSqlNumberOfParams(t *testing.T) {
Expand All @@ -94,7 +98,7 @@ func TestExecuteSqlNumberOfParams(t *testing.T) {
}

func TestParseParams(t *testing.T) {
def, val, err := parseParams(1, 2, "pero")
def, val, err := parseParams(false, 1, 2, "pero")
assert.Nil(t, err)
assert.Equal(t, def, "@p1 int, @p2 int, @p3 nvarchar (4)")
assert.Equal(t, val, "@p1=1, @p2=2, @p3='pero'")
Expand Down