Skip to content

Commit

Permalink
Sequence template (#5)
Browse files Browse the repository at this point in the history
* init impl for sequence template

* add dynamic type id

* update to 0.5.9

* docstring for template term
  • Loading branch information
chengchingwen authored Sep 18, 2022
1 parent 77e6d51 commit 98c5e6e
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TextEncodeBase"
uuid = "f92c20c0-9f2a-4705-8116-881385faba05"
authors = ["chengchingwen <[email protected]> and contributors"]
version = "0.5.8"
version = "0.5.9"

[deps]
FuncPipelines = "9ed96fbb-10b6-44d4-99a6-7e2a3dc8861b"
Expand Down
218 changes: 218 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,221 @@ function _nested2batch!(arr, offset, x::AbstractArray{>:AbstractArray})
error("Input array is mixing array and non-array elements")
end
end

# Sequence template

"""
abstract type TemplateTerm{T} end
Abstract type for term used in [`SequenceTemplate`](@ref).
"""
abstract type TemplateTerm{T} end

"""
InputTerm{T}(type_id = 1)
A `TemplateTerm` that take out a sequence from the input.
"""
struct InputTerm{T} <: TemplateTerm{T}
type_id::Int
InputTerm{T}(type_id = 1) where T = new{T}(type_id)
end

"""
IndexInputTerm{T}(idx::Int, type_id = 1)
A `TemplateTerm` that take the `idx`-th sequence of the input. If the `IndexInputTerm` is also the `idx`-th
input related term in a [`SequenceTemplate`](@ref), it behave the same as [`InputTerm`](@ref).
"""
struct IndexInputTerm{T} <: TemplateTerm{T}
idx::Int
type_id::Int
IndexInputTerm{T}(idx, type_id = 1) where T = new{T}(idx, type_id)
end

"""
ConstTerm(value::T, type_id = 1)
A `TemplateTerm` that simply put `value` to the output sequence.
"""
struct ConstTerm{T} <: TemplateTerm{T}
value::T
type_id::Int
end
ConstTerm(value, type_id = 1) = ConstTerm{typeof(value)}(value, type_id)

"""
RepeatedTerm(terms::TemplateTerm...; dynamic_type_id = false)
A special term that indicate the `terms` sequence can appear zero or multiple times. Cannot be nested.
If `dynamic_type_id` is set, each repeat would add an offset value to the type id of those repeat `terms`.
The offset value if the number of repetiton, starting form `0`, times `dynamic_type_id`.
"""
struct RepeatedTerm{T, Ts<:Tuple{Vararg{TemplateTerm{T}}}} <: TemplateTerm{T}
terms::Ts
dynamic_type_id::Int
function RepeatedTerm(terms::Tuple{Vararg{TemplateTerm{T}}}, dynamic_type_id = false) where T
@assert length(terms) >= 1 "No TemplateTerm provided."
@assert !any(Base.Fix2(isa, RepeatedTerm), terms) "Cannot nest RepeatedTerm"
return new{T, typeof(terms)}(terms, dynamic_type_id)
end
end
RepeatedTerm(terms::TemplateTerm...; dynamic_type_id = false) = RepeatedTerm(terms, dynamic_type_id)

"""
SequenceTemplate(terms::TemplateTerm)(sequences...)
Constructing a function by multiple `TemplateTerm` that indicate how to combine the input `sequences`. Return
a tuple of the result sequence and a type id (a special number associated with the template term) sequence.
# Example
```julia-repl
julia> SequenceTemplate(ConstTerm(-1), InputTerm{Int}(), ConstTerm(-2))(1:5)[1] == TextEncodeBase.with_head_tail(1:5, -1, -2)
true
julia> SequenceTemplate(ConstTerm(-1), InputTerm{Int}(), ConstTerm(-2))(1:5)
([-1, 1, 2, 3, 4, 5, -2], [1, 1, 1, 1, 1, 1, 1])
julia> bert_template = SequenceTemplate(
ConstTerm("[CLS]", 1), InputTerm{String}(1), ConstTerm("[SEP]", 1),
RepeatedTerm(InputTerm{String}(2), ConstTerm("[SEP]", 2))
)
SequenceTemplate{String}([CLS]:<type=1> Input:<type=1> [SEP]:<type=1> (Input:<type=2> [SEP]:<type=2>)...)
julia> bert_template(["hello", "world"])
(["[CLS]", "hello", "world", "[SEP]"], [1, 1, 1, 1])
julia> bert_template(["hello", "world"], ["today", "is", "a", "good", "day"])
(["[CLS]", "hello", "world", "[SEP]", "today", "is", "a", "good", "day", "[SEP]"], [1, 1, 1, 1, 2, 2, 2, 2, 2, 2])
```
"""
struct SequenceTemplate{T, Ts<:Tuple{Vararg{TemplateTerm{T}}}} <: Function
terms::Ts
function SequenceTemplate(terms::Tuple{Vararg{TemplateTerm{T}}}) where T
@assert length(terms) >= 1 "No TemplateTerm provided."
@assert count(Base.Fix2(isa, RepeatedTerm), terms) <= 1 "RepeatedTerm can only appear at most once."
return new{T, typeof(terms)}(terms)
end
end
SequenceTemplate(terms::TemplateTerm...) = SequenceTemplate(terms)

function process_term!(term::InputTerm, output, type_ids, i, j, terms, xs)
@assert j <= length(xs) "InputTerm indexing $j-th input but only get $(length(xs))"
x = xs[j]
append!(output, x)
append!(type_ids, Iterators.repeated(term.type_id, length(x)))
return j + 1
end

function process_term!(term::IndexInputTerm, output, type_ids, i, j, terms, xs)
idx = term.idx
@assert idx <= length(xs) "IndexInputTerm indexing $idx-th input but only get $(length(xs))"
x = xs[idx]
append!(output, x)
append!(type_ids, Iterators.repeated(term.type_id, length(x)))
return idx == j ? j + 1 : j
end

function process_term!(term::ConstTerm, output, type_ids, i, j, terms, xs)
push!(output, term.value)
push!(type_ids, term.type_id)
return j
end

function process_term!(term::RepeatedTerm, output, type_ids, i, j, terms, xs)
r_terms = term.terms
dynamic_type_id = term.dynamic_type_id
n = count(Base.Fix2(isa, InputTerm), terms[i+1:end])
J = length(xs) - n
type_id_offset = 0
while j <= J
type_id_start = length(type_ids) + 1
_j = j
for (t_i, term_i) in enumerate(r_terms)
j = process_term!(term_i, output, type_ids, t_i, j, r_terms, xs)
end
_j == j && error("RepeatedTerm doesn't seem to terminate")
type_id_end = length(type_ids)
dynamic_type_id != 0 && (type_ids[type_id_start:type_id_end] .+= type_id_offset)
type_id_offset += dynamic_type_id
end
return j
end

apply_template(st::SequenceTemplate) = Base.Fix1(apply_template, st)
function apply_template(st::SequenceTemplate{T}, xs) where T
terms = st.terms
len = length(xs)
n_input = count(Base.Fix2(isa, InputTerm), terms)
@assert len >= n_input "SequenceTemplate require at least $n_input but only get $len"

output = Vector{T}()
type_ids = Vector{Int}()

j = 1
for (i, term) in enumerate(terms)
j = process_term!(term, output, type_ids, i, j, terms, xs)
end
@assert j > len "SequenceTemplate only take $(j-1) inputs but get $len"
return output, type_ids
end

## static single sample
(st::SequenceTemplate{T})(xs::AbstractVector{T}...) where T = apply_template(st, xs)
(st::SequenceTemplate{T})(xs::Tuple{Vararg{AbstractVector{T}}}) where T = apply_template(st, xs)
(st::SequenceTemplate{T})(xs::AbstractVector{<:AbstractVector{T}}) where T = apply_template(st, xs)

## static multiple sample
(st::SequenceTemplate{T})(xs::AbstractArray{<:AbstractVector{<:AbstractVector{T}}}) where T = map(apply_template(st), xs)

## dynamic
function (st::SequenceTemplate{T})(xs::AbstractArray) where T
aoa, aov = allany(Base.Fix2(isa, AbstractArray), xs)
if aoa
if all(Base.Fix1(all, Base.Fix2(isa, T)), xs) # dynamic single sample
# xs is an array of sequence
return apply_template(st, xs)
elseif all(Base.Fix1(all, Base.Fix2(isa, AbstractArray)), xs) # dynamic multiple sample
# xs is an array of array of array
return map(st, xs)
else
throw(MethodError(st, xs))
end
elseif aov # dynamic single sample
# xs is a sequence
!all(Base.Fix2(isa, T), xs) && throw(MethodError(st, xs)) # assert eltype of sequence == T
return apply_template(st, (xs,))
else
throw(MethodError(st, xs))
end
end

_show(io, t::InputTerm) = print(io, "Input:<type=$(t.type_id)>")
_show(io, t::IndexInputTerm) = print(io, "Input[$(t.idx)]:<type=$(t.type_id)>")
_show(io, t::ConstTerm) = print(io, "$(t.value):<type=$(t.type_id)>")
function _show(io, t::RepeatedTerm)
print(io, '(')
_show(io, first(t.terms))
for term in Base.tail(t.terms)
print(io, ' ')
_show(io, term)
end
if iszero(t.dynamic_type_id)
print(io, ")...")
else
print(io, ")<type+=$(t.dynamic_type_id)>...")
end
end

Base.show(io::IO, ::MIME"text/plain", st::SequenceTemplate) = show(io, st)
function Base.show(io::IO, st::SequenceTemplate{T}) where T
print(io, "SequenceTemplate{", T, "}(")
_show(io, first(st.terms))
for term in Base.tail(st.terms)
print(io, ' ')
_show(io, term)
end
print(io, ')')
end
59 changes: 59 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using TextEncodeBase: AbstractTokenizer, AbstractTokenization,
TokenStages, Document, Sentence, Word, Token, Batch
using TextEncodeBase: getvalue, getmeta, updatevalue,
with_head_tail, trunc_and_pad, trunc_or_pad, nested2batch, nestedcall
using TextEncodeBase: SequenceTemplate, InputTerm, IndexInputTerm, ConstTerm, RepeatedTerm

using WordTokenizers

Expand Down Expand Up @@ -701,6 +702,64 @@ end

@test_throws DimensionMismatch nested2batch([[1:5], 2:6])
end

@testset "SequenceTemplate" begin
x = collect(1:5)
head_tail_template = SequenceTemplate(ConstTerm(-1), InputTerm{Int}(), ConstTerm(-2))
@test head_tail_template(x)[1] == with_head_tail(x, -1, -2)
@test nestedcall(x->x[1], head_tail_template(AbstractVector[[x], [1:5], [2:3]])) ==
map(x->x[1], with_head_tail(AbstractVector[[x], [1:5], [2:3]], -1, -2))
@test nestedcall(x->x[1], head_tail_template(Any[Any[x], [1:5], [2:3]])) ==
map(x->x[1], with_head_tail(Any[Any[x], [1:5], [2:3]], -1, -2))
@test nestedcall(x->x[1], head_tail_template(Any[Any[Any[0,1,2]]])) ==
with_head_tail(Any[Any[Any[0,1,2]]], -1, -2)[1]
@test_throws MethodError head_tail_template(Any[Any[x], 1:5, 2:3])
@test_throws Exception head_tail_template(Any[1:5, 2:3])
@test_throws Exception head_tail_template(Any[1:5, 2:3])
bert_template = SequenceTemplate(
ConstTerm("[CLS]", 1), InputTerm{String}(1), ConstTerm("[SEP]", 1),
RepeatedTerm(InputTerm{String}(2), ConstTerm("[SEP]", 2))
)
@test bert_template(["A"]) == (["[CLS]", "A", "[SEP]"], [1,1,1])
@test bert_template(["A"], ["B"]) == (["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])
@test bert_template(["A"], ["B"], ["C"]) ==
(["[CLS]", "A", "[SEP]", "B", "[SEP]", "C", "[SEP]"], [1,1,1,2,2,2,2])
@test bert_template([["A"], ["B"]]) == (["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])
@test bert_template([[["A"], ["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])]
@test bert_template(Any[[["A"], ["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])]
@test bert_template([Any[["A"], ["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])]
@test bert_template([Any[Any["A"], Any["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])]
@test bert_template(Any[Any[Any["A"], Any["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])]
bert_template2 = SequenceTemplate(
ConstTerm("[CLS]", 1), InputTerm{String}(1), ConstTerm("[SEP]", 1),
RepeatedTerm(InputTerm{String}(2), ConstTerm("[SEP]", 2); dynamic_type_id = true)
)
@test bert_template2(["A"]) == (["[CLS]", "A", "[SEP]"], [1,1,1])
@test bert_template2(["A"], ["B"]) == (["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])
@test bert_template2(["A"], ["B"], ["C"]) ==
(["[CLS]", "A", "[SEP]", "B", "[SEP]", "C", "[SEP]"], [1,1,1,2,2,3,3])
trail_template = SequenceTemplate(
IndexInputTerm{Int}(1, 1), RepeatedTerm(InputTerm{Int}(2)), IndexInputTerm{Int}(1, 1)
)
@test trail_template([3,5]) == ([3,5,3,5], [1,1,1,1])
@test trail_template([3,5],[1,2,4]) == ([3,5,1,2,4,3,5], [1,1,2,2,2,1,1])
@test SequenceTemplate(RepeatedTerm(InputTerm{Int}(3); dynamic_type_id = 2))(1:1, 2:2) == ([1,2],[3,5])
multi_repeat_template = SequenceTemplate(
ConstTerm(0,1),
RepeatedTerm(InputTerm{Int}(3), ConstTerm(1, 5), InputTerm{Int}(7); dynamic_type_id = 2),
ConstTerm(0,9)
)
@test multi_repeat_template() == ([0,0],[1,9])
@test_throws AssertionError multi_repeat_template(1:2)
@test multi_repeat_template(1:2, 3:4) == ([0,1,2,1,3,4,0], [1,3,3,5,7,7,9])
@test_throws AssertionError multi_repeat_template(1:2,3:4,5:6)
@test multi_repeat_template(1:2, 3:4,5:6,7:8) == ([0,1,2,1,3,4,5,6,1,7,8,0], [1,3,3,5,7,7,5,5,7,9,9,9])

@test sprint(show, bert_template2) ==
"SequenceTemplate{String}([CLS]:<type=1> Input:<type=1> [SEP]:<type=1> (Input:<type=2> [SEP]:<type=2>)<type+=1>...)"
@test sprint(show, trail_template) ==
"SequenceTemplate{Int64}(Input[1]:<type=1> (Input:<type=2>)... Input[1]:<type=1>)"
end
end

@testset "Encoder" begin
Expand Down

2 comments on commit 98c5e6e

@chengchingwen
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/68522

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.9 -m "<description of version>" 98c5e6e047f2fc4cf517ff7d124708d84a16a471
git push origin v0.5.9

Please sign in to comment.