Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow string literals as segmentation modes #245

Merged
merged 4 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ README*.html
python/dist/
__pycache__/
.env
.venv
*.egg-info
*.so
python/py_src/sudachipy/*.pyd
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.20", features = ["extension-module"] }
thread_local = "1.1" # Apache 2.0/MIT
scopeguard = "1" # Apache 2.0/MIT

[dependencies.sudachi]
path = "../sudachi"
25 changes: 19 additions & 6 deletions python/py_src/sudachipy/sudachipy.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import ClassVar, Iterator, List, Tuple, Union, Callable, Iterable, Optional, Literal, Set
from sudachipy.config import Config
from .config import Config

POS = Tuple[str, str, str, str, str, str]
# POS element
Expand Down Expand Up @@ -32,7 +32,12 @@ class SplitMode:
B: ClassVar[SplitMode] = ...
C: ClassVar[SplitMode] = ...
@classmethod
def __init__(cls) -> None: ...
def __init__(cls, mode: str = "C") -> None:
"""
Creates a split mode from a string value
:param mode: string representation of the split mode
"""
...


class Dictionary:
Expand Down Expand Up @@ -65,7 +70,7 @@ class Dictionary:
...

def create(self,
mode: SplitMode = SplitMode.C,
mode: Union[SplitMode, Literal["A", "B", "C"]] = SplitMode.C,
fields: FieldSet = None,
*,
projection: str = None) -> Tokenizer:
Expand Down Expand Up @@ -96,7 +101,7 @@ class Dictionary:
...

def pre_tokenizer(self,
mode: SplitMode = SplitMode.C,
mode: Union[SplitMode, Literal["A", "B", "C"]] = "C",
fields: FieldSet = None,
handler: Optional[Callable[[int, object, MorphemeList], list]] = None,
*,
Expand Down Expand Up @@ -191,7 +196,7 @@ class Morpheme:
"""
...

def split(self, mode: SplitMode, out: Optional[MorphemeList] = None, add_single: bool = True) -> MorphemeList:
def split(self, mode: Union[SplitMode, Literal["A", "B", "C"]], out: Optional[MorphemeList] = None, add_single: bool = True) -> MorphemeList:
"""
Returns sub-morphemes in the provided split mode.

Expand Down Expand Up @@ -278,7 +283,7 @@ class Tokenizer:
def __init__(cls) -> None: ...

def tokenize(self, text: str,
mode: SplitMode = ...,
mode: Union[SplitMode, Literal["A", "B", "C"]] = ...,
out: Optional[MorphemeList] = None) -> MorphemeList:
"""
Break text into morphemes.
Expand All @@ -295,6 +300,14 @@ class Tokenizer:
"""
...

@property
def mode(self) -> SplitMode:
"""
Get the current analysis mode
:return: current analysis mode
"""
...


class WordInfo:
a_unit_split: ClassVar[List[int]] = ...
Expand Down
39 changes: 30 additions & 9 deletions python/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use std::convert::TryFrom;
use std::fmt::Write;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use sudachi::analysis::Mode;

use crate::errors::{wrap, wrap_ctx, SudachiError as SudachiErr};
use sudachi::analysis::stateless_tokenizer::DictionaryAccess;
Expand Down Expand Up @@ -218,16 +220,20 @@ impl PyDictionary {
/// :param fields: load only a subset of fields.
/// See https://worksapplications.github.io/sudachi.rs/python/topics/subsetting.html
#[pyo3(
text_signature = "($self, mode: sudachipy.SplitMode = sudachipy.SplitMode.C) -> sudachipy.Tokenizer",
text_signature = "($self, mode = 'C') -> sudachipy.Tokenizer",
signature = (mode = None, fields = None, *, projection = None)
)]
fn create(
&self,
mode: Option<PySplitMode>,
fields: Option<&PySet>,
projection: Option<&PyString>,
fn create<'py>(
&'py self,
py: Python<'py>,
mode: Option<&'py PyAny>,
fields: Option<&'py PySet>,
projection: Option<&'py PyString>,
) -> PyResult<PyTokenizer> {
let mode = mode.unwrap_or(PySplitMode::C).into();
let mode = match mode {
Some(m) => extract_mode(py, m)?,
None => Mode::C,
};
let fields = parse_field_subset(fields)?;
let mut required_fields = self.config.projection.required_subset();
let dict = self.dictionary.as_ref().unwrap().clone();
Expand Down Expand Up @@ -283,12 +289,15 @@ impl PyDictionary {
fn pre_tokenizer<'p>(
&'p self,
py: Python<'p>,
mode: Option<PySplitMode>,
mode: Option<&PyAny>,
fields: Option<&PySet>,
handler: Option<Py<PyAny>>,
projection: Option<&PyString>,
) -> PyResult<&'p PyAny> {
let mode = mode.unwrap_or(PySplitMode::C).into();
let mode = match mode {
Some(m) => extract_mode(py, m)?,
None => Mode::C,
};
let subset = parse_field_subset(fields)?;
if let Some(h) = handler.as_ref() {
if !h.as_ref(py).is_callable() {
Expand Down Expand Up @@ -401,6 +410,18 @@ fn config_repr(cfg: &Config) -> Result<String, std::fmt::Error> {
Ok(result)
}

pub(crate) fn extract_mode<'py>(py: Python<'py>, mode: &'py PyAny) -> PyResult<Mode> {
if mode.is_instance_of::<PyString>() {
let mode = mode.str()?.to_str()?;
Mode::from_str(mode).map_err(|e| SudachiErr::new_err(e).into())
} else if mode.is_instance_of::<PySplitMode>() {
let mode = mode.extract::<PySplitMode>()?;
Ok(Mode::from(mode))
} else {
Err(SudachiErr::new_err(("unknown mode", mode.into_py(py))))
}
}

fn read_config_from_fs(path: Option<&Path>) -> PyResult<ConfigBuilder> {
wrap(ConfigBuilder::from_opt_file(path))
}
Expand Down
9 changes: 5 additions & 4 deletions python/src/morpheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ use pyo3::types::{PyList, PyString, PyTuple, PyType};

use sudachi::prelude::{Morpheme, MorphemeList};

use crate::dictionary::{PyDicData, PyDictionary};
use crate::dictionary::{extract_mode, PyDicData, PyDictionary};
use crate::projection::MorphemeProjection;
use crate::tokenizer::PySplitMode;
use crate::word_info::PyWordInfo;

pub(crate) type PyMorphemeList = MorphemeList<Arc<PyDicData>>;
Expand Down Expand Up @@ -362,12 +361,14 @@ impl PyMorpheme {
fn split<'py>(
&'py self,
py: Python<'py>,
mode: PySplitMode,
mode: &PyAny,
out: Option<&'py PyCell<PyMorphemeListWrapper>>,
add_single: Option<bool>,
) -> PyResult<&'py PyCell<PyMorphemeListWrapper>> {
let list = self.list(py);

let mode = extract_mode(py, mode)?;

let out_cell = match out {
None => {
let list = list.empty_clone(py);
Expand All @@ -385,7 +386,7 @@ impl PyMorpheme {
out_ref.clear();
let splitted = list
.internal(py)
.split_into(mode.into(), self.index, out_ref)
.split_into(mode, self.index, out_ref)
.map_err(|e| {
PyException::new_err(format!("Error while splitting morpheme: {}", e.to_string()))
})?;
Expand Down
100 changes: 64 additions & 36 deletions python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
* limitations under the License.
*/

