diff --git a/include/mirage/search/search.h b/include/mirage/search/search.h index 4f8f87c..f0dd296 100644 --- a/include/mirage/search/search.h +++ b/include/mirage/search/search.h @@ -20,7 +20,8 @@ class KernelGraphGenerator { public: KernelGraphGenerator(kernel::Graph const &computation_graph, GeneratorConfig const &config, - char const *filename); + char const *filename, + bool verbose = false); void generate_kernel_graphs(); @@ -30,19 +31,24 @@ class KernelGraphGenerator { char const *filename; std::vector generated_graphs; int num_thread; + bool verbose; private: + // Computation graph-related fields std::vector> computation_graph_output_patterns; std::vector computation_graph_output_tensors; std::vector, type::DataType, layout::DmemLayout>> computation_graph_input_attrs; - std::atomic num_total_kernel_graphs; + // Statistics-related fields std::atomic num_total_random_tests; std::atomic num_valid_kernel_graphs; std::atomic num_total_states; + // Time + std::chrono::time_point start_time; + std::mutex fp_mutex; std::mutex generated_graphs_mutex; @@ -70,6 +76,8 @@ class KernelGraphGenerator { bool verify(kernel::Graph const &g); void save_results() const; + double get_elapsed_time_in_sec() const; + void show_statistics() const; }; } // namespace search diff --git a/include/mirage/search/search_c.h b/include/mirage/search/search_c.h index 14d90f0..d56eda1 100644 --- a/include/mirage/search/search_c.h +++ b/include/mirage/search/search_c.h @@ -22,6 +22,7 @@ int cython_search(mirage::kernel::Graph const *input_graph, std::vector block_dim_to_explore, std::vector fmap_to_explore, std::vector frange_to_explore, + bool verbose, char const *default_config); } // namespace search_c } // namespace mirage diff --git a/python/mirage/_cython/CCore.pxd b/python/mirage/_cython/CCore.pxd index 9047e05..d9c0353 100644 --- a/python/mirage/_cython/CCore.pxd +++ b/python/mirage/_cython/CCore.pxd @@ -159,6 +159,7 @@ cdef extern from "mirage/search/search_c.h" namespace "mirage::search_c": vector[MDim3] blockdims, vector[int] fmaps, vector[int] franges, + bool verbose, const char * default_config) cdef extern from "mirage/transpiler/transpile.h" namespace "mirage::transpiler": diff --git a/python/mirage/_cython/core.pyx b/python/mirage/_cython/core.pyx index bacfd4f..d67e637 100644 --- a/python/mirage/_cython/core.pyx +++ b/python/mirage/_cython/core.pyx @@ -442,7 +442,7 @@ cdef class CyTBGraph: t = ctypes.cast(ptr, ctypes.c_void_p) return STensor(t) -def search(CyKNGraph input_graph, *, int max_num_new_graphs = 1024, list imaps = None, list omaps = None, list griddims = None, list blockdims = None, list fmaps = None, list franges = None, str previous_checkpoint = None, str default_config = None): +def search(CyKNGraph input_graph, *, int max_num_new_graphs = 1024, list imaps = None, list omaps = None, list griddims = None, list blockdims = None, list fmaps = None, list franges = None, str previous_checkpoint = None, bool verbose, str default_config = None): # set cimaps cdef vector[MInt3] cimaps cimaps.resize(0) @@ -498,12 +498,14 @@ def search(CyKNGraph input_graph, *, int max_num_new_graphs = 1024, list imaps = # currently support up to 1024 new graphs assert max_num_new_graphs <= 1024 cdef CppKNGraph* cnewgraphs[1024] + # set verbose + cverbose = verbose # convert config description cdef char* cconfig = NULL if default_config is not None: py_byte_string = default_config.encode('UTF-8') cconfig = py_byte_string - num = cython_search(input_graph.p_kgraph, max_num_new_graphs, cnewgraphs, cimaps, comaps, cgriddims, cblockdims, cfmaps, cfranges, cconfig) + num = cython_search(input_graph.p_kgraph, max_num_new_graphs, cnewgraphs, cimaps, comaps, cgriddims, cblockdims, cfmaps, cfranges, cverbose, cconfig) new_graphs = list() for i in range(num): ptr = ctypes.cast(cnewgraphs[i], ctypes.c_void_p) diff --git a/python/mirage/kernel.py b/python/mirage/kernel.py index a152cd1..b6290a1 100644 --- a/python/mirage/kernel.py +++ b/python/mirage/kernel.py @@ -213,8 +213,8 @@ def compile(self, **kwargs): self._cached_results = result return self._cached_results - def superoptimize(self, imaps : list = None, omaps : list = None, griddims : list = None, blockdims : list = None, fmaps : list = None, franges : list = None, config : str = None): - cygraphs = search(self.cygraph, imaps=imaps, omaps=omaps, griddims=griddims, blockdims=blockdims, fmaps=fmaps, franges=franges, default_config=config) + def superoptimize(self, imaps : list = None, omaps : list = None, griddims : list = None, blockdims : list = None, fmaps : list = None, franges : list = None, verbose : bool = False, config : str = None): + cygraphs = search(self.cygraph, imaps=imaps, omaps=omaps, griddims=griddims, blockdims=blockdims, fmaps=fmaps, franges=franges, verbose=verbose, default_config=config) all_graphs = [KNGraph(g) for g in cygraphs] # profile and use the best graph diff --git a/src/search/search.cc b/src/search/search.cc index cd6e133..fcd7c01 100644 --- a/src/search/search.cc +++ b/src/search/search.cc @@ -17,9 +17,10 @@ namespace search { KernelGraphGenerator::KernelGraphGenerator( kernel::Graph const &computation_graph, GeneratorConfig const &config, - char const *filename) + char const *filename, + bool verbose) : config(config), dim_strategy(DimStrategy(config)), filename(filename), - num_thread(config.search_thread), num_total_kernel_graphs(0), + num_thread(config.search_thread), verbose(verbose), num_total_random_tests(0), num_valid_kernel_graphs(0), num_total_states(0) { preprocess(computation_graph); @@ -107,8 +108,8 @@ void KernelGraphGenerator::generate_next_operator( std::function const &verify, std::vector &verified) { ++num_total_states; - if (num_total_states % 1000 == 1) { - printf("Total states explored: %d.\n", num_total_states.load()); + if (num_total_states % 100 == 1) { + show_statistics(); } if (verify(c)) { verified.push_back(SerializedSearchContext(c)); @@ -191,7 +192,7 @@ void KernelGraphGenerator::generate_next_operator( grid_dim, block_dim, forloop_dim)) { - { + if (verbose) { TBGraphConfig cfg{grid_dim, block_dim, input_map, @@ -369,6 +370,7 @@ void KernelGraphGenerator::search_from( } void KernelGraphGenerator::generate_kernel_graphs() { + start_time = std::chrono::steady_clock::now(); SearchContext c; c.level = SearchLevel::LV_KERNEL; c.kn_graph = std::make_shared(); @@ -378,8 +380,6 @@ void KernelGraphGenerator::generate_kernel_graphs() { c.kn_graph->new_input(dim, data_type, layout); } - auto start_time = std::chrono::steady_clock::now(); - std::vector middle_states; generate_next_operator( c, @@ -387,10 +387,9 @@ void KernelGraphGenerator::generate_kernel_graphs() { return c.tb_graph != nullptr || this->verify(*c.kn_graph); }, middle_states); - printf("[Search] First step finished. Time elapsed: %fsec\n", - std::chrono::duration(std::chrono::steady_clock::now() - - start_time) - .count()); + printf("\n"); + printf("[Search] First step finished. Time elapsed: %lfsec\n", + get_elapsed_time_in_sec()); std::vector> split_middle_states( num_thread); for (size_t i = 0; i < middle_states.size(); ++i) { @@ -407,12 +406,12 @@ void KernelGraphGenerator::generate_kernel_graphs() { save_results(); + printf("\n"); printf("[Search] Second step finished. Time elapsed: %fsec\n", std::chrono::duration(std::chrono::steady_clock::now() - start_time) .count()); - printf("[Search] Total kernel graphs explored: %d\n", - num_total_kernel_graphs.load()); + printf("[Search] Total states explored: %d\n", num_total_states.load()); printf("[Search] Random tests performed: %d\n", num_total_random_tests.load()); printf("[Serach] Valid kernel graphs explored: %d\n", @@ -474,12 +473,6 @@ bool KernelGraphGenerator::check_pattern( } bool KernelGraphGenerator::verify(kernel::Graph const &g) { - ++num_total_kernel_graphs; - if (num_total_kernel_graphs % 1000 == 1) { - printf("Total kernel graphs explored: %d.\n", - num_total_kernel_graphs.load()); - } - std::vector outputs = get_output_tensors(g); if (outputs.size() != computation_graph_output_patterns.size()) { @@ -535,5 +528,20 @@ void KernelGraphGenerator::save_results() const { ofs << json(generated_graphs); } +double KernelGraphGenerator::get_elapsed_time_in_sec() const { + return std::chrono::duration(std::chrono::steady_clock::now() - + start_time) + .count(); +} + +void KernelGraphGenerator::show_statistics() const { + printf( + "[Search] States: %d, Random tests: %d, Valid mugraphs: %d, Time: %lf\r", + num_total_states.load(), + num_total_random_tests.load(), + num_valid_kernel_graphs.load(), + get_elapsed_time_in_sec()); +} + } // namespace search } // namespace mirage diff --git a/src/search/search_c.cc b/src/search/search_c.cc index af37880..538dc0f 100644 --- a/src/search/search_c.cc +++ b/src/search/search_c.cc @@ -20,6 +20,7 @@ int cython_search(mirage::kernel::Graph const *input_graph, std::vector block_dim_to_explore, std::vector fmap_to_explore, std::vector frange_to_explore, + bool verbose, char const *default_config) { // NOTE(@wmdi): Checkpointing is disabled for now // Load from a checkpoint @@ -94,7 +95,7 @@ int cython_search(mirage::kernel::Graph const *input_graph, } } search::KernelGraphGenerator gen( - *input_graph, config, "mirage_search_checkpoint.json"); + *input_graph, config, "mirage_search_checkpoint.json", verbose); gen.config.show(); gen.generate_kernel_graphs(); int num = 0;