diff --git a/backend/controller/controller.go b/backend/controller/controller.go
index 8ddd017a2c..db3d5bf5a1 100644
--- a/backend/controller/controller.go
+++ b/backend/controller/controller.go
@@ -66,12 +66,20 @@ import (
// CommonConfig between the production controller and development server.
type CommonConfig struct {
AllowOrigins []*url.URL `help:"Allow CORS requests to ingress endpoints from these origins." env:"FTL_CONTROLLER_ALLOW_ORIGIN"`
+ AllowHeaders []string `help:"Allow these headers in CORS requests. (Requires AllowOrigins)" env:"FTL_CONTROLLER_ALLOW_HEADERS"`
NoConsole bool `help:"Disable the console."`
IdleRunners int `help:"Number of idle runners to keep around (not supported in production)." default:"3"`
WaitFor []string `help:"Wait for these modules to be deployed before becoming ready." placeholder:"MODULE"`
CronJobTimeout time.Duration `help:"Timeout for cron jobs." default:"5m"`
}
+func (c *CommonConfig) Validate() error {
+ if len(c.AllowHeaders) > 0 && len(c.AllowOrigins) == 0 {
+ return fmt.Errorf("AllowOrigins must be set when AllowHeaders is used")
+ }
+ return nil
+}
+
type Config struct {
Bind *url.URL `help:"Socket to bind to." default:"http://localhost:8892" env:"FTL_CONTROLLER_BIND"`
IngressBind *url.URL `help:"Socket to bind to for ingress." default:"http://localhost:8891" env:"FTL_CONTROLLER_INGRESS_BIND"`
@@ -139,7 +147,11 @@ func Start(ctx context.Context, config Config, runnerScaling scaling.RunnerScali
ingressHandler := http.Handler(svc)
if len(config.AllowOrigins) > 0 {
- ingressHandler = cors.Middleware(slices.Map(config.AllowOrigins, func(u *url.URL) string { return u.String() }), ingressHandler)
+ ingressHandler = cors.Middleware(
+ slices.Map(config.AllowOrigins, func(u *url.URL) string { return u.String() }),
+ config.AllowHeaders,
+ ingressHandler,
+ )
}
g, ctx := errgroup.WithContext(ctx)
diff --git a/backend/controller/ingress/ingress_integration_test.go b/backend/controller/ingress/ingress_integration_test.go
index 2db1cd0d47..8ca60f4339 100644
--- a/backend/controller/ingress/ingress_integration_test.go
+++ b/backend/controller/ingress/ingress_integration_test.go
@@ -4,6 +4,7 @@ package ingress_test
import (
"net/http"
+ "os"
"testing"
"github.com/alecthomas/assert/v2"
@@ -16,7 +17,7 @@ func TestHttpIngress(t *testing.T) {
in.Run(t, "",
in.CopyModule("httpingress"),
in.Deploy("httpingress"),
- in.HttpCall(http.MethodGet, "/users/123/posts/456", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/users/123/posts/456", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Get"])
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
@@ -31,7 +32,7 @@ func TestHttpIngress(t *testing.T) {
assert.True(t, ok, "good_stuff is not a string: %s", repr.String(resp.JsonBody))
assert.Equal(t, "This is good stuff", goodStuff)
}),
- in.HttpCall(http.MethodPost, "/users", in.JsonData(t, in.Obj{"userId": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodPost, "/users", nil, in.JsonData(t, in.Obj{"userId": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 201, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Post"])
success, ok := resp.JsonBody["success"].(bool)
@@ -39,88 +40,140 @@ func TestHttpIngress(t *testing.T) {
assert.True(t, success)
}),
// contains aliased field
- in.HttpCall(http.MethodPost, "/users", in.JsonData(t, in.Obj{"user_id": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodPost, "/users", nil, in.JsonData(t, in.Obj{"user_id": 123, "postId": 345}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 201, resp.Status)
}),
- in.HttpCall(http.MethodPut, "/users/123", in.JsonData(t, in.Obj{"postId": "346"}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodPut, "/users/123", nil, in.JsonData(t, in.Obj{"postId": "346"}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Put"])
assert.Equal(t, map[string]any{}, resp.JsonBody)
}),
- in.HttpCall(http.MethodDelete, "/users/123", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodDelete, "/users/123", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"Header from FTL"}, resp.Headers["Delete"])
assert.Equal(t, map[string]any{}, resp.JsonBody)
}),
- in.HttpCall(http.MethodGet, "/queryparams?foo=bar", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/queryparams?foo=bar", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, "bar", string(resp.BodyBytes))
}),
- in.HttpCall(http.MethodGet, "/queryparams", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/queryparams", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, "No value", string(resp.BodyBytes))
}),
- in.HttpCall(http.MethodGet, "/html", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/html", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/html; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, "
HTML Page From FTL 🚀!
", string(resp.BodyBytes))
}),
- in.HttpCall(http.MethodPost, "/bytes", []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodPost, "/bytes", nil, []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/octet-stream"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("Hello, World!"), resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/empty", nil, func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/empty", nil, nil, func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, nil, resp.Headers["Content-Type"])
assert.Equal(t, nil, resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/string", []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/string", nil, []byte("Hello, World!"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("Hello, World!"), resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/int", []byte("1234"), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/int", nil, []byte("1234"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("1234"), resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/float", []byte("1234.56789"), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/float", nil, []byte("1234.56789"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("1234.56789"), resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/bool", []byte("true"), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/bool", nil, []byte("true"), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("true"), resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/error", nil, func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/error", nil, nil, func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 500, resp.Status)
assert.Equal(t, []string{"text/plain; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, []byte("Error from FTL"), resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/array/string", in.JsonData(t, []string{"hello", "world"}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/array/string", nil, in.JsonData(t, []string{"hello", "world"}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, in.JsonData(t, []string{"hello", "world"}), resp.BodyBytes)
}),
- in.HttpCall(http.MethodPost, "/array/data", in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodPost, "/array/data", nil, in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, in.JsonData(t, []in.Obj{{"item": "a"}, {"item": "b"}}), resp.BodyBytes)
}),
- in.HttpCall(http.MethodGet, "/typeenum", in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/typeenum", nil, in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
assert.Equal(t, []string{"application/json; charset=utf-8"}, resp.Headers["Content-Type"])
assert.Equal(t, in.JsonData(t, in.Obj{"name": "A", "value": "hello"}), resp.BodyBytes)
}),
+ // CORS preflight request without CORS middleware enabled
+ in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
+ "Origin": {"http://localhost:8892"},
+ "Access-Control-Request-Method": {"GET"},
+ "Access-Control-Request-Headers": {"x-forwarded-capabilities"},
+ }, nil, func(t testing.TB, resp *in.HTTPResponse) {
+ // should not return access control headers because we have not set up cors in this controller
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"])
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"])
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"])
+ }),
+ )
+}
+
+// Run with CORS enabled via FTL_CONTROLLER_ALLOW_ORIGIN and FTL_CONTROLLER_ALLOW_HEADERS
+// This test is similar to TestHttpIngress above with the addition of CORS enabled in the controller.
+func TestHttpIngressWithCors(t *testing.T) {
+ os.Setenv("FTL_CONTROLLER_ALLOW_ORIGIN", "http://localhost:8892")
+ os.Setenv("FTL_CONTROLLER_ALLOW_HEADERS", "x-forwarded-capabilities")
+ in.Run(t, "",
+ in.CopyModule("httpingress"),
+ in.Deploy("httpingress"),
+ // A correct CORS preflight request
+ in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
+ "Origin": {"http://localhost:8892"},
+ "Access-Control-Request-Method": {"GET"},
+ "Access-Control-Request-Headers": {"x-forwarded-capabilities"},
+ }, nil, func(t testing.TB, resp *in.HTTPResponse) {
+ assert.Equal(t, []string{"http://localhost:8892"}, resp.Headers["Access-Control-Allow-Origin"])
+ assert.Equal(t, []string{"GET"}, resp.Headers["Access-Control-Allow-Methods"])
+ assert.Equal(t, []string{"x-forwarded-capabilities"}, resp.Headers["Access-Control-Allow-Headers"])
+ }),
+ // Not allowed headers
+ in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
+ "Origin": {"http://localhost:8892"},
+ "Access-Control-Request-Method": {"GET"},
+ "Access-Control-Request-Headers": {"moo"},
+ }, nil, func(t testing.TB, resp *in.HTTPResponse) {
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"])
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"])
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"])
+ }),
+ // Not allowed origin
+ in.HttpCall(http.MethodOptions, "/typeenum", map[string][]string{
+ "Origin": {"http://localhost:4444"},
+ "Access-Control-Request-Method": {"GET"},
+ "Access-Control-Request-Headers": {"x-forwarded-capabilities"},
+ }, nil, func(t testing.TB, resp *in.HTTPResponse) {
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Origin"])
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Methods"])
+ assert.Equal(t, nil, resp.Headers["Access-Control-Allow-Headers"])
+ }),
)
}
diff --git a/frontend/local.go b/frontend/local.go
index 6ea258410a..f7961d662a 100644
--- a/frontend/local.go
+++ b/frontend/local.go
@@ -43,5 +43,5 @@ func Server(ctx context.Context, timestamp time.Time, publicURL *url.URL, allowO
return proxy, nil
}
- return cors.Middleware([]string{allowOrigin.String()}, proxy), nil
+ return cors.Middleware([]string{allowOrigin.String()}, nil, proxy), nil
}
diff --git a/frontend/release.go b/frontend/release.go
index d9115df8f8..913119ab0a 100644
--- a/frontend/release.go
+++ b/frontend/release.go
@@ -47,7 +47,7 @@ func Server(ctx context.Context, timestamp time.Time, publicURL *url.URL, allowO
http.ServeContent(w, r, filePath, timestamp, f.(io.ReadSeeker))
})
if allowOrigin != nil {
- handler = cors.Middleware([]string{allowOrigin.String()}, handler)
+ handler = cors.Middleware([]string{allowOrigin.String()}, nil, handler)
}
return handler, nil
}
diff --git a/go-runtime/encoding/encoding_integration_test.go b/go-runtime/encoding/encoding_integration_test.go
index b3dea04d91..d8157ba364 100644
--- a/go-runtime/encoding/encoding_integration_test.go
+++ b/go-runtime/encoding/encoding_integration_test.go
@@ -14,7 +14,7 @@ func TestHttpEncodeOmitempty(t *testing.T) {
in.Run(t, "",
in.CopyModule("omitempty"),
in.Deploy("omitempty"),
- in.HttpCall(http.MethodGet, "/get", in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
+ in.HttpCall(http.MethodGet, "/get", nil, in.JsonData(t, in.Obj{}), func(t testing.TB, resp *in.HTTPResponse) {
assert.Equal(t, 200, resp.Status)
_, ok := resp.JsonBody["mustset"]
assert.True(t, ok)
diff --git a/integration/actions.go b/integration/actions.go
index 119593e04a..693a5760ca 100644
--- a/integration/actions.go
+++ b/integration/actions.go
@@ -403,7 +403,7 @@ func JsonData(t testing.TB, body interface{}) []byte {
}
// HttpCall makes an HTTP call to the running FTL ingress endpoint.
-func HttpCall(method string, path string, body []byte, onResponse func(t testing.TB, resp *HTTPResponse)) Action {
+func HttpCall(method string, path string, headers map[string][]string, body []byte, onResponse func(t testing.TB, resp *HTTPResponse)) Action {
return func(t testing.TB, ic TestContext) {
Infof("HTTP %s %s", method, path)
baseURL, err := url.Parse(fmt.Sprintf("http://localhost:8891"))
@@ -415,6 +415,11 @@ func HttpCall(method string, path string, body []byte, onResponse func(t testing
assert.NoError(t, err)
r.Header.Add("Content-Type", "application/json")
+ for k, vs := range headers {
+ for _, v := range vs {
+ r.Header.Add(k, v)
+ }
+ }
client := http.Client{}
resp, err := client.Do(r)
diff --git a/internal/cors/cors.go b/internal/cors/cors.go
index 7bb0672f88..a3c7cd485f 100644
--- a/internal/cors/cors.go
+++ b/internal/cors/cors.go
@@ -6,7 +6,10 @@ import (
"github.com/rs/cors"
)
-func Middleware(allowOrigins []string, next http.Handler) http.Handler {
- c := cors.New(cors.Options{AllowedOrigins: allowOrigins})
+func Middleware(allowOrigins []string, allowHeaders []string, next http.Handler) http.Handler {
+ c := cors.New(cors.Options{
+ AllowedOrigins: allowOrigins,
+ AllowedHeaders: allowHeaders,
+ })
return c.Handler(next)
}