use std::ops::DerefMut;
use std::str::FromStr;
use std::sync::Arc;

use pyo3::exceptions::PyException;
use pyo3::prelude::*;

use sudachi::analysis::stateful_tokenizer::StatefulTokenizer;

use sudachi::dic::subset::InfoSubset;
use sudachi::prelude::*;

use crate::dictionary::PyDicData;
use crate::dictionary::{extract_mode, PyDicData};
use crate::errors::SudachiError as SudachiPyErr;
use crate::morpheme::{PyMorphemeListWrapper, PyProjector};

/// Unit to split text
Expand All @@ -35,33 +37,47 @@ use crate::morpheme::{PyMorphemeListWrapper, PyProjector};
///
/// C == long mode
//
// This implementation is a workaround. Waiting for the pyo3 enum feature.
// ref: [PyO3 issue #834](https://github.com/PyO3/pyo3/issues/834).
#[pyclass(module = "sudachipy.tokenizer", name = "SplitMode")]
#[derive(Clone, PartialEq, Eq)]
#[repr(transparent)]
pub struct PySplitMode {
mode: u8,
}

#[pymethods]
impl PySplitMode {
#[classattr]
pub const A: Self = Self { mode: 0 };

#[classattr]
pub const B: Self = Self { mode: 1 };

#[classattr]
pub const C: Self = Self { mode: 2 };
#[pyclass(module = "sudachipy.tokenizer", name = "SplitMode", frozen)]
#[derive(Clone, PartialEq, Eq, Copy, Debug)]
#[repr(u8)]
pub enum PySplitMode {
A,
B,
C,
}

impl From<PySplitMode> for Mode {
fn from(mode: PySplitMode) -> Self {
match mode {
PySplitMode::A => Mode::A,
PySplitMode::B => Mode::B,
_ => Mode::C,
PySplitMode::C => Mode::C,
}
}
}

impl From<Mode> for PySplitMode {
fn from(value: Mode) -> Self {
match value {
Mode::A => PySplitMode::A,
Mode::B => PySplitMode::B,
Mode::C => PySplitMode::C,
}
}
}

#[pymethods]
impl PySplitMode {
#[new]
fn new(mode: Option<&str>) -> PyResult<PySplitMode> {
let mode = match mode {
Some(m) => m,
None => return Ok(PySplitMode::C),
};

match Mode::from_str(mode) {
Ok(m) => Ok(m.into()),
Err(e) => Err(SudachiPyErr::new_err(e.to_string())),
}
}
}
Expand Down Expand Up @@ -112,29 +128,39 @@ impl PyTokenizer {
/// :type mode: sudachipy.SplitMode
/// :type out: sudachipy.MorphemeList
#[pyo3(
text_signature = "($self, text: str, mode: SplitMode = None, logger = None, out = None) -> sudachipy.MorphemeList",
text_signature = "($self, text: str, mode = None, logger = None, out = None) -> sudachipy.MorphemeList",
signature = (text, mode = None, logger = None, out = None)
)]
#[allow(unused_variables)]
fn tokenize<'py>(
&'py mut self,
py: Python<'py>,
text: &'py str,
mode: Option<PySplitMode>,
mode: Option<&PyAny>,
logger: Option<PyObject>,
out: Option<&'py PyCell<PyMorphemeListWrapper>>,
) -> PyResult<&'py PyCell<PyMorphemeListWrapper>> {
// keep default mode to restore later
// restore default mode on scope exit
let mode = match mode {
None => None,
Some(m) => Some(extract_mode(py, m)?),
};
let default_mode = mode.map(|m| self.tokenizer.set_mode(m.into()));
let mut tokenizer = scopeguard::guard(&mut self.tokenizer, |t| {
default_mode.map(|m| t.set_mode(m));
});

