From 3dde884b735e7f20db8a90664e3156facf48b3f0 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Mon, 9 Sep 2024 12:52:31 +0200 Subject: [PATCH] fix memory issues (#257) --- src/bpe_tokenizer.cpp | 10 +++++----- src/regex_split.hpp | 2 +- src/utf8_validate.cpp | 8 ++++---- src/utils.cpp | 18 ++++++++++++------ 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/bpe_tokenizer.cpp b/src/bpe_tokenizer.cpp index 71c90ba07..d55b8fd52 100644 --- a/src/bpe_tokenizer.cpp +++ b/src/bpe_tokenizer.cpp @@ -253,19 +253,19 @@ BPETokenizerImpl::BPETokenizerImpl( Vocab new_vocab = vocab; for (size_t i = 0; i < merges.size(); i++) { - auto pair = merges.at(i); + auto& pair = merges.at(i); auto id_pair = std::make_pair(vocab.at(pair.first), vocab.at(pair.second)); new_merges[id_pair] = {i, vocab.at(pair.first + pair.second)}; new_vocab.erase(pair.first + pair.second); } - this->m_vocab = new_vocab; - this->m_merges = new_merges; + m_vocab = std::move(new_vocab); + m_merges = std::move(new_merges); m_trie = std::make_unique(); - for(const auto& word: new_vocab) { + for(const auto& word: m_vocab) { const auto token = std::vector(word.first.begin(), word.first.end()); m_trie->add(token, word.second); } m_cache.reserve(cache_capacity); -} \ No newline at end of file +} diff --git a/src/regex_split.hpp b/src/regex_split.hpp index e6bd38ef3..ccf91e133 100644 --- a/src/regex_split.hpp +++ b/src/regex_split.hpp @@ -69,7 +69,7 @@ class RegexSplit : public ov::op::Op { mutable std::shared_ptr m_search_pattern_pcre2; mutable std::shared_ptr> m_skip_tokens; mutable std::string m_behaviour = "remove"; - mutable SplitMode m_split_mode; + mutable SplitMode m_split_mode = SplitMode::REMOVED; bool m_invert = false; int m_max_splits = -1; diff --git a/src/utf8_validate.cpp b/src/utf8_validate.cpp index abf4aab01..1c87ef835 100644 --- a/src/utf8_validate.cpp +++ b/src/utf8_validate.cpp @@ -39,10 +39,10 @@ bool UTF8Validate::evaluate(ov::TensorVector& outputs, const ov::TensorVector& i // UTF-8 code points should not intersect: // if 2 byte object has code point < 0x80 then it's not valid 2 byte utf-8, // even if it has a valid bit mask. - const uint64_t code_point_starts[4] = {0x0, 0x80, 0x800, 0x10000}; - uint64_t utf_code_point; - size_t bytes_to_consume; // Number of additional 0b10xxxxxx bytes to consume to produce a valid UTF8 symbol. - size_t num_bytes; + const uint32_t code_point_starts[4] = {0x0, 0x80, 0x800, 0x10000}; + uint32_t utf_code_point; + uint32_t bytes_to_consume; // Number of additional 0b10xxxxxx bytes to consume to produce a valid UTF8 symbol. + uint32_t num_bytes; size_t out_idx = begins[0]; for (size_t i = 0; i < begins_shape[0]; i++) { diff --git a/src/utils.cpp b/src/utils.cpp index ba58d6797..dd4b02d8f 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -254,11 +254,6 @@ std::string PCRE2Wrapper::substitute(const std::string& orig_str, pcre2_match_data* match_data = pcre2_match_data_create_from_pattern(m_compiled, NULL); PCRE2_SIZE subject_length = orig_str.size(); - // Usually found pattern is replaced by shorter string, but set 3 times more space for safety. - // Allocate dynamically since lenght depends dynamically on the lenght of input string. - // Allocated memory will be freed at the exit from function. - auto buffer = (PCRE2_UCHAR*) std::malloc(sizeof(PCRE2_UCHAR) * subject_length * 3); - // Check if the string matches the pattern int match_result = pcre2_match( m_compiled, @@ -272,7 +267,17 @@ std::string PCRE2Wrapper::substitute(const std::string& orig_str, pcre2_match_data_free(match_data); return orig_str; } - + + // Usually found pattern is replaced by shorter string, but set 3 times more space for safety. + // Allocate dynamically since lenght depends dynamically on the lenght of input string. + // Allocated memory will be freed at the exit from function. + auto buffer = (PCRE2_UCHAR*) std::malloc(sizeof(PCRE2_UCHAR) * subject_length * 3); + if (buffer == nullptr) { + std::cerr << "Memory allocation failed" << std::endl; + pcre2_match_data_free(match_data); + return orig_str; + } + int rc = pcre2_substitute( m_compiled, (PCRE2_SPTR) orig_str.c_str(), orig_str.size(), @@ -292,6 +297,7 @@ std::string PCRE2Wrapper::substitute(const std::string& orig_str, std::cerr << "PCRE2 substitution failed with error code " << rc << std::endl; } pcre2_match_data_free(match_data); + std::free(buffer); return orig_str; } auto res = std::string(reinterpret_cast(buffer), subject_length);