diff --git a/client/client.go b/client/client.go index 6fa4b9d64..33084e7bf 100644 --- a/client/client.go +++ b/client/client.go @@ -34,6 +34,59 @@ import ( "github.com/canonical/pebble/internals/wsutil" ) +type Requester interface { + // Do performs the HTTP transaction using the provided options. + Do(ctx context.Context, opts *RequestOptions) (*RequestResponse, error) + + // Transport returns the HTTP transport in use by the underlying HTTP client. + Transport() *http.Transport +} + +type RequestType int + +const ( + RawRequest RequestType = iota + SyncRequest + AsyncRequest +) + +type RequestOptions struct { + Type RequestType + Method string + Path string + Query url.Values + Headers map[string]string + Body io.Reader +} + +type RequestResponse struct { + StatusCode int + // ChangeID is typically set when an AsyncRequest type is performed. The + // change id allows for introspection and progress tracking of the request. + ChangeID string + // Result can contain request specific JSON data. The result can be + // unmarshalled into the expected type using the DecodeResult method. + Result []byte + // Body is only set for request type RawRequest. + Body io.ReadCloser +} + +// DecodeResult decodes the endpoint-specific result payload that is included as part of +// sync and async request responses. The decoding is performed with the standard JSON +// package, so the usual field tags should be used to prepare the type for decoding. +func (resp *RequestResponse) DecodeResult(result interface{}) error { + reader := bytes.NewReader(resp.Result) + dec := json.NewDecoder(reader) + dec.UseNumber() + if err := dec.Decode(&result); err != nil { + return fmt.Errorf("cannot unmarshal: %w", err) + } + if dec.More() { + return fmt.Errorf("cannot unmarshal: cannot parse json value") + } + return nil +} + // SocketNotFoundError is the error type returned when the client fails // to find a unix socket at the specified path. type SocketNotFoundError struct { @@ -55,20 +108,6 @@ func (s SocketNotFoundError) Unwrap() error { return s.Err } -// decodeWithNumber decodes input data using json.Decoder, ensuring numbers are preserved -// via json.Number data type. It errors out on invalid json or any excess input. -func decodeWithNumber(r io.Reader, value interface{}) error { - dec := json.NewDecoder(r) - dec.UseNumber() - if err := dec.Decode(&value); err != nil { - return err - } - if dec.More() { - return fmt.Errorf("cannot parse json value") - } - return nil -} - func unixDialer(socketPath string) func(string, string) (net.Conn, error) { return func(_, _ string) (net.Conn, error) { _, err := os.Stat(socketPath) @@ -106,12 +145,9 @@ type Config struct { // A Client knows how to talk to the Pebble daemon. type Client struct { - baseURL url.URL - doer doer - userAgent string - - maintenance error + requester Requester + maintenance error warningCount int warningTimestamp time.Time @@ -136,33 +172,24 @@ func New(config *Config) (*Client, error) { config = &Config{} } - var client *Client - var transport *http.Transport - - if config.BaseURL == "" { - // By default talk over a unix socket. - transport = &http.Transport{Dial: unixDialer(config.Socket), DisableKeepAlives: config.DisableKeepAlive} - baseURL := url.URL{Scheme: "http", Host: "localhost"} - client = &Client{baseURL: baseURL} - } else { - // Otherwise talk regular HTTP-over-TCP. - baseURL, err := url.Parse(config.BaseURL) - if err != nil { - return nil, fmt.Errorf("cannot parse base URL: %v", err) - } - transport = &http.Transport{DisableKeepAlives: config.DisableKeepAlive} - client = &Client{baseURL: *baseURL} + client := &Client{} + requester, err := newDefaultRequester(client, config) + if err != nil { + return nil, err } - client.doer = &http.Client{Transport: transport} - client.userAgent = config.UserAgent + client.requester = requester client.getWebsocket = func(url string) (clientWebsocket, error) { - return getWebsocket(transport, url) + return getWebsocket(requester.Transport(), url) } return client, nil } +func (client *Client) Requester() Requester { + return client.requester +} + func (client *Client) getTaskWebsocket(taskID, websocketID string) (clientWebsocket, error) { url := fmt.Sprintf("ws://localhost/v1/tasks/%s/websocket/%s", taskID, websocketID) return client.getWebsocket(url) @@ -181,10 +208,7 @@ func getWebsocket(transport *http.Transport, url string) (clientWebsocket, error // CloseIdleConnections closes any API connections that are currently unused. func (client *Client) CloseIdleConnections() { - c, ok := client.doer.(*http.Client) - if ok { - c.CloseIdleConnections() - } + client.Requester().Transport().CloseIdleConnections() } // Maintenance returns an error reflecting the daemon maintenance status or nil. @@ -219,27 +243,24 @@ func (e ConnectionError) Unwrap() error { return e.error } -// raw performs a request and returns the resulting http.Response and -// error you usually only need to call this directly if you expect the -// response to not be JSON, otherwise you'd call Do(...) instead. -func (client *Client) raw(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader) (*http.Response, error) { +func (rq *defaultRequester) dispatch(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader) (*http.Response, error) { // fake a url to keep http.Client happy - u := client.baseURL - u.Path = path.Join(client.baseURL.Path, urlpath) + u := rq.baseURL + u.Path = path.Join(rq.baseURL.Path, urlpath) u.RawQuery = query.Encode() req, err := http.NewRequestWithContext(ctx, method, u.String(), body) if err != nil { return nil, RequestError{err} } - if client.userAgent != "" { - req.Header.Set("User-Agent", client.userAgent) + if rq.userAgent != "" { + req.Header.Set("User-Agent", rq.userAgent) } for key, value := range headers { req.Header.Set(key, value) } - rsp, err := client.doer.Do(req) + rsp, err := rq.doer.Do(req) if err != nil { return nil, ConnectionError{err} } @@ -265,30 +286,15 @@ func FakeDoRetry(retry, timeout time.Duration) (restore func()) { } } -type hijacked struct { - do func(*http.Request) (*http.Response, error) -} - -func (h hijacked) Do(req *http.Request) (*http.Response, error) { - return h.do(req) -} - -// Hijack lets the caller take over the raw HTTP request. -func (client *Client) Hijack(f func(*http.Request) (*http.Response, error)) { - client.doer = hijacked{f} -} - -// do performs a request and decodes the resulting json into the given -// value. It's low-level, for testing/experimenting only; you should -// usually use a higher level interface that builds on this. -func (client *Client) do(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) error { +// retry builds in a retry mechanism for GET failures. +func (rq *defaultRequester) retry(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader) (*http.Response, error) { retry := time.NewTicker(doRetry) defer retry.Stop() timeout := time.After(doTimeout) var rsp *http.Response var err error for { - rsp, err = client.raw(context.Background(), method, path, query, headers, body) + rsp, err = rq.dispatch(ctx, method, urlpath, query, headers, body) if err == nil || method != "GET" { break } @@ -300,17 +306,78 @@ func (client *Client) do(method, path string, query url.Values, headers map[stri break } if err != nil { - return err + return nil, err } - defer rsp.Body.Close() + return rsp, nil +} - if v != nil { - if err := decodeInto(rsp.Body, v); err != nil { - return err +// Do performs the HTTP request according to the provided options, possibly retrying GET requests +// if appropriate for the status reported by the server. +func (rq *defaultRequester) Do(ctx context.Context, opts *RequestOptions) (*RequestResponse, error) { + httpResp, err := rq.retry(ctx, opts.Method, opts.Path, opts.Query, opts.Headers, opts.Body) + if err != nil { + return nil, err + } + + // Is the result expecting a caller-managed raw body? + if opts.Type == RawRequest { + return &RequestResponse{Body: httpResp.Body}, nil + } + + defer httpResp.Body.Close() + var serverResp response + if err := decodeInto(httpResp.Body, &serverResp); err != nil { + return nil, err + } + + // Update the maintenance error state + if serverResp.Maintenance != nil { + rq.client.maintenance = serverResp.Maintenance + } else { + // We cannot assign a nil pointer of type *Error to an + // interface here because the interface is only nil if + // both the type and value is nil. + // https://go.dev/doc/faq#nil_error + rq.client.maintenance = nil + } + + // Deal with error type response + if err := serverResp.err(); err != nil { + return nil, err + } + + // At this point only sync and async type requests may exist so lets + // make sure this is the case. + switch opts.Type { + case SyncRequest: + if serverResp.Type != "sync" { + return nil, fmt.Errorf("expected sync response, got %q", serverResp.Type) + } + case AsyncRequest: + if serverResp.Type != "async" { + return nil, fmt.Errorf("expected async response for %q on %q, got %q", opts.Method, opts.Path, serverResp.Type) } + if serverResp.StatusCode != http.StatusAccepted { + return nil, fmt.Errorf("operation not accepted") + } + if serverResp.Change == "" { + return nil, fmt.Errorf("async response without change reference") + } + default: + return nil, fmt.Errorf("cannot process unknown request type") } - return nil + // Warnings are only included if not an error type response, so we don't + // replace valid local warnings with an empty state that comes from a failure. + rq.client.warningCount = serverResp.WarningCount + rq.client.warningTimestamp = serverResp.WarningTimestamp + + // Common response + return &RequestResponse{ + StatusCode: serverResp.StatusCode, + ChangeID: serverResp.Change, + Result: serverResp.Result, + }, nil } func decodeInto(reader io.Reader, v interface{}) error { @@ -326,66 +393,48 @@ func decodeInto(reader io.Reader, v interface{}) error { return nil } -// doSync performs a request to the given path using the specified HTTP method. -// It expects a "sync" response from the API and on success decodes the JSON -// response payload into the given value using the "UseNumber" json decoding -// which produces json.Numbers instead of float64 types for numbers. -func (client *Client) doSync(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) (*ResultInfo, error) { - var rsp response - if err := client.do(method, path, query, headers, body, &rsp); err != nil { - return nil, err - } - if err := rsp.err(client); err != nil { +func (client *Client) doSync(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) (*RequestResponse, error) { + resp, err := client.Requester().Do(context.Background(), &RequestOptions{ + Type: SyncRequest, + Method: method, + Path: path, + Query: query, + Headers: headers, + Body: body, + }) + if err != nil { return nil, err } - if rsp.Type != "sync" { - return nil, fmt.Errorf("expected sync response, got %q", rsp.Type) - } - if v != nil { - if err := decodeWithNumber(bytes.NewReader(rsp.Result), v); err != nil { - return nil, fmt.Errorf("cannot unmarshal: %w", err) + err = resp.DecodeResult(v) + if err != nil { + return nil, err } } - - client.warningCount = rsp.WarningCount - client.warningTimestamp = rsp.WarningTimestamp - - return &rsp.ResultInfo, nil -} - -func (client *Client) doAsync(method, path string, query url.Values, headers map[string]string, body io.Reader) (changeID string, err error) { - _, changeID, err = client.doAsyncFull(method, path, query, headers, body) - return + return resp, nil } -func (client *Client) doAsyncFull(method, path string, query url.Values, headers map[string]string, body io.Reader) (result json.RawMessage, changeID string, err error) { - var rsp response - - if err := client.do(method, path, query, headers, body, &rsp); err != nil { - return nil, "", err - } - if err := rsp.err(client); err != nil { - return nil, "", err - } - if rsp.Type != "async" { - return nil, "", fmt.Errorf("expected async response for %q on %q, got %q", method, path, rsp.Type) - } - if rsp.StatusCode != 202 { - return nil, "", fmt.Errorf("operation not accepted") +func (client *Client) doAsync(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) (*RequestResponse, error) { + resp, err := client.Requester().Do(context.Background(), &RequestOptions{ + Type: AsyncRequest, + Method: method, + Path: path, + Query: query, + Headers: headers, + Body: body, + }) + if err != nil { + return nil, err } - if rsp.Change == "" { - return nil, "", fmt.Errorf("async response without change reference") + if v != nil { + err = resp.DecodeResult(v) + if err != nil { + return nil, err + } } - - return rsp.Result, rsp.Change, nil + return resp, nil } -// ResultInfo is empty for now, but this is the mechanism that conveys -// general information that makes sense to requests at a more general -// level, and might be disconnected from the specific request at hand. -type ResultInfo struct{} - // A response produced by the REST API will usually fit in this // (exceptions are the icons/ endpoints obvs) type response struct { @@ -398,8 +447,6 @@ type response struct { WarningCount int `json:"warning-count"` WarningTimestamp time.Time `json:"warning-timestamp"` - ResultInfo - Maintenance *Error `json:"maintenance"` } @@ -427,16 +474,8 @@ const ( ErrorKindDaemonRestart = "daemon-restart" ) -func (rsp *response) err(cli *Client) error { - if cli != nil { - maintErr := rsp.Maintenance - // avoid setting to (*client.Error)(nil) - if maintErr != nil { - cli.maintenance = maintErr - } else { - cli.maintenance = nil - } - } +// err extracts the error in case of an error type response +func (rsp *response) err() error { if rsp.Type != "error" { return nil } @@ -461,7 +500,7 @@ func parseError(r *http.Response) error { return fmt.Errorf("cannot unmarshal error: %w", err) } - err := rsp.err(nil) + err := rsp.err() if err == nil { return fmt.Errorf("server error: %q", r.Status) } @@ -515,3 +554,44 @@ func (client *Client) DebugGet(action string, result interface{}, params map[str _, err := client.doSync("GET", "/v1/debug", urlParams, nil, nil, &result) return err } + +type defaultRequester struct { + baseURL url.URL + doer doer + userAgent string + transport *http.Transport + client *Client +} + +func newDefaultRequester(client *Client, opts *Config) (*defaultRequester, error) { + if opts == nil { + opts = &Config{} + } + + var requester *defaultRequester + + if opts.BaseURL == "" { + // By default talk over a unix socket. + transport := &http.Transport{Dial: unixDialer(opts.Socket), DisableKeepAlives: opts.DisableKeepAlive} + baseURL := url.URL{Scheme: "http", Host: "localhost"} + requester = &defaultRequester{baseURL: baseURL, transport: transport} + } else { + // Otherwise talk regular HTTP-over-TCP. + baseURL, err := url.Parse(opts.BaseURL) + if err != nil { + return nil, fmt.Errorf("cannot parse base URL: %w", err) + } + transport := &http.Transport{DisableKeepAlives: opts.DisableKeepAlive} + requester = &defaultRequester{baseURL: *baseURL, transport: transport} + } + + requester.doer = &http.Client{Transport: requester.transport} + requester.userAgent = opts.UserAgent + requester.client = client + + return requester, nil +} + +func (rq *defaultRequester) Transport() *http.Transport { + return rq.transport +} diff --git a/client/exec.go b/client/exec.go index 88d405bd4..5b9d3d5ae 100644 --- a/client/exec.go +++ b/client/exec.go @@ -152,14 +152,10 @@ func (client *Client) Exec(opts *ExecOptions) (*ExecProcess, error) { headers := map[string]string{ "Content-Type": "application/json", } - resultBytes, changeID, err := client.doAsyncFull("POST", "/v1/exec", nil, headers, &body) - if err != nil { - return nil, err - } var result execResult - err = json.Unmarshal(resultBytes, &result) + resp, err := client.doAsync("POST", "/v1/exec", nil, headers, &body, &result) if err != nil { - return nil, fmt.Errorf("cannot unmarshal JSON response: %w", err) + return nil, err } // Connect to the "control" websocket. @@ -211,7 +207,7 @@ func (client *Client) Exec(opts *ExecOptions) (*ExecProcess, error) { }() process := &ExecProcess{ - changeID: changeID, + changeID: resp.ChangeID, client: client, timeout: opts.Timeout, writesDone: writesDone, diff --git a/client/export_test.go b/client/export_test.go index ec509df21..be4b7a14a 100644 --- a/client/export_test.go +++ b/client/export_test.go @@ -15,6 +15,7 @@ package client import ( + "context" "fmt" "io" "net/url" @@ -26,19 +27,37 @@ var ( ) func (client *Client) SetDoer(d doer) { - client.doer = d + client.Requester().(*defaultRequester).doer = d } +// TODO: Clean up tests to use the new Requester API. Tests do not generate a client.response type +// reply in the body while SyncRequest or AsyncRequest responses assume the JSON body can be +// unmarshalled into client.response. func (client *Client) Do(method, path string, query url.Values, body io.Reader, v interface{}) error { - return client.do(method, path, query, nil, body, v) + resp, err := client.Requester().Do(context.Background(), &RequestOptions{ + Type: RawRequest, + Method: method, + Path: path, + Query: query, + Headers: nil, + Body: body, + }) + if err != nil { + return err + } + err = decodeInto(resp.Body, v) + if err != nil { + return err + } + return nil } func (client *Client) FakeAsyncRequest() (changeId string, err error) { - changeId, err = client.doAsync("GET", "/v1/async-test", nil, nil, nil) + resp, err := client.doAsync("GET", "/v1/async-test", nil, nil, nil, nil) if err != nil { return "", fmt.Errorf("cannot do async test: %v", err) } - return changeId, nil + return resp.ChangeID, nil } func (client *Client) SetGetWebsocket(f getWebsocketFunc) { diff --git a/client/logs.go b/client/logs.go index 408b3bc4b..21581652d 100644 --- a/client/logs.go +++ b/client/logs.go @@ -73,13 +73,18 @@ func (client *Client) logs(ctx context.Context, opts *LogsOptions, follow bool) if follow { query.Set("follow", "true") } - res, err := client.raw(ctx, "GET", "/v1/logs", query, nil, nil) + resp, err := client.Requester().Do(ctx, &RequestOptions{ + Type: RawRequest, + Method: "GET", + Path: "/v1/logs", + Query: query, + }) if err != nil { return err } - defer res.Body.Close() + defer resp.Body.Close() - reader := bufio.NewReaderSize(res.Body, logReaderSize) + reader := bufio.NewReaderSize(resp.Body, logReaderSize) for { err = decodeLog(reader, opts.WriteLog) if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { diff --git a/client/services.go b/client/services.go index 0be19d23e..0143cf627 100644 --- a/client/services.go +++ b/client/services.go @@ -30,33 +30,33 @@ type ServiceOptions struct { // AutoStart starts the services makes as "startup: enabled". opts.Names must // be empty for this call. func (client *Client) AutoStart(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("autostart", opts.Names) + changeID, err = client.doMultiServiceAction("autostart", opts.Names) return changeID, err } // Start starts the services named in opts.Names in dependency order. func (client *Client) Start(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("start", opts.Names) + changeID, err = client.doMultiServiceAction("start", opts.Names) return changeID, err } // Stop stops the services named in opts.Names in dependency order. func (client *Client) Stop(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("stop", opts.Names) + changeID, err = client.doMultiServiceAction("stop", opts.Names) return changeID, err } // Restart stops and then starts the services named in opts.Names in // dependency order. func (client *Client) Restart(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("restart", opts.Names) + changeID, err = client.doMultiServiceAction("restart", opts.Names) return changeID, err } // Replan stops and (re)starts the services whose configuration has changed // since they were started. opts.Names must be empty for this call. func (client *Client) Replan(opts *ServiceOptions) (changeID string, err error) { - _, changeID, err = client.doMultiServiceAction("replan", opts.Names) + changeID, err = client.doMultiServiceAction("replan", opts.Names) return changeID, err } @@ -65,19 +65,24 @@ type multiActionData struct { Services []string `json:"services"` } -func (client *Client) doMultiServiceAction(actionName string, services []string) (result json.RawMessage, changeID string, err error) { +func (client *Client) doMultiServiceAction(actionName string, services []string) (changeID string, err error) { action := multiActionData{ Action: actionName, Services: services, } data, err := json.Marshal(&action) if err != nil { - return nil, "", fmt.Errorf("cannot marshal multi-service action: %w", err) + return "", fmt.Errorf("cannot marshal multi-service action: %w", err) } headers := map[string]string{ "Content-Type": "application/json", } - return client.doAsyncFull("POST", "/v1/services", nil, headers, bytes.NewBuffer(data)) + + resp, err := client.doAsync("POST", "/v1/services", nil, headers, bytes.NewBuffer(data), nil) + if err != nil { + return "", err + } + return resp.ChangeID, nil } type ServicesOptions struct { diff --git a/internals/cli/cmd_run.go b/internals/cli/cmd_run.go index e32bb4b36..4ba5774f5 100644 --- a/internals/cli/cmd_run.go +++ b/internals/cli/cmd_run.go @@ -239,7 +239,9 @@ out: } } - // Close our own self-connection, otherwise it prevents fast and clean termination. + // Close the client idle connection to the server (self connection) before we + // start with the HTTP shutdown process. This will speed up the server shutdown, + // and allow the Pebble process to exit faster. rcmd.client.CloseIdleConnections() return d.Stop(ch)