diff --git a/README.md b/README.md index 892293a..f8afbb2 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,24 @@ gorqlite.TraceOn(os.Stderr) // turn off gorqlite.TraceOff() + + +// using prepared statements +wr, err := conn.WritePrepared( + []*gorqlite.PreparedStatement{ + { + Query: "INSERT INTO secret_agents(id, name, secret) VALUES(?, ?, ?)", + Arguments: []interface{}{7, "James Bond", []byte{0x42}} + } + } +) +// alternatively +wr, err := conn.WriteOnePrepared( + &gorqlite.PreparedStatement{ + Query: "INSERT INTO secret_agents(id, name, secret) VALUES(?, ?, ?)", + Arguments: []interface{}{7, "James Bond", []byte{0x42}}, + }, +) ``` ## Important Notes diff --git a/api.go b/api.go index 54937e0..ce0cc5a 100644 --- a/api.go +++ b/api.go @@ -21,6 +21,11 @@ import "io/ioutil" import "net/http" import "time" +type PreparedStatement struct { + Query string + Arguments []interface{} +} + /* ***************************************************************** method: rqliteApiGet() - for api_STATUS @@ -201,3 +206,105 @@ PeerLoop: } return responseBody, errors.New(stringBuffer.String()) } + + +func (conn *Connection) rqliteApiPostPrepared(apiOp apiOperation, sqlStatements []*PreparedStatement) ([]byte, error) { + var responseBody []byte + + switch apiOp { + case api_QUERY: + trace("%s: rqliteApiGet() post called for a QUERY of %d statements", conn.ID, len(sqlStatements)) + case api_WRITE: + trace("%s: rqliteApiGet() post called for a QUERY of %d statements", conn.ID, len(sqlStatements)) + default: + return responseBody, errors.New("weird! called for an invalid apiOperation in rqliteApiPost()") + } + + // jsonify the statements. not really needed in the + // case of api_STATUS but doesn't hurt + + + formattedStatements := make([][]interface{}, 0, len(sqlStatements)) + + for _, statement := range sqlStatements { + formattedStatement := make([]interface{}, 0, len(statement.Arguments)+1) + formattedStatement = append(formattedStatement, statement.Query) + + for _, argument := range statement.Arguments { + formattedStatement = append(formattedStatement, argument) + } + formattedStatements = append(formattedStatements, formattedStatement) + } + + jStatements, err := json.Marshal(formattedStatements) + if err != nil { + return nil, err + } + + // just to be safe, check this + peersToTry := conn.cluster.makePeerList() + if len(peersToTry) < 1 { + return responseBody, errors.New("I don't have any cluster info") + } + + // failure log is used so that if all peers fail, we can say something + // about why each failed + failureLog := make([]string, 0) + +PeerLoop: + for peerNum, peer := range peersToTry { + trace("%s: trying peer #%d", conn.ID, peerNum) + + // we're doing a post, and the RFCs say that if you get a 301, it's not + // automatically followed, so we have to do that ourselves + + responseStatus := "Haven't Tried Yet" + var url string + for responseStatus == "Haven't Tried Yet" || responseStatus == "301 Moved Permanently" { + url = conn.assembleURL(apiOp, peer) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jStatements)) + if err != nil { + trace("%s: got error '%s' doing http.NewRequest", conn.ID, err.Error()) + failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error())) + continue PeerLoop + } + req.Header.Set("Content-Type", "application/json") + client := &http.Client{} + response, err := client.Do(req) + if err != nil { + trace("%s: got error '%s' doing client.Do", conn.ID, err.Error()) + failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error())) + continue PeerLoop + } + defer response.Body.Close() + responseBody, err = ioutil.ReadAll(response.Body) + if err != nil { + trace("%s: got error '%s' doing ioutil.ReadAll", conn.ID, err.Error()) + failureLog = append(failureLog, fmt.Sprintf("%s failed due to %s", url, err.Error())) + continue PeerLoop + } + responseStatus = response.Status + if responseStatus == "301 Moved Permanently" { + v := response.Header["Location"] + failureLog = append(failureLog, fmt.Sprintf("%s redirected me to %s", url, v[0])) + url = v[0] + continue PeerLoop + } else if responseStatus == "200 OK" { + trace("%s: api call OK, returning", conn.ID) + return responseBody, nil + } else { + trace("%s: got error in responseStatus: %s", conn.ID, responseStatus) + failureLog = append(failureLog, fmt.Sprintf("%s failed, got: %s", url, response.Status)) + continue PeerLoop + } + } + } + + // if we got here, all peers failed. Let's build a verbose error message + var stringBuffer bytes.Buffer + stringBuffer.WriteString("tried all peers unsuccessfully. here are the results:\n") + for n, v := range failureLog { + stringBuffer.WriteString(fmt.Sprintf(" peer #%d: %s\n", n, v)) + } + return responseBody, errors.New(stringBuffer.String()) +} \ No newline at end of file diff --git a/write.go b/write.go index 5716927..153ea4d 100644 --- a/write.go +++ b/write.go @@ -70,6 +70,16 @@ func (conn *Connection) WriteOne(sqlStatement string) (wr WriteResult, err error return wra[0], err } +func (conn *Connection) WriteOnePrepared(statement *PreparedStatement) (wr WriteResult, err error) { + if conn.hasBeenClosed { + wr.Err = errClosed + return wr, errClosed + } + wra, err := conn.WritePrepared([]*PreparedStatement{statement}) + return wra[0], err +} + + /* Write() is used to perform DDL/DML in the database. ALTER, CREATE, DELETE, DROP, INSERT, UPDATE, etc. all go through Write(). @@ -171,6 +181,98 @@ func (conn *Connection) Write(sqlStatements []string) (results []WriteResult, er } } +func (conn *Connection) WritePrepared(sqlStatements []*PreparedStatement) (results []WriteResult, err error) { + results = make([]WriteResult, 0) + + if conn.hasBeenClosed { + var errResult WriteResult + errResult.Err = errClosed + results = append(results, errResult) + return results, errClosed + } + + trace("%s: Write() for %d statements", conn.ID, len(sqlStatements)) + + response, err := conn.rqliteApiPostPrepared(api_WRITE, sqlStatements) + if err != nil { + trace("%s: rqliteApiCall() ERROR: %s", conn.ID, err.Error()) + var errResult WriteResult + errResult.Err = err + results = append(results, errResult) + return results, err + } + trace("%s: rqliteApiCall() OK", conn.ID) + + var sections map[string]interface{} + err = json.Unmarshal(response, §ions) + if err != nil { + trace("%s: json.Unmarshal() ERROR: %s", conn.ID, err.Error()) + var errResult WriteResult + errResult.Err = err + results = append(results, errResult) + return results, err + } + + /* + at this point, we have a "results" section and + a "time" section. we can igore the latter. + */ + + resultsArray, ok := sections["results"].([]interface{}) + if !ok { + err = errors.New("Result key is missing from response") + trace("%s: sections[\"results\"] ERROR: %s", conn.ID, err) + var errResult WriteResult + errResult.Err = err + results = append(results, errResult) + return results, err + } + trace("%s: I have %d result(s) to parse", conn.ID, len(resultsArray)) + numStatementErrors := 0 + for n, k := range resultsArray { + trace("%s: starting on result %d", conn.ID, n) + thisResult := k.(map[string]interface{}) + + var thisWR WriteResult + thisWR.conn = conn + + // did we get an error? + _, ok := thisResult["error"] + if ok { + trace("%s: have an error on this result: %s", conn.ID, thisResult["error"].(string)) + thisWR.Err = errors.New(thisResult["error"].(string)) + results = append(results, thisWR) + numStatementErrors += 1 + continue + } + + _, ok = thisResult["last_insert_id"] + if ok { + thisWR.LastInsertID = int64(thisResult["last_insert_id"].(float64)) + } + + _, ok = thisResult["rows_affected"] // could be zero for a CREATE + if ok { + thisWR.RowsAffected = int64(thisResult["rows_affected"].(float64)) + } + _, ok = thisResult["time"] // could be nil + if ok { + thisWR.Timing = thisResult["time"].(float64) + } + + trace("%s: this result (LII,RA,T): %d %d %f", conn.ID, thisWR.LastInsertID, thisWR.RowsAffected, thisWR.Timing) + results = append(results, thisWR) + } + + trace("%s: finished parsing, returning %d results", conn.ID, len(results)) + + if numStatementErrors > 0 { + return results, errors.New(fmt.Sprintf("there were %d statement errors", numStatementErrors)) + } else { + return results, nil + } +} + /* ***************************************************************** type: WriteResult