diff --git a/client/changes.go b/client/changes.go index 47fdfd84f..9e47feb98 100644 --- a/client/changes.go +++ b/client/changes.go @@ -89,7 +89,7 @@ type changeAndData struct { // Change fetches information about a Change given its ID. func (client *Client) Change(id string) (*Change, error) { var chgd changeAndData - _, err := client.doSync("GET", "/v1/changes/"+id, nil, nil, nil, &chgd) + err := client.doSync("GET", "/v1/changes/"+id, nil, nil, nil, &chgd) if err != nil { return nil, err } @@ -111,7 +111,7 @@ func (client *Client) Abort(id string) (*Change, error) { } var chg Change - if _, err := client.doSync("POST", "/v1/changes/"+id, nil, nil, &body, &chg); err != nil { + if err := client.doSync("POST", "/v1/changes/"+id, nil, nil, &body, &chg); err != nil { return nil, err } @@ -158,7 +158,7 @@ func (client *Client) Changes(opts *ChangesOptions) ([]*Change, error) { } var chgds []changeAndData - _, err := client.doSync("GET", "/v1/changes", query, nil, nil, &chgds) + err := client.doSync("GET", "/v1/changes", query, nil, nil, &chgds) if err != nil { return nil, err } @@ -190,7 +190,7 @@ func (client *Client) WaitChange(id string, opts *WaitChangeOptions) (*Change, e query.Set("timeout", opts.Timeout.String()) } - _, err := client.doSync("GET", "/v1/changes/"+id+"/wait", query, nil, nil, &chgd) + err := client.doSync("GET", "/v1/changes/"+id+"/wait", query, nil, nil, &chgd) if err != nil { return nil, err } diff --git a/client/checks.go b/client/checks.go index b86d656be..381532707 100644 --- a/client/checks.go +++ b/client/checks.go @@ -79,7 +79,7 @@ func (client *Client) Checks(opts *ChecksOptions) ([]*CheckInfo, error) { query["names"] = opts.Names } var checks []*CheckInfo - _, err := client.doSync("GET", "/v1/checks", query, nil, nil, &checks) + err := client.doSync("GET", "/v1/checks", query, nil, nil, &checks) if err != nil { return nil, err } diff --git a/client/client.go b/client/client.go index 0e9cce91a..ee639a037 100644 --- a/client/client.go +++ b/client/client.go @@ -18,173 +18,147 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "io/ioutil" - "net" "net/http" "net/url" - "os" - "path" "time" - - "github.com/gorilla/websocket" - - "github.com/canonical/pebble/internals/wsutil" ) -// SocketNotFoundError is the error type returned when the client fails -// to find a unix socket at the specified path. -type SocketNotFoundError struct { - // Err is the wrapped error. - Err error - - // Path is the path of the non-existent socket. - Path string -} - -func (s SocketNotFoundError) Error() string { - if s.Path == "" && s.Err != nil { - return s.Err.Error() - } - return fmt.Sprintf("socket %q not found", s.Path) -} - -func (s SocketNotFoundError) Unwrap() error { - return s.Err -} - -// decodeWithNumber decodes input data using json.Decoder, ensuring numbers are preserved -// via json.Number data type. It errors out on invalid json or any excess input. -func decodeWithNumber(r io.Reader, value interface{}) error { - dec := json.NewDecoder(r) - dec.UseNumber() - if err := dec.Decode(&value); err != nil { - return err - } - if dec.More() { - return fmt.Errorf("cannot parse json value") - } - return nil -} - -func unixDialer(socketPath string) func(string, string) (net.Conn, error) { - return func(_, _ string) (net.Conn, error) { - _, err := os.Stat(socketPath) - if errors.Is(err, os.ErrNotExist) { - return nil, &SocketNotFoundError{Err: err, Path: socketPath} - } - if err != nil { - return nil, fmt.Errorf("cannot stat %q: %w", socketPath, err) - } - - return net.Dial("unix", socketPath) - } -} - -type doer interface { - Do(*http.Request) (*http.Response, error) -} - -// Config allows the user to customize client behavior. -type Config struct { - // BaseURL contains the base URL where the Pebble daemon is expected to be. - // It can be empty for a default behavior of talking over a unix socket. - BaseURL string - - // Socket is the path to the unix socket to use. - Socket string - - // DisableKeepAlive indicates that the connections should not be kept - // alive for later reuse (the default is to keep them alive). - DisableKeepAlive bool - - // UserAgent is the User-Agent header sent to the Pebble daemon. - UserAgent string +// DecoderFunc allows the client access to the HTTP response. See the SetDecoder +// description for more details. +type DecoderFunc func(ctx context.Context, res *http.Response, opts *RequestOptions, result interface{}) (*RequestResponse, error) + +type Requester interface { + // Do must support the following cases: + // + // 1. Sync response + // 2. Async response + // 3. Error response + // 4. Websocket creation request + // 5. Raw HTTP body request + // + // See the default implementation for information on what is expected in + // each use case. + Do(ctx context.Context, opts *RequestOptions, result interface{}) (*RequestResponse, error) + + // SetDecoder allows for client specific processing to be hooked into + // the sync and async response decoding process. The decoder is also + // responsible for unmarshalling the result, and populating the + // RequestReponse. + SetDecoder(decoder DecoderFunc) +} + +// RequestOptions allows setting up a specific request. +type RequestOptions struct { + Method string + Path string + Query url.Values + Headers map[string]string + Body io.Reader + Async bool +} + +// RequestResponse defines a common response associated with requests. +type RequestResponse struct { + StatusCode int + Type string + Change string } // A Client knows how to talk to the Pebble daemon. type Client struct { - baseURL url.URL - doer doer - userAgent string - - maintenance error + Requester Requester + maintenance error warningCount int warningTimestamp time.Time - - getWebsocket getWebsocketFunc } -type getWebsocketFunc func(url string) (clientWebsocket, error) - -type clientWebsocket interface { - wsutil.MessageReader - wsutil.MessageWriter - io.Closer - jsonWriter +func New(requester Requester) (*Client, error) { + client := &Client{Requester: requester} + client.Requester.SetDecoder(client.decoder) + return client, nil } -type jsonWriter interface { - WriteJSON(v interface{}) error -} +// decoder receives a raw HTTP response and performs internal client +// processing, as well as unmarshalling the custom result. +func (client *Client) decoder(ctx context.Context, rsp *http.Response, opts *RequestOptions, result interface{}) (*RequestResponse, error) { + var serverResp response + if err := decodeInto(rsp.Body, &serverResp); err != nil { + return nil, err + } -func New(config *Config) (*Client, error) { - if config == nil { - config = &Config{} + // Update the maintenance error state + if serverResp.Maintenance != nil { + client.maintenance = serverResp.Maintenance + } else { + client.maintenance = nil } - var client *Client - var transport *http.Transport + // Deal with error type response + if err := serverResp.err(); err != nil { + return nil, err + } - if config.BaseURL == "" { - // By default talk over a unix socket. - transport = &http.Transport{Dial: unixDialer(config.Socket), DisableKeepAlives: config.DisableKeepAlive} - baseURL := url.URL{Scheme: "http", Host: "localhost"} - client = &Client{baseURL: baseURL} - } else { - // Otherwise talk regular HTTP-over-TCP. - baseURL, err := url.Parse(config.BaseURL) - if err != nil { - return nil, fmt.Errorf("cannot parse base URL: %v", err) + // Warnings are only included if not an error type response + client.warningCount = serverResp.WarningCount + client.warningTimestamp = serverResp.WarningTimestamp + + // Decode the supplied result type + if result != nil { + if err := decodeWithNumber(bytes.NewReader(serverResp.Result), result); err != nil { + return nil, fmt.Errorf("cannot unmarshal: %w", err) } - transport = &http.Transport{DisableKeepAlives: config.DisableKeepAlive} - client = &Client{baseURL: *baseURL} } - client.doer = &http.Client{Transport: transport} - client.userAgent = config.UserAgent - client.getWebsocket = func(url string) (clientWebsocket, error) { - return getWebsocket(transport, url) - } + // Common response + return &RequestResponse{ + StatusCode: serverResp.StatusCode, + Type: serverResp.Type, + Change: serverResp.Change, + }, nil +} - return client, nil +type jsonWriter interface { + WriteJSON(v interface{}) error } -func (client *Client) getTaskWebsocket(taskID, websocketID string) (clientWebsocket, error) { +func (client *Client) getTaskWebsocket(taskID, websocketID string) (Websocket, error) { url := fmt.Sprintf("ws://localhost/v1/tasks/%s/websocket/%s", taskID, websocketID) - return client.getWebsocket(url) + var ws Websocket + _, err :=client.Requester.Do(context.Background(), &RequestOptions{Path: url}, &ws) + if err != nil { + return nil, err + } + return ws, nil } -func getWebsocket(transport *http.Transport, url string) (clientWebsocket, error) { - dialer := websocket.Dialer{ - NetDial: transport.Dial, - Proxy: transport.Proxy, - TLSClientConfig: transport.TLSClientConfig, - HandshakeTimeout: 5 * time.Second, - } - conn, _, err := dialer.Dial(url, nil) - return conn, err +func (client *Client) doSync(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) error { + _, err := client.Requester.Do(context.Background(), &RequestOptions{ + Method: method, + Path: path, + Query: query, + Headers: headers, + Body: body, + }, v) + return err } -// CloseIdleConnections closes any API connections that are currently unused. -func (client *Client) CloseIdleConnections() { - c, ok := client.doer.(*http.Client) - if ok { - c.CloseIdleConnections() +func (client *Client) doAsync(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) (changeID string, err error) { + rsp, err := client.Requester.Do(context.Background(), &RequestOptions{ + Method: method, + Path: path, + Query: query, + Headers: headers, + Body: body, + Async: true, + }, v) + if err != nil { + return "", err } + return rsp.Change, nil } // Maintenance returns an error reflecting the daemon maintenance status or nil. @@ -199,120 +173,19 @@ func (client *Client) WarningsSummary() (count int, timestamp time.Time) { return client.warningCount, client.warningTimestamp } -// RequestError is returned when there's an error processing the request. -type RequestError struct{ error } - -func (e RequestError) Error() string { - return fmt.Sprintf("cannot build request: %v", e.error) -} - -// ConnectionError represents a connection or communication error. -type ConnectionError struct { - error -} - -func (e ConnectionError) Error() string { - return fmt.Sprintf("cannot communicate with server: %v", e.error) -} - -func (e ConnectionError) Unwrap() error { - return e.error -} - -// raw performs a request and returns the resulting http.Response and -// error you usually only need to call this directly if you expect the -// response to not be JSON, otherwise you'd call Do(...) instead. -func (client *Client) raw(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader) (*http.Response, error) { - // fake a url to keep http.Client happy - u := client.baseURL - u.Path = path.Join(client.baseURL.Path, urlpath) - u.RawQuery = query.Encode() - req, err := http.NewRequestWithContext(ctx, method, u.String(), body) - if err != nil { - return nil, RequestError{err} - } - if client.userAgent != "" { - req.Header.Set("User-Agent", client.userAgent) - } - - for key, value := range headers { - req.Header.Set(key, value) - } - - rsp, err := client.doer.Do(req) - if err != nil { - return nil, ConnectionError{err} - } - - return rsp, nil -} - -var ( - doRetry = 250 * time.Millisecond - doTimeout = 5 * time.Second -) - // FakeDoRetry fakes the delays used by the do retry loop (intended for // testing). Calling restore will revert the changes. func FakeDoRetry(retry, timeout time.Duration) (restore func()) { - oldRetry := doRetry - oldTimeout := doTimeout - doRetry = retry - doTimeout = timeout + oldRetry := rawRetry + oldTimeout := rawTimeout + rawRetry = retry + rawTimeout = timeout return func() { - doRetry = oldRetry - doTimeout = oldTimeout + rawRetry = oldRetry + rawTimeout = oldTimeout } } -type hijacked struct { - do func(*http.Request) (*http.Response, error) -} - -func (h hijacked) Do(req *http.Request) (*http.Response, error) { - return h.do(req) -} - -// Hijack lets the caller take over the raw HTTP request. -func (client *Client) Hijack(f func(*http.Request) (*http.Response, error)) { - client.doer = hijacked{f} -} - -// do performs a request and decodes the resulting json into the given -// value. It's low-level, for testing/experimenting only; you should -// usually use a higher level interface that builds on this. -func (client *Client) do(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) error { - retry := time.NewTicker(doRetry) - defer retry.Stop() - timeout := time.After(doTimeout) - var rsp *http.Response - var err error - for { - rsp, err = client.raw(context.Background(), method, path, query, headers, body) - if err == nil || method != "GET" { - break - } - select { - case <-retry.C: - continue - case <-timeout: - } - break - } - if err != nil { - return err - } - defer rsp.Body.Close() - - if v != nil { - if err := decodeInto(rsp.Body, v); err != nil { - return err - } - } - - return nil -} - func decodeInto(reader io.Reader, v interface{}) error { dec := json.NewDecoder(reader) if err := dec.Decode(v); err != nil { @@ -326,66 +199,6 @@ func decodeInto(reader io.Reader, v interface{}) error { return nil } -// doSync performs a request to the given path using the specified HTTP method. -// It expects a "sync" response from the API and on success decodes the JSON -// response payload into the given value using the "UseNumber" json decoding -// which produces json.Numbers instead of float64 types for numbers. -func (client *Client) doSync(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) (*ResultInfo, error) { - var rsp response - if err := client.do(method, path, query, headers, body, &rsp); err != nil { - return nil, err - } - if err := rsp.err(client); err != nil { - return nil, err - } - if rsp.Type != "sync" { - return nil, fmt.Errorf("expected sync response, got %q", rsp.Type) - } - - if v != nil { - if err := decodeWithNumber(bytes.NewReader(rsp.Result), v); err != nil { - return nil, fmt.Errorf("cannot unmarshal: %w", err) - } - } - - client.warningCount = rsp.WarningCount - client.warningTimestamp = rsp.WarningTimestamp - - return &rsp.ResultInfo, nil -} - -func (client *Client) doAsync(method, path string, query url.Values, headers map[string]string, body io.Reader) (changeID string, err error) { - _, changeID, err = client.doAsyncFull(method, path, query, headers, body) - return -} - -func (client *Client) doAsyncFull(method, path string, query url.Values, headers map[string]string, body io.Reader) (result json.RawMessage, changeID string, err error) { - var rsp response - - if err := client.do(method, path, query, headers, body, &rsp); err != nil { - return nil, "", err - } - if err := rsp.err(client); err != nil { - return nil, "", err - } - if rsp.Type != "async" { - return nil, "", fmt.Errorf("expected async response for %q on %q, got %q", method, path, rsp.Type) - } - if rsp.StatusCode != 202 { - return nil, "", fmt.Errorf("operation not accepted") - } - if rsp.Change == "" { - return nil, "", fmt.Errorf("async response without change reference") - } - - return rsp.Result, rsp.Change, nil -} - -// ResultInfo is empty for now, but this is the mechanism that conveys -// general information that makes sense to requests at a more general -// level, and might be disconnected from the specific request at hand. -type ResultInfo struct{} - // A response produced by the REST API will usually fit in this // (exceptions are the icons/ endpoints obvs) type response struct { @@ -398,8 +211,6 @@ type response struct { WarningCount int `json:"warning-count"` WarningTimestamp time.Time `json:"warning-timestamp"` - ResultInfo - Maintenance *Error `json:"maintenance"` } @@ -423,16 +234,8 @@ const ( ErrorKindNoDefaultServices = "no-default-services" ) -func (rsp *response) err(cli *Client) error { - if cli != nil { - maintErr := rsp.Maintenance - // avoid setting to (*client.Error)(nil) - if maintErr != nil { - cli.maintenance = maintErr - } else { - cli.maintenance = nil - } - } +// err extract the error in case of an error type response +func (rsp *response) err() error { if rsp.Type != "error" { return nil } @@ -446,24 +249,6 @@ func (rsp *response) err(cli *Client) error { return &resultErr } -func parseError(r *http.Response) error { - var rsp response - if r.Header.Get("Content-Type") != "application/json" { - return fmt.Errorf("server error: %q", r.Status) - } - - dec := json.NewDecoder(r.Body) - if err := dec.Decode(&rsp); err != nil { - return fmt.Errorf("cannot unmarshal error: %w", err) - } - - err := rsp.err(nil) - if err == nil { - return fmt.Errorf("server error: %q", r.Status) - } - return err -} - type SysInfo struct { // Version is the server version. Version string `json:"version,omitempty"` @@ -476,7 +261,7 @@ type SysInfo struct { func (client *Client) SysInfo() (*SysInfo, error) { var sysInfo SysInfo - if _, err := client.doSync("GET", "/v1/system-info", nil, nil, nil, &sysInfo); err != nil { + if err := client.doSync("GET", "/v1/system-info", nil, nil, nil, &sysInfo); err != nil { return nil, fmt.Errorf("cannot obtain system details: %w", err) } @@ -498,7 +283,7 @@ func (client *Client) DebugPost(action string, params interface{}, result interf return err } - _, err = client.doSync("POST", "/v1/debug", nil, nil, bytes.NewReader(body), result) + err = client.doSync("POST", "/v1/debug", nil, nil, bytes.NewReader(body), result) return err } @@ -508,6 +293,6 @@ func (client *Client) DebugGet(action string, result interface{}, params map[str for k, v := range params { urlParams.Set(k, v) } - _, err := client.doSync("GET", "/v1/debug", urlParams, nil, nil, &result) + err := client.doSync("GET", "/v1/debug", urlParams, nil, nil, &result) return err } diff --git a/client/exec.go b/client/exec.go index 88d405bd4..b8b28a9ff 100644 --- a/client/exec.go +++ b/client/exec.go @@ -152,14 +152,10 @@ func (client *Client) Exec(opts *ExecOptions) (*ExecProcess, error) { headers := map[string]string{ "Content-Type": "application/json", } - resultBytes, changeID, err := client.doAsyncFull("POST", "/v1/exec", nil, headers, &body) - if err != nil { - return nil, err - } var result execResult - err = json.Unmarshal(resultBytes, &result) + changeID, err := client.doAsync("POST", "/v1/exec", nil, headers, &body, &result) if err != nil { - return nil, fmt.Errorf("cannot unmarshal JSON response: %w", err) + return nil, err } // Connect to the "control" websocket. @@ -178,7 +174,7 @@ func (client *Client) Exec(opts *ExecOptions) (*ExecProcess, error) { stdoutDone := wsutil.WebsocketRecvStream(stdout, ioConn) // Handle stderr separately if needed. - var stderrConn clientWebsocket + var stderrConn Websocket var stderrDone chan bool if opts.Stderr != nil { stderrConn, err = client.getTaskWebsocket(taskID, "stderr") diff --git a/client/files.go b/client/files.go index cce33414d..168a2c65b 100644 --- a/client/files.go +++ b/client/files.go @@ -121,7 +121,7 @@ func (client *Client) ListFiles(opts *ListFilesOptions) ([]*FileInfo, error) { } var results []fileInfoResult - _, err := client.doSync("GET", "/v1/files", q, nil, nil, &results) + err := client.doSync("GET", "/v1/files", q, nil, nil, &results) if err != nil { return nil, err } @@ -280,7 +280,7 @@ func (client *Client) MakeDir(opts *MakeDirOptions) error { headers := map[string]string{ "Content-Type": "application/json", } - if _, err := client.doSync("POST", "/v1/files", nil, headers, &body, &result); err != nil { + if err := client.doSync("POST", "/v1/files", nil, headers, &body, &result); err != nil { return err } @@ -347,7 +347,7 @@ func (client *Client) RemovePath(opts *RemovePathOptions) error { headers := map[string]string{ "Content-Type": "application/json", } - if _, err := client.doSync("POST", "/v1/files", nil, headers, &body, &result); err != nil { + if err := client.doSync("POST", "/v1/files", nil, headers, &body, &result); err != nil { return err } diff --git a/client/logs.go b/client/logs.go index 408b3bc4b..f374cf725 100644 --- a/client/logs.go +++ b/client/logs.go @@ -73,13 +73,18 @@ func (client *Client) logs(ctx context.Context, opts *LogsOptions, follow bool) if follow { query.Set("follow", "true") } - res, err := client.raw(ctx, "GET", "/v1/logs", query, nil, nil) + var body BodyReader + _, err := client.Requester.Do(ctx, &RequestOptions{ + Method: "GET", + Path: "/v1/logs", + Query: query, + }, &body) if err != nil { return err } - defer res.Body.Close() + defer body.Close() - reader := bufio.NewReaderSize(res.Body, logReaderSize) + reader := bufio.NewReaderSize(body, logReaderSize) for { err = decodeLog(reader, opts.WriteLog) if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { diff --git a/client/plan.go b/client/plan.go index 25b447f08..46b53daa6 100644 --- a/client/plan.go +++ b/client/plan.go @@ -52,7 +52,7 @@ func (client *Client) AddLayer(opts *AddLayerOptions) error { if err := json.NewEncoder(&body).Encode(&payload); err != nil { return err } - _, err := client.doSync("POST", "/v1/layers", nil, nil, &body, nil) + err := client.doSync("POST", "/v1/layers", nil, nil, &body, nil) return err } @@ -64,7 +64,7 @@ func (client *Client) PlanBytes(_ *PlanOptions) (data []byte, err error) { "format": []string{"yaml"}, } var dataStr string - _, err = client.doSync("GET", "/v1/plan", query, nil, nil, &dataStr) + err = client.doSync("GET", "/v1/plan", query, nil, nil, &dataStr) if err != nil { return nil, err } diff --git a/client/requester.go b/client/requester.go new file mode 100644 index 000000000..603e1abd4 --- /dev/null +++ b/client/requester.go @@ -0,0 +1,292 @@ +// Copyright (c) 2023 Canonical Ltd +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License version 3 as +// published by the Free Software Foundation. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "path" + "time" + + "github.com/gorilla/websocket" + + "github.com/canonical/pebble/internals/wsutil" +) + +var ( + rawRetry = 250 * time.Millisecond + rawTimeout = 5 * time.Second +) + +type BodyReader interface { + io.ReadCloser +} + +// websocket defines a minimal compliant interface which the requester will +// interpret as a websocket creation request. +type Websocket interface { + wsutil.MessageReader + wsutil.MessageWriter + io.Closer + WriteJSON(v interface{}) error +} + +type BaseRequesterConfig struct { + // BaseURL contains the base URL where the Pebble daemon is expected to be. + // It can be empty for a default behavior of talking over a unix socket. + BaseURL string + + // Socket is the path to the unix socket to use. + Socket string + + // DisableKeepAlive indicates that the connections should not be kept + // alive for later reuse (the default is to keep them alive). + DisableKeepAlive bool + + // UserAgent is the User-Agent header sent to the Pebble daemon. + UserAgent string +} + +type BaseRequester struct { + baseURL url.URL + doer doer + userAgent string + transport *http.Transport + decoder DecoderFunc +} + +func NewBaseRequester(opts *BaseRequesterConfig) (*BaseRequester, error) { + if opts == nil { + opts = &BaseRequesterConfig{} + } + + var requester *BaseRequester + + if opts.BaseURL == "" { + // By default talk over a unix socket. + transport := &http.Transport{Dial: unixDialer(opts.Socket), DisableKeepAlives: opts.DisableKeepAlive} + baseURL := url.URL{Scheme: "http", Host: "localhost"} + requester = &BaseRequester{baseURL: baseURL, transport: transport} + } else { + // Otherwise talk regular HTTP-over-TCP. + baseURL, err := url.Parse(opts.BaseURL) + if err != nil { + return nil, fmt.Errorf("cannot parse base URL: %v", err) + } + transport := &http.Transport{DisableKeepAlives: opts.DisableKeepAlive} + requester = &BaseRequester{baseURL: *baseURL, transport: transport} + } + + requester.doer = &http.Client{Transport: requester.transport} + requester.userAgent = opts.UserAgent + + return requester, nil +} + +func (br *BaseRequester) SetDecoder(decoder DecoderFunc) { + br.decoder = decoder +} + +// rawOnce is unchanged from the client implementation +func (br *BaseRequester) raw(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader) (*http.Response, error) { + // fake a url to keep http.Client happy + u := br.baseURL + u.Path = path.Join(br.baseURL.Path, urlpath) + u.RawQuery = query.Encode() + req, err := http.NewRequestWithContext(ctx, method, u.String(), body) + if err != nil { + return nil, RequestError{err} + } + if br.userAgent != "" { + req.Header.Set("User-Agent", br.userAgent) + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + rsp, err := br.doer.Do(req) + if err != nil { + return nil, ConnectionError{err} + } + + return rsp, nil +} + +// raw builds in a retry mechanism for GET failures (body-less request) +func (br *BaseRequester) rawWithRetry(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader) (*http.Response, error) { + retry := time.NewTicker(rawRetry) + defer retry.Stop() + timeout := time.After(rawTimeout) + var rsp *http.Response + var err error + for { + rsp, err = br.raw(ctx, method, urlpath, query, headers, body) + if err == nil || method != "GET" { + break + } + select { + case <-retry.C: + continue + case <-timeout: + } + break + } + if err != nil { + return nil, err + } + return rsp, nil +} + +func (br *BaseRequester) Do(ctx context.Context, opts *RequestOptions, result interface{}) (*RequestResponse, error) { + // Is the result expecting a websocket? + if ws, ok := result.(*Websocket); ok { + dialer := websocket.Dialer{ + NetDial: br.transport.Dial, + Proxy: br.transport.Proxy, + TLSClientConfig: br.transport.TLSClientConfig, + HandshakeTimeout: 5 * time.Second, + } + conn, _, err := dialer.DialContext(ctx, opts.Path, nil) + if err != nil { + return nil, err + } + *ws = conn + return nil, nil + } + + // Is the result expecting a caller managed body reader? + if bodyReader, ok := result.(*BodyReader); ok { + httpResp, err := br.raw(ctx, opts.Method, opts.Path, opts.Query, opts.Headers, opts.Body) + if err != nil { + return nil, err + } + + *bodyReader = httpResp.Body + return nil, nil + } + + // This is a normal sync or async server request + httpResp, err := br.rawWithRetry(ctx, opts.Method, opts.Path, opts.Query, opts.Headers, opts.Body) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + + // Get the client decoder to extract what it needs before we proceed + reqResp, err := br.decoder(ctx, httpResp, opts, result) + if err != nil { + return nil, err + } + + // Sanity check sync and async requests + if opts.Async == false { + if reqResp.Type != "sync" { + return nil, fmt.Errorf("expected sync response, got %q", reqResp.Type) + } + } else { + if reqResp.Type != "async" { + return nil, fmt.Errorf("expected async response for %q on %q, got %q", opts.Method, opts.Path, reqResp.Type) + } + if reqResp.StatusCode != 202 { + return nil, fmt.Errorf("operation not accepted") + } + if reqResp.Change == "" { + return nil, fmt.Errorf("async response without change reference") + } + } + + return reqResp, nil +} + +// SocketNotFoundError is the error type returned when the client fails +// to find a unix socket at the specified path. +type SocketNotFoundError struct { + // Err is the wrapped error. + Err error + + // Path is the path of the non-existent socket. + Path string +} + +func (s SocketNotFoundError) Error() string { + if s.Path == "" && s.Err != nil { + return s.Err.Error() + } + return fmt.Sprintf("socket %q not found", s.Path) +} + +func (s SocketNotFoundError) Unwrap() error { + return s.Err +} + +// decodeWithNumber decodes input data using json.Decoder, ensuring numbers are preserved +// via json.Number data type. It errors out on invalid json or any excess input. +func decodeWithNumber(r io.Reader, value interface{}) error { + dec := json.NewDecoder(r) + dec.UseNumber() + if err := dec.Decode(&value); err != nil { + return err + } + if dec.More() { + return fmt.Errorf("cannot parse json value") + } + return nil +} + +func unixDialer(socketPath string) func(string, string) (net.Conn, error) { + return func(_, _ string) (net.Conn, error) { + _, err := os.Stat(socketPath) + if errors.Is(err, os.ErrNotExist) { + return nil, &SocketNotFoundError{Err: err, Path: socketPath} + } + if err != nil { + return nil, fmt.Errorf("cannot stat %q: %w", socketPath, err) + } + + return net.Dial("unix", socketPath) + } +} + +type doer interface { + Do(*http.Request) (*http.Response, error) +} + +// RequestError is returned when there's an error processing the request. +type RequestError struct{ error } + +func (e RequestError) Error() string { + return fmt.Sprintf("cannot build request: %v", e.error) +} + +// ConnectionError represents a connection or communication error. +type ConnectionError struct { + error +} + +func (e ConnectionError) Error() string { + return fmt.Sprintf("cannot communicate with server: %v", e.error) +} + +func (e ConnectionError) Unwrap() error { + return e.error +} diff --git a/client/services.go b/client/services.go index 0be19d23e..cb04ff6b4 100644 --- a/client/services.go +++ b/client/services.go @@ -30,33 +30,33 @@ type ServiceOptions struct { // AutoStart starts the services makes as "startup: enabled". opts.Names must // be empty for this call. func (client *Client) AutoStart(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("autostart", opts.Names) + changeID, err = client.doMultiServiceAction("autostart", opts.Names) return changeID, err } // Start starts the services named in opts.Names in dependency order. func (client *Client) Start(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("start", opts.Names) + changeID, err = client.doMultiServiceAction("start", opts.Names) return changeID, err } // Stop stops the services named in opts.Names in dependency order. func (client *Client) Stop(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("stop", opts.Names) + changeID, err = client.doMultiServiceAction("stop", opts.Names) return changeID, err } // Restart stops and then starts the services named in opts.Names in // dependency order. func (client *Client) Restart(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("restart", opts.Names) + changeID, err = client.doMultiServiceAction("restart", opts.Names) return changeID, err } // Replan stops and (re)starts the services whose configuration has changed // since they were started. opts.Names must be empty for this call. func (client *Client) Replan(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("replan", opts.Names) + changeID, err = client.doMultiServiceAction("replan", opts.Names) return changeID, err } @@ -65,19 +65,19 @@ type multiActionData struct { Services []string `json:"services"` } -func (client *Client) doMultiServiceAction(actionName string, services []string) (result json.RawMessage, changeID string, err error) { +func (client *Client) doMultiServiceAction(actionName string, services []string) (changeID string, err error) { action := multiActionData{ Action: actionName, Services: services, } data, err := json.Marshal(&action) if err != nil { - return nil, "", fmt.Errorf("cannot marshal multi-service action: %w", err) + return "", fmt.Errorf("cannot marshal multi-service action: %w", err) } headers := map[string]string{ "Content-Type": "application/json", } - return client.doAsyncFull("POST", "/v1/services", nil, headers, bytes.NewBuffer(data)) + return client.doAsync("POST", "/v1/services", nil, headers, bytes.NewBuffer(data), nil) } type ServicesOptions struct { @@ -119,7 +119,7 @@ func (client *Client) Services(opts *ServicesOptions) ([]*ServiceInfo, error) { "names": []string{strings.Join(opts.Names, ",")}, } var services []*ServiceInfo - _, err := client.doSync("GET", "/v1/services", query, nil, nil, &services) + err := client.doSync("GET", "/v1/services", query, nil, nil, &services) if err != nil { return nil, err } diff --git a/client/signals.go b/client/signals.go index 9452ad4f8..91be4ba1e 100644 --- a/client/signals.go +++ b/client/signals.go @@ -36,7 +36,7 @@ func (client *Client) SendSignal(opts *SendSignalOptions) error { if err != nil { return fmt.Errorf("cannot encode JSON payload: %w", err) } - _, err = client.doSync("POST", "/v1/signals", nil, nil, &body, nil) + err = client.doSync("POST", "/v1/signals", nil, nil, &body, nil) return err } diff --git a/client/warnings.go b/client/warnings.go index d63da1238..fd2d19bbd 100644 --- a/client/warnings.go +++ b/client/warnings.go @@ -52,7 +52,7 @@ func (client *Client) Warnings(opts WarningsOptions) ([]*Warning, error) { if opts.All { q.Add("select", "all") } - _, err := client.doSync("GET", "/v1/warnings", q, nil, nil, &jws) + err := client.doSync("GET", "/v1/warnings", q, nil, nil, &jws) ws := make([]*Warning, len(jws)) for i, jw := range jws { @@ -77,6 +77,6 @@ func (client *Client) Okay(t time.Time) error { if err := json.NewEncoder(&body).Encode(op); err != nil { return err } - _, err := client.doSync("POST", "/v1/warnings", nil, nil, &body, nil) + err := client.doSync("POST", "/v1/warnings", nil, nil, &body, nil) return err } diff --git a/internals/cli/cli.go b/internals/cli/cli.go index 9d7f23fd6..0daff6493 100644 --- a/internals/cli/cli.go +++ b/internals/cli/cli.go @@ -243,8 +243,7 @@ var ( osExit = os.Exit ) -// ClientConfig is the configuration of the Client used by all commands. -var clientConfig client.Config +var requesterConfig client.BaseRequesterConfig // exitStatus can be used in panic(&exitStatus{code}) to cause Pebble's main // function to exit with a given exit code, for the rare cases when you want @@ -270,9 +269,13 @@ func Run() error { logger.SetLogger(logger.New(os.Stderr, "[pebble] ")) - _, clientConfig.Socket = getEnvPaths() + _, requesterConfig.Socket = getEnvPaths() - cli, err := client.New(&clientConfig) + requester, err := client.NewBaseRequester(&requesterConfig) + if err != nil { + return fmt.Errorf("cannot create requester: %v", err) + } + cli, err := client.New(requester) if err != nil { return fmt.Errorf("cannot create client: %v", err) } diff --git a/internals/cli/cmd_run.go b/internals/cli/cmd_run.go index e32bb4b36..4d971a71e 100644 --- a/internals/cli/cmd_run.go +++ b/internals/cli/cmd_run.go @@ -239,9 +239,6 @@ out: } } - // Close our own self-connection, otherwise it prevents fast and clean termination. - rcmd.client.CloseIdleConnections() - return d.Stop(ch) }