diff --git a/worker/pdm.lock b/worker/pdm.lock index 94b0f637..acdbeb0f 100644 --- a/worker/pdm.lock +++ b/worker/pdm.lock @@ -468,6 +468,11 @@ version = "1.3.0" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" +[[package]] +name = "spectralcluster" +version = "0.2.21" +summary = "Spectral Clustering" + [[package]] name = "speechbrain" version = "0.5.14" @@ -603,7 +608,7 @@ summary = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" [metadata] lock_version = "4.1" -content_hash = "sha256:c2beace8191afbf59f190ff9b7ceab852db61c8f0aa9df6d84dd96d782a3f5c8" +content_hash = "sha256:9041a37b09fea377779dc546d2595166b3653309897f69cd159a67b38d39cc1a" [metadata.files] "ansicon 1.89.0" = [ @@ -1048,6 +1053,44 @@ content_hash = "sha256:c2beace8191afbf59f190ff9b7ceab852db61c8f0aa9df6d84dd96d78 {url = "https://files.pythonhosted.org/packages/f2/74/41037079732f1976616356acc13bddceacd5d0c60d77ce3b4c79ba230d27/protobuf-4.25.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:0f881b589ff449bf0b931a711926e9ddaad3b35089cc039ce1af50b21a4ae8cb"}, {url = "https://files.pythonhosted.org/packages/fe/6b/7f177e8d6fe4caa14f4065433af9f879d4fab84f0d17dcba7b407f6bd808/protobuf-4.25.1-cp310-abi3-win_amd64.whl", hash = "sha256:3497c1af9f2526962f09329fd61a36566305e6c72da2590ae0d7d1322818843b"}, ] +"pydantic 1.10.7" = [ + {url = "https://files.pythonhosted.org/packages/00/43/f15d991ce715a2e7a229ef7c2534527d6fe4e5d260a675bd06615a4ede82/pydantic-1.10.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:976cae77ba6a49d80f461fd8bba183ff7ba79f44aa5cfa82f1346b5626542f8e"}, + {url = "https://files.pythonhosted.org/packages/05/4e/92a0c1fd305f764801dba26182b08ccf72026766fc4451d88186185467f2/pydantic-1.10.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe2507b8ef209da71b6fb5f4e597b50c5a34b78d7e857c4f8f3115effaef5fe"}, + {url = "https://files.pythonhosted.org/packages/07/3a/5bc906697c9aa0f0fc28f81ec25995315c999fb6df7b29e56a49b08009a3/pydantic-1.10.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0cfe895a504c060e5d36b287ee696e2fdad02d89e0d895f83037245218a87fe"}, + {url = "https://files.pythonhosted.org/packages/14/60/08f4b0a87561f64305002dffc5db2078043d46ed213e730a92e16840b120/pydantic-1.10.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6434b49c0b03a51021ade5c4daa7d70c98f7a79e95b551201fff682fc1661245"}, + {url = "https://files.pythonhosted.org/packages/21/ab/d7d0f74be71041507fe7ab1a61a71b251fc7667e720323b1f51a039370bb/pydantic-1.10.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c230c0d8a322276d6e7b88c3f7ce885f9ed16e0910354510e0bae84d54991143"}, + {url = "https://files.pythonhosted.org/packages/2e/97/e1e06d17f0f928083c660f6750b321797371ebd43aa16eda0ae80a4d3a7c/pydantic-1.10.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:abfb7d4a7cd5cc4e1d1887c43503a7c5dd608eadf8bc615413fc498d3e4645cd"}, + {url = "https://files.pythonhosted.org/packages/31/9e/32896df239096e0052e390e90eb0d374367e74bf7ce603a62841310c34c7/pydantic-1.10.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:01aea3a42c13f2602b7ecbbea484a98169fb568ebd9e247593ea05f01b884b2e"}, + {url = "https://files.pythonhosted.org/packages/34/d8/fd31b8172643cbf2cfd42398cba1406ea47ca1268f5e7ba48227f06c61a6/pydantic-1.10.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cc1dde4e50a5fc1336ee0581c1612215bc64ed6d28d2c7c6f25d2fe3e7c3e918"}, + {url = "https://files.pythonhosted.org/packages/38/cb/21afb81e5b3270cf5504543fb94a0d7734c4536b98c893701842602f9da0/pydantic-1.10.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f4a2b50e2b03d5776e7f21af73e2070e1b5c0d0df255a827e7c632962f8315af"}, + {url = "https://files.pythonhosted.org/packages/42/dc/092da33080729a95805e73084abf7cc064de7ae64462d1081859b2c1b7e2/pydantic-1.10.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d75ae19d2a3dbb146b6f324031c24f8a3f52ff5d6a9f22f0683694b3afcb16fb"}, + {url = "https://files.pythonhosted.org/packages/43/5f/e53a850fd32dddefc998b6bfcbda843d4ff5b0dcac02a92e414ba6c97d46/pydantic-1.10.7.tar.gz", hash = "sha256:cfc83c0678b6ba51b0532bea66860617c4cd4251ecf76e9846fa5a9f3454e97e"}, + {url = "https://files.pythonhosted.org/packages/5e/06/a6b6a325b4085558d48f8804433b523bf31b62e8bcad6a9f8537418240d6/pydantic-1.10.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0f85904f73161817b80781cc150f8b906d521fa11e3cdabae19a581c3606209"}, + {url = "https://files.pythonhosted.org/packages/67/a9/f4fde01bb028c2afd0bd053ba440f7aeb609a9dc85f5d2d41a937526dbe8/pydantic-1.10.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:80b1fab4deb08a8292d15e43a6edccdffa5377a36a4597bb545b93e79c5ff0a5"}, + {url = "https://files.pythonhosted.org/packages/67/ac/ff5f7eca22bf58dbecfd266597e15b1ec7ddc68b886157a2095a25eedb17/pydantic-1.10.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:68792151e174a4aa9e9fc1b4e653e65a354a2fa0fed169f7b3d09902ad2cb6f1"}, + {url = "https://files.pythonhosted.org/packages/73/f9/860473019e228ac0b12e5cccecc086ce1f7e41d5f1482b64b9454a528e4f/pydantic-1.10.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae150a63564929c675d7f2303008d88426a0add46efd76c3fc797cd71cb1b46f"}, + {url = "https://files.pythonhosted.org/packages/7e/2f/05c7f8dbd1de1542d7560b5e7b5aeb7d58558af2262010f8de9abb466be1/pydantic-1.10.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cf135c46099ff3f919d2150a948ce94b9ce545598ef2c6c7bf55dca98a304b52"}, + {url = "https://files.pythonhosted.org/packages/81/1b/04ce5303aee97af30b94c45699ed228b8ba6ba64c972efac184fb9a566f3/pydantic-1.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:516f1ed9bc2406a0467dd777afc636c7091d71f214d5e413d64fef45174cfc7a"}, + {url = "https://files.pythonhosted.org/packages/83/f2/b86db67c476177ec73fce0ea87e3fa0fd686c0602efbd4e42e5ccdb2bab9/pydantic-1.10.7-cp37-cp37m-win_amd64.whl", hash = "sha256:82dffb306dd20bd5268fd6379bc4bfe75242a9c2b79fec58e1041fbbdb1f7914"}, + {url = "https://files.pythonhosted.org/packages/8a/64/db1aafc37fab0dad89e0a27f120a18f2316fca704e9f95096ade47b933ac/pydantic-1.10.7-cp310-cp310-win_amd64.whl", hash = "sha256:a7cd2251439988b413cb0a985c4ed82b6c6aac382dbaff53ae03c4b23a70e80a"}, + {url = "https://files.pythonhosted.org/packages/8a/9b/4a6e7f721e54269966be968b7672f23b69d396ff59af7be6ea2e7bc30d0b/pydantic-1.10.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d2a5ebb48958754d386195fe9e9c5106f11275867051bf017a8059410e9abf1f"}, + {url = "https://files.pythonhosted.org/packages/8d/e1/d9219c4e4161a511158e531a84aa719087064d208c2bf87df5c58812f190/pydantic-1.10.7-py3-none-any.whl", hash = "sha256:0cd181f1d0b1d00e2b705f1bf1ac7799a2d938cce3376b8007df62b29be3c2c6"}, + {url = "https://files.pythonhosted.org/packages/91/b8/e02d21709db955b92125059d6f80a1a543f9cc9f60ef212621514462b4e9/pydantic-1.10.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c15582f9055fbc1bfe50266a19771bbbef33dd28c45e78afbe1996fd70966c2a"}, + {url = "https://files.pythonhosted.org/packages/a0/ef/9b9a6c4f2e520c84c86908105bdec18a06449be0b2ec5c73526eba141402/pydantic-1.10.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d45fc99d64af9aaf7e308054a0067fdcd87ffe974f2442312372dfa66e1001d"}, + {url = "https://files.pythonhosted.org/packages/a4/cb/16648745548e4c18f4b98b7e323bbac698e77cd8fc250a6b2ff83688c95f/pydantic-1.10.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:464855a7ff7f2cc2cf537ecc421291b9132aa9c79aef44e917ad711b4a93163b"}, + {url = "https://files.pythonhosted.org/packages/aa/64/1b66f84ffe07562366c5ae87e83f0b3871afefd97f0632091629e6d5cfb2/pydantic-1.10.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:670bb4683ad1e48b0ecb06f0cfe2178dcf74ff27921cdf1606e527d2617a81ee"}, + {url = "https://files.pythonhosted.org/packages/b8/b7/158fb5bf629f5a97c997711757fb14e831825872c6d091a41a769c9c69e4/pydantic-1.10.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ecbbc51391248116c0a055899e6c3e7ffbb11fb5e2a4cd6f2d0b93272118a209"}, + {url = "https://files.pythonhosted.org/packages/c8/70/8fe094a67a9431095069f6f9eb2a893e11fdaec8c1182016f53a535adfec/pydantic-1.10.7-cp38-cp38-win_amd64.whl", hash = "sha256:9f6f0fd68d73257ad6685419478c5aece46432f4bdd8d32c7345f1986496171e"}, + {url = "https://files.pythonhosted.org/packages/c8/f3/8b3d444bdce482d6c206ab2b3ad309ae699f3074fde3d5e54c786f22b8c0/pydantic-1.10.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c7f51861d73e8b9ddcb9916ae7ac39fb52761d9ea0df41128e81e2ba42886cd"}, + {url = "https://files.pythonhosted.org/packages/d1/a1/0aa23b545299186f6eabc7a5d289a951e6c033852938ae6673d75846e611/pydantic-1.10.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10a86d8c8db68086f1e30a530f7d5f83eb0685e632e411dbbcf2d5c0150e8dcd"}, + {url = "https://files.pythonhosted.org/packages/d5/f0/a1bab22b297fc4333d496b34e0db42bc33c85c4b0e7e7a39da76fc65a643/pydantic-1.10.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:193924c563fae6ddcb71d3f06fa153866423ac1b793a47936656e806b64e24ca"}, + {url = "https://files.pythonhosted.org/packages/d6/59/8082b963e077ea4bec5bb85e8c0fc636e4e7b3484e6a8ceac94e743e3b74/pydantic-1.10.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e79e999e539872e903767c417c897e729e015872040e56b96e67968c3b918b2d"}, + {url = "https://files.pythonhosted.org/packages/dc/01/03bb09fdb5c06075c5dc79d4c68885e87fdc7e8becf347d6a1ff8f890f79/pydantic-1.10.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:701daea9ffe9d26f97b52f1d157e0d4121644f0fcf80b443248434958fd03dc3"}, + {url = "https://files.pythonhosted.org/packages/f1/bd/0dad4908e5f693b7951b68f435139ec583f5eebb3d75505e1efa0f2284fe/pydantic-1.10.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64d34ab766fa056df49013bb6e79921a0265204c071984e75a09cbceacbbdd5d"}, + {url = "https://files.pythonhosted.org/packages/f6/2d/0fc591686bc119d844f26268f503a7a504fbc9dd6a02e14aa42738c21fed/pydantic-1.10.7-cp39-cp39-win_amd64.whl", hash = "sha256:d71e69699498b020ea198468e2480a2f1e7433e32a3a99760058c6520e2bea7e"}, + {url = "https://files.pythonhosted.org/packages/fa/c2/3df79cd00e65678fce12e59e8c95378a992a93d7b9f9510d4f1f65df1936/pydantic-1.10.7-cp311-cp311-win_amd64.whl", hash = "sha256:b4a849d10f211389502059c33332e91327bc154acc1845f375a99eca3afa802d"}, + {url = "https://files.pythonhosted.org/packages/fd/66/3da2e7c0306251435bd61ae9da52db8a00672fdf2b2db1e3efe1692f41dd/pydantic-1.10.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:950ce33857841f9a337ce07ddf46bc84e1c4946d2a3bba18f8280297157a3fd1"}, +] "pyicu 2.11" = [ {url = "https://files.pythonhosted.org/packages/03/1b/800fce0236be0b8a99b3ccbb797786dd178028960b3fd65544e2d8bad5ac/PyICU-2.11.tar.gz", hash = "sha256:3ab531264cfe9132b3d2ac5d708da9a4649d25f6e6813730ac88cf040a08a844"}, ] @@ -1332,6 +1375,10 @@ content_hash = "sha256:c2beace8191afbf59f190ff9b7ceab852db61c8f0aa9df6d84dd96d78 {url = "https://files.pythonhosted.org/packages/c3/a0/5dba8ed157b0136607c7f2151db695885606968d1fae123dc3391e0cfdbf/sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, {url = "https://files.pythonhosted.org/packages/cd/50/d49c388cae4ec10e8109b1b833fd265511840706808576df3ada99ecb0ac/sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] +"spectralcluster 0.2.21" = [ + {url = "https://files.pythonhosted.org/packages/37/01/c94ada73cd9d32f2fb4a160b6615834fcf7d4809610fc27dc516e2b09fbe/spectralcluster-0.2.21.tar.gz", hash = "sha256:53a62b85b92900e02ba9196b895be57922750b72f273c7b28e77869cc0c7cbe6"}, + {url = "https://files.pythonhosted.org/packages/3b/02/6e235fad0d9ea5277fc1b2d967ba9c6fe66a7169a00561e10134a8507b18/spectralcluster-0.2.21-py3-none-any.whl", hash = "sha256:8935f235460b27942ba7d2e43ec8daf27a23589f4d7ad85ce1d91b06d0efa54e"}, +] "speechbrain 0.5.14" = [ {url = "https://files.pythonhosted.org/packages/64/ed/500aec1df26ee70fc341338091ff7b318cfc08908b404db4dc8dc14e06df/speechbrain-0.5.14-py3-none-any.whl", hash = "sha256:f1982de5a2bd4e1e285156b329c20448f4fe98771e086b50c1fbef773cf4498d"}, {url = "https://files.pythonhosted.org/packages/73/24/6ea470445382be56ebf843e62bed24a229cc6e7e1995eb783ff5baa75181/speechbrain-0.5.14.tar.gz", hash = "sha256:4363dc8d2a10c51ccbb4b9756501c45bc155f4889685830e6ea846a7b5d4641e"}, diff --git a/worker/pyproject.toml b/worker/pyproject.toml index 2c0c5866..c421ff8e 100644 --- a/worker/pyproject.toml +++ b/worker/pyproject.toml @@ -10,6 +10,7 @@ authors = [ ] dependencies = [ + "spectralcluster>=0.2.21", "numpy>=1.23.5", "pydantic[dotenv]>=1.10.7", "transformers>=4.26.1", diff --git a/worker/transcribee_worker/identify_speakers.py b/worker/transcribee_worker/identify_speakers.py index f1bf2a13..b2bc2e4e 100644 --- a/worker/transcribee_worker/identify_speakers.py +++ b/worker/transcribee_worker/identify_speakers.py @@ -4,7 +4,7 @@ import numpy as np import numpy.typing as npt import torch -from sklearn.cluster import AgglomerativeClustering +from spectralcluster import refinement, spectral_clusterer from speechbrain.pretrained import EncoderClassifier from transcribee_proto.document import Document from transcribee_worker.types import ProgressCallbackType @@ -12,6 +12,28 @@ from .config import settings +RefinementName = refinement.RefinementName +RefinementOptions = refinement.RefinementOptions +ThresholdType = refinement.ThresholdType +SymmetrizeType = refinement.SymmetrizeType +SpectralClusterer = spectral_clusterer.SpectralClusterer + +ICASSP2018_REFINEMENT_SEQUENCE = [ + RefinementName.CropDiagonal, + RefinementName.RowWiseThreshold, + RefinementName.Symmetrize, + RefinementName.Diffuse, + RefinementName.RowWiseNormalize, +] + +icassp2018_refinement_options = RefinementOptions( + gaussian_blur_sigma=5, + p_percentile=0.95, + thresholding_soft_multiplier=0.05, + thresholding_type=ThresholdType.RowMax, + refinement_sequence=ICASSP2018_REFINEMENT_SEQUENCE, +) + async def identify_speakers( number_of_speakers: int | None, @@ -75,24 +97,26 @@ def time_to_sample(time: float | None): progress=len(segments) / (len(segments) + 1), ) - clustering = AgglomerativeClustering( - compute_full_tree=True, # type: ignore - linkage="complete", - n_clusters=number_of_speakers, # type: ignore - # distance_threshold curtesty of - # https://huggingface.co/pyannote/speaker-diarization/blob/369ac1852c71759894a48c9bb1c6f499a54862fe/config.yaml#L15 - distance_threshold=0.7153 if number_of_speakers is None else None, - metric="cosine", + clusterer = SpectralClusterer( + min_clusters=1 if number_of_speakers is None else number_of_speakers, + max_clusters=100 + if number_of_speakers is None + else number_of_speakers, # TODO(robin): arbitrary upper limit + autotune=None, + laplacian_type=None, + refinement_options=icassp2018_refinement_options, + custom_dist="cosine", ) - clustering.fit(np.array(embeddings)) + + labels = clusterer.predict(np.vstack(embeddings)) # we now re-shuffle the labels so that the first occuring speaker is 1, the second is 2, ... label_map = {} - for label in clustering.labels_: + for label in labels: if label not in label_map: label_map[label] = str(len(label_map) + 1) - for para, label in zip(doc.children, clustering.labels_): + for para, label in zip(doc.children, labels): para.speaker = automerge.Text(label_map[label]) - await alist(aiter(async_task(work))) + return await alist(aiter(async_task(work)))