// analysis can be done without GIL
let err = py.allow_threads(|| {
tokenizer.reset().push_str(text);
tokenizer.do_tokenize()
});

self.tokenizer.reset().push_str(text);
self.tokenizer
.do_tokenize()
.map_err(|e| PyException::new_err(format!("Tokenization error: {}", e.to_string())))?;
err.map_err(|e| SudachiPyErr::new_err(format!("Tokenization error: {}", e.to_string())))?;

let out_list = match out {
None => {
let dict = self.tokenizer.dict_clone();
let dict = tokenizer.dict_clone();
let morphemes = MorphemeList::empty(dict);
let wrapper =
PyMorphemeListWrapper::from_components(morphemes, self.projection.clone());
Expand All @@ -146,16 +172,18 @@ impl PyTokenizer {
let mut borrow = out_list.try_borrow_mut();
let morphemes = match borrow {
Ok(ref mut ms) => ms.internal_mut(py),
Err(e) => return Err(PyException::new_err("out was used twice at the same time")),
Err(e) => return Err(SudachiPyErr::new_err("out was used twice at the same time")),
};

morphemes
.collect_results(&mut self.tokenizer)
.map_err(|e| PyException::new_err(format!("Tokenization error: {}", e.to_string())))?;

// restore default mode
default_mode.map(|m| self.tokenizer.set_mode(m));
.collect_results(tokenizer.deref_mut())
.map_err(|e| SudachiPyErr::new_err(format!("Tokenization error: {}", e.to_string())))?;

Ok(out_list)
}

#[getter]
fn mode(&self) -> PySplitMode {
self.tokenizer.mode().into()
}
}
15 changes: 15 additions & 0 deletions python/tests/test_pretokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def test_works_with_different_split_mode(self):
res = tok.encode("外国人参政権")
self.assertEqual(res.ids, [1, 5, 2, 3])

def test_works_with_different_split_mode_str(self):
pretok = self.dict.pre_tokenizer(mode='A')
vocab = {
"[UNK]": 0,
"外国": 1,
"参政": 2,
"権": 3,
"人": 5,
"外国人参政権": 4
}
tok = tokenizers.Tokenizer(WordLevel(vocab, unk_token="[UNK]"))
tok.pre_tokenizer = pretok
res = tok.encode("外国人参政権")
self.assertEqual(res.ids, [1, 5, 2, 3])

def test_with_handler(self):
def _handler(index, sentence: tokenizers.NormalizedString, ml: MorphemeList):
return [tokenizers.NormalizedString(ml[0].part_of_speech()[0]), tokenizers.NormalizedString(str(len(ml)))]
Expand Down
Loading
Loading