Skip to content

Commit

Permalink
[Search] Improve the display of search statistics (#92)
Browse files Browse the repository at this point in the history
* fix & upd range

* adjust heuristics for grid dim enumeration

* better search statistics display

* add verbose flag

---------

Co-authored-by: Mengdi Wu <[email protected]>
  • Loading branch information
wmdi and wmdi authored Sep 30, 2024
1 parent e787562 commit 11e6140
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 26 deletions.
12 changes: 10 additions & 2 deletions include/mirage/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class KernelGraphGenerator {
public:
KernelGraphGenerator(kernel::Graph const &computation_graph,
GeneratorConfig const &config,
char const *filename);
char const *filename,
bool verbose = false);

void generate_kernel_graphs();

Expand All @@ -30,19 +31,24 @@ class KernelGraphGenerator {
char const *filename;
std::vector<json> generated_graphs;
int num_thread;
bool verbose;

private:
// Computation graph-related fields
std::vector<std::shared_ptr<AlgebraicPattern>>
computation_graph_output_patterns;
std::vector<cpu::CTensor> computation_graph_output_tensors;
std::vector<std::tuple<std::vector<int>, type::DataType, layout::DmemLayout>>
computation_graph_input_attrs;

std::atomic<int> num_total_kernel_graphs;
// Statistics-related fields
std::atomic<int> num_total_random_tests;
std::atomic<int> num_valid_kernel_graphs;
std::atomic<int> num_total_states;

// Time
std::chrono::time_point<std::chrono::steady_clock> start_time;

std::mutex fp_mutex;
std::mutex generated_graphs_mutex;

Expand Down Expand Up @@ -70,6 +76,8 @@ class KernelGraphGenerator {
bool verify(kernel::Graph const &g);

void save_results() const;
double get_elapsed_time_in_sec() const;
void show_statistics() const;
};

} // namespace search
Expand Down
1 change: 1 addition & 0 deletions include/mirage/search/search_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ int cython_search(mirage::kernel::Graph const *input_graph,
std::vector<MDim3> block_dim_to_explore,
std::vector<int> fmap_to_explore,
std::vector<int> frange_to_explore,
bool verbose,
char const *default_config);
} // namespace search_c
} // namespace mirage
1 change: 1 addition & 0 deletions python/mirage/_cython/CCore.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ cdef extern from "mirage/search/search_c.h" namespace "mirage::search_c":
vector[MDim3] blockdims,
vector[int] fmaps,
vector[int] franges,
bool verbose,
const char * default_config)

cdef extern from "mirage/transpiler/transpile.h" namespace "mirage::transpiler":
Expand Down
6 changes: 4 additions & 2 deletions python/mirage/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ cdef class CyTBGraph:
t = ctypes.cast(<unsigned long long>ptr, ctypes.c_void_p)
return STensor(t)

