From 06a1f05adcd8f16d6f304c3dff017c0e61270a3b Mon Sep 17 00:00:00 2001
From: Tsar Nikolay <nsmirnov@ics.perm.ru>
Date: Fri, 18 Dec 2020 13:39:42 +0500
Subject: [PATCH] Fix #2294: batch processing loses HttpContext for main
 request.

---
 .../Batch/ODataBatchReaderExtensions.cs       | 15 +--
 .../Batch/DefaultODataBatchHandlerTest.cs     | 96 ++++++++++++++++++-
 2 files changed, 97 insertions(+), 14 deletions(-)

diff --git a/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs b/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs
index e054e0fa57..acfe36e2b2 100644
--- a/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs
+++ b/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs
@@ -14,7 +14,6 @@
 using Microsoft.AspNet.OData.Interfaces;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http.Features;
-using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.Primitives;
 using Microsoft.OData;
 
@@ -239,17 +238,9 @@ private static HttpContext CreateHttpContext(HttpContext originalContext)
 
             features[typeof(IHttpResponseFeature)] = new HttpResponseFeature();
 
-            // Create a context from the factory or use the default context.
-            HttpContext context = null;
-            IHttpContextFactory httpContextFactory = originalContext.RequestServices.GetRequiredService<IHttpContextFactory>();
-            if (httpContextFactory != null)
-            {
-                context = httpContextFactory.Create(features);
-            }
-            else
-            {
-                context = new DefaultHttpContext(features);
-            }
+            // Create a context.
+            // IHttpContextFactory should not be used, because it resets IHttpContextAccessor.HttpContext;
+            HttpContext context = new DefaultHttpContext(features);
 
             // Clone parts of the request. All other parts of the request will be 
             // populated during batch processing.
diff --git a/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs b/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs
index 4b7a6fcf0b..e3fc8fa8cf 100644
--- a/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs
+++ b/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs
@@ -12,11 +12,14 @@
 using Microsoft.AspNet.OData.Extensions;
 using Microsoft.AspNet.OData.Test.Abstraction;
 using Microsoft.AspNet.OData.Test.Common;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.DependencyInjection.Extensions;
 using Xunit;
 #if !NETCORE
 using System.Web.Http;
 using System.Web.Http.Routing;
 #else
+using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Mvc;
 using Newtonsoft.Json;
 #endif
@@ -727,7 +730,7 @@ public async Task SendAsync_CorrectlyHandlesCookieHeader()
             var changesetRef = $"changeset_{Guid.NewGuid()}";
             var endpoint = "http://localhost";
 
-            Type[] controllers = new[] { typeof(BatchTestCustomersController), typeof(BatchTestOrdersController), };
+            Type[] controllers = new[] { typeof(BatchTestOrdersController), };
             var server = TestServerFactory.Create(controllers, (config) =>
             {
                 var builder = ODataConventionModelBuilderFactory.Create(config);
@@ -782,8 +785,73 @@ public async Task SendAsync_CorrectlyHandlesCookieHeader()
             var response = await client.SendAsync(batchRequest);
 
             ExceptionAssert.DoesNotThrow(() => response.EnsureSuccessStatusCode());
+        }
+
+        [Fact]
+        public async Task ProcessBatchAsync_PreservesHttpContext()
+        {
+            var batchRef = $"batch_{Guid.NewGuid()}";
+            var changesetRef = $"changeset_{Guid.NewGuid()}";
+            var endpoint = "http://localhost";
+
+            Type[] controllers = new[] { typeof(BatchTestOrdersController), };
+            var server = TestServerFactory.Create(
+                controllers,
+                config =>
+                {
+                    var builder = ODataConventionModelBuilderFactory.Create(config);
+                    builder.EntitySet<BatchTestOrder>("BatchTestOrders");
+
+                    config.MapODataServiceRoute("odata", null, builder.GetEdmModel(), new CustomODataBatchHandler());
+                    config.Expand();
+                    config.EnableDependencyInjection();
+                },
+                config =>
+                {
+                    config.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>();
+                });
+
+            var client = TestServerFactory.CreateClient(server);
+
+            var orderId = 2;
+            var createOrderPayload = $@"{{""@odata.type"":""Microsoft.AspNet.OData.Test.Batch.BatchTestOrder"",""Id"":{orderId},""Amount"":50}}";
+
+            var batchRequest = new HttpRequestMessage(HttpMethod.Post, $"{endpoint}/$batch");
+            batchRequest.Headers.Accept.Add(MediaTypeWithQualityHeaderValue.Parse("text/plain"));
+
+            var batchContent = $@"
+--{batchRef}
+Content-Type: multipart/mixed;boundary={changesetRef}
+
+--{changesetRef}
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: 1
+
+POST {endpoint}/BatchTestOrders HTTP/1.1
+Content-Type: application/json;type=entry
+Prefer: return=representation
 
-            // TODO: assert somehow?
+{createOrderPayload}
+--{changesetRef}--
+--{batchRef}
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+
+GET {endpoint}/BatchTestOrders({orderId}) HTTP/1.1
+Content-Type: application/json;type=entry
+Prefer: return=representation
+
+--{batchRef}--
+";
+
+            var httpContent = new StringContent(batchContent);
+            httpContent.Headers.ContentType = MediaTypeHeaderValue.Parse($"multipart/mixed;boundary={batchRef}");
+            httpContent.Headers.ContentLength = batchContent.Length;
+            batchRequest.Content = httpContent;
+            var response = await client.SendAsync(batchRequest);
+
+            ExceptionAssert.DoesNotThrow(() => response.EnsureSuccessStatusCode());
         }
 #endif
     }
@@ -824,6 +892,13 @@ public class BatchTestOrder
                 return new List<BatchTestOrder> { order01 };
             });
 
+
+        [EnableQuery]
+        public SingleResult<BatchTestOrder> Get([FromODataUri]int key)
+        {
+            return SingleResult.Create(Orders.Where(d => d.Id.Equals(key)).AsQueryable());
+        }
+
         public static IList<BatchTestOrder> Orders
         {
             get
@@ -927,5 +1002,22 @@ public class BatchTestHeadersCustomer
     {
         public int Id { get; set; }
     }
+
+    public class CustomODataBatchHandler : DefaultODataBatchHandler
+    {
+        /// <inheritdoc />
+        public override async Task ProcessBatchAsync(HttpContext context, RequestDelegate nextHandler)
+        {
+            // Retrieve current httpcontext.
+            var httpContextAccessor = context.RequestServices.GetService<IHttpContextAccessor>();
+            var beforeContext = httpContextAccessor?.HttpContext;
+            await base.ProcessBatchAsync(context, nextHandler);
+            var afterContext = httpContextAccessor?.HttpContext;
+            if (httpContextAccessor != null)
+            {
+                Assert.Equal(beforeContext, afterContext);
+            }
+        }
+    }
 #endif
 }