diff --git a/Project.toml b/Project.toml index d8ff4fb..cdb1043 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SafeTensors" uuid = "eeda0dda-7046-4914-a807-2495fc7abb89" authors = ["pevnak and contributors"] -version = "1.1.0" +version = "1.1.1" [deps] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" diff --git a/src/SafeTensors.jl b/src/SafeTensors.jl index 59af07e..14b3826 100644 --- a/src/SafeTensors.jl +++ b/src/SafeTensors.jl @@ -134,6 +134,8 @@ function Base.getindex(x::Metadata, name) index = x.index_map[name] return @inbounds x.tensors[index] end +Base.haskey(x::Metadata, name) = haskey(x.index_map, name) +Base.get(x::Metadata, name, default) = haskey(x, name) ? x[name] : default StructTypes.StructType(::Type{Metadata}) = StructTypes.CustomStruct() function StructTypes.lower(x::Metadata) @@ -187,6 +189,8 @@ function Base.getindex(x::SafeTensor, name) info = getmetadata(x)[name] return _tensorslice(x.data, info) end +Base.haskey(x::SafeTensor, name) = haskey(getmetadata(x), name) +Base.get(x::SafeTensor, name, default) = haskey(x, name) ? x[name] : default _from_le(x) = mappedarray(ltoh, x) function _changemaj(x, shape::NTuple{N}) where N diff --git a/test/runtests.jl b/test/runtests.jl index a70d8b6..4921018 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -184,8 +184,10 @@ end jl_tensors = SafeTensors.deserialize(tfile; mmap = use_mmap) push!(jl_bytes, read(tfile)) @test jl_tensors.metadata == torch_tensors.metadata + @test isnothing(get(torch_tensors, "should_not_exist", nothing)) for (name, tensor) in torch_tensors @test collect(jl_tensors[name]) == collect(tensor) + @test haskey(torch_tensors, name) end end jl_bytes[1] == jl_bytes[2]