Skip to content

Commit

Permalink
fix(whisper): Reconnects on socket closed while operating. (#566)
Browse files Browse the repository at this point in the history
* fix(whisper): Reconnects on socket closed while operating.

Executes stop and connect in a different from the websocket messages thread as recommended in jetty docs.

* squash: move comments.
  • Loading branch information
damencho authored Oct 30, 2024
1 parent 3548be7 commit a0e043b
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 44 deletions.
21 changes: 21 additions & 0 deletions src/main/java/org/jitsi/jigasi/stats/Statistics.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ public class Statistics
*/
public static final String TOTAL_TRANSCRIBER_CONNECTION_ERRORS = "total_transcriber_connection_errors";

/**
* The total number of connection retries for the transcriber.
*/
public static final String TOTAL_TRANSCRIBER_CONNECTION_RETRIES = "total_transcriber_connection_retries";

/**
* The total number of no result errors for the transcriber.
*/
Expand Down Expand Up @@ -302,6 +307,13 @@ public class Statistics
TOTAL_TRANSCRIBER_CONNECTION_ERRORS,
"Total number of transcriber connection errors.");

/**
* Total number of transcriptions connection retries.
*/
private static CounterMetric totalTrasnscriberConnectionRetries = JigasiMetricsContainer.INSTANCE.registerCounter(
TOTAL_TRANSCRIBER_CONNECTION_RETRIES,
"Total number of transcriber connection retries.");

/**
* Total number of transcriptions no result errors.
*/
Expand Down Expand Up @@ -465,6 +477,7 @@ public static synchronized void sendJSON(
stats.put(TOTAL_TRANSCRIBER_FAILED, totalTrasnscriberFailed.get());

stats.put(TOTAL_TRANSCRIBER_CONNECTION_ERRORS, totalTrasnscriberConnectionErrors.get());
stats.put(TOTAL_TRANSCRIBER_CONNECTION_RETRIES, totalTrasnscriberConnectionRetries.get());
stats.put(TOTAL_TRANSCRIBER_NO_RESUL_ERRORS, totalTrasnscriberNoResultErrors.get());
stats.put(TOTAL_TRANSCRIBER_SEND_ERRORS, totalTrasnscriberSendErrors.get());
stats.put(TOTAL_TRANSCRIBER_SESSION_CREATION_ERRORS, totalTrasnscriberSessionCreationErrors.get());
Expand Down Expand Up @@ -736,6 +749,14 @@ public static void incrementTotalTranscriberConnectionErrors()
totalTrasnscriberConnectionErrors.inc();
}

/**
* Increment the value of total number of transcriber connection retries.
*/
public static void incrementTotalTranscriberConnectionRetries()
{
totalTrasnscriberConnectionRetries.inc();
}

/**
* Increment the value of total number of transcriber no result errors.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ public class WhisperConnectionPool
*/
private final Map<String, WhisperWebsocket> pool = new ConcurrentHashMap<>();

/**
* The thread pool to serve all connect disconnect operations.
*/
private static final ExecutorService threadPool = Util.createNewThreadPool("jigasi-whisper-ws");

/**
* Gets a connection if it exists, creates one if it doesn't.
* @param roomId The room jid.
Expand All @@ -66,8 +61,7 @@ public WhisperWebsocket getConnection(String roomId)
logger.info("Room " + roomId + " doesn't exist. Creating a new connection.");
final WhisperWebsocket socket = new WhisperWebsocket();

// connect socket in new thread to not block Smack threads
threadPool.execute(socket::connect);
socket.connect();

pool.put(roomId, socket);
}
Expand All @@ -81,31 +75,20 @@ public WhisperWebsocket getConnection(String roomId)
* @param participantId The participant id.
*/
public void end(String roomId, String participantId)
{
// execute this in new thread to not block Smack
threadPool.execute(() -> this.endInternal(roomId, participantId));
}

private void endInternal(String roomId, String participantId)
{
WhisperWebsocket wsConn = pool.getOrDefault(roomId, null);
if (wsConn == null)
{
return;
}

try
wsConn.disconnectParticipant(participantId, allDisconnected ->
{
if (wsConn.disconnectParticipant(participantId))
if (allDisconnected)
{
// remove from the pull if everyone is disconnected
pool.remove(roomId);
}
}
catch (IOException e)
{
logger.error("Error while finalizing websocket connection for participant " + participantId, e);
}
});
}

