From b3f2f4c1d9fcaec10ce66e02bd9727da7a4ca52a Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Wed, 12 Jan 2022 02:56:15 +0800 Subject: [PATCH] add match tokenization --- src/TextEncodeBase.jl | 5 ++-- src/base.jl | 19 +++++++++--- src/match.jl | 54 +++++++++++++++++++++++++++++++++ src/tkrs.jl | 17 +++++++++++ src/utils.jl | 70 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 159 insertions(+), 6 deletions(-) create mode 100644 src/match.jl create mode 100644 src/utils.jl diff --git a/src/TextEncodeBase.jl b/src/TextEncodeBase.jl index 0db125b..0c195f8 100644 --- a/src/TextEncodeBase.jl +++ b/src/TextEncodeBase.jl @@ -30,9 +30,10 @@ struct DefaultTokenization <: AbstractTokenization end tokenization(::AbstractTokenizer) = DefaultTokenization() +include("./utils.jl") include("./base.jl") include("./indexed.jl") -# include("./match.jl") -include("tkrs.jl") +include("./match.jl") +include("./tkrs.jl") end diff --git a/src/base.jl b/src/base.jl index 4ed6361..1da5d47 100644 --- a/src/base.jl +++ b/src/base.jl @@ -39,8 +39,16 @@ Document(x) = Document(x, nothing) Sentence(x) = Sentence(x, nothing) SubSentence(x) = SubSentence(x, nothing) Word(x) = Word(x, nothing) +SubWord(x) = SubWord(x, nothing) Token(x) = Token(x, nothing) +updatemeta(x::Document, meta) = Document(x.x, meta) +updatemeta(x::Sentence, meta) = Sentence(x.x, meta) +updatemeta(x::SubSentence, meta) = SubSentence(x.x, meta) +updatemeta(x::Word, meta) = Word(x.x, meta) +updatemeta(x::SubWord, meta) = SubWord(x.x, meta) +updatemeta(x::Token, meta) = Token(x.x, meta) + function Base.show(io::IO, t::TokenStages) print(io, typeof(t).name.name) vs = filter(!isnothing, ntuple(i->getfield(t, i), fieldcount(typeof(t)))) @@ -70,13 +78,16 @@ let ATR = AbstractTokenizer, AT = AbstractTokenization # [tokenization dispatch] default behavior on specific stages, mark the splitting result for further tokenization global @inline tokenize(::AT, ::DocumentStage, x) = Sentence(x) global @inline tokenize(::AT, ::SentenceStage, x) = Token(x) - # [tokenization dispatch] skip if splitting result is already wrapped + global @inline tokenize(::AT, ::SubSentenceStage, x) = Token(x) + # [tokenization dispatch] default skip if splitting result is already wrapped global @inline tokenize(::AT, ::TokenStages, x::TokenStages) = x # [full dispatch, default to ignore tokenizer] the outer-most api, but these stages are usually unsplittable - global @inline tokenize(::ATR, ::AT, w::WordStage) = [Token(w.x)] - global @inline tokenize(::ATR, ::AT, w::SubWordStage) = [Token(w.x)] - global @inline tokenize(::ATR, ::AT, t::TokenStage) = [t] + global @inline tokenize(tkr::ATR, t::AT, s::Union{WordStage, SubWordStage, TokenStage}) = tokenize(t, s) + # [tokenization dispatch] default behavior of unspplittable type + global @inline tokenize(::AT, w::WordStage) = [Token(w.x)] + global @inline tokenize(::AT, w::SubWordStage) = [Token(w.x)] + global @inline tokenize(::AT, t::TokenStage) = [t] # [full dispatch] the outer-most api, splitting input and recursively tokenize the result. ignore if input is empty global @inline tokenize(tkr::ATR, t::AT, x::TokenStages) = tokenize_procedure(tkr, t, x) end diff --git a/src/match.jl b/src/match.jl new file mode 100644 index 0000000..f4d3d5d --- /dev/null +++ b/src/match.jl @@ -0,0 +1,54 @@ +struct MatchTokenization <: AbstractTokenization + patterns::Vector{Regex} +end + +splitting(t::MatchTokenization, s::SentenceStage) = collect(Tuple{Bool, SubString}, matchsplits(t.patterns, s.x)) + +@inline tokenize(t::MatchTokenization, s::SentenceStage, (istoken, x)) = istoken ? Token(x, s.meta) : SubSentence(x, s.meta) + + +struct IndexedMatchTokenization <: AbstractTokenization + patterns::Vector{Regex} +end + +@inline splitting(t::IndexedMatchTokenization, s::SentenceStage) = splitting(MatchTokenization(t.patterns), s) +@inline splitting(::IndexedMatchTokenization, s::TokenStages, x) = splitting(IndexedTokenization(), s, x) + +function splitting(::IndexedMatchTokenization, s::SubSentenceStage, x) + lastid = length(x) + !isnothing(s.meta.rsibling) && (s.meta.rsibling[] = lastid + s.meta.offset[]) + return enumerate(x) +end + +function splitting(::IndexedMatchTokenization, s::SentenceStage, x) + tokenoffset = map(Base.RefValue, 0:length(x)-1) + RV = Base.RefValue{Int} + v = Tuple{RV, Tuple{Bool, SubString}, Union{RV, Nothing}}[] + for ((i, sp), offset) in zip(enumerate(x), tokenoffset) + push!(v, (offset, sp, i == lastindex(x) ? nothing : tokenoffset[i+1])) + end + return v +end + +function tokenize(::IndexedMatchTokenization, s::SentenceStage, (offset, (istoken, x), rsibling)) + meta = merge(s.meta, (offset = offset, rsibling = rsibling)) + return istoken ? Token(x, meta) : SubSentence(x, meta) +end + +@inline tokenize(::IndexedMatchTokenization, d::DocumentStage, x) = tokenize(IndexedTokenization(), d, x) + +function tokenize(::IndexedMatchTokenization, s::SubSentenceStage, (i, x)) + offset = s.meta.offset[] + meta = Base.structdiff(s.meta, NamedTuple{(:offset, :rsibling)}) + return Token(x, merge(meta, (token_id = i+offset,))) +end + +function tokenize(::IndexedMatchTokenization, x::TokenStage) + if haskey(x.meta, :offset) && haskey(x.meta, :rsibling) + cid = x.meta.offset[]+1 + x.meta.rsibling[] = cid + meta = Base.structdiff(x.meta, NamedTuple{(:offset, :rsibling)}) + return [updatemeta(x, merge(meta, (token_id = cid,)))] + end + return [x] +end diff --git a/src/tkrs.jl b/src/tkrs.jl index f0d39e6..0087cb0 100644 --- a/src/tkrs.jl +++ b/src/tkrs.jl @@ -1,3 +1,8 @@ +struct MixedTokenization{T <:Tuple} <: AbstractTokenization + ts::T +end +MixedTokenization(ts...) = MixedTokenization(ts) + "tokenizer that run the default behavior" struct NaiveTokenizer <: AbstractTokenizer end @@ -5,3 +10,15 @@ struct NaiveTokenizer <: AbstractTokenizer end struct NaiveIndexedTokenizer <: AbstractTokenizer end tokenization(::NaiveIndexedTokenizer) = IndexedTokenization() +"default behavior but don't split some pattern" +struct NaiveMatchTokenizer <: AbstractTokenizer + patterns::Vector{Regex} +end +tokenization(tkr::NaiveMatchTokenizer) = MatchTokenization(tkr.patterns) + +"default behavior but counting index and don't split some pattern" +struct NaiveIndexedMatchTokenizer <: AbstractTokenizer + patterns::Vector{Regex} +end +tokenization(tkr::NaiveIndexedMatchTokenizer) = IndexedMatchTokenization(tkr.patterns) + diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..666dde7 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,70 @@ +# match utils + +struct MatchSplitIterator + t::Regex + s::Union{String, SubString} +end +Base.eltype(::Type{MatchSplitIterator}) = Tuple{Bool, SubString} +Base.IteratorSize(::Type{MatchSplitIterator}) = Base.SizeUnknown() + +function Base.iterate(itr::MatchSplitIterator, (r, i, e) = (nothing, firstindex(itr.s), lastindex(itr.s))) + i > e && return nothing + t, s = itr.t, itr.s + if !isnothing(r) + ri, re = first(r), last(r) + j = isempty(r) ? first(r) : last(r) + v = (true, SubString(s, ri, re)) + return v, j > e ? (nothing, i, -1) : (nothing, @inbounds(nextind(s, j)), e) + end + + r = findnext(itr.t, itr.s, i) + if isnothing(r) + return (false, SubString(s, i, e)), (nothing, i, -1) + end + + ri, re = first(r), last(r) + if i != ri + return (false, SubString(s, i, @inbounds(prevind(s, ri)))), (r, i, e) + else + j = isempty(r) ? first(r) : last(r) + v = (true, SubString(s, ri, re)) + return v, j > e ? (nothing, i, -1) : (nothing, @inbounds(nextind(s, j)), e) + end + nothing +end + +matchsplit(t, s) = matchsplit!(Tuple{Bool, SubString}[], t, s) +function matchsplit!(found, t, s) + i, e = firstindex(s), lastindex(s) + + while true + r = findnext(t, s, i) + if isnothing(r) + push!(found, (false, SubString(s, i, e))) + break + end + + ri, re = first(r), last(r) + i != ri && push!(found, (false, @inbounds SubString(s, i, prevind(s, ri)))) + push!(found, (true, SubString(s, ri, re))) + + j = isempty(r) ? first(r) : last(r) + j > e && break + @inbounds i = nextind(s, j) + i > e && break + end + return found +end + +function matchsplits(patterns, x) + m, ms = first(patterns), @view patterns[2:end] + sp = MatchSplitIterator(m, x) + + for m in ms + iters = Iterators.map(sp) do (istoken, s) + istoken ? ((istoken, s) for _ = 1:1) : MatchSplitIterator(m, s) + end + sp = Iterators.Flatten(iters) + end + return sp +end