Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Nov 9, 2024
1 parent 242fdca commit 99ecd8f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion lib/models/src/models/dlrm/dlrm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) {
tensor_guid_t dense_input = create_input_tensor(
{config.batch_size, config.dense_arch_layer_sizes.front()},
DataType::FLOAT); // TODO: change this to DataType::FLOAT
// after cgb.cast is implemented.
// after cgb.cast is implemented.

// Construct the model
tensor_guid_t bottom_mlp_output = create_dlrm_mlp(
Expand Down
9 changes: 4 additions & 5 deletions lib/pcg/src/pcg/computation_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ static std::string get_default_name(ComputationGraphOpAttrs const &attrs) {
return get_default_name(get_op_type(attrs));
}


ComputationGraphBuilder::ComputationGraphBuilder()
: computation_graph(make_empty_computation_graph()) {}

Expand Down Expand Up @@ -170,10 +169,10 @@ tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input,
this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)}));
}

tensor_guid_t
ComputationGraphBuilder::cast(tensor_guid_t const &input,
DataType dtype,
std::optional<std::string> const &maybe_name) {
tensor_guid_t ComputationGraphBuilder::cast(
tensor_guid_t const &input,
DataType dtype,
std::optional<std::string> const &maybe_name) {

CastAttrs attrs = CastAttrs{dtype};

Expand Down

0 comments on commit 99ecd8f

Please sign in to comment.