Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multithreading to the OCR #69

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ include(CheckCXXSourceCompiles)
include(CheckCXXSourceRuns)

set(CMAKE_C_FLAGS "-std=gnu99")
set(CMAKE_CXX_FLAGS "-ansi -pedantic -Wall -Wextra -Wno-long-long")
set(CMAKE_CXX_FLAGS "-DUSE_STD_NAMESPACE -std=gnu++11 -pedantic -Wall -Wextra -Wno-long-long")

set(CMAKE_CXX_FLAGS_RELEASE "-O3 -mtune=native -march=native -DNDEBUG -fomit-frame-pointer -ffast-math") # TODO -Ofast GCC 4.6
set(CMAKE_C_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
Expand Down
6 changes: 2 additions & 4 deletions CMakeModules/FindTesseract.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ check_cxx_source_compiles(
int main() {
}"
TESSERACT_NAMESPACE)
if(TESSERACT_NAMESPACE)
add_definitions("-DCONFIG_TESSERACT_NAMESPACE")
else()
message(WARNING "You are using an old Tesseract version. Support for Tesseract 2 is deprecated and will be removed in the future!")
if(NOT TESSERACT_NAMESPACE)
message(FATAL_ERROR "You are using an old Tesseract version. Support for Tesseract 2 has been removed when implementing multithreading.")
endif()
list(REMOVE_ITEM CMAKE_REQUIRED_INCLUDES ${Tesseract_INCLUDE_DIR})

Expand Down
2 changes: 1 addition & 1 deletion doc/completion.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ _vobsub2srt() {

case $cur in
-*)
COMPREPLY=( $( compgen -W '--dump-images --verbose --ifo --lang --langlist --tesseract-lang --tesseract-data --blacklist --y-threshold --min-width --min-height' -- "$cur" ) )
COMPREPLY=( $( compgen -W '--dump-images --verbose --ifo --lang --langlist --tesseract-lang --tesseract-data --blacklist --y-threshold --min-width --min-height --max-threads' -- "$cur" ) )
;;
*)
_filedir '(idx|IDX|sub|SUB)'
Expand Down
3 changes: 3 additions & 0 deletions doc/vobsub2srt.1
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ Minimum width in pixels to consider a subpicture for OCR (Default: 9).
.TP
\fB\-\-min-height\fR \fIheight\fR
Minimum height in pixels to consider a subpicture for OCR (Default: 1).
.TP
\fB\-\-max\-threads\fR \fInb\fR
Maximum number of threads to use to do the OCR, use 0 to autodetect the number of cores (Default: 0).
.SH EXAMPLES
.nf
$ \fBvobsub2srt \-\-lang en foobar\fR
Expand Down
175 changes: 123 additions & 52 deletions src/vobsub2srt.c++
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@
// Tesseract OCR
#include "tesseract/baseapi.h"

#include <unistd.h>
#include <iostream>
#include <string>
#include <cstdio>
#include <vector>
#include <atomic>
#include <algorithm>
#include <thread>
#include <mutex>
using namespace std;

#include "langcodes.h++"
Expand All @@ -40,10 +45,10 @@ typedef void* spu_t;

// helper struct for caching and fixing end_pts in some cases
struct sub_text_t {
sub_text_t(unsigned start_pts, unsigned end_pts, char const *text)
: start_pts(start_pts), end_pts(end_pts), text(text)
sub_text_t(unsigned counter, unsigned start_pts, unsigned end_pts, char const *text)
: counter(counter), start_pts(start_pts), end_pts(end_pts), text(text)
{ }
unsigned start_pts, end_pts;
unsigned counter, start_pts, end_pts;
char const *text;
};

Expand Down Expand Up @@ -83,15 +88,69 @@ void dump_pgm(std::string const &filename, unsigned counter, unsigned width, uns
}
}

#ifdef CONFIG_TESSERACT_NAMESPACE
using namespace tesseract;
#endif

#define TESSERACT_DEFAULT_PATH "<builtin default>"
#ifndef TESSERACT_DATA_PATH
#define TESSERACT_DATA_PATH TESSERACT_DEFAULT_PATH
#endif

