Skip to content

Commit

Permalink
Merge pull request #302 from max-ieremenko/bugfix/301
Browse files Browse the repository at this point in the history
non-grpc response handling in SwaggerUiMiddleware
  • Loading branch information
max-ieremenko authored Nov 30, 2024
2 parents df38894 + 8eb69f6 commit a90b531
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
// </copyright>

using System.IO.Compression;
using System.Net;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.Extensions.DependencyInjection;
using NUnit.Framework;
using ServiceModel.Grpc.AspNetCore.TestApi;
using ServiceModel.Grpc.AspNetCore.TestApi.Domain;
using ServiceModel.Grpc.TestApi.Domain;

namespace ServiceModel.Grpc.AspNetCore.NSwag;
Expand All @@ -44,6 +46,7 @@ public async Task BeforeAll()
services.AddServiceModelGrpcSwagger();
services.AddOpenApiDocument();
services.AddMvc();
services.AddTransient<CustomResponseMiddleware>();
})
.ConfigureApp(app =>
{
Expand All @@ -52,6 +55,8 @@ public async Task BeforeAll()
app.UseReDoc(); // serve ReDoc UI
app.UseServiceModelGrpcSwaggerGateway();
app.UseMiddleware<CustomResponseMiddleware>();
})
.ConfigureEndpoints(endpoints =>
{
Expand Down Expand Up @@ -123,4 +128,43 @@ public async Task Sum5ValuesAsync()

actual.ShouldBe(15);
}

[Test]
public async Task BlockingCallAsync()
{
var parameters = new Dictionary<string, object>
{
{ "x", 1 },
{ "y", "a" }
};

var actual = await _client.InvokeAsync<string>(nameof(IMultipurposeService.BlockingCallAsync), parameters);

actual.ShouldBe("a1");
}

[Test]
[TestCase(HttpStatusCode.Unauthorized, "application/grpc")]
[TestCase(HttpStatusCode.OK, "text/plain")]
public async Task NonGrpcResponseAsync(HttpStatusCode statusCode, string contentType)
{
var headers = new Dictionary<string, string>
{
{ CustomResponseMiddleware.HeaderResponseStatusCode, statusCode.ToString() },
{ CustomResponseMiddleware.HeaderContentType, contentType },
{ CustomResponseMiddleware.HeaderResponseBody, "some message" }
};

var parameters = new Dictionary<string, object>
{
{ "x", 1 },
{ "y", "a" }
};

var response = await _client.PostAsync(nameof(IMultipurposeService.BlockingCallAsync), parameters, headers);

response.StatusCode.ShouldBe(statusCode);
response.Content.Headers.ContentType!.MediaType.ShouldBe(contentType);
(await response.Content.ReadAsStringAsync()).ShouldBe("some message");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
// </copyright>

using System.IO.Compression;
using System.Net;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.OpenApi.Models;
using NUnit.Framework;
using ServiceModel.Grpc.AspNetCore.TestApi;
using ServiceModel.Grpc.AspNetCore.TestApi.Domain;
using ServiceModel.Grpc.TestApi.Domain;
using OpenApiDocument = ServiceModel.Grpc.AspNetCore.TestApi.OpenApiDocument;

Expand Down Expand Up @@ -50,6 +52,7 @@ public async Task BeforeAll()
c.EnableAnnotations(true, true);
});
services.AddMvc();
services.AddTransient<CustomResponseMiddleware>();
})
.ConfigureApp(app =>
{
Expand All @@ -60,6 +63,8 @@ public async Task BeforeAll()
});
app.UseServiceModelGrpcSwaggerGateway();
app.UseMiddleware<CustomResponseMiddleware>();
})
.ConfigureEndpoints(endpoints =>
{
Expand Down Expand Up @@ -131,4 +136,43 @@ public async Task Sum5ValuesAsync()

actual.ShouldBe(15);
}

[Test]
public async Task BlockingCallAsync()
{
var parameters = new Dictionary<string, object>
{
{ "x", 1 },
{ "y", "a" }
};

var actual = await _client.InvokeAsync<string>(nameof(IMultipurposeService.BlockingCallAsync), parameters);

actual.ShouldBe("a1");
}

[Test]
[TestCase(HttpStatusCode.Unauthorized, "application/grpc")]
[TestCase(HttpStatusCode.OK, "text/plain")]
public async Task NonGrpcResponseAsync(HttpStatusCode statusCode, string contentType)
{
var headers = new Dictionary<string, string>
{
{ CustomResponseMiddleware.HeaderResponseStatusCode, statusCode.ToString() },
{ CustomResponseMiddleware.HeaderContentType, contentType },
{ CustomResponseMiddleware.HeaderResponseBody, "some message" }
};

var parameters = new Dictionary<string, object>
{
{ "x", 1 },
{ "y", "a" }
};

var response = await _client.PostAsync(nameof(IMultipurposeService.BlockingCallAsync), parameters, headers);

response.StatusCode.ShouldBe(statusCode);
response.Content.Headers.ContentType!.MediaType.ShouldBe(contentType);
(await response.Content.ReadAsStringAsync()).ShouldBe("some message");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// <copyright>
// Copyright Max Ieremenko
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// </copyright>

using System.Net;

namespace ServiceModel.Grpc.AspNetCore.TestApi.Domain;

public sealed class CustomResponseMiddleware : IMiddleware
{
public const string HeaderResponseStatusCode = "CustomStatusCode";
public const string HeaderContentType = "CustomContentType";
public const string HeaderResponseBody = "CustomBody";

public Task InvokeAsync(HttpContext context, RequestDelegate next)
{
var statusCode = context.Request.Headers[HeaderResponseStatusCode];
if (statusCode.Count == 0)
{
return next(context);
}

context.Response.StatusCode = (int)Enum.Parse<HttpStatusCode>(statusCode.ToString());
context.Response.ContentType = context.Request.Headers[HeaderContentType].ToString();

var body = context.Request.Headers[HeaderResponseBody].ToString();
return context.Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes(body)).AsTask();
}
}
103 changes: 48 additions & 55 deletions Sources/ServiceModel.Grpc.AspNetCore.TestApi/SwaggerUiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,81 +37,74 @@ public SwaggerUiClient(

public async Task<HttpResponseHeaders> InvokeAsync(string methodName, IDictionary<string, object> parameters, IDictionary<string, string>? headers = null)
{
var endpoint = _document.GetEndpoint(_serviceName, methodName);
var contentType = _document.GetRequestContentType(endpoint);
var url = new Uri(new Uri(_hostLocation), endpoint);
var response = await PostAsync(methodName, parameters, headers).ConfigureAwait(false);

using (var client = new HttpClient())
{
HttpResponseMessage response;
response.EnsureSuccessStatusCode();
return response.Headers;
}

using (var request = new MemoryStream())
{
using (var writer = new StreamWriter(request, leaveOpen: true))
{
JsonSerializer.CreateDefault().Serialize(writer, parameters);
}
public async Task<T> InvokeAsync<T>(string methodName, IDictionary<string, object> parameters, IDictionary<string, string>? headers = null)
{
var response = await PostAsync(methodName, parameters, headers).ConfigureAwait(false);

request.Position = 0;
using (var content = new StreamContent(request))
{
content.Headers.ContentType = new MediaTypeHeaderValue(contentType);
if (headers != null)
{
foreach (var entry in headers)
{
content.Headers.Add(entry.Key, entry.Value);
}
}

response = await client.PostAsync(url, content).ConfigureAwait(false);
}
}
response.EnsureSuccessStatusCode();

response.EnsureSuccessStatusCode();
return response.Headers;
}
var content = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
return JsonSerializer.CreateDefault().Deserialize<T>(new JsonTextReader(new StreamReader(content)))!;
}

public async Task<T> InvokeAsync<T>(string methodName, IDictionary<string, object> parameters, IDictionary<string, string>? headers = null)
public async Task<HttpResponseMessage> PostAsync(string methodName, IDictionary<string, object>? parameters = null, IDictionary<string, string>? headers = null)
{
var endpoint = _document.GetEndpoint(_serviceName, methodName);
var contentType = _document.GetRequestContentType(endpoint);
var url = new Uri(new Uri(_hostLocation), endpoint);

using (var client = new HttpClient())
using var client = new HttpClient();

using var requestBody = new MemoryStream();
using (var writer = new StreamWriter(requestBody, leaveOpen: true))
{
HttpResponseMessage response;
JsonSerializer.CreateDefault().Serialize(writer, parameters);
}

using (var request = new MemoryStream())
{
using (var writer = new StreamWriter(request, leaveOpen: true))
{
JsonSerializer.CreateDefault().Serialize(writer, parameters);
}
HttpResponseMessage response;

request.Position = 0;
using (var content = new StreamContent(request))
requestBody.Position = 0;
using (var request = new StreamContent(requestBody))
{
request.Headers.ContentType = new MediaTypeHeaderValue(contentType);
if (headers != null)
{
foreach (var entry in headers)
{
content.Headers.ContentType = new MediaTypeHeaderValue(contentType);
if (headers != null)
{
foreach (var entry in headers)
{
content.Headers.Add(entry.Key, entry.Value);
}
}

response = await client.PostAsync(url, content).ConfigureAwait(false);
request.Headers.Add(entry.Key, entry.Value);
}
}

response.EnsureSuccessStatusCode();
response = await client.PostAsync(url, request).ConfigureAwait(false);
}

var result = new HttpResponseMessage(response.StatusCode);
foreach (var header in response.Headers)
{
result.Headers.Add(header.Key, header.Value);
}

var responseBody = new MemoryStream();
await using (var content = await response.Content.ReadAsStreamAsync().ConfigureAwait(false))
{
await content.CopyToAsync(responseBody).ConfigureAwait(false);
}

using (var content = await response.Content.ReadAsStreamAsync().ConfigureAwait(false))
responseBody.Position = 0;
result.Content = new StreamContent(responseBody)
{
Headers =
{
return JsonSerializer.CreateDefault().Deserialize<T>(new JsonTextReader(new StreamReader(content)))!;
ContentType = response.Content.Headers.ContentType
}
}
};

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ private async Task HandleRequestAsync(
response = await proxy.GetResponseBody().ConfigureAwait(false);
}

// non GRPC response => pass as is
if (context.Response.StatusCode != (int)HttpStatusCode.OK
|| !ProtocolConstants.MediaTypeNameGrpc.Equals(context.Response.Headers.ContentType, StringComparison.OrdinalIgnoreCase))
{
if (response.Length > 0)
{
await response.CopyToAsync(context.Response.BodyWriter.AsStream(true), context.RequestAborted).ConfigureAwait(false);
}

return;
}

context.Response.ContentType = ProtocolConstants.MediaTypeNameSwaggerResponse;
if (status.StatusCode == StatusCode.OK)
{
Expand Down

0 comments on commit a90b531

Please sign in to comment.