From 31d61c0e05bf558062dd0b11c21a66de77466443 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Tue, 29 Nov 2022 15:34:51 -0800 Subject: [PATCH 1/8] [src] Several cuda decoder fixes. These affect both correctness and performance. - Add missing cudaStreamSynchronize() This was not caught before because we were running at smaller batch sizes, which allowed the init decoding kernels to run in parallel with the nnet3 kernels, and thus have completed at this point. At large enough batch sizes, no such parallelization is possible (all blocks of the GPU are occupied). - Faster host paged to pinned memory copy via multithreading. - Disable timing in cuda events for increased performance. Before (on A100 PCIe): Overall: Aggregate Total Time: 26.6364 Total Audio: 194525 RealTimeX: 7302.96 After (on A100 PCIe): Overall: Aggregate Total Time: 26.0323 Total Audio: 194525 RealTimeX: 7472.43 - In online decoder, Create writers before initializing cuda. CUDA initialization creates a lot of virtual memory (for unified virtual memory, if I understand correctly) that can cause errors if memory oversubscription is not set high enough when using the fork() syscall. The issue is further described here: https://groups.google.com/g/kaldi-help/c/3hc0xsRpqqY?pli=1 - Add cudaProfilerStart/Stop to online binary - Name H2H copy threads in NSight Systems. --- src/cudadecoder/batched-static-nnet3.cc | 2 +- ...hed-threaded-nnet3-cuda-online-pipeline.cc | 66 ++++++++++++++++--- ...ched-threaded-nnet3-cuda-online-pipeline.h | 22 ++++++- .../batched-threaded-nnet3-cuda-pipeline2.cc | 4 ++ src/cudadecoder/cuda-decoder-common.h | 2 - src/cudadecoder/cuda-decoder-kernels.cu | 2 +- src/cudadecoder/cuda-decoder.cc | 37 ++++++++--- src/cudadecoder/cuda-decoder.h | 6 +- src/cudadecoder/cuda-pipeline-common.h | 2 +- src/cudadecoder/thread-pool-light.h | 11 +++- .../batched-wav-nnet3-cuda-online.cc | 11 ++-- src/cudadecoderbin/batched-wav-nnet3-cuda2.cc | 6 +- 12 files changed, 138 insertions(+), 33 deletions(-) diff --git a/src/cudadecoder/batched-static-nnet3.cc b/src/cudadecoder/batched-static-nnet3.cc index aa9ddd0f859..1d60f76e82b 100644 --- a/src/cudadecoder/batched-static-nnet3.cc +++ b/src/cudadecoder/batched-static-nnet3.cc @@ -77,7 +77,7 @@ void BatchedStaticNnet3::PresetKernelParams() { } void BatchedStaticNnet3::Allocate() { - cudaEventCreate(&batch_slot_assignement_copy_evt_); + cudaEventCreate(&batch_slot_assignement_copy_evt_, cudaEventDisableTiming); d_all_context_frames_.Resize(nchannels_ * total_nnet_context_, input_dim_); d_batch_with_context_.Resize( max_batch_size_ * input_frames_per_chunk_with_context_, input_dim_); diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc index 6e78d7212fd..900dcbf6bc8 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc @@ -240,14 +240,57 @@ bool BatchedThreadedNnet3CudaOnlinePipeline::TryInitCorrID( void BatchedThreadedNnet3CudaOnlinePipeline::CompactWavesToMatrix( const std::vector> &wave_samples) { - for (int i = 0; i < wave_samples.size(); ++i) { - const SubVector &src = wave_samples[i]; - int size = src.Dim(); - n_samples_valid_[i] = size; - const BaseFloat *wave_src = src.Data(); - BaseFloat *wave_dst = h_all_waveform_.RowData(i); - std::memcpy(wave_dst, wave_src, size * sizeof(BaseFloat)); + nvtxRangePushA(__func__); + + if (!batching_copy_thread_pool_) { + for (int i = 0; i < wave_samples.size(); ++i) { + const SubVector &src = wave_samples[i]; + int size = src.Dim(); + n_samples_valid_[i] = size; + const BaseFloat *wave_src = src.Data(); + BaseFloat *wave_dst = h_all_waveform_.RowData(i); + std::memcpy(wave_dst, wave_src, size * sizeof(BaseFloat)); + } + } else { + const size_t batch_size = + KALDI_CUDA_DECODER_DIV_ROUND_UP(wave_samples.size(), + config_.num_batching_copy_threads); + + std::mutex m; + std::condition_variable cv; + + std::atomic tasks_remaining; + std::atomic_init(&tasks_remaining, KALDI_CUDA_DECODER_DIV_ROUND_UP(wave_samples.size(), batch_size)); + + for (size_t i = 0; i < wave_samples.size(); i += batch_size) { + + auto task = [i, this, &wave_samples, &m, &cv, &tasks_remaining, &batch_size](void *ignore1, uint64_t ignore2, void *ignore3) { + nvtxRangePush("CompactWavesToMatrix task"); + for (size_t j = i; j < std::min(i + batch_size, wave_samples.size()); ++j) { + const SubVector &src = wave_samples[j]; + int size = src.Dim(); + n_samples_valid_[j] = size; + const BaseFloat *wave_src = src.Data(); + BaseFloat *wave_dst = this->h_all_waveform_.RowData(j); + std::memcpy(wave_dst, wave_src, size * sizeof(BaseFloat)); + } + --tasks_remaining; + if (tasks_remaining.load() == 0) { + std::lock_guard lock(m); + cv.notify_one(); + } + nvtxRangePop(); + }; + batching_copy_thread_pool_->Push({task, nullptr, 0, nullptr}); + } + + // wait for all threads to finish + { + std::unique_lock lock(m); + cv.wait(lock, [&tasks_remaining](){ return tasks_remaining == 0; }); + } } + nvtxRangePop(); } void BatchedThreadedNnet3CudaOnlinePipeline::ComputeGPUFeatureExtraction( @@ -258,9 +301,11 @@ void BatchedThreadedNnet3CudaOnlinePipeline::ComputeGPUFeatureExtraction( // CopyFromMat syncs, avoiding it KALDI_ASSERT(d_all_waveform_.SizeInBytes() == h_all_waveform.SizeInBytes()); // Note : we could have smaller copies using the actual channels.size() + nvtxRangePushA("ComputeGPUFeatureExtractioncudaMemcpyAsync"); cudaMemcpyAsync(d_all_waveform_.Data(), h_all_waveform.Data(), h_all_waveform.SizeInBytes(), cudaMemcpyHostToDevice, cudaStreamPerThread); + nvtxRangePop(); KALDI_ASSERT(channels.size() == is_last_chunk.size()); KALDI_ASSERT(channels.size() == is_first_chunk.size()); @@ -348,6 +393,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::DecodeBatch( ListIChannelsInBatch(corr_ids, &channels_); // Compact in h_all_waveform_ to use the main DecodeBatch version + // this is slow CompactWavesToMatrix(wave_samples); DecodeBatch(corr_ids, h_all_waveform_, n_samples_valid_, is_first_chunk, @@ -575,6 +621,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::RunLatticeCallbacks( void BatchedThreadedNnet3CudaOnlinePipeline::RunCallbacksAndFinalize( const std::vector &corr_ids, const std::vector &channels, const std::vector &is_last_chunk) { + nvtxRangePushA("RunCallbacksAndFinalize"); // Reading endpoints, figuring out is_end_of_segment_ for (size_t i = 0; i < is_last_chunk.size(); ++i) { bool endpoint_detected = false; @@ -589,6 +636,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::RunCallbacksAndFinalize( RunBestPathCallbacks(corr_ids, channels); RunLatticeCallbacks(corr_ids, channels, is_last_chunk); + nvtxRangePop(); } void BatchedThreadedNnet3CudaOnlinePipeline::ListIChannelsInBatch( @@ -646,7 +694,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::InitDecoding( } if (should_reset_decoder) - init_decoding_list_channels_.push_back((channels)[i]); + init_decoding_list_channels_.push_back(channels[i]); } if (!init_decoding_list_channels_.empty()) @@ -655,6 +703,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::InitDecoding( void BatchedThreadedNnet3CudaOnlinePipeline::RunDecoder( const std::vector &channels, const std::vector &is_first_chunk) { + nvtxRangePushA("RunDecoder"); if (partial_hypotheses_) { // We're going to have to generate the partial hypotheses if (word_syms_ == nullptr) { @@ -690,6 +739,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::RunDecoder( (*end_points_)[i] = cuda_decoder_->EndpointDetected(ichannel); } } + nvtxRangePop(); } void BatchedThreadedNnet3CudaOnlinePipeline::ReadParametersFromModel() { diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h index 6608aa79dd8..ed12931b578 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h @@ -66,7 +66,8 @@ struct BatchedThreadedNnet3CudaOnlinePipelineConfig { determinize_lattice(true), num_decoder_copy_threads(2), use_gpu_feature_extraction(true), - reset_on_endpoint(false) {} + reset_on_endpoint(false), + num_batching_copy_threads(0) {} void Register(OptionsItf *po) { po->Register("max-batch-size", &max_batch_size, "The maximum execution batch size." @@ -88,6 +89,12 @@ struct BatchedThreadedNnet3CudaOnlinePipelineConfig { po->Register( "reset-on-endpoint", &reset_on_endpoint, "Reset a decoder channel when endpoint detected. Do not close stream"); + po->Register( + "batching-copy-threads", &num_batching_copy_threads, + "Number of threads to use for copying inputs on CPU into single pinned memory matrix. " + "0 means to just use the main thread. Recommend setting this to 8 because the memory " + "copy can starve the GPU of work." +); feature_opts.Register(po); decoder_opts.Register(po); @@ -101,6 +108,7 @@ struct BatchedThreadedNnet3CudaOnlinePipelineConfig { int num_decoder_copy_threads; bool use_gpu_feature_extraction; bool reset_on_endpoint; + int num_batching_copy_threads; OnlineNnet2FeaturePipelineConfig feature_opts; CudaDecoderConfig decoder_opts; @@ -121,6 +129,8 @@ struct BatchedThreadedNnet3CudaOnlinePipelineConfig { num_worker_threads = (num_worker_threads > 0) ? num_worker_threads : std::thread::hardware_concurrency(); + + KALDI_ASSERT(num_batching_copy_threads >= 0); } }; @@ -150,9 +160,15 @@ class BatchedThreadedNnet3CudaOnlinePipeline { word_syms_(NULL) { config_.compute_opts.CheckAndFixConfigs(am_nnet_->GetNnet().Modulus()); config_.CheckAndFixConfigs(); - Initialize(decode_fst); int num_worker_threads = config_.num_worker_threads; thread_pool_ = std::make_unique(num_worker_threads); + + int num_batching_copy_threads = config_.num_batching_copy_threads; + if (num_batching_copy_threads > 0) { + batching_copy_thread_pool_ = std::make_unique(num_batching_copy_threads); + } + + Initialize(decode_fst); } ~BatchedThreadedNnet3CudaOnlinePipeline(); @@ -503,6 +519,8 @@ class BatchedThreadedNnet3CudaOnlinePipeline { // destructor blocks until the thread pool is drained of work items. std::unique_ptr thread_pool_; + std::unique_ptr batching_copy_thread_pool_; + // The decoder owns thread(s) that reconstruct lattices transferred from the // device in a compacted form as arrays with offsets instead of pointers. std::unique_ptr cuda_decoder_; diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc index c076910672a..4186632d9a3 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc @@ -431,13 +431,17 @@ void BatchedThreadedNnet3CudaPipeline2::AcquireTasks() { void BatchedThreadedNnet3CudaPipeline2::ComputeTasks() { while (threads_running_) { + nvtxRangePushA("AcquireTasks"); if (current_tasks_.size() < max_batch_size_) AcquireTasks(); + nvtxRangePop(); if (current_tasks_.empty()) { // If we still have nothing to do, let's sleep a bit Sleep(kSleepForNewTask); continue; } + nvtxRangePushA("BuildBatch"); BuildBatchFromCurrentTasks(); + nvtxRangePop(); if (use_online_features_) cuda_online_pipeline_.DecodeBatch(batch_corr_ids_, batch_wave_samples_, diff --git a/src/cudadecoder/cuda-decoder-common.h b/src/cudadecoder/cuda-decoder-common.h index 388f6625a3f..eae1f36800c 100644 --- a/src/cudadecoder/cuda-decoder-common.h +++ b/src/cudadecoder/cuda-decoder-common.h @@ -139,8 +139,6 @@ #define KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(i, n) \ for (int i = blockIdx.y; i < (n); i += gridDim.y) -#define KALDI_CUDA_DECODER_DIV_ROUND_UP(a, b) ((a + b - 1) / b) - #define KALDI_CUDA_DECODER_1D_BLOCK 256 #define KALDI_CUDA_DECODER_LARGEST_1D_BLOCK 1024 #define KALDI_CUDA_DECODER_ONE_THREAD_BLOCK 1 diff --git a/src/cudadecoder/cuda-decoder-kernels.cu b/src/cudadecoder/cuda-decoder-kernels.cu index 3a835d02b76..76985f93299 100644 --- a/src/cudadecoder/cuda-decoder-kernels.cu +++ b/src/cudadecoder/cuda-decoder-kernels.cu @@ -1538,7 +1538,7 @@ __global__ void emitting_preprocess_and_list_extra_prev_tokens_step1_kernel( // Token index of one of the token which the lowest token.cost for that // state uint32_t state_best_int_cost_argmin; - GetArgFromPackedArgminUInt64(h_val.min_and_argmin_int_cost_u64, &state_best_int_cost_argmin); + GetArgFromPackedArgminUInt64(h_val.min_and_argmin_int_cost_u64, &state_best_int_cost_argmin); // Checking if we're the representative of that state representing_state = (main_q_idx == state_best_int_cost_argmin); diff --git a/src/cudadecoder/cuda-decoder.cc b/src/cudadecoder/cuda-decoder.cc index 1ec456ac32c..7d94e41d9e3 100644 --- a/src/cudadecoder/cuda-decoder.cc +++ b/src/cudadecoder/cuda-decoder.cc @@ -37,6 +37,11 @@ #include #include +#ifdef __linux__ +#include +#include +#endif // __linux__ + #include #include @@ -96,14 +101,18 @@ CudaDecoder::CudaDecoder(const CudaFst &fst, const CudaDecoderConfig &config, InitHostData(); InitDeviceData(); - KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&nnet3_done_evt_)); - KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&d2h_copy_acoustic_evt_)); - KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&d2h_copy_infotoken_evt_)); + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&nnet3_done_evt_, + cudaEventDisableTiming)); + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&d2h_copy_acoustic_evt_, + cudaEventDisableTiming)); + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&d2h_copy_infotoken_evt_, + cudaEventDisableTiming)); KALDI_DECODER_CUDA_API_CHECK_ERROR( - cudaEventCreate(&d2h_copy_extra_prev_tokens_evt_)); + cudaEventCreate(&d2h_copy_extra_prev_tokens_evt_, cudaEventDisableTiming)); KALDI_DECODER_CUDA_API_CHECK_ERROR( - cudaEventCreate(&concatenated_data_ready_evt_)); - KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&lane_offsets_ready_evt_)); + cudaEventCreate(&concatenated_data_ready_evt_, cudaEventDisableTiming)); + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaEventCreate(&lane_offsets_ready_evt_, + cudaEventDisableTiming)); ComputeInitialChannel(); --nchannels_; // removing the special initial channel from the count @@ -682,6 +691,7 @@ void CudaDecoder::PostProcessingMainQueue() { } void CudaDecoder::CopyMainQueueDataToHost() { + nvtxRangePushA("CopyMainQueueDataToHost"); CU_SAFE_CALL(cudaEventRecord(concatenated_data_ready_evt_, compute_st_)); // The copies on copy_st will wait on compute_st_. CU_SAFE_CALL(cudaStreamWaitEvent(copy_st_, concatenated_data_ready_evt_, 0)); @@ -719,8 +729,8 @@ void CudaDecoder::CopyMainQueueDataToHost() { ++num_frames_decoded_[ichannel]; } } - LaunchH2HCopies(); + nvtxRangePop(); } void CudaDecoder::LaunchD2HCopies() { @@ -838,6 +848,8 @@ void CudaDecoder::AdvanceDecoding( h_lanes_counters_.lane(ilane)->loglikelihoods = lanes_assignements[ilane].second; } + // Make sure that InitDecoding() has completed. + CU_SAFE_CALL(cudaStreamSynchronize(compute_st_)); LoadChannelsStateToLanes(channels); KALDI_ASSERT(nlanes_used_ > 0); CU_SAFE_CALL(cudaMemcpyAsync(d_lanes_counters_.MutableData(), @@ -845,6 +857,9 @@ void CudaDecoder::AdvanceDecoding( nlanes_used_ * sizeof(*h_lanes_counters_.lane(0)), cudaMemcpyHostToDevice, compute_st_)); // compute_st_ will wait for nnet3 to complete + + // TODO: Pass this in as a parameter instead of assuming that the + // neural network computes on the per-thread default stream CU_SAFE_CALL(cudaEventRecord(nnet3_done_evt_, cudaStreamPerThread)); CU_SAFE_CALL(cudaStreamWaitEvent(compute_st_, nnet3_done_evt_, 0)); @@ -1820,6 +1835,7 @@ void CudaDecoder::CheckStaticAsserts() { } void CudaDecoder::LaunchH2HCopies() { + nvtxRangePushA("LaunchH2HCopies"); // Each H2H copy counter n_acoustic_h2h_copies_todo_.store(nlanes_used_ - 1); n_infotoken_h2h_copies_todo_.store(nlanes_used_ - 1); @@ -1844,10 +1860,14 @@ void CudaDecoder::LaunchH2HCopies() { } else { ComputeH2HCopies(); } + nvtxRangePop(); } void CudaDecoder::ComputeH2HCopiesCPUWorker() { // Run by a dedicated CPU thread +#ifdef __linux__ + nvtxNameOsThread(syscall(SYS_gettid), "h2hcopies"); +#endif while (h2h_threads_running_) { ComputeH2HCopies(); } @@ -2086,9 +2106,10 @@ void CudaDecoder::SetThreadPoolAndStartCPUWorkers(ThreadPoolLight *thread_pool, KALDI_ASSERT(nworkers > 0); n_threads_used_ = nworkers; thread_pool_ = thread_pool; - for (int32 i = 0; i < nworkers; ++i) + for (int32 i = 0; i < nworkers; ++i) { cpu_dedicated_threads_.emplace_back(&CudaDecoder::ComputeH2HCopiesCPUWorker, this); + } } } // namespace cuda_decoder diff --git a/src/cudadecoder/cuda-decoder.h b/src/cudadecoder/cuda-decoder.h index de2bd09f47c..0b86a5dd9bc 100644 --- a/src/cudadecoder/cuda-decoder.h +++ b/src/cudadecoder/cuda-decoder.h @@ -559,8 +559,8 @@ class CudaDecoder { // // The auxiliary queue is used to store the raw output of ExpandArcs. // We then prune that aux queue (and apply max-active) and move the - // survival tokens in the main queue. Tokens stored in the main q can - // then be used to generate new tokens (using ExpandArcs) We also + // survival tokens into the main queue. Tokens stored in the main q can + // then be used to generate new tokens (using ExpandArcs). We also // generate more information about what's in the main_q at the end of a // frame (in PostProcessingMainQueue) // @@ -587,7 +587,7 @@ class CudaDecoder { // // The data linked with a channel contains the data of frame i we need // to remember to compute frame i+1. It is the list of tokens from frame - // i, with some additional info (ie the prefix sum of the emitting arcs + // i, with some additional info (i.e. the prefix sum of the emitting arcs // degrees from those tokens). We are only storing // d_main_q_state_and_cost_ as channel data because that's all we need // in a token to compute frame i+1. We don't need token.arc_idx or diff --git a/src/cudadecoder/cuda-pipeline-common.h b/src/cudadecoder/cuda-pipeline-common.h index 01dbd95c796..e6197774110 100644 --- a/src/cudadecoder/cuda-pipeline-common.h +++ b/src/cudadecoder/cuda-pipeline-common.h @@ -158,7 +158,7 @@ struct HostDeviceVector { HostDeviceVector( const size_t new_size = KALDI_CUDA_DECODER_AUDIO_HOST_DEVICE_BUFFER_SIZE) : h_data(NULL), d_data(NULL), size(new_size) { - cudaEventCreate(&evt); + cudaEventCreate(&evt, cudaEventDisableTiming); Reallocate(new_size); } diff --git a/src/cudadecoder/thread-pool-light.h b/src/cudadecoder/thread-pool-light.h index 1906ab8bbb4..55f1d5c9d77 100644 --- a/src/cudadecoder/thread-pool-light.h +++ b/src/cudadecoder/thread-pool-light.h @@ -23,6 +23,12 @@ #include #include +#ifdef __linux__ +#include +#include +#endif // __linux__ + + namespace kaldi { namespace cuda_decoder { @@ -30,7 +36,7 @@ constexpr double kSleepForWorkAvailable = 1e-3; constexpr double kSleepForWorkerAvailable = 1e-3; struct ThreadPoolLightTask { - void (*func_ptr)(void *, uint64_t, void *); + std::function func_ptr; void *obj_ptr; uint64_t arg1; void *arg2; @@ -90,6 +96,9 @@ class ThreadPoolLightWorker final { std::weak_ptr other_; void Work() { +#ifdef __linux__ + nvtxNameOsThread(syscall(SYS_gettid), "threadpool"); +#endif while (run_thread_) { bool got_task = queue_.TryPop(&curr_task_); if (!got_task) { diff --git a/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc b/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc index 1aba7144af1..750b14103ae 100644 --- a/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc +++ b/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc @@ -74,6 +74,10 @@ int main(int argc, char *argv[]) { CudaOnlineBinaryOptions opts; if (SetUpAndReadCmdLineOptions(argc, argv, &opts) != 0) return 1; + std::unique_ptr clat_writer; + std::unique_ptr ctm_writer; + OpenOutputHandles(opts.clat_wspecifier, &clat_writer, &ctm_writer); + TransitionModel trans_model; nnet3::AmNnetSimple am_nnet; fst::Fst *decode_fst; @@ -84,10 +88,6 @@ int main(int argc, char *argv[]) { delete decode_fst; if (word_syms) cuda_pipeline.SetSymbolTable(*word_syms); - std::unique_ptr clat_writer; - std::unique_ptr ctm_writer; - OpenOutputHandles(opts.clat_wspecifier, &clat_writer, &ctm_writer); - std::mutex output_writer_m_; if (!opts.write_lattice) { KALDI_LOG @@ -145,6 +145,7 @@ int main(int argc, char *argv[]) { std::uniform_real_distribution<> dis(0.0, 1.0); std::priority_queue streams; + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaProfilerStart()); nvtxRangePush("Global Timer"); Timer timer; @@ -312,7 +313,7 @@ int main(int argc, char *argv[]) { if (clat_writer) clat_writer->Close(); cudaDeviceSynchronize(); - + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaProfilerStop()); return 0; } catch (const std::exception &e) { std::cerr << e.what(); diff --git a/src/cudadecoderbin/batched-wav-nnet3-cuda2.cc b/src/cudadecoderbin/batched-wav-nnet3-cuda2.cc index 992b34598d2..b8e406a328b 100644 --- a/src/cudadecoderbin/batched-wav-nnet3-cuda2.cc +++ b/src/cudadecoderbin/batched-wav-nnet3-cuda2.cc @@ -163,6 +163,8 @@ int main(int argc, char *argv[]) { int32 num_task_submitted = 0, num_err = 0; double total_audio = 0; + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaProfilerStart()); + nvtxRangePush("Global Timer"); // starting timer here so we // can measure throughput @@ -242,7 +244,9 @@ int main(int argc, char *argv[]) { delete word_syms; // will delete if non-NULL. - cudaDeviceSynchronize(); + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaDeviceSynchronize()); + + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaProfilerStop()); return 0; } catch (const std::exception &e) { From 2c247a1939186c6958b8ba981e3ecf629cf79b8d Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Mon, 5 Dec 2022 14:03:34 -0800 Subject: [PATCH 2/8] async memory copy to speed up online decoding. --- .../cuda-online-pipeline-dynamic-batcher.h | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/cudadecoder/cuda-online-pipeline-dynamic-batcher.h b/src/cudadecoder/cuda-online-pipeline-dynamic-batcher.h index d0ac1ab0e89..1340eddd5b5 100644 --- a/src/cudadecoder/cuda-online-pipeline-dynamic-batcher.h +++ b/src/cudadecoder/cuda-online-pipeline-dynamic-batcher.h @@ -70,7 +70,31 @@ class CudaOnlinePipelineDynamicBatcher { Batch(int max_batch_size, int max_samps_per_chunk) { h_all_waveform.Resize(max_batch_size, max_samps_per_chunk, kUndefined, kStrideEqualNumCols); - // TODO use cudaHostRegister, check cudaDevAttrHostRegisterSupported + int device; + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaGetDevice(&device)); + int supports_cudaHostRegister; + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaDeviceGetAttribute(&supports_cudaHostRegister, + cudaDevAttrHostRegisterSupported, + device)); + if (supports_cudaHostRegister) { + KALDI_DECODER_CUDA_API_CHECK_ERROR(cudaHostRegister(h_all_waveform.Data(), + h_all_waveform.SizeInBytes(), + cudaHostRegisterDefault)); + } else { + KALDI_WARN << "Your device does not support cudaHostRegister(). Copying data to GPU will be slow."; + } + } + + ~Batch() { + int device; + CU_SAFE_CALL(cudaGetDevice(&device)); + int supports_cudaHostRegister; + CU_SAFE_CALL(cudaDeviceGetAttribute(&supports_cudaHostRegister, + cudaDevAttrHostRegisterSupported, + device)); + if (supports_cudaHostRegister) { + CU_SAFE_CALL(cudaHostUnregister(h_all_waveform.Data())); + } } void Clear() { From abbaa6e4f214bfbe27b439db0928b943e5e6b3a2 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Mon, 5 Dec 2022 14:06:45 -0800 Subject: [PATCH 3/8] Decrease latency by doing partial hypothesis work on host at the same time as cuda calls on the device. --- src/cudadecoder/cuda-decoder.cc | 98 ++++++++++++++++++++--------- src/cudadecoder/cuda-decoder.h | 9 ++- src/cudadecoder/thread-pool-light.h | 4 +- 3 files changed, 79 insertions(+), 32 deletions(-) diff --git a/src/cudadecoder/cuda-decoder.cc b/src/cudadecoder/cuda-decoder.cc index 7d94e41d9e3..e2f53301bf0 100644 --- a/src/cudadecoder/cuda-decoder.cc +++ b/src/cudadecoder/cuda-decoder.cc @@ -350,6 +350,7 @@ CudaDecoder::~CudaDecoder() noexcept(false) { // Stop h2h tasks. WaitForInitDecodingH2HCopies(); WaitForH2HCopies(); + WaitForPartialHypotheses(); h2h_threads_running_ = false; n_h2h_main_task_todo_cv_.notify_all(); for (std::thread &thread : cpu_dedicated_threads_) thread.join(); @@ -459,6 +460,7 @@ void CudaDecoder::InitDecoding(const std::vector &channels) { std::lock_guard lk(n_init_decoding_h2h_task_not_done_mutex_); n_init_decoding_h2h_task_not_done_ += channels.size(); } + WaitForPartialHypotheses(); // good for (ChannelId ichannel : channels) { ChannelCounters &channel_counters = h_channels_counters_[ichannel]; channel_counters.prev_main_q_narcs_and_end = @@ -601,7 +603,7 @@ void CudaDecoder::MoveConcatenatedCopyToVector( // Unpacking the concatenated vector into individual channel storage int32 beg = lanes_offsets[ilane]; int32 end = lanes_offsets[ilane + 1]; - auto &vec = (*vecvec)[ichannel]; + std::vector &vec = (*vecvec)[ichannel]; vec.insert(vec.end(), h_concat + beg, h_concat + end); } @@ -702,6 +704,7 @@ void CudaDecoder::CopyMainQueueDataToHost() { // Making sure the previous H2H copies are done WaitForInitDecodingH2HCopies(); WaitForH2HCopies(); + WaitForPartialHypotheses(); std::swap(h_extra_and_acoustic_cost_concat_tmp_, h_extra_and_acoustic_cost_concat_); @@ -730,6 +733,7 @@ void CudaDecoder::CopyMainQueueDataToHost() { } } LaunchH2HCopies(); + LaunchPartialHypotheses(); nvtxRangePop(); } @@ -934,16 +938,14 @@ void CudaDecoder::AdvanceDecoding( main_q_end); } SaveChannelsStateFromLanes(); - - // Waiting for partial path to be ready (if set) - // They are computed async - WaitForPartialHypotheses(); } +// waiting here... Should use condition variable, right? void CudaDecoder::WaitForPartialHypotheses() { if (!generate_partial_hypotheses_) return; while (n_partial_traceback_threads_not_done_ .load(std::memory_order_acquire) > 0) { + // this is bad Sleep(200e-6); } } @@ -1568,7 +1570,6 @@ void CudaDecoder::SwapPrevAndCurrLatticeMap( } void CudaDecoder::WaitForH2HCopies() { - Timer timer; std::unique_lock lk(n_h2h_task_not_done_mutex_); h2h_done_.wait(lk, [this] { return (n_h2h_task_not_done_ == 0); }); } @@ -1834,17 +1835,64 @@ void CudaDecoder::CheckStaticAsserts() { KALDI_COMPILE_TIME_ASSERT(KALDI_CUDA_DECODER_NONEM_LT_MAX_NARCS > 0); } +void CudaDecoder::LaunchPartialHypotheses() { + if (partial_traceback_) { + // performance killer. Need to wait until WaitForH2HCopies() is + // done in each individual launched thread, not in this main + // thread. + // the other issue is that the thread worker sleep time is too large. + n_partial_traceback_threads_todo_.store(nlanes_used_ - 1); + // necessary because the todo_ variable can reach 0 when the last + // one still isn't done. Ugh. + + // I assign to this in one thread, but then use a different + // thread... Is that a problem? + n_partial_traceback_threads_not_done_.store(thread_pool_ ? nlanes_used_ : 1); + + auto nlanes_used = nlanes_used_; + + auto launch = [this, nlanes_used]() { + WaitForInitDecodingH2HCopies(); + WaitForH2HCopies(); + + for (std::size_t i = 0; i < nlanes_used; ++i) { + thread_pool_->Push(ThreadPoolLightTask{ + [this](void *a0, uint64_t a1, void *a2) { + nvtxRangePush("PartialHypothesis"); + // I wait for these, so why am I getting an error? Should I lock just for safety? + int ilane; + if ((ilane = n_partial_traceback_threads_todo_.fetch_sub(1)) >= 0) { + // ERROR!!! Need to keep the previous value of lanes2channels_todo_ + int32 ichannel = lanes2channels_todo_[ilane]; + // std::lock_guard channel_lk(channel_lock_[ichannel]); + GeneratePartialPath(ilane, ichannel); + if (generate_partial_hypotheses_) { + std::stack> traceback_buffer; + BuildPartialHypothesisOutput(ichannel, &traceback_buffer); + } + if (endpointing_) { + EndpointDetected(ilane, ichannel); + } + h_all_channels_prev_best_path_traceback_head_[ichannel] = + h_best_path_traceback_head_[ilane]; + } + n_partial_traceback_threads_not_done_.fetch_sub(1, std::memory_order_release); + nvtxRangePop(); + }, nullptr, uint64_t(0), nullptr}); + } + }; + + std::thread t(launch); + t.detach(); + } +} + void CudaDecoder::LaunchH2HCopies() { nvtxRangePushA("LaunchH2HCopies"); // Each H2H copy counter n_acoustic_h2h_copies_todo_.store(nlanes_used_ - 1); n_infotoken_h2h_copies_todo_.store(nlanes_used_ - 1); n_extra_prev_tokens_h2h_copies_todo_.store(nlanes_used_ - 1); - if (partial_traceback_) { - n_partial_traceback_threads_todo_.store(nlanes_used_ - 1); - n_partial_traceback_threads_not_done_.store(thread_pool_ ? n_threads_used_ - : 1); - } { std::lock_guard n_h2h_not_done_lk(n_h2h_task_not_done_mutex_); n_h2h_task_not_done_ += thread_pool_ ? n_threads_used_ : 1; @@ -1867,6 +1915,7 @@ void CudaDecoder::ComputeH2HCopiesCPUWorker() { // Run by a dedicated CPU thread #ifdef __linux__ nvtxNameOsThread(syscall(SYS_gettid), "h2hcopies"); + pthread_setname_np(pthread_self(), "h2hcopies"); #endif while (h2h_threads_running_) { ComputeH2HCopies(); @@ -1894,7 +1943,7 @@ void CudaDecoder::GeneratePartialPath(LaneId ilane, ChannelId ichannel) { // Adding that link at the end of the partial path partial_hypotheses.emplace_back(curr_token_idx, arc_idx); // If this is the first link, we don't have to check that we're still on the - // same best path than before + // same best path as before if (partial_hypotheses.size() == 1) return; // Backtracking until we reconnect with our stored partial path @@ -1904,7 +1953,7 @@ void CudaDecoder::GeneratePartialPath(LaneId ilane, ChannelId ichannel) { auto it = std::prev(partial_hypotheses.end(), 2); // The new partial best path is not directly to the previous partial - // best path We need to backtrack until we reconnect with the previous + // best path. We need to backtrack until we reconnect with the previous // partial best path (or until we reach the root node) while (true) { @@ -2028,23 +2077,10 @@ void CudaDecoder::ComputeH2HCopies() { if (!h2h_threads_running_) return; int32 ilane; - if (partial_traceback_) { - std::stack> traceback_buffer_; - while ((ilane = n_partial_traceback_threads_todo_.fetch_sub(1)) >= 0) { - int32 ichannel = lanes2channels_todo_[ilane]; - GeneratePartialPath(ilane, ichannel); - if (generate_partial_hypotheses_) - BuildPartialHypothesisOutput(ichannel, &traceback_buffer_); - if (endpointing_) EndpointDetected(ilane, ichannel); - h_all_channels_prev_best_path_traceback_head_[ichannel] = - h_best_path_traceback_head_[ilane]; - } - n_partial_traceback_threads_not_done_.fetch_sub(1, - std::memory_order_release); - } // Waiting for the D2H copies. This is threadsafe // Step 1: acoustic costs CU_SAFE_CALL(cudaEventSynchronize(d2h_copy_acoustic_evt_)); + nvtxRangePush("acoustic copy"); while ((ilane = n_acoustic_h2h_copies_todo_.fetch_sub(1)) >= 0) { int32 ichannel = lanes2channels_todo_[ilane]; // Lock Channel @@ -2061,9 +2097,10 @@ void CudaDecoder::ComputeH2HCopies() { auto &vec = h_all_tokens_acoustic_cost_[ichannel]; vec.insert(vec.end(), ntokens_nonemitting, 0.0f); } - + nvtxRangePop(); // Step 2: infotoken CU_SAFE_CALL(cudaEventSynchronize(d2h_copy_infotoken_evt_)); + nvtxRangePush("infotoken copy"); while ((ilane = n_infotoken_h2h_copies_todo_.fetch_sub(1)) >= 0) { int32 ichannel = lanes2channels_todo_[ilane]; // Lock Channel @@ -2071,11 +2108,12 @@ void CudaDecoder::ComputeH2HCopies() { MoveConcatenatedCopyToVector(ilane, ichannel, h_main_q_end_lane_offsets_, h_infotoken_concat_, &h_all_tokens_info_); } - + nvtxRangePop(); // Step 3: // - extra prev tokens // - partial path and endpointing CU_SAFE_CALL(cudaEventSynchronize(d2h_copy_extra_prev_tokens_evt_)); + nvtxRangePush("extra prev tokens copy"); while ((ilane = n_extra_prev_tokens_h2h_copies_todo_.fetch_sub(1)) >= 0) { int32 ichannel = lanes2channels_todo_[ilane]; // Lock Channel @@ -2088,7 +2126,7 @@ void CudaDecoder::ComputeH2HCopies() { h_extra_and_acoustic_cost_concat_, &h_all_tokens_extra_prev_tokens_extra_and_acoustic_cost_); } - + nvtxRangePop(); // If we're the last cpu thread to complete the current tasks, notify // the main thread bool all_done; diff --git a/src/cudadecoder/cuda-decoder.h b/src/cudadecoder/cuda-decoder.h index 0b86a5dd9bc..b644a246adc 100644 --- a/src/cudadecoder/cuda-decoder.h +++ b/src/cudadecoder/cuda-decoder.h @@ -277,10 +277,13 @@ class CudaDecoder { void SetOutputFrameShiftInSeconds(BaseFloat f) { frame_shift_seconds_ = f; } + // here's how we get the partial hypotheses... Need to wait until we + // can do this for thread safety. void GetPartialHypothesis(ChannelId ichannel, PartialHypothesis **out) { KALDI_ASSERT(generate_partial_hypotheses_); + WaitForPartialHypotheses(); // No need to lock, all ops on h_all_channels_partial_hypotheses_out_ are - // done before returning InitDecoding or AdvanceDecoding + // done after calling WaitForPartialHypotheses() *out = &h_all_channels_partial_hypotheses_out_[ichannel]; } @@ -492,6 +495,7 @@ class CudaDecoder { // before returning void WaitForPartialHypotheses(); + void LaunchPartialHypotheses(); // Takes care of preparing the data for ComputeH2HCopies // and check whether we can use the threadpool or we have to do the work // on the current thread @@ -780,13 +784,16 @@ class CudaDecoder { CostType *h_acoustic_cost_concat_tmp_; InfoToken *h_extra_prev_tokens_concat_tmp_; // Offsets used in MoveConcatenatedCopyToVector + // offsets, so size is nlanes_ + 1! std::vector h_main_q_end_lane_offsets_; std::vector h_emitting_main_q_end_lane_offsets_; std::vector h_n_extra_prev_tokens_lane_offsets_; // Index of the best index for the last frame. Used by endpointing/partial // results + // indexed by lanes std::vector h_best_path_traceback_head_; + // indexed by channels std::vector h_all_channels_prev_best_path_traceback_head_; // Partial path so far on a given channel diff --git a/src/cudadecoder/thread-pool-light.h b/src/cudadecoder/thread-pool-light.h index 55f1d5c9d77..79e3b2cf3c9 100644 --- a/src/cudadecoder/thread-pool-light.h +++ b/src/cudadecoder/thread-pool-light.h @@ -32,7 +32,7 @@ namespace kaldi { namespace cuda_decoder { -constexpr double kSleepForWorkAvailable = 1e-3; +constexpr double kSleepForWorkAvailable = 1e-4; constexpr double kSleepForWorkerAvailable = 1e-3; struct ThreadPoolLightTask { @@ -98,6 +98,7 @@ class ThreadPoolLightWorker final { void Work() { #ifdef __linux__ nvtxNameOsThread(syscall(SYS_gettid), "threadpool"); + pthread_setname_np(pthread_self(), "threadpool"); #endif while (run_thread_) { bool got_task = queue_.TryPop(&curr_task_); @@ -114,6 +115,7 @@ class ThreadPoolLightWorker final { (curr_task_.func_ptr)(curr_task_.obj_ptr, curr_task_.arg1, curr_task_.arg2); } else { + // std::this_thread::yield(); Sleep(kSleepForWorkAvailable); // TODO } } From 0d31458caff9f8e3356a61800887e05679d9f5d7 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Mon, 5 Dec 2022 17:21:30 -0800 Subject: [PATCH 4/8] New Thread pool implementation. --- src/cudadecoder/Makefile | 3 +- ...hed-threaded-nnet3-cuda-online-pipeline.cc | 4 +- ...ched-threaded-nnet3-cuda-online-pipeline.h | 6 +- src/cudadecoder/thread-pool-cia.cc | 6 + src/cudadecoder/thread-pool-cia.h | 427 ++++++++++++++++++ 5 files changed, 441 insertions(+), 5 deletions(-) create mode 100644 src/cudadecoder/thread-pool-cia.cc create mode 100644 src/cudadecoder/thread-pool-cia.h diff --git a/src/cudadecoder/Makefile b/src/cudadecoder/Makefile index e2569e89ab7..0b69830f2ee 100644 --- a/src/cudadecoder/Makefile +++ b/src/cudadecoder/Makefile @@ -19,7 +19,8 @@ OBJFILES = cuda-decoder.o cuda-decoder-kernels.o cuda-fst.o \ batched-threaded-nnet3-cuda-pipeline2.o \ batched-static-nnet3.o batched-static-nnet3-kernels.o \ cuda-online-pipeline-dynamic-batcher.o decodable-cumatrix.o \ - cuda-pipeline-common.o lattice-postprocessor.o + cuda-pipeline-common.o lattice-postprocessor.o \ + thread-pool-cia.o LIBNAME = kaldi-cudadecoder diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc index 900dcbf6bc8..9f2ecdd592e 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc @@ -264,7 +264,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::CompactWavesToMatrix( for (size_t i = 0; i < wave_samples.size(); i += batch_size) { - auto task = [i, this, &wave_samples, &m, &cv, &tasks_remaining, &batch_size](void *ignore1, uint64_t ignore2, void *ignore3) { + auto task = [i, this, &wave_samples, &m, &cv, &tasks_remaining, &batch_size]() { nvtxRangePush("CompactWavesToMatrix task"); for (size_t j = i; j < std::min(i + batch_size, wave_samples.size()); ++j) { const SubVector &src = wave_samples[j]; @@ -281,7 +281,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::CompactWavesToMatrix( } nvtxRangePop(); }; - batching_copy_thread_pool_->Push({task, nullptr, 0, nullptr}); + batching_copy_thread_pool_->submit(task); } // wait for all threads to finish diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h index ed12931b578..b5b7f48097c 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h @@ -40,6 +40,8 @@ #include "nnet3/nnet-optimize.h" #include "online2/online-nnet2-feature-pipeline.h" +#include "cudadecoder/thread-pool-cia.h" + namespace kaldi { namespace cuda_decoder { @@ -165,7 +167,7 @@ class BatchedThreadedNnet3CudaOnlinePipeline { int num_batching_copy_threads = config_.num_batching_copy_threads; if (num_batching_copy_threads > 0) { - batching_copy_thread_pool_ = std::make_unique(num_batching_copy_threads); + batching_copy_thread_pool_ = std::make_unique(num_batching_copy_threads); } Initialize(decode_fst); @@ -519,7 +521,7 @@ class BatchedThreadedNnet3CudaOnlinePipeline { // destructor blocks until the thread pool is drained of work items. std::unique_ptr thread_pool_; - std::unique_ptr batching_copy_thread_pool_; + std::unique_ptr batching_copy_thread_pool_; // The decoder owns thread(s) that reconstruct lattices transferred from the // device in a compacted form as arrays with offsets instead of pointers. diff --git a/src/cudadecoder/thread-pool-cia.cc b/src/cudadecoder/thread-pool-cia.cc new file mode 100644 index 00000000000..4e294a5cd27 --- /dev/null +++ b/src/cudadecoder/thread-pool-cia.cc @@ -0,0 +1,6 @@ +#include + +namespace kaldi { +thread_local work_stealing_queue* work_stealing_thread_pool::local_work_queue; +thread_local unsigned int work_stealing_thread_pool::my_index; +} diff --git a/src/cudadecoder/thread-pool-cia.h b/src/cudadecoder/thread-pool-cia.h new file mode 100644 index 00000000000..bf9b1531d4e --- /dev/null +++ b/src/cudadecoder/thread-pool-cia.h @@ -0,0 +1,427 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace kaldi { + +class join_threads { + std::vector& threads; +public: + explicit join_threads(std::vector& threads_): threads(threads_) + {} + ~join_threads() { + for (unsigned int i = 0; i < threads.size(); ++i) { + if (threads[i].joinable()) { + threads[i].join(); + } + } + } +}; + +template +class threadsafe_queue { +private: + mutable std::mutex mut; + std::queue data_queue; + std::condition_variable data_cond; +public: + threadsafe_queue() {} + threadsafe_queue(const threadsafe_queue& other) { + std::lock_guard lk(other.mut); + // does not work if T is function_wrapper + data_queue = other.data_queue; + } + threadsafe_queue& operator=(const threadsafe_queue&) = delete; + template + typename std::enable_if::value && std::is_move_assignable::value, void>::type + push(T new_value) { + std::lock_guard lk(mut); + // There appears to be no reason not to use std::move here... + data_queue.push(std::move(new_value)); + data_cond.notify_one(); + } + template + typename std::enable_if::value && std::is_copy_assignable::value && !std::is_move_assignable::value, void>::type + push(T new_value) { + std::lock_guard lk(mut); + // There appears to be no reason not to use std::move here... + data_queue.push(new_value); + data_cond.notify_one(); + } + void wait_and_pop(T& value) + { + std::unique_lock lk(mut); + data_cond.wait(lk, [this]{return !data_queue.empty();}); + value = data_queue.front(); + data_queue.pop(); + } + std::unique_ptr wait_and_pop() { + std::unique_lock lk(mut); + data_cond.wait(lk, [this]{return !data_queue.empty();}); + std::unique_ptr res(std::make_unique(data_queue.front())); + data_queue.pop(); + return res; + } + template + typename std::enable_if::value && std::is_move_assignable::value, bool>::type + try_pop(T& value) { + std::lock_guard lk(mut); + if(data_queue.empty()) { + return false; + } + value = std::move(data_queue.front()); + data_queue.pop(); + return true; + } + template + typename std::enable_if::value && std::is_copy_assignable::value && !std::is_move_assignable::value, bool>::type + try_pop(T& value) { + std::lock_guard lk(mut); + if(data_queue.empty()) { + return false; + } + value = data_queue.front(); + data_queue.pop(); + return true; + } + template + typename std::enable_if::value && std::is_move_assignable::value, std::unique_ptr>::type + try_pop() { + std::lock_guard lk(mut); + if (data_queue.empty()) { + return std::unique_ptr(); + } + std::unique_ptr res(std::make_unique(data_queue.front())); + data_queue.pop(); + return res; + } + template + typename std::enable_if::value && std::is_copy_assignable::value && !std::is_move_assignable::value, std::unique_ptr>::type + try_pop() { + std::lock_guard lk(mut); + if (data_queue.empty()) { + return std::unique_ptr(); + } + std::unique_ptr res(std::make_unique(data_queue.front())); + data_queue.pop(); + return res; + } + bool empty() const { + std::lock_guard lk(mut); + return data_queue.empty(); + } +}; + +class thread_pool { + std::atomic_bool done; + threadsafe_queue> work_queue; + std::vector threads; + join_threads joiner; + // class PassKey { + // friend class thread_pool; + // PassKey() = default; + // ~PassKey() = default; + // }; +public: + void worker_thread(/*PassKey*/) { + while (!done) { + std::function task; + // wait_and_pop seems more efficient than try_pop... + if (work_queue.try_pop(task)) { + task(); + } else { + std::this_thread::yield(); + } + } + } + thread_pool(unsigned int const num_threads): done(false), joiner(threads) { + try { + for (unsigned int i = 0; i < num_threads;++i) { + threads.push_back(std::thread(&thread_pool::worker_thread, this/*, PassKey()*/)); + } + } catch(...) { + done = true; + throw; + } + } + + ~thread_pool() { + done = true; + } + + template + void submit(FunctionType f) { + work_queue.push(std::function(f)); + } +}; + + +// 9.2 + +class function_wrapper { + struct impl_base { + virtual void call()=0; + virtual ~impl_base() {} + }; + std::unique_ptr impl; + template + struct impl_type: impl_base + { + F f; + impl_type(F&& f_): f(std::move(f_)) {} + void call() { f(); } + }; +public: + template + function_wrapper(F&& f): impl(new impl_type(std::move(f))) {} + void operator()() {impl->call(); } + function_wrapper() = default; + function_wrapper(function_wrapper&& other): impl(std::move(other.impl)) {} + function_wrapper& operator=(function_wrapper&& other) + { + impl = std::move(other.impl); + return *this; + } + function_wrapper(const function_wrapper&) = delete; + function_wrapper(function_wrapper&) = delete; + function_wrapper& operator=(const function_wrapper&) = delete; +}; + + +class futures_thread_pool { + std::atomic_bool done; + threadsafe_queue work_queue; + std::vector threads; + join_threads joiner; +public: + void worker_thread() { + while (!done) { + function_wrapper task; + if (work_queue.try_pop(task)) { + task(); + } else { + std::this_thread::yield(); + } + } + } + futures_thread_pool(const unsigned int num_threads): done(false), joiner(threads) { + try { + for (unsigned int i = 0; i < num_threads;++i) { + threads.push_back(std::thread(&futures_thread_pool::worker_thread, this)); + } + } catch(...) { + done = true; + throw; + } + } + + ~futures_thread_pool() { + done = true; + } + + // can we include Args... args as well here? Don't think so... + template + std::future::type> + submit(FunctionType f) { + typedef typename std::result_of::type result_type; + std::packaged_task task(std::move(f)); + std::future res(task.get_future()); + work_queue.push(std::move(task)); + return res; + } +}; + +class thread_local_queue_thread_pool { + std::atomic_bool done; + std::vector threads; + join_threads joiner; + threadsafe_queue pool_work_queue; + typedef std::queue local_queue_type; + // why unique_ptr here? + static thread_local std::unique_ptr local_work_queue; + void run_pending_task() { + function_wrapper task; + if (local_work_queue && !local_work_queue->empty()) { + task = std::move(local_work_queue->front()); + local_work_queue->pop(); + task(); + } else if (pool_work_queue.try_pop(task)) { + task(); + } else { + std::this_thread::yield(); + } + } + +public: + void worker_thread() { + local_work_queue.reset(new local_queue_type); + + // spining here, unlike previous implementation... + while (!done) { + run_pending_task(); + } + } + + thread_local_queue_thread_pool(unsigned int const num_threads): done(false), joiner(threads) { + try { + for (unsigned int i = 0; i < num_threads;++i) { + threads.push_back(std::thread(&thread_local_queue_thread_pool::worker_thread, this)); + } + } catch(...) { + done = true; + throw; + } + } + + ~thread_local_queue_thread_pool() { + done = true; + } + + template + std::future::type> + submit(FunctionType f) { + typedef typename std::result_of::type result_type; + std::packaged_task task(f); + std::future res(task.get_future()); + if(local_work_queue) { + local_work_queue->push(std::move(task)); + } else { + pool_work_queue.push(std::move(task)); + } + return res; + } +}; + +class work_stealing_queue { +private: + typedef function_wrapper data_type; + std::deque the_queue; + mutable std::mutex the_mutex; +public: + work_stealing_queue() {} + work_stealing_queue(const work_stealing_queue& other) = delete; + work_stealing_queue& operator=(const work_stealing_queue& other) = delete; + void push(data_type data) + { + std::lock_guard lock(the_mutex); + the_queue.push_front(std::move(data)); + } + bool empty() const { + std::lock_guard lock(the_mutex); + return the_queue.empty(); + } + bool try_pop(data_type& res) { + std::lock_guard lock(the_mutex); + if (the_queue.empty()) { + return false; + } + res = std::move(the_queue.front()); + the_queue.pop_front(); + return true; + } + bool try_steal(data_type& res) { + std::lock_guard lock(the_mutex); + if (the_queue.empty()) { + return false; + } + res = std::move(the_queue.back()); + the_queue.pop_back(); + return true; + } +}; + +// namespace detail { +// thread_local work_stealing_queue* local_work_queue; +// thread_local unsigned int my_index; +// } + + +class work_stealing_thread_pool { + typedef function_wrapper task_type; + std::atomic_bool done; + threadsafe_queue pool_work_queue; + std::vector > queues; + std::vector threads; + join_threads joiner; + static thread_local work_stealing_queue* local_work_queue; + static thread_local unsigned int my_index; + bool pop_task_from_local_queue(task_type& task) { + return local_work_queue && local_work_queue->try_pop(task); + } + + bool pop_task_from_pool_queue(task_type &task) { + return pool_work_queue.try_pop(task); + } + + bool pop_task_from_other_thread_queue(task_type &task) { + for (unsigned int i = 0; i < queues.size(); ++i) { + unsigned int const index = (my_index + i + 1) % queues.size(); + if (queues[index]->try_steal(task)) { + return true; + } + } + return false; + } +public: + void worker_thread(unsigned int my_index_) { + my_index = my_index_; + local_work_queue = queues[my_index].get(); + while(!done) { + run_pending_task(); + } + } + + work_stealing_thread_pool(unsigned int thread_count): + done(false), joiner(threads) + { + try { + for (unsigned int i = 0; i ()); + } + for (unsigned int i = 0; i + std::future::type> + submit(FunctionType f) { + typedef typename std::result_of::type result_type; + std::packaged_task task(f); + std::future res(task.get_future()); + if (local_work_queue) { + local_work_queue->push(std::move(task)); + } else { + pool_work_queue.push(std::move(task)); + } + return res; + } + + void run_pending_task() { + task_type task; + if (pop_task_from_local_queue(task) || + pop_task_from_pool_queue(task) || + // O(#threads). No good if threads never submit work to the + // thread pool themselves... + pop_task_from_other_thread_queue(task)) { + task(); + } else { + std::this_thread::yield(); + } + } +}; + +} // namespace kaldi From dc95de196f59130e829433bd606b066a4cad1c1c Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 10 Dec 2022 09:55:35 -0800 Subject: [PATCH 5/8] Add RTFx calculation to online decoder. Note that the max RTFx in online mode is necessarily --num-parallel-streaming-channels --- src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc b/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc index 750b14103ae..d2c9551fee3 100644 --- a/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc +++ b/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc @@ -309,6 +309,16 @@ int main(int argc, char *argv[]) { KALDI_LOG << "Latency stats:"; PrintLatencyStats(latencies); + + double total_latency = std::accumulate(latencies.begin(), latencies.end(), 0.0); + auto sum_op = [](double accum, const auto& a){ + return accum + a->Duration(); + }; + double total_duration = opts.niterations * std::accumulate(all_wav.begin(), all_wav.end(), + 0.0, sum_op); + double rtf_x = total_duration / total_latency; + KALDI_LOG << "RTFx:" << rtf_x; + delete word_syms; if (clat_writer) clat_writer->Close(); From 501b9d947e90cbd2491c0f926730e5086c456b22 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 10 Dec 2022 10:24:37 -0800 Subject: [PATCH 6/8] Improve online performance of cuda decoder. Use a thread pool that sleeps when there is no data to retrieve. Sort data at the right pooint to improve cache performance. Remove spin locks with atomics. These cause slow downs compared to condition variables, in particular, because we cannot sleep accurate for 200 microseconds or less. (A 200 microsecond sleep turns out tot ake 250 microseconds). These delays cause unnecessary slow down. --- ...hed-threaded-nnet3-cuda-online-pipeline.cc | 8 +- ...ched-threaded-nnet3-cuda-online-pipeline.h | 22 +- src/cudadecoder/cuda-decoder.cc | 97 +++++---- src/cudadecoder/cuda-decoder.h | 12 +- src/cudadecoder/thread-pool-cia.cc | 20 ++ src/cudadecoder/thread-pool-cia.h | 104 ++++++++-- src/cudadecoder/thread-pool-light.h | 193 ------------------ src/cudadecoder/thread-pool.h | 169 --------------- 8 files changed, 171 insertions(+), 454 deletions(-) delete mode 100644 src/cudadecoder/thread-pool-light.h delete mode 100644 src/cudadecoder/thread-pool.h diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc index 9f2ecdd592e..2a6e4f3bd0e 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc @@ -330,9 +330,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::ComputeCPUFeatureExtraction( n_compute_features_not_done_.store(channels.size()); for (size_t i = 0; i < channels.size(); ++i) { - thread_pool_->Push( - {&BatchedThreadedNnet3CudaOnlinePipeline::ComputeOneFeatureWrapper, - this, i, 0}); // second argument "0" is not used + thread_pool_->submit(std::bind(&BatchedThreadedNnet3CudaOnlinePipeline::ComputeOneFeature, this, i)); } while (n_compute_features_not_done_.load(std::memory_order_acquire)) @@ -611,9 +609,7 @@ void BatchedThreadedNnet3CudaOnlinePipeline::RunLatticeCallbacks( // If q is not empty, it means we already have a task in the threadpool // for that channel it is important to run those task in FIFO order if // empty, run a new task - thread_pool_->Push( - {&BatchedThreadedNnet3CudaOnlinePipeline::FinalizeDecodingWrapper, - this, ichannel, /* ignored */ nullptr}); + thread_pool_->submit(std::bind(&BatchedThreadedNnet3CudaOnlinePipeline::FinalizeDecoding, this, ichannel)); } } } diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h index b5b7f48097c..3a377c2e51b 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h @@ -163,11 +163,11 @@ class BatchedThreadedNnet3CudaOnlinePipeline { config_.compute_opts.CheckAndFixConfigs(am_nnet_->GetNnet().Modulus()); config_.CheckAndFixConfigs(); int num_worker_threads = config_.num_worker_threads; - thread_pool_ = std::make_unique(num_worker_threads); + thread_pool_ = std::make_unique(num_worker_threads); int num_batching_copy_threads = config_.num_batching_copy_threads; if (num_batching_copy_threads > 0) { - batching_copy_thread_pool_ = std::make_unique(num_batching_copy_threads); + batching_copy_thread_pool_ = std::make_unique(num_batching_copy_threads); } Initialize(decode_fst); @@ -322,12 +322,6 @@ class BatchedThreadedNnet3CudaOnlinePipeline { // Used when features are computed on the host (CPU) on pool threads. void ComputeOneFeature(int element); - static void ComputeOneFeatureWrapper(void *obj, uint64_t element, - void *ignored) { - static_cast(obj) - ->ComputeOneFeature(element); - } - void RunNnet3(const std::vector &channels, const std::vector &d_features, const int feature_stride, @@ -361,14 +355,6 @@ class BatchedThreadedNnet3CudaOnlinePipeline { // it will call the utterance's callback when done void FinalizeDecoding(int32 ichannel); - // static wrapper for thread pool - static void FinalizeDecodingWrapper(void *obj, uint64_t ichannel64, - void *ignored) { - int32 ichannel = static_cast(ichannel64); - static_cast(obj) - ->FinalizeDecoding(ichannel); - } - // // Internal structs // @@ -519,9 +505,9 @@ class BatchedThreadedNnet3CudaOnlinePipeline { // The thread pool receives data from device and post-processes it. This class // destructor blocks until the thread pool is drained of work items. - std::unique_ptr thread_pool_; + std::unique_ptr thread_pool_; - std::unique_ptr batching_copy_thread_pool_; + std::unique_ptr batching_copy_thread_pool_; // The decoder owns thread(s) that reconstruct lattices transferred from the // device in a compacted form as arrays with offsets instead of pointers. diff --git a/src/cudadecoder/cuda-decoder.cc b/src/cudadecoder/cuda-decoder.cc index e2f53301bf0..c8054bf45c4 100644 --- a/src/cudadecoder/cuda-decoder.cc +++ b/src/cudadecoder/cuda-decoder.cc @@ -83,7 +83,9 @@ CudaDecoder::CudaDecoder(const CudaFst &fst, const CudaDecoderConfig &config, CU_SAFE_CALL(cudaStreamCreate(&compute_st_)); // Copies D2H of tokens for storage on host are done on // copy_st_, in parallel with compute_st_ - CU_SAFE_CALL(cudaStreamCreate(©_st_)); + int least_priority, greatest_priority; + CU_SAFE_CALL(cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); + CU_SAFE_CALL(cudaStreamCreateWithPriority(©_st_, cudaStreamDefault, greatest_priority)); // For all the allocating/initializing process // We create a special channel // containing the exact state a channel should have when starting a new @@ -433,6 +435,8 @@ void CudaDecoder::ComputeInitialChannel() { } void CudaDecoder::InitDecoding(const std::vector &channels) { + WaitForH2HCopies(); + WaitForPartialHypotheses(); // Cloning the init_channel_id_ channel into all channels in the // channels vec const int nlanes_used = channels.size(); @@ -460,7 +464,6 @@ void CudaDecoder::InitDecoding(const std::vector &channels) { std::lock_guard lk(n_init_decoding_h2h_task_not_done_mutex_); n_init_decoding_h2h_task_not_done_ += channels.size(); } - WaitForPartialHypotheses(); // good for (ChannelId ichannel : channels) { ChannelCounters &channel_counters = h_channels_counters_[ichannel]; channel_counters.prev_main_q_narcs_and_end = @@ -845,12 +848,18 @@ void CudaDecoder::AdvanceDecoding( // Setting the loglikelihoods pointers for that frame std::vector channels; // TODO channels.reserve(lanes_assignements.size()); - for (LaneId ilane = 0; ilane < lanes_assignements.size(); ++ilane) { - ChannelId ichannel = lanes_assignements[ilane].first; + std::vector> lanes_assignments_copy(lanes_assignements); + // sorting this makes lanes2channels_todo_ sorted as well. Since + // contiguous chunks of lanes2channels_todo_ are assigned to + // separate worker threads, this sorting speeds up CPU work by + // increasing cache data locality. + std::sort(lanes_assignments_copy.begin(), lanes_assignments_copy.end(), [](auto&& a, auto&& b){return a.first < b.first;}); + for (LaneId ilane = 0; ilane < lanes_assignments_copy.size(); ++ilane) { + ChannelId ichannel = lanes_assignments_copy[ilane].first; channels.push_back(ichannel); channel_to_compute_[ilane] = ichannel; h_lanes_counters_.lane(ilane)->loglikelihoods = - lanes_assignements[ilane].second; + lanes_assignments_copy[ilane].second; } // Make sure that InitDecoding() has completed. CU_SAFE_CALL(cudaStreamSynchronize(compute_st_)); @@ -871,9 +880,6 @@ void CudaDecoder::AdvanceDecoding( ResetForFrameAndEstimateCutoffKernel( KaldiCudaDecoderNumBlocks(1, nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, compute_st_, *h_device_params_, *h_kernel_params_); - // Reset max active status. If necessary, ApplyMaxActiveAndReduceBeam - // will switch it back on - compute_max_active_ = false; // Processing emitting arcs. We've done the preprocess stage at the end // of the previous frame @@ -943,11 +949,10 @@ void CudaDecoder::AdvanceDecoding( // waiting here... Should use condition variable, right? void CudaDecoder::WaitForPartialHypotheses() { if (!generate_partial_hypotheses_) return; - while (n_partial_traceback_threads_not_done_ - .load(std::memory_order_acquire) > 0) { - // this is bad - Sleep(200e-6); - } + std::unique_lock lk(n_partial_traceback_threads_not_done_mutex_); + n_partial_traceback_threads_not_done_cv_.wait(lk, [this]{ + return n_partial_traceback_threads_not_done_ == 0;} + ); } void CudaDecoder::CheckOverflow() { @@ -1837,34 +1842,28 @@ void CudaDecoder::CheckStaticAsserts() { void CudaDecoder::LaunchPartialHypotheses() { if (partial_traceback_) { - // performance killer. Need to wait until WaitForH2HCopies() is - // done in each individual launched thread, not in this main - // thread. - // the other issue is that the thread worker sleep time is too large. - n_partial_traceback_threads_todo_.store(nlanes_used_ - 1); - // necessary because the todo_ variable can reach 0 when the last - // one still isn't done. Ugh. - - // I assign to this in one thread, but then use a different - // thread... Is that a problem? - n_partial_traceback_threads_not_done_.store(thread_pool_ ? nlanes_used_ : 1); - auto nlanes_used = nlanes_used_; + const size_t num_tasks = thread_pool_->num_workers(); - auto launch = [this, nlanes_used]() { + { + std::lock_guard lk(n_partial_traceback_threads_not_done_mutex_); + KALDI_ASSERT(n_partial_traceback_threads_not_done_ == 0); + n_partial_traceback_threads_not_done_ = thread_pool_ ? num_tasks : 1; + } + + auto launch = [this, nlanes_used, num_tasks]() { WaitForInitDecodingH2HCopies(); WaitForH2HCopies(); - for (std::size_t i = 0; i < nlanes_used; ++i) { - thread_pool_->Push(ThreadPoolLightTask{ - [this](void *a0, uint64_t a1, void *a2) { - nvtxRangePush("PartialHypothesis"); - // I wait for these, so why am I getting an error? Should I lock just for safety? - int ilane; - if ((ilane = n_partial_traceback_threads_todo_.fetch_sub(1)) >= 0) { - // ERROR!!! Need to keep the previous value of lanes2channels_todo_ + const size_t batch_size = + KALDI_CUDA_DECODER_DIV_ROUND_UP(nlanes_used, + num_tasks); + for (size_t i = 0; i < num_tasks; i += 1) { + auto task = [this, nlanes_used, batch_size, i, num_tasks]() { + for (size_t ilane = i * batch_size; + ilane < std::min(size_t((i + 1) * batch_size), size_t(nlanes_used)); + ++ilane) { int32 ichannel = lanes2channels_todo_[ilane]; - // std::lock_guard channel_lk(channel_lock_[ichannel]); GeneratePartialPath(ilane, ichannel); if (generate_partial_hypotheses_) { std::stack> traceback_buffer; @@ -1876,14 +1875,20 @@ void CudaDecoder::LaunchPartialHypotheses() { h_all_channels_prev_best_path_traceback_head_[ichannel] = h_best_path_traceback_head_[ilane]; } - n_partial_traceback_threads_not_done_.fetch_sub(1, std::memory_order_release); - nvtxRangePop(); - }, nullptr, uint64_t(0), nullptr}); - } + { + std::lock_guard lk(n_partial_traceback_threads_not_done_mutex_); + --n_partial_traceback_threads_not_done_; + KALDI_ASSERT(n_partial_traceback_threads_not_done_ < num_tasks); + KALDI_ASSERT(n_partial_traceback_threads_not_done_ >= 0); + if (n_partial_traceback_threads_not_done_ == 0) { + n_partial_traceback_threads_not_done_cv_.notify_all(); + } + } + }; + thread_pool_->submit(task); + } }; - - std::thread t(launch); - t.detach(); + thread_pool_->submit(launch); } } @@ -2080,7 +2085,6 @@ void CudaDecoder::ComputeH2HCopies() { // Waiting for the D2H copies. This is threadsafe // Step 1: acoustic costs CU_SAFE_CALL(cudaEventSynchronize(d2h_copy_acoustic_evt_)); - nvtxRangePush("acoustic copy"); while ((ilane = n_acoustic_h2h_copies_todo_.fetch_sub(1)) >= 0) { int32 ichannel = lanes2channels_todo_[ilane]; // Lock Channel @@ -2097,10 +2101,8 @@ void CudaDecoder::ComputeH2HCopies() { auto &vec = h_all_tokens_acoustic_cost_[ichannel]; vec.insert(vec.end(), ntokens_nonemitting, 0.0f); } - nvtxRangePop(); // Step 2: infotoken CU_SAFE_CALL(cudaEventSynchronize(d2h_copy_infotoken_evt_)); - nvtxRangePush("infotoken copy"); while ((ilane = n_infotoken_h2h_copies_todo_.fetch_sub(1)) >= 0) { int32 ichannel = lanes2channels_todo_[ilane]; // Lock Channel @@ -2108,12 +2110,10 @@ void CudaDecoder::ComputeH2HCopies() { MoveConcatenatedCopyToVector(ilane, ichannel, h_main_q_end_lane_offsets_, h_infotoken_concat_, &h_all_tokens_info_); } - nvtxRangePop(); // Step 3: // - extra prev tokens // - partial path and endpointing CU_SAFE_CALL(cudaEventSynchronize(d2h_copy_extra_prev_tokens_evt_)); - nvtxRangePush("extra prev tokens copy"); while ((ilane = n_extra_prev_tokens_h2h_copies_todo_.fetch_sub(1)) >= 0) { int32 ichannel = lanes2channels_todo_[ilane]; // Lock Channel @@ -2126,7 +2126,6 @@ void CudaDecoder::ComputeH2HCopies() { h_extra_and_acoustic_cost_concat_, &h_all_tokens_extra_prev_tokens_extra_and_acoustic_cost_); } - nvtxRangePop(); // If we're the last cpu thread to complete the current tasks, notify // the main thread bool all_done; @@ -2139,7 +2138,7 @@ void CudaDecoder::ComputeH2HCopies() { } } -void CudaDecoder::SetThreadPoolAndStartCPUWorkers(ThreadPoolLight *thread_pool, +void CudaDecoder::SetThreadPoolAndStartCPUWorkers(futures_thread_pool *thread_pool, int32 nworkers) { KALDI_ASSERT(nworkers > 0); n_threads_used_ = nworkers; diff --git a/src/cudadecoder/cuda-decoder.h b/src/cudadecoder/cuda-decoder.h index b644a246adc..75f35743c97 100644 --- a/src/cudadecoder/cuda-decoder.h +++ b/src/cudadecoder/cuda-decoder.h @@ -42,7 +42,7 @@ #include "cudadecoder/cuda-decodable-itf.h" #include "cudadecoder/cuda-decoder-common.h" #include "cudadecoder/cuda-fst.h" -#include "cudadecoder/thread-pool-light.h" +#include "cudadecoder/thread-pool-cia.h" #include "fst/symbol-table.h" #include "online2/online-endpoint.h" @@ -333,7 +333,7 @@ class CudaDecoder { // InitDecodingH2HCopies For recurrent CPU work, such as // ComputeH2HCopies, we will use dedicated CPU threads We will launch // nworkers of those threads - void SetThreadPoolAndStartCPUWorkers(ThreadPoolLight *thread_pool, + void SetThreadPoolAndStartCPUWorkers(futures_thread_pool *thread_pool, int32 nworkers); // Used to generate partial results @@ -813,7 +813,6 @@ class CudaDecoder { std::vector has_reached_final_; std::vector>> list_finals_token_idx_and_cost_; - bool compute_max_active_; cudaEvent_t nnet3_done_evt_; cudaEvent_t d2h_copy_acoustic_evt_; cudaEvent_t d2h_copy_infotoken_evt_; @@ -868,7 +867,7 @@ class CudaDecoder { // read comments associated with must_replay_frame in GetRawLattice to // understand what it does CostType extra_cost_min_delta_; - ThreadPoolLight *thread_pool_; + futures_thread_pool *thread_pool_; std::vector cpu_dedicated_threads_; int32 n_threads_used_; std::vector lanes2channels_todo_; @@ -888,8 +887,9 @@ class CudaDecoder { //TODO(hugovbraun): unused: std::atomic active_wait_; // Used for sync on partial hypotheses tasks - std::atomic n_partial_traceback_threads_todo_; - std::atomic n_partial_traceback_threads_not_done_; + std::int32_t n_partial_traceback_threads_not_done_; + std::mutex n_partial_traceback_threads_not_done_mutex_; + std::condition_variable n_partial_traceback_threads_not_done_cv_; // Set to false in destructor to stop threads. volatile bool h2h_threads_running_; diff --git a/src/cudadecoder/thread-pool-cia.cc b/src/cudadecoder/thread-pool-cia.cc index 4e294a5cd27..d6e3a087874 100644 --- a/src/cudadecoder/thread-pool-cia.cc +++ b/src/cudadecoder/thread-pool-cia.cc @@ -1,3 +1,23 @@ +// cudadecoder/thread-pool-cia.cc +// +// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +// Daniel Galvez +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This code was modified from Chapter 10 of C++ Concurrency in +// Action, which offers its code under the Boost License. + #include namespace kaldi { diff --git a/src/cudadecoder/thread-pool-cia.h b/src/cudadecoder/thread-pool-cia.h index bf9b1531d4e..2fdb09f9d93 100644 --- a/src/cudadecoder/thread-pool-cia.h +++ b/src/cudadecoder/thread-pool-cia.h @@ -1,5 +1,26 @@ +// cudadecoder/thread-pool-cia.h +// +// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +// Daniel Galvez +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This code was modified from Chapter 10 of C++ Concurrency in +// Action, which offers its code under the Boost License. + #pragma once +#include #include #include #include @@ -7,6 +28,13 @@ #include #include +#ifdef __linux__ +#include +#include +#include +#endif // __linux__ + + namespace kaldi { class join_threads { @@ -29,14 +57,23 @@ class threadsafe_queue { mutable std::mutex mut; std::queue data_queue; std::condition_variable data_cond; + std::atomic done; public: - threadsafe_queue() {} - threadsafe_queue(const threadsafe_queue& other) { - std::lock_guard lk(other.mut); - // does not work if T is function_wrapper - data_queue = other.data_queue; - } + threadsafe_queue(): done(false) {} threadsafe_queue& operator=(const threadsafe_queue&) = delete; + + void mark_done() { + std::lock_guard lk(mut); + done = true; + data_cond.notify_all(); + } + + ~threadsafe_queue() { + if (!done) { + assert(false && "Must set to done to true before destroying threadsafe_queue."); + } + } + template typename std::enable_if::value && std::is_move_assignable::value, void>::type push(T new_value) { @@ -53,13 +90,35 @@ class threadsafe_queue { data_queue.push(new_value); data_cond.notify_one(); } - void wait_and_pop(T& value) + template + typename std::enable_if::value && std::is_move_assignable::value, bool>::type + wait_and_pop(T& value) { std::unique_lock lk(mut); - data_cond.wait(lk, [this]{return !data_queue.empty();}); - value = data_queue.front(); - data_queue.pop(); + data_cond.wait(lk, [this]{return !data_queue.empty() || done;}); + if (!data_queue.empty()) { + value = std::move(data_queue.front()); + data_queue.pop(); + return true; + } else { + return false; + } + } + template + typename std::enable_if::value && std::is_copy_assignable::value && !std::is_move_assignable::value, bool>::type + wait_and_pop(T& value) + { + std::unique_lock lk(mut); + data_cond.wait(lk, [this]{return !data_queue.empty() || done;}); + if (!data_queue.empty()) { + value = data_queue.front(); + data_queue.pop(); + return true; + } else { + return false; + } } + // TODO: return null pointer if done. TODO: Add move assign overload. std::unique_ptr wait_and_pop() { std::unique_lock lk(mut); data_cond.wait(lk, [this]{return !data_queue.empty();}); @@ -200,13 +259,21 @@ class futures_thread_pool { join_threads joiner; public: void worker_thread() { + #ifdef __linux__ + nvtxNameOsThread(syscall(SYS_gettid), "threadpool"); + pthread_setname_np(pthread_self(), "threadpool"); + #endif while (!done) { function_wrapper task; - if (work_queue.try_pop(task)) { + bool success = work_queue.wait_and_pop(task); + if (success) { task(); - } else { - std::this_thread::yield(); } + // if (work_queue.try_pop(task)) { + // task(); + // } else { + // std::this_thread::yield(); + // } } } futures_thread_pool(const unsigned int num_threads): done(false), joiner(threads) { @@ -221,6 +288,7 @@ class futures_thread_pool { } ~futures_thread_pool() { + work_queue.mark_done(); done = true; } @@ -234,6 +302,10 @@ class futures_thread_pool { work_queue.push(std::move(task)); return res; } + + size_t num_workers() const { + return threads.size(); + } }; class thread_local_queue_thread_pool { @@ -371,6 +443,12 @@ class work_stealing_thread_pool { void worker_thread(unsigned int my_index_) { my_index = my_index_; local_work_queue = queues[my_index].get(); + + #ifdef __linux__ + nvtxNameOsThread(syscall(SYS_gettid), "threadpool"); + pthread_setname_np(pthread_self(), "threadpool"); + #endif + while(!done) { run_pending_task(); } diff --git a/src/cudadecoder/thread-pool-light.h b/src/cudadecoder/thread-pool-light.h deleted file mode 100644 index 79e3b2cf3c9..00000000000 --- a/src/cudadecoder/thread-pool-light.h +++ /dev/null @@ -1,193 +0,0 @@ -// cudadecoder/cuda-decoder.h -// -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -// Hugo Braun, Justin Luitjens, Ryan Leary, Daniel Galvez -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_CUDADECODER_THREAD_POOL_LIGHT_H_ -#define KALDI_CUDADECODER_THREAD_POOL_LIGHT_H_ - -#include -#include -#include -#include - -#ifdef __linux__ -#include -#include -#endif // __linux__ - - -namespace kaldi { -namespace cuda_decoder { - -constexpr double kSleepForWorkAvailable = 1e-4; -constexpr double kSleepForWorkerAvailable = 1e-3; - -struct ThreadPoolLightTask { - std::function func_ptr; - void *obj_ptr; - uint64_t arg1; - void *arg2; -}; - -template -// Single producer, multiple consumer -class ThreadPoolLightSPMCQueue { - static constexpr unsigned int QUEUE_MASK = QUEUE_SIZE - 1; - std::vector tasks_{QUEUE_SIZE}; - std::atomic back_{0}; - std::atomic front_{0}; - static int inc(int curr) { return ((curr + 1) & QUEUE_MASK); } - - public: - ThreadPoolLightSPMCQueue() { - KALDI_COMPILE_TIME_ASSERT(QUEUE_SIZE > 1); - constexpr bool is_power_of_2 = ((QUEUE_SIZE & (QUEUE_SIZE - 1)) == 0); - KALDI_COMPILE_TIME_ASSERT(is_power_of_2); // validity of QUEUE_MASK - } - - bool TryPush(const ThreadPoolLightTask &task) { - int back = back_.load(std::memory_order_relaxed); - int next = inc(back); - if (next == front_.load(std::memory_order_acquire)) { - return false; // queue is full - } - tasks_[back] = task; - back_.store(next, std::memory_order_release); - - return true; - } - - bool TryPop(ThreadPoolLightTask *front_task) { - while (true) { - int front = front_.load(std::memory_order_relaxed); - if (front == back_.load(std::memory_order_acquire)) { - return false; // queue is empty - } - *front_task = tasks_[front]; - if (front_.compare_exchange_weak(front, inc(front), - std::memory_order_release)) { - return true; - } - } - } -}; - -class ThreadPoolLightWorker final { - // Multi consumer queue, because worker can steal work - ThreadPoolLightSPMCQueue<512> queue_; - // If this thread has no more work to do, it will try to steal work from - // other - std::thread thread_; - volatile bool run_thread_; - ThreadPoolLightTask curr_task_; - std::weak_ptr other_; - - void Work() { -#ifdef __linux__ - nvtxNameOsThread(syscall(SYS_gettid), "threadpool"); - pthread_setname_np(pthread_self(), "threadpool"); -#endif - while (run_thread_) { - bool got_task = queue_.TryPop(&curr_task_); - if (!got_task) { - if (auto other_sp = other_.lock()) { - got_task = other_sp->TrySteal(&curr_task_); - } - } - if (got_task) { - // Not calling func_ptr as a member function, - // because we need to specialize the arguments - // anyway (we may want to ignore arg2, for - // instance) Using a wrapper func - (curr_task_.func_ptr)(curr_task_.obj_ptr, curr_task_.arg1, - curr_task_.arg2); - } else { - // std::this_thread::yield(); - Sleep(kSleepForWorkAvailable); // TODO - } - } - } - - // Another worker can steal a task from this queue - // This is done so that a very long task computed by one thread does not - // hold the entire threadpool to complete a time-sensitive task - bool TrySteal(ThreadPoolLightTask *task) { return queue_.TryPop(task); } - - public: - ThreadPoolLightWorker() : run_thread_(true), other_() {} - ~ThreadPoolLightWorker() { - KALDI_ASSERT(!queue_.TryPop(&curr_task_)); - } - bool TryPush(const ThreadPoolLightTask &task) { - return queue_.TryPush(task); - } - void SetOtherWorkerToStealFrom( - const std::shared_ptr& other) { - other_ = other; - } - void Start() { - KALDI_ASSERT("Please call SetOtherWorkerToStealFrom() first" && - !other_.expired()); - thread_ = std::thread(&ThreadPoolLightWorker::Work, this); - } - void Stop() { - run_thread_ = false; - thread_.join(); - other_.reset(); - } -}; - -class ThreadPoolLight { - std::vector> workers_; - int curr_iworker_; // next call on tryPush will post work on this - // worker - public: - ThreadPoolLight(int32 nworkers = std::thread::hardware_concurrency()) - : workers_(nworkers), curr_iworker_(0) { - KALDI_ASSERT(nworkers > 1); - for (size_t i = 0; i < workers_.size(); ++i) { - workers_[i] = std::make_shared(); - } - for (size_t i = 0; i < workers_.size(); ++i) { - int iother = (i + nworkers / 2) % nworkers; - workers_[i]->SetOtherWorkerToStealFrom(workers_[iother]); - workers_[i]->Start(); - } - } - - ~ThreadPoolLight() { - for (auto& wkr : workers_) wkr->Stop(); - } - - bool TryPush(const ThreadPoolLightTask &task) { - if (!workers_[curr_iworker_]->TryPush(task)) return false; - ++curr_iworker_; - if (curr_iworker_ == workers_.size()) curr_iworker_ = 0; - return true; - } - - void Push(const ThreadPoolLightTask &task) { - // Could try another curr_iworker_ - while (!TryPush(task)) { - Sleep(kSleepForWorkerAvailable); - } - } -}; - -} // namespace cuda_decoder -} // namespace kaldi - -#endif // KALDI_CUDADECODER_THREAD_POOL_LIGHT_H_ diff --git a/src/cudadecoder/thread-pool.h b/src/cudadecoder/thread-pool.h deleted file mode 100644 index dc26ecc688a..00000000000 --- a/src/cudadecoder/thread-pool.h +++ /dev/null @@ -1,169 +0,0 @@ -// cudadecoder/thread-pool.h -// Source: https://github.com/progschj/ThreadPool -// Modified to add a priority queue -// Ubtained under this license: -/* -Copyright (c) 2012 Jakob Progsch, Václav Zeman - -This software is provided 'as-is', without any express or implied -warranty. In no event will the authors be held liable for any damages -arising from the use of this software. - -Permission is granted to anyone to use this software for any purpose, -including commercial applications, and to alter it and redistribute it -freely, subject to the following restrictions: - - 1. The origin of this software must not be misrepresented; you must not - claim that you wrote the original software. If you use this software - in a product, an acknowledgment in the product documentation would be - appreciated but is not required. - - 2. Altered source versions must be plainly marked as such, and must not be - misrepresented as being the original software. - - 3. This notice may not be removed or altered from any source - distribution. -*/ - -// -// Important: This file is deprecated and will be removed in a future release -// - -#ifndef KALDI_CUDA_DECODER_DEPRECATED_THREAD_POOL_H_ -#define KALDI_CUDA_DECODER_DEPRECATED_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace kaldi { -namespace cuda_decoder { - -// C++ indexes enum 0,1,2... -enum [[deprecated]] ThreadPoolPriority { - THREAD_POOL_LOW_PRIORITY, - THREAD_POOL_NORMAL_PRIORITY, - THREAD_POOL_HIGH_PRIORITY -}; - -class [[deprecated]] ThreadPool { - public: - ThreadPool(size_t); - template - auto enqueue(ThreadPoolPriority priority, F &&f, Args &&... args) - -> std::future::type>; - template - auto enqueue(F &&f, Args &&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - struct Task { - std::function func; - // Ordered first by priority, then FIFO order - // tasks created first will have a higher - // priority_with_fifo.second - std::pair priority_with_fifo; - }; - friend bool operator<(const ThreadPool::Task &lhs, - const ThreadPool::Task &rhs); - - std::priority_queue tasks; - long long task_counter; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - - bool stop; -}; - -inline bool operator<(const ThreadPool::Task &lhs, - const ThreadPool::Task &rhs) { - return lhs.priority_with_fifo < rhs.priority_with_fifo; -} - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) - : task_counter(LONG_MAX), stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - Task task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - if (!tasks.empty()) { - task = std::move(this->tasks.top()); - this->tasks.pop(); - } - } - task.func(); - } - }); -} - -// add new work item to the pool : normal priority -template -auto ThreadPool::enqueue(F &&f, Args &&... args) - -> std::future::type> { - return enqueue(THREAD_POOL_NORMAL_PRIORITY, std::forward(f), - std::forward(args)...); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(ThreadPoolPriority priority, F &&f, Args &&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto func = std::make_shared>( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = func->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); - Task task; - task.func = [func]() { (*func)(); }; - long long task_fifo_id = task_counter--; - // The following if will temporarly break the FIFO order - // (leading to a perf drop for a few seconds) - // But it should trigger in ~50 million years - if (task_counter == 0) task_counter = LONG_MAX; - task.priority_with_fifo = {priority, task_fifo_id}; - tasks.push(std::move(task)); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread &worker : workers) worker.join(); -} - -} // end namespace cuda_decoder -} // end namespace kaldi - -#endif // KALDI_CUDA_DECODER_THREAD_POOL_H_ From 6d94122a6fff59ff618f75a90fb965e433b05333 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Tue, 13 Dec 2022 11:03:29 -0800 Subject: [PATCH 7/8] [misc] Install python2.7 This is to fix a CI error. It appears that this is from using "ubuntu-latest" in the CI workflow. It got upgraded to ubuntu 22.04 automatically, and this doesn't have python2.7 by default. --- .github/workflows/c-cpp.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/c-cpp.yml b/.github/workflows/c-cpp.yml index c1f923cf58a..8a21c82ea8f 100644 --- a/.github/workflows/c-cpp.yml +++ b/.github/workflows/c-cpp.yml @@ -19,6 +19,8 @@ jobs: - uses: actions/checkout@v3 - name: Install sox run: sudo apt-get install -y sox intel-mkl + - name: Install python2 + run: sudo apt-get install -y python2 - name: ccache uses: hendrikmuhs/ccache-action@v1.2 with: From 87d577f9f8de058d93ad1af0bf5dc3b86ecbc1ed Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Tue, 13 Dec 2022 11:05:58 -0800 Subject: [PATCH 8/8] Make codefactor changes. --- .../batched-threaded-nnet3-cuda-online-pipeline.cc | 1 - src/cudadecoder/thread-pool-cia.h | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc index 2a6e4f3bd0e..1e3bdeac894 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc @@ -263,7 +263,6 @@ void BatchedThreadedNnet3CudaOnlinePipeline::CompactWavesToMatrix( std::atomic_init(&tasks_remaining, KALDI_CUDA_DECODER_DIV_ROUND_UP(wave_samples.size(), batch_size)); for (size_t i = 0; i < wave_samples.size(); i += batch_size) { - auto task = [i, this, &wave_samples, &m, &cv, &tasks_remaining, &batch_size]() { nvtxRangePush("CompactWavesToMatrix task"); for (size_t j = i; j < std::min(i + batch_size, wave_samples.size()); ++j) { diff --git a/src/cudadecoder/thread-pool-cia.h b/src/cudadecoder/thread-pool-cia.h index 2fdb09f9d93..ca079194370 100644 --- a/src/cudadecoder/thread-pool-cia.h +++ b/src/cudadecoder/thread-pool-cia.h @@ -58,6 +58,7 @@ class threadsafe_queue { std::queue data_queue; std::condition_variable data_cond; std::atomic done; + public: threadsafe_queue(): done(false) {} threadsafe_queue& operator=(const threadsafe_queue&) = delete; @@ -235,6 +236,7 @@ class function_wrapper { impl_type(F&& f_): f(std::move(f_)) {} void call() { f(); } }; + public: template function_wrapper(F&& f): impl(new impl_type(std::move(f))) {} @@ -257,6 +259,7 @@ class futures_thread_pool { threadsafe_queue work_queue; std::vector threads; join_threads joiner; + public: void worker_thread() { #ifdef __linux__ @@ -374,6 +377,7 @@ class work_stealing_queue { typedef function_wrapper data_type; std::deque the_queue; mutable std::mutex the_mutex; + public: work_stealing_queue() {} work_stealing_queue(const work_stealing_queue& other) = delete; @@ -439,6 +443,7 @@ class work_stealing_thread_pool { } return false; } + public: void worker_thread(unsigned int my_index_) { my_index = my_index_;