diff --git a/Project.toml b/Project.toml index 1522b20..198f613 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TextEncodeBase" uuid = "f92c20c0-9f2a-4705-8116-881385faba05" authors = ["chengchingwen and contributors"] -version = "0.5.8" +version = "0.5.9" [deps] FuncPipelines = "9ed96fbb-10b6-44d4-99a6-7e2a3dc8861b" diff --git a/src/utils.jl b/src/utils.jl index e1e7df9..cd312a8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -523,3 +523,221 @@ function _nested2batch!(arr, offset, x::AbstractArray{>:AbstractArray}) error("Input array is mixing array and non-array elements") end end + +# Sequence template + +""" + abstract type TemplateTerm{T} end + +Abstract type for term used in [`SequenceTemplate`](@ref). +""" +abstract type TemplateTerm{T} end + +""" + InputTerm{T}(type_id = 1) + +A `TemplateTerm` that take out a sequence from the input. +""" +struct InputTerm{T} <: TemplateTerm{T} + type_id::Int + InputTerm{T}(type_id = 1) where T = new{T}(type_id) +end + +""" + IndexInputTerm{T}(idx::Int, type_id = 1) + +A `TemplateTerm` that take the `idx`-th sequence of the input. If the `IndexInputTerm` is also the `idx`-th + input related term in a [`SequenceTemplate`](@ref), it behave the same as [`InputTerm`](@ref). +""" +struct IndexInputTerm{T} <: TemplateTerm{T} + idx::Int + type_id::Int + IndexInputTerm{T}(idx, type_id = 1) where T = new{T}(idx, type_id) +end + +""" + ConstTerm(value::T, type_id = 1) + +A `TemplateTerm` that simply put `value` to the output sequence. +""" +struct ConstTerm{T} <: TemplateTerm{T} + value::T + type_id::Int +end +ConstTerm(value, type_id = 1) = ConstTerm{typeof(value)}(value, type_id) + +""" + RepeatedTerm(terms::TemplateTerm...; dynamic_type_id = false) + +A special term that indicate the `terms` sequence can appear zero or multiple times. Cannot be nested. + If `dynamic_type_id` is set, each repeat would add an offset value to the type id of those repeat `terms`. + The offset value if the number of repetiton, starting form `0`, times `dynamic_type_id`. +""" +struct RepeatedTerm{T, Ts<:Tuple{Vararg{TemplateTerm{T}}}} <: TemplateTerm{T} + terms::Ts + dynamic_type_id::Int + function RepeatedTerm(terms::Tuple{Vararg{TemplateTerm{T}}}, dynamic_type_id = false) where T + @assert length(terms) >= 1 "No TemplateTerm provided." + @assert !any(Base.Fix2(isa, RepeatedTerm), terms) "Cannot nest RepeatedTerm" + return new{T, typeof(terms)}(terms, dynamic_type_id) + end +end +RepeatedTerm(terms::TemplateTerm...; dynamic_type_id = false) = RepeatedTerm(terms, dynamic_type_id) + +""" + SequenceTemplate(terms::TemplateTerm)(sequences...) + +Constructing a function by multiple `TemplateTerm` that indicate how to combine the input `sequences`. Return + a tuple of the result sequence and a type id (a special number associated with the template term) sequence. + +# Example + +```julia-repl +julia> SequenceTemplate(ConstTerm(-1), InputTerm{Int}(), ConstTerm(-2))(1:5)[1] == TextEncodeBase.with_head_tail(1:5, -1, -2) +true + +julia> SequenceTemplate(ConstTerm(-1), InputTerm{Int}(), ConstTerm(-2))(1:5) +([-1, 1, 2, 3, 4, 5, -2], [1, 1, 1, 1, 1, 1, 1]) + +julia> bert_template = SequenceTemplate( + ConstTerm("[CLS]", 1), InputTerm{String}(1), ConstTerm("[SEP]", 1), + RepeatedTerm(InputTerm{String}(2), ConstTerm("[SEP]", 2)) + ) +SequenceTemplate{String}([CLS]: Input: [SEP]: (Input: [SEP]:)...) + +julia> bert_template(["hello", "world"]) +(["[CLS]", "hello", "world", "[SEP]"], [1, 1, 1, 1]) + +julia> bert_template(["hello", "world"], ["today", "is", "a", "good", "day"]) +(["[CLS]", "hello", "world", "[SEP]", "today", "is", "a", "good", "day", "[SEP]"], [1, 1, 1, 1, 2, 2, 2, 2, 2, 2]) + +``` +""" +struct SequenceTemplate{T, Ts<:Tuple{Vararg{TemplateTerm{T}}}} <: Function + terms::Ts + function SequenceTemplate(terms::Tuple{Vararg{TemplateTerm{T}}}) where T + @assert length(terms) >= 1 "No TemplateTerm provided." + @assert count(Base.Fix2(isa, RepeatedTerm), terms) <= 1 "RepeatedTerm can only appear at most once." + return new{T, typeof(terms)}(terms) + end +end +SequenceTemplate(terms::TemplateTerm...) = SequenceTemplate(terms) + +function process_term!(term::InputTerm, output, type_ids, i, j, terms, xs) + @assert j <= length(xs) "InputTerm indexing $j-th input but only get $(length(xs))" + x = xs[j] + append!(output, x) + append!(type_ids, Iterators.repeated(term.type_id, length(x))) + return j + 1 +end + +function process_term!(term::IndexInputTerm, output, type_ids, i, j, terms, xs) + idx = term.idx + @assert idx <= length(xs) "IndexInputTerm indexing $idx-th input but only get $(length(xs))" + x = xs[idx] + append!(output, x) + append!(type_ids, Iterators.repeated(term.type_id, length(x))) + return idx == j ? j + 1 : j +end + +function process_term!(term::ConstTerm, output, type_ids, i, j, terms, xs) + push!(output, term.value) + push!(type_ids, term.type_id) + return j +end + +function process_term!(term::RepeatedTerm, output, type_ids, i, j, terms, xs) + r_terms = term.terms + dynamic_type_id = term.dynamic_type_id + n = count(Base.Fix2(isa, InputTerm), terms[i+1:end]) + J = length(xs) - n + type_id_offset = 0 + while j <= J + type_id_start = length(type_ids) + 1 + _j = j + for (t_i, term_i) in enumerate(r_terms) + j = process_term!(term_i, output, type_ids, t_i, j, r_terms, xs) + end + _j == j && error("RepeatedTerm doesn't seem to terminate") + type_id_end = length(type_ids) + dynamic_type_id != 0 && (type_ids[type_id_start:type_id_end] .+= type_id_offset) + type_id_offset += dynamic_type_id + end + return j +end + +apply_template(st::SequenceTemplate) = Base.Fix1(apply_template, st) +function apply_template(st::SequenceTemplate{T}, xs) where T + terms = st.terms + len = length(xs) + n_input = count(Base.Fix2(isa, InputTerm), terms) + @assert len >= n_input "SequenceTemplate require at least $n_input but only get $len" + + output = Vector{T}() + type_ids = Vector{Int}() + + j = 1 + for (i, term) in enumerate(terms) + j = process_term!(term, output, type_ids, i, j, terms, xs) + end + @assert j > len "SequenceTemplate only take $(j-1) inputs but get $len" + return output, type_ids +end + +## static single sample +(st::SequenceTemplate{T})(xs::AbstractVector{T}...) where T = apply_template(st, xs) +(st::SequenceTemplate{T})(xs::Tuple{Vararg{AbstractVector{T}}}) where T = apply_template(st, xs) +(st::SequenceTemplate{T})(xs::AbstractVector{<:AbstractVector{T}}) where T = apply_template(st, xs) + +## static multiple sample +(st::SequenceTemplate{T})(xs::AbstractArray{<:AbstractVector{<:AbstractVector{T}}}) where T = map(apply_template(st), xs) + +## dynamic +function (st::SequenceTemplate{T})(xs::AbstractArray) where T + aoa, aov = allany(Base.Fix2(isa, AbstractArray), xs) + if aoa + if all(Base.Fix1(all, Base.Fix2(isa, T)), xs) # dynamic single sample + # xs is an array of sequence + return apply_template(st, xs) + elseif all(Base.Fix1(all, Base.Fix2(isa, AbstractArray)), xs) # dynamic multiple sample + # xs is an array of array of array + return map(st, xs) + else + throw(MethodError(st, xs)) + end + elseif aov # dynamic single sample + # xs is a sequence + !all(Base.Fix2(isa, T), xs) && throw(MethodError(st, xs)) # assert eltype of sequence == T + return apply_template(st, (xs,)) + else + throw(MethodError(st, xs)) + end +end + +_show(io, t::InputTerm) = print(io, "Input:") +_show(io, t::IndexInputTerm) = print(io, "Input[$(t.idx)]:") +_show(io, t::ConstTerm) = print(io, "$(t.value):") +function _show(io, t::RepeatedTerm) + print(io, '(') + _show(io, first(t.terms)) + for term in Base.tail(t.terms) + print(io, ' ') + _show(io, term) + end + if iszero(t.dynamic_type_id) + print(io, ")...") + else + print(io, ")...") + end +end + +Base.show(io::IO, ::MIME"text/plain", st::SequenceTemplate) = show(io, st) +function Base.show(io::IO, st::SequenceTemplate{T}) where T + print(io, "SequenceTemplate{", T, "}(") + _show(io, first(st.terms)) + for term in Base.tail(st.terms) + print(io, ' ') + _show(io, term) + end + print(io, ')') +end diff --git a/test/runtests.jl b/test/runtests.jl index 39d0562..b4293e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,6 +19,7 @@ using TextEncodeBase: AbstractTokenizer, AbstractTokenization, TokenStages, Document, Sentence, Word, Token, Batch using TextEncodeBase: getvalue, getmeta, updatevalue, with_head_tail, trunc_and_pad, trunc_or_pad, nested2batch, nestedcall +using TextEncodeBase: SequenceTemplate, InputTerm, IndexInputTerm, ConstTerm, RepeatedTerm using WordTokenizers @@ -701,6 +702,64 @@ end @test_throws DimensionMismatch nested2batch([[1:5], 2:6]) end + + @testset "SequenceTemplate" begin + x = collect(1:5) + head_tail_template = SequenceTemplate(ConstTerm(-1), InputTerm{Int}(), ConstTerm(-2)) + @test head_tail_template(x)[1] == with_head_tail(x, -1, -2) + @test nestedcall(x->x[1], head_tail_template(AbstractVector[[x], [1:5], [2:3]])) == + map(x->x[1], with_head_tail(AbstractVector[[x], [1:5], [2:3]], -1, -2)) + @test nestedcall(x->x[1], head_tail_template(Any[Any[x], [1:5], [2:3]])) == + map(x->x[1], with_head_tail(Any[Any[x], [1:5], [2:3]], -1, -2)) + @test nestedcall(x->x[1], head_tail_template(Any[Any[Any[0,1,2]]])) == + with_head_tail(Any[Any[Any[0,1,2]]], -1, -2)[1] + @test_throws MethodError head_tail_template(Any[Any[x], 1:5, 2:3]) + @test_throws Exception head_tail_template(Any[1:5, 2:3]) + @test_throws Exception head_tail_template(Any[1:5, 2:3]) + bert_template = SequenceTemplate( + ConstTerm("[CLS]", 1), InputTerm{String}(1), ConstTerm("[SEP]", 1), + RepeatedTerm(InputTerm{String}(2), ConstTerm("[SEP]", 2)) + ) + @test bert_template(["A"]) == (["[CLS]", "A", "[SEP]"], [1,1,1]) + @test bert_template(["A"], ["B"]) == (["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2]) + @test bert_template(["A"], ["B"], ["C"]) == + (["[CLS]", "A", "[SEP]", "B", "[SEP]", "C", "[SEP]"], [1,1,1,2,2,2,2]) + @test bert_template([["A"], ["B"]]) == (["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2]) + @test bert_template([[["A"], ["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])] + @test bert_template(Any[[["A"], ["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])] + @test bert_template([Any[["A"], ["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])] + @test bert_template([Any[Any["A"], Any["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])] + @test bert_template(Any[Any[Any["A"], Any["B"]]]) == [(["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2])] + bert_template2 = SequenceTemplate( + ConstTerm("[CLS]", 1), InputTerm{String}(1), ConstTerm("[SEP]", 1), + RepeatedTerm(InputTerm{String}(2), ConstTerm("[SEP]", 2); dynamic_type_id = true) + ) + @test bert_template2(["A"]) == (["[CLS]", "A", "[SEP]"], [1,1,1]) + @test bert_template2(["A"], ["B"]) == (["[CLS]", "A", "[SEP]", "B", "[SEP]"], [1,1,1,2,2]) + @test bert_template2(["A"], ["B"], ["C"]) == + (["[CLS]", "A", "[SEP]", "B", "[SEP]", "C", "[SEP]"], [1,1,1,2,2,3,3]) + trail_template = SequenceTemplate( + IndexInputTerm{Int}(1, 1), RepeatedTerm(InputTerm{Int}(2)), IndexInputTerm{Int}(1, 1) + ) + @test trail_template([3,5]) == ([3,5,3,5], [1,1,1,1]) + @test trail_template([3,5],[1,2,4]) == ([3,5,1,2,4,3,5], [1,1,2,2,2,1,1]) + @test SequenceTemplate(RepeatedTerm(InputTerm{Int}(3); dynamic_type_id = 2))(1:1, 2:2) == ([1,2],[3,5]) + multi_repeat_template = SequenceTemplate( + ConstTerm(0,1), + RepeatedTerm(InputTerm{Int}(3), ConstTerm(1, 5), InputTerm{Int}(7); dynamic_type_id = 2), + ConstTerm(0,9) + ) + @test multi_repeat_template() == ([0,0],[1,9]) + @test_throws AssertionError multi_repeat_template(1:2) + @test multi_repeat_template(1:2, 3:4) == ([0,1,2,1,3,4,0], [1,3,3,5,7,7,9]) + @test_throws AssertionError multi_repeat_template(1:2,3:4,5:6) + @test multi_repeat_template(1:2, 3:4,5:6,7:8) == ([0,1,2,1,3,4,5,6,1,7,8,0], [1,3,3,5,7,7,5,5,7,9,9,9]) + + @test sprint(show, bert_template2) == + "SequenceTemplate{String}([CLS]: Input: [SEP]: (Input: [SEP]:)...)" + @test sprint(show, trail_template) == + "SequenceTemplate{Int64}(Input[1]: (Input:)... Input[1]:)" + end end @testset "Encoder" begin