-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a structural loadparams!
#1875
Conversation
Because no API addition is complete without bikeshedding, might I link to the discussion in beacon-biosignals/LegolasFlux.jl#4 (comment) 😜. ( |
Can you write briefly what this does & doesn't do? E.g. it copies non-trainable parameter arrays, unlike restructure, but not integer sizes/strides, not activation functions, how about dropout rate? At least as a docstring. What's the argument for the name change? What happens with immutable arrays? Maybe it should work like What this custom traverse doesn't allow is recovering all the weights from something like It seems you could still have nice errors, even for custom layers, by just keeping track of the one-layer-below name/summary. Not
Do we need the That's a lot of questions, sorry! |
Yes, this PR still needs docstrings, examples, and doc updates before it's ready.
I had not seen that! "State" might be better than "model" here. I'm open to changing it or keeping the old name.
Flux uses "parameters" to mean "trainable parameters," while this function loads both the trainable parameters and the non-trainable state. I wanted the name to make clear that we are considering the model's structure when loading. My main motivation was so that a user noticed the change, but I'm happy to keep the old name.
Right now they would just fail (though I could give a better error message). It would be easy enough to make this behave like But I am waffling on what's the right behavior. Before calling As it stands, Flux doesn't work with immutable parameters. The gradients can be immutable, but not the parameters. I figure we can make this function behave like
Believe this is asking a similar thing to what I addressed above, but I wasn't sure. Could you rephrase if I misunderstood?
Good point, I'll refactor.
I'm not sure if it is used somewhere, but the old |
I'm not super-sure. At the moment it needs two valid Flux models, matching in most respects (but perhaps one is trained). My thought was that it could easily take a Flux model and anything else with a matching tree of structs. And this might be useful because a nested set of NamedTuples is going to be more robust to save/serialise/etc, as it can be handled without Flux's (or your model's) special structs. In fact, I almost wonder if it should be more like julia> m = Chain(Dense(2,2,tanh));
julia> m0, re = somefun!(m); # creating m0 is approximately free.
julia> m0 # this contains exactly what will be re-loaded, and `nothing` else -- and makes it easy to inspect what "parameters and state" are going to be loaded for layer X
(layers = ((weight = Float32[-1.1553937 1.2085382; -0.27197266 0.09527644], bias = Float32[0.0, 0.0], σ = nothing),),)
julia> re(m0) # this is like `loadparams!`, and could have a one-step method e.g. `somefun!(m, m0)`
Chain(
Dense(2, 2, tanh), # 6 parameters
) |
The two valid models requirement doesn't seem necessary for this PR. Assuming structural similarity, I don't see any obvious barriers to using the nested namedtuple as the second arg instead of an actual model struct (i.e. drop the |
Okay, I see what you mean. That would be a larger step than this PR, but something I am willing to do. This would basically be two pieces:
The "loading" step can already be done by this PR as Brian mentioned (some tweaks necessary of course). Ultimately, the code can be written so that the saving and loading are both just calling In the |
I likewise thought I also think the ability to run the "un-load" half and see exactly which bits of the model are and are not captured is a nice thing. Instead of trying to read the docs for what a parameter is, whether X is trainable, what happens if Y isn't a functor... you can just try it and get the ground truth.
But does it have to? This
Maybe. But moving on from implicit params (& introducing a new name) sounds like a good point at which to figure out the right design, rather than inflict changes later. |
re-documentation, it would be great if the docs made sure to specify the interface for custom layers to participate (e.g. when to define |
Okay let's just make sure we're on the same page, cause all the symbols are making me confused. We have:
|
Personally, I'm not a fan of the |
This seems redundant, I would hope that
I think this is an argument against defining things for every built-in layer, if we can possibly avoid it. The interface is
But it's not necessary for The difference between the two is that |
I agree that between The fact that |
Let me be more precise: In contrast, |
I think you are assuming that most uses of But this "most" doesn't seem absolute. You may want flat parameters to e.g. use them for some regularisation within the loss. You may want Base structs to save them. On the other hand, you may also (as in the present PR) want only the The two seem more and more analogous to me, differing in what the Base-only form looks like (nested or flat) and whether non-trainable parameters are included. I suppose I'm advocating that they have similar user-facing interfaces, more or less, so that there are fewer different things to remember. (Whether they can share any code I don't know -- most of the code in The present implementation of |
This hasn't been discussed so far. You may indeed want to turn the complete state including non-trainable parameters into a flat vector. This could be done like One reason I asked about the ::AbstractVector method above was wondering whether |
I guess I am in agreement at the highest level of this discussion, but I'm confused about what's actually being proposed. So, can we make |
Ultimately, I want to deprecate that path entirely, but I plan on doing it once implicit params are gone. |
Maybe I should think some more and write the options somewhere. But not today. Sorry about derailing the PR! tl;dr is that I'd vote not to introduce a new |
All this makes sense for
To me is a great argument for having the two separate load/save functions, because otherwise you're incurring unnecessary work to generate both of More generally, the two return value API is honestly kinda weird for users, especially those coming from Python. My understanding has always been that it was a necessary evil in order to get acceptable performance for
would've been a less confusing option. Moreover, while But my biggest concern (which I apologize for not thinking of earlier) is more basic. In PyTorch, I can do |
https://github.com/beacon-biosignals/LegolasFlux.jl/blob/main/src/functors.jl has a basic implementation that has worked well and been stable, although I don't think anyone has pushed it too far in terms of variety of layers etc. In terms of
It could be helpful to know what features are missing from a basic implementation like that. |
No, I don't think this is true. The old implementation literally closes over The separated version was what I proposed in FluxML/Functors.jl#31, but I got the impression everyone preferred the |
I had a look back at the original issues, and I think I got some wires crossed reading #986 (comment). For posterity, #799 seems to have the actual design discussion. That said, the point about performance is true now that |
The main missing feature is the kind of structural error checking that you get from this design. Collecting into a flat vector could always silently be wrong, and when it does catch an error, the best you can say is "some parameters are missing." Here, we can be more helpful since we know exactly what structure is being loaded. It also allows for the convenient syntax of Okay, so we've had more discussion on this PR than I expected (which is good)! And I think at a high level the concept of "model -> simple structure" is shared between this and For now, we can let this PR stew if needed. I will say that this function is needed to make Metalhead.jl work with pre-trained models. So, I propose the following path forward:
|
I've also significantly simplified the implementation so that most custom types will participate in the thorough error checking for free. This PR considers any type for which I will add documentation soon too which should hopefully help narrow the discussion as well. |
What are our final thoughts here? I see several transformations we want to do:
There are all kinds of ways to write the implementations for these transforms to share code. Ultimately, I think the code that they really share is the concept of walking a tree from Functors.jl. So my vote here is not to write Given #1882, I would suggest we move forward with this PR and place either |
I'm not sure if |
@ericphanson let me know if the new version clears things up |
src/loading.jl
Outdated
Inactive parameters, encoded by `false` in place on an array, | ||
can be copied to and from all-zero arrays. | ||
Attempting to copy a non-zero array to/from an inactive parameter will throw an error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is more clear?
Inactive parameters, encoded by `false` in place on an array, | |
can be copied to and from all-zero arrays. | |
Attempting to copy a non-zero array to/from an inactive parameter will throw an error. | |
Inactive parameters can be encoded by using the boolean value `false` instead of an array. | |
If `src` or `dst` has `false` where the other model has an all-zero array, no error will be raised (and no values copied). However, attempting to copy a non-zero array to/from an inactive parameter will throw an error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I used this with a few modifications, because your comment made me realize that the behavior is not 1-1 like the docstring implies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we shorten this somehow? It seems a very tiny edge case about layers with bias=false
vs. models with an actual bias array. Most likely this will never happen in real life. Yet somehow it gets an essay describing all possible paths.
How about just "Zero bias and bias=false
are considered equivalent."?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could put it under extended help? See point 11 of https://docs.julialang.org/en/v1/manual/documentation/.
IMO magic should at least be clearly documented...
src/loading.jl
Outdated
Inactive parameters can be encoded by using the boolean value `false` instead of an array. | ||
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); | ||
however, attempting to copy a non-zero array to an inactive parameter will throw an error. | ||
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tiny nitpick: src
isn't == false
, but rather one of its values:
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error. | |
Likewise, copying a `src` value of `false` to any `dst` array is valid, but copying a `src` value of `true` will error. |
src/loading.jl
Outdated
and do not need to match between `dst` and `src`. | ||
Inactive parameters can be encoded by using the boolean value `false` instead of an array. | ||
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); | ||
however, attempting to copy a non-zero array to an inactive parameter will throw an error. | ||
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible shortening:
and do not need to match between `dst` and `src`. | |
Inactive parameters can be encoded by using the boolean value `false` instead of an array. | |
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); | |
however, attempting to copy a non-zero array to an inactive parameter will throw an error. | |
Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error. | |
and need not match. | |
Zero-valued arrays and boolean `false` (which is Flux's encoding of absent bias) are considered equivalent. |
(edited not to be so specific to bias)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strictly speaking, the rule applies to anything not just bias. But bias should be the only occurrence of this rule in practice.
@ericphanson how do you like this shortened version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like it since like you say, it sounds like this only applies to bias
(and to vectors), and it doesn't give the full semantics. Re-#1875 (comment), I think if we want a short docstring, then we should just put more of the details under extended help, so it only shows up in the online docs or if you do ?? loadparams!
in the REPL.
In my view, special casing false
, and allowing it to interop with zero-arrays is a bit magical, and therefore should at least be clearly documented, since it's not something you can really predict from the rest of the behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. [Edit -- crossed in the mail...]
bias=false
is the one official API for making variants of layers, which aims to handle. Others, like affine=false
, are not -- the models must simply match.
The fact that you could, perversely, use false
elsewhere, and trigger the feature, seems like we are now describing ways to hack the code to do other things. There are many others. E.g. loadleaf!(dst, src, err) = dst
means that if dst has an array, and src has any other non-array (like 1.0, or Dense), then nothing will happen. Sufficiently far off the intended track, the source is the only truth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this is supposed to work with custom layers, right? So who knows how someone is using false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that works for me :)
my concern is about well-documented unambiguous semantics so this can be a reliable model serialization tool, including for models with custom layers etc. I.e. the Flux as a library of composable building blocks thing.
I think @mcabbott's concerns are about making it simple and keeping Flux self-consistent (but not necessarily worried about interactions outside of Flux itself). I think simple + consistent is important too, and extended help can let us achieve both, to some extent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, didn't see your comment @mcabbott. I would be fine with removing the boolean <-> array special casing altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW I think the reason allowing other mismatches is, I think, that the other half of this "model serialization tool" is something like
function simpletree(m)
fmapstructure(m; prune=nothing) do x
# We know isleaf(x), but further keep only values modelcopy! will accept:
x isa AbstractArray && return x
x === false && return x # if we keep that...
nothing
end
end
which should produce a nested set of NamedTuples, with only the details this thing will load --- no layer types, no activation functions, and tied arrays appear only once. If nothing
is the magic value for this, then we probably want a method to ignore it on loading:
loadleaf!(dst, src::Nothing, err) = dst
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think saving a model trained w/o bias and loading into a model w/ bias that you intend to fine-tune is a pretty reasonable/common use case. This is the pre-trained model flow, not the save my own model and load my own model flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've mentioned it before, but PyTorch has a pretty compelling model for dealing with these mismatches: load_state_dict
errors by default, but also has a non-strict mode where it returns a symmetric diff of the source and destination model trees. All this to say that the behaviour here need not be set in stone, and that we should strive to be at least as good about telling the user about how/why loading failed when it does.
@@ -0,0 +1,92 @@ | |||
loadleaf!(dst, src, err) = dst |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder if there should be more errors here:
loadleaf!(dst, src, err) = dst | |
loadleaf!(dst, src, err) = dst | |
loadleaf!(dst::AbstractArray, src, err) = error() | |
loadleaf!(dst, src::AbstractArray, err) = error() |
I can imagine that allowing src
to have nothing
means "don't change the existing weight". Which is what #1875 (comment) would generate. But it may also make truncations of branches not just leaves, which aren't allowed right now, but would I think be easy:
loadleaf!(dst, src::Nothing, err) = dst
loadleaf!(dst:: AbstractArray, src::Nothing, err) = dst
loadmodel!(dst, src::Nothing; cache = Base.IdSet()) = dst
Okay I went with the extended help suggestion, but if special casing |
…me other review comments
Co-authored-by: Michael Abbott <[email protected]>
3899c35
to
6b533b8
Compare
This replaces
loadparams!
withloadmodel!
which usesfmap
to structurally walk the model and copy parameters over. Right now it mutates destination model, so fields like the activation are not copied.I opted to have a more verbose implementation than the one-liner
fmap(loadto!, m, mbar)
. It allows us to have more informative error messages for the standard layers. Custom layers will fallback to the error thrown by Functors.jl.PR Checklist