diff --git a/src/Ocelot/Errors/Middleware/ExceptionHandlerMiddleware.cs b/src/Ocelot/Errors/Middleware/ExceptionHandlerMiddleware.cs index feb64af0d..58c189010 100644 --- a/src/Ocelot/Errors/Middleware/ExceptionHandlerMiddleware.cs +++ b/src/Ocelot/Errors/Middleware/ExceptionHandlerMiddleware.cs @@ -1,16 +1,15 @@ -using System; -using System.Linq; -using System.Threading.Tasks; -using Ocelot.Configuration.Repository; -using Ocelot.Infrastructure.Extensions; -using Ocelot.Infrastructure.RequestData; -using Ocelot.Logging; -using Ocelot.Middleware; - namespace Ocelot.Errors.Middleware { using Configuration; - + using System; + using System.Linq; + using System.Threading.Tasks; + using Ocelot.Configuration.Repository; + using Ocelot.Infrastructure.Extensions; + using Ocelot.Infrastructure.RequestData; + using Ocelot.Logging; + using Ocelot.Middleware; + /// /// Catches all unhandled exceptions thrown by middleware, logs and returns a 500 /// diff --git a/src/Ocelot/Responder/HttpContextResponder.cs b/src/Ocelot/Responder/HttpContextResponder.cs index 0b2959b4b..6511590c6 100644 --- a/src/Ocelot/Responder/HttpContextResponder.cs +++ b/src/Ocelot/Responder/HttpContextResponder.cs @@ -1,74 +1,74 @@ -using System.IO; -using System.Linq; -using System.Net; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Primitives; -using Ocelot.Headers; -using Ocelot.Middleware; - -namespace Ocelot.Responder -{ - /// - /// Cannot unit test things in this class due to methods not being implemented - /// on .net concretes used for testing - /// - public class HttpContextResponder : IHttpResponder - { - private readonly IRemoveOutputHeaders _removeOutputHeaders; - - public HttpContextResponder(IRemoveOutputHeaders removeOutputHeaders) - { - _removeOutputHeaders = removeOutputHeaders; - } - - public async Task SetResponseOnHttpContext(HttpContext context, DownstreamResponse response) - { - _removeOutputHeaders.Remove(response.Headers); - - foreach (var httpResponseHeader in response.Headers) - { - AddHeaderIfDoesntExist(context, httpResponseHeader); - } - - foreach (var httpResponseHeader in response.Content.Headers) - { - AddHeaderIfDoesntExist(context, new Header(httpResponseHeader.Key, httpResponseHeader.Value)); - } - - var content = await response.Content.ReadAsByteArrayAsync(); - - AddHeaderIfDoesntExist(context, new Header("Content-Length", new []{ content.Length.ToString() }) ); - - context.Response.OnStarting(state => - { - var httpContext = (HttpContext)state; - - httpContext.Response.StatusCode = (int)response.StatusCode; - - return Task.CompletedTask; - }, context); - - using (Stream stream = new MemoryStream(content)) - { - if (response.StatusCode != HttpStatusCode.NotModified && context.Response.ContentLength != 0) - { - await stream.CopyToAsync(context.Response.Body); - } - } - } - - public void SetErrorResponseOnContext(HttpContext context, int statusCode) +using System.IO; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Primitives; +using Ocelot.Headers; +using Ocelot.Middleware; + +namespace Ocelot.Responder +{ + /// + /// Cannot unit test things in this class due to methods not being implemented + /// on .net concretes used for testing + /// + public class HttpContextResponder : IHttpResponder + { + private readonly IRemoveOutputHeaders _removeOutputHeaders; + + public HttpContextResponder(IRemoveOutputHeaders removeOutputHeaders) { - context.Response.StatusCode = statusCode; - } - - private static void AddHeaderIfDoesntExist(HttpContext context, Header httpResponseHeader) - { - if (!context.Response.Headers.ContainsKey(httpResponseHeader.Key)) - { - context.Response.Headers.Add(httpResponseHeader.Key, new StringValues(httpResponseHeader.Values.ToArray())); - } - } - } -} + _removeOutputHeaders = removeOutputHeaders; + } + + public async Task SetResponseOnHttpContext(HttpContext context, DownstreamResponse response) + { + _removeOutputHeaders.Remove(response.Headers); + + foreach (var httpResponseHeader in response.Headers) + { + AddHeaderIfDoesntExist(context, httpResponseHeader); + } + + foreach (var httpResponseHeader in response.Content.Headers) + { + AddHeaderIfDoesntExist(context, new Header(httpResponseHeader.Key, httpResponseHeader.Value)); + } + + var content = await response.Content.ReadAsStreamAsync(); + + AddHeaderIfDoesntExist(context, new Header("Content-Length", new []{ content.Length.ToString() }) ); + + context.Response.OnStarting(state => + { + var httpContext = (HttpContext)state; + + httpContext.Response.StatusCode = (int)response.StatusCode; + + return Task.CompletedTask; + }, context); + + using(content) + { + if (response.StatusCode != HttpStatusCode.NotModified && context.Response.ContentLength != 0) + { + await content.CopyToAsync(context.Response.Body); + } + } + } + + public void SetErrorResponseOnContext(HttpContext context, int statusCode) + { + context.Response.StatusCode = statusCode; + } + + private static void AddHeaderIfDoesntExist(HttpContext context, Header httpResponseHeader) + { + if (!context.Response.Headers.ContainsKey(httpResponseHeader.Key)) + { + context.Response.Headers.Add(httpResponseHeader.Key, new StringValues(httpResponseHeader.Values.ToArray())); + } + } + } +} diff --git a/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs b/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs index 0487d811b..ec6696c55 100644 --- a/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs +++ b/src/Ocelot/Responder/Middleware/ResponderMiddleware.cs @@ -1,55 +1,55 @@ -using Microsoft.AspNetCore.Http; -using Ocelot.Errors; -using Ocelot.Logging; -using Ocelot.Middleware; -using System.Collections.Generic; -using System.Threading.Tasks; -using Ocelot.Infrastructure.Extensions; - -namespace Ocelot.Responder.Middleware -{ - /// - /// Completes and returns the request and request body, if any pipeline errors occured then sets the appropriate HTTP status code instead. - /// - public class ResponderMiddleware : OcelotMiddleware - { - private readonly OcelotRequestDelegate _next; - private readonly IHttpResponder _responder; - private readonly IErrorsToHttpStatusCodeMapper _codeMapper; - - public ResponderMiddleware(OcelotRequestDelegate next, - IHttpResponder responder, - IOcelotLoggerFactory loggerFactory, - IErrorsToHttpStatusCodeMapper codeMapper - ) - :base(loggerFactory.CreateLogger()) - { - _next = next; - _responder = responder; - _codeMapper = codeMapper; - } - - public async Task Invoke(DownstreamContext context) - { - await _next.Invoke(context); - - if (context.IsError) - { - Logger.LogWarning($"{context.Errors.ToErrorString()} errors found in {MiddlewareName}. Setting error response for request path:{context.HttpContext.Request.Path}, request method: {context.HttpContext.Request.Method}"); - - SetErrorResponse(context.HttpContext, context.Errors); - } - else - { - Logger.LogDebug("no pipeline errors, setting and returning completed response"); - await _responder.SetResponseOnHttpContext(context.HttpContext, context.DownstreamResponse); - } - } - - private void SetErrorResponse(HttpContext context, List errors) - { - var statusCode = _codeMapper.Map(errors); - _responder.SetErrorResponseOnContext(context, statusCode); - } - } -} +using Microsoft.AspNetCore.Http; +using Ocelot.Errors; +using Ocelot.Logging; +using Ocelot.Middleware; +using System.Collections.Generic; +using System.Threading.Tasks; +using Ocelot.Infrastructure.Extensions; + +namespace Ocelot.Responder.Middleware +{ + /// + /// Completes and returns the request and request body, if any pipeline errors occured then sets the appropriate HTTP status code instead. + /// + public class ResponderMiddleware : OcelotMiddleware + { + private readonly OcelotRequestDelegate _next; + private readonly IHttpResponder _responder; + private readonly IErrorsToHttpStatusCodeMapper _codeMapper; + + public ResponderMiddleware(OcelotRequestDelegate next, + IHttpResponder responder, + IOcelotLoggerFactory loggerFactory, + IErrorsToHttpStatusCodeMapper codeMapper + ) + :base(loggerFactory.CreateLogger()) + { + _next = next; + _responder = responder; + _codeMapper = codeMapper; + } + + public async Task Invoke(DownstreamContext context) + { + await _next.Invoke(context); + + if (context.IsError) + { + Logger.LogWarning($"{context.Errors.ToErrorString()} errors found in {MiddlewareName}. Setting error response for request path:{context.HttpContext.Request.Path}, request method: {context.HttpContext.Request.Method}"); + + SetErrorResponse(context.HttpContext, context.Errors); + } + else + { + Logger.LogDebug("no pipeline errors, setting and returning completed response"); + await _responder.SetResponseOnHttpContext(context.HttpContext, context.DownstreamResponse); + } + } + + private void SetErrorResponse(HttpContext context, List errors) + { + var statusCode = _codeMapper.Map(errors); + _responder.SetErrorResponseOnContext(context, statusCode); + } + } +} diff --git a/test/Ocelot.UnitTests/Errors/ExceptionHandlerMiddlewareTests.cs b/test/Ocelot.UnitTests/Errors/ExceptionHandlerMiddlewareTests.cs index ce88b4688..0ca0b659f 100644 --- a/test/Ocelot.UnitTests/Errors/ExceptionHandlerMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Errors/ExceptionHandlerMiddlewareTests.cs @@ -1,197 +1,197 @@ -namespace Ocelot.UnitTests.Errors -{ - using System; - using System.Net; - using System.Threading.Tasks; - using Ocelot.Errors.Middleware; - using Ocelot.Logging; - using Shouldly; - using TestStack.BDDfy; - using Xunit; - using Microsoft.AspNetCore.Http; - using Moq; - using Ocelot.Configuration; - using Ocelot.Errors; - using Ocelot.Infrastructure.RequestData; - using Ocelot.Middleware; - using Ocelot.Configuration.Repository; - - public class ExceptionHandlerMiddlewareTests - { - bool _shouldThrowAnException; - private readonly Mock _configRepo; - private readonly Mock _repo; - private Mock _loggerFactory; - private Mock _logger; - private readonly ExceptionHandlerMiddleware _middleware; - private readonly DownstreamContext _downstreamContext; - private OcelotRequestDelegate _next; - - public ExceptionHandlerMiddlewareTests() - { - _configRepo = new Mock(); - _repo = new Mock(); - _downstreamContext = new DownstreamContext(new DefaultHttpContext()); - _loggerFactory = new Mock(); - _logger = new Mock(); - _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); - _next = async context => { - await Task.CompletedTask; - - if (_shouldThrowAnException) - { - throw new Exception("BOOM"); - } - - context.HttpContext.Response.StatusCode = (int)HttpStatusCode.OK; - }; - _middleware = new ExceptionHandlerMiddleware(_next, _loggerFactory.Object, _configRepo.Object, _repo.Object); - } - - [Fact] - public void NoDownstreamException() - { - var config = new InternalConfiguration(null, null, null, null, null, null, null, null); - - this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) - .And(_ => GivenTheConfigurationIs(config)) - .When(_ => WhenICallTheMiddleware()) - .Then(_ => ThenTheResponseIsOk()) - .And(_ => TheAspDotnetRequestIdIsSet()) - .BDDfy(); - } - - [Fact] - public void DownstreamException() - { - var config = new InternalConfiguration(null, null, null, null, null, null, null, null); - - this.Given(_ => GivenAnExceptionWillBeThrownDownstream()) - .And(_ => GivenTheConfigurationIs(config)) - .When(_ => WhenICallTheMiddleware()) - .Then(_ => ThenTheResponseIsError()) - .BDDfy(); - } - - [Fact] - public void ShouldSetRequestId() - { - var config = new InternalConfiguration(null, null, null, "requestidkey", null, null, null, null); - - this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) - .And(_ => GivenTheConfigurationIs(config)) - .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) - .Then(_ => ThenTheResponseIsOk()) - .And(_ => TheRequestIdIsSet("RequestId", "1234")) - .BDDfy(); - } - - [Fact] - public void ShouldSetAspDotNetRequestId() - { - var config = new InternalConfiguration(null, null, null, null, null, null, null, null); - - this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) - .And(_ => GivenTheConfigurationIs(config)) - .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) - .Then(_ => ThenTheResponseIsOk()) - .And(_ => TheAspDotnetRequestIdIsSet()) - .BDDfy(); - } - - [Fact] - public void should_throw_exception_if_config_provider_returns_error() - { - this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) - .And(_ => GivenTheConfigReturnsError()) - .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) - .Then(_ => ThenAnExceptionIsThrown()) - .BDDfy(); - } - - [Fact] - public void should_throw_exception_if_config_provider_throws() - { - this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) - .And(_ => GivenTheConfigThrows()) - .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) - .Then(_ => ThenAnExceptionIsThrown()) - .BDDfy(); - } - - private void WhenICallTheMiddlewareWithTheRequestIdKey(string key, string value) - { - _downstreamContext.HttpContext.Request.Headers.Add(key, value); - _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); - } - - private void WhenICallTheMiddleware() - { - _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); - } - - private void GivenTheConfigThrows() - { - var ex = new Exception("outer", new Exception("inner")); - _configRepo - .Setup(x => x.Get()).Throws(ex); - } - - private void ThenAnExceptionIsThrown() - { - _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500); - } - - private void GivenTheConfigReturnsError() - { - var response = new Responses.ErrorResponse(new FakeError()); - _configRepo - .Setup(x => x.Get()).Returns(response); - } - - private void TheRequestIdIsSet(string key, string value) - { - _repo.Verify(x => x.Add(key, value), Times.Once); - } - - private void GivenTheConfigurationIs(IInternalConfiguration config) - { - var response = new Responses.OkResponse(config); - _configRepo - .Setup(x => x.Get()).Returns(response); - } - - private void GivenAnExceptionWillNotBeThrownDownstream() - { - _shouldThrowAnException = false; - } - - private void GivenAnExceptionWillBeThrownDownstream() - { - _shouldThrowAnException = true; - } - - private void ThenTheResponseIsOk() - { - _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(200); - } - - private void ThenTheResponseIsError() - { - _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500); - } - - private void TheAspDotnetRequestIdIsSet() - { - _repo.Verify(x => x.Add(It.IsAny(), It.IsAny()), Times.Once); - } - - class FakeError : Error - { - internal FakeError() - : base("meh", OcelotErrorCode.CannotAddDataError) - { - } - } - } -} +namespace Ocelot.UnitTests.Errors +{ + using System; + using System.Net; + using System.Threading.Tasks; + using Ocelot.Errors.Middleware; + using Ocelot.Logging; + using Shouldly; + using TestStack.BDDfy; + using Xunit; + using Microsoft.AspNetCore.Http; + using Moq; + using Ocelot.Configuration; + using Ocelot.Errors; + using Ocelot.Infrastructure.RequestData; + using Ocelot.Middleware; + using Ocelot.Configuration.Repository; + + public class ExceptionHandlerMiddlewareTests + { + bool _shouldThrowAnException; + private readonly Mock _configRepo; + private readonly Mock _repo; + private Mock _loggerFactory; + private Mock _logger; + private readonly ExceptionHandlerMiddleware _middleware; + private readonly DownstreamContext _downstreamContext; + private OcelotRequestDelegate _next; + + public ExceptionHandlerMiddlewareTests() + { + _configRepo = new Mock(); + _repo = new Mock(); + _downstreamContext = new DownstreamContext(new DefaultHttpContext()); + _loggerFactory = new Mock(); + _logger = new Mock(); + _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); + _next = async context => { + await Task.CompletedTask; + + if (_shouldThrowAnException) + { + throw new Exception("BOOM"); + } + + context.HttpContext.Response.StatusCode = (int)HttpStatusCode.OK; + }; + _middleware = new ExceptionHandlerMiddleware(_next, _loggerFactory.Object, _configRepo.Object, _repo.Object); + } + + [Fact] + public void NoDownstreamException() + { + var config = new InternalConfiguration(null, null, null, null, null, null, null, null); + + this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) + .And(_ => GivenTheConfigurationIs(config)) + .When(_ => WhenICallTheMiddleware()) + .Then(_ => ThenTheResponseIsOk()) + .And(_ => TheAspDotnetRequestIdIsSet()) + .BDDfy(); + } + + [Fact] + public void DownstreamException() + { + var config = new InternalConfiguration(null, null, null, null, null, null, null, null); + + this.Given(_ => GivenAnExceptionWillBeThrownDownstream()) + .And(_ => GivenTheConfigurationIs(config)) + .When(_ => WhenICallTheMiddleware()) + .Then(_ => ThenTheResponseIsError()) + .BDDfy(); + } + + [Fact] + public void ShouldSetRequestId() + { + var config = new InternalConfiguration(null, null, null, "requestidkey", null, null, null, null); + + this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) + .And(_ => GivenTheConfigurationIs(config)) + .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) + .Then(_ => ThenTheResponseIsOk()) + .And(_ => TheRequestIdIsSet("RequestId", "1234")) + .BDDfy(); + } + + [Fact] + public void ShouldSetAspDotNetRequestId() + { + var config = new InternalConfiguration(null, null, null, null, null, null, null, null); + + this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) + .And(_ => GivenTheConfigurationIs(config)) + .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) + .Then(_ => ThenTheResponseIsOk()) + .And(_ => TheAspDotnetRequestIdIsSet()) + .BDDfy(); + } + + [Fact] + public void should_throw_exception_if_config_provider_returns_error() + { + this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) + .And(_ => GivenTheConfigReturnsError()) + .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) + .Then(_ => ThenAnExceptionIsThrown()) + .BDDfy(); + } + + [Fact] + public void should_throw_exception_if_config_provider_throws() + { + this.Given(_ => GivenAnExceptionWillNotBeThrownDownstream()) + .And(_ => GivenTheConfigThrows()) + .When(_ => WhenICallTheMiddlewareWithTheRequestIdKey("requestidkey", "1234")) + .Then(_ => ThenAnExceptionIsThrown()) + .BDDfy(); + } + + private void WhenICallTheMiddlewareWithTheRequestIdKey(string key, string value) + { + _downstreamContext.HttpContext.Request.Headers.Add(key, value); + _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); + } + + private void WhenICallTheMiddleware() + { + _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); + } + + private void GivenTheConfigThrows() + { + var ex = new Exception("outer", new Exception("inner")); + _configRepo + .Setup(x => x.Get()).Throws(ex); + } + + private void ThenAnExceptionIsThrown() + { + _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500); + } + + private void GivenTheConfigReturnsError() + { + var response = new Responses.ErrorResponse(new FakeError()); + _configRepo + .Setup(x => x.Get()).Returns(response); + } + + private void TheRequestIdIsSet(string key, string value) + { + _repo.Verify(x => x.Add(key, value), Times.Once); + } + + private void GivenTheConfigurationIs(IInternalConfiguration config) + { + var response = new Responses.OkResponse(config); + _configRepo + .Setup(x => x.Get()).Returns(response); + } + + private void GivenAnExceptionWillNotBeThrownDownstream() + { + _shouldThrowAnException = false; + } + + private void GivenAnExceptionWillBeThrownDownstream() + { + _shouldThrowAnException = true; + } + + private void ThenTheResponseIsOk() + { + _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(200); + } + + private void ThenTheResponseIsError() + { + _downstreamContext.HttpContext.Response.StatusCode.ShouldBe(500); + } + + private void TheAspDotnetRequestIdIsSet() + { + _repo.Verify(x => x.Add(It.IsAny(), It.IsAny()), Times.Once); + } + + class FakeError : Error + { + internal FakeError() + : base("meh", OcelotErrorCode.CannotAddDataError) + { + } + } + } +}