Skip to content

Commit

Permalink
fix typing on thread_pool and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mjurbanski-reef committed Feb 26, 2024
1 parent 2c7ce95 commit ed48c42
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 11 deletions.
28 changes: 25 additions & 3 deletions b2sdk/utils/thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,29 @@
######################################################################
from __future__ import annotations

import os
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Callable

try:
from typing_extensions import Protocol
except ImportError:
from typing import Protocol

from b2sdk.utils import B2TraceMetaAbstract


class DynamicThreadPoolExecutorProtocol(Protocol):
def submit(self, fn: Callable, *args, **kwargs) -> Future:
...

def set_size(self, max_workers: int) -> None:
"""Set the size of the thread pool."""

def get_size(self) -> int:
"""Return the current size of the thread pool."""


class LazyThreadPool:
"""
Lazily initialized thread pool.
Expand All @@ -23,6 +40,10 @@ class LazyThreadPool:
_THREAD_POOL_FACTORY = ThreadPoolExecutor

def __init__(self, max_workers: int | None = None, **kwargs):
if max_workers is None:
max_workers = min(
32, (os.cpu_count() or 1) + 4
) # same default as in ThreadPoolExecutor
self._max_workers = max_workers
self._thread_pool: ThreadPoolExecutor | None = None
super().__init__(**kwargs)
Expand All @@ -49,7 +70,8 @@ def set_size(self, max_workers: int) -> None:
old_thread_pool.shutdown(wait=True)
self._max_workers = max_workers

def get_size(self) -> int | None:
def get_size(self) -> int:
"""Return the current size of the thread pool."""
return self._max_workers


Expand All @@ -62,7 +84,7 @@ class ThreadPoolMixin(metaclass=B2TraceMetaAbstract):

def __init__(
self,
thread_pool: ThreadPoolExecutor | None = None,
thread_pool: DynamicThreadPoolExecutorProtocol | None = None,
max_workers: int | None = None,
**kwargs,
):
Expand All @@ -88,5 +110,5 @@ def set_thread_pool_size(self, max_workers: int) -> None:
"""
return self._thread_pool.set_size(max_workers)

def get_thread_pool_size(self) -> int | None:
def get_thread_pool_size(self) -> int:
return self._thread_pool.get_size()
43 changes: 43 additions & 0 deletions test/unit/utils/test_thread_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
######################################################################
#
# File: test/unit/utils/test_thread_pool.py
#
# Copyright 2024 Backblaze Inc. All Rights Reserved.
#
# License https://www.backblaze.com/using_b2_code.html
#
######################################################################
from concurrent.futures import Future

import pytest

from b2sdk.utils.thread_pool import LazyThreadPool


class TestLazyThreadPool:
@pytest.fixture
def thread_pool(self):
return LazyThreadPool()

def test_submit(self, thread_pool):

future = thread_pool.submit(sum, (1, 2))
assert isinstance(future, Future)
assert future.result() == 3

def test_set_size(self, thread_pool):
thread_pool.set_size(10)
assert thread_pool.get_size() == 10

def test_get_size(self, thread_pool):
assert thread_pool.get_size() > 0

def test_set_size__after_submit(self, thread_pool):
future = thread_pool.submit(sum, (1, 2))

thread_pool.set_size(7)
assert thread_pool.get_size() == 7

assert future.result() == 3

assert thread_pool.submit(sum, (1,)).result() == 1
18 changes: 10 additions & 8 deletions test/unit/v_all/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@
class TestDownloadManager(TestBase):
def test_set_thread_pool_size(self) -> None:
download_manager = DownloadManager(services=Mock())
assert download_manager.get_thread_pool_size() is None
download_manager.set_thread_pool_size(21)
assert download_manager._thread_pool._max_workers == 21
assert download_manager.get_thread_pool_size() == 21
assert download_manager.get_thread_pool_size() > 0

pool_size = 21
download_manager.set_thread_pool_size(pool_size)
assert download_manager.get_thread_pool_size() == pool_size


class TestUploadManager(TestBase):
def test_set_thread_pool_size(self) -> None:
upload_manager = UploadManager(services=Mock())
assert upload_manager.get_thread_pool_size() is None
upload_manager.set_thread_pool_size(37)
assert upload_manager._thread_pool._max_workers == 37
assert upload_manager.get_thread_pool_size() == 37
assert upload_manager.get_thread_pool_size() > 0

pool_size = 37
upload_manager.set_thread_pool_size(pool_size)
assert upload_manager.get_thread_pool_size() == pool_size

0 comments on commit ed48c42

Please sign in to comment.