Skip to content

Commit

Permalink
Add speaker diarization API for HarmonyOS. (#1609)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 10, 2024
1 parent 14944d8 commit 1bae408
Show file tree
Hide file tree
Showing 18 changed files with 279 additions and 79 deletions.
29 changes: 14 additions & 15 deletions harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
export {
listRawfileDir,
readWave,
readWaveFromBinary,
} from "libsherpa_onnx.so";
export { listRawfileDir, readWave, readWaveFromBinary, } from "libsherpa_onnx.so";

export {
CircularBuffer,
export { CircularBuffer,
SileroVadConfig,
SpeechSegment,
Vad,
VadConfig,
} from './src/main/ets/components/Vad';


export {
Samples,
export { Samples,
OfflineStream,
FeatureConfig,
OfflineTransducerModelConfig,
Expand All @@ -31,8 +25,7 @@ export {
OfflineRecognizer,
} from './src/main/ets/components/NonStreamingAsr';

export {
OnlineStream,
export { OnlineStream,
OnlineTransducerModelConfig,
OnlineParaformerModelConfig,
OnlineZipformer2CtcModelConfig,
Expand All @@ -43,17 +36,23 @@ export {
OnlineRecognizer,
} from './src/main/ets/components/StreamingAsr';

export {
OfflineTtsVitsModelConfig,
export { OfflineTtsVitsModelConfig,
OfflineTtsModelConfig,
OfflineTtsConfig,
OfflineTts,
TtsOutput,
TtsInput,
} from './src/main/ets/components/NonStreamingTts';

export {
SpeakerEmbeddingExtractorConfig,
export { SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingExtractor,
SpeakerEmbeddingManager,
} from './src/main/ets/components/SpeakerIdentification';

export { OfflineSpeakerSegmentationPyannoteModelConfig,
OfflineSpeakerSegmentationModelConfig,
OfflineSpeakerDiarizationConfig,
OfflineSpeakerDiarizationSegment,
OfflineSpeakerDiarization,
FastClusteringConfig,
} from './src/main/ets/components/NonStreamingSpeakerDiarization';
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ static SherpaOnnxFastClusteringConfig GetFastClusteringConfig(
static Napi::External<SherpaOnnxOfflineSpeakerDiarization>
CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

#if __OHOS__
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();

Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();

return {};
}
#else
if (info.Length() != 1) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
Expand All @@ -109,6 +120,7 @@ CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) {

return {};
}
#endif

if (!info[0].IsObject()) {
Napi::TypeError::New(env, "Expect an object as the argument")
Expand All @@ -129,8 +141,18 @@ CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) {
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_on, minDurationOn);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_off, minDurationOff);

#if __OHOS__
std::unique_ptr<NativeResourceManager,
decltype(&OH_ResourceManager_ReleaseNativeResourceManager)>
mgr(OH_ResourceManager_InitNativeResourceManager(env, info[1]),
&OH_ResourceManager_ReleaseNativeResourceManager);

const SherpaOnnxOfflineSpeakerDiarization *sd =
SherpaOnnxCreateOfflineSpeakerDiarizationOHOS(&c, mgr.get());
#else
const SherpaOnnxOfflineSpeakerDiarization *sd =
SherpaOnnxCreateOfflineSpeakerDiarization(&c);
#endif

if (c.segmentation.pyannote.model) {
delete[] c.segmentation.pyannote.model;
Expand Down Expand Up @@ -224,9 +246,17 @@ static Napi::Array OfflineSpeakerDiarizationProcessWrapper(

Napi::Float32Array samples = info[1].As<Napi::Float32Array>();

#if __OHOS__
// Note(fangjun): For unknown reasons on HarmonyOS, we need to divide it by
// sizeof(float) here
const SherpaOnnxOfflineSpeakerDiarizationResult *r =
SherpaOnnxOfflineSpeakerDiarizationProcess(
sd, samples.Data(), samples.ElementLength() / sizeof(float));
#else
const SherpaOnnxOfflineSpeakerDiarizationResult *r =
SherpaOnnxOfflineSpeakerDiarizationProcess(sd, samples.Data(),
samples.ElementLength());
#endif

int32_t num_segments =
SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,8 @@ export const speakerEmbeddingManagerVerify: (handle: object, obj: {name: string,
export const speakerEmbeddingManagerContains: (handle: object, name: string) => boolean;
export const speakerEmbeddingManagerNumSpeakers: (handle: object) => number;
export const speakerEmbeddingManagerGetAllSpeakers: (handle: object) => Array<string>;

export const createOfflineSpeakerDiarization: (config: object, mgr?: object) => object;
export const getOfflineSpeakerDiarizationSampleRate: (handle: object) => number;
export const offlineSpeakerDiarizationProcess: (handle: object, samples: Float32Array) => object;
export const offlineSpeakerDiarizationSetConfig: (handle: object, config: object) => void;
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ static Napi::Boolean WriteWaveWrapper(const Napi::CallbackInfo &info) {

Napi::Float32Array samples = obj.Get("samples").As<Napi::Float32Array>();
int32_t sample_rate = obj.Get("sampleRate").As<Napi::Number>().Int32Value();

#if __OHOS__
int32_t ok = SherpaOnnxWriteWave(
samples.Data(), samples.ElementLength() / sizeof(float), sample_rate,
info[0].As<Napi::String>().Utf8Value().c_str());
#else
int32_t ok =
SherpaOnnxWriteWave(samples.Data(), samples.ElementLength(), sample_rate,
info[0].As<Napi::String>().Utf8Value().c_str());
#endif

return Napi::Boolean::New(env, ok);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import {
createOfflineSpeakerDiarization,
getOfflineSpeakerDiarizationSampleRate,
offlineSpeakerDiarizationProcess,
offlineSpeakerDiarizationSetConfig,
} from 'libsherpa_onnx.so';

import { SpeakerEmbeddingExtractorConfig } from './SpeakerIdentification';

export class OfflineSpeakerSegmentationPyannoteModelConfig {
public model: string = '';
}

export class OfflineSpeakerSegmentationModelConfig {
public pyannote: OfflineSpeakerSegmentationPyannoteModelConfig = new OfflineSpeakerSegmentationPyannoteModelConfig();
public numThreads: number = 1;
public debug: boolean = false;
public provider: string = 'cpu';
}

export class FastClusteringConfig {
public numClusters: number = -1;
public threshold: number = 0.5;
}

export class OfflineSpeakerDiarizationConfig {
public segmentation: OfflineSpeakerSegmentationModelConfig = new OfflineSpeakerSegmentationModelConfig();
public embedding: SpeakerEmbeddingExtractorConfig = new SpeakerEmbeddingExtractorConfig();
public clustering: FastClusteringConfig = new FastClusteringConfig();
public minDurationOn: number = 0.2;
public minDurationOff: number = 0.5;
}

export class OfflineSpeakerDiarizationSegment {
public start: number = 0; // in secondspublic end: number = 0; // in secondspublic speaker: number =
0; // ID of the speaker; count from 0
}

export class OfflineSpeakerDiarization {
public config: OfflineSpeakerDiarizationConfig;
public sampleRate: number;
private handle: object;

constructor(config: OfflineSpeakerDiarizationConfig, mgr?: object) {
this.handle = createOfflineSpeakerDiarization(config, mgr);
this.config = config;

this.sampleRate = getOfflineSpeakerDiarizationSampleRate(this.handle);
}

/**
* samples is a 1-d float32 array. Each element of the array should be
* in the range [-1, 1].
*
* We assume its sample rate equals to this.sampleRate.
*
* Returns an array of object, where an object is
*
* {
* "start": start_time_in_seconds,
* "end": end_time_in_seconds,
* "speaker": an_integer,
* }
*/
process(samples: Float32Array): OfflineSpeakerDiarizationSegment {
return offlineSpeakerDiarizationProcess(this.handle, samples) as OfflineSpeakerDiarizationSegment;
}

setConfig(config: OfflineSpeakerDiarizationConfig) {
offlineSpeakerDiarizationSetConfig(this.handle, config);
this.config.clustering = config.clustering;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,15 @@ export class SpeakerEmbeddingExtractor {
}

createStream(): OnlineStream {
return new OnlineStream(
speakerEmbeddingExtractorCreateStream(this.handle));
return new OnlineStream(speakerEmbeddingExtractorCreateStream(this.handle));
}

isReady(stream: OnlineStream): boolean {
return speakerEmbeddingExtractorIsReady(this.handle, stream.handle);
}

compute(stream: OnlineStream, enableExternalBuffer: boolean = true): Float32Array {
return speakerEmbeddingExtractorComputeEmbedding(
this.handle, stream.handle, enableExternalBuffer);
return speakerEmbeddingExtractorComputeEmbedding(this.handle, stream.handle, enableExternalBuffer);
}
}

Expand Down Expand Up @@ -106,9 +104,7 @@ export class SpeakerEmbeddingManager {

addMulti(speaker: SpeakerNameWithEmbeddingList): boolean {
const c: SpeakerNameWithEmbeddingN = {
name: speaker.name,
vv: flatten(speaker.v),
n: speaker.v.length,
name: speaker.name, vv: flatten(speaker.v), n: speaker.v.length,
};
return speakerEmbeddingManagerAddListFlattened(this.handle, c);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ export class OnlineRecognizer {
}

getResult(stream: OnlineStream): OnlineRecognizerResult {
const jsonStr: string =
getOnlineStreamResultAsJson(this.handle, stream.handle);
const jsonStr: string = getOnlineStreamResultAsJson(this.handle, stream.handle);

let o = JSON.parse(jsonStr) as OnlineRecognizerResultJson;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ export class CircularBuffer {

// return a float32 array
get(startIndex: number, n: number, enableExternalBuffer: boolean = true): Float32Array {
return circularBufferGet(
this.handle, startIndex, n, enableExternalBuffer);
return circularBufferGet(this.handle, startIndex, n, enableExternalBuffer);
}

pop(n: number) {
Expand Down Expand Up @@ -93,8 +92,7 @@ export class Vad {
private handle: object;

constructor(config: VadConfig, bufferSizeInSeconds?: number, mgr?: object) {
this.handle =
createVoiceActivityDetector(config, bufferSizeInSeconds, mgr);
this.handle = createVoiceActivityDetector(config, bufferSizeInSeconds, mgr);
this.config = config;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class OfflineSpeakerDiarization {
}

setConfig(config) {
addon.offlineSpeakerDiarizationSetConfig(config);
addon.offlineSpeakerDiarizationSetConfig(this.handle, config);
this.config.clustering = config.clustering;
}
}
Expand Down
51 changes: 45 additions & 6 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1784,8 +1784,8 @@ struct SherpaOnnxOfflineSpeakerDiarizationResult {
sherpa_onnx::OfflineSpeakerDiarizationResult impl;
};

const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarization(
static sherpa_onnx::OfflineSpeakerDiarizationConfig
GetOfflineSpeakerDiarizationConfig(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config;

Expand Down Expand Up @@ -1820,6 +1820,22 @@ SherpaOnnxCreateOfflineSpeakerDiarization(

sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5);

if (sd_config.segmentation.debug || sd_config.embedding.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", sd_config.ToString().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str());
#endif
}

return sd_config;
}

const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarization(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
auto sd_config = GetOfflineSpeakerDiarizationConfig(config);

if (!sd_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
Expand All @@ -1831,10 +1847,6 @@ SherpaOnnxCreateOfflineSpeakerDiarization(
sd->impl =
std::make_unique<sherpa_onnx::OfflineSpeakerDiarization>(sd_config);

if (sd_config.segmentation.debug || sd_config.embedding.debug) {
SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str());
}

return sd;
}

Expand Down Expand Up @@ -2029,5 +2041,32 @@ SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS(
}

#endif // #if SHERPA_ONNX_ENABLE_TTS == 1
//
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarizationOHOS(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config,
NativeResourceManager *mgr) {
if (!mgr) {
return SherpaOnnxCreateOfflineSpeakerDiarization(config);
}

auto sd_config = GetOfflineSpeakerDiarizationConfig(config);

if (!sd_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}

SherpaOnnxOfflineSpeakerDiarization *sd =
new SherpaOnnxOfflineSpeakerDiarization;

sd->impl =
std::make_unique<sherpa_onnx::OfflineSpeakerDiarization>(mgr, sd_config);

return sd;
}

#endif // #if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1

#endif // #ifdef __OHOS__
5 changes: 5 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,11 @@ SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config,
NativeResourceManager *mgr);

SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarizationOHOS(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config,
NativeResourceManager *mgr);
#endif

#if defined(__GNUC__)
Expand Down
Loading

0 comments on commit 1bae408

Please sign in to comment.