Skip to content

Commit

Permalink
Remove CSRF checking middleware
Browse files Browse the repository at this point in the history
The remaining two endpoints that were checking the CSRF token were
both unauthenticated requests. We don't need a CSRF token here because
we require Content-Type: application/json for these requests.
  • Loading branch information
zmb3 committed Dec 17, 2024
1 parent 17ffecc commit 93df810
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 116 deletions.
19 changes: 1 addition & 18 deletions lib/httplib/httplib.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/observability/tracing"
tracehttp "github.com/gravitational/teleport/api/observability/tracing/http"
"github.com/gravitational/teleport/lib/httplib/csrf"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -155,23 +154,6 @@ func MakeStdHandlerWithErrorWriter(fn StdHandlerFunc, errWriter ErrorWriter) htt
}
}

// WithCSRFProtection ensures that request to unauthenticated API is checked against CSRF attacks
func WithCSRFProtection(fn HandlerFunc) httprouter.Handle {
handlerFn := MakeHandler(fn)
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
errHeader := csrf.VerifyHTTPHeader(r)
errForm := csrf.VerifyFormField(r)
if errForm != nil && errHeader != nil {
slog.WarnContext(r.Context(), "unable to validate CSRF token", "header_error", errHeader, "form_error", errForm)
trace.WriteError(w, trace.AccessDenied("access denied"))
return
}
}
handlerFn(w, r, p)
}
}

// ReadJSON reads HTTP json request and unmarshals it
// into passed any obj. A reasonable maximum size is enforced
// to mitigate resource exhaustion attacks.
Expand All @@ -188,6 +170,7 @@ func ReadResourceJSON(r *http.Request, val any) error {

func readJSON(r *http.Request, val any, maxSize int64) error {
// Check content type to mitigate CSRF attack.
// (Form POST requests don't support application/json payloads.)
contentType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
slog.WarnContext(r.Context(), "Error parsing media type for reading JSON", "error", err)
Expand Down
29 changes: 2 additions & 27 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.POST("/webapi/sessions/app", h.WithAuth(h.createAppSession))

// Web sessions
h.POST("/webapi/sessions/web", httplib.WithCSRFProtection(h.WithLimiterHandlerFunc(h.createWebSession)))
h.POST("/webapi/sessions/web", h.WithLimiter(h.createWebSession))
h.DELETE("/webapi/sessions/web", h.WithAuth(h.deleteWebSession))
h.POST("/webapi/sessions/web/renew", h.WithAuth(h.renewWebSession))
h.POST("/webapi/users", h.WithAuth(h.createUserHandle))
Expand All @@ -793,7 +793,7 @@ func (h *Handler) bindDefaultEndpoints() {
// h.GET("/webapi/users/password/token/:token", h.WithLimiter(h.getResetPasswordTokenHandle))
h.GET("/webapi/users/*wildcard", h.handleGetUserOrResetToken)

h.PUT("/webapi/users/password/token", httplib.WithCSRFProtection(h.changeUserAuthentication))
h.PUT("/webapi/users/password/token", h.WithLimiter(h.changeUserAuthentication))
h.PUT("/webapi/users/password", h.WithAuth(h.changePassword))
h.POST("/webapi/users/password/token", h.WithAuth(h.createResetPasswordToken))
h.POST("/webapi/users/privilege/token", h.WithAuth(h.createPrivilegeTokenHandle))
Expand Down Expand Up @@ -1993,7 +1993,6 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr
}

response, err := h.cfg.ProxyClient.CreateGithubAuthRequest(r.Context(), types.GithubAuthRequest{
CSRFToken: req.CSRFToken,
ConnectorID: req.ConnectorID,
CreateWebSession: true,
ClientRedirectURL: req.ClientRedirectURL,
Expand All @@ -2003,7 +2002,6 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr
if err != nil {
logger.WithError(err).Error("Error creating auth request.")
return client.LoginFailedRedirectURL

}

return response.RedirectURL
Expand Down Expand Up @@ -4704,21 +4702,6 @@ func (h *Handler) WithSession(fn ContextHandler) httprouter.Handle {
})
}

// WithAuthCookieAndCSRF ensures that a request is authenticated
// for plain old non-AJAX requests (does not check the Bearer header).
// It enforces CSRF checks (except for "safe" methods).
func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle {
f := func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
sctx, err := h.AuthenticateRequest(w, r, false)
if err != nil {
return nil, trace.Wrap(err)
}
return fn(w, r, p, sctx)
}

return httplib.WithCSRFProtection(f)
}

// WithUnauthenticatedLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated requests.
// This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if
// you're certain that no authenticated requests will be made.
Expand Down Expand Up @@ -5053,8 +5036,6 @@ type SSORequestParams struct {
// ConnectorID identifies the SSO connector to use to log in, from
// the connector_id query parameter.
ConnectorID string
// CSRFToken is the token in the CSRF cookie header.
CSRFToken string
}

// ParseSSORequestParams extracts the SSO request parameters from an http.Request,
Expand Down Expand Up @@ -5087,15 +5068,9 @@ func ParseSSORequestParams(r *http.Request) (*SSORequestParams, error) {
return nil, trace.BadParameter("missing connector_id query parameter")
}

csrfToken, err := csrf.ExtractTokenFromCookie(r)
if err != nil {
return nil, trace.Wrap(err)
}

return &SSORequestParams{
ClientRedirectURL: clientRedirectURL,
ConnectorID: connectorID,
CSRFToken: csrfToken,
}, nil
}

