Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenAI] Improved logging and specific handlers #43460

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ Mono<Void> start(Runnable postStartTask) {
.flatMap(authenticationHeader -> Mono.<Void>fromRunnable(() -> {
this.webSocketSession
.set(webSocketClient.connectToServer(this.clientEndpointConfiguration, () -> authenticationHeader,
loggerReference, this::handleMessage, this::handleSessionOpen, this::handleSessionClose));
this::handleMessage, this::handleSessionOpen, this::handleSessionClose));
}))
.subscribeOn(Schedulers.boundedElastic())
.doOnError(error -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.ai.openai.realtime.implementation.websocket;

import com.azure.core.util.logging.ClientLogger;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.CharsetUtil;

/**
* Dedicated handler for server-side ping messages.
*/
public class KeepAliveHandler extends ChannelDuplexHandler {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public class KeepAliveHandler extends ChannelDuplexHandler {
public final class KeepAliveHandler extends ChannelDuplexHandler {

private static final ClientLogger LOGGER = new ClientLogger(KeepAliveHandler.class);

/**
* {@inheritDoc}
*/
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpResponse) {
FullHttpResponse response = (FullHttpResponse) msg;
throw LOGGER.logExceptionAsError(new IllegalStateException("Unexpected FullHttpResponse (getStatus="
+ response.status() + ", content=" + response.content().toString(CharsetUtil.UTF_8) + ')'));
}
Channel ch = ctx.channel();

LOGGER.atVerbose().log("Processing message: ");
WebSocketFrame frame = (WebSocketFrame) msg;
if (frame instanceof PingWebSocketFrame) {
// Ping, reply Pong
LOGGER.atVerbose().log(() -> "Received PingWebSocketFrame");
LOGGER.atVerbose().log(() -> "Sending PongWebSocketFrame");
ch.writeAndFlush(new PongWebSocketFrame());
} else if (frame instanceof PongWebSocketFrame) {
// Pong
LOGGER.atVerbose().log(() -> "Received PongWebSocketFrame");
} else {
// Pass other frames down the pipeline
// We only pass down the pipeline messages this handler doesn't process
ctx.fireChannelRead(msg);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

package com.azure.ai.openai.realtime.implementation.websocket;

import com.azure.core.util.logging.ClientLogger;

import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;

Expand All @@ -14,6 +11,5 @@ public interface WebSocketClient {

WebSocketSession connectToServer(ClientEndpointConfiguration cec,
Supplier<AuthenticationProvider.AuthenticationHeader> authenticationHeaderSupplier,
AtomicReference<ClientLogger> loggerReference, Consumer<Object> messageHandler,
Consumer<WebSocketSession> openHandler, Consumer<CloseReason> closeHandler);
Consumer<Object> messageHandler, Consumer<WebSocketSession> openHandler, Consumer<CloseReason> closeHandler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,27 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException;
import io.netty.util.CharsetUtil;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

final class WebSocketClientHandler extends SimpleChannelInboundHandler<Object> {

private final WebSocketClientHandshaker handShaker;
private ChannelPromise handshakeFuture;

private final AtomicReference<ClientLogger> loggerReference;
private static final ClientLogger LOGGER = new ClientLogger(WebSocketClientHandler.class);
private final MessageDecoder messageDecoder;
private final Consumer<Object> messageHandler;

WebSocketClientHandler(WebSocketClientHandshaker handShaker, AtomicReference<ClientLogger> loggerReference,
MessageDecoder messageDecoder, Consumer<Object> messageHandler) {
WebSocketClientHandler(WebSocketClientHandshaker handShaker, MessageDecoder messageDecoder,
Consumer<Object> messageHandler) {
this.handShaker = handShaker;
this.loggerReference = loggerReference;
this.messageDecoder = messageDecoder;
this.messageHandler = messageHandler;
}
Expand All @@ -62,42 +58,31 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
handShaker.finishHandshake(ch, (FullHttpResponse) msg);
handshakeFuture.setSuccess();
} catch (WebSocketHandshakeException e) {
handshakeFuture.setFailure(e);
handshakeFuture.setFailure(LOGGER.atError().log(e));
}
return;
}

if (msg instanceof FullHttpResponse) {
FullHttpResponse response = (FullHttpResponse) msg;
throw loggerReference.get()
.logExceptionAsError(new IllegalStateException("Unexpected FullHttpResponse (getStatus="
+ response.status() + ", content=" + response.content().toString(CharsetUtil.UTF_8) + ')'));
throw LOGGER.logExceptionAsError(new IllegalStateException("Unexpected FullHttpResponse (getStatus="
+ response.status() + ", content=" + response.content().toString(CharsetUtil.UTF_8) + ')'));
}

WebSocketFrame frame = (WebSocketFrame) msg;
LOGGER.atInfo().log("Processing frame: " + frame.toString());

if (frame instanceof TextWebSocketFrame) {
// Text
TextWebSocketFrame textFrame = (TextWebSocketFrame) frame;
loggerReference.get()
.atVerbose()
.addKeyValue("text", textFrame.text())
.log(() -> "Received TextWebSocketFrame");
LOGGER.atVerbose().addKeyValue("text", textFrame.text()).log(() -> "Received TextWebSocketFrame");

Object wpsMessage = messageDecoder.decode(textFrame.text());
messageHandler.accept(wpsMessage);
} else if (frame instanceof PingWebSocketFrame) {
// Ping, reply Pong
loggerReference.get().atVerbose().log(() -> "Received PingWebSocketFrame");
loggerReference.get().atVerbose().log(() -> "Send PongWebSocketFrame");
ch.writeAndFlush(new PongWebSocketFrame());
} else if (frame instanceof PongWebSocketFrame) {
// Pong
loggerReference.get().atVerbose().log(() -> "Received PongWebSocketFrame");
} else if (frame instanceof CloseWebSocketFrame) {
// Close
CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame;
loggerReference.get()
.atVerbose()
LOGGER.atVerbose()
.addKeyValue("statusCode", closeFrame.statusCode())
.addKeyValue("reasonText", closeFrame.reasonText())
.log(() -> "Received CloseWebSocketFrame");
Expand All @@ -106,25 +91,24 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) {

if (closeCallbackFuture == null) {
// close initiated from server, reply CloseWebSocketFrame, then close connection
loggerReference.get().atVerbose().log(() -> "Send CloseWebSocketFrame");
LOGGER.atVerbose().log(() -> "Sending CloseWebSocketFrame");
closeFrame.retain(); // retain before write it back
ch.writeAndFlush(closeFrame).addListener(future -> ch.close());
} else {
// close initiated from client, client already sent CloseWebSocketFrame
ch.close();
}
} else {
// Pass other frames down the pipeline
// We only pass down the pipeline messages this handler doesn't process
ctx.fireChannelRead(msg);
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ClientLogger logger = loggerReference.get();
if (logger != null) {
logger.atError().log(cause);
}
// cause.printStackTrace();
if (handshakeFuture != null && !handshakeFuture.isDone()) {
handshakeFuture.setFailure(cause);
handshakeFuture.setFailure(LOGGER.atError().log(cause));
}
ctx.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
import com.azure.ai.openai.realtime.models.ConnectFailedException;
import com.azure.core.util.logging.ClientLogger;

import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;

public final class WebSocketClientNettyImpl implements WebSocketClient {

private static final ClientLogger LOGGER = new ClientLogger(WebSocketClientNettyImpl.class);

@Override
public WebSocketSession connectToServer(ClientEndpointConfiguration cec,
Supplier<AuthenticationProvider.AuthenticationHeader> authenticationHeaderSupplier,
AtomicReference<ClientLogger> loggerReference, Consumer<Object> messageHandler,
Consumer<WebSocketSession> openHandler, Consumer<CloseReason> closeHandler) {
Consumer<Object> messageHandler, Consumer<WebSocketSession> openHandler, Consumer<CloseReason> closeHandler) {
try {
WebSocketSessionNettyImpl session = new WebSocketSessionNettyImpl(cec, authenticationHeaderSupplier,
loggerReference, messageHandler, openHandler, closeHandler);
messageHandler, openHandler, closeHandler);
session.connect();
return session;
} catch (Exception e) {
throw loggerReference.get().logExceptionAsError(new ConnectFailedException("Failed to connect", e));
throw LOGGER.logExceptionAsError(new ConnectFailedException("Failed to connect", e));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.ai.openai.realtime.implementation.websocket;

import com.azure.core.util.logging.ClientLogger;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.timeout.IdleStateEvent;

/**
* Handler that sends a ping frame to the server when the channel is idle to prevent keep-alive timeouts.
*/
public class WebSocketPingHandler extends ChannelDuplexHandler {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public class WebSocketPingHandler extends ChannelDuplexHandler {
public final class WebSocketPingHandler extends ChannelDuplexHandler {

private static final ClientLogger LOGGER = new ClientLogger(WebSocketPingHandler.class);

/**
* {@inheritDoc}
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
LOGGER.atVerbose().log("Received IdleStateEvent");
IdleStateEvent event = (IdleStateEvent) evt;
if (event.state() == IdleStateEvent.ALL_IDLE_STATE_EVENT.state()) {
LOGGER.atVerbose().log("Sending PingWebSocketFrame");
ctx.writeAndFlush(new PingWebSocketFrame());
}
}
super.userEventTriggered(ctx, evt);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import com.azure.ai.openai.realtime.models.RealtimeClientEvent;
import com.azure.core.util.logging.ClientLogger;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
Expand All @@ -28,21 +30,22 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.timeout.IdleStateHandler;

import javax.net.ssl.SSLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Supplier;

final class WebSocketSessionNettyImpl implements WebSocketSession {

private static final int MAX_FRAME_SIZE = 65536;

private final AtomicReference<ClientLogger> loggerReference;
private static final ClientLogger LOGGER = new ClientLogger(WebSocketSessionNettyImpl.class);
private final MessageEncoder messageEncoder;
private final MessageDecoder messageDecoder;

Expand Down Expand Up @@ -81,15 +84,16 @@ protected void initChannel(SocketChannel ch) {
}
p.addLast(new HttpClientCodec(), new HttpObjectAggregator(MAX_FRAME_SIZE),
WebSocketClientCompressionHandler.INSTANCE, new WebSocketFrameAggregator(MAX_FRAME_SIZE), handler);
p.addLast(new KeepAliveHandler());
p.addLast(new IdleStateHandler(0, 0, 30, TimeUnit.SECONDS));
p.addLast(new WebSocketPingHandler());
}
}

WebSocketSessionNettyImpl(ClientEndpointConfiguration cec,
Supplier<AuthenticationProvider.AuthenticationHeader> authenticationHeaderSupplier,
AtomicReference<ClientLogger> loggerReference, Consumer<Object> messageHandler,
Consumer<WebSocketSession> openHandler, Consumer<CloseReason> closeHandler) {
Consumer<Object> messageHandler, Consumer<WebSocketSession> openHandler, Consumer<CloseReason> closeHandler) {
this.uri = cec.getUri();
this.loggerReference = loggerReference;
this.messageEncoder = cec.getMessageEncoder();
this.messageDecoder = cec.getMessageDecoder();
this.subProtocol = cec.getSubProtocol();
Expand Down Expand Up @@ -132,11 +136,16 @@ void connect() throws URISyntaxException, SSLException, InterruptedException, Ex
handshaker = WebSocketClientHandshakerFactory.newHandshaker(uri, WebSocketVersion.V13, this.subProtocol, true,
this.headers);

clientHandler = new WebSocketClientHandler(handshaker, loggerReference, messageDecoder, messageHandler);
clientHandler = new WebSocketClientHandler(handshaker, messageDecoder, messageHandler);

Bootstrap b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000) // Connection establishment timeout
.option(ChannelOption.SO_KEEPALIVE, true) // Enable TCP-level keep-alive
.option(ChannelOption.TCP_NODELAY, true) // Disable Nagle's algorithm for low latency
.option(ChannelOption.SO_REUSEADDR, true) // Allow address reuse
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) // Efficient memory allocation
.handler(new WebSocketChannelHandler(host, port, sslCtx, clientHandler));

final CompletableFuture<Void> handshakeCallbackFuture = new CompletableFuture<>();
Expand Down Expand Up @@ -183,7 +192,6 @@ public boolean isOpen() {
@Override
public void sendObjectAsync(Object data, Consumer<SendResult> handler) {
if (ch != null && ch.isOpen()) {
// TODO: jpalvarezl adjust to the right type for casting
String msg = messageEncoder.encode((RealtimeClientEvent) data);
sendTextAsync(msg, handler);
} else {
Expand All @@ -195,7 +203,7 @@ public void sendObjectAsync(Object data, Consumer<SendResult> handler) {
public void sendTextAsync(String text, Consumer<SendResult> handler) {
if (ch != null && ch.isOpen()) {
TextWebSocketFrame frame = new TextWebSocketFrame(text);
loggerReference.get().atVerbose().addKeyValue("text", frame.text()).log(() -> "Send TextWebSocketFrame");
LOGGER.atVerbose().addKeyValue("text", frame.text()).log(() -> "Send TextWebSocketFrame");
ch.writeAndFlush(frame).addListener(future -> {
if (future.isSuccess()) {
handler.accept(new SendResult());
Expand All @@ -219,7 +227,7 @@ public void closeSocket() {

group.shutdownGracefully();
} catch (InterruptedException e) {
throw loggerReference.get().logExceptionAsError(new ConnectFailedException("Failed to disconnect", e));
throw LOGGER.logExceptionAsError(new ConnectFailedException("Failed to disconnect", e));
}
}
}
Expand All @@ -235,8 +243,7 @@ public void close() {
clientHandler.setClientCloseCallbackFuture(closeCallbackFuture);

CloseWebSocketFrame closeFrame = new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE);
loggerReference.get()
.atVerbose()
LOGGER.atVerbose()
.addKeyValue("statusCode", closeFrame.statusCode())
.addKeyValue("reasonText", closeFrame.reasonText())
.log(() -> "Send CloseWebSocketFrame");
Expand All @@ -252,7 +259,7 @@ public void close() {
closeCallbackFuture.get();
}
} catch (InterruptedException | ExecutionException e) {
throw loggerReference.get().logExceptionAsError(new ConnectFailedException("Failed to disconnect", e));
throw LOGGER.logExceptionAsError(new ConnectFailedException("Failed to disconnect", e));
}
}
}
Expand Down
Loading