From 06a1f05adcd8f16d6f304c3dff017c0e61270a3b Mon Sep 17 00:00:00 2001 From: Tsar Nikolay <nsmirnov@ics.perm.ru> Date: Fri, 18 Dec 2020 13:39:42 +0500 Subject: [PATCH] Fix #2294: batch processing loses HttpContext for main request. --- .../Batch/ODataBatchReaderExtensions.cs | 15 +-- .../Batch/DefaultODataBatchHandlerTest.cs | 96 ++++++++++++++++++- 2 files changed, 97 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs b/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs index e054e0fa57..acfe36e2b2 100644 --- a/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs +++ b/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs @@ -14,7 +14,6 @@ using Microsoft.AspNet.OData.Interfaces; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Primitives; using Microsoft.OData; @@ -239,17 +238,9 @@ private static HttpContext CreateHttpContext(HttpContext originalContext) features[typeof(IHttpResponseFeature)] = new HttpResponseFeature(); - // Create a context from the factory or use the default context. - HttpContext context = null; - IHttpContextFactory httpContextFactory = originalContext.RequestServices.GetRequiredService<IHttpContextFactory>(); - if (httpContextFactory != null) - { - context = httpContextFactory.Create(features); - } - else - { - context = new DefaultHttpContext(features); - } + // Create a context. + // IHttpContextFactory should not be used, because it resets IHttpContextAccessor.HttpContext; + HttpContext context = new DefaultHttpContext(features); // Clone parts of the request. All other parts of the request will be // populated during batch processing. diff --git a/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs b/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs index 4b7a6fcf0b..e3fc8fa8cf 100644 --- a/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs +++ b/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs @@ -12,11 +12,14 @@ using Microsoft.AspNet.OData.Extensions; using Microsoft.AspNet.OData.Test.Abstraction; using Microsoft.AspNet.OData.Test.Common; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Xunit; #if !NETCORE using System.Web.Http; using System.Web.Http.Routing; #else +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Newtonsoft.Json; #endif @@ -727,7 +730,7 @@ public async Task SendAsync_CorrectlyHandlesCookieHeader() var changesetRef = $"changeset_{Guid.NewGuid()}"; var endpoint = "http://localhost"; - Type[] controllers = new[] { typeof(BatchTestCustomersController), typeof(BatchTestOrdersController), }; + Type[] controllers = new[] { typeof(BatchTestOrdersController), }; var server = TestServerFactory.Create(controllers, (config) => { var builder = ODataConventionModelBuilderFactory.Create(config); @@ -782,8 +785,73 @@ public async Task SendAsync_CorrectlyHandlesCookieHeader() var response = await client.SendAsync(batchRequest); ExceptionAssert.DoesNotThrow(() => response.EnsureSuccessStatusCode()); + } + + [Fact] + public async Task ProcessBatchAsync_PreservesHttpContext() + { + var batchRef = $"batch_{Guid.NewGuid()}"; + var changesetRef = $"changeset_{Guid.NewGuid()}"; + var endpoint = "http://localhost"; + + Type[] controllers = new[] { typeof(BatchTestOrdersController), }; + var server = TestServerFactory.Create( + controllers, + config => + { + var builder = ODataConventionModelBuilderFactory.Create(config); + builder.EntitySet<BatchTestOrder>("BatchTestOrders"); + + config.MapODataServiceRoute("odata", null, builder.GetEdmModel(), new CustomODataBatchHandler()); + config.Expand(); + config.EnableDependencyInjection(); + }, + config => + { + config.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>(); + }); + + var client = TestServerFactory.CreateClient(server); + + var orderId = 2; + var createOrderPayload = $@"{{""@odata.type"":""Microsoft.AspNet.OData.Test.Batch.BatchTestOrder"",""Id"":{orderId},""Amount"":50}}"; + + var batchRequest = new HttpRequestMessage(HttpMethod.Post, $"{endpoint}/$batch"); + batchRequest.Headers.Accept.Add(MediaTypeWithQualityHeaderValue.Parse("text/plain")); + + var batchContent = $@" +--{batchRef} +Content-Type: multipart/mixed;boundary={changesetRef} + +--{changesetRef} +Content-Type: application/http +Content-Transfer-Encoding: binary +Content-ID: 1 + +POST {endpoint}/BatchTestOrders HTTP/1.1 +Content-Type: application/json;type=entry +Prefer: return=representation - // TODO: assert somehow? +{createOrderPayload} +--{changesetRef}-- +--{batchRef} +Content-Type: application/http +Content-Transfer-Encoding: binary + +GET {endpoint}/BatchTestOrders({orderId}) HTTP/1.1 +Content-Type: application/json;type=entry +Prefer: return=representation + +--{batchRef}-- +"; + + var httpContent = new StringContent(batchContent); + httpContent.Headers.ContentType = MediaTypeHeaderValue.Parse($"multipart/mixed;boundary={batchRef}"); + httpContent.Headers.ContentLength = batchContent.Length; + batchRequest.Content = httpContent; + var response = await client.SendAsync(batchRequest); + + ExceptionAssert.DoesNotThrow(() => response.EnsureSuccessStatusCode()); } #endif } @@ -824,6 +892,13 @@ public class BatchTestOrder return new List<BatchTestOrder> { order01 }; }); + + [EnableQuery] + public SingleResult<BatchTestOrder> Get([FromODataUri]int key) + { + return SingleResult.Create(Orders.Where(d => d.Id.Equals(key)).AsQueryable()); + } + public static IList<BatchTestOrder> Orders { get @@ -927,5 +1002,22 @@ public class BatchTestHeadersCustomer { public int Id { get; set; } } + + public class CustomODataBatchHandler : DefaultODataBatchHandler + { + /// <inheritdoc /> + public override async Task ProcessBatchAsync(HttpContext context, RequestDelegate nextHandler) + { + // Retrieve current httpcontext. + var httpContextAccessor = context.RequestServices.GetService<IHttpContextAccessor>(); + var beforeContext = httpContextAccessor?.HttpContext; + await base.ProcessBatchAsync(context, nextHandler); + var afterContext = httpContextAccessor?.HttpContext; + if (httpContextAccessor != null) + { + Assert.Equal(beforeContext, afterContext); + } + } + } #endif }