diff --git a/app/middleware/csrf.js b/app/middleware/csrf.js index 077e6993d..ea94e3c78 100644 --- a/app/middleware/csrf.js +++ b/app/middleware/csrf.js @@ -1,52 +1,34 @@ -'use strict' - -// NPM dependencies -const csrf = require('csrf') - -// Local dependencies const logger = require('../utils/logger')(__filename) -const { getLoggingFields } = require('../utils/logging-fields-helper') const session = require('../utils/session') const responseRouter = require('../utils/response-router') const cookies = require('../utils/cookies') +const { configureCsrfMiddleware } = require('@govuk-pay/pay-js-commons/lib/utils/middleware/csrf.middleware') -exports.csrfSetSecret = (req, _, next) => { - const csrfSecret = cookies.getSessionCsrfSecret(req) - if (!csrfSecret) { - logger.info('Setting CSRF secret for session') - cookies.setSessionVariable(req, 'csrfSecret', csrf().secretSync()) - } - next() -} +const csrfMiddleware = configureCsrfMiddleware(logger, cookies.getSessionCookieName(), 'csrfSecret', 'csrfToken') -exports.csrfTokenGeneration = (req, res, next) => { - const csrfSecret = cookies.getSessionCsrfSecret(req) - res.locals.csrf = csrf().create(csrfSecret) - next() -} +exports.setSecret = csrfMiddleware.setSecret + +exports.generateToken = csrfMiddleware.generateToken + +exports.checkToken = [checkToken, handleCsrfError] -exports.csrfCheck = (req, res, next) => { +function checkToken (req, res, next) { const chargeId = fetchAndValidateChargeId(req) if (!chargeId) { return responseRouter.response(req, res, 'UNAUTHORISED') } + csrfMiddleware.checkToken(req, res, next) +} - const sessionCsrfSecret = cookies.getSessionCsrfSecret(req) - const csrfToken = req.body.csrfToken - - if (!sessionCsrfSecret) { - responseRouter.response(req, res, 'UNAUTHORISED') - logger.warn('CSRF secret is not defined', { - ...getLoggingFields(req), - referrer: req.get('Referrer'), - url: req.originalUrl, - method: req.method - }) - } else if (!csrfValid(csrfToken, sessionCsrfSecret, req)) { - responseRouter.systemErrorResponse(req, res, 'CSRF is invalid') - } else { - next() +function handleCsrfError (err, req, res, next) { + if (err && err.name === 'CsrfError') { + if (err.message.toLowerCase().includes('csrf secret was not found')) { + return responseRouter.response(req, res, 'UNAUTHORISED') + } else { + return responseRouter.systemErrorResponse(req, res, 'CSRF is invalid') + } } + next(err) } function fetchAndValidateChargeId (req) { @@ -55,14 +37,3 @@ function fetchAndValidateChargeId (req) { } return false } - -function csrfValid (csrfToken, secret, req) { - if (!secret) { - return false - } - if (!['put', 'post'].includes(req.method.toLowerCase())) { - return true - } else { - return csrf().verify(secret, csrfToken) - } -} diff --git a/app/routes.js b/app/routes.js index 43ab0b2ed..7da8bb195 100644 --- a/app/routes.js +++ b/app/routes.js @@ -14,7 +14,7 @@ const { log } = require('./controllers/client-side-logging.controller') const paths = require('./paths.js') // Express middleware -const { csrfSetSecret, csrfCheck, csrfTokenGeneration } = require('./middleware/csrf.js') +const { setSecret, generateToken, checkToken } = require('./middleware/csrf.js') const actionName = require('./middleware/action-name.js') const stateEnforcer = require('./middleware/state-enforcer.js') const retrieveCharge = require('./middleware/retrieve-charge.js') @@ -43,9 +43,9 @@ exports.bind = function (app) { const card = paths.card const standardMiddlewareStack = [ - csrfSetSecret, - csrfCheck, - csrfTokenGeneration, + setSecret, + checkToken, + generateToken, actionName, enforceSessionCookie, retrieveCharge, diff --git a/app/utils/response-router.js b/app/utils/response-router.js index d9a7e8f06..564938f1a 100644 --- a/app/utils/response-router.js +++ b/app/utils/response-router.js @@ -239,7 +239,7 @@ const actions = { } } -exports.errorResponse = function errorReponse (req, res, reason, options = {}, error) { +exports.errorResponse = function errorResponse (req, res, reason, options = {}, error) { const action = actions.ERROR logErrorPageShown(action.view, reason, getLoggingFields(req), error) options.viewName = 'ERROR' diff --git a/test/middleware/csrf.test.js b/test/middleware/csrf.test.js index 52a0938c7..b23d13564 100644 --- a/test/middleware/csrf.test.js +++ b/test/middleware/csrf.test.js @@ -4,7 +4,7 @@ const _ = require('lodash') const expect = require('chai').expect const nock = require('nock') const helper = require('../test-helpers/test-helpers.js') -const { csrfCheck, csrfTokenGeneration } = require('../../app/middleware/csrf.js') +const { checkToken, generateToken } = require('../../app/middleware/csrf.js') describe('retrieve param test', function () { const response = { @@ -32,9 +32,6 @@ describe('retrieve param test', function () { const noCharge = _.cloneDeep(validGetRequest) delete noCharge.frontend_state.ch_foo - const noSecret = _.cloneDeep(validGetRequest) - delete noSecret.frontend_state.csrfSecret - const invalidPost = _.cloneDeep(validGetRequest) delete invalidPost.method const invalidPut = _.cloneDeep(invalidPost) @@ -45,6 +42,9 @@ describe('retrieve param test', function () { const validPost = _.cloneDeep(invalidPost) validPost.body.csrfToken = helper.csrfToken() + const noSecret = _.cloneDeep(validPost) + delete noSecret.frontend_state.csrfSecret + const validPut = _.cloneDeep(invalidPut) validPut.body.csrfToken = helper.csrfToken() @@ -67,6 +67,18 @@ describe('retrieve param test', function () { expect(resp.locals.csrf).to.not.be.undefined // eslint-disable-line } + const callCheckToken = (scenario, expectedResponse, next) => { + const [checkTokenMiddleware, handleCsrfError] = checkToken + checkTokenMiddleware(scenario, expectedResponse, (err) => { + // simulate next(err) + if (err) { + handleCsrfError(err, scenario, expectedResponse, next) + } else { + next() + } + }) + } + beforeEach(function () { status = sinon.stub(response, 'status') render = sinon.stub(response, 'render') @@ -81,38 +93,38 @@ describe('retrieve param test', function () { it('should append csrf token to response locals on get request', function () { const resp = _.cloneDeep(response) - csrfTokenGeneration(validGetRequest, resp, next) + generateToken(validGetRequest, resp, next) assertValidRequest(next, resp, status, render) }) it('should error if no charge in session', function () { const resp = _.cloneDeep(response) - csrfCheck(noCharge, resp, next) + callCheckToken(noCharge, resp, next) assertUnauthorisedRequest(next, resp, status, render) }) it('should error if no secret in session', function () { const resp = _.cloneDeep(response) - csrfCheck(noSecret, resp, next) + callCheckToken(noSecret, resp, next) assertUnauthorisedRequest(next, resp, status, render) }) it('should error if no csrfToken in post request', function () { const resp = _.cloneDeep(response) - csrfCheck(invalidPost, resp, next) + callCheckToken(invalidPost, resp, next) assertErrorRequest(next, resp, status, render) }) it('should be successful on post if valid put', function () { const resp = _.cloneDeep(response) - csrfCheck(validPut, resp, next) - csrfTokenGeneration(validGetRequest, resp, next) + callCheckToken(validPut, resp, next) + generateToken(validGetRequest, resp, next) assertValidRequest(next, resp, status, render) }) it('should error if no csrfToken in put request', function () { const resp = _.cloneDeep(response) - csrfCheck(invalidPut, resp, next) + callCheckToken(invalidPut, resp, next) assertErrorRequest(next, resp, status, render) }) })