Expand Down
64 changes: 11 additions & 53 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ import (
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/httplib/csrf"
"github.com/gravitational/teleport/lib/inventory"
kubeproxy "github.com/gravitational/teleport/lib/kube/proxy"
"github.com/gravitational/teleport/lib/limiter"
Expand Down Expand Up @@ -947,50 +946,32 @@ func TestWebSessionsCRUD(t *testing.T) {
func TestCSRF(t *testing.T) {
t.Parallel()
s := newWebSuite(t)
type input struct {
reqToken string
cookieToken string
}

// create a valid user
user := "csrfuser"
pass := "abcdef123456"
otpSecret := newOTPSharedSecret()
s.createUser(t, user, user, pass, otpSecret)

encodedToken1 := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
encodedToken2 := "bf355921bbf3ef3672a03e410d4194077dfa5fe863c652521763b3e7f81e7b11"
invalid := []input{
{reqToken: encodedToken2, cookieToken: encodedToken1},
{reqToken: "", cookieToken: encodedToken1},
{reqToken: "", cookieToken: ""},
{reqToken: encodedToken1, cookieToken: ""},
}

clt := s.client(t)
ctx := context.Background()

// valid
validReq := loginWebOTPParams{
webClient: clt,
clock: s.clock,
user: user,
password: pass,
otpSecret: otpSecret,
cookieCSRF: &encodedToken1,
headerCSRF: &encodedToken1,
webClient: clt,
clock: s.clock,
user: user,
password: pass,
otpSecret: otpSecret,
}
loginWebOTP(t, ctx, validReq)

// invalid
for i := range invalid {
req := validReq
req.cookieCSRF = &invalid[i].cookieToken
req.headerCSRF = &invalid[i].reqToken
httpResp, _, err := rawLoginWebOTP(ctx, req)
require.NoError(t, err, "Login via /webapi/sessions/new failed unexpectedly")
assert.Equal(t, http.StatusForbidden, httpResp.StatusCode, "HTTP status code mismatch")
}
// invalid - wrong content-type header
invalidReq := validReq
invalidReq.overrideContentType = "multipart/form-data"
httpResp, _, err := rawLoginWebOTP(ctx, invalidReq)
require.NoError(t, err, "Login via /webapi/sessions/new failed unexpectedly")
require.Equal(t, http.StatusBadRequest, httpResp.StatusCode, "HTTP status code mismatch")
}

func TestPasswordChange(t *testing.T) {
Expand Down Expand Up @@ -5953,13 +5934,9 @@ func TestChangeUserAuthentication_WithPrivacyPolicyEnabledError(t *testing.T) {
httpReqData, err := json.Marshal(req)
require.NoError(t, err)

// CSRF protected endpoint.
csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
httpReq, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(httpReqData))
require.NoError(t, err)
addCSRFCookieToReq(httpReq, csrfToken)
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set(csrf.HeaderName, csrfToken)
httpRes, err := httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) {
return clt.HTTPClient().Do(httpReq)
}))
Expand Down Expand Up @@ -6104,10 +6081,6 @@ func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing
req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(body))
require.NoError(t, err)

