Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimVP model computational graph #1549

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/models/include/models/candle_uno/candle_uno.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ CandleUnoConfig get_default_candle_uno_config();
* this model.
*
* @param CandleUnoConfig The config of the Candle Uno model.
* @return ComputationGraph The PCG of a Transformer model.
* @return ComputationGraph The computation graph of a Candle Uno model.
*/
ComputationGraph get_candle_uno_computation_graph(CandleUnoConfig const &);

Expand Down
75 changes: 75 additions & 0 deletions lib/models/include/models/simvp/simvp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SIMVP_H
#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SIMVP_H

#include "pcg/computation_graph_builder.h"
#include "simvp_config.dtg.h"

namespace FlexFlow {

// Helper functions to construct the SimVP model

/**
* @brief Get the default configs of SimVP model.
*/
SimVPConfig get_default_simvp_config();

std::vector<bool> create_simvp_samplings(size_t N_S, bool reverse = false);

tensor_guid_t create_simvp_convsc(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input,
size_t in_dim,
size_t out_dim,
int kernel_size = 3,
bool downsampling = false,
bool upsampling = false);

tensor_guid_t create_simvp_gsta_meta_block(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input,
int in_channels,
int out_channels,
float mlp_ratio = 8.0,
float drop = 0.0,
float drop_path = 0.0);

tensor_guid_t create_simvp_ga_sub_block(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input,
int dim,
int kernel_size = 21,
float mlp_ratio = 4.0,
float drop = 0.0,
float drop_path = 0.1,
float init_value = 1e-2);

std::pair<tensor_guid_t, tensor_guid_t>
create_simvp_encoder(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input);

tensor_guid_t create_simvp_middle_net(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &embed,
int channel_in,
int channel_hid,
float mlp_ratio = 4.0,
float drop = 0.0,
float drop_path = 0.1);

tensor_guid_t create_simvp_decoder(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &hid,
tensor_guid_t const &skip);

/**
* @brief Get the SimVP computation graph.
*
* @param SimVPConfig The config of the SimVP model.
* @return ComputationGraph The computation graph of a SimVP model.
*/
ComputationGraph get_simvp_computation_graph(SimVPConfig const &config);

} // namespace FlexFlow

#endif
72 changes: 72 additions & 0 deletions lib/models/include/models/simvp/simvp_config.struct.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
namespace = "FlexFlow"
name = "SimVPConfig"

features = [
"eq",
"ord",
"hash",
"json",
"rapidcheck",
"fmt",
]

includes = [
"<vector>",
"<map>",
"<string>",
]

src_includes = [
"utils/fmt/vector.h",
"utils/fmt/map.h",
"utils/hash/vector.h",
"utils/hash/map.h",
]

[[fields]]
name = "batch_size"
type = "size_t"

[[fields]]
name = "hid_S"
type = "size_t"

[[fields]]
name = "hid_T"
type = "size_t"

[[fields]]
name = "N_S"
type = "size_t"

[[fields]]
name = "N_T"
type = "size_t"

[[fields]]
name = "model_type"
type = "std::string"

[[fields]]
name = "mlp_ratio"
type = "float"

[[fields]]
name = "drop"
type = "float"

[[fields]]
name = "drop_path"
type = "float"

[[fields]]
name = "spatio_kernel_enc"
type = "size_t"

[[fields]]
name = "spatio_kernel_dec"
type = "size_t"

[[fields]]
name = "in_shape"
type = "std::vector<size_t>"
Loading
Loading