-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SimVP model computational graph #1549
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## repo-refactor #1549 +/- ##
=================================================
- Coverage 78.16% 78.00% -0.16%
=================================================
Files 860 862 +2
Lines 27994 28299 +305
Branches 770 775 +5
=================================================
+ Hits 21881 22076 +195
- Misses 6113 6223 +110
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you double checked the model using the model viewer to make sure everything looks as expected?
Also, there need to be a lot more comments/links referencing the original code--the goal with these is to make it trivial for someone else a year or two in the future to look through the code without you and easily double check everything against the reference implementation without having to track a whole bunch of stuff down. Ideally also some assertions on tensor shapes.
Reviewed 5 of 5 files at r1, all commit messages.
Reviewable status: all files reviewed, 21 unresolved discussions (waiting on @easyeasydev and @reyna-abhyankar)
lib/models/src/models/simvp/simvp.cc
line 24 at r1 (raw file):
std::vector<bool> create_simvp_samplings(size_t N_S, bool reverse) { size_t N_S_even_floor = (N_S / 2) * 2;
Pull out a separate function named round_down_to_nearest_even
to improve readbility
Code quote:
size_t N_S_even_floor = (N_S / 2) * 2;
lib/models/src/models/simvp/simvp.cc
line 26 at r1 (raw file):
size_t N_S_even_floor = (N_S / 2) * 2; auto const change_to_true = [&](size_t idx) -> bool {
Is this actually the name used in the reference implementation? It seems rather weird.
Code quote:
change_to_true
lib/models/src/models/simvp/simvp.cc
line 37 at r1 (raw file):
} return samplings;
In all of these PRs use functions from utils/containers
more frequently to improve readbility.
Suggestion:
return transform(range(N_S_even_floor), change_to_true);
lib/models/src/models/simvp/simvp.cc
line 43 at r1 (raw file):
SimVPConfig const &config, tensor_guid_t const &input, size_t in_dim,
Prefer int
over size_t
lib/models/src/models/simvp/simvp.cc
line 68 at r1 (raw file):
SimVPConfig const &config, tensor_guid_t const &input) { size_t C = config.in_shape.at(1); // Channel
It seems like in_shape
has a fixed number of dimension, so it would probably be better to use four named fields (or even an additional struct) rather than a vector
Code quote:
size_t C = config.in_shape.at(1); // Channel
lib/models/src/models/simvp/simvp.cc
line 81 at r1 (raw file):
tensor_guid_t latent = enc1; for (size_t i = 1; i < samplings.size(); i++) {
Suggestion:
for (int sampling : subvec(samplings, 1, std::nullopt)) {
lib/models/src/models/simvp/simvp.cc
line 88 at r1 (raw file):
config.hid_S, config.spatio_kernel_enc, samplings[i],
Suggestion:
samplings.at(i)
lib/models/src/models/simvp/simvp.cc
line 105 at r1 (raw file):
float drop_path, float init_value) { return input;
Seems to be missing an actual implementation?
lib/models/src/models/simvp/simvp.cc
line 134 at r1 (raw file):
float drop, float drop_path) { if (config.model_type != "gSTA") {
Prefer an dtgen enum instead of a string for the model type
lib/models/src/models/simvp/simvp.cc
line 147 at r1 (raw file):
// Downsample z = create_simvp_gsta_meta_block( cgb, config, z, channel_in, channel_hid, mlp_ratio, drop, drop_path);
Add argument name comments for all invocations of these many-argument functions for readability
Suggestion:
z = create_simvp_gsta_meta_block(
/*cgb=*/cgb,
/*config=*/config,
/*input=*/z,
/*in_channels=*/channel_in,
/*out_channels=*/channel_hid,
/*mlp_ratio=*/mlp_ratio,
/*drop=*/drop,
/*drop_path=*/drop_path);
lib/models/src/models/simvp/simvp.cc
line 150 at r1 (raw file):
// Middle layers for (size_t i = 1; i < config.N_T - 1; i++) {
Suggestion:
for (int i : range(1, config.N_T - 1)) {
lib/models/src/models/simvp/simvp.cc
line 153 at r1 (raw file):
z = create_simvp_gsta_meta_block( cgb, config, z, channel_hid, channel_hid, mlp_ratio, drop, drop_path); }
Considering that this f(f(f(f(...f(x_0)...))))
pattern keeps showing up in model definitions, it might be nice to pull it out into a separate function in utils/containers
named something like "primitive_recurse" or "recurse_n" or something. Thoughts? I'm sure @Marsella8 would be happy to add such a function
Code quote:
for (size_t i = 1; i < config.N_T - 1; i++) {
z = create_simvp_gsta_meta_block(
cgb, config, z, channel_hid, channel_hid, mlp_ratio, drop, drop_path);
}
lib/models/src/models/simvp/simvp.cc
line 168 at r1 (raw file):
std::cout << "hid shape: " << cgb.get_shape(hid) << std::endl; std::cout << "skip shape: " << cgb.get_shape(skip) << std::endl;
Remove prints.
Also, I assume if you were looking at the shapes then you know what they should be, so you should add in asserts on the tensor shapes where possible to improve readability and reduce the possibility of bugs
Code quote:
std::cout << "hid shape: " << cgb.get_shape(hid) << std::endl;
std::cout << "skip shape: " << cgb.get_shape(skip) << std::endl;
lib/models/src/models/simvp/simvp.cc
line 174 at r1 (raw file):
tensor_guid_t latent = hid; for (size_t i = 0; i < samplings.size() - 1; i++) {
Suggestion:
for (bool sampling : subvec(samplings, 1, std::nullopt)) {
lib/models/src/models/simvp/simvp.cc
line 192 at r1 (raw file):
config.spatio_kernel_dec, false, samplings[samplings.size() - 1]);
Suggestion:
samplings.back());
lib/models/src/models/simvp/simvp.cc
line 207 at r1 (raw file):
size_t W = config.in_shape.at(3); // Width // std::cout << "B T C H W: " << B << " " << T << " " << C << " " << H << " "
Remove prints
lib/models/src/models/simvp/simvp.cc
line 219 at r1 (raw file):
auto [embed, skip] = create_simvp_encoder(cgb, config, input); // std::cout << "embed shape: " << cgb.get_shape(embed) << std::endl;
Remove prints, add assertions on shape
lib/models/src/models/simvp/simvp.cc
line 232 at r1 (raw file):
config.drop_path); // TODO: need to reshape hid here
What's the plan for all of these TODOs? Are you waiting on something else to be implemented?
lib/models/include/models/simvp/simvp_config.struct.toml
line 28 at r1 (raw file):
[[fields]] name = "batch_size" type = "size_t"
Prefer int
over size_t
lib/models/include/models/simvp/simvp_config.struct.toml
line 43 at r1 (raw file):
[[fields]] name = "N_T"
Let's keep variable names lower-case
Suggestion:
name = "n_t"
lib/models/include/models/simvp/simvp.h
line 18 at r1 (raw file):
std::vector<bool> create_simvp_samplings(size_t N_S, bool reverse = false); tensor_guid_t create_simvp_convsc(ComputationGraphBuilder &cgb,
Would be nice to have links here in the docstrings to wherever the equivalent code in OpenSTL is for each of these functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: all files reviewed, 21 unresolved discussions (waiting on @easyeasydev, @lockshaw, and @reyna-abhyankar)
lib/models/src/models/simvp/simvp.cc
line 153 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Considering that this
f(f(f(f(...f(x_0)...))))
pattern keeps showing up in model definitions, it might be nice to pull it out into a separate function inutils/containers
named something like "primitive_recurse" or "recurse_n" or something. Thoughts? I'm sure @Marsella8 would be happy to add such a function
See #1563
Description of changes:
This PR is to add the computational graph of the SimVP model.
Related Issues:
Linked Issues:
Issues closed by this PR:
This change is