Skip to content

Commit

Permalink
Cherry Pick For 24.1 Release (#118)
Browse files Browse the repository at this point in the history
* Add Example RWKV to Readme

(cherry picked from commit fb4142d)

* Update Example

(cherry picked from commit 6441bef)

* Add Cache to VocabEncoder

(cherry picked from commit fd32aa5)

* Install torch CPU

(cherry picked from commit dcc5fcd)

* Parse QWEN EOD_TOKEN_ID

(cherry picked from commit 029a341)

* Fix Parse

(cherry picked from commit 0298fea)

* Fix Pass-Rate

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
apaniukov and ilya-lavrenov authored Apr 18, 2024
1 parent e9d708d commit ad37623
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ jobs:
wheel_name=$(find . -name 'openvino_tokenizers*.whl')
python3 -m pip install $wheel_name[dev]
popd
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu

- name: Tokenizers regression tests (using openvino python modules)
if: needs.openvino_download.outputs.status == 'success'
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ jobs:
wheel_name=$(find . -name 'openvino_tokenizers*.whl')
python3 -m pip install $wheel_name[dev]
popd
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu

- name: Tokenizers regression tests (using openvino python modules)
if: needs.openvino_download.outputs.status == 'success'
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ jobs:
# Find and install wheel
$ovCoreWheelPath=Get-ChildItem -Path "${{ env.INSTALL_DIR }}\\ov_tokenizers" -Filter openvino_tokenizers*.whl | % { $_.FullName }
python3 -m pip install "$ovCoreWheelPath[all]"
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu

- name: Tokenizers regression tests (using openvino python modules)
if: needs.openvino_download.outputs.status == 'success'
Expand Down
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# OpenVINO Tokenizers

[![Downloads](https://static.pepy.tech/badge/openvino-tokenizers)](https://pepy.tech/project/openvino-tokenizers)

OpenVINO Tokenizers adds text processing operations to OpenVINO.

## Features
Expand Down Expand Up @@ -59,6 +61,9 @@ This command is the equivalent of minimal installation. Install tokenizers conve
```bash
pip install transformers[sentencepiece] tiktoken
```
:warning: Latest commit of OpenVINO Tokenizers might rely on features that are not present in the release OpenVINO version.
Use [a nightly build](https://docs.openvino.ai/2024/get-started/install-openvino.html?VERSION=NIGHTLY) of OpenVINO or build
OpenVINO Tokenizers from a release branch if you have issues with the build process.

### Build and install for development
```bash
Expand Down Expand Up @@ -279,6 +284,29 @@ tf_result = tf_embed(sentences)
assert np.all(np.isclose(ov_result, tf_result, atol=1e-4))
```

### RWKV Tokenizer

```python
from urllib.request import urlopen

from openvino import compile_model
from openvino_tokenizers import build_rwkv_tokenizer


rwkv_vocab_url = (
"https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/tokenizer/rwkv_vocab_v20230424.txt"
)

with urlopen(rwkv_vocab_url) as vocab_file:
vocab = map(bytes.decode, vocab_file)
tokenizer, detokenizer = build_rwkv_tokenizer(vocab)

tokenizer, detokenizer = compile_model(tokenizer), compile_model(detokenizer)

print(tokenized := tokenizer(["Test string"])["input_ids"]) # [[24235 47429]]
print(detokenizer(tokenized)["string_output"]) # ['Test string']
```

## Supported Tokenizer Types

| Huggingface <br/>Tokenizer Type | Tokenizer Model Type | Tokenizer | Detokenizer |
Expand All @@ -288,6 +316,7 @@ assert np.all(np.isclose(ov_result, tf_result, atol=1e-4))
| | Unigram |||
| Legacy | SentencePiece .model |||
| Custom | tiktoken |||
| RWKV | Trie |||

## Test Results

Expand Down
4 changes: 2 additions & 2 deletions python/openvino_tokenizers/build_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Iterable, Tuple

from openvino import Model, PartialShape, Type
from openvino.runtime import op
Expand All @@ -16,7 +16,7 @@


def build_rwkv_tokenizer(
rwkv_vocab: List[str],
rwkv_vocab: Iterable[str],
clean_up_tokenization_spaces: bool = False,
tokenizer_output_type: Type = Type.i64,
detokenizer_input_type: Type = Type.i64,
Expand Down
4 changes: 2 additions & 2 deletions python/openvino_tokenizers/tokenizer_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass, field
from functools import singledispatchmethod
from itertools import chain, islice
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from openvino.runtime import Model, Output, PartialShape, Type, op
Expand Down Expand Up @@ -382,7 +382,7 @@ def fill_vocab(vocab: List[str], indices: List[int]) -> Tuple[List[str], List[in
return new_vocab, new_indices

@classmethod
def from_rwkv_vocab(cls, vocab_file_strings: Iterator[str]) -> TrieTokenizerStep:
def from_rwkv_vocab(cls, vocab_file_strings: Iterable[str]) -> TrieTokenizerStep:
vocab = []
indices = []
for line in vocab_file_strings:
Expand Down
34 changes: 15 additions & 19 deletions src/vocab_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
using namespace ov;


VocabEncoder::VocabEncoder (const ov::OutputVector& arguments) :
ov::op::Op(arguments) {
constructor_validate_and_infer_types();
}


void VocabEncoder::validate_and_infer_types() {
// main string input
check_string_input(this, 0);
Expand All @@ -44,19 +38,21 @@ bool VocabEncoder::evaluate(ov::TensorVector& outputs, const ov::TensorVector& i
auto ends = inputs[1].data<const int32_t>();
auto chars = inputs[2].data<const uint8_t>();

// vocab string keys
auto vocab_begins = inputs[3].data<const int32_t>();
auto vocab_ends = inputs[4].data<const int32_t>();
auto vocab_chars = inputs[5].data<const uint8_t>();
if (m_vocab == nullptr) {
// vocab string keys
auto vocab_begins = inputs[3].data<const int32_t>();
auto vocab_ends = inputs[4].data<const int32_t>();
auto vocab_chars = inputs[5].data<const uint8_t>();

auto vocab_values = inputs[6].data<const int32_t>();
auto vocab_size = inputs[6].get_size();
auto vocab_values = inputs[6].data<const int32_t>();
auto vocab_size = inputs[6].get_size();

std::map<std::vector<uint8_t>, int32_t> vocab;
for (size_t i = 0; i < vocab_size; ++i) {
std::vector<uint8_t> token = std::vector<uint8_t>(vocab_chars + vocab_begins[i], vocab_chars + vocab_ends[i]);
vocab[token] = vocab_values[i];
};
m_vocab = std::make_shared<std::map<std::vector<unsigned char>, int32_t>>();
for (size_t i = 0; i < vocab_size; ++i) {
std::vector<uint8_t> token = std::vector<uint8_t>(vocab_chars + vocab_begins[i], vocab_chars + vocab_ends[i]);
m_vocab->insert(std::pair{token, vocab_values[i]});
};
}

auto default_value = *inputs[7].data<const int32_t>();
const size_t num_elements = inputs[0].get_size();
Expand All @@ -66,8 +62,8 @@ bool VocabEncoder::evaluate(ov::TensorVector& outputs, const ov::TensorVector& i
auto token_ids = outputs[0].data<int32_t>();

for (size_t element_idx = 0; element_idx < num_elements; ++element_idx) {
auto element = vocab.find(std::vector<uint8_t>(chars + begins[element_idx], chars + ends[element_idx]));
if (element == vocab.end()) {
auto element = m_vocab->find(std::vector<uint8_t>(chars + begins[element_idx], chars + ends[element_idx]));
if (element == m_vocab->end()) {
token_ids[element_idx] = default_value;
} else {
token_ids[element_idx] = element->second;
Expand Down
19 changes: 13 additions & 6 deletions src/vocab_encoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,30 @@
#pragma once
#include <vector>
#include <openvino/op/op.hpp>
#include "openvino/opsets/opset13.hpp"

using namespace ov;
using namespace ov::opset13;


class VocabEncoder : public ov::op::Op {
public:
OPENVINO_OP("VocabEncoder");

VocabEncoder () = default;
VocabEncoder(
const ov::OutputVector& arguments
);

VocabEncoder(const ov::OutputVector& arguments) :
ov::op::Op(arguments) {
constructor_validate_and_infer_types();
}

VocabEncoder(const ov::OutputVector& arguments, std::shared_ptr<std::map<std::vector<unsigned char>, int32_t>> vocab) :
ov::op::Op(arguments), m_vocab(vocab) {
constructor_validate_and_infer_types();
}

void validate_and_infer_types() override;

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return std::make_shared<VocabEncoder>(inputs);
return std::make_shared<VocabEncoder>(inputs, m_vocab);
}

bool visit_attributes(ov::AttributeVisitor& visitor) override {
Expand All @@ -35,4 +40,6 @@ class VocabEncoder : public ov::op::Op {
bool has_evaluate() const override {
return true;
}
private:
mutable std::shared_ptr<std::map<std::vector<unsigned char>, int32_t>> m_vocab;
};

0 comments on commit ad37623

Please sign in to comment.