Skip to content

Commit

Permalink
refine pipeline show & add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed May 19, 2022
1 parent 7f70693 commit 8d71424
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 14 deletions.
49 changes: 35 additions & 14 deletions src/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,34 @@ julia> pipes(0.3)
Pipelines

# display
@nospecialize

show_pipeline_function(io::IO, f1::Base.Fix1) = print(io, f1.f, '(', f1.x, ')')
show_pipeline_function(io::IO, f2::Base.Fix2) = print(io, "(x)->", f2.f, "(x, ", f2.x, ')')
show_pipeline_function(io::IO, fr::FixRest) = (print(io, fr.f, '('); join(io, fr.arg, ", "); print(io, ')'))
function show_pipeline_function(io::IO, f1::Base.Fix1)
print(io, "(x->")
show_pipeline_function(io, f1.f)
print(io, '(', f1.x, ", x))")
end
function show_pipeline_function(io::IO, f2::Base.Fix2)
print(io, "(x->")
show_pipeline_function(io, f2.f)
print(io, "(x, ", f2.x, "))")
end
function show_pipeline_function(io::IO, a::ApplyN)
print(io, "(args...->")
show_pipeline_function(io, a.f)
print(io, "(args[", _nth(a), "]))")
end
function show_pipeline_function(io::IO, a::ApplySyms)
print(io, "((; kwargs...)->")
show_pipeline_function(io, a.f)
print(io, "(kwargs[", '(', _syms(a), ')', "]...))")
end
function show_pipeline_function(io::IO, fr::FixRest)
show_pipeline_function(io, fr.f)
print(io, '(')
join(io, fr.arg, ", ")
print(io, ')')
end
function show_pipeline_function(io::IO, c::ComposedFunction, nested=false)
if nested
show_pipeline_function(io, c.outer, nested)
Expand Down Expand Up @@ -219,24 +243,16 @@ function show_pipeline_function(io::IO, p::Pipeline)
_show_pipeline_fixf(io, g, :source)
elseif n == 2
if g isa ApplySyms
show_pipeline_function(io, g.f)
syms = _syms(g)
if syms isa Tuple
print(io, "(target.")
join(io, syms, ", target.")
print(io, ')')
else
print(io, "(target.$syms)")
end
_show_pipeline_fixf(io, g.f, syms isa Tuple ? join(map(x->"target.$x", syms), ", ") : "target.$syms")
else
_show_pipeline_fixf(io, g, :target)
end
else
print(io, p.f)
end
else
print(io, p.f)
print(io, "(source, target)")
_show_pipeline_fixf(io, p.f, "source, target")
end
end

Expand All @@ -252,7 +268,10 @@ function show_pipeline_function(io::IO, p::PipeGet)
end

function Base.show(io::IO, p::Pipeline)
print(io, "Pipeline{$(_name(p))}(")
print(io, "Pipeline{")
name = _name(p)
name isa Tuple ? (print(io, '('); join(io, name, ", "); print(io, ')')) : print(io, name)
print(io, "}(")
show_pipeline_function(io, p)
print(io, ')')
end
Expand Down Expand Up @@ -295,3 +314,5 @@ function Base.show(io::IO, ps::Pipelines)
flat = get(io, :compact, false)
show_pipeline(io, ps; flat, prefix)
end

@specialize
41 changes: 41 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,17 @@ TextEncodeBase.splittability(::CharTk, x::Word) = TextEncodeBase.Splittable()
end
end
end

@testset "show" begin
@test sprint(show, FlatTokenizer()) == "FlatTokenizer(default)"
@test sprint(show, FlatTokenizer(WordTokenization(tokenize=poormans_tokenize))) == "FlatTokenizer(WordTokenization(split_sentences = WordTokenizers.split_sentences, tokenize = WordTokenizers.poormans_tokenize))"
@test sprint(show, FlatTokenizer(IndexedTokenization())) == "FlatTokenizer(IndexedTokenization(default))"
@test sprint(show, FlatTokenizer(MatchTokenization([r"\d", r"en"]))) == "FlatTokenizer(MatchTokenization(default, patterns = Regex[r\"\\d\", r\"en\"]))"
@test sprint(show, FlatTokenizer(IndexedTokenization(MatchTokenization([r"\d", r"en"])))) == "FlatTokenizer(IndexedTokenization(MatchTokenization(default, patterns = Regex[r\"\\d\", r\"en\"])))"
@test sprint(show, NestedTokenizer(IndexedTokenization())) == "NestedTokenizer(IndexedTokenization(default))"
@test sprint(show, FlatTokenizer(IndexedTokenization(CharTk()))) == "FlatTokenizer(IndexedTokenization(CharTk))"
@test sprint(show, NestedTokenizer(IndexedTokenization(MatchTokenization(CharTk(), [r"\d", r"en"])))) == "NestedTokenizer(IndexedTokenization(MatchTokenization(CharTk, patterns = Regex[r\"\\d\", r\"en\"])))"
end
end

