Skip to content

Commit

Permalink
upgrade server
Browse files Browse the repository at this point in the history
  • Loading branch information
KoljaB committed Oct 27, 2024
1 parent 8f65cfb commit cacb7a1
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 108 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ https://github.com/user-attachments/assets/797e6552-27cd-41b1-a7f3-e5cbc72094f5

### Updates

Latest Version: v0.3.1
Latest Version: v0.3.2

See [release history](https://github.com/KoljaB/RealtimeSTT/releases).

Expand Down Expand Up @@ -549,6 +549,17 @@ Suggested starting parameters for OpenWakeWord usage:
) as recorder:
```
## FAQ
### Q: I encountered the following error: "Unable to load any of {libcudnn_ops.so.9.1.0, libcudnn_ops.so.9.1, libcudnn_ops.so.9, libcudnn_ops.so} Invalid handle. Cannot load symbol cudnnCreateTensorDescriptor." How do I fix this?
**A:** This issue arises from a mismatch between the version of `ctranslate2` and cuDNN. The `ctranslate2` library was updated to version 4.5.0, which uses cuDNN 9.2. There are two ways to resolve this issue:
1. **Downgrade `ctranslate2` to version 4.4.0**:
```bash
pip install ctranslate2==4.4.0
```
2. **Upgrade cuDNN** on your system to version 9.2 or above.
## Contribution
Contributions are always welcome!
Expand Down
117 changes: 96 additions & 21 deletions RealtimeSTT/audio_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def __init__(self,
Exception: Errors related to initializing transcription
model, wake word detection, or audio recording.
"""

self.language = language
self.compute_type = compute_type
self.input_device_index = input_device_index
Expand Down Expand Up @@ -598,6 +599,11 @@ def __init__(self,

logging.info("Starting RealTimeSTT")

if use_extended_logging:
logging.info("RealtimeSTT was called with these parameters:")
for param, value in locals().items():
logging.info(f"{param}: {value}")

self.interrupt_stop_event = mp.Event()
self.was_interrupted = mp.Event()
self.main_transcription_ready_event = mp.Event()
Expand Down Expand Up @@ -922,14 +928,78 @@ def get_highest_sample_rate(audio_interface, device_index):

def initialize_audio_stream(audio_interface, sample_rate, chunk_size):
nonlocal input_device_index

def validate_device(device_index):
"""Validate that the device exists and is actually available for input."""
try:
device_info = audio_interface.get_device_info_by_index(device_index)
if not device_info.get('maxInputChannels', 0) > 0:
return False

# Try to actually read from the device
test_stream = audio_interface.open(
format=pyaudio.paInt16,
channels=1,
rate=target_sample_rate,
input=True,
frames_per_buffer=chunk_size,
input_device_index=device_index,
start=False # Don't start the stream yet
)

# Start the stream and try to read from it
test_stream.start_stream()
test_data = test_stream.read(chunk_size, exception_on_overflow=False)
test_stream.stop_stream()
test_stream.close()

# Check if we got valid data
if len(test_data) == 0:
return False

return True

except Exception as e:
logging.debug(f"Device validation failed: {e}")
return False

"""Initialize the audio stream with error handling."""
while not shutdown_event.is_set():
try:
# Check and assign the input device index if it is not set
if input_device_index is None:
default_device = audio_interface.get_default_input_device_info()
input_device_index = default_device['index']
# First, get a list of all available input devices
input_devices = []
for i in range(audio_interface.get_device_count()):
try:
device_info = audio_interface.get_device_info_by_index(i)
if device_info.get('maxInputChannels', 0) > 0:
input_devices.append(i)
except Exception:
continue

if not input_devices:
raise Exception("No input devices found")

# If input_device_index is None or invalid, try to find a working device
if input_device_index is None or input_device_index not in input_devices:
# First try the default device
try:
default_device = audio_interface.get_default_input_device_info()
if validate_device(default_device['index']):
input_device_index = default_device['index']
except Exception:
# If default device fails, try other available input devices
for device_index in input_devices:
if validate_device(device_index):
input_device_index = device_index
break
else:
raise Exception("No working input devices found")

# Validate the selected device one final time
if not validate_device(input_device_index):
raise Exception("Selected device validation failed")

# If we get here, we have a validated device
stream = audio_interface.open(
format=pyaudio.paInt16,
channels=1,
Expand All @@ -938,13 +1008,15 @@ def initialize_audio_stream(audio_interface, sample_rate, chunk_size):
frames_per_buffer=chunk_size,
input_device_index=input_device_index,
)
logging.info("Microphone connected successfully.")

logging.info(f"Microphone connected and validated (input_device_index: {input_device_index})")
return stream

except Exception as e:
logging.error(f"Microphone connection failed: {e}. Retrying...")
input_device_index = None
time.sleep(3) # Wait for 3 seconds before retrying
time.sleep(3) # Wait before retrying
continue

def preprocess_audio(chunk, original_sample_rate, target_sample_rate):
"""Preprocess audio chunk similar to feed_audio method."""
Expand Down Expand Up @@ -980,7 +1052,8 @@ def preprocess_audio(chunk, original_sample_rate, target_sample_rate):
def setup_audio():
nonlocal audio_interface, stream, device_sample_rate, input_device_index
try:
audio_interface = pyaudio.PyAudio()
if audio_interface is None:
audio_interface = pyaudio.PyAudio()
if input_device_index is None:
try:
default_device = audio_interface.get_default_input_device_info()
Expand Down Expand Up @@ -1024,6 +1097,7 @@ def setup_audio():
silero_buffer_size = 2 * buffer_size # silero complains if too short

time_since_last_buffer_message = 0

try:
while not shutdown_event.is_set():
try:
Expand Down Expand Up @@ -1057,17 +1131,14 @@ def setup_audio():
else:
logging.error(f"OSError during recording: {e}")
# Attempt to reinitialize the stream
logging.info("Attempting to reinitialize the audio stream...")
logging.error("Attempting to reinitialize the audio stream...")

try:
if stream:
stream.stop_stream()
stream.close()
except Exception as e:
pass

if audio_interface:
audio_interface.terminate()

# Wait a bit before trying to reinitialize
time.sleep(1)
Expand All @@ -1076,7 +1147,7 @@ def setup_audio():
logging.error("Failed to reinitialize audio stream. Exiting.")
break
else:
logging.info("Audio stream reinitialized successfully.")
logging.error("Audio stream reinitialized successfully.")
continue

except Exception as e:
Expand All @@ -1086,14 +1157,15 @@ def setup_audio():
logging.error(f"Error: {e}")
# Attempt to reinitialize the stream
logging.info("Attempting to reinitialize the audio stream...")
if stream:
stream.stop_stream()
stream.close()
if audio_interface:
audio_interface.terminate()
try:
if stream:
stream.stop_stream()
stream.close()
except Exception as e:
pass

# Wait a bit before trying to reinitialize
time.sleep(0.5)
time.sleep(1)

if not setup_audio():
logging.error("Failed to reinitialize audio stream. Exiting.")
Expand All @@ -1110,9 +1182,12 @@ def setup_audio():
if buffer:
audio_queue.put(bytes(buffer))

if stream:
stream.stop_stream()
stream.close()
try:
if stream:
stream.stop_stream()
stream.close()
except Exception as e:
pass
if audio_interface:
audio_interface.terminate()

Expand Down
75 changes: 72 additions & 3 deletions RealtimeSTT/audio_recorder_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
log_outgoing_chunks = False

from typing import Iterable, List, Optional, Union
from urllib.parse import urlparse
import subprocess
Expand Down Expand Up @@ -208,6 +210,9 @@ def __init__(self,
self.realtime_text = ""
self.final_text = ""

self.request_counter = 0
self.pending_requests = {} # Map from request_id to threading.Event and value

if self.debug_mode:
print("Checking STT server")
if not self.connect():
Expand Down Expand Up @@ -349,6 +354,8 @@ def start_server(self):
args += ['--language', self.language]
if self.silero_sensitivity is not None:
args += ['--silero_sensitivity', str(self.silero_sensitivity)]
if self.silero_use_onnx:
args.append('--silero_use_onnx') # flag, no need for True/False
if self.webrtc_sensitivity is not None:
args += ['--webrtc_sensitivity', str(self.webrtc_sensitivity)]
if self.min_length_of_recording is not None:
Expand All @@ -359,12 +366,35 @@ def start_server(self):
args += ['--realtime_processing_pause', str(self.realtime_processing_pause)]
if self.early_transcription_on_silence is not None:
args += ['--early_transcription_on_silence', str(self.early_transcription_on_silence)]
if self.silero_deactivity_detection:
args.append('--silero_deactivity_detection') # flag, no need for True/False
if self.beam_size is not None:
args += ['--beam_size', str(self.beam_size)]
if self.beam_size_realtime is not None:
args += ['--beam_size_realtime', str(self.beam_size_realtime)]
if self.initial_prompt:
args += ['--initial_prompt', self.initial_prompt]
if self.wake_words is not None:
args += ['--wake_words', str(self.wake_words)]
if self.wake_words_sensitivity is not None:
args += ['--wake_words_sensitivity', str(self.wake_words_sensitivity)]
if self.wake_word_timeout is not None:
args += ['--wake_word_timeout', str(self.wake_word_timeout)]
if self.wake_word_activation_delay is not None:
args += ['--wake_word_activation_delay', str(self.wake_word_activation_delay)]
if self.wakeword_backend is not None:
args += ['--wakeword_backend', str(self.wakeword_backend)]
if self.openwakeword_model_paths:
args += ['--openwakeword_model_paths', str(self.openwakeword_model_paths)]
if self.openwakeword_inference_framework is not None:
args += ['--openwakeword_inference_framework', str(self.openwakeword_inference_framework)]
if self.wake_word_buffer_duration is not None:
args += ['--wake_word_buffer_duration', str(self.wake_word_buffer_duration)]
if self.use_main_model_for_realtime:
args.append('--use_main_model_for_realtime') # flag, no need for True/False
if self.use_extended_logging:
args.append('--use_extended_logging') # flag, no need for True/False

if self.control_url:
parsed_control_url = urlparse(self.control_url)
if parsed_control_url.port:
Expand All @@ -377,6 +407,7 @@ def start_server(self):
# Start the subprocess with the mapped arguments
if os.name == 'nt': # Windows
cmd = 'start /min cmd /c ' + subprocess.list2cmdline(args)
# print(f"Opening server with cli command: {cmd}")
subprocess.Popen(cmd, shell=True)
else: # Unix-like systems
subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True)
Expand Down Expand Up @@ -480,6 +511,8 @@ def record_and_send_audio(self):
message = struct.pack('<I', metadata_length) + metadata_json.encode('utf-8') + audio_data

if self.is_running:
if log_outgoing_chunks:
print(".", flush=True, end='')
self.data_ws.send(message, opcode=websocket.ABNF.OPCODE_BINARY)
except KeyboardInterrupt: # handle manual interruption (Ctrl+C)
if self.debug_mode:
Expand Down Expand Up @@ -513,8 +546,12 @@ def on_control_message(self, ws, message):
if 'status' in data:
if data['status'] == 'success':
if 'parameter' in data and 'value' in data:
if self.debug_mode:
print(f"Parameter {data['parameter']} = {data['value']}")
request_id = data.get('request_id')
if request_id is not None and request_id in self.pending_requests:
if self.debug_mode:
print(f"Parameter {data['parameter']} = {data['value']}")
self.pending_requests[request_id]['value'] = data['value']
self.pending_requests[request_id]['event'].set()
elif data['status'] == 'error':
print(f"Server Error: {data.get('message', '')}")
else:
Expand Down Expand Up @@ -550,12 +587,20 @@ def on_data_message(self, ws, message):
elif data.get('type') == 'vad_detect_start':
if self.on_vad_detect_start:
self.on_vad_detect_start()
elif data.get('type') == 'vad_detect_stop':
if self.on_vad_detect_stop:
self.on_vad_detect_stop()
elif data.get('type') == 'wakeword_detected':
if self.on_wakeword_detected:
self.on_wakeword_detected()
elif data.get('type') == 'wakeword_detection_start':
if self.on_wakeword_detection_start:
self.on_wakeword_detection_start()
elif data.get('type') == 'wakeword_detection_end':
if self.on_wakeword_detection_end:
self.on_wakeword_detection_end()
elif data.get('type') == 'recorded_chunk':
pass

else:
print(f"Unknown data message format: {data}")
Expand Down Expand Up @@ -595,12 +640,36 @@ def set_parameter(self, parameter, value):
self.control_ws.send(json.dumps(command))

def get_parameter(self, parameter):
# Generate a unique request_id
request_id = self.request_counter
self.request_counter += 1

# Prepare the command with the request_id
command = {
"command": "get_parameter",
"parameter": parameter
"parameter": parameter,
"request_id": request_id
}

# Create an event to wait for the response
event = threading.Event()
self.pending_requests[request_id] = {'event': event, 'value': None}

# Send the command to the server
self.control_ws.send(json.dumps(command))

# Wait for the response or timeout after 5 seconds
if event.wait(timeout=5):
value = self.pending_requests[request_id]['value']
# Clean up the pending request
del self.pending_requests[request_id]
return value
else:
print(f"Timeout waiting for get_parameter {parameter}")
# Clean up the pending request
del self.pending_requests[request_id]
return None

def call_method(self, method, args=None, kwargs=None):
command = {
"command": "call_method",
Expand Down
Loading

0 comments on commit cacb7a1

Please sign in to comment.