Skip to content

Commit

Permalink
Merge pull request #22 from dlr-eoc/provider
Browse files Browse the repository at this point in the history
exposed onnxruntime execution provider
  • Loading branch information
MWieland authored Oct 22, 2024
2 parents 91938f2 + 13cafef commit bdf735d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

[0.2.2] (2024-10-21)
--------------------
Added
*******
- expose onnxruntime execution provider

[0.2.1] (2023-12-05)
--------------------
Added
Expand Down
2 changes: 1 addition & 1 deletion ukis_csmask/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1"
__version__ = "0.2.2"
8 changes: 5 additions & 3 deletions ukis_csmask/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
invalid_buffer=4,
intra_op_num_threads=0,
inter_op_num_threads=0,
providers=None,
):
"""
:param img: Input satellite image of shape (rows, cols, bands). (ndarray).
Expand All @@ -42,6 +43,8 @@ def __init__(
:param invalid_buffer: Number of pixels that should be buffered around invalid areas. (int).
:param intra_op_num_threads: Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose. (int).
:param inter_op_num_threads: Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose. (int).
:param providers: onnxruntime session providers. Default is None to let onnxruntime choose. (list).
>>> providers = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
"""
# consistency checks on input image
if isinstance(img, np.ndarray) is False:
Expand Down Expand Up @@ -98,9 +101,8 @@ def __init__(
so = onnxruntime.SessionOptions()
so.intra_op_num_threads = intra_op_num_threads
so.inter_op_num_threads = inter_op_num_threads
self.sess = onnxruntime.InferenceSession(
model_file, sess_options=so, providers=onnxruntime.get_available_providers()
)
providers = onnxruntime.get_available_providers() if providers is None else providers
self.sess = onnxruntime.InferenceSession(model_file, sess_options=so, providers=providers)

self.img = img
self.band_order = band_order
Expand Down

0 comments on commit bdf735d

Please sign in to comment.