@testset "Vocabulary" begin
Expand Down Expand Up @@ -429,19 +440,27 @@ TextEncodeBase.splittability(::CharTk, x::Word) = TextEncodeBase.Splittable()
@testset "Pipelines" begin
p1 = Pipeline{:x}((x,_)->x)
p2 = Pipeline{(:sinx, :cosx)}((x, _)->sincos(x))
p3 = Pipeline{:z}(x->3x+5, :x)
ps1 = Pipelines(p1, p2)
ps2 = Pipelines(Pipeline{:x}(identity, 1), Pipeline{(:sinx, :cosx)}(y->sincos(y.x), 2))
ps3 = ps2 |> PipeGet{:x}()
ps4 = ps2 |> PipeGet{(:x, :sinx)}()
ps5 = p1 |> p2 |> Pipeline{:xsinx}(*, (:x, :sinx))

@test ps5[begin:end] == ps5.pipes

@test p3(0, (x = 2,)) == (x = 2, z = 11)
@test ps1(0.5) == ps2(0.5)
@test ps3(0.2) == 0.2
@test ps4(0.3) == (x = 0.3, sinx = sin(0.3))
@test ps5(0.7) == (x = 0.7, sinx = sin(0.7), cosx = cos(0.7), xsinx = 0.7*sin(0.7))
@test_inferred p1(0.3)
@test_inferred p2(0.5)
@test_inferred p3(0, (x = 2,))
@test_inferred ps1(0.5)
@test_inferred ps2(0.5)
@test_inferred ps3(0.5)
@test_inferred ps5(0.5)

@test p1 |> p2 == ps1
@test ps1 |> p1 == Pipelines(p1, p2, p1)
Expand All @@ -451,5 +470,27 @@ TextEncodeBase.splittability(::CharTk, x::Word) = TextEncodeBase.Splittable()
@test_throws Exception Pipeline{:x}(identity, 3)
@test_throws Exception Pipeline{()}(identity)
@test_throws Exception Pipelines(())

@testset "show" begin
@test sprint(show, Pipeline{:x}(-, 1)) == "Pipeline{x}(-(source))"
@test sprint(show, Pipeline{(:sinx, :cosx)}(sincos, 1)) == "Pipeline{(sinx, cosx)}(sincos(source))"
@test sprint(show, Pipeline{(:sinx, :cosx)}(sincos, :x)) == "Pipeline{(sinx, cosx)}(sincos(target.x))"
@test sprint(show, Pipeline{(:tanx, :tany)}(Base.Fix1(map, tan), 2)) == "Pipeline{(tanx, tany)}(map(tan, target))"
@test sprint(show, Pipeline{:x2}(Base.Fix2(/, 2), :x)) == "Pipeline{x2}(/(target.x, 2))"
@test sprint(show, Pipeline{:z}(sincos, :x)) == "Pipeline{z}((sin ∘ cos)(target.x))"
@test sprint(show, Pipeline{:z}(Base.Fix1(*, 2) Base.Fix2(+, 1))) == "Pipeline{z}(((x->*(2, x)) ∘ (x->+(x, 1)))(source, target))"
@test sprint(show, Pipeline{:tok}(trunc_and_pad(nothing, 0), :tok)) == "Pipeline{tok}(trunc_and_pad(nothing, 0)(target.tok))"
@test sprint(show, PipeGet{:x}()) == "Pipeline{x}((target.x))"
@test sprint(show, PipeGet{(:a, :b)}()) == "Pipeline{(a, b)}((target.a, target.b))"

foo(x, y) = x * y.sinx
@test sprint(
show,
Pipeline{:x}(identity, 1) |> Pipeline{(:sinx, :cosx)}(sincos, :x) |>
Pipeline{:xsinx}(foo) |> PipeGet{(:cosx, :xsinx)}()
; context=:compact=>true
) ==
"Pipelines(target[x] := identity(source); target[(sinx, cosx)] := sincos(target.x); target[xsinx] := foo(source, target); target := (target.cosx, target.xsinx))"
end
end
end

0 comments on commit 8d71424

Please sign in to comment.