From 640d44714c91ae5eb1d3260385cd5e10f7af09d0 Mon Sep 17 00:00:00 2001 From: roxx30198 Date: Fri, 20 Oct 2023 21:49:16 -0600 Subject: [PATCH] rodinia/gaussian sycl and ndpx implementation --- .github/workflows/build_and_run.yml | 3 + .github/workflows/conda-package.yml | 4 + .pre-commit-config.yaml | 3 +- dpbench/benchmarks/CMakeLists.txt | 1 + dpbench/benchmarks/rodinia/CMakeLists.txt | 5 + .../rodinia/gaussian/CMakeLists.txt | 5 + .../benchmarks/rodinia/gaussian/__init__.py | 28 ++++ .../rodinia/gaussian/gaussian_initialize.py | 34 +++++ .../rodinia/gaussian/gaussian_numba_dpex_k.py | 107 ++++++++++++++ .../rodinia/gaussian/gaussian_python.py | 24 ++++ .../gaussian_sycl_native_ext/CMakeLists.txt | 14 ++ .../gaussian_sycl_native_ext/__init__.py | 7 + .../gaussian_sycl/_gaussian_kernel.hpp | 57 ++++++++ .../gaussian_sycl/_gaussian_sycl.cpp | 131 ++++++++++++++++++ dpbench/config/reader.py | 13 ++ .../configs/bench_info/rodinia/gaussian.toml | 55 ++++++++ dpbench/console/_namespace.py | 1 + dpbench/console/config.py | 1 + dpbench/console/run.py | 7 + setup.py | 1 + 20 files changed, 500 insertions(+), 1 deletion(-) create mode 100644 dpbench/benchmarks/rodinia/CMakeLists.txt create mode 100644 dpbench/benchmarks/rodinia/gaussian/CMakeLists.txt create mode 100644 dpbench/benchmarks/rodinia/gaussian/__init__.py create mode 100644 dpbench/benchmarks/rodinia/gaussian/gaussian_initialize.py create mode 100644 dpbench/benchmarks/rodinia/gaussian/gaussian_numba_dpex_k.py create mode 100644 dpbench/benchmarks/rodinia/gaussian/gaussian_python.py create mode 100644 dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/CMakeLists.txt create mode 100644 dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/__init__.py create mode 100644 dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_kernel.hpp create mode 100644 dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_sycl.cpp create mode 100644 dpbench/configs/bench_info/rodinia/gaussian.toml diff --git a/.github/workflows/build_and_run.yml b/.github/workflows/build_and_run.yml index 37cd16f3..e0b2cec4 100644 --- a/.github/workflows/build_and_run.yml +++ b/.github/workflows/build_and_run.yml @@ -168,5 +168,8 @@ jobs: - name: Run benchmarks run: dpbench -i ${{env.WORKLOADS}} run -r2 --no-print-results || exit 1 + - name: Run rodinia benchmarks + run: dpbench -i ${{env.WORKLOADS}} run -r2 --no-print-results --rodinia --no-dpbench|| exit 1 + - name: Generate report run: dpbench -i ${{env.WORKLOADS}} report || exit 1 diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index b6c4ec33..b6292b61 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -195,6 +195,10 @@ jobs: run: | dpbench -i numpy -b azimint_hist run --npbench + - name: Run rodinia benchmark + run: | + dpbench -b gaussian run --rodinia --no-dpbench -r 1 + upload_anaconda: name: Upload dppy/label/dev ['${{ matrix.os }}', python='${{ matrix.python }}'] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ed1b00bd..fcb558d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,8 @@ repos: hooks: - id: pydocstyle # TODO: add packages one by one to enforce pydocstyle eventually - files: (^dpbench/config/|^scripts/|^dpbench/console/|^dpbench/infrastructure/benchmark_runner.py|^dpbench/infrastructure/benchmark_validation.py) + files: (^dpbench/config/|^scripts/|^dpbench/console/|^dpbench/infrastructure/benchmark_runner.py|^dpbench/infrastructure/benchmark_validation.py| + ^dpbench/benchmarks/rodinia) args: ["--convention=google"] # D417 does not work properly: # https://github.com/PyCQA/pydocstyle/issues/459 diff --git a/dpbench/benchmarks/CMakeLists.txt b/dpbench/benchmarks/CMakeLists.txt index 6869220b..c1693bc1 100644 --- a/dpbench/benchmarks/CMakeLists.txt +++ b/dpbench/benchmarks/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(kmeans) add_subdirectory(knn) add_subdirectory(gpairs) add_subdirectory(dbscan) +add_subdirectory(rodinia) # generate dpcpp version into config set(FILE ${CMAKE_SOURCE_DIR}/dpbench/configs/framework_info/dpcpp.toml) diff --git a/dpbench/benchmarks/rodinia/CMakeLists.txt b/dpbench/benchmarks/rodinia/CMakeLists.txt new file mode 100644 index 00000000..eac8f1cf --- /dev/null +++ b/dpbench/benchmarks/rodinia/CMakeLists.txt @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +add_subdirectory(gaussian) diff --git a/dpbench/benchmarks/rodinia/gaussian/CMakeLists.txt b/dpbench/benchmarks/rodinia/gaussian/CMakeLists.txt new file mode 100644 index 00000000..d60c99ed --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/CMakeLists.txt @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +add_subdirectory(gaussian_sycl_native_ext) diff --git a/dpbench/benchmarks/rodinia/gaussian/__init__.py b/dpbench/benchmarks/rodinia/gaussian/__init__.py new file mode 100644 index 00000000..80c92bf6 --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/__init__.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +""" + +Gaussian elimination implementation + +This is sycl and numba-dpex implementation for gaussian elimination + +Input +--------- +size : Forms an input matrix of dimensions (size x size) + +Output + +-------- + +result> : Result of the given set of linear equations using + gaussian elimination. + +Method: + +The gaussian transformations are applied to the input matrix to form the +diagonal matrix in forward elimination, and then the equations are solved +to find the result in back substitution. + +""" diff --git a/dpbench/benchmarks/rodinia/gaussian/gaussian_initialize.py b/dpbench/benchmarks/rodinia/gaussian/gaussian_initialize.py new file mode 100644 index 00000000..21821337 --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/gaussian_initialize.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +LAMBDA = -0.01 + + +def initialize(size, types_dict): + import math + + import numpy as np + + dtype = types_dict["float"] + + coe = np.empty((2 * size - 1), dtype=dtype) + a = np.empty((size * size), dtype=dtype) + + for i in range(size): + coe_i = 10 * math.exp(LAMBDA * i) + j = size - 1 + i + coe[j] = coe_i + j = size - 1 - i + coe[j] = coe_i + + for i in range(size): + for j in range(size): + a[i * size + j] = coe[size - 1 - i + j] + + return ( + a, + np.ones(size, dtype=dtype), + np.zeros((size * size), dtype=dtype), + np.zeros(size, dtype=dtype), + ) diff --git a/dpbench/benchmarks/rodinia/gaussian/gaussian_numba_dpex_k.py b/dpbench/benchmarks/rodinia/gaussian/gaussian_numba_dpex_k.py new file mode 100644 index 00000000..369eebbe --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/gaussian_numba_dpex_k.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import dpctl +import numba_dpex + +BLOCK_size_XY = 4 + + +@numba_dpex.kernel() +def gaussian_kernel_1(m, a, size, t): + if ( + numba_dpex.get_local_id(2) + + numba_dpex.get_group_id(2) * numba_dpex.get_local_size(2) + >= size - 1 - t + ): + return + + m[ + size + * ( + numba_dpex.get_local_size(2) * numba_dpex.get_group_id(2) + + numba_dpex.get_local_id(2) + + t + + 1 + ) + + t + ] = ( + a[ + size + * ( + numba_dpex.get_local_size(2) * numba_dpex.get_group_id(2) + + numba_dpex.get_local_id(2) + + t + + 1 + ) + + t + ] + / a[size * t + t] + ) + + +@numba_dpex.kernel() +def gaussian_kernel_2(m, a, b, size, t): + if ( + numba_dpex.get_local_id(2) + + numba_dpex.get_group_id(2) * numba_dpex.get_local_size(2) + >= size - 1 - t + ): + return + + if ( + numba_dpex.get_local_id(1) + + numba_dpex.get_group_id(1) * numba_dpex.get_local_size(1) + >= size - t + ): + return + + xidx = numba_dpex.get_group_id(2) * numba_dpex.get_local_size( + 2 + ) + numba_dpex.get_local_id(2) + yidx = numba_dpex.get_group_id(1) * numba_dpex.get_local_size( + 1 + ) + numba_dpex.get_local_id(1) + + a[size * (xidx + 1 + t) + (yidx + t)] -= ( + m[size * (xidx + 1 + t) + t] * a[size * t + (yidx + t)] + ) + if yidx == 0: + b[xidx + 1 + t] -= m[size * (xidx + 1 + t) + (yidx + t)] * b[t] + + +def gaussian(a, b, m, size, result): + device = dpctl.SyclDevice() + block_size = device.max_work_group_size + grid_size = int((size / block_size) + 0 if not (size % block_size) else 1) + + blocksize2d = BLOCK_size_XY + gridsize2d = (size / blocksize2d) + (0 if not (size % blocksize2d) else 1) + + global_range = numba_dpex.Range(1, 1, grid_size * block_size) + local_range = numba_dpex.Range(1, 1, block_size) + + dim_blockXY = numba_dpex.Range(1, blocksize2d, blocksize2d) + dim_gridXY = numba_dpex.Range( + 1, int(gridsize2d) * blocksize2d, int(gridsize2d) * blocksize2d + ) + + for t in range(size - 1): + gaussian_kernel_1[numba_dpex.NdRange(global_range, local_range)]( + m, a, size, t + ) + + gaussian_kernel_2[numba_dpex.NdRange(dim_gridXY, dim_blockXY)]( + m, a, b, size, t + ) + + for i in range(size): + result[size - i - 1] = b[size - i - 1] + for j in range(i): + result[size - i - 1] -= ( + a[size * (size - i - 1) + (size - j - 1)] * result[size - j - 1] + ) + result[size - i - 1] = ( + result[size - i - 1] / a[size * (size - i - 1) + (size - i - 1)] + ) diff --git a/dpbench/benchmarks/rodinia/gaussian/gaussian_python.py b/dpbench/benchmarks/rodinia/gaussian/gaussian_python.py new file mode 100644 index 00000000..189609d2 --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/gaussian_python.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + + +def gaussian(a, b, m, size, result): + # Forward Elimination + for t in range(size - 1): + for i in range(t + 1, size): + m = a[i * size + t] / a[t * size + t] + for j in range(t, size): + a[i * size + j] = a[i * size + j] - m * a[t * size + j] + b[i] = b[i] - m * b[t] + + # Back Substitution + for i in range(size): + result[size - i - 1] = b[size - i - 1] + for j in range(i): + result[size - i - 1] -= ( + a[size * (size - i - 1) + (size - j - 1)] * result[size - j - 1] + ) + result[size - i - 1] = ( + result[size - i - 1] / a[size * (size - i - 1) + (size - i - 1)] + ) diff --git a/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/CMakeLists.txt b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/CMakeLists.txt new file mode 100644 index 00000000..236f6c45 --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/CMakeLists.txt @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +set(module_name gaussian_sycl) +set(py_module_name _${module_name}) +python_add_library(${py_module_name} MODULE ${module_name}/${py_module_name}.cpp) +add_sycl_to_target(TARGET ${py_module_name} SOURCES ${module_name}/${py_module_name}.cpp) +target_include_directories(${py_module_name} PRIVATE ${Dpctl_INCLUDE_DIRS}) + +file(RELATIVE_PATH py_module_dest ${CMAKE_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) +install(TARGETS ${py_module_name} + DESTINATION ${py_module_dest}/${module_name} +) diff --git a/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/__init__.py b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/__init__.py new file mode 100644 index 00000000..e99261dc --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from .gaussian_sycl._gaussian_sycl import gaussian as gaussian_sycl + +__all__ = ["gaussian_sycl"] diff --git a/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_kernel.hpp b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_kernel.hpp new file mode 100644 index 00000000..fd184c62 --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_kernel.hpp @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +#include + +using namespace sycl; + +template +void gaussian_kernel_1(FpTy *m_device, + const FpTy *a_device, + int size, + int t, + sycl::nd_item<3> item_ct1) +{ + if (item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range().get(2) >= + size - 1 - t) + return; + m_device[size * (item_ct1.get_local_range().get(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2) + t + 1) + + t] = a_device[size * (item_ct1.get_local_range().get(2) * + item_ct1.get_group(2) + + item_ct1.get_local_id(2) + t + 1) + + t] / + a_device[size * t + t]; +} + +template +void gaussian_kernel_2(FpTy *m_device, + FpTy *a_device, + FpTy *b_device, + int size, + int j1, + int t, + sycl::nd_item<3> item_ct1) +{ + if (item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range().get(2) >= + size - 1 - t) + return; + if (item_ct1.get_local_id(1) + + item_ct1.get_group(1) * item_ct1.get_local_range().get(1) >= + size - t) + return; + + int xidx = item_ct1.get_group(2) * item_ct1.get_local_range().get(2) + + item_ct1.get_local_id(2); + int yidx = item_ct1.get_group(1) * item_ct1.get_local_range().get(1) + + item_ct1.get_local_id(1); + + a_device[size * (xidx + 1 + t) + (yidx + t)] -= + m_device[size * (xidx + 1 + t) + t] * a_device[size * t + (yidx + t)]; + if (yidx == 0) { + b_device[xidx + 1 + t] -= + m_device[size * (xidx + 1 + t) + (yidx + t)] * b_device[t]; + } +} diff --git a/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_sycl.cpp b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_sycl.cpp new file mode 100644 index 00000000..1e7e21dc --- /dev/null +++ b/dpbench/benchmarks/rodinia/gaussian/gaussian_sycl_native_ext/gaussian_sycl/_gaussian_sycl.cpp @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +#include "_gaussian_kernel.hpp" +#include +#include + +#define BLOCK_size_XY 4 + +template bool ensure_compatibility(const Args &...args) +{ + std::vector arrays = {args...}; + + auto arr = arrays.at(0); + auto q = arr.get_queue(); + auto type_flag = arr.get_typenum(); + auto arr_size = arr.get_size(); + + for (auto &arr : arrays) { + if (!(arr.get_flags() & (USM_ARRAY_C_CONTIGUOUS))) { + std::cerr << "All arrays need to be C contiguous.\n"; + return false; + } + if (arr.get_typenum() != type_flag) { + std::cerr << "All arrays should be of same elemental type.\n"; + return false; + } + if (arr.get_ndim() > 1) { + std::cerr << "All arrays expected to be single-dimensional.\n"; + return false; + } + } + return true; +} + +void gaussian_sync(dpctl::tensor::usm_ndarray a, + dpctl::tensor::usm_ndarray b, + dpctl::tensor::usm_ndarray m, + int size, + dpctl::tensor::usm_ndarray result) +{ + if (!ensure_compatibility(a, m, b, result)) + throw std::runtime_error("Input arrays are not acceptable."); + + int t; + + sycl::queue q_ct1; + + int block_size, grid_size; + block_size = q_ct1.get_device() + .get_info(); + grid_size = (size / block_size) + (!(size % block_size) ? 0 : 1); + + sycl::range<3> dimBlock(1, 1, block_size); + sycl::range<3> dimGrid(1, 1, grid_size); + + int blocksize2d, gridsize2d; + blocksize2d = BLOCK_size_XY; + gridsize2d = (size / blocksize2d) + (!(size % blocksize2d ? 0 : 1)); + + sycl::range<3> dimBlockXY(1, blocksize2d, blocksize2d); + sycl::range<3> dimGridXY(1, gridsize2d, gridsize2d); + + auto a_value = a.get_data(); + auto b_value = b.get_data(); + auto m_value = m.get_data(); + + for (t = 0; t < (size - 1); t++) { + /* + DPCT1049:7: The workgroup size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the workgroup size if + needed. + */ + q_ct1.submit([&](sycl::handler &cgh) { + auto size_ct2 = size; + cgh.parallel_for(sycl::nd_range<3>(dimGrid * dimBlock, dimBlock), + [=](sycl::nd_item<3> item_ct1) { + gaussian_kernel_1(m_value, a_value, size_ct2, + t, item_ct1); + }); + }); + q_ct1.wait_and_throw(); + /* + DPCT1049:8: The workgroup size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the workgroup size if + needed. + */ + q_ct1.submit([&](sycl::handler &cgh) { + auto size_ct3 = size; + auto size_t_ct4 = size - t; + + cgh.parallel_for( + sycl::nd_range<3>(dimGridXY * dimBlockXY, dimBlockXY), + [=](sycl::nd_item<3> item_ct1) { + gaussian_kernel_2(m_value, a_value, b_value, size_ct3, + size_t_ct4, t, item_ct1); + }); + }); + q_ct1.wait_and_throw(); + } + // Copying the final answer + auto result_value = result.get_data(); + + for (int i = 0; i < size; i++) { + + result_value[size - i - 1] = b_value[size - i - 1]; + + for (int j = 0; j < i; j++) { + result_value[size - i - 1] -= + *(a_value + size * (size - i - 1) + (size - j - 1)) * + result_value[size - j - 1]; + } + + result_value[size - i - 1] = + result_value[size - i - 1] / + *(a_value + size * (size - i - 1) + (size - i - 1)); + } +} + +PYBIND11_MODULE(_gaussian_sycl, m) +{ + // Import the dpctl extensions + import_dpctl(); + + m.def("gaussian", &gaussian_sync, + "DPC++ implementation of the gaussian elimination", py::arg("a"), + py::arg("b"), py::arg("m"), py::arg("size"), py::arg("result")); +} diff --git a/dpbench/config/reader.py b/dpbench/config/reader.py index bc549653..e24d5c53 100644 --- a/dpbench/config/reader.py +++ b/dpbench/config/reader.py @@ -28,6 +28,7 @@ def read_configs( # noqa: C901: TODO: move modules into config no_dpbench: bool = False, with_npbench: bool = False, with_polybench: bool = False, + with_rodinia: bool = False, load_implementations: bool = True, ) -> Config: """Read all configuration files and populate those settings into Config. @@ -85,6 +86,18 @@ def read_configs( # noqa: C901: TODO: move modules into config ) ) + if with_rodinia: + modules.append( + Module( + benchmark_configs_path=os.path.join( + dirname, "../configs/bench_info/rodinia" + ), + benchmark_configs_recursive=True, + benchmarks_module="dpbench.benchmarks.rodinia", + path=os.path.join(dirname, "../benchmarks/rodinia"), + ) + ) + for mod in modules: if mod.benchmark_configs_path != "": read_benchmarks( diff --git a/dpbench/configs/bench_info/rodinia/gaussian.toml b/dpbench/configs/bench_info/rodinia/gaussian.toml new file mode 100644 index 00000000..662ef146 --- /dev/null +++ b/dpbench/configs/bench_info/rodinia/gaussian.toml @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +[benchmark] +name = "Gaussian Elimination" +short_name = "gaussian" +relative_path = "gaussian" +module_name = "gaussian" +func_name = "gaussian" +kind = "microbenchmark" +domain = "Matrix manipulation" +input_args = [ + "a", + "b", + "m", + "size", + "result" +] +array_args = [ + "a", + "b", + "m", + "result" +] +output_args = [ + "result", +] + +[benchmark.parameters.S] +size = 100 + +[benchmark.parameters.M16Gb] +size = 4096 + +[benchmark.parameters.M] +size = 4096 + +[benchmark.parameters.L] +size = 8192 + +[benchmark.init] +func_name = "initialize" +types_dict_name="types_dict" +precision="double" +input_args = [ + "size", + "types_dict", +] +output_args = [ + "a", + "b", + "m", + "result" +] diff --git a/dpbench/console/_namespace.py b/dpbench/console/_namespace.py index 2cefe737..f6c5f63a 100644 --- a/dpbench/console/_namespace.py +++ b/dpbench/console/_namespace.py @@ -19,6 +19,7 @@ class Namespace(argparse.Namespace): dpbench: bool npbench: bool polybench: bool + rodinia: bool print_results: bool validate: bool run_id: Union[int, None] diff --git a/dpbench/console/config.py b/dpbench/console/config.py index fe962f79..f7144367 100644 --- a/dpbench/console/config.py +++ b/dpbench/console/config.py @@ -43,6 +43,7 @@ def execute_config(args: Namespace): implementations=args.implementations, with_npbench=True, with_polybench=True, + with_rodinia=True, ) color_output = args.color diff --git a/dpbench/console/run.py b/dpbench/console/run.py index e22433dd..9e44c8b8 100644 --- a/dpbench/console/run.py +++ b/dpbench/console/run.py @@ -58,6 +58,12 @@ def add_run_arguments(parser: argparse.ArgumentParser): default=False, help="Set if run polybench benchmarks.", ) + parser.add_argument( + "--rodinia", + action=argparse.BooleanOptionalAction, + default=False, + help="Set if run rodinia benchmarks.", + ) parser.add_argument( "-r", "--repeat", @@ -135,6 +141,7 @@ def execute_run(args: Namespace, conn: sqlalchemy.Engine): no_dpbench=not args.dpbench, with_npbench=args.npbench, with_polybench=args.polybench, + with_rodinia=args.rodinia, ) if args.all_implementations: diff --git a/setup.py b/setup.py index 6c55a714..37c39bd5 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "bench_info/polybench/linear-algebra/blas/*.toml", "bench_info/polybench/medley/*.toml", "bench_info/npbench/*.toml", + "bench_info/rodinia/*.toml", "framework_info/*.toml", ], },