Skip to content

Commit

Permalink
feat: add ExecRawWithExtensions method for retrieving extensions in r…
Browse files Browse the repository at this point in the history
…esponse (#144)

* feat: add ExecRawWithExtensions method for retrieving extensions in response

---------

Co-authored-by: Rafał Kałuski <[email protected]>
  • Loading branch information
r4fall1 and Rafał Kałuski authored Jun 28, 2024
1 parent fa97047 commit 132b131
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 15 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,20 @@ if err != nil {
err = json.Unmarshal(raw, &res)
```
Additionally, if you need information about the extensions returned in the response use `ExecRawWithExtensions`. This function returns a map with extensions as the second variable.
```Go
query := `query{something(where: { foo: { _eq: "bar" }}){id}}`

data, extensions, err := client.ExecRawWithExtensions(ctx, query, map[string]any{})
if err != nil {
panic(err)
}

// You can now use the `extensions` variable to access the extensions data
fmt.Println("Extensions:", extensions)
```
### With operation name (deprecated)
Operation name is still on API decision plan https://github.com/shurcooL/graphql/issues/12. However, in my opinion separate methods are easier choice to avoid breaking changes
Expand Down
53 changes: 38 additions & 15 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,12 @@ func (c *Client) buildAndRequest(ctx context.Context, op operationType, v interf
return nil, nil, nil, Errors{newError(ErrGraphQLEncode, err)}
}

return c.request(ctx, query, variables, optionOutput)
data, _, resp, respBuf, errs := c.request(ctx, query, variables, optionOutput)
return data, resp, respBuf, errs
}

// Request the common method that send graphql request
func (c *Client) request(ctx context.Context, query string, variables map[string]interface{}, options *constructOptionsOutput) ([]byte, *http.Response, io.Reader, Errors) {
func (c *Client) request(ctx context.Context, query string, variables map[string]interface{}, options *constructOptionsOutput) ([]byte, []byte, *http.Response, io.Reader, Errors) {
in := GraphQLRequestPayload{
Query: query,
Variables: variables,
Expand All @@ -133,7 +134,7 @@ func (c *Client) request(ctx context.Context, query string, variables map[string
var buf bytes.Buffer
err := json.NewEncoder(&buf).Encode(in)
if err != nil {
return nil, nil, nil, Errors{newError(ErrGraphQLEncode, err)}
return nil, nil, nil, nil, Errors{newError(ErrGraphQLEncode, err)}
}

reqReader := bytes.NewReader(buf.Bytes())
Expand All @@ -143,7 +144,7 @@ func (c *Client) request(ctx context.Context, query string, variables map[string
if c.debug {
e = e.withRequest(request, reqReader)
}
return nil, nil, nil, Errors{e}
return nil, nil, nil, nil, Errors{e}
}
request.Header.Add("Content-Type", "application/json")

Expand All @@ -162,7 +163,7 @@ func (c *Client) request(ctx context.Context, query string, variables map[string
if c.debug {
e = e.withRequest(request, reqReader)
}
return nil, nil, nil, Errors{e}
return nil, nil, nil, nil, Errors{e}
}
defer resp.Body.Close()

Expand All @@ -171,7 +172,7 @@ func (c *Client) request(ctx context.Context, query string, variables map[string
if resp.Header.Get("Content-Encoding") == "gzip" {
gr, err := gzip.NewReader(r)
if err != nil {
return nil, nil, nil, Errors{newError(ErrJsonDecode, fmt.Errorf("problem trying to create gzip reader: %w", err))}
return nil, nil, nil, nil, Errors{newError(ErrJsonDecode, fmt.Errorf("problem trying to create gzip reader: %w", err))}
}
defer gr.Close()
r = gr
Expand All @@ -187,20 +188,21 @@ func (c *Client) request(ctx context.Context, query string, variables map[string
if c.debug {
err = err.withRequest(request, reqReader)
}
return nil, nil, nil, Errors{err}
return nil, nil, nil, nil, Errors{err}
}

var out struct {
Data *json.RawMessage
Errors Errors
Data *json.RawMessage
Extensions *json.RawMessage
Errors Errors
}

// copy the response reader for debugging
var respReader *bytes.Reader
if c.debug {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, nil, Errors{newError(ErrJsonDecode, err)}
return nil, nil, nil, nil, Errors{newError(ErrJsonDecode, err)}
}
respReader = bytes.NewReader(body)
r = io.NopCloser(respReader)
Expand All @@ -218,25 +220,30 @@ func (c *Client) request(ctx context.Context, query string, variables map[string
we = we.withRequest(request, reqReader).
withResponse(resp, respReader)
}
return nil, nil, nil, Errors{we}
return nil, nil, nil, nil, Errors{we}
}

var rawData []byte
if out.Data != nil && len(*out.Data) > 0 {
rawData = []byte(*out.Data)
}

var extensions []byte
if out.Extensions != nil && len(*out.Extensions) > 0 {
extensions = []byte(*out.Extensions)
}

if len(out.Errors) > 0 {
if c.debug && (out.Errors[0].Extensions == nil || out.Errors[0].Extensions["request"] == nil) {
out.Errors[0] = out.Errors[0].
withRequest(request, reqReader).
withResponse(resp, respReader)
}

return rawData, resp, respReader, out.Errors
return rawData, extensions, resp, respReader, out.Errors
}

return rawData, resp, respReader, nil
return rawData, extensions, resp, respReader, nil
}

// do executes a single GraphQL operation.
Expand All @@ -263,7 +270,7 @@ func (c *Client) Exec(ctx context.Context, query string, v interface{}, variable
return err
}

data, resp, respBuf, errs := c.request(ctx, query, variables, optionsOutput)
data, _, resp, respBuf, errs := c.request(ctx, query, variables, optionsOutput)
return c.processResponse(v, data, resp, respBuf, errs)
}

Expand All @@ -275,13 +282,29 @@ func (c *Client) ExecRaw(ctx context.Context, query string, variables map[string
return nil, err
}

data, _, _, errs := c.request(ctx, query, variables, optionsOutput)
data, _, _, _, errs := c.request(ctx, query, variables, optionsOutput)
if len(errs) > 0 {
return data, errs
}
return data, nil
}

// Executes a pre-built query and returns the raw json message and a map with extensions (values also as raw json objects). Unlike the
// Query method you have to specify in the query the fields that you want to receive as they are not inferred from the interface. This method
// is useful if you need to build the query dynamically.
func (c *Client) ExecRawWithExtensions(ctx context.Context, query string, variables map[string]interface{}, options ...Option) ([]byte, []byte, error) {
optionsOutput, err := constructOptions(options)
if err != nil {
return nil, nil, err
}

data, ext, _, _, errs := c.request(ctx, query, variables, optionsOutput)
if len(errs) > 0 {
return data, ext, errs
}
return data, ext, nil
}

func (c *Client) processResponse(v interface{}, data []byte, resp *http.Response, respBuf io.Reader, errs Errors) error {
if len(data) > 0 {
err := jsonutil.UnmarshalGraphQL(data, v)
Expand Down
41 changes: 41 additions & 0 deletions graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,47 @@ func TestClient_Exec_QueryRaw(t *testing.T) {
}
}

// Test exec pre-built query, return raw json string and map
// with extensions
func TestClient_Exec_QueryRawWithExtensions(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/graphql", func(w http.ResponseWriter, req *http.Request) {
body := mustRead(req.Body)
if got, want := body, `{"query":"{user{id,name}}"}`+"\n"; got != want {
t.Errorf("got body: %v, want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
mustWrite(w, `{"data": {"user": {"name": "Gopher"}}, "extensions": {"id": 1, "domain": "users"}}`)
})
client := graphql.NewClient("/graphql", &http.Client{Transport: localRoundTripper{handler: mux}})

var ext struct {
ID int `graphql:"id"`
Domain string `graphql:"domain"`
}

_, extensions, err := client.ExecRawWithExtensions(context.Background(), "{user{id,name}}", map[string]interface{}{})
if err != nil {
t.Fatal(err)
}

if got := extensions; got == nil {
t.Errorf("got nil extensions: %q, want: non-nil", got)
}

err = json.Unmarshal(extensions, &ext)
if err != nil {
t.Fatal(err)
}

if got, want := ext.ID, 1; got != want {
t.Errorf("got ext.ID: %q, want: %q", got, want)
}
if got, want := ext.Domain, "users"; got != want {
t.Errorf("got ext.Domain: %q, want: %q", got, want)
}
}

// localRoundTripper is an http.RoundTripper that executes HTTP transactions
// by using handler directly, instead of going over an HTTP connection.
type localRoundTripper struct {
Expand Down

0 comments on commit 132b131

Please sign in to comment.