Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wmdi committed Dec 30, 2024
1 parent 0315160 commit 855a7d5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,17 @@ TEST_SUITE(FF_TEST_SUITE) {
{binary_tree_root_path(), mv2},
}};

printf("Before constructing cost_estimator\n");

auto map1 = std::unordered_map<OpCostEstimateKey, OpCostMetrics>{{
{map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics(1.0, 0)},
{map_unmapped_op_cost_estimate_key(k2, mv1), OpCostMetrics(2.0, 0)},
{map_unmapped_op_cost_estimate_key(k1, mv2), OpCostMetrics(1.5, 0)},
{map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics(2.5, 0)},
{map_unmapped_op_cost_estimate_key(k1, mv1),
OpCostMetrics{/*runtime=*/1.0, /*memory=*/0}},
{map_unmapped_op_cost_estimate_key(k2, mv1),
OpCostMetrics{/*runtime=*/2.0, /*memory=*/0}},
{map_unmapped_op_cost_estimate_key(k1, mv2),
OpCostMetrics{/*runtime=*/1.5, /*memory=*/0}},
{map_unmapped_op_cost_estimate_key(k2, mv2),
OpCostMetrics{/*runtime=*/2.5, /*memory=*/0}},
}};

printf("After constructing map1\n");

CostEstimator cost_estimator = make_fake_cost_estimator(
map1,
std::unordered_map<TensorSetMovement, float>{{
Expand All @@ -169,36 +169,11 @@ TEST_SUITE(FF_TEST_SUITE) {
0.4},
}});

// CostEstimator cost_estimator = make_fake_cost_estimator(
// std::unordered_map<OpCostEstimateKey, OpCostMetrics>{{
// {map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics(1.0,
// 0)}, {map_unmapped_op_cost_estimate_key(k2, mv1),
// OpCostMetrics(2.0, 0)}, {map_unmapped_op_cost_estimate_key(k1,
// mv2), OpCostMetrics(1.5, 0)},
// {map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics(2.5,
// 0)},
// }},
// std::unordered_map<TensorSetMovement, float>{{
// {TensorSetMovement{{}}, 0.0},
// {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1),
// 0.1},
// {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2),
// 0.2},
// {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2),
// 0.3},
// {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1),
// 0.4},
// }});

printf("After constructing cost_estimator\n");

MachineMappingContext context = MachineMappingContext{
cost_estimator,
allowed_machine_views1,
};

printf("After constructing context\n");

MachineMappingCache cache = empty_machine_mapping_cache();

