This repository has been archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
remat.patch
2094 lines (2013 loc) · 85.2 KB
/
remat.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc
index 93ae3d2..39a6266 100644
--- a/tensorflow/compiler/xla/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/debug_options_flags.cc
@@ -59,6 +59,16 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_allow_excess_precision(true);
opts.set_xla_force_host_platform_device_count(1);
+
+ opts.set_xla_use_hlo_rematerialization(false);
+ opts.set_xla_rematerialization_mem_limit("0");
+ opts.set_xla_rematerialization_scheduler("default");
+ opts.set_xla_rematerialization_algorithm("standard");
+ opts.set_xla_rematerialization_small_node_limit(1);
+ opts.set_xla_rematerialization_disable_cuda(false);
+ opts.set_xla_rematerialization_dump_dot(false);
+ opts.set_xla_rematerialization_dump_memlog(false);
+
return opts;
}
@@ -440,6 +450,54 @@ static void AllocateFlags() {
"--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."),
tensorflow::Flag(
+ "xla_use_hlo_rematerialization",
+ bool_setter_for(&DebugOptions::set_xla_use_hlo_rematerialization),
+ flag_values->xla_use_hlo_rematerialization(),
+ "Enables HLO rematerialization heuristic which tries either to reduce"
+ " memory consunpution as much as possible or until below a limit "
+ "setted by --xla_rematerialization_mem_limit"),
+ tensorflow::Flag(
+ "xla_rematerialization_mem_limit",
+ string_setter_for(&DebugOptions::set_xla_rematerialization_mem_limit),
+ flag_values->xla_rematerialization_mem_limit(),
+ "Sets a memory limit goal (in bytes) to the HLO rematerialization "
+ "heuristic."),
+ tensorflow::Flag(
+ "xla_rematerialization_scheduler",
+ string_setter_for(&DebugOptions::set_xla_rematerialization_scheduler),
+ flag_values->xla_rematerialization_scheduler(),
+ "Sets the scheduler to be used just before rematerialization."
+ " Options are: default, postorder, DFS, and list."),
+ tensorflow::Flag(
+ "xla_rematerialization_algorithm",
+ string_setter_for(&DebugOptions::set_xla_rematerialization_algorithm),
+ flag_values->xla_rematerialization_algorithm(),
+ "Sets the rematerialization or compression technique to be used."
+ " Options are: standard, compress, standardcompress, and path."),
+ tensorflow::Flag(
+ "xla_rematerialization_small_node_limit",
+ int32_setter_for(&DebugOptions::set_xla_rematerialization_small_node_limit),
+ flag_values->xla_rematerialization_small_node_limit(),
+ "Sets the minimum size (in MiB) that a candidate to rematerialization"
+ " needs to have."),
+ tensorflow::Flag(
+ "xla_rematerialization_disable_cuda",
+ bool_setter_for(&DebugOptions::set_xla_rematerialization_disable_cuda),
+ flag_values->xla_rematerialization_disable_cuda(),
+ "Disable cuda picking fusion optimization (this can improve remat)."),
+ tensorflow::Flag(
+ "xla_rematerialization_dump_dot",
+ bool_setter_for(&DebugOptions::set_xla_rematerialization_dump_dot),
+ flag_values->xla_rematerialization_dump_dot(),
+ "Dump dot representation of the HLO graph."),
+ tensorflow::Flag(
+ "xla_rematerialization_dump_memlog",
+ bool_setter_for(&DebugOptions::set_xla_rematerialization_dump_memlog),
+ flag_values->xla_rematerialization_dump_memlog(),
+ "Dump mem log about memory usage after the rematerialization."),
+
+
+ tensorflow::Flag(
"xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to),
flag_values->xla_dump_to(),
"Directory into which debugging data is written. If not specified "
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 581d358..bab48f1 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2957,6 +2957,7 @@ cc_library(
":flatten_call_graph",
":hlo",
":hlo_dce",
+ ":hlo_cost_analysis",
":hlo_memory_scheduler",
":hlo_ordering",
":logical_buffer",
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 7d65624..9e8db31 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -130,6 +130,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:hlo_proto_util",
+ "//tensorflow/compiler/xla/service:hlo_rematerialization",
"//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index acafa2c..85d388f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+#include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
@@ -601,6 +602,68 @@ struct OrcJITPostCompilationHook {
} // namespace
+// Return the byte size of the top-level buffer of the given shape.
+static int64 ByteSizeOf(const Shape& shape) {
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
+static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
+ Shape result = shape;
+ Layout layout = result.layout();
+ int64 most_minor_index = layout.minor_to_major()[0];
+ int64 second_minor_index = layout.minor_to_major()[1];
+ int64 most_minor = result.dimensions(most_minor_index);
+ int64 second_minor = result.dimensions(second_minor_index);
+ if (most_minor < second_minor) {
+ result.set_dimensions(most_minor_index, second_minor);
+ result.set_dimensions(second_minor_index, most_minor);
+ }
+ return result;
+}
+
+StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
+ HloModule* module) {
+
+ auto sch = DefaultMemoryScheduler;
+ string scheduler_option =
+ module->config().debug_options().xla_rematerialization_scheduler();
+
+ if (scheduler_option == "postorder") {
+ sch = PostOrderMemoryScheduler;
+ } else if (scheduler_option == "DFS") {
+ sch = DFSMemoryScheduler;
+ } else if (scheduler_option == "list") {
+ sch = ListMemoryScheduler;
+ }
+
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ ComputationSchedulerToModuleScheduler(
+ sch
+ ));
+
+ TF_RETURN_IF_ERROR(scheduler.Run(module).status());
+
+ RematerializationAlg alg = kStandardAlg;
+ string algorithm_option =
+ module->config().debug_options().xla_rematerialization_algorithm();
+
+ if (algorithm_option == "compress") {
+ alg = kCompressAlg;
+ } else if (algorithm_option == "standardcompress") {
+ alg = kStandardAndCompressAlg;
+ } else if (algorithm_option == "path") {
+ alg = kPathAlg;
+ }
+
+ DumpHloModuleIfEnabled(*module, "before_remat");
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr, ChooseCompactLayoutForShape);
+ remat.setAlgorithm(alg);
+ return remat.Run(module);
+}
+
+
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* /*device_allocator*/) {
@@ -613,6 +676,22 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::call_once(llvm_command_line_options_initialized,
&llvm_ir::InitializeLLVMCommandLineOptions, module->config());
+ // Rematerialization needs to be apply after all optimizations
+ if (module->config().debug_options().xla_use_hlo_rematerialization()) {
+ string mem_limit_s =
+ module->config().debug_options().xla_rematerialization_mem_limit();
+
+ LOG(WARNING) << "Starting rematerialization of "<<
+ module->name() << " with " << mem_limit_s << " bytes as mem limit";
+
+ int64_t mem_limit_u = std::stoull(mem_limit_s);
+
+ StatusOr<bool> remat_result =
+ RunHloRematerialization(mem_limit_u, module.get());
+
+ TF_RETURN_IF_ERROR(remat_result.status());
+ }
+
ModuleHook pre_optimization_ir_hook;
ModuleHook post_optimization_ir_hook;
std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 866df46..59f0fe0 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1008,6 +1008,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
+ "//tensorflow/compiler/xla/service:hlo_rematerialization",
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 9dda327..f5ca615 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -78,6 +78,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -259,6 +260,67 @@ bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) {
} // namespace
+// Return the byte size of the top-level buffer of the given shape.
+static int64 ByteSizeOf(const Shape& shape) {
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
+static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
+ Shape result = shape;
+ Layout layout = result.layout();
+ int64 most_minor_index = layout.minor_to_major()[0];
+ int64 second_minor_index = layout.minor_to_major()[1];
+ int64 most_minor = result.dimensions(most_minor_index);
+ int64 second_minor = result.dimensions(second_minor_index);
+ if (most_minor < second_minor) {
+ result.set_dimensions(most_minor_index, second_minor);
+ result.set_dimensions(second_minor_index, most_minor);
+ }
+ return result;
+}
+
+StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
+ HloModule* module) {
+
+ auto sch = DefaultMemoryScheduler;
+ string scheduler_option =
+ module->config().debug_options().xla_rematerialization_scheduler();
+
+ if (scheduler_option == "postorder") {
+ sch = PostOrderMemoryScheduler;
+ } else if (scheduler_option == "DFS") {
+ sch = DFSMemoryScheduler;
+ } else if (scheduler_option == "list") {
+ sch = ListMemoryScheduler;
+ }
+
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ ComputationSchedulerToModuleScheduler(
+ sch
+ ));
+
+ TF_RETURN_IF_ERROR(scheduler.Run(module).status());
+
+ RematerializationAlg alg = kStandardAlg;
+ string algorithm_option =
+ module->config().debug_options().xla_rematerialization_algorithm();
+
+ if (algorithm_option == "compress") {
+ alg = kCompressAlg;
+ } else if (algorithm_option == "standardcompress") {
+ alg = kStandardAndCompressAlg;
+ } else if (algorithm_option == "path") {
+ alg = kPathAlg;
+ }
+
+ DumpHloModuleIfEnabled(*module, "before_remat");
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr, ChooseCompactLayoutForShape);
+ remat.setAlgorithm(alg);
+ return remat.Run(module);
+}
+
// Runs optimization passes on the given HLO module.
Status impl::OptimizeHloModule(HloModule* hlo_module,
se::StreamExecutor* stream_exec,
@@ -379,6 +441,7 @@ Status impl::OptimizeHloModule(HloModule* hlo_module,
// tuple/get-tuple-element pairs that TupleSimplifier fixes.
pipeline.AddPass<TupleSimplifier>();
}
+
// CudnnConvRewriter, CudnnConvPaddingLegalization and
// CudnnConvPadForTensorCores may add instructions which can be simplified
// by constant folding.
@@ -399,7 +462,8 @@ Status impl::OptimizeHloModule(HloModule* hlo_module,
LayoutAssignment::InstructionCanChangeLayout, stream_exec);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
-
+
+ if (!hlo_module->config().debug_options().xla_rematerialization_disable_cuda())
{
HloPassPipeline pipeline("post-layout_assignment");
/* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
@@ -558,6 +622,22 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
TF_RET_CHECK(stream_exec != nullptr);
+ // Rematerialization needs to be apply after all optimizations
+ if (module->config().debug_options().xla_use_hlo_rematerialization()) {
+ string mem_limit_s =
+ module->config().debug_options().xla_rematerialization_mem_limit();
+
+ LOG(WARNING) << "Starting rematerialization of "<<
+ module->name() << " with " << mem_limit_s << " bytes as mem limit";
+
+ int64_t mem_limit_u = std::stoull(mem_limit_s);
+
+ StatusOr<bool> remat_result =
+ RunHloRematerialization(mem_limit_u, module.get());
+
+ TF_RETURN_IF_ERROR(remat_result.status());
+ }
+
llvm::LLVMContext llvm_context;
std::string buffer;
llvm::raw_string_ostream error(buffer);
@@ -605,7 +685,6 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
&ir_emitter_context);
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
-
{
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission");
TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 603371d..e3f1386 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
+#include <fstream>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
@@ -66,14 +67,24 @@ bool IsRematerializable(const HloInstruction* instruction) {
}
}
+ // Don`t rematerialize instructions that are smaller than 1 MB. This improves
+ // rematerialization stability over different mem_limits budgets.
+ int small_node_limit = instruction->parent()->parent()
+ ->config().debug_options().xla_rematerialization_small_node_limit();
+ if (small_node_limit !=0 &&
+ ShapeUtil::ByteSizeOf(instruction->shape(),sizeof(void*))
+ <= small_node_limit*1024*1024) {
+ return false;
+ }
+
// Don't rematerialize instructions with side effects or instructions which
// cannot be cloned safely.
switch (instruction->opcode()) {
case HloOpcode::kCall:
+ case HloOpcode::kCustomCall:
case HloOpcode::kConstant:
case HloOpcode::kConditional:
case HloOpcode::kAllReduce:
- case HloOpcode::kCustomCall:
case HloOpcode::kParameter:
case HloOpcode::kWhile:
return false;
@@ -100,6 +111,17 @@ bool CanBeRematerialized(
using BufferId = int64;
using BufferIdList = absl::InlinedVector<BufferId, 3>;
+struct RematStrategy {
+ enum {
+ // Recompute the node at a later program point.
+ kRecompute,
+ // Change the layout into a compact form and uncompress it back at a later
+ // program point.
+ kCompress,
+ } kind;
+ Shape compact_shape;
+};
+
// We wrap HloInstruction* with an Item that holds auxiliary
// per-instruction state.
struct Item {
@@ -117,6 +139,10 @@ struct Item {
// The buffers defined by this instruction.
BufferIdList buffers_defined;
+ // Output buffers of this instruction. This is used to track outputs by GTE
+ // instructions (where the instruction doesn't define a buffer).
+ BufferIdList buffers_output;
+
// The buffers used by this instruction.
BufferIdList buffers_used;
@@ -251,6 +277,32 @@ class InstructionList {
return InsertBefore(to_insert, min_position_item);
}
+ void InsertAfterInstructions(Item* to_insert,
+ absl::Span<Item* const> after_instructions) {
+ VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
+ << " after {"
+ << absl::StrJoin(after_instructions, ", ",
+ [](string* out, Item* item) {
+ absl::StrAppend(out, item->instruction->name());
+ })
+ << "}";
+
+ // Find the max position number of any instruction in
+ // 'after_instructions'.
+ CHECK(!after_instructions.empty());
+ Item* max_position_item = nullptr;
+ for (Item* item : after_instructions) {
+ if (max_position_item == nullptr ||
+ item->position > max_position_item->position) {
+ max_position_item = item;
+ }
+ }
+ // No rematerializable instruction should be inserted at the end of the
+ // computation.
+ CHECK(max_position_item->next != nullptr);
+ InsertBeforeInstructions(to_insert, {max_position_item->next});
+ }
+
void Blacklist(const HloInstruction* inst) {
GetItem(inst)->blacklisted = true;
}
@@ -327,6 +379,7 @@ class MemoryUsageTracker {
MemoryUsageTracker(
const HloComputation* computation,
const HloRematerialization::ShapeSizeFunction& size_function,
+ const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
const InstructionList& instruction_list);
@@ -338,6 +391,22 @@ class MemoryUsageTracker {
// EndInstruction memory for dead operand(s) is freed.
Status BeginInstruction(Item* item);
+ int64 RematerializationCost(const HloInstruction* instruction,
+ int64 memory_reduced, int64 memory_limit_bytes) {
+ // If none of the users of 'instruction' have been placed in the sequence
+ // (as tracked by memory_tracker), then rematerialization of 'instruction'
+ // is a zero-cost move of 'instruction' in the sequence.
+ if (!absl::c_any_of(
+ instruction->users(),
+ [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
+ return 0;
+ }
+
+ CHECK_GT(memory_reduced, 0);
+ // Return the inverse of the benefit of rematerialization.
+ return memory_limit_bytes / memory_reduced;
+ }
+
// Finishes the placement of the current instruction. This frees any dead
// operands or dead result of the instruction. This must be called after
// each call to BeginInstruction.
@@ -348,16 +417,28 @@ class MemoryUsageTracker {
int64 MemoryReducedIfRematerialized(Item* item) const;
// Returns the number of bytes that the current memory usage will be reduced
+ // if the given instruction is compact.
+ int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
+
+ // Returns the number of bytes that the current memory usage will be reduced
// by if the given sequence of instructions is rematerialized.
int64 MemoryReducedIfRematerialized(const absl::Span<Item*>& items) const;
+ Status AddCompressInstructions(Item* original_item, Item* compressed_item,
+ Item* uncompressed_item);
+
// Adjusts memory usage to account for the rematerialization of
// original_item for all remaining unplaced uses. The rematerialization
// is remat_item. This method should be called after the HLO graph has
- // been transformed (rematerialization instruction created and connected to
- // uses).
+ // been transformed (rematerialization instruction created and connected
+ // to uses).
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
+ std::pair<Item*, RematStrategy> PickRematerializationCandidate(
+ const RematerializationAlg,
+ const InstructionList& instruction_list, int64 memory_limit_bytes,
+ absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
+
// Returns whether the given instruction has been placed (BeginInstruction
// has been called with 'instruction' as the argument).
bool IsPlaced(const HloInstruction* instruction) const {
@@ -390,6 +471,9 @@ class MemoryUsageTracker {
// The materialized size of the buffer in bytes.
const int64 size;
+ // Shape of the buffer.
+ Shape shape;
+
// Whether this buffer is live-out of the computation.
bool live_out;
@@ -412,19 +496,21 @@ class MemoryUsageTracker {
}
};
+ // Get the compact shape of given hlo instruction. An internal cache is used
+ // to avoid computing the shape multiple times.
+ StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
+
// Creates a Buffer representing the given logical buffer. The buffer is added
// to buffers_ and a reference is returned.
Buffer& CreateBufferFromLogicalBuffer(
const LogicalBuffer* logical_buffer,
- const TuplePointsToAnalysis& points_to_analysis,
- const HloRematerialization::ShapeSizeFunction& size_function,
- bool live_out) {
+ const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
bool has_indirect_uses = false;
ItemList users = GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &has_indirect_uses);
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
- size_function(logical_buffer->shape()), std::move(users),
- live_out, has_indirect_uses);
+ logical_buffer->shape(), std::move(users), live_out,
+ has_indirect_uses);
}
// Create a new buffer representing a rematerialization of given buffer for
@@ -438,7 +524,7 @@ class MemoryUsageTracker {
for (Item* use : rematerialized_uses) {
CHECK(!use->placed) << use->instruction->name();
}
- return NewBuffer(remat_item, original_buffer.size,
+ return NewBuffer(remat_item, original_buffer.shape,
std::move(rematerialized_uses), /*live_out=*/false,
/*has_indirect_uses=*/false);
}
@@ -449,7 +535,8 @@ class MemoryUsageTracker {
// different computation.
int64 AllocatedSize(BufferId buffer_id) const {
const Buffer& buffer = buffers_.at(buffer_id);
- HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
+ HloInstruction* inst = buffer.defining_instruction->instruction;
+ HloOpcode def_opcode = inst->opcode();
if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
return 0;
} else {
@@ -473,7 +560,7 @@ class MemoryUsageTracker {
return absl::c_linear_search(in_progress_uses, buffer_id);
}
- // Returns whether the given instruction is live at the current program
+ // Returns whether the given buffer is live at the current program
// point.
bool IsCurrentlyLive(BufferId buffer_id) const {
const Buffer& buffer = buffers_[buffer_id];
@@ -481,13 +568,30 @@ class MemoryUsageTracker {
buffer.unfinished_user_count > 0);
}
+ // Returns whether the given instruction is live at the current program
+ // point.
+ bool IsInstructionCurrentlyLive(Item* instruction) const {
+ // If the instruction has not started yet, it is not alive.
+ if (!IsPlaced(instruction->instruction)) {
+ return false;
+ }
+ for (const HloInstruction* user : instruction->instruction->users()) {
+ if (!IsPlaced(user)) {
+ // If there is an unplaced user, consider this instruction currently
+ // live.
+ return true;
+ }
+ }
+ return false;
+ }
+
// Create a new buffer, add it to buffers_, and return a reference.
- Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
- bool live_out, bool has_indirect_uses) {
+ Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
+ ItemList&& users, bool live_out, bool has_indirect_uses) {
int buffer_id = buffers_.size();
- buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
- has_indirect_uses, users,
- static_cast<int64>(users.size())});
+ buffers_.push_back(Buffer{
+ buffer_id, defining_instruction, size_function_(shape), shape, live_out,
+ has_indirect_uses, users, static_cast<int64>(users.size())});
return buffers_.back();
}
@@ -498,6 +602,16 @@ class MemoryUsageTracker {
// (BeginInstruction/EndInstruction calls).
const InstructionList& instruction_list_;
+ // Size function returns the bytes of a given buffer.
+ const HloRematerialization::ShapeSizeFunction& size_function_;
+
+ // Converts a shape into compact form, returns the same shape if a shape is
+ // already considered compact.
+ const HloRematerialization::CompactShapeFunction& compact_shape_function_;
+
+ // A map that caches existing known compact shape for each instruction.
+ absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
+
// Memory usage at the currently placed instruction.
int64 memory_usage_ = 0;
@@ -512,9 +626,13 @@ class MemoryUsageTracker {
MemoryUsageTracker::MemoryUsageTracker(
const HloComputation* computation,
const HloRematerialization::ShapeSizeFunction& size_function,
+ const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
const InstructionList& instruction_list)
- : computation_(computation), instruction_list_(instruction_list) {
+ : computation_(computation),
+ instruction_list_(instruction_list),
+ size_function_(size_function),
+ compact_shape_function_(compact_shape_function) {
PointsToSet::BufferSet live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
@@ -556,7 +674,7 @@ MemoryUsageTracker::MemoryUsageTracker(
}
} else {
buffer = &CreateBufferFromLogicalBuffer(
- logical_buffer, points_to_analysis, size_function,
+ logical_buffer, points_to_analysis,
ContainsKey(live_out_set, logical_buffer));
item->buffers_defined.push_back(buffer->id);
for (Item* user : buffer->users) {
@@ -566,6 +684,14 @@ MemoryUsageTracker::MemoryUsageTracker(
logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
}
+
+ // Trace the output of each instruction. This is so that we can properly
+ // track which outputs does GTEs have.
+ for (const LogicalBuffer* logical_buffer :
+ points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
+ item->buffers_output.push_back(
+ logical_buffer_to_buffer_id[logical_buffer]);
+ }
}
XLA_VLOG_LINES(10, ToString());
DCHECK(Check());
@@ -609,9 +735,9 @@ Status MemoryUsageTracker::EndInstruction() {
<< buffer.ToString() << " has negative unfinished use count.";
if (buffer.unfinished_user_count == 0) {
// Buffer is now dead.
- VLOG(3) << " " << buffer.ToString() << " is now dead.";
memory_usage_ -= AllocatedSize(buffer_id);
- CHECK_GE(memory_usage_, 0);
+ // The memory usage can become negative inside the computation as we can
+ // free up the parameter space and reuse it for other tensors.
}
}
@@ -620,9 +746,9 @@ Status MemoryUsageTracker::EndInstruction() {
for (BufferId buffer_id : in_progress_item_->buffers_defined) {
const Buffer& buffer = buffers_.at(buffer_id);
if (buffer.unfinished_user_count == 0) {
- VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
memory_usage_ -= AllocatedSize(buffer_id);
- CHECK_GE(memory_usage_, 0);
+ // The memory usage can become negative inside the computation as we can
+ // free up the parameter space and reuse it for other tensors.
}
}
@@ -637,6 +763,30 @@ Status MemoryUsageTracker::EndInstruction() {
return Status::OK();
}
+int64 MemoryUsageTracker::MemoryReducedIfCompressed(
+ Item* item, const Shape& compact_shape) const {
+ CHECK_NE(in_progress_item_, nullptr);
+ if (!item->placed || item == in_progress_item_) {
+ return 0;
+ }
+
+ int64 memory_reduced = 0;
+
+ // We only compress a single piece of an output at one time.
+ CHECK_EQ(item->buffers_output.size(), 1);
+ BufferId buffer_id = item->buffers_output[0];
+ if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) &&
+ IsInstructionCurrentlyLive(item)) {
+ const Buffer& buffer = buffers_.at(buffer_id);
+ memory_reduced += buffer.size;
+
+ int64 compact_shape_size = size_function_(compact_shape);
+ // Account for buffers that are compressed after instruction.
+ memory_reduced -= compact_shape_size;
+ }
+ return memory_reduced;
+}
+
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
CHECK_NE(in_progress_item_, nullptr);
if (!item->placed || item == in_progress_item_) {
@@ -736,6 +886,56 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
return memory_reduced;
}
+Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
+ Item* compressed_item,
+ Item* uncompressed_item) {
+ // Original buffer is now dead.
+ memory_usage_ -= size_function_(original_item->instruction->shape());
+ // Compressed buffer is now alive.
+ memory_usage_ += size_function_(compressed_item->instruction->shape());
+
+ ItemList placed_users;
+ ItemList unplaced_users;
+ CHECK_EQ(original_item->buffers_output.size(), 1);
+ BufferId original_buffer_id = original_item->buffers_output[0];
+ Buffer& original_buffer = buffers_.at(original_buffer_id);
+ for (Item* user : original_buffer.users) {
+ if (user->placed) {
+ CHECK(IsFinished(user)) << user->instruction->name();
+ placed_users.push_back(user);
+ } else {
+ unplaced_users.push_back(user);
+ }
+ }
+ original_buffer.users = std::move(placed_users);
+ original_buffer.unfinished_user_count = 0;
+ original_buffer.users.push_back(compressed_item);
+ Buffer& compressed_buffer =
+ NewBuffer(compressed_item, compressed_item->instruction->shape(),
+ {uncompressed_item}, /*live_out=*/false,
+ /*has_indirect_uses=*/false);
+ compressed_item->buffers_used = original_item->buffers_output;
+ compressed_item->buffers_output = {compressed_buffer.id};
+ compressed_item->buffers_defined.push_back(compressed_buffer.id);
+
+ Buffer& uncompressed_buffer =
+ NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
+ std::move(unplaced_users), /*live_out=*/false,
+ /*has_indirect_uses=*/false);
+
+ uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
+ uncompressed_item->buffers_output = {uncompressed_buffer.id};
+ uncompressed_item->buffers_defined = {uncompressed_buffer.id};
+
+ for (Item* user : uncompressed_buffer.users) {
+ BufferIdList& buffers_used = user->buffers_used;
+ std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
+ uncompressed_buffer.id);
+ }
+
+ return Status::OK();
+}
+
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
Item* remat_item) {
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
@@ -831,6 +1031,17 @@ string MemoryUsageTracker::ToString() const {
return output;
}
+StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
+ auto it = compact_shape_.find(hlo);
+ if (it != compact_shape_.end()) {
+ return it->second;
+ }
+ const Shape& original_shape = hlo->shape();
+ TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
+ compact_shape_[hlo] = min_shape;
+ return min_shape;
+}
+
bool MemoryUsageTracker::Check() const {
auto elements_are_unique = [](const BufferIdList& vec) {
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
@@ -917,12 +1128,16 @@ int64 RematerializationCost(const HloInstruction* instruction,
// candidate which reduce memory use at the program point of the current
// instruction as indicated by memory_tracker. nullptr is returned if no
// candidate can be found.
-Item* PickRematerializationCandidate(
- const MemoryUsageTracker& memory_tracker,
+std::pair<Item*, RematStrategy>
+MemoryUsageTracker::PickRematerializationCandidate(
+ const RematerializationAlg algorithm,
const InstructionList& instruction_list, int64 memory_limit_bytes,
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
Item* best_item = nullptr;
int64 best_cost = 0;
+ RematStrategy best_strategy;
+
+ VLOG(5) << "Picking candidate";
// TODO(b/35244891): This is currently quadratic in the number of HLO
// instructions.
@@ -947,68 +1162,520 @@ Item* PickRematerializationCandidate(
if (!CanBeRematerialized(candidate, remat_able)) {
VLOG(5) << "candidate " << candidate->name()
<< " not viable: is not rematerializable";
+
continue;
}
- // If any of the candidate's control successor has been placed, we need to
- // skip this candidate. Otherwise we will violate control dependency.
- bool control_successor_placed =
- std::any_of(candidate->control_successors().begin(),
- candidate->control_successors().end(),
- [&memory_tracker](const HloInstruction* inst) {
- return memory_tracker.IsPlaced(inst);
- });
+ if (item->buffers_output.size() == 1 &&
+ (algorithm == RematerializationAlg::kCompressAlg ||
+ algorithm == RematerializationAlg::kStandardAndCompressAlg)) {
+ // Only consider compressing single output instruction.
+ const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
+
+ if (item->placed && item != in_progress_item_ &&
+ !output_buffer.live_out) {
+ const Shape& original_shape = item->instruction->shape();
+ if (original_shape.IsArray()) {
+ Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
+ const int64 memory_reduced =
+ MemoryReducedIfCompressed(item, compact_shape);
+ if (memory_reduced > 0) {
+ const int64 cost = memory_limit_bytes / memory_reduced;
+ if (best_item == nullptr || cost < best_cost) {
+ VLOG(3) << "candidate " << candidate->name() << "("
+ << candidate->ToShortString() << ")"
+ << " now best when compressed into "
+ << compact_shape.ToString(true);
+ RematStrategy strategy;
+ strategy.kind = RematStrategy::kCompress;
+ best_strategy = strategy;
+ best_strategy.compact_shape = compact_shape;
+ best_item = item;
+ best_cost = cost;
+ }
+ }
+ }
+ }
+ }
+
+ // If any of the candidate's control successor has been placed, we need
+ // to skip this candidate. Otherwise we will violate control dependency.
+ bool control_successor_placed = std::any_of(
+ candidate->control_successors().begin(),
+ candidate->control_successors().end(),
+ [this](const HloInstruction* inst) { return IsPlaced(inst); });
if (control_successor_placed) {
continue;
}
- const int64 memory_reduced =
- memory_tracker.MemoryReducedIfRematerialized(item);
+ if (algorithm == RematerializationAlg::kStandardAlg ||
+ algorithm == RematerializationAlg::kStandardAndCompressAlg) {
+ const int64 memory_reduced = MemoryReducedIfRematerialized(item);
- if (memory_reduced <= 0) {
- VLOG(5) << "candidate " << candidate->name()
- << " memory reduced = " << memory_reduced << " <= 0";
- continue;
+ if (memory_reduced > 0) {
+ const int cost =
+ RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
+
+ VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
+ << memory_reduced << ", cost per byte " << cost;
+
+ if (best_item == nullptr || cost < best_cost) {
+ VLOG(5) << "candidate " << candidate->name() << " now best";
+ best_strategy.kind = RematStrategy::kRecompute;
+ best_item = item;
+ best_cost = cost;
+ }
+ }
+ }
+ }
+ return {best_item, best_strategy};
+}
+
+StatusOr<int64> DerematerializeInstruction(HloComputation* computation,
+ HloInstruction* source_node) {
+
+ for (auto inst : computation->instructions()) {
+ if (inst->name().find(source_node->name() + ".remat") == 0) {
+ std::vector<HloInstruction*> users = inst->users();
+ for (HloInstruction* user : users) {
+ TF_RETURN_IF_ERROR(inst->ReplaceUseWith(user, source_node));
+ }
+ }
+ }
+ return true;
+}
+
+// Rematerialize the instruction source_node and change its use in target_user:
+// before remat:
+// ---> targe_user
+// /
+// source_node -------|
+// \
+// ----> other users
+//
+// after remat:
+//
+// remat_copy ------> target_user
+//
+// source_node -----> other users
+//
+StatusOr<int64> RematerializeInstructionPath(
+ HloComputation* computation, Item* source_node, Item* target_user,
+ absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
+ InstructionList* instruction_list, int path_size,
+ std::vector<HloInstruction*> &articulations_vector) {
+
+ HloInstruction* source_node_inst = source_node->instruction;
+
+ if (!IsRematerializable(source_node_inst) ||
+ path_size == 0 ||
+ source_node->blacklisted) {
+ return false;
+ }
+
+ HloInstruction* remat_copy_inst =
+ computation->AddInstruction(source_node_inst->Clone("remat"));
+
+ Item* remat_copy_item = instruction_list->CreateItem(remat_copy_inst);
+
+ TF_RETURN_IF_ERROR(source_node_inst->ReplaceUseWith(
+ target_user->instruction, remat_copy_inst));
+
+ ItemList place_before;
+ place_before.push_back(source_node);
+
+ instruction_list->InsertAfterInstructions(remat_copy_item, place_before);
+ remat_copy_item->placed = true;
+
+ if (source_node_inst->users().empty()) {
+ if (ContainsKey(*remat_move_instructions, source_node_inst)) {
+ remat_copy_item->blacklisted = true;
+ }
+ remat_move_instructions->insert(remat_copy_inst);
+ }
+
+ auto* inst_item = instruction_list->first();
+ for (; inst_item != nullptr; inst_item = instruction_list->next(inst_item)) {
+ for (auto inst_item_use : inst_item->instruction->users()) {
+ if (inst_item_use == remat_copy_inst) {
+ RematerializeInstructionPath(computation, inst_item, remat_copy_item,
+ remat_move_instructions, instruction_list, path_size-1,
+ articulations_vector);
+ }
+ }
+ }
+
+ return true;
+}
+
+StatusOr<int64> RematerializeInstruction(
+ MemoryUsageTracker* memory_tracker, Item* best_item,
+ absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
+ InstructionList* instruction_list) {
+ HloInstruction* best = best_item->instruction;
+ VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
+ << HumanReadableNumBytes(
+ memory_tracker->MemoryReducedIfRematerialized(best_item))
+ << ")";
+
+ int64 net_instructions_added = 0;
+
+ HloComputation* computation = best->parent();
+
+ HloInstruction* remat =
+ computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
+
+ // Add control dependencies to the new operation.
+ for (auto successor : best->control_successors()) {
+ TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
+ }
+ for (auto predecessor : best->control_predecessors()) {