Skip to content

Commit

Permalink
Improve segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Apr 26, 2022
1 parent 183463d commit 6e2f256
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 40 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name='starcc',
version='0.0.4',
version='0.0.5',
description='Python implementation of StarCC',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
81 changes: 42 additions & 39 deletions src/StarCC/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jieba
from os import path
from pygtrie import CharTrie
from typing import Callable, Optional, Sequence
from typing import Callable, List, Optional, Sequence

from .Dicts import Dicts

Expand Down Expand Up @@ -30,7 +30,7 @@ def _dicts2trie(dicts: str) -> CharTrie:

return trie

def _convert(trie: CharTrie, s: str) -> str:
def _convert_inner(trie: CharTrie, s: str) -> str:
results = []

total_len = len(s)
Expand All @@ -54,47 +54,54 @@ def _convert(trie: CharTrie, s: str) -> str:

return ''.join(results)

def _convert(tries: Sequence[CharTrie], s: str):
for trie in tries:
s = _convert_inner(trie, s)
return s

def _run_once(f):
is_executed = False
def wrapper():
nonlocal is_executed
if not is_executed:
f()
is_executed = True
return wrapper

@_run_once
def _jieba_add_words():
phrase_file = path.join(here, 'dict', 'STPhrases.txt')
with open(phrase_file, encoding='utf-8') as f:
for line in f:
line = line.rstrip('\n')
filenames = ['STPhrases', 'TWPhrasesIT', 'TWPhrasesName', 'TWPhrasesOther']
for filename in filenames:
phrase_file = path.join(here, 'dict', f'{filename}.txt')
with open(phrase_file, encoding='utf-8') as f:
for line in f:
line = line.rstrip('\n')

if line and not line.startswith('#'):
k, _ = line.split('\t')
jieba.add_word(k)
if line and not line.startswith('#'):
k, _ = line.split('\t')
jieba.add_word(k)

class Conversion:
def __init__(self, dicts_list: Sequence[str], seg_funcs: Optional[Sequence[Callable]]=None) -> None:
def __init__(self, dicts_list: Sequence[str], seg_func: Optional[Callable]=None) -> None:
self.tries = [_dicts2trie(dicts) for dicts in dicts_list]

if seg_funcs is None:
self.seg_funcs = [None for _ in dicts_list]
else:
if len(dicts_list) != len(seg_funcs):
raise ValueError('`seg_funcs` should either be `None`, or has the same length with `dicts_list`')
self.seg_funcs = seg_funcs
self.seg_func = (lambda x: [x]) if seg_func is None else seg_func

def __call__(self, s: str) -> str:
for trie, seg_func in zip(self.tries, self.seg_funcs):
if seg_func is None:
s = _convert(trie, s)
else:
results = []
for segment in seg_func(s):
segment = _convert(trie, segment)
results.append(segment)
s = ''.join(results)
return s
results = []
for seg in self.seg_func(s):
seg = _convert(self.tries, seg)
results.append(seg)
return ''.join(results)

class PresetConversion(Conversion):
def __init__(self, src: str='cn', dst: str='hk', with_phrase: bool=False, use_seg: bool=True) -> None:
'''
Initialize a `PresetConversion` object.
`use_seg` Whether to use an external segmentation tool (i.e. jieba) or not
when converting from Simplified to Traditional. If the conversion is not
from Simplified to Traditional, this option has no effect.
`use_seg` Whether to use an external segmentation tool (i.e. jieba) or not when
at least one of the following two conditions is satisfied: (1) converting from
Simplified Chinese; (2) converting to Traditional Chinese (Taiwan) with phrase
conversion. If the conditions are not meet, this option has no effect.
'''

if src not in ('st', 'cn', 'hk', 'tw', 'cnt', 'jp'):
Expand All @@ -104,7 +111,6 @@ def __init__(self, src: str='cn', dst: str='hk', with_phrase: bool=False, use_se
assert src != dst

dicts_list = []
seg_funcs = []

if src != 'st':
if not with_phrase:
Expand All @@ -123,12 +129,6 @@ def __init__(self, src: str='cn', dst: str='hk', with_phrase: bool=False, use_se
'tw': Dicts.TWP2ST,
}[src])

if src == 'cn' and use_seg:
_jieba_add_words()
seg_funcs.append(jieba.cut)
else:
seg_funcs.append(None)

if dst != 'st':
if not with_phrase:
dicts_list.append({
Expand All @@ -146,6 +146,9 @@ def __init__(self, src: str='cn', dst: str='hk', with_phrase: bool=False, use_se
'tw': Dicts.ST2TWP,
}[dst])

seg_funcs.append(None)
use_seg_func = use_seg and (src == 'cn' or dst == 'tw' and with_phrase)
if use_seg_func:
_jieba_add_words()
seg_func = None if not use_seg_func else jieba.cut

super().__init__(dicts_list, seg_funcs)
super().__init__(dicts_list, seg_func)

0 comments on commit 6e2f256

Please sign in to comment.