Skip to content

Commit

Permalink
[upstream-mtls] Move certificate logic inside APIcast policy
Browse files Browse the repository at this point in the history
  • Loading branch information
tkan145 committed Oct 16, 2024
1 parent 26c408d commit 09c1c25
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 75 deletions.
39 changes: 38 additions & 1 deletion gateway/src/apicast/policy/apicast/apicast.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ local math = math
local setmetatable = setmetatable
local assert = assert
local table_insert = table.insert
local base = require "resty.core.base"
local get_request = base.get_request
local tls = require 'resty.tls'

local user_agent = require('apicast.user_agent')

Expand Down Expand Up @@ -156,6 +159,40 @@ function _M:export()
}
end

_M.balancer = balancer.call
function _M:balancer(context)
-- All of this happens on balancer because this is subrequest inside APICAst
--to @upstream, so the request need to be the one that connects to the
--upstreamssl_client_raw_cert0
local r = get_request()
if not r then
ngx.log(ngx.WARN, "Invalid request")
return
end

if context.upstream_certificate and context.upstream_key then
local ok, err = tls.set_upstream_cert_and_key(context.upstream_certificate, context.upstream_key)
if ok ~= nil then
ngx.log(ngx.ERR, "Certificate cannot be set correctly, err: ", err)
end
end

if context.upstream_verify then
local ok, err = tls.set_upstream_ssl_verify(true, 1)
if ok ~= nil then
ngx.log(ngx.WARN, "Cannot verify SSL upstream connection, err: ", err)
end

if not context.upstream_ca_store then
ngx.log(ngx.WARN, "Set verify without including CA certificates")
end

ok, err = tls.set_upstream_ca_cert(context.upstream_ca_store)
if ok ~= nil then
ngx.log(ngx.WARN, "Cannot set a valid trusted CA store, err: ", err)
end
end

balancer:call(context)
end

return _M
29 changes: 5 additions & 24 deletions gateway/src/apicast/policy/upstream_mtls/upstream_mtls.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
local ssl = require('ngx.ssl')
local data_url = require('resty.data_url')
local util = require 'apicast.util'
local tls = require 'resty.tls'

local pairs = pairs

Expand Down Expand Up @@ -91,29 +90,11 @@ function _M.new(config)
return self
end

-- All of this happens on balancer because this is subrequest inside APICAst
--to @upstream, so the request need to be the one that connects to the
--upstream0
function _M:balancer(context)
if self.cert and self.cert_key then
self.set_certs(self.cert, self.cert_key)
end

if not self.verify then
return
end

local ok, err = tls.set_upstream_ssl_verify(true, 1)
if ok ~= nil then
ngx.log(ngx.WARN, "Cannot verify SSL upstream connection, err: ", err)
end

if not self.ca_store then
ngx.log(ngx.WARN, "Set verify without including CA certificates")
return
end

self.set_ca_cert(self.ca_store)
function _M:access(context)
context.upstream_certificate = self.cert
context.upstream_key = self.cert_key
context.upstream_verify = self.verify
context.upstream_ca_store = self.ca_store
end

return _M
77 changes: 77 additions & 0 deletions spec/policy/apicast/apicast_spec.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
local _M = require 'apicast.policy.apicast'
local util = require("apicast.util")
local ssl = require('ngx.ssl')
local tls = require('resty.tls')
local X509_STORE = require('resty.openssl.x509.store')
local X509 = require('resty.openssl.x509')
local balancer = require('apicast.balancer')

describe('APIcast policy', function()
local ngx_on_abort_stub
Expand Down Expand Up @@ -31,6 +37,77 @@ describe('APIcast policy', function()
end)
end)

describe(".balancer", function()
local certificate_path = 't/fixtures/CA/root-ca.crt'
local certificate_key_path = 't/fixtures/CA/root-ca.key'

local certificate_content = util.read_file(certificate_path)
local key_content = util.read_file(certificate_key_path)
local ca_cert, _ = X509.parse_pem_cert(certificate_content)

local ca_store = X509_STORE.new()
ca_store:add_cert(ca_cert)

local cert = ssl.parse_pem_cert(certificate_content)
local key = ssl.parse_pem_priv_key(key_content)

before_each(function()
stub.new(balancer, 'call', function() return true end)
end)

it("correctly set certificate and key", function()
local apicast = _M.new()
local context = {
upstream_certificate = cert,
upstream_key = key,
}

spy.on(tls, "set_upstream_cert_and_key")
apicast:balancer(context)
assert.spy(tls.set_upstream_cert_and_key).was.called()
end)

it("ignore invalid certificate and key", function()
local apicast = _M.new()
local context = {
upstream_certificate = nil,
upstream_key = nil,
}

spy.on(tls, "set_upstream_cert_and_key")
apicast:balancer(context)
assert.spy(tls.set_upstream_cert_and_key).was_not.called()
end)

it("CA certificate is not used if verify is not enabled", function()
local apicast = _M.new()
local context = {
upstream_certificate = cert,
upstream_key = key,
upstream_verify = false,
upstream_ca_store = cert
}

spy.on(tls, "set_upstream_ca_store")
apicast:balancer(context)
assert.spy(tls.set_upstream_ca_store).was_not.called()
end)

it("CA certificate is used if verify is enabled", function()
local apicast = _M.new()
local context = {
upstream_certificate = cert,
upstream_key = key,
upstream_verify = true,
upstream_ca_store = ca_store.store
}

spy.on(tls, "set_upstream_ca_store")
apicast:balancer(context)
assert.spy(tls.set_upstream_ca_store).was_not.called()
end)
end)

