Skip to content

Commit

Permalink
Add the ability to wrap the http.RoundTripper from Go code
Browse files Browse the repository at this point in the history
Signed-off-by: Evan Anderson <[email protected]>
  • Loading branch information
evankanderson committed Nov 19, 2024
1 parent 0d50f52 commit 4dfce4b
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 16 deletions.
12 changes: 12 additions & 0 deletions rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ type EvalContext struct {
interQueryBuiltinValueCache cache.InterQueryValueCache
ndBuiltinCache builtins.NDBCache
resolvers []refResolver
httpRoundTrip topdown.CustomizeRoundTrip
sortSets bool
copyMaps bool
printHook print.Hook
Expand Down Expand Up @@ -335,6 +336,13 @@ func EvalResolver(ref ast.Ref, r resolver.Resolver) EvalOption {
}
}

// EvalHTTPRoundTrip allows customizing the http.RoundTripper for this evaluation.
func EvalHTTPRoundTrip(t topdown.CustomizeRoundTrip) EvalOption {
return func(e *EvalContext) {
e.httpRoundTrip = t
}
}

// EvalSortSets causes the evaluator to sort sets before returning them as JSON arrays.
func EvalSortSets(yes bool) EvalOption {
return func(e *EvalContext) {
Expand Down Expand Up @@ -2165,6 +2173,10 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
q = q.WithInput(ast.NewTerm(ectx.parsedInput))
}

if ectx.httpRoundTrip != nil {
q = q.WithHTTPRoundTrip(ectx.httpRoundTrip)
}

for i := range ectx.resolvers {
q = q.WithResolver(ectx.resolvers[i].ref, ectx.resolvers[i].r)
}
Expand Down
1 change: 1 addition & 0 deletions topdown/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type (
QueryID uint64 // identifies query being evaluated
ParentID uint64 // identifies parent of query being evaluated
PrintHook print.Hook // provides callback function to use for printing
RoundTrip CustomizeRoundTrip // customize transport to use for HTTP requests
DistributedTracingOpts tracing.Options // options to be used by distributed tracing.
rand *rand.Rand // randomization source for non-security-sensitive operations
Capabilities *ast.Capabilities
Expand Down
2 changes: 2 additions & 0 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type eval struct {
tracingOpts tracing.Options
findOne bool
strictObjects bool
roundTrip CustomizeRoundTrip
}

func (e *eval) Run(iter evalIterator) error {
Expand Down Expand Up @@ -836,6 +837,7 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
PrintHook: e.printHook,
DistributedTracingOpts: e.tracingOpts,
Capabilities: capabilities,
RoundTrip: e.roundTrip,
}

eval := evalBuiltin{
Expand Down
22 changes: 16 additions & 6 deletions topdown/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ var (

type httpSendKey string

// CustomizeRoundTrip allows customizing an existing http.Transport,
// to the returned value, which could be the same Transport or a new one.
type CustomizeRoundTrip func(*http.Transport) http.RoundTripper

const (
// httpSendBuiltinCacheKey is the key in the builtin context cache that
// points to the http.send() specific cache resides at.
Expand Down Expand Up @@ -626,23 +630,29 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
tlsConfig.RootCAs = pool
}

var transport *http.Transport
if isTLS {
if ok, parsedURL, tr := useSocket(url, &tlsConfig); ok {
client.Transport = tr
transport = tr
url = parsedURL
} else {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tlsConfig
tr.DisableKeepAlives = true
client.Transport = tr
transport = http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tlsConfig
transport.DisableKeepAlives = true
}
} else {
if ok, parsedURL, tr := useSocket(url, nil); ok {
client.Transport = tr
transport = tr
url = parsedURL
}
}

if bctx.RoundTrip != nil {
client.Transport = bctx.RoundTrip(transport)
} else if transport != nil {
client.Transport = transport
}

// check if redirects are enabled
if enableRedirect {
client.CheckRedirect = func(req *http.Request, _ []*http.Request) error {
Expand Down
135 changes: 127 additions & 8 deletions topdown/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ func TestHTTPSendIntraQueryCaching(t *testing.T) {
ruleTemplate: `p = x { http.send(%REQ%, r); x = r.body }`,
headers: map[string][]string{"Cache-Control": {"max-age=290304000, public"}},
response: `{"x": 1}`,
expectedReqCount: 1,
expectedReqCount: 2, // Partial evaluation generates a second query, so expect 2 requests
expectedInterQueryCacheHit: false,
},
{
Expand All @@ -1197,7 +1197,7 @@ func TestHTTPSendIntraQueryCaching(t *testing.T) {
}`,
headers: map[string][]string{"Cache-Control": {"max-age=290304000, public"}},
response: `{"x": 1}`,
expectedReqCount: 1,
expectedReqCount: 2, // Partial evaluation generates a second query, so expect 2 requests
expectedInterQueryCacheHit: false,
},
{
Expand All @@ -1213,11 +1213,11 @@ func TestHTTPSendIntraQueryCaching(t *testing.T) {
}`,
headers: map[string][]string{"Cache-Control": {"max-age=290304000, public"}},
response: `{"x": 1}`,
expectedReqCount: 1,
expectedReqCount: 1, // Inter-query cache applies across full and partial eval
expectedInterQueryCacheHit: true,
},
{
note: "http.send GET multiple (inter-query cache enabled, )",
note: "http.send GET multiple (inter-query cache enabled, server no-store)",
request: `{"method": "get", "url": "%URL%", "force_json_decode": true, "cache": true}`,
ruleTemplate: `p = x {
r1 = http.send(%REQ%)
Expand All @@ -1229,7 +1229,7 @@ func TestHTTPSendIntraQueryCaching(t *testing.T) {
}`,
headers: map[string][]string{"Cache-Control": {"no-store"}},
response: `{"x": 1}`,
expectedReqCount: 1,
expectedReqCount: 2, // no-store means the Partial evaluation generates a second query
expectedInterQueryCacheHit: false,
},
}
Expand Down Expand Up @@ -1285,8 +1285,8 @@ func TestHTTPSendIntraQueryCaching(t *testing.T) {
runTopDownTestCase(t, data, tc.note, []string{rule}, tc.response, opts...)

// Note: The runTopDownTestCase ends up evaluating twice (once with and once without partial
// eval first), so expect 2x the total request count the test case specified.
actualCount := len(requests) / 2
// eval first); this affects inter-query caching enabled vs disabled.
actualCount := len(requests)
if actualCount != tc.expectedReqCount {
t.Fatalf("Expected to get %d requests, got %d", tc.expectedReqCount, actualCount)
}
Expand Down Expand Up @@ -3101,7 +3101,7 @@ func TestHTTPSendCacheDefaultStatusCodesIntraQueryCache(t *testing.T) {
}
}))

defer ts.Close()
t.Cleanup(ts.Close)

t.Run("non-cacheable status code: intra-query cache", func(t *testing.T) {
base := fmt.Sprintf(`http.send({"method": "get", "url": %q, "cache": true})`, ts.URL)
Expand Down Expand Up @@ -3775,3 +3775,122 @@ func TestHTTPGetRequestAllowNet(t *testing.T) {
runTopDownTestCase(t, data, tc.note, append(tc.rules, httpSendHelperRules...), tc.expected, tc.options)
}
}

type secretTransport struct {
extraRequestHeaders http.Header

*http.Transport
}

func (st *secretTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range st.extraRequestHeaders {
// Set additional headers on the request not visible to the caller
req.Header[k] = v
}
return st.Transport.RoundTrip(req)
}

func (st *secretTransport) Transform(t *http.Transport) http.RoundTripper {
st.Transport = t.Clone()
return st
}

func TestHTTPWithCustomTransport(t *testing.T) {
// test data
body := map[string]bool{"ok": true}

// test server only returns answers when a custom header is set
var callCount int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Header.Get("secret-header") != "secret-value" {
w.WriteHeader(http.StatusForbidden)
return
}
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(body)
}))

defer ts.Close()

// host
serverURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
serverHost := strings.Split(serverURL.Host, ":")[0]

// expected result
expectedResult := make(map[string]interface{})
expectedResult["status"] = "200 OK"
expectedResult["status_code"] = http.StatusOK

expectedResult["body"] = body
expectedResult["raw_body"] = "{\"ok\":true}\n"

resultObj := ast.MustInterfaceToValue(expectedResult)

hostError := &Error{Code: "eval_builtin_error", Message: fmt.Sprintf("http.send: unallowed host: %s", serverHost)}
expectedError := map[string]any{"body": nil, "raw_body": "", "status": "403 Forbidden", "status_code": 403}
errorObj := ast.MustInterfaceToValue(expectedError)

rules := []string{fmt.Sprintf(
`p = x { http.send({"method": "get", "url": "%s", "force_json_decode": true}, resp); x := remove_headers(resp) }`, ts.URL)}

st := &secretTransport{
extraRequestHeaders: http.Header{"secret-header": []string{"secret-value"}},
}

// run the test
tests := []struct {
note string
rules []string
options func(*Query) *Query
expected interface{}
calls int
}{
{
"http.send transport is default",
rules,
func(q *Query) *Query {
return q
},
errorObj.String(),
1,
},
{
"http.send transport is nil",
rules,
setRoundTripper(nil),
errorObj.String(),
1,
},
{
"http.send transport adds secret header",
rules,
setRoundTripper(st.Transform),
resultObj.String(),
1,
},
{
"http.send allow_net empty, no call to RoundTrip",
rules,
setAllowNet([]string{}),
hostError,
0,
},
}

data := loadSmallTestData()

for _, tc := range tests {
startingCalls := callCount
runTopDownTestCase(t, data, tc.note, append(tc.rules, httpSendHelperRules...), tc.expected, tc.options)
// Note: The runTopDownTestCase ends up evaluating twice (once with and once without partial
// eval first), so expect 2x the total request count the test case specified.
serverCalls := (callCount - startingCalls) / 2
if serverCalls != tc.calls {
t.Errorf("Expected %d calls to server, got %d", tc.calls, serverCalls)
}
}
}
8 changes: 8 additions & 0 deletions topdown/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type Query struct {
strictBuiltinErrors bool
builtinErrorList *[]Error
strictObjects bool
roundTrip CustomizeRoundTrip
printHook print.Hook
tracingOpts tracing.Options
virtualCache VirtualCache
Expand Down Expand Up @@ -279,6 +280,12 @@ func (q *Query) WithResolver(ref ast.Ref, r resolver.Resolver) *Query {
return q
}

// WithHTTPRoundTrip configures a custom HTTP transport for built-in functions that make HTTP requests.
func (q *Query) WithHTTPRoundTrip(t CustomizeRoundTrip) *Query {
q.roundTrip = t
return q
}

func (q *Query) WithPrintHook(h print.Hook) *Query {
q.printHook = h
return q
Expand Down Expand Up @@ -561,6 +568,7 @@ func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error {
printHook: q.printHook,
tracingOpts: q.tracingOpts,
strictObjects: q.strictObjects,
roundTrip: q.roundTrip,
}
e.caller = e
q.metrics.Timer(metrics.RegoQueryEval).Start()
Expand Down
15 changes: 13 additions & 2 deletions topdown/topdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,12 @@ func setAllowNet(a []string) func(*Query) *Query {
}
}

func setRoundTripper(t CustomizeRoundTrip) func(*Query) *Query {
return func(q *Query) *Query {
return q.WithHTTPRoundTrip(t)
}
}

func runTopDownTestCase(t *testing.T, data map[string]interface{}, note string, rules []string, expected interface{}, options ...func(*Query) *Query) {
t.Helper()

Expand Down Expand Up @@ -2228,15 +2234,16 @@ func assertTopDownWithPathAndContext(ctx context.Context, t *testing.T, compiler
// the result of a query against `data` because the queries need to be
// converted into rules (which would result in recursion.)
if len(path) > 0 {
runTopDownPartialTestCase(ctx, t, compiler, store, txn, inputTerm, rhs, body, requiresSort, expected)
runTopDownPartialTestCase(ctx, t, compiler, store, txn, inputTerm, rhs, body, requiresSort, expected, options...)
}
default:
t.Fatalf("Unexpected expected value type: %+v", e)
}
})
}

func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast.Compiler, store storage.Store, txn storage.Transaction, input *ast.Term, output *ast.Term, body ast.Body, requiresSort bool, expected interface{}) {
func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast.Compiler, store storage.Store, txn storage.Transaction, input *ast.Term, output *ast.Term, body ast.Body, requiresSort bool, expected interface{},
options ...func(*Query) *Query) {
t.Helper()

// add an inter-query cache
Expand Down Expand Up @@ -2286,6 +2293,10 @@ func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast.
WithInterQueryBuiltinCache(interQueryCache).
WithInterQueryBuiltinValueCache(interQueryValueCache)

for _, opt := range options {
query = opt(query)
}

qrs, err := query.Run(ctx)
if err != nil {
t.Fatal("Unexpected error on query after partial evaluation:", err)
Expand Down

0 comments on commit 4dfce4b

Please sign in to comment.