def search(CyKNGraph input_graph, *, int max_num_new_graphs = 1024, list imaps = None, list omaps = None, list griddims = None, list blockdims = None, list fmaps = None, list franges = None, str previous_checkpoint = None, str default_config = None):
def search(CyKNGraph input_graph, *, int max_num_new_graphs = 1024, list imaps = None, list omaps = None, list griddims = None, list blockdims = None, list fmaps = None, list franges = None, str previous_checkpoint = None, bool verbose, str default_config = None):
# set cimaps
cdef vector[MInt3] cimaps
cimaps.resize(0)
Expand Down Expand Up @@ -498,12 +498,14 @@ def search(CyKNGraph input_graph, *, int max_num_new_graphs = 1024, list imaps =
# currently support up to 1024 new graphs
assert max_num_new_graphs <= 1024
cdef CppKNGraph* cnewgraphs[1024]
# set verbose
cverbose = verbose
# convert config description
cdef char* cconfig = NULL
if default_config is not None:
py_byte_string = default_config.encode('UTF-8')
cconfig = py_byte_string
num = cython_search(input_graph.p_kgraph, max_num_new_graphs, cnewgraphs, cimaps, comaps, cgriddims, cblockdims, cfmaps, cfranges, cconfig)
num = cython_search(input_graph.p_kgraph, max_num_new_graphs, cnewgraphs, cimaps, comaps, cgriddims, cblockdims, cfmaps, cfranges, cverbose, cconfig)
new_graphs = list()
for i in range(num):
ptr = ctypes.cast(<unsigned long long>cnewgraphs[i], ctypes.c_void_p)
Expand Down
4 changes: 2 additions & 2 deletions python/mirage/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def compile(self, **kwargs):
self._cached_results = result
return self._cached_results

def superoptimize(self, imaps : list = None, omaps : list = None, griddims : list = None, blockdims : list = None, fmaps : list = None, franges : list = None, config : str = None):
cygraphs = search(self.cygraph, imaps=imaps, omaps=omaps, griddims=griddims, blockdims=blockdims, fmaps=fmaps, franges=franges, default_config=config)
def superoptimize(self, imaps : list = None, omaps : list = None, griddims : list = None, blockdims : list = None, fmaps : list = None, franges : list = None, verbose : bool = False, config : str = None):
cygraphs = search(self.cygraph, imaps=imaps, omaps=omaps, griddims=griddims, blockdims=blockdims, fmaps=fmaps, franges=franges, verbose=verbose, default_config=config)
all_graphs = [KNGraph(g) for g in cygraphs]

# profile and use the best graph
Expand Down
46 changes: 27 additions & 19 deletions src/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ namespace search {
KernelGraphGenerator::KernelGraphGenerator(
kernel::Graph const &computation_graph,
GeneratorConfig const &config,
char const *filename)
char const *filename,
bool verbose)
: config(config), dim_strategy(DimStrategy(config)), filename(filename),
num_thread(config.search_thread), num_total_kernel_graphs(0),
num_thread(config.search_thread), verbose(verbose),
num_total_random_tests(0), num_valid_kernel_graphs(0),
num_total_states(0) {
preprocess(computation_graph);
Expand Down Expand Up @@ -107,8 +108,8 @@ void KernelGraphGenerator::generate_next_operator(
std::function<bool(SearchContext const &)> const &verify,
std::vector<SerializedSearchContext> &verified) {
++num_total_states;
if (num_total_states % 1000 == 1) {
printf("Total states explored: %d.\n", num_total_states.load());
if (num_total_states % 100 == 1) {
show_statistics();
}
if (verify(c)) {
verified.push_back(SerializedSearchContext(c));
Expand Down Expand Up @@ -191,7 +192,7 @@ void KernelGraphGenerator::generate_next_operator(
grid_dim,
block_dim,
forloop_dim)) {
{
if (verbose) {
TBGraphConfig cfg{grid_dim,
block_dim,
input_map,
Expand Down Expand Up @@ -369,6 +370,7 @@ void KernelGraphGenerator::search_from(
}

void KernelGraphGenerator::generate_kernel_graphs() {
start_time = std::chrono::steady_clock::now();
SearchContext c;
c.level = SearchLevel::LV_KERNEL;
c.kn_graph = std::make_shared<kernel::Graph>();
Expand All @@ -378,19 +380,16 @@ void KernelGraphGenerator::generate_kernel_graphs() {
c.kn_graph->new_input(dim, data_type, layout);
}

auto start_time = std::chrono::steady_clock::now();

std::vector<SerializedSearchContext> middle_states;
generate_next_operator(
c,
[this](SearchContext const &c) {
return c.tb_graph != nullptr || this->verify(*c.kn_graph);
},
middle_states);
printf("[Search] First step finished. Time elapsed: %fsec\n",
std::chrono::duration<double>(std::chrono::steady_clock::now() -
start_time)
.count());
printf("\n");
printf("[Search] First step finished. Time elapsed: %lfsec\n",
get_elapsed_time_in_sec());
std::vector<std::vector<SerializedSearchContext>> split_middle_states(
num_thread);
for (size_t i = 0; i < middle_states.size(); ++i) {
Expand All @@ -407,12 +406,12 @@ void KernelGraphGenerator::generate_kernel_graphs() {

save_results();

printf("\n");
printf("[Search] Second step finished. Time elapsed: %fsec\n",
std::chrono::duration<double>(std::chrono::steady_clock::now() -
start_time)
.count());
printf("[Search] Total kernel graphs explored: %d\n",
num_total_kernel_graphs.load());
printf("[Search] Total states explored: %d\n", num_total_states.load());
printf("[Search] Random tests performed: %d\n",
num_total_random_tests.load());
printf("[Serach] Valid kernel graphs explored: %d\n",
Expand Down Expand Up @@ -474,12 +473,6 @@ bool KernelGraphGenerator::check_pattern(
}

bool KernelGraphGenerator::verify(kernel::Graph const &g) {
++num_total_kernel_graphs;
if (num_total_kernel_graphs % 1000 == 1) {
printf("Total kernel graphs explored: %d.\n",
num_total_kernel_graphs.load());
}

std::vector<DTensor> outputs = get_output_tensors(g);

if (outputs.size() != computation_graph_output_patterns.size()) {
Expand Down Expand Up @@ -535,5 +528,20 @@ void KernelGraphGenerator::save_results() const {
ofs << json(generated_graphs);
}

double KernelGraphGenerator::get_elapsed_time_in_sec() const {
return std::chrono::duration<double>(std::chrono::steady_clock::now() -
start_time)
.count();
}

void KernelGraphGenerator::show_statistics() const {
printf(
"[Search] States: %d, Random tests: %d, Valid mugraphs: %d, Time: %lf\r",
num_total_states.load(),
num_total_random_tests.load(),
num_valid_kernel_graphs.load(),
get_elapsed_time_in_sec());
}

} // namespace search
} // namespace mirage
3 changes: 2 additions & 1 deletion src/search/search_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ int cython_search(mirage::kernel::Graph const *input_graph,
std::vector<MDim3> block_dim_to_explore,
std::vector<int> fmap_to_explore,
std::vector<int> frange_to_explore,
bool verbose,
char const *default_config) {
// NOTE(@wmdi): Checkpointing is disabled for now
// Load from a checkpoint
Expand Down Expand Up @@ -94,7 +95,7 @@ int cython_search(mirage::kernel::Graph const *input_graph,
}
}
search::KernelGraphGenerator gen(
*input_graph, config, "mirage_search_checkpoint.json");
*input_graph, config, "mirage_search_checkpoint.json", verbose);
gen.config.show();
gen.generate_kernel_graphs();
int num = 0;
Expand Down

0 comments on commit 11e6140

Please sign in to comment.