Skip to content

Commit

Permalink
added typing to checklist
Browse files Browse the repository at this point in the history
  • Loading branch information
cecoeco committed Oct 26, 2024
1 parent 57f8933 commit 206f180
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
1 change: 0 additions & 1 deletion src/PRISMA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ using Poppler_jll
using StatsBase
using TextEncodeBase
using Transformers
using Transformers.HuggingFace

import Base.Multimedia.display
import Base.show
Expand Down
30 changes: 15 additions & 15 deletions src/checklist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,16 +419,16 @@ function top_k_sample(probs; k=10)::Vector{Int64}
return index
end

function generate_text(context; max_length=50)
encoded = TextEncodeBase.encode(TEXT_ENCODER, context).token
ids = encoded.onehots
ends_id = TextEncodeBase.lookup(TEXT_ENCODER.vocab, TEXT_ENCODER.endsym)
function generate_text(prompt; max_length=25)::Vector{String}
encoded::OneHotArray = TextEncodeBase.encode(TEXT_ENCODER, prompt).token
ids::Vector = encoded.onehots
ends_id::Int64 = TextEncodeBase.lookup(TEXT_ENCODER.vocab, TEXT_ENCODER.endsym)
for _ in 1:max_length
input = (; token=encoded)
outputs = MODEL(input)
logits = Base.@view outputs.logit[:, end, 1]
input::NamedTuple = (; token=encoded)
outputs::NamedTuple = MODEL(input)
logits::SubArray = Base.@view outputs.logit[:, end, 1]
probs::Vector{Float32} = temp_softmax(logits)
new_id::Vector{Int64} = top_k_sample(probs)[1]
new_id::Int64 = top_k_sample(probs)[1]
Base.push!(ids, new_id)
new_id == ends_id && break
end
Expand Down Expand Up @@ -470,7 +470,7 @@ end

"""
checklist(paper::AbstractString)::Checklist
checklist(bytes::Vector{UInt8})::Checklist
checklist(paper::Vector{UInt8})::Checklist
This function returns a completed PRISMA checklist as the type `Checklist`.
The `Checklist` type includes a completed checklist as a `DataFrame` and the
Expand Down Expand Up @@ -504,7 +504,7 @@ If the parsing fails the value will be an empty string.
## Arguments
- `paper::AbstractString`: a path to a pdf file as a string
- `bytes::Vector{UInt8}`: the pdf data as an array of bytes
- `paper::Vector{UInt8}`: the pdf data as an array of bytes
## Returns
Expand All @@ -518,15 +518,15 @@ function checklist(paper::AbstractString)::Checklist
)
end

function checklist(bytes::Vector{UInt8})::Checklist
paper::String = Base.Filesystem.tempname()
Base.Filesystem.write(paper, bytes)
function checklist(paper::Vector{UInt8})::Checklist
temp_pdf::String = Base.Filesystem.tempname()
Base.Filesystem.write(temp_pdf, paper)
try
return checklist(paper)
return checklist(temp_pdf)
catch ex
Base.rethrow(ex)
finally
Base.Filesystem.rm(paper, force=true)
Base.Filesystem.rm(temp_pdf, force=true)
end
end

Expand Down

0 comments on commit 206f180

Please sign in to comment.