diff --git a/Jenkinsfile b/Jenkinsfile index cd1ad07..d1a3852 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,16 +1,19 @@ def buildDockerfile(main_folder, dockerfilePath, image_name, version, changedFiles) { - if (changedFiles.contains(main_folder) || changedFiles.contains('celery_app') || changedFiles.contains('http_server') || changedFiles.contains('websocket') || changedFiles.contains('document')) { + boolean has_changed = changedFiles.contains(main_folder) || changedFiles.contains('celery_app') || changedFiles.contains('http_server') || changedFiles.contains('websocket') || changedFiles.contains('document') + if (main_folder == "kaldi") { + // Kaldi also depends on recasepunc + has_changed = has_changed || changedFiles.contains('punctuation') + } + if (has_changed) { echo "Building Dockerfile for ${image_name} with version ${version} (using ${dockerfilePath})" script { def image = docker.build(image_name, "-f ${dockerfilePath} .") docker.withRegistry('https://registry.hub.docker.com', 'docker-hub-credentials') { - if (version == 'latest-unstable') { - image.push('latest-unstable') - } else { + image.push(version) + if (version != 'latest-unstable') { image.push('latest') - image.push(version) } } } diff --git a/celery_app/tasks.py b/celery_app/tasks.py index 114df2a..cae914b 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -6,10 +6,10 @@ from stt.processing.utils import load_audiofile from celery_app.celeryapp import celery - +from typing import Optional @celery.task(name="transcribe_task") -def transcribe_task(file_name: str, with_metadata: bool): +def transcribe_task(file_name: str, with_metadata: bool, language: Optional[str] = None): """transcribe_task""" logger.info(f"Received transcription task for {file_name}") @@ -26,7 +26,7 @@ def transcribe_task(file_name: str, with_metadata: bool): # Decode try: - result = decode(file_content, MODEL, with_metadata) + result = decode(file_content, MODEL, with_metadata, language=language) except Exception as err: import traceback diff --git a/document/swagger.yml b/document/swagger.yml index 89725bf..1d0a19a 100644 --- a/document/swagger.yml +++ b/document/swagger.yml @@ -30,6 +30,11 @@ paths: description: "Audio File - WaveFile PCM 16b 16KHz" required: true type: "file" + - name: "language" + in: "formData" + description: "Language (code or *)" + required: false + type: string responses: 200: description: Successfully transcribe the audio diff --git a/http_server/ingress.py b/http_server/ingress.py index 424e71c..f39747c 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -69,11 +69,12 @@ def transcribe(): raise ValueError(f"No audio file was uploaded (missing 'file' key)") file_buffer = request.files["file"].read() - + language = request.form.get("language") + audio_data = load_wave_buffer(file_buffer) # Transcription - transcription = decode(audio_data, MODEL, join_metadata) + transcription = decode(audio_data, MODEL, join_metadata, language=language) if join_metadata: return json.dumps(transcription, ensure_ascii=False), 200 diff --git a/kaldi/Dockerfile b/kaldi/Dockerfile index 17da10f..d9e020f 100644 --- a/kaldi/Dockerfile +++ b/kaldi/Dockerfile @@ -44,6 +44,9 @@ RUN git clone -b vosk --single-branch https://github.com/alphacep/kaldi /opt/kal && sed -i 's: -O1 : -O3 :g' kaldi.mk \ && make -j $(( $(nproc) < 8 ? $(nproc) : 8 )) online2 lm rnnlm +# Upgrade pip +RUN pip install --no-cache-dir --upgrade pip + # Install python dependencies COPY kaldi/requirements.txt ./ RUN pip install --no-cache-dir -r requirements.txt @@ -57,6 +60,12 @@ RUN git clone --depth 1 https://github.com/alphacep/vosk-api /opt/vosk-api && cd WORKDIR /usr/src/app +# Install what's needed for punctuation +COPY punctuation/requirements.cpu.txt ./ +RUN pip install --no-cache-dir -r requirements.cpu.txt -f https://download.pytorch.org/whl/torch_stable.html +RUN rm requirements.cpu.txt + +# Copy code COPY celery_app /usr/src/app/celery_app COPY http_server /usr/src/app/http_server COPY websocket /usr/src/app/websocket @@ -64,6 +73,7 @@ COPY document /usr/src/app/document COPY kaldi/stt /usr/src/app/stt COPY kaldi/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ COPY kaldi/lin_to_vosk.py /usr/src/app/lin_to_vosk.py +COPY punctuation ./punctuation RUN mkdir -p /var/log/supervisor/ @@ -71,4 +81,4 @@ ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" HEALTHCHECK CMD ./healthcheck.sh -ENTRYPOINT ["./docker-entrypoint.sh"] \ No newline at end of file +ENTRYPOINT ["./docker-entrypoint.sh"] diff --git a/kaldi/README.md b/kaldi/README.md index c7bd22b..40cfe81 100644 --- a/kaldi/README.md +++ b/kaldi/README.md @@ -15,13 +15,29 @@ To run the transcription models you'll need: * One CPU per worker. Inference time scales on CPU performances. ### Model -LinTO-STT-Kaldi accepts two kinds of models: +If not done alreadt, download and unzip model folders into a directory accessible from the docker container. + +LinTO-STT-Kaldi accepts two kinds of ASR models: * LinTO Acoustic and Languages models. -* Vosk models. +* Vosk models (all in one). We provide home-cured models (v2) on [dl.linto.ai](https://doc.linto.ai/docs/developpers/apis/ASR/models). Or you can also use Vosk models available [here](https://alphacephei.com/vosk/models). + +If you want text with upper case letters and punctuation, you can specify a recasepunc model. +Some recasepunc models trained on [Common Crawl](http://data.statmt.org/cc-100/) are available on [recasepunc](https://github.com/benob/recasepunc) for the following the languages: +* French + * [fr-txt.large.19000](https://github.com/benob/recasepunc/releases/download/0.3/fr-txt.large.19000) + +* English + * [en.23000](https://github.com/benob/recasepunc/releases/download/0.3/en.23000) +* Italian + * [it.22000](https://github.com/CoffeePerry/recasepunc/releases/download/v0.1.0/it.22000) +* Chinese + * [zh.24000](https://github.com/benob/recasepunc/releases/download/0.3/zh.24000) + + ### Docker The transcription service requires docker up and running. @@ -62,6 +78,8 @@ An example of .env file is provided in [kaldi/.envdefault](https://github.com/li | BROKER_PASS | Using the task mode, broker password | my-password | | STREAMING_PORT | Using the websocket mode, the listening port for ingoing WS connexions. | 80 | | CONCURRENCY | Maximum number of parallel requests | >1 | +| PUNCTUATION_MODEL | Path to a recasepunc model, for recovering punctuation and upper letter in streaming | opt/PUNCT | + ### Serving mode ![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png) @@ -89,6 +107,12 @@ docker run --rm \ linto-stt-kaldi:latest ``` +If you have a recasepunc model do recover punctuation marks, you can add the following option: +```bash +-v <>:/opt/PUNCT +--env PUNCTUATION_MODEL=/opt/PUNCT +``` + This will run a container providing an [HTTP API](#http-api) binded on the host HOST_SERVING_PORT port. **Parameters:** diff --git a/kaldi/RELEASE.md b/kaldi/RELEASE.md index f1c55dc..e9687b6 100644 --- a/kaldi/RELEASE.md +++ b/kaldi/RELEASE.md @@ -1,7 +1,10 @@ -# 1.0.3 +# 1.1.0 +- Add possibility to add recase & punctuation in streaming + +# 1.0.3 - Fix corner case in streaming where "eof" was found in message -# 1.0.2 +# 1.0.2 - Fix task mode for kaldi by updating SERVICES_BROKER and BROKER_PASS in .envdefault # 1.0.1 diff --git a/kaldi/stt/processing/__init__.py b/kaldi/stt/processing/__init__.py index e0476a1..4fbf37b 100644 --- a/kaldi/stt/processing/__init__.py +++ b/kaldi/stt/processing/__init__.py @@ -5,7 +5,10 @@ from stt import logger from stt.processing.decoding import decode from stt.processing.utils import load_audiofile, load_wave_buffer +from punctuation.recasepunc import load_recasepunc_model from vosk import Model +import torch + __all__ = [ "logger", @@ -23,12 +26,27 @@ logger.info("Loading acoustic model and decoding graph ...") start = time() try: - MODEL = Model(MODEL_PATH) + ASR_MODEL = Model(MODEL_PATH) except Exception as err: raise Exception("Failed to load transcription model: {}".format(str(err))) from err - sys.exit(-1) + logger.info("Acoustic model and decoding graph loaded. (t={}s)".format(time() - start)) + +# Number of CPU threads +NUM_THREADS = os.environ.get("NUM_THREADS", torch.get_num_threads()) +NUM_THREADS = int(NUM_THREADS) +# This set the number of threads for sklearn +os.environ["OMP_NUM_THREADS"] = str( + NUM_THREADS +) # This must be done BEFORE importing packages (sklearn, etc.) +# For Torch, we will set it afterward, because setting that before loading the model can hang the process (see https://github.com/pytorch/pytorch/issues/58962) +torch.set_num_threads(1) + +PUNCTUATION_MODEL = load_recasepunc_model() + +MODEL = (ASR_MODEL, PUNCTUATION_MODEL) + def warmup(): pass diff --git a/kaldi/stt/processing/decoding.py b/kaldi/stt/processing/decoding.py index 8c06007..d329e5d 100644 --- a/kaldi/stt/processing/decoding.py +++ b/kaldi/stt/processing/decoding.py @@ -3,13 +3,19 @@ from vosk import KaldiRecognizer, Model +from punctuation.recasepunc import apply_recasepunc -def decode(audio: tuple[bytes, int], model: Model, with_metadata: bool) -> dict: +def decode(audio: tuple[bytes, int], model: Model, with_metadata: bool, language=None) -> dict: """Transcribe the audio data using the vosk library with the defined model.""" - result = {"text": "", "confidence-score": 0.0, "words": []} + decoder_result = {"text": "", "confidence-score": 0.0, "words": []} + if language: + raise NotImplementedError("Language selection is not implemented for kaldi.") + audio_data, sampling_rate = audio + model, punctuation_model = model + recognizer = KaldiRecognizer(model, sampling_rate) recognizer.SetMaxAlternatives(0) # Set confidence per words recognizer.SetWords(with_metadata) @@ -22,12 +28,14 @@ def decode(audio: tuple[bytes, int], model: Model, with_metadata: bool) -> dict: try: decoder_result = json.loads(decoder_result_raw) except Exception: - return result - result["text"] = re.sub(" ", "", decoder_result["text"]) - if "result" in decoder_result: - result["words"] = [w for w in decoder_result["result"] if w["word"] != ""] - if result["words"]: - result["confidence-score"] = sum([w["conf"] for w in result["words"]]) / len( - result["words"] + return decoder_result + + decoder_result = apply_recasepunc(punctuation_model, decoder_result) + + if "decoder_result" in decoder_result: + decoder_result["words"] = [w for w in decoder_result["decoder_result"] if w["word"] != ""] + if decoder_result["words"]: + decoder_result["confidence-score"] = sum([w["conf"] for w in decoder_result["words"]]) / len( + decoder_result["words"] ) - return result + return decoder_result diff --git a/kaldi/stt/processing/streaming.py b/kaldi/stt/processing/streaming.py index f06cc82..bcaa5ce 100644 --- a/kaldi/stt/processing/streaming.py +++ b/kaldi/stt/processing/streaming.py @@ -7,6 +7,8 @@ from vosk import KaldiRecognizer, Model from websockets.legacy.server import WebSocketServerProtocol +from punctuation.recasepunc import apply_recasepunc + EOF_REGEX = re.compile(r' *\{.*"eof" *: *1.*\} *$') async def wssDecode(ws: WebSocketServerProtocol, model: Model): @@ -14,6 +16,8 @@ async def wssDecode(ws: WebSocketServerProtocol, model: Model): # Wait for config res = await ws.recv() + model, punctuation_model = model + # Parse config try: config = json.loads(res)["config"] @@ -42,6 +46,7 @@ async def wssDecode(ws: WebSocketServerProtocol, model: Model): # End frame if (isinstance(message, str) and re.match(EOF_REGEX, message)): ret = recognizer.FinalResult() + ret = apply_recasepunc(punctuation_model, ret) await ws.send(json.dumps(ret)) await ws.close(reason="End of stream") break @@ -49,6 +54,7 @@ async def wssDecode(ws: WebSocketServerProtocol, model: Model): # Audio chunk if recognizer.AcceptWaveform(message): ret = recognizer.Result() # Result seems to not work properly + ret = apply_recasepunc(punctuation_model, ret) await ws.send(ret) else: @@ -62,6 +68,8 @@ def ws_streaming(websocket_server: WSServer, model: Model): # Wait for config res = websocket_server.receive(timeout=10) + model, punctuation_model = model + # Timeout if res is None: pass @@ -93,6 +101,7 @@ def ws_streaming(websocket_server: WSServer, model: Model): # End frame if (isinstance(message, str) and re.match(EOF_REGEX, message)): ret = recognizer.FinalResult() + ret = apply_recasepunc(punctuation_model, ret) websocket_server.send(json.dumps(re.sub(" ", "", ret))) websocket_server.close() break @@ -100,6 +109,7 @@ def ws_streaming(websocket_server: WSServer, model: Model): print("Received chunk") if recognizer.AcceptWaveform(message): ret = recognizer.Result() + ret = apply_recasepunc(punctuation_model, ret) websocket_server.send(re.sub(" ", "", ret)) else: diff --git a/punctuation/__init__.py b/punctuation/__init__.py new file mode 100644 index 0000000..a05ea06 --- /dev/null +++ b/punctuation/__init__.py @@ -0,0 +1,8 @@ +import logging +import os + +logging.basicConfig( + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + datefmt="%d/%m/%Y %H:%M:%S", +) +logger = logging.getLogger("__punctuation__") diff --git a/punctuation/recasepunc.py b/punctuation/recasepunc.py new file mode 100644 index 0000000..bb8c9c5 --- /dev/null +++ b/punctuation/recasepunc.py @@ -0,0 +1,511 @@ +# coding=utf-8 + +"""recasepunc file.""" + +import argparse +import os +import random +import sys +import re +import json + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +# from mosestokenizer import * +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer, BertTokenizer + +default_config = argparse.Namespace( + seed=871253, + lang='fr', + # flavor='flaubert/flaubert_base_uncased', + flavor=None, + max_length=256, + batch_size=16, + updates=24000, + period=1000, + lr=1e-5, + dab_rate=0.1, + device='cuda', + debug=False +) + +default_flavors = { + 'fr': 'flaubert/flaubert_base_uncased', + 'en': 'bert-base-uncased', + 'zh': 'ckiplab/bert-base-chinese', + 'it': 'dbmdz/bert-base-italian-uncased', +} + + +class Config(argparse.Namespace): + def __init__(self, **kwargs): + super().__init__() + for key, value in default_config.__dict__.items(): + setattr(self, key, value) + for key, value in kwargs.items(): + setattr(self, key, value) + + assert self.lang in ['fr', 'en', 'zh', 'it'] + + if 'lang' in kwargs and ('flavor' not in kwargs or kwargs['flavor'] is None): + self.flavor = default_flavors[self.lang] + + # print(self.lang, self.flavor) + + +def init_random(seed): + # make sure everything is deterministic + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + set_seed(seed) + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + + +# NOTE: it is assumed in the implementation that y[:,0] is the punctuation label, and y[:,1] is the case label! + +punctuation = { + 'O': 0, + 'COMMA': 1, + 'PERIOD': 2, + 'QUESTION': 3, + 'EXCLAMATION': 4, +} + +punctuation_syms = ['', ',', '.', ' ?', ' !'] + +case = { + 'LOWER': 0, + 'UPPER': 1, + 'CAPITALIZE': 2, + 'OTHER': 3, +} + + +class Model(nn.Module): + def __init__(self, flavor, device): + super().__init__() + self.bert = AutoModel.from_pretrained(flavor) + # need a proper way of determining representation size + size = self.bert.dim \ + if hasattr(self.bert, 'dim') else self.bert.config.pooler_fc_size \ + if hasattr(self.bert.config, 'pooler_fc_size') else self.bert.config.emb_dim \ + if hasattr(self.bert.config, 'emb_dim') else self.bert.config.hidden_size + self.punc = nn.Linear(size, 5) + self.case = nn.Linear(size, 4) + self.dropout = nn.Dropout(0.3) + self.to(device) + + def forward(self, x): + output = self.bert(x) + representations = self.dropout(F.gelu(output['last_hidden_state'])) + punc = self.punc(representations) + case = self.case(representations) + return punc, case + +def recase(token, label): + if label == case['LOWER']: + return token.lower() + if label == case['CAPITALIZE']: + return token.lower().capitalize() + if label == case['UPPER']: + return token.upper() + return token + +_PUNCTUATION_MODEL = None + +def load_recasepunc_model(config=None): + + global _PUNCTUATION_MODEL + memoize = (config is None) + if memoize and _PUNCTUATION_MODEL is not None: + return _PUNCTUATION_MODEL + + checkpoint_path = os.environ.get('PUNCTUATION_MODEL') + if not checkpoint_path: + return None + + if config is None: + config = default_config + + device = os.environ.get("DEVICE") + if device is None: + if torch.cuda.is_available(): + config.device = 'cuda' + else: + config.device = 'cpu' + + print(f"Loading recasepunc model from {checkpoint_path} on device={config.device}") # TODO: use logger.info + + loaded = torch.load(checkpoint_path, map_location=config.device) + if 'config' in loaded: + config = Config(**loaded['config']) + + if config.flavor is None: + config.flavor = default_flavors[config.lang] + + init(config) + + model = Model(config.flavor, config.device) + model.load_state_dict(loaded['model_state_dict']) + + config.model = model + + if memoize: + _PUNCTUATION_MODEL = config + return config + + +def apply_recasepunc(config, line, ignore_disfluencies=False): + + num_threads = os.environ.get("OMP_NUM_THREADS") + if num_threads: + torch.set_num_threads(int(num_threads)) + + if isinstance(line, list): + return [apply_recasepunc(config, l, ignore_disfluencies=ignore_disfluencies) for l in line] + + if isinstance(line, dict): + new_dict = line.copy() + assert "text" in line + line = line["text"] + line = apply_recasepunc(config, line, ignore_disfluencies=ignore_disfluencies) + new_dict["text"] = line + return new_dict + + assert isinstance(line, str) + line = line.strip() + + if line.startswith("{") and line.endswith("}"): + # A dict inside a string + line = json.loads(line) + assert isinstance(line, dict) + return json.dumps(apply_recasepunc(config, line, ignore_disfluencies=ignore_disfluencies), indent=2, ensure_ascii=False) + + if not line: + # Avoid hanging on empty lines + return "" + + # Remove tokens (Ugly: LinTO/Kaldi model specific here) + line = re.sub(" ", "", line) + + if config is None: + return line + + model = config.model + set_seed(config.seed) + + # Drop all punctuation that can be generated + line = ''.join([c for c in line if c not in mapped_punctuation]) + + # Relevant only if disfluences annotations + if ignore_disfluencies: + # TODO: fix when there are several disfluencies in a row ("euh euh") + line = collapse_whitespace(line) + line = re.sub(r"(\w) *' *(\w)", r"\1'\2", line) # glue apostrophes to words + disfluencies, line = remove_simple_disfluences(line) + + output = '' + if config.debug: + print(line) + tokens = [config.cls_token] + config.tokenizer.tokenize(line) + [config.sep_token] + if config.debug: + print(tokens) + previous_label = punctuation['PERIOD'] + first_time = True + was_word = False + for start in range(0, len(tokens), config.max_length): + instance = tokens[start: start + config.max_length] + ids = config.tokenizer.convert_tokens_to_ids(instance) + if len(ids) < config.max_length: + ids += [config.pad_token_id] * (config.max_length - len(ids)) + x = torch.tensor([ids]).long().to(config.device) + y_scores1, y_scores2 = model(x) + y_pred1 = torch.max(y_scores1, 2)[1] + y_pred2 = torch.max(y_scores2, 2)[1] + for id, token, punc_label, case_label in zip(ids, instance, y_pred1[0].tolist()[:len(instance)], + y_pred2[0].tolist()[:len(instance)]): + if config.debug: + print(id, token, punc_label, case_label, file=sys.stderr) + if id in (config.cls_token_id, config.sep_token_id): + continue + if previous_label is not None and previous_label > 1: + if case_label in [case['LOWER'], case['OTHER']]: + case_label = case['CAPITALIZE'] + previous_label = punc_label + # different strategy due to sub-lexical token encoding in Flaubert + if config.lang == 'fr': + if token.endswith(''): + cased_token = recase(token[:-4], case_label) + if was_word: + output += ' ' + output += cased_token + punctuation_syms[punc_label] + was_word = True + else: + cased_token = recase(token, case_label) + if was_word: + output += ' ' + output += cased_token + was_word = False + else: + if token.startswith('##'): + cased_token = recase(token[2:], case_label) + output += cased_token + else: + cased_token = recase(token, case_label) + if not first_time: + output += ' ' + first_time = False + output += cased_token + punctuation_syms[punc_label] + if previous_label == 0: + output += '.' + # Glue apostrophes back to words + output = re.sub(r"(\w) *' *(\w)", r"\1'\2", output) + + if ignore_disfluencies: + output = collapse_whitespace(output) + output = reconstitute_text(output, disfluencies) + return output + +mapped_punctuation = { + '.': 'PERIOD', + '...': 'PERIOD', + ',': 'COMMA', + ';': 'COMMA', + ':': 'COMMA', + '(': 'COMMA', + ')': 'COMMA', + '?': 'QUESTION', + '!': 'EXCLAMATION', + ',': 'COMMA', + '!': 'EXCLAMATION', + '?': 'QUESTION', + ';': 'COMMA', + ':': 'COMMA', + '(': 'COMMA', + '(': 'COMMA', + ')': 'COMMA', + '[': 'COMMA', + ']': 'COMMA', + '【': 'COMMA', + '】': 'COMMA', + '└': 'COMMA', + #'└ ': 'COMMA', + '_': 'O', + '。': 'PERIOD', + '、': 'COMMA', # enumeration comma + '、': 'COMMA', + '…': 'PERIOD', + '—': 'COMMA', + '「': 'COMMA', + '」': 'COMMA', + '.': 'PERIOD', + '《': 'O', + '》': 'O', + ',': 'COMMA', + '“': 'O', + '”': 'O', + '"': 'O', + #'-': 'O', # hyphen is a word piece + '〉': 'COMMA', + '〈': 'COMMA', + '↑': 'O', + '〔': 'COMMA', + '〕': 'COMMA', +} + +def collapse_whitespace(text): + return re.sub(r'\s+', ' ', text).strip() + + +# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased +# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py + +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100, keep_case=True): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + self.keep_case = keep_case + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in text.strip().split(): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + # optionaly lowercase substring before checking for inclusion in vocab + if (self.keep_case and substr.lower() in self.vocab) or (substr in self.vocab): + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +# modification of XLM bpe tokenizer for keeping case information when vocab is lowercase +# forked from https://github.com/huggingface/transformers/blob/cd56f3fe7eae4a53a9880e3f5e8f91877a78271c/src/transformers/models/xlm/tokenization_xlm.py +def bpe(self, token): + def to_lower(pair): + # print(' ',pair) + return (pair[0].lower(), pair[1].lower()) + + from transformers.models.xlm.tokenization_xlm import get_pairs + + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(to_lower(pair), float("inf"))) + # print(bigram) + if to_lower(bigram) not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + +def init(config): + init_random(config.seed) + + if config.lang == 'fr': + config.tokenizer = tokenizer = AutoTokenizer.from_pretrained(config.flavor, do_lower_case=False) + + from transformers.models.xlm.tokenization_xlm import XLMTokenizer + assert isinstance(tokenizer, XLMTokenizer) + + # monkey patch XLM tokenizer + import types + tokenizer.bpe = types.MethodType(bpe, tokenizer) + else: + # warning: needs to be BertTokenizer for monkey patching to work + config.tokenizer = tokenizer = BertTokenizer.from_pretrained(config.flavor, do_lower_case=False) + + # warning: monkey patch tokenizer to keep case information + # from recasing_tokenizer import WordpieceTokenizer + config.tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token) + + if config.lang == 'fr': + config.pad_token_id = tokenizer.pad_token_id + config.cls_token_id = tokenizer.bos_token_id + config.cls_token = tokenizer.bos_token + config.sep_token_id = tokenizer.sep_token_id + config.sep_token = tokenizer.sep_token + else: + config.pad_token_id = tokenizer.pad_token_id + config.cls_token_id = tokenizer.cls_token_id + config.cls_token = tokenizer.cls_token + config.sep_token_id = tokenizer.sep_token_id + config.sep_token = tokenizer.sep_token + + if not torch.cuda.is_available() and config.device == 'cuda': + print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr) + config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu') + +def remove_simple_disfluences(text, language=None): + if language is None: + # Get language from environment + language = os.environ.get("LANGUAGE","") + language = language.lower()[:2] + disfluencies = DISFLUENCIES.get(language, []) + all_hits = [] + for disfluency in disfluencies: + all_hits += re.finditer(r" *\b"+disfluency+r"\b *", text) + all_hits = sorted(all_hits, key=lambda x: x.start()) + to_be_inserted = [(hit.start(), hit.group()) for hit in all_hits] + new_text = text + for hit in all_hits[::-1]: + new_text = new_text[:hit.start()] + " " + new_text[hit.end():] + return to_be_inserted, new_text + +punctuation_regex = r"["+re.escape("".join(mapped_punctuation.keys()))+r"]" + +def reconstitute_text(text, to_be_inserted): + if len(to_be_inserted) == 0: + return text + pos_punc = [s.start() for s in re.finditer(punctuation_regex, text)] + for start, token in to_be_inserted: + start += len([p for p in pos_punc if p < start]) + text = text[:start] + token.rstrip(" ") + text[start:] + print(text) + return text + + +DISFLUENCIES = { + "fr": [ + "euh", + "heu", + ] +} \ No newline at end of file diff --git a/punctuation/requirements.cpu.txt b/punctuation/requirements.cpu.txt new file mode 100644 index 0000000..d60673f --- /dev/null +++ b/punctuation/requirements.cpu.txt @@ -0,0 +1,11 @@ +celery[redis,auth,msgpack]>=4.4.7 +flask>=1.1.2 +flask-cors>=3.0.10 +flask-swagger-ui==3.36.0 +gevent>=22.10.2 +gunicorn>=20.1.0 +git+https://github.com/benob/mosestokenizer.git@169bd3a504fe20a3e51b9a7af3f0ca359c2d36c9 +numpy==1.19.5 +regex==2021.8.28 +torch==1.9.0+cpu +transformers==4.10.0 diff --git a/punctuation/requirements.txt b/punctuation/requirements.txt new file mode 100644 index 0000000..9082873 --- /dev/null +++ b/punctuation/requirements.txt @@ -0,0 +1,11 @@ +celery[redis,auth,msgpack]>=4.4.7 +flask>=1.1.2 +flask-cors>=3.0.10 +flask-swagger-ui==3.36.0 +gevent>=22.10.2 +gunicorn>=20.1.0 +git+https://github.com/benob/mosestokenizer.git@169bd3a504fe20a3e51b9a7af3f0ca359c2d36c9 +numpy==1.19.5 +regex==2021.8.28 +torch==1.9.0 +transformers==4.10.0 diff --git a/test/automated/__init__.py b/test/automated/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/automated/automated_utils.py b/test/automated/automated_utils.py new file mode 100644 index 0000000..c8668b2 --- /dev/null +++ b/test/automated/automated_utils.py @@ -0,0 +1,44 @@ +import os +import re +from configparser import ConfigParser + +AUTOMATEDTESTDIR = os.path.dirname(os.path.realpath(__file__)) +TESTDIR = os.path.dirname(AUTOMATEDTESTDIR) +ROOTDIR = os.path.dirname(TESTDIR) +os.chdir(ROOTDIR) + +config = ConfigParser() +config.read(f"{AUTOMATEDTESTDIR}/test_config.ini") + +SERVER_STARTING_TIMEOUT = int(config.get('server', 'STARTING_TIMEOUT')) if config.get('server', 'STARTING_TIMEOUT')!="" else 60 + +def copy_env_file(env_file, env_variables=""): + env_variables = env_variables.split() + env_variables.append("SERVICE_MODE=") + with open(env_file, "r") as f: + lines = f.readlines() + with open(f"{AUTOMATEDTESTDIR}/.env", "w") as f: + for line in lines: + if not any([line.startswith(b.split("=")[0] + "=") for b in env_variables]): + f.write(line) + +def parse_env_variables(env_variables): + # make a dict + env_variables = env_variables.split() + env = {} + for env_variable in env_variables: + key, value = env_variable.split("=") + env[key] = value + return env + +def get_file_regex(test_file, language=None): + if not language: + raise ValueError("Language must be set") + if os.path.basename(test_file) == "bonjour.wav": + if language == "ru": + return re.compile("Б") + else : + return re.compile("[bB]onjour") + elif "notexisting": + return re.compile("") + raise ValueError(f"Unknown test file {test_file}") \ No newline at end of file diff --git a/test/test.py b/test/automated/core.py similarity index 60% rename from test/test.py rename to test/automated/core.py index f683583..17a7a52 100644 --- a/test/test.py +++ b/test/automated/core.py @@ -4,74 +4,14 @@ import subprocess import requests import re -from ddt import ddt, idata -from pathlib import Path import warnings +from ddt import ddt +from pathlib import Path +from automated_utils import AUTOMATEDTESTDIR, TESTDIR, SERVER_STARTING_TIMEOUT, get_file_regex, parse_env_variables -TESTDIR = os.path.dirname(os.path.realpath(__file__)) -ROOTDIR = os.path.dirname(TESTDIR) -os.chdir(ROOTDIR) -TESTDIR = os.path.basename(TESTDIR) - - - -def generate_whisper_test_setups(): - dockerfiles = [ - "whisper/Dockerfile.ctranslate2", - "whisper/Dockerfile.ctranslate2.cpu", - "whisper/Dockerfile.torch", - "whisper/Dockerfile.torch.cpu", - ] - - servings = ["http", "task"] - - vads = [None, "false", "auditok", "silero"] - devices = [None, "cpu", "cuda"] - models = ["tiny"] - - for dockerfile in dockerfiles: - for device in devices: - for vad in vads: - for model in models: - for serving in servings: - - # Test CPU dockerfile only on CPU - if dockerfile.endswith("cpu") and device != "cpu": - continue - - # Do not test all VAD settings if not on CPU - if vad not in [None, "silero"]: - if device != "cpu": - continue - - env_variables = "" - if vad: - env_variables += f"VAD={vad} " - if device: - env_variables += f"DEVICE={device} " - env_variables += f"MODEL={model}" - - yield dockerfile, serving, env_variables - -def generate_kaldi_test_setups(): - dockerfiles = ["kaldi/Dockerfile"] - - servings = ["http", "task"] - - for dockerfile in dockerfiles: - for serving in servings: - env_variables = "" - yield dockerfile, serving, env_variables - -def copy_env_file(env_file, env_variables=""): - env_variables = env_variables.split() - env_variables.append("SERVICE_MODE=") - with open(env_file, "r") as f: - lines = f.readlines() - with open(f"{TESTDIR}/.env", "w") as f: - for line in lines: - if not any([line.startswith(b.split("=")[0] + "=") for b in env_variables]): - f.write(line) +def finalize_tests(): + subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) @ddt class TestRunner(unittest.TestCase): @@ -79,10 +19,6 @@ class TestRunner(unittest.TestCase): built_images = [] redis_launched = False - # def __init__(self, *args, **kwargs): - # super(TestRunner, self).__init__(*args, **kwargs) - # self.cleanup() - def echo_success(self, message): print('\033[0;32m' + u'\u2714' + '\033[0m ' + message) @@ -173,7 +109,7 @@ def build_and_run_container(self, serving, docker_image, env_variables, use_loca if serving == "task": self.launch_redis() - build_args += "-v {}/:/opt/audio ".format(os.getcwd()) + build_args += f"-v {TESTDIR}/:/opt/audio " tag = f"test_{os.path.basename(docker_image)}" if tag not in TestRunner.built_images: @@ -190,7 +126,7 @@ def build_and_run_container(self, serving, docker_image, env_variables, use_loca self.echo_note(f"Docker image has been successfully built in {end_time - start_time:.0f} sec.") TestRunner.built_images.append(tag) - cmd=f"docker run --rm -p 8080:80 --name test_container --env-file {TESTDIR}/.env --gpus all {build_args} linto-stt-test:{tag}" + cmd=f"docker run --rm -p 8080:80 --name test_container --env-file {AUTOMATEDTESTDIR}/.env --gpus all {build_args} linto-stt-test:{tag}" self.echo_command(cmd) p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) if p.poll() is not None: @@ -210,11 +146,9 @@ def transcribe(self, command, regex, test_file, error_message, success_message, self.echo_note(f"{success_message} has transcribed {test_file} in {end - start:.0f} sec.") return - def run_test(self, docker_image="whisper/Dockerfile.ctranslate2", serving="http", env_variables="", test_file=f"{TESTDIR}/bonjour.wav", use_local_cache=True, expect_failure=False): + def run_test(self, docker_image="whisper/Dockerfile.ctranslate2", serving="http", env_variables="", test_file=f"{TESTDIR}/bonjour.wav", language=None, use_local_cache=True, expect_failure=False): warnings.simplefilter("ignore", ResourceWarning) - regex = "" - if os.path.basename(test_file) == "bonjour.wav": - regex = re.compile("[bB]onjour") + regex = get_file_regex(test_file, parse_env_variables(env_variables).get("LANGUAGE", "fr") if language is None else language) r, pid = self.build_and_run_container(serving, docker_image, env_variables, use_local_cache) if r: return self.report_failure(r, expect_failure=expect_failure) @@ -223,16 +157,22 @@ def run_test(self, docker_image="whisper/Dockerfile.ctranslate2", serving="http" if r: return self.report_failure(r, expect_failure=expect_failure) cmd = f'curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@{test_file};type=audio/wav"' + if language: + cmd += f' -F "language={language}"' self.echo_command(cmd) r = self.transcribe(cmd, regex, test_file, "Error transcription", "HTTP route 'transcribe'") if r: return self.report_failure(r, expect_failure=expect_failure) cmd = f"python3 {TESTDIR}/test_streaming.py --audio_file {test_file}" + if language: + cmd += f" --language {language}" self.echo_command(cmd) r = self.transcribe(cmd, regex, test_file, "Error streaming", "HTTP route 'streaming'") elif serving == "task": # you can be stuck here if the server crashed bc the task will be in the queue forever - cmd = f"python3 {TESTDIR}/test_celery.py {test_file}" + cmd = f"python3 {TESTDIR}/test_celery.py --audio_file {os.path.basename(test_file)}" + if language: + cmd += f" --language {language}" self.echo_command(cmd) r = self.transcribe(cmd, regex, test_file, "Error task", "TASK route", timeout=60) else: @@ -250,69 +190,3 @@ def setUp(self): def tearDown(self): print("-"*70) - - @idata(generate_kaldi_test_setups()) - def test_01_kaldi_integration(self, setup): - dockerfile, serving, env_variables = setup - if AM_PATH is None or LM_PATH is None or AM_PATH=="" or LM_PATH=="": - self.fail("AM or LM path not provided. Skipping kaldi test.") - if not os.path.exists(AM_PATH) or not os.path.exists(LM_PATH): - self.fail(f"AM or LM path not found: {AM_PATH} or {LM_PATH}") - copy_env_file("kaldi/.envdefault") - env_variables += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM" - self.run_test(dockerfile, serving=serving, env_variables=env_variables) - - - @idata(generate_whisper_test_setups()) - def test_03_whisper_integration(self, setup): - dockerfile, serving, env_variables = setup - copy_env_file("whisper/.envdefault", env_variables) - self.run_test(dockerfile, serving=serving, env_variables=env_variables) - - def test_02_whisper_failures_cuda_on_cpu_dockerfile(self): - env_variables = "MODEL=tiny DEVICE=cuda" - dockerfile = "whisper/Dockerfile.ctranslate2.cpu" - copy_env_file("whisper/.envdefault", env_variables) - self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=True)) - - def test_02_whisper_failures_not_existing_file(self): - env_variables = "MODEL=tiny" - copy_env_file("whisper/.envdefault", env_variables) - with self.assertRaises(FileNotFoundError): - self.run_test(test_file="notexisting", env_variables=env_variables, expect_failure=True) - self.cleanup() - - def test_02_whisper_failures_wrong_vad(self): - env_variables = "VAD=whatever MODEL=tiny" - copy_env_file("whisper/.envdefault", env_variables) - self.assertIn("Got unexpected VAD method whatever", self.run_test(env_variables=env_variables, expect_failure=True)) - - def test_04_model_whisper(self): - env_variables = "MODEL=small" - copy_env_file("whisper/.envdefault", env_variables) - self.run_test(env_variables=env_variables) - -def finalize_tests(): - subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - - -AM_PATH = None -LM_PATH = None -SERVER_STARTING_TIMEOUT = 60 - -if __name__ == '__main__': - from configparser import ConfigParser - config = ConfigParser() - - config.read(f"{TESTDIR}/test_config.ini") - - SERVER_STARTING_TIMEOUT = int(config.get('server', 'STARTING_TIMEOUT')) if config.get('server', 'STARTING_TIMEOUT')!="" else SERVER_STARTING_TIMEOUT - - AM_PATH = config.get('kaldi', 'AM_PATH') - LM_PATH = config.get('kaldi', 'LM_PATH') - - try: - unittest.main(verbosity=2) - finally: - finalize_tests() diff --git a/test/automated/kaldi.py b/test/automated/kaldi.py new file mode 100644 index 0000000..5e5cc16 --- /dev/null +++ b/test/automated/kaldi.py @@ -0,0 +1,45 @@ +import unittest +import os + +from core import TestRunner, finalize_tests +from automated_utils import config, copy_env_file, TESTDIR +from ddt import ddt, idata + + +def generate_kaldi_test_setups(): + dockerfiles = ["kaldi/Dockerfile"] + + servings = ["http", "task"] + + for dockerfile in dockerfiles: + for serving in servings: + env_variables = "" + yield dockerfile, serving, env_variables + + +@ddt +class KaldiTestRunner(TestRunner): + + @idata(generate_kaldi_test_setups()) + def test_01_integration(self, setup): + dockerfile, serving, env_variables = setup + if AM_PATH is None or LM_PATH is None or AM_PATH == "" or LM_PATH == "": + self.fail("AM or LM path not provided. Skipping kaldi test.") + if not os.path.exists(AM_PATH) or not os.path.exists(LM_PATH): + self.fail(f"AM or LM path not found: {AM_PATH} or {LM_PATH}") + copy_env_file("kaldi/.envdefault") + env_variables += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM" + self.run_test(dockerfile, serving=serving, env_variables=env_variables) + + +AM_PATH = None +LM_PATH = None + +if __name__ == "__main__": + AM_PATH = config.get("kaldi", "AM_PATH") + LM_PATH = config.get("kaldi", "LM_PATH") + + try: + unittest.main(verbosity=2) + finally: + finalize_tests() diff --git a/test/test_config.ini b/test/automated/test_config.ini similarity index 100% rename from test/test_config.ini rename to test/automated/test_config.ini diff --git a/test/automated/whisper.py b/test/automated/whisper.py new file mode 100644 index 0000000..79fd4f3 --- /dev/null +++ b/test/automated/whisper.py @@ -0,0 +1,140 @@ +import unittest +from ddt import ddt, idata +from core import TestRunner, finalize_tests +from automated_utils import config, copy_env_file, TESTDIR + + +def generate_whisper_test_setups( + device="cpu", vads=[None, "false", "auditok", "silero"] +): + # reduce the number of tests because it takes multiples hours + if device == "cpu": + dockerfiles = [ + "whisper/Dockerfile.ctranslate2.cpu", + "whisper/Dockerfile.torch.cpu", + ] + elif device == "cuda": + dockerfiles = [ + "whisper/Dockerfile.ctranslate2", + "whisper/Dockerfile.torch", + ] + else: + dockerfiles = [ + "whisper/Dockerfile.ctranslate2", + "whisper/Dockerfile.ctranslate2.cpu", + "whisper/Dockerfile.torch", + "whisper/Dockerfile.torch.cpu", + ] + + servings = ["http", "task"] + + models = ["tiny"] + + for dockerfile in dockerfiles: + for vad in vads: + for model in models: + for serving in servings: + env_variables = "" + if vad: + env_variables += f"VAD={vad} " + if device: + env_variables += f"DEVICE={device} " + env_variables += f"MODEL={model}" + + yield dockerfile, serving, env_variables + + +@ddt +class WhisperTestRunner(TestRunner): + + @idata(generate_whisper_test_setups(device="cpu")) + def test_04_integration_cpu(self, setup): + dockerfile, serving, env_variables = setup + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(dockerfile, serving=serving, env_variables=env_variables) + + @idata(generate_whisper_test_setups(device="cuda", vads=[None, "silero"])) + def test_05_integration_cuda(self, setup): + dockerfile, serving, env_variables = setup + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(dockerfile, serving=serving, env_variables=env_variables) + + @idata(generate_whisper_test_setups(device=None, vads=[None])) + def test_06_integration_nodevice(self, setup): + dockerfile, serving, env_variables = setup + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(dockerfile, serving=serving, env_variables=env_variables) + + def test_02_failures_cuda_on_cpu_dockerfile(self): + env_variables = "MODEL=tiny DEVICE=cuda" + dockerfile = "whisper/Dockerfile.ctranslate2.cpu" + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn( + "cannot open shared object file", + self.run_test(dockerfile, env_variables=env_variables, expect_failure=True), + ) + + def test_02_failure_not_existing_file(self): + env_variables = "MODEL=tiny" + copy_env_file("whisper/.envdefault", env_variables) + with self.assertRaises(FileNotFoundError): + self.run_test( + test_file="notexisting", + env_variables=env_variables, + expect_failure=True, + ) + self.cleanup() + + def test_02_failure_wrong_vad(self): + env_variables = "VAD=whatever MODEL=tiny" + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn( + "Got unexpected VAD method whatever", + self.run_test(env_variables=env_variables, expect_failure=True), + ) + + def test_03_model(self): + env_variables = "MODEL=small" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(env_variables=env_variables) + + def test_01_failure_wrong_language(self): + env_variables = "MODEL=tiny LANGUAGE=whatever" + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn( + "ValueError: Language \'whatever\' is not available", + self.run_test(env_variables=env_variables, expect_failure=True), + ) + + def test_01_nolanguage(self): + env_variables = "MODEL=tiny LANGUAGE=*" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(env_variables=env_variables) + + def test_01_russian(self): + env_variables = "MODEL=tiny LANGUAGE=ru" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(env_variables=env_variables) + + def test_01_language_over_config(self): + env_variables = "MODEL=tiny LANGUAGE=ru" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(env_variables=env_variables, language="fr") + + def test_01_russian_celery(self): + env_variables = "MODEL=tiny LANGUAGE=ru" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(serving="task", env_variables=env_variables) + + def test_01_language_over_config_celery(self): + env_variables = "MODEL=tiny LANGUAGE=ru" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(serving="task", env_variables=env_variables, language="fr") + + + +if __name__ == "__main__": + try: + unittest.main(verbosity=2) + finally: + finalize_tests() diff --git a/test/test_celery.py b/test/test_celery.py index 59ed62e..dfbc9b8 100755 --- a/test/test_celery.py +++ b/test/test_celery.py @@ -1,16 +1,24 @@ import sys from celery import Celery -def transcribe_task(file_path): +def transcribe_task(file_path, language=None): celery = Celery(broker='redis://localhost:6379/0', backend='redis://localhost:6379/1') r = celery.send_task( 'transcribe_task', ( file_path, True, + language ), queue='stt') return r.get() if __name__ == '__main__': - print(transcribe_task(sys.argv[1])) \ No newline at end of file + import argparse + parser = argparse.ArgumentParser(description='Transcribe with LinSTT', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--audio_file", default="bonjour.wav", help="A path to an audio file to transcribe (if not provided, use mic)") + parser.add_argument("--language", default=None, help="Language model to use") + args = parser.parse_args() + print(transcribe_task(args.audio_file, args.language)) \ No newline at end of file diff --git a/test/test_streaming.py b/test/test_streaming.py index 670540a..5af369e 100644 --- a/test/test_streaming.py +++ b/test/test_streaming.py @@ -14,6 +14,7 @@ async def _linstt_streaming( audio_file, ws_api = "ws://localhost:8080/streaming", verbose = False, + language = None ): if audio_file is None: @@ -29,7 +30,11 @@ async def _linstt_streaming( text = "" partial = None async with websockets.connect(ws_api) as websocket: - await websocket.send(json.dumps({"config" : {"sample_rate": 16000 }})) + if language is not None: + config = {"config" : {"sample_rate": 16000, "language": language}} + else: + config = {"config" : {"sample_rate": 16000}} + await websocket.send(json.dumps(config)) while True: data = stream.read(2*2*16000) if audio_file and not data: @@ -107,6 +112,7 @@ def print_final(text, background=" "): ) parser.add_argument("-v", "--verbose", action="store_true", help="Verbose mode") parser.add_argument("--audio_file", default=None, help="A path to an audio file to transcribe (if not provided, use mic)") + parser.add_argument("--language", default=None, help="Language model to use") args = parser.parse_args() - res = linstt_streaming(args.audio_file, args.server, verbose=2 if args.verbose else 1) \ No newline at end of file + res = linstt_streaming(args.audio_file, args.server, verbose=2 if args.verbose else 1, language=args.language) \ No newline at end of file diff --git a/whisper/.envdefault b/whisper/.envdefault index 3230b46..c649248 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -20,8 +20,8 @@ STREAMING_PORT=80 ############################################ # The model can be a path to a model (e.g. "/root/.cache/whisper/large-v3.pt", "/root/.cache/huggingface/hub/models--openai--whisper-large-v3"), -# or a model size ("tiny", "base", "small", "medium", "large-v1", "large-v2" or "large-v3") -# or a HuggingFace model name (e.g. "distil-whisper/distil-large-v2") +# or a model size ("tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3" or "large-v3-turbo") +# or a HuggingFace model name (e.g. "distil-whisper/distil-large-v2", "bofenghuang/whisper-large-v3-french-distil-dec8", ...) MODEL=large-v3 # The language can be in different formats: "en", "en-US", "English", ... diff --git a/whisper/Dockerfile.ctranslate2 b/whisper/Dockerfile.ctranslate2 index 5fd3c53..9111564 100644 --- a/whisper/Dockerfile.ctranslate2 +++ b/whisper/Dockerfile.ctranslate2 @@ -1,5 +1,5 @@ FROM ghcr.io/opennmt/ctranslate2:latest-ubuntu20.04-cuda12.2 -LABEL maintainer="contact@linto.ai, jlouradour@linagora.com, dgaynullin@linagora.com" +LABEL maintainer="contact@linto.ai, jlouradour@linagora.com, abert@linagora.com" RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ffmpeg git diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu index 1f0f40c..40d8d00 100644 --- a/whisper/Dockerfile.ctranslate2.cpu +++ b/whisper/Dockerfile.ctranslate2.cpu @@ -1,5 +1,5 @@ FROM python:3.9 -LABEL maintainer="contact@linto.ai, jlouradour@linagora.com, dgaynullin@linagora.com" +LABEL maintainer="contact@linto.ai, jlouradour@linagora.com, abert@linagora.com" RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends ffmpeg git diff --git a/whisper/README.md b/whisper/README.md index 4bfb6da..1f4b99a 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -67,6 +67,13 @@ for some model sizes, depending on the backend 8.2G 10.4G + + large-v3-turbo + 1.3G + 2.0G + 4.0G + 6.0G + ### Model(s) @@ -114,7 +121,7 @@ An example of .env file is provided in [whisper/.envdefault](https://github.com/ | PARAMETER | DESCRIPTION | EXEMPLE | |---|---|---| -| SERVICE_MODE | (Required) STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` | +| SERVICE_MODE | (Required) STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` \| `websocket` | | MODEL | (Required) Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | `large-v3` \| `distil-whisper/distil-large-v2` \| \ \| ... | | LANGUAGE | Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... | | PROMPT | Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` | @@ -154,13 +161,14 @@ you may want to download one of OpenAI Whisper models: * [large-v1](https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt) * [large-v2](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt) * [large-v3](https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt) + * [large-v3-turbo](https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt) * Whisper models specialized for English can also be found here: * [tiny.en](https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt) * [base.en](https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt) * [small.en](https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt) * [medium.en](https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt) -If you already used Whisper in the past locally using [OpenAI-Whipser](https://github.com/openai/whisper), models can be found under ~/.cache/whisper. +If you already used Whisper in the past locally using [OpenAI-Whipser](https://github.com/openai/whisper), models can be found under `~/.cache/whisper`. The same apply for Whisper models from Hugging Face (transformers), as for instance https://huggingface.co/distil-whisper/distil-large-v2 (you can either download the model or use the Hugging Face identifier `distil-whisper/distil-large-v2`). @@ -173,23 +181,12 @@ automatic language detection will be performed by Whisper. The language can be a code of two or three letters. The list of languages supported by Whisper are: ``` af(afrikaans), am(amharic), ar(arabic), as(assamese), az(azerbaijani), -ba(bashkir), be(belarusian), bg(bulgarian), bn(bengali), bo(tibetan), br(breton), bs(bosnian), -ca(catalan), cs(czech), cy(welsh), da(danish), de(german), el(greek), en(english), es(spanish), -et(estonian), eu(basque), fa(persian), fi(finnish), fo(faroese), fr(french), gl(galician), -gu(gujarati), ha(hausa), haw(hawaiian), he(hebrew), hi(hindi), hr(croatian), ht(haitian creole), -hu(hungarian), hy(armenian), id(indonesian), is(icelandic), it(italian), ja(japanese), -jw(javanese), ka(georgian), kk(kazakh), km(khmer), kn(kannada), ko(korean), la(latin), -lb(luxembourgish), ln(lingala), lo(lao), lt(lithuanian), lv(latvian), mg(malagasy), mi(maori), -mk(macedonian), ml(malayalam), mn(mongolian), mr(marathi), ms(malay), mt(maltese), my(myanmar), -ne(nepali), nl(dutch), nn(nynorsk), no(norwegian), oc(occitan), pa(punjabi), pl(polish), -ps(pashto), pt(portuguese), ro(romanian), ru(russian), sa(sanskrit), sd(sindhi), si(sinhala), -sk(slovak), sl(slovenian), sn(shona), so(somali), sq(albanian), sr(serbian), su(sundanese), -sv(swedish), sw(swahili), ta(tamil), te(telugu), tg(tajik), th(thai), tk(turkmen), tl(tagalog), -tr(turkish), tt(tatar), uk(ukrainian), ur(urdu), uz(uzbek), vi(vietnamese), yi(yiddish), -yo(yoruba), zh(chinese) +ba(bashkir), be(belarusian), bg(bulgarian), bn(bengali), bo(tibetan), br(breton), bs(bosnian), ca(catalan), cs(czech), cy(welsh), da(danish), de(german), el(greek), en(english), es(spanish), et(estonian), eu(basque), fa(persian), fi(finnish), fo(faroese), fr(french), gl(galician), gu(gujarati), ha(hausa), haw(hawaiian), he(hebrew), hi(hindi), hr(croatian), ht(haitian creole), hu(hungarian), hy(armenian), id(indonesian), is(icelandic), it(italian), ja(japanese), jw(javanese), ka(georgian), kk(kazakh), km(khmer), kn(kannada), ko(korean), la(latin), lb(luxembourgish), ln(lingala), lo(lao), lt(lithuanian), lv(latvian), mg(malagasy), mi(maori), mk(macedonian), ml(malayalam), mn(mongolian), mr(marathi), ms(malay), mt(maltese), my(myanmar), ne(nepali), nl(dutch), nn(nynorsk), no(norwegian), oc(occitan), pa(punjabi), pl(polish), ps(pashto), pt(portuguese), ro(romanian), ru(russian), sa(sanskrit), sd(sindhi), si(sinhala), sk(slovak), sl(slovenian), sn(shona), so(somali), sq(albanian), sr(serbian), su(sundanese), sv(swedish), sw(swahili), ta(tamil), te(telugu), tg(tajik), th(thai), tk(turkmen), tl(tagalog), tr(turkish), tt(tatar), uk(ukrainian), ur(urdu), uz(uzbek), vi(vietnamese), yi(yiddish), yo(yoruba), zh(chinese) ``` and also `yue(cantonese)` since large-v3. +Language codes with "-" in them like "fr-FR" are also supported, but the part after "-" is ignored. So "fr-CA" (Canadian french) is equivalent to just "fr". Languages names like "French" are also supported and will be converted to "fr". + #### SERVING_MODE ![Serving Modes](https://i.ibb.co/qrtv3Z6/platform-stt.png) @@ -286,6 +283,7 @@ Transcription API * Method: POST * Response content: text/plain or application/json * File: An Wave file 16b 16Khz +* Language (optional): For overriding env language Return the transcripted text using "text/plain" or a json object when using "application/json" structure as followed: ```json @@ -300,6 +298,7 @@ Return the transcripted text using "text/plain" or a json object when using "app }, ... ], + "language": "en", "confidence-score": 0.879 } ``` @@ -308,7 +307,7 @@ Return the transcripted text using "text/plain" or a json object when using "app The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true. The route accepts websocket connexions. Exchanges are structured as followed: -1. Client send a json {"config": {"sample_rate":16000}}. +1. Client send a json {"config": {"sample_rate":16000, "language":"en"}}. Language is optional, if not specified it will use the language from the env. 2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-). 3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}. 4. Back to 2- diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md index 448d3c2..6621cd6 100644 --- a/whisper/RELEASE.md +++ b/whisper/RELEASE.md @@ -1,3 +1,6 @@ +# 1.0.5 +- Upgrade faster-whisper and support (large v3) turbo models + # 1.0.4 - Add environment variables to control decoding strategy (USE_ACCURATE=0/1) - Add environment variables to control streaming performance (STREAMING_MIN_CHUNK_SIZE, STREAMING_BUFFER_TRIMMING_SEC) diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt index 87b5e80..251ce6a 100644 --- a/whisper/requirements.ctranslate2.txt +++ b/whisper/requirements.ctranslate2.txt @@ -11,7 +11,7 @@ regex requests>=2.26.0 wavio>=0.0.4 websockets -auditok -#faster_whisper==1.0.1 -# This is version faster_whisper==1.0.1 + option for (persistent) prompt + fix for large-v3 -git+https://github.com/linto-ai/faster-whisper.git \ No newline at end of file +auditok<0.3.0 +# faster_whisper==1.1.0 +# vvv This is version faster_whisper==1.1.0 + option for (persistent) prompt + fix for large-v3 (and turbo) models +git+https://github.com/linto-ai/faster-whisper.git diff --git a/whisper/requirements.torch.txt b/whisper/requirements.torch.txt index e5f5f93..9219d1d 100644 --- a/whisper/requirements.torch.txt +++ b/whisper/requirements.torch.txt @@ -16,4 +16,4 @@ websockets whisper-timestamped onnxruntime torchaudio -auditok \ No newline at end of file +auditok<0.3.0 \ No newline at end of file diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index 6716215..7464d6d 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -39,8 +39,7 @@ def decode( compression_ratio_threshold: float = 2.4, prompt: str = default_prompt, ) -> dict: - if language is None: - language = get_language() + language = get_language(language) kwargs = copy.copy(locals()) kwargs.pop("model_and_alignementmodel") kwargs["model"], kwargs["alignment_model"] = model_and_alignementmodel diff --git a/whisper/stt/processing/load_model.py b/whisper/stt/processing/load_model.py index 1ca80f2..d417b09 100644 --- a/whisper/stt/processing/load_model.py +++ b/whisper/stt/processing/load_model.py @@ -60,6 +60,9 @@ def load_whisper_model(model_type_or_file, device="cpu", download_root=None): ) logger.info(f"CTranslate2 model in {output_dir}") if not os.path.isdir(output_dir): + + check_torch_installed() + from transformers.utils import cached_file import json @@ -96,8 +99,6 @@ def load_whisper_model(model_type_or_file, device="cpu", download_root=None): if hf_path is None: raise RuntimeError(f"Could not find pytorch_model.bin in {model_type_or_file}") - check_torch_installed() - # from ctranslate2.converters.transformers import TransformersConverter # converter = TransformersConverter( # model_type_or_file, diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py index 6dc9d99..638274e 100644 --- a/whisper/stt/processing/streaming.py +++ b/whisper/stt/processing/streaming.py @@ -13,6 +13,7 @@ ) from websockets.legacy.server import WebSocketServerProtocol from simple_websocket.ws import Server as WSServer +from .utils import get_language EOF_REGEX = re.compile(r' *\{.*"eof" *: *1.*\} *$') @@ -48,12 +49,13 @@ async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): logger.error("Failed to read stream configuration") await ws.close(reason="Failed to load configuration") model, _ = model_and_alignementmodel + language = get_language(config.get("language")) if USE_CTRANSLATE2: logger.info("Using ctranslate2 for decoding") - asr = FasterWhisperASR(model=model, lan="fr", beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) + asr = FasterWhisperASR(model=model, lan=language, beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) else: logger.info("Using whisper_timestamped for decoding") - asr = WhisperTimestampedASR(model=model, lan="fr", beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) + asr = WhisperTimestampedASR(model=model, lan=language, beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) online = OnlineASRProcessor( asr, logfile=sys.stderr, buffer_trimming=STREAMING_BUFFER_TRIMMING_SEC, vad=VAD, sample_rate=sample_rate, \ dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION @@ -105,12 +107,13 @@ def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): logger.error("Failed to read stream configuration") websocket_server.close() model, _ = model_and_alignementmodel + language = get_language(config.get("language")) if USE_CTRANSLATE2: logger.info("Using ctranslate2 for decoding") - asr = FasterWhisperASR(model=model, lan="fr", beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) + asr = FasterWhisperASR(model=model, lan=language, beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) else: logger.info("Using whisper_timestamped for decoding") - asr = WhisperTimestampedASR(model=model, lan="fr", beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) + asr = WhisperTimestampedASR(model=model, lan=language, beam_size=DEFAULT_BEAM_SIZE, best_of=DEFAULT_BEST_OF, temperature=DEFAULT_TEMPERATURE) online = OnlineASRProcessor( asr, logfile=sys.stderr, buffer_trimming=STREAMING_BUFFER_TRIMMING_SEC, vad=VAD, sample_rate=sample_rate, \ dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION diff --git a/whisper/stt/processing/utils.py b/whisper/stt/processing/utils.py index 106167a..1bfc142 100644 --- a/whisper/stt/processing/utils.py +++ b/whisper/stt/processing/utils.py @@ -48,14 +48,18 @@ def get_device(): return device, use_gpu -def get_language(): +def get_language(language = None): """ Get the language from the environment variable LANGUAGE, and format as expected by Whisper. """ - language = os.environ.get("LANGUAGE", "*") + if language is None: + language = os.environ.get("LANGUAGE", "*") # "fr-FR" -> "fr" (language-country code to ISO 639-1 code) - if len(language) > 2 and language[2] == "-": - language = language.split("-")[0] + language = language.split("-")[0].lower() + language_fields = language.split("-") + if len(language_fields) == 2: + language = language_fields[0] + language = language.lower() # "*" means "all languages" if language == "*": language = None @@ -64,11 +68,8 @@ def get_language(): language = {v: k for k, v in LANGUAGES.items()}.get(language.lower(), language) # Raise an exception for unknown languages if language not in LANGUAGES: - available_languages = ( - list(LANGUAGES.keys()) - + [k[0].upper() + k[1:] for k in LANGUAGES.values()] - + ["*", None] - ) + available_languages = [f"{k}({v})" for k, v in LANGUAGES.items()] + available_languages.append("*") raise ValueError( f"Language '{language}' is not available. Available languages are: {available_languages}" ) diff --git a/whisper/stt/processing/vad.py b/whisper/stt/processing/vad.py index 4c0fc8e..b705e99 100644 --- a/whisper/stt/processing/vad.py +++ b/whisper/stt/processing/vad.py @@ -272,8 +272,8 @@ def apply_folder_hack(): data = (audio.numpy() * 32767).astype(np.int16).tobytes() audio_duration = len(audio) / sample_rate - from auditok import split - segments = split( + import auditok + segments = auditok.split( data, sampling_rate=sample_rate, # sampling frequency in Hz channels=1, # number of channels @@ -287,10 +287,16 @@ def apply_folder_hack(): drop_trailing_silence=True, ) - segments = [ - {"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} - for s in segments - ] + if auditok.__version__ < "0.3.0": + segments = [ + {"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} + for s in segments + ] + else: + segments = [ + {"start": s.start * sample_rate, "end": s.end * sample_rate} + for s in segments + ] else: raise ValueError(f"Got unexpected VAD method {method}")