describe('.post_action', function()
describe('when the "run_post_action" flag is set to true', function()
it('runs its logic', function()
Expand Down
72 changes: 22 additions & 50 deletions spec/policy/upstream_mtls/upstream_mtls_spec.lua
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
local upstream_mtls = require("apicast.policy.upstream_mtls")
local ssl = require('ngx.ssl')
local open = io.open

local function read_file(path)
local file = open(path, "rb")
if not file then return nil end
local content = file:read "*a" -- *a or *all reads the whole file
file:close()
return content
end
local util = require("apicast.util")

describe('Upstream MTLS policy', function()

local certificate_path = 't/fixtures/CA/root-ca.crt'
local certificate_key_path = 't/fixtures/CA/root-ca.key'

local certificate_content = read_file(certificate_path)
local certificate_content = util.read_file(certificate_path)
-- Set here the const to not use the pakcage ones, if not test will not fail
-- if changes0
local path_type = "path"
Expand Down Expand Up @@ -43,9 +35,10 @@ describe('Upstream MTLS policy', function()
assert.truthy(object.cert)
assert.truthy(object.cert_key)

spy.on(object, "set_certs")
object:balancer(context)
assert.spy(object.set_certs).was.called()
local context = {}
object:access(context)
assert.truthy(context.upstream_certificate)
assert.truthy(context.upstream_key)
end)


Expand All @@ -62,9 +55,10 @@ describe('Upstream MTLS policy', function()
assert.is_falsy(object.cert)
assert.is_falsy(object.cert_key)

spy.on(object, "set_certs")
object:balancer(context)
assert.spy(object.set_certs).was_not_called()
local context = {}
object:access(context)
assert.is_falsy(context.upstream_certificate)
assert.is_falsy(context.upstream_key)
end)
end)

Expand All @@ -86,9 +80,10 @@ describe('Upstream MTLS policy', function()
assert.truthy(object.cert)
assert.truthy(object.cert_key)

spy.on(object, "set_certs")
object:balancer(context)
assert.spy(object.set_certs).was.called()
local context = {}
object:access(context)
assert.truthy(context.upstream_certificate)
assert.truthy(context.upstream_key)
end)

it("Nil certificate", function()
Expand All @@ -103,10 +98,10 @@ describe('Upstream MTLS policy', function()
assert.spy(ssl.parse_pem_priv_key).was_not_called()
assert.falsy(object.cert)
assert.falsy(object.cert_key)

spy.on(object, "set_certs")
object:balancer(context)
assert.spy(object.set_certs).was_not_called()
local context = {}
object:access(context)
assert.is_falsy(context.upstream_certificate)
assert.is_falsy(context.upstream_key)
end)

it("Invalid certificate", function()
Expand All @@ -122,9 +117,10 @@ describe('Upstream MTLS policy', function()
assert.falsy(object.cert)
assert.falsy(object.cert_key)

spy.on(object, "set_certs")
object:balancer(context)
assert.spy(object.set_certs).was_not_called()
local context = {}
object:access(context)
assert.is_falsy(context.upstream_certificate)
assert.is_falsy(context.upstream_key)
end)

end)
Expand Down Expand Up @@ -160,30 +156,6 @@ describe('Upstream MTLS policy', function()
local object = upstream_mtls.new(config)
assert.same(type(object.ca_store), "cdata")
end)

it("CA certificate is not used if verify is not enabled", function()
config.ca_certificates = { certificate_content}
config.verify = false

local object = upstream_mtls.new(config)
assert.same(type(object.ca_store), "cdata")

spy.on(object, "set_ca_cert")
object:balancer({})
assert.spy(object.set_ca_cert).was_not.called()
end)

it("CA certificate is used if verify is enabled", function()
config.ca_certificates = { certificate_content}
config.verify = true

local object = upstream_mtls.new(config)
assert.same(type(object.ca_store), "cdata")

spy.on(object, "set_ca_cert")
object:balancer({})
assert.spy(object.set_ca_cert).was.called()
end)
end)
end)

0 comments on commit 09c1c25

Please sign in to comment.