Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
cowanmeg committed Dec 13, 2024
1 parent eb11636 commit bf3de16
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 236 deletions.
231 changes: 0 additions & 231 deletions benchmarks/cpp/transformer.cpp

This file was deleted.

5 changes: 0 additions & 5 deletions tests/cpp/multidevice_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ MlpResult DistributedTransformer::mlp(
x->axis(0)->parallelize(ParallelType::Serial);
// Reshape back to 2D. This is uncessary except to keep
// the shapes of linear0 the same for TP and TP+SP.
auto D = w0->axis(0)->extent()->value().as<int64_t>();
x = reshape(x, {D, B * S / D, E}, {B * S, E});
}
// Linear 0
Expand All @@ -124,7 +123,6 @@ MlpResult DistributedTransformer::mlp(
if (sequence_parallel) {
// Remove after https://github.com/NVIDIA/Fuser/issues/2563
// Reshape to explicitly pull the sharded axis into the logical domain
auto D = w0->axis(0)->extent()->value().as<int64_t>();
local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E});
}
TensorView* matmul1 = sum(local_matmul1, {0}); // Allreduce or Reduce scatter
Expand Down Expand Up @@ -162,7 +160,6 @@ MhaResult DistributedTransformer::mha(
TensorView* b1,
const DeviceMesh& mesh,
bool sequence_parallel) {
const auto D = w0->axis(0)->extent()->value().as<int64_t>();
auto dtype = w0->dtype();

if (sequence_parallel) {
Expand Down Expand Up @@ -202,7 +199,6 @@ MhaResult DistributedTransformer::mha(
if (sequence_parallel) {
// Remove after https://github.com/NVIDIA/Fuser/issues/2563
// Reshape to explicitly pull the sharded axis into the logical domain
auto D = w0->axis(0)->extent()->value().as<int64_t>();
local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E});
}
TensorView* matmul1 = sum(local_matmul1, {0}); // allreduce
Expand Down Expand Up @@ -301,7 +297,6 @@ std::vector<TensorView*> DistributedTransformer::mha_backwards(
TensorView* linear0,
const DeviceMesh& mesh) {
DataType dtype = w0->dtype();
const auto D = w0->axis(0)->extent()->value().as<int64_t>();
// Reform qkv from linear0 output
TensorView* qkv_cat = reshape(
castOp(DataType::Float, linear0),
Expand Down

0 comments on commit bf3de16

Please sign in to comment.