Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nagarwal4 committed Mar 27, 2024
1 parent 7ff42c5 commit 235a62f
Showing 1 changed file with 68 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,80 +42,76 @@ private enum SignatureAlgorithmHeader

public async Task Invoke(HttpContext context)
{
if (!SkipEndpoint(context))
// Retrieve headers from the request
string? signatureHeader = context.Request.Headers["x-signature"];
string? algorithmHeader = context.Request.Headers["x-algorithm"];
string? publicKeyHeader = context.Request.Headers["x-public-key"];
string? timestampHeader = context.Request.Headers["x-timestamp"];

// Make sure the headers are present
if (!Enum.TryParse<SignatureAlgorithmHeader>(algorithmHeader, ignoreCase: true, out _))
{
// Retrieve headers from the request
string? signatureHeader = context.Request.Headers["x-signature"];
string? algorithmHeader = context.Request.Headers["x-algorithm"];
string? publicKeyHeader = context.Request.Headers["x-public-key"];
string? timestampHeader = context.Request.Headers["x-timestamp"];

// Make sure the headers are present
if (!Enum.TryParse<SignatureAlgorithmHeader>(algorithmHeader, ignoreCase: true, out _))
{
_logger.LogError("Invalid algorithm header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid Header", "x-algorithm"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(signatureHeader))
{
_logger.LogError("Missing signature header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(publicKeyHeader))
{
_logger.LogError("Missing publicKey header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-public-key"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(timestampHeader)
|| !long.TryParse(timestampHeader, out var timestampHeaderLong))
{
_logger.LogError("Missing timestamp header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

// Check if the timestamp is within the allowed time
if (!IsWithinAllowedTime(timestampHeaderLong))
{
_logger.LogError("Timestamp outdated");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid timestamp", "x-timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

// TODO: Check if the public key is valid according to algorithm

var payloadSigningStream = await GetPayloadStream(context, timestampHeader);

// Parse the public key
var publicKeyBytes = Convert.FromBase64String(publicKeyHeader);

// Decode the signature header from Base64
var signatureBytes = Convert.FromBase64String(signatureHeader);

if (VerifySignature(publicKeyBytes, payloadSigningStream.ToArray(), signatureBytes))
{
// Signature is valid, continue with the request
await _next(context);
}
else
{
_logger.LogError("Invalid signature");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid signature", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
}
_logger.LogError("Invalid algorithm header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid Header", "x-algorithm"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(signatureHeader))
{
_logger.LogError("Missing signature header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(publicKeyHeader))
{
_logger.LogError("Missing publicKey header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-public-key"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(timestampHeader)
|| !long.TryParse(timestampHeader, out var timestampHeaderLong))
{
_logger.LogError("Missing timestamp header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

// Check if the timestamp is within the allowed time
if (!IsWithinAllowedTime(timestampHeaderLong))
{
_logger.LogError("Timestamp outdated");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid timestamp", "x-timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

// TODO: Check if the public key is valid according to algorithm

var payloadSigningStream = await GetPayloadStream(context, timestampHeader);

// Parse the public key
var publicKeyBytes = Convert.FromBase64String(publicKeyHeader);

// Decode the signature header from Base64
var signatureBytes = Convert.FromBase64String(signatureHeader);

if (VerifySignature(publicKeyBytes, payloadSigningStream.ToArray(), signatureBytes))
{
// Signature is valid, continue with the request
await _next(context);
}
else
{
_logger.LogError("Invalid signature");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid signature", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
}
await _next(context);
}

private static async Task<MemoryStream> GetPayloadStream(HttpContext context, string timestampHeader)
Expand Down Expand Up @@ -170,20 +166,6 @@ public static bool VerifySignature(byte[] publicKey, byte[] payload, byte[] sign
return false;
}
}

/// <summary>
/// Skip the middleware for specific endpoints
/// </summary>
private static bool SkipEndpoint(HttpContext context)
{
var endpoint = context.GetEndpoint();
var endpointName = endpoint?.Metadata.GetMetadata<EndpointNameMetadata>()?.EndpointName;

var excludeList = new[] { "SendOTPCodeEmail" };

return context.Request.Path.StartsWithSegments("/health")
|| excludeList.Contains(endpointName);
}
}

public static class SignatureVerificationMiddlewareExtensions
Expand Down

0 comments on commit 235a62f

Please sign in to comment.