diff --git a/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift b/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift index 10c81adb..51575633 100644 --- a/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift +++ b/Sources/NIOHTTPTypesHTTP2/HTTP2ToHTTPCodec.swift @@ -152,6 +152,14 @@ public final class HTTP2FramePayloadToHTTPClientCodec: ChannelDuplexHandler, Rem context.fireErrorCaught(error) } } + + public func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { + if let ev = event as? NIOHTTP2FramePayloadToHTTPEvent, let code = ev.reset { + context.writeAndFlush(self.wrapOutboundOut(.rstStream(code)), promise: promise) + return + } + context.triggerUserOutboundEvent(event, promise: promise) + } } // MARK: - Server @@ -262,4 +270,34 @@ public final class HTTP2FramePayloadToHTTPServerCodec: ChannelDuplexHandler, Rem let transformedPayload = self.baseCodec.processOutboundData(responsePart, allocator: context.channel.allocator) context.write(self.wrapOutboundOut(transformedPayload), promise: promise) } + + public func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { + if let ev = event as? NIOHTTP2FramePayloadToHTTPEvent, let code = ev.reset { + context.writeAndFlush(self.wrapOutboundOut(.rstStream(code)), promise: promise) + return + } + context.triggerUserOutboundEvent(event, promise: promise) + } +} + +/// Events that can be sent by the application to be handled by the `HTTP2StreamChannel` +public struct NIOHTTP2FramePayloadToHTTPEvent: Hashable, Sendable { + private enum Kind: Hashable, Sendable { + case reset(HTTP2ErrorCode) + } + + private var kind: Kind + + /// Send a `RST_STREAM` with the specified code + public static func reset(code: HTTP2ErrorCode) -> Self { + .init(kind: .reset(code)) + } + + /// Returns reset code if the event is a reset + public var reset: HTTP2ErrorCode? { + switch self.kind { + case .reset(let code): + return code + } + } } diff --git a/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift b/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift index 769b7255..8030ba5a 100644 --- a/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift +++ b/Tests/NIOHTTPTypesHTTP2Tests/NIOHTTPTypesHTTP2Tests.swift @@ -116,9 +116,16 @@ final class NIOHTTPTypesHTTP2Tests: XCTestCase { try self.channel.writeOutbound(HTTPRequestPart.head(Self.request)) try self.channel.writeOutbound(HTTPRequestPart.end(Self.trailers)) + try self.channel.triggerUserOutboundEvent(NIOHTTP2FramePayloadToHTTPEvent.reset(code: .enhanceYourCalm)).wait() XCTAssertEqual(try self.channel.readOutbound(as: HTTP2Frame.FramePayload.self)?.headers, Self.oldRequest) XCTAssertEqual(try self.channel.readOutbound(as: HTTP2Frame.FramePayload.self)?.headers, Self.oldTrailers) + switch try self.channel.readOutbound(as: HTTP2Frame.FramePayload.self) { + case .rstStream(.enhanceYourCalm): + break + default: + XCTFail("expected reset") + } try self.channel.writeInbound(HTTP2Frame.FramePayload(headers: Self.oldResponse)) try self.channel.writeInbound(HTTP2Frame.FramePayload(headers: Self.oldTrailers)) @@ -142,9 +149,16 @@ final class NIOHTTPTypesHTTP2Tests: XCTestCase { try self.channel.writeOutbound(HTTPResponsePart.head(Self.response)) try self.channel.writeOutbound(HTTPResponsePart.end(Self.trailers)) + try self.channel.triggerUserOutboundEvent(NIOHTTP2FramePayloadToHTTPEvent.reset(code: .enhanceYourCalm)).wait() XCTAssertEqual(try self.channel.readOutbound(as: HTTP2Frame.FramePayload.self)?.headers, Self.oldResponse) XCTAssertEqual(try self.channel.readOutbound(as: HTTP2Frame.FramePayload.self)?.headers, Self.oldTrailers) + switch try self.channel.readOutbound(as: HTTP2Frame.FramePayload.self) { + case .rstStream(.enhanceYourCalm): + break + default: + XCTFail("expected reset") + } XCTAssertTrue(try self.channel.finish().isClean) }