/**
Expand Down
123 changes: 100 additions & 23 deletions src/main/java/org/jitsi/jigasi/transcription/WhisperWebsocket.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.eclipse.jetty.websocket.client.*;
import org.jitsi.jigasi.*;
import org.jitsi.jigasi.stats.*;
import org.jitsi.jigasi.util.Util;
import org.jitsi.utils.logging.*;
import org.json.*;

Expand All @@ -32,8 +33,13 @@
import java.time.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.*;


/**
* This holds the websocket that is used to send audio data to the Whisper.
* This is one WhisperWebsocket per room (mapping is in <link>WhisperConnectionPool</link>).
* The jetty WebSocketClient process messages in a single thread.
*/
@WebSocket
public class WhisperWebsocket
{
Expand Down Expand Up @@ -111,6 +117,8 @@ public class WhisperWebsocket

private WebSocketClient ws;

private boolean reconnecting = false;

static
{
jwtAudience = JigasiBundleActivator.getConfigurationService()
Expand Down Expand Up @@ -139,6 +147,11 @@ public class WhisperWebsocket
logger.info("Websocket transcription streaming endpoint: " + websocketUrlConfig);
}

/**
* The thread pool to serve all connect, disconnect ore reconnect operations.
*/
private static final ExecutorService threadPool = Util.createNewThreadPool("jigasi-whisper-ws");

/**
* Creates a connection url by concatenating the websocket
* url with the Connection Id;
Expand All @@ -152,12 +165,19 @@ private void generateWebsocketUrl()
}
}

/**
* Connect to the websocket in a new thread so we do not block Smack.
*/
void connect()
{
threadPool.submit(this::connectInternal);
}

/**
* Connect to the websocket, retry up to maxRetryAttempts
* with exponential backoff in case of failure
*/
void connect()
private void connectInternal()
{
int attempt = 0;
float multiplier = 1.5f;
Expand All @@ -178,6 +198,7 @@ void connect()
wsSession = ws.connect(this, new URI(websocketUrl), upgradeRequest).get();
wsSession.setIdleTimeout(Duration.ofSeconds(300));
isConnected = true;
reconnecting = false;
logger.info("Successfully connected to " + websocketUrl);
break;
}
Expand Down Expand Up @@ -208,14 +229,59 @@ void connect()
}
}

private synchronized void reconnect()
{
if (reconnecting)
{
return;
}
reconnecting = true;

Statistics.incrementTotalTranscriberConnectionRetries();

threadPool.submit(() ->
{
this.stopWebSocketClient();

this.connectInternal();
});
}

@OnWebSocketClose
public void onClose(int statusCode, String reason)
{
if (!this.participants.isEmpty())
{
// let's try to reconnect
if (!wsSession.isOpen())
{
reconnect();

return;
}
}

if (participants != null && !participants.isEmpty())
{
logger.error("Websocket closed: " + statusCode + " reason:" + reason);
}

wsSession = null;
participants = null;
participantListeners = null;
participantTranscriptionStarts = null;
participantTranscriptionIds = null;

threadPool.submit(this::stopWebSocketClient);
}

/**
* Stop the websocket client.
* Make sure this is executed in a different thread than the one
* the websocket client is running in (the onMessage, onError or onClose callbacks).
*/
private void stopWebSocketClient()
{
try
{
if (ws != null)
Expand Down Expand Up @@ -300,7 +366,7 @@ public void onMessage(String msg)
@OnWebSocketError
public void onError(Throwable cause)
{
if (wsSession != null)
if (!ended() && participants != null && !participants.isEmpty())
{
Statistics.incrementTotalTranscriberSendErrors();
logger.error("Error while streaming audio data to transcription service.", cause);
Expand Down Expand Up @@ -337,17 +403,21 @@ private ByteBuffer buildPayload(String participantId, Participant participant, B
}

/**
* Disconnect a participant from the transcription service.
* Disconnect a participant from the transcription service, executing that in a new thread so we do not block Smack.
* @param participantId the participant to disconnect.
* @return <tt>true</tt> if the last participant has left and the session was closed.
* @throws IOException
* @param callback the callback to execute when the last participant is disconnected and session is closed.
*/
public boolean disconnectParticipant(String participantId)
throws IOException
public void disconnectParticipant(String participantId, Consumer<Boolean> callback)
{
if (this.wsSession == null)
threadPool.submit(() -> this.disconnectParticipantInternal(participantId, callback));
}

private void disconnectParticipantInternal(String participantId, Consumer<Boolean> callback)
{
if (ended() && (participants == null || participants.isEmpty()))
{
return true;
callback.accept(true);
return;
}

synchronized (this)
Expand All @@ -362,11 +432,21 @@ public boolean disconnectParticipant(String participantId)
if (participants.isEmpty())
{
logger.info("All participants have left, disconnecting from Whisper transcription server.");
wsSession.getRemote().sendBytes(EOF_MESSAGE);

try
{
wsSession.getRemote().sendBytes(EOF_MESSAGE);
}
catch (IOException e)
{
logger.error("Error while finalizing websocket connection for participant " + participantId, e);
}

wsSession.disconnect();
return true;
callback.accept(true);
}
return false;

callback.accept(false);
}
}

Expand All @@ -384,18 +464,15 @@ public void sendAudio(String participantId, Participant participant, ByteBuffer
logger.error("Failed sending audio for " + participantId + ". Attempting to reconnect.");
if (!wsSession.isOpen())
{
try
{
connect();
remoteEndpoint = wsSession.getRemote();
}
catch (Exception ex)
{
logger.error(ex);
}
reconnect();
}
else
{
logger.warn("Failed sending audio for " + participantId
+ ". RemoteEndpoint is null but sessions is open.");
}
}
if (remoteEndpoint != null)
else
{
try
{
Expand Down

0 comments on commit a0e043b

Please sign in to comment.