diff --git a/src/loading.jl b/src/loading.jl index 01090d95b0..a4f33a3304 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,4 +1,8 @@ loadleaf!(dst, src, err) = dst +loadleaf!(dst::AbstractArray, src, err) = + error("Tried to copy $src into an array destination; this is not allowed.") +loadleaf!(dst, src::AbstractArray, err) = + error("Tried to copy an array to $dst; this is not allowed.") function loadleaf!(dst::AbstractArray, src::Bool, err) if iszero(src) dst .= src @@ -30,20 +34,12 @@ _bool_tie_check(dst, src) = true Copy all the parameters (trainable and non-trainable) from `src` into `dst`. Recursively walks `dst` and `src` together using [`Functors.children`](@ref), -and calling `copyto!` on parameter arrays. -Non-array elements (such as activation functions) are not copied -and do not need to match between `dst` and `src`. -Inactive parameters can be encoded by using the boolean value `false` instead of an array. -If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); -however, attempting to copy a non-zero array to an inactive parameter will throw an error. -Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error. - -Throws an error when: -- `dst` and `src` do not share the same fields (at any level) -- the sizes of leaf nodes are mismatched between `dst` and `src` -- `dst` is a "tied" parameter (i.e. refers to another parameter) and - loaded into multiple times with mismatched source values +and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch. +Non-array elements (such as activation functions) are not copied and need not match. +Zero bias vectors and `bias=false` are considered equivalent +(see extended help for more details). +# Examples ```julia julia> using Flux: loadmodel! @@ -69,6 +65,22 @@ true julia> dst[2].bias == src[2].bias true ``` + +# Extended help + +Throws an error when: +- `dst` and `src` do not share the same fields (at any level) +- the sizes of leaf nodes are mismatched between `dst` and `src` +- copying non-array values to/from an array parameter + (except inactive parameters described below) +- `dst` is a "tied" parameter (i.e. refers to another parameter) and + loaded into multiple times with mismatched source values + +Inactive parameters can be encoded by using the boolean value `false` instead of an array. +If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); +however, attempting to copy a non-zero array to an inactive parameter will throw an error. +Likewise, copying a `src` value of `false` to any `dst` array is valid, +but copying a `src` value of `true` will error. """ function loadmodel!(dst, src; cache = Base.IdSet()) ldsts, _ = functor(dst)