SUBCASE("single layer") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ TEST_SUITE(FF_TEST_SUITE) {
cache, context, problem_tree, full_machine_spec, constraints);
MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{
SingleMachineMapping{
OpCostMetrics{1.0 + 2.0 + 0.1, 2 + 3},
OpCostMetrics{
/*runtime=*/1.0 + 2.0 + 0.1,
/*memory=*/2 + 3,
},
ParallelLayerGuidObliviousMachineMapping{{
{
BinaryTreePath{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,16 @@ TEST_SUITE(FF_TEST_SUITE) {
};

OpCostMetrics cost1 = OpCostMetrics{
2.0,
2,
/*runtime=*/2.0,
/*memory=*/2,
};
OpCostMetrics cost2 = OpCostMetrics{
4.0,
1,
/*runtime=*/4.0,
/*memory=*/1,
};
OpCostMetrics cost3 = OpCostMetrics{
2.0,
3,
/*runtime=*/2.0,
/*memory=*/3,
};

SingleMachineMapping mm1 = SingleMachineMapping{
Expand Down Expand Up @@ -101,40 +101,42 @@ TEST_SUITE(FF_TEST_SUITE) {
};

SUBCASE("empty") {
MachineMappingWithMemoryResult to_remove =
MachineMappingWithMemoryResult before_remove =
empty_machine_mapping_with_memory_result();
MachineMappingWithMemoryResult result =
remove_non_pareto_optimal_machine_mapping_result(to_remove);
remove_non_pareto_optimal_machine_mapping_result(before_remove);
MachineMappingWithMemoryResult correct =
empty_machine_mapping_with_memory_result();

CHECK(result == correct);
}

SUBCASE("no non-pareto_optimal") {
MachineMappingWithMemoryResult to_remove = MachineMappingWithMemoryResult{
{
mm1,
mm2,
},
};
SUBCASE("all solutions are pareto-optimal") {
MachineMappingWithMemoryResult before_remove =
MachineMappingWithMemoryResult{
{
mm1,
mm2,
},
};
MachineMappingWithMemoryResult result =
remove_non_pareto_optimal_machine_mapping_result(to_remove);
MachineMappingWithMemoryResult correct = to_remove;
remove_non_pareto_optimal_machine_mapping_result(before_remove);
MachineMappingWithMemoryResult correct = before_remove;

CHECK(result == correct);
}

SUBCASE("non-pareto_optimal") {
MachineMappingWithMemoryResult to_remove = MachineMappingWithMemoryResult{
{
mm1,
mm2,
mm3,
},
};
SUBCASE("there exists a non-pareto-optimal solution") {
MachineMappingWithMemoryResult before_remove =
MachineMappingWithMemoryResult{
{
mm1,
mm2,
mm3,
},
};
MachineMappingWithMemoryResult result =
remove_non_pareto_optimal_machine_mapping_result(to_remove);
remove_non_pareto_optimal_machine_mapping_result(before_remove);
MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{
{
mm1,
Expand All @@ -146,7 +148,9 @@ TEST_SUITE(FF_TEST_SUITE) {
}
}

TEST_CASE("series_combine(memory)") {
TEST_CASE("series_combine(float, MachineMappingWithMemoryResult const &, "
"MachineMappingWithMemoryResult const &, "
"std::optional<ParallelSplitTransformation> const&)") {
MachineView machine_view_0 = MachineView{
/*start=*/MachineSpaceCoordinate{
/*node_idx=*/0,
Expand Down Expand Up @@ -178,8 +182,8 @@ TEST_SUITE(FF_TEST_SUITE) {
};

OpCostMetrics pre_cost = OpCostMetrics{
2.0,
2,
/*runtime=*/2.0,
/*memory=*/2,
};
MachineMappingWithMemoryResult pre = MachineMappingWithMemoryResult{{
SingleMachineMapping{
Expand All @@ -204,8 +208,8 @@ TEST_SUITE(FF_TEST_SUITE) {
}};

OpCostMetrics post_cost = OpCostMetrics{
4.0,
1,
/*runtime=*/4.0,
/*memory=*/1,
};

MachineMappingWithMemoryResult post = MachineMappingWithMemoryResult{{
Expand Down Expand Up @@ -249,8 +253,9 @@ TEST_SUITE(FF_TEST_SUITE) {
{
SingleMachineMapping{
/*cost=*/OpCostMetrics{
pre_cost.runtime + comm_cost + post_cost.runtime,
pre_cost.memory + post_cost.memory,
/*runtime=*/pre_cost.runtime + comm_cost +
post_cost.runtime,
/*memory=*/pre_cost.memory + post_cost.memory,
},
/*machine_mapping=*/
ParallelLayerGuidObliviousMachineMapping{{
Expand Down Expand Up @@ -302,8 +307,9 @@ TEST_SUITE(FF_TEST_SUITE) {
{
SingleMachineMapping{
/*cost=*/OpCostMetrics{
pre_cost.runtime + comm_cost + post_cost.runtime,
pre_cost.memory + post_cost.memory,
/*runtime=*/pre_cost.runtime + comm_cost +
post_cost.runtime,
/*memory=*/pre_cost.memory + post_cost.memory,
},
/*machine_mapping=*/
ParallelLayerGuidObliviousMachineMapping{{
Expand Down Expand Up @@ -337,7 +343,9 @@ TEST_SUITE(FF_TEST_SUITE) {
}
}

TEST_CASE("parallel_combine(memory)") {
TEST_CASE("parallel_combine(float, MachineMappingWithMemoryResult const &, "
"MachineMappingWithMemoryResult const &, "
"std::optional<ParallelSplitTransformation> const&)") {
MachineView machine_view_0 = MachineView{
/*start=*/MachineSpaceCoordinate{
/*node_idx=*/0,
Expand Down Expand Up @@ -369,8 +377,8 @@ TEST_SUITE(FF_TEST_SUITE) {
};

OpCostMetrics lhs_cost = OpCostMetrics{
2.0,
2,
/*runtime=*/2.0,
/*memory=*/2,
};
MachineMappingWithMemoryResult lhs = MachineMappingWithMemoryResult{{
SingleMachineMapping{
Expand All @@ -395,8 +403,8 @@ TEST_SUITE(FF_TEST_SUITE) {
}};

OpCostMetrics rhs_cost = OpCostMetrics{
4.0,
1,
/*runtime=*/4.0,
/*memory=*/1,
};
MachineMappingWithMemoryResult rhs = MachineMappingWithMemoryResult{{
SingleMachineMapping{
Expand Down Expand Up @@ -434,8 +442,8 @@ TEST_SUITE(FF_TEST_SUITE) {
MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{
SingleMachineMapping{
/*cost=*/OpCostMetrics{
std::max(lhs_cost.runtime, rhs_cost.runtime),
std::max(lhs_cost.memory, rhs_cost.memory),
/*runtime=*/std::max(lhs_cost.runtime, rhs_cost.runtime),
/*memory=*/std::max(lhs_cost.memory, rhs_cost.memory),
},
/*machine_mapping=*/
ParallelLayerGuidObliviousMachineMapping{
Expand Down Expand Up @@ -510,16 +518,16 @@ TEST_SUITE(FF_TEST_SUITE) {
};

OpCostMetrics cost1 = OpCostMetrics{
2.0,
2,
/*runtime=*/2.0,
/*memory=*/2,
};
OpCostMetrics cost2 = OpCostMetrics{
4.0,
1,
/*runtime=*/4.0,
/*memory=*/1,
};
OpCostMetrics cost3 = OpCostMetrics{
2.0,
3,
/*runtime=*/2.0,
/*memory=*/3,
};

SingleMachineMapping mm1 = SingleMachineMapping{
Expand Down

0 comments on commit 855a7d5

Please sign in to comment.