csrfToken, err := csrf.GenerateToken()
require.NoError(t, err)
addCSRFCookieToReq(req, csrfToken)
req.Header.Set(csrf.HeaderName, csrfToken)
req.Header.Set("Content-Type", "application/json")

re, err := clt.Client.RoundTrip(func() (*http.Response, error) {
Expand All @@ -6129,8 +6102,6 @@ func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing
func TestParseSSORequestParams(t *testing.T) {
t.Parallel()

token := "someMeaninglessTokenString"

tests := []struct {
name, url string
wantErr bool
Expand All @@ -6142,7 +6113,6 @@ func TestParseSSORequestParams(t *testing.T) {
expected: &SSORequestParams{
ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc",
ConnectorID: "oidc",
CSRFToken: token,
},
},
{
Expand All @@ -6151,7 +6121,6 @@ func TestParseSSORequestParams(t *testing.T) {
expected: &SSORequestParams{
ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc",
ConnectorID: "github",
CSRFToken: token,
},
},
{
Expand All @@ -6160,7 +6129,6 @@ func TestParseSSORequestParams(t *testing.T) {
expected: &SSORequestParams{
ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc",
ConnectorID: "saml",
CSRFToken: token,
},
},
{
Expand All @@ -6179,7 +6147,6 @@ func TestParseSSORequestParams(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest("", tc.url, nil)
require.NoError(t, err)
addCSRFCookieToReq(req, token)

params, err := ParseSSORequestParams(req)

Expand Down Expand Up @@ -7932,15 +7899,6 @@ func (s *WebSuite) url() *url.URL {
return u
}

func addCSRFCookieToReq(req *http.Request, token string) {
cookie := &http.Cookie{
Name: csrf.CookieName,
Value: token,
}

req.AddCookie(cookie)
}

func removeSpace(in string) string {
for _, c := range []string{"\n", "\r", "\t"} {
in = strings.Replace(in, c, " ", -1)
Expand Down
22 changes: 4 additions & 18 deletions lib/web/login_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package web

import (
"bytes"
"cmp"
"context"
"encoding/base32"
"encoding/json"
Expand All @@ -34,7 +35,6 @@ import (

"github.com/gravitational/teleport/lib/auth/mocku2f"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/httplib/csrf"
)

// newOTPSharedSecret returns an OTP shared secret, encoded as a base32 string.
Expand All @@ -54,9 +54,8 @@ type loginWebOTPParams struct {
// If empty then no OTP is sent in the request.
otpSecret string

userAgent string // Optional.

cookieCSRF, headerCSRF *string // Explicit CSRF tokens. Optional.
userAgent string // Optional.
overrideContentType string // Optional.
}

// DrainedHTTPResponse mimics an http.Response, but without a body.
Expand Down Expand Up @@ -124,24 +123,11 @@ func rawLoginWebOTP(ctx context.Context, params loginWebOTPParams) (resp *Draine
}

// Set assorted headers.
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Type", cmp.Or(params.overrideContentType, "application/json"))
if params.userAgent != "" {
req.Header.Set("User-Agent", params.userAgent)
}

// Set CSRF cookie and header.
const defaultCSRFToken = "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
cookieCSRF := defaultCSRFToken
if params.cookieCSRF != nil {
cookieCSRF = *params.cookieCSRF
}
addCSRFCookieToReq(req, cookieCSRF)
headerCSRF := defaultCSRFToken
if params.headerCSRF != nil {
headerCSRF = *params.headerCSRF
}
req.Header.Set(csrf.HeaderName, headerCSRF)

httpResp, err := webClient.HTTPClient().Do(req)
if err != nil {
return nil, nil, trace.Wrap(err, "do HTTP request")
Expand Down

0 comments on commit 93df810

Please sign in to comment.