TessBaseAPI* init_tesseract(std::string tesseract_data_path, char const *tess_lang, std::string blacklist) {
char const *tess_path = NULL;
if (tesseract_data_path != TESSERACT_DEFAULT_PATH)
tess_path = tesseract_data_path.c_str();

TessBaseAPI *tess_base_api = new TessBaseAPI();
if(tess_base_api->Init(tess_path, tess_lang) == -1) {
delete tess_base_api;
cerr << "Failed to initialize tesseract (OCR).\n";
return NULL;
}
if(not blacklist.empty()) {
tess_base_api->SetVariable("tessedit_char_blacklist", blacklist.c_str());
}
return tess_base_api;
}

void do_ocr(TessBaseAPI *tess_base_api, atomic<bool> *done, vector<sub_text_t> *conv_subs, mutex *mut,
unsigned counter, unsigned width, unsigned height, unsigned stride,
unsigned char *image_cpy, unsigned start_pts, unsigned end_pts, bool verb) {

char *text = tess_base_api->TesseractRect(image_cpy, 1, stride, 0, 0, width, height);
free(image_cpy);

if(not text) {
cerr << "ERROR: OCR failed for " << counter << '\n';
char const errormsg[] = "VobSub2SRT ERROR: OCR failure!";
// using raw memory is evil but that's the way Tesseract works
// If we switch to C++11 we can use unique_ptr.
text = new char[sizeof(errormsg)];
memcpy(text, errormsg, sizeof(errormsg));
}
else {
size_t size = strlen(text);
while (size > 0 and isspace(text[--size])) {
text[size] = '\0';
}
}
if(verb) {
cout << counter << " Text: " << text << endl;
}
mut->lock();
conv_subs->push_back(sub_text_t(counter, start_pts, end_pts, text));
mut->unlock();
done->store(true);
}

struct ocr_thread_t {
ocr_thread_t(TessBaseAPI *tess_base_api)
: tess_base_api(tess_base_api)
{ }
thread *t = NULL;
atomic<bool> done{false};
TessBaseAPI *tess_base_api = NULL;
};

int main(int argc, char **argv) {
bool dump_images = false;
bool verb = false;
Expand All @@ -106,6 +165,7 @@ int main(int argc, char **argv) {
int y_threshold = 0;
int min_width = 9;
int min_height = 1;
int max_threads = 0;

{
/************************************************************************************
Expand All @@ -125,6 +185,7 @@ int main(int argc, char **argv) {
add_option("y-threshold", y_threshold, "Y (luminance) threshold below which colors treated as black (Default: 0)").
add_option("min-width", min_width, "Minimum width in pixels to consider a subpicture for OCR (Default: 9)").
add_option("min-height", min_height, "Minimum height in pixels to consider a subpicture for OCR (Default: 1)").
add_option("max-threads", max_threads, "Maximum number of threads to use to do the OCR, use 0 to autodetect the number of cores (Default: 0)").
add_unnamed(subname, "subname", "name of the subtitle files WITHOUT .idx/.sub ending! (REQUIRED)");
if(not opts.parse_cmd(argc, argv) or subname.empty()) {
return 1;
Expand Down Expand Up @@ -200,27 +261,6 @@ int main(int argc, char **argv) {
}
}

// Init Tesseract
char const *tess_path = NULL;
if (tesseract_data_path != TESSERACT_DEFAULT_PATH)
tess_path = tesseract_data_path.c_str();

#ifdef CONFIG_TESSERACT_NAMESPACE
TessBaseAPI tess_base_api;
if(tess_base_api.Init(tess_path, tess_lang) == -1) {
cerr << "Failed to initialize tesseract (OCR).\n";
return 1;
}
if(not blacklist.empty()) {
tess_base_api.SetVariable("tessedit_char_blacklist", blacklist.c_str());
}
#else
TessBaseAPI::SimpleInit(tess_path, tess_lang, false); // TODO params
if(not blacklist.empty()) {
TessBaseAPI::SetVariable("tessedit_char_blacklist", blacklist.c_str());
}
#endif

// Open srt output file
string const srt_filename = subname + ".srt";
FILE *srtout = fopen(srt_filename.c_str(), "w");
Expand All @@ -229,14 +269,22 @@ int main(int argc, char **argv) {
return 1;
}

if (max_threads <= 0)
max_threads = thread::hardware_concurrency();

vector<ocr_thread_t*> threads;

// Read subtitles and convert
void *packet;
int timestamp; // pts100
int len;
unsigned last_start_pts = 0;
unsigned sub_counter = 1;

vector<sub_text_t> conv_subs;
conv_subs.reserve(200); // TODO better estimate
mutex mut;

while( (len = vobsub_get_next_packet(vob, &packet, &timestamp)) > 0) {
if(timestamp >= 0) {
spudec_assemble(spu, reinterpret_cast<unsigned char*>(packet), len, timestamp);
Expand Down Expand Up @@ -269,50 +317,73 @@ int main(int argc, char **argv) {
dump_pgm(subname, sub_counter, width, height, stride, image, image_size);
}

#ifdef CONFIG_TESSERACT_NAMESPACE
char *text = tess_base_api.TesseractRect(image, 1, stride, 0, 0, width, height);
#else
char *text = TessBaseAPI::TesseractRect(image, 1, stride, 0, 0, width, height);
#endif
if(not text) {
cerr << "ERROR: OCR failed for " << sub_counter << '\n';
char const errormsg[] = "VobSub2SRT ERROR: OCR failure!";
// using raw memory is evil but that's the way Tesseract works
// If we switch to C++11 we can use unique_ptr.
text = new char[sizeof(errormsg)];
memcpy(text, errormsg, sizeof(errormsg));
}
else {
size_t size = strlen(text);
while (size > 0 and isspace(text[--size])) {
text[size] = '\0';
ocr_thread_t *ocr_thread = NULL;
if (threads.size() < static_cast<unsigned>(max_threads)) {
TessBaseAPI *tess_base_api = init_tesseract(tesseract_data_path, tess_lang, blacklist);
if (tess_base_api == NULL)
return -1;
ocr_thread = new ocr_thread_t(tess_base_api);
threads.push_back(ocr_thread);
} else if (max_threads == 1) {
ocr_thread = threads[0];
} else {
while (ocr_thread == NULL) {
for (unsigned i=0; i < threads.size(); i++) {
if (threads[i]->done) {
threads[i]->t->join();
delete threads[i]->t;
ocr_thread = threads[i];
break;
}
}
if (ocr_thread == NULL)
usleep(50);
}
}
if(verb) {
cout << sub_counter << " Text: " << text << endl;

unsigned char *image_cpy = (unsigned char *)malloc(image_size);
memcpy(image_cpy, image, image_size);

if (max_threads == 1)
do_ocr(ocr_thread->tess_base_api, &ocr_thread->done, &conv_subs, &mut, sub_counter, width, height, stride, image_cpy, start_pts, end_pts, verb);
else {
ocr_thread->done = false;
ocr_thread->t = new thread(do_ocr, ocr_thread->tess_base_api, &ocr_thread->done, &conv_subs, &mut, sub_counter, width, height, stride, image_cpy, start_pts, end_pts, verb);
}
conv_subs.push_back(sub_text_t(start_pts, end_pts, text));

++sub_counter;
}
}

for(unsigned i = 0; i < threads.size(); ++i) {
if (threads[i]->t != NULL) {
threads[i]->t->join();
delete threads[i]->t;
}
threads[i]->tess_base_api->End();
delete threads[i]->tess_base_api;
delete threads[i];
}

struct {
bool operator()(sub_text_t a, sub_text_t b) const {
return a.counter < b.counter;
}
} sort_fct;
sort(conv_subs.begin(), conv_subs.end(), sort_fct);

// write the file, fixing end_pts when needed
for(unsigned i = 0; i < conv_subs.size(); ++i) {
if(conv_subs[i].end_pts == UINT_MAX && i+1 < conv_subs.size())
conv_subs[i].end_pts = conv_subs[i+1].start_pts;

fprintf(srtout, "%u\n%s --> %s\n%s\n\n", i+1, pts2srt(conv_subs[i].start_pts).c_str(),
fprintf(srtout, "%u\n%s --> %s\n%s\n\n", conv_subs[i].counter, pts2srt(conv_subs[i].start_pts).c_str(),
pts2srt(conv_subs[i].end_pts).c_str(), conv_subs[i].text);

delete[]conv_subs[i].text;
conv_subs[i].text = 0x0;
}

#ifdef CONFIG_TESSERACT_NAMESPACE
tess_base_api.End();
#else
TessBaseAPI::End();
#endif
fclose(srtout);
cout << "Wrote Subtitles to '" << srt_filename << "'\n";
vobsub_close(vob);
Expand Down