From 00bf8495f70167a829261e45836eac83924f2c18 Mon Sep 17 00:00:00 2001 From: Sean Marciniak Date: Thu, 24 Oct 2024 09:34:51 +1030 Subject: [PATCH] Correcting AsResponseError method --- errors.go | 29 +++++++++++++++++++++++------ errors_test.go | 13 ++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/errors.go b/errors.go index 1ef98e0..435ca63 100644 --- a/errors.go +++ b/errors.go @@ -38,12 +38,29 @@ func newResponseError(resp *http.Response, target int, targets ...int) error { } } -// IsResponseError is convenience function to see -// if it can convert into RequestError. -func IsResponseError(err error) (*ResponseError, bool) { - var re *ResponseError - if errors.As(err, &re) { - return err.(*ResponseError), true +// AsResponseError is a convenience function to check the error +// to see if it contains an `ResponseError` and returns the value with true. +// If the error was initially joined using [errors.Join], it will check each error +// within the list and return the first matching error. +func AsResponseError(err error) (*ResponseError, bool) { + // When `errors.Join` is called, it returns an error that + // matches the provided interface. + if joined, ok := err.(interface{ Unwrap() []error }); ok { + for _, err := range joined.Unwrap() { + if re, ok := AsResponseError(err); ok { + return re, ok + } + } + return nil, false + } + + for err != nil { + if re, ok := err.(*ResponseError); ok { + return re, true + } + // In case the error is wrapped using `fmt.Errorf` + // this will also account for that. + err = errors.Unwrap(err) } return nil, false } diff --git a/errors_test.go b/errors_test.go index ff3cc01..0ed14de 100644 --- a/errors_test.go +++ b/errors_test.go @@ -2,6 +2,7 @@ package signalfx import ( "errors" + "fmt" "io" "net/http" "net/url" @@ -130,12 +131,22 @@ func TestIsRequestError(t *testing.T) { err: &ResponseError{}, expected: true, }, + { + name: "joined errors", + err: errors.Join(errors.New("boom"), &ResponseError{}), + expected: true, + }, + { + name: "fmt error", + err: fmt.Errorf("check permissions: %w", &ResponseError{}), + expected: true, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, ok := IsResponseError(tc.err) + _, ok := AsResponseError(tc.err) assert.Equal(t, tc.expected, ok, "Must match the expected value") }) }