Skip to content

Commit

Permalink
basic support for prepared statements
Browse files Browse the repository at this point in the history
  • Loading branch information
marctrem committed Jul 21, 2021
1 parent 08ff1e7 commit 150b1fa
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 0 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,24 @@ gorqlite.TraceOn(os.Stderr)

// turn off
gorqlite.TraceOff()


// using prepared statements
wr, err := conn.WritePrepared(
[]*gorqlite.PreparedStatement{
{
Query: "INSERT INTO secret_agents(id, name, secret) VALUES(?, ?, ?)",
Arguments: []interface{}{7, "James Bond", []byte{0x42}}
}
}
)
// alternatively
wr, err := conn.WriteOnePrepared(
&gorqlite.PreparedStatement{
Query: "INSERT INTO secret_agents(id, name, secret) VALUES(?, ?, ?)",
Arguments: []interface{}{7, "James Bond", []byte{0x42}},
},
)
```
## Important Notes

Expand Down
107 changes: 107 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import "io/ioutil"
import "net/http"
import "time"

type PreparedStatement struct {
Query string
Arguments []interface{}
}

/* *****************************************************************
method: rqliteApiGet() - for api_STATUS
Expand Down Expand Up @@ -201,3 +206,105 @@ PeerLoop:
}
return responseBody, errors.New(stringBuffer.String())
}


func (conn *Connection) rqliteApiPostPrepared(apiOp apiOperation, sqlStatements []*PreparedStatement) ([]byte, error) {
var responseBody []byte

switch apiOp {
case api_QUERY:
trace("%s: rqliteApiGet() post called for a QUERY of %d statements", conn.ID, len(sqlStatements))
case api_WRITE:
trace("%s: rqliteApiGet() post called for a QUERY of %d statements", conn.ID, len(sqlStatements))
default:
return responseBody, errors.New("weird! called for an invalid apiOperation in rqliteApiPost()")
}

// jsonify the statements. not really needed in the
// case of api_STATUS but doesn't hurt


formattedStatements := make([][]interface{}, 0, len(sqlStatements))

for _, statement := range sqlStatements {
formattedStatement := make([]interface{}, 0, len(statement.Arguments)+1)
formattedStatement = append(formattedStatement, statement.Query)

for _, argument := range statement.Arguments {
formattedStatement = append(formattedStatement, argument)
}
formattedStatements = append(formattedStatements, formattedStatement)
}

jStatements, err := json.Marshal(formattedStatements)
if err != nil {
return nil, err
}

// just to be safe, check this
peersToTry := conn.cluster.makePeerList()
if len(peersToTry) < 1 {
return responseBody, errors.New("I don't have any cluster info")
}

// failure log is used so that if all peers fail, we can say something
// about why each failed
failureLog := make([]string, 0)

PeerLoop:
for peerNum, peer := range peersToTry {
trace("%s: trying peer #%d", conn.ID, peerNum)

// we're doing a post, and the RFCs say that if you get a 301, it's not
// automatically followed, so we have to do that ourselves

responseStatus := "Haven't Tried Yet"
var url string
for responseStatus == "Haven't Tried Yet" || responseStatus == "301 Moved Permanently" {
url = conn.assembleURL(apiOp, peer)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jStatements))
if err != nil {
trace("%s: got error '%s' doing http.NewRequest", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
response, err := client.Do(req)
if err != nil {
trace("%s: got error '%s' doing client.Do", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
defer response.Body.Close()
responseBody, err = ioutil.ReadAll(response.Body)
if err != nil {
trace("%s: got error '%s' doing ioutil.ReadAll", conn.ID, err.Error())
failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error()))
continue PeerLoop
}
responseStatus = response.Status
if responseStatus == "301 Moved Permanently" {
v := response.Header["Location"]
failureLog = append(failureLog, fmt.Sprintf("%s redirected me to %s", url, v[0]))
url = v[0]
continue PeerLoop
} else if responseStatus == "200 OK" {
trace("%s: api call OK, returning", conn.ID)
return responseBody, nil
} else {
trace("%s: got error in responseStatus: %s", conn.ID, responseStatus)
failureLog = append(failureLog, fmt.Sprintf("%s failed, got: %s", url, response.Status))
continue PeerLoop
}
}
}

// if we got here, all peers failed. Let's build a verbose error message
var stringBuffer bytes.Buffer
stringBuffer.WriteString("tried all peers unsuccessfully. here are the results:\n")
for n, v := range failureLog {
stringBuffer.WriteString(fmt.Sprintf(" peer #%d: %s\n", n, v))
}
return responseBody, errors.New(stringBuffer.String())
}
102 changes: 102 additions & 0 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ func (conn *Connection) WriteOne(sqlStatement string) (wr WriteResult, err error
return wra[0], err
}

func (conn *Connection) WriteOnePrepared(statement *PreparedStatement) (wr WriteResult, err error) {
if conn.hasBeenClosed {
wr.Err = errClosed
return wr, errClosed
}
wra, err := conn.WritePrepared([]*PreparedStatement{statement})
return wra[0], err
}


/*
Write() is used to perform DDL/DML in the database. ALTER, CREATE, DELETE, DROP, INSERT, UPDATE, etc. all go through Write().
Expand Down Expand Up @@ -171,6 +181,98 @@ func (conn *Connection) Write(sqlStatements []string) (results []WriteResult, er
}
}

func (conn *Connection) WritePrepared(sqlStatements []*PreparedStatement) (results []WriteResult, err error) {
results = make([]WriteResult, 0)

if conn.hasBeenClosed {
var errResult WriteResult
errResult.Err = errClosed
results = append(results, errResult)
return results, errClosed
}

trace("%s: Write() for %d statements", conn.ID, len(sqlStatements))

response, err := conn.rqliteApiPostPrepared(api_WRITE, sqlStatements)
if err != nil {
trace("%s: rqliteApiCall() ERROR: %s", conn.ID, err.Error())
var errResult WriteResult
errResult.Err = err
results = append(results, errResult)
return results, err
}
trace("%s: rqliteApiCall() OK", conn.ID)

var sections map[string]interface{}
err = json.Unmarshal(response, &sections)
if err != nil {
trace("%s: json.Unmarshal() ERROR: %s", conn.ID, err.Error())
var errResult WriteResult
errResult.Err = err
results = append(results, errResult)
return results, err
}

/*
at this point, we have a "results" section and
a "time" section. we can igore the latter.
*/

resultsArray, ok := sections["results"].([]interface{})
if !ok {
err = errors.New("Result key is missing from response")
trace("%s: sections[\"results\"] ERROR: %s", conn.ID, err)
var errResult WriteResult
errResult.Err = err
results = append(results, errResult)
return results, err
}
trace("%s: I have %d result(s) to parse", conn.ID, len(resultsArray))
numStatementErrors := 0
for n, k := range resultsArray {
trace("%s: starting on result %d", conn.ID, n)
thisResult := k.(map[string]interface{})

var thisWR WriteResult
thisWR.conn = conn

// did we get an error?
_, ok := thisResult["error"]
if ok {
trace("%s: have an error on this result: %s", conn.ID, thisResult["error"].(string))
thisWR.Err = errors.New(thisResult["error"].(string))
results = append(results, thisWR)
numStatementErrors += 1
continue
}

_, ok = thisResult["last_insert_id"]
if ok {
thisWR.LastInsertID = int64(thisResult["last_insert_id"].(float64))
}

_, ok = thisResult["rows_affected"] // could be zero for a CREATE
if ok {
thisWR.RowsAffected = int64(thisResult["rows_affected"].(float64))
}
_, ok = thisResult["time"] // could be nil
if ok {
thisWR.Timing = thisResult["time"].(float64)
}

trace("%s: this result (LII,RA,T): %d %d %f", conn.ID, thisWR.LastInsertID, thisWR.RowsAffected, thisWR.Timing)
results = append(results, thisWR)
}

trace("%s: finished parsing, returning %d results", conn.ID, len(results))

if numStatementErrors > 0 {
return results, errors.New(fmt.Sprintf("there were %d statement errors", numStatementErrors))
} else {
return results, nil
}
}

/* *****************************************************************
type: WriteResult
Expand Down

0 comments on commit 150b1fa

Please sign in to comment.