diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3b2ea59..df81e2b 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,12 @@ Changelog ========= +[0.2.2] (2024-10-21) +-------------------- +Added +******* +- expose onnxruntime execution provider + [0.2.1] (2023-12-05) -------------------- Added diff --git a/ukis_csmask/__init__.py b/ukis_csmask/__init__.py index c19de65..7977696 100755 --- a/ukis_csmask/__init__.py +++ b/ukis_csmask/__init__.py @@ -1 +1 @@ -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/ukis_csmask/mask.py b/ukis_csmask/mask.py index fd57341..33583b2 100755 --- a/ukis_csmask/mask.py +++ b/ukis_csmask/mask.py @@ -29,6 +29,7 @@ def __init__( invalid_buffer=4, intra_op_num_threads=0, inter_op_num_threads=0, + providers=None, ): """ :param img: Input satellite image of shape (rows, cols, bands). (ndarray). @@ -42,6 +43,8 @@ def __init__( :param invalid_buffer: Number of pixels that should be buffered around invalid areas. (int). :param intra_op_num_threads: Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose. (int). :param inter_op_num_threads: Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose. (int). + :param providers: onnxruntime session providers. Default is None to let onnxruntime choose. (list). + >>> providers = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] """ # consistency checks on input image if isinstance(img, np.ndarray) is False: @@ -98,9 +101,8 @@ def __init__( so = onnxruntime.SessionOptions() so.intra_op_num_threads = intra_op_num_threads so.inter_op_num_threads = inter_op_num_threads - self.sess = onnxruntime.InferenceSession( - model_file, sess_options=so, providers=onnxruntime.get_available_providers() - ) + providers = onnxruntime.get_available_providers() if providers is None else providers + self.sess = onnxruntime.InferenceSession(model_file, sess_options=so, providers=providers) self.img = img self.band_order = band_order