diff --git a/iceprod/rest/handlers/tasks.py b/iceprod/rest/handlers/tasks.py index 4fd0fd0a..482ad53a 100644 --- a/iceprod/rest/handlers/tasks.py +++ b/iceprod/rest/handlers/tasks.py @@ -7,6 +7,7 @@ import pymongo import tornado.web +from wipac_dev_tools import strtobool from ..base_handler import APIBase from ..auth import authorization, attr_auth @@ -308,44 +309,6 @@ async def put(self, task_id): self.finish() -class TaskCountsStatusHandler(APIBase): - """ - Handle task summary grouping by status. - """ - @authorization(roles=['admin', 'system']) - async def get(self): - """ - Get the task counts for all tasks, group by status. - - Params (optional): - status: | separated list of task status to filter by - - Returns: - dict: {: num} - """ - match = {} - status = self.get_argument('status', None) - if status: - status_list = status.split('|') - if any(s not in TASK_STATUS for s in status_list): - raise tornado.web.HTTPError(400, reaosn='Unknown task status') - match['status'] = {'$in': status_list} - - ret = {} - cursor = self.db.tasks.aggregate([ - {'$match': match}, - {'$group': {'_id': '$status', 'total': {'$sum': 1}}}, - ]) - ret = {} - async for row in cursor: - ret[row['_id']] = row['total'] - ret2 = {} - for k in sorted(ret, key=task_status_sort): - ret2[k] = ret[k] - self.write(ret2) - self.finish() - - class DatasetMultiTasksHandler(APIBase): """ Handle multi tasks requests. @@ -553,23 +516,11 @@ async def get(self, dataset_id): self.finish() -class DatasetTaskCountsStatusHandler(APIBase): +class TaskCountsStatusHandler(APIBase): """ Handle task summary grouping by status. """ - @authorization(roles=['admin', 'user', 'system']) - @attr_auth(arg='dataset_id', role='read') - async def get(self, dataset_id): - """ - Get the task counts for all tasks in a dataset, group by status. - - Args: - dataset_id (str): dataset id - - Returns: - dict: {: num} - """ - match = {'dataset_id': dataset_id} + async def counts(self, match): status = self.get_argument('status', None) if status: status_list = status.split('|') @@ -577,6 +528,13 @@ async def get(self, dataset_id): raise tornado.web.HTTPError(400, reaosn='Unknown task status') match['status'] = {'$in': status_list} + gpu = self.get_argument('gpu', None) + if gpu is not None: + if strtobool(gpu): + match['requirements.gpu'] = {'$gte': 1} + else: + match['$or'] = [{"requirements.gpu": {"$exists": False}}, {"requirements.gpu": {"$lte": 0}}] + ret = {} cursor = self.db.tasks.aggregate([ {'$match': match}, @@ -591,6 +549,44 @@ async def get(self, dataset_id): self.write(ret2) self.finish() + @authorization(roles=['admin', 'system']) + async def get(self): + """ + Get the task counts for all tasks, group by status. + + Params (optional): + status: | separated list of task status to filter by + gpu: bool to select only gpu tasks or non-gpu tasks + + Returns: + dict: {: num} + """ + await self.counts(match={}) + + +class DatasetTaskCountsStatusHandler(TaskCountsStatusHandler): + """ + Handle task summary grouping by status. + """ + @authorization(roles=['admin', 'user', 'system']) + @attr_auth(arg='dataset_id', role='read') + async def get(self, dataset_id): + """ + Get the task counts for all tasks in a dataset, group by status. + + Args: + dataset_id (str): dataset id + + Params (optional): + status: | separated list of task status to filter by + gpu: bool to select only gpu tasks or non-gpu tasks + + Returns: + dict: {: num} + """ + match = {'dataset_id': dataset_id} + await self.counts(match=match) + class DatasetTaskCountsNameStatusHandler(APIBase): """ diff --git a/iceprod/scheduled_tasks/queue_tasks.py b/iceprod/scheduled_tasks/queue_tasks.py index 7234e888..8cf1823b 100644 --- a/iceprod/scheduled_tasks/queue_tasks.py +++ b/iceprod/scheduled_tasks/queue_tasks.py @@ -8,40 +8,48 @@ import argparse import asyncio import logging -import os + +from wipac_dev_tools import from_environment, strtobool from iceprod.client_auth import add_auth_to_argparse, create_rest_client logger = logging.getLogger('queue_tasks') -NTASKS = 250000 -NTASKS_PER_CYCLE = 1000 +default_config = { + 'NTASKS': 250000, + 'NTASKS_PER_CYCLE': 1000, +} -async def run(rest_client, dataset_id=None, ntasks=NTASKS, ntasks_per_cycle=NTASKS_PER_CYCLE, debug=False): +async def run(rest_client, config, dataset_id='', gpus=None, debug=False): """ Actual runtime / loop. Args: rest_client (:py:class:`iceprod.core.rest_client.Client`): rest client + config (dict): config dict + dataset_id (str): dataset to queue + gpus (bool): run on gpu tasks, cpu tasks, or both debug (bool): debug flag to propagate exceptions """ try: + num_tasks_idle = 0 num_tasks_waiting = 0 - num_tasks_queued = 0 if dataset_id: route = f'/datasets/{dataset_id}/task_counts/status' else: route = '/task_counts/status' args = {'status': 'idle|waiting'} + if gpus is not None: + args['gpu'] = gpus tasks = await rest_client.request('GET', route, args) if 'idle' in tasks: - num_tasks_waiting = tasks['idle'] + num_tasks_idle = tasks['idle'] if 'waiting' in tasks: - num_tasks_queued = tasks['waiting'] - tasks_to_queue = min(num_tasks_waiting, ntasks - num_tasks_queued, ntasks_per_cycle) - logger.warning(f'num tasks idle: {num_tasks_waiting}') - logger.warning(f'num tasks waiting: {num_tasks_queued}') + num_tasks_waiting = tasks['waiting'] + tasks_to_queue = min(num_tasks_idle, config['NTASKS'] - num_tasks_waiting, config['NTASKS_PER_CYCLE']) + logger.warning(f'num tasks idle: {num_tasks_idle}') + logger.warning(f'num tasks waiting: {num_tasks_waiting}') logger.warning(f'tasks to waiting: {tasks_to_queue}') if tasks_to_queue > 0: @@ -128,24 +136,28 @@ async def check_deps(task): def main(): + config = from_environment(default_config) + parser = argparse.ArgumentParser(description='run a scheduled task once') add_auth_to_argparse(parser) parser.add_argument('--dataset-id', help='dataset id') - parser.add_argument('--ntasks', type=int, default=os.environ.get('NTASKS', NTASKS), + parser.add_argument('--gpus', default=None, type=strtobool, help='whether to select only gpu or non-gpu tasks') + parser.add_argument('--ntasks', type=int, default=config['NTASKS'], help='number of tasks to keep queued') - parser.add_argument('--ntasks_per_cycle', type=int, default=os.environ.get('NTASKS_PER_CYCLE', NTASKS_PER_CYCLE), + parser.add_argument('--ntasks_per_cycle', type=int, default=config['NTASKS_PER_CYCLE'], help='number of tasks to queue per cycle') parser.add_argument('--log-level', default='info', help='log level') parser.add_argument('--debug', default=False, action='store_true', help='debug enabled') args = parser.parse_args() + config.update(vars(args)) logformat = '%(asctime)s %(levelname)s %(name)s %(module)s:%(lineno)s - %(message)s' logging.basicConfig(format=logformat, level=getattr(logging, args.log_level.upper())) rest_client = create_rest_client(args) - asyncio.run(run(rest_client, dataset_id=args.dataset_id, ntasks=args.ntasks, ntasks_per_cycle=args.ntasks_per_cycle, debug=args.debug)) + asyncio.run(run(rest_client, dataset_id=args.dataset_id, gpus=args.gpus, config=config, debug=args.debug)) if __name__ == '__main__': diff --git a/requirements-docs.txt b/requirements-docs.txt index 6a5df3b9..7256bc04 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -6,7 +6,7 @@ # alabaster==1.0.0 # via sphinx -anyio==4.7.0 +anyio==4.8.0 # via httpx asyncache==0.3.1 # via iceprod (setup.py) @@ -16,9 +16,9 @@ attrs==24.3.0 # referencing babel==2.16.0 # via sphinx -boto3==1.35.92 +boto3==1.35.96 # via iceprod (setup.py) -botocore==1.35.92 +botocore==1.35.96 # via # boto3 # s3transfer @@ -50,7 +50,7 @@ exceptiongroup==1.2.2 # via anyio h11==0.14.0 # via httpcore -htcondor==24.2.1 +htcondor==24.3.0 # via iceprod (setup.py) httpcore==1.0.7 # via httpx @@ -87,7 +87,7 @@ pyasn1==0.6.1 # via ldap3 pycparser==2.22 # via cffi -pygments==2.18.0 +pygments==2.19.1 # via sphinx pyjwt[crypto]==2.10.1 # via wipac-rest-tools @@ -170,7 +170,7 @@ urllib3==2.3.0 # botocore # requests # wipac-rest-tools -wipac-dev-tools==1.13.0 +wipac-dev-tools==1.14.0 # via # iceprod (setup.py) # wipac-rest-tools diff --git a/requirements-tests.txt b/requirements-tests.txt index 98cf1f54..73c2d029 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -4,7 +4,7 @@ # # pip-compile --extra=tests --output-file=requirements-tests.txt # -anyio==4.7.0 +anyio==4.8.0 # via httpx asyncache==0.3.1 # via iceprod (setup.py) @@ -14,11 +14,11 @@ attrs==24.3.0 # referencing beautifulsoup4==4.12.3 # via iceprod (setup.py) -boto3==1.35.92 +boto3==1.35.96 # via # iceprod (setup.py) # moto -botocore==1.35.92 +botocore==1.35.96 # via # boto3 # moto @@ -60,7 +60,7 @@ flexmock==0.12.2 # via iceprod (setup.py) h11==0.14.0 # via httpcore -htcondor==24.2.1 +htcondor==24.3.0 # via iceprod (setup.py) httpcore==1.0.7 # via httpx @@ -95,7 +95,7 @@ mccabe==0.7.0 # via flake8 mock==5.1.0 # via iceprod (setup.py) -moto[s3]==5.0.25 +moto[s3]==5.0.26 # via iceprod (setup.py) motor==3.6.0 # via iceprod (setup.py) @@ -169,7 +169,7 @@ requests-mock==1.12.1 # via iceprod (setup.py) requests-toolbelt==1.0.0 # via iceprod (setup.py) -responses==0.25.3 +responses==0.25.5 # via moto respx==0.22.0 # via iceprod (setup.py) @@ -216,7 +216,7 @@ urllib3==2.3.0 # wipac-rest-tools werkzeug==3.1.3 # via moto -wipac-dev-tools==1.13.0 +wipac-dev-tools==1.14.0 # via # iceprod (setup.py) # wipac-rest-tools diff --git a/requirements.txt b/requirements.txt index ff064744..8d1fe05c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ # # pip-compile --output-file=requirements.txt # -anyio==4.7.0 +anyio==4.8.0 # via httpx asyncache==0.3.1 # via iceprod (setup.py) @@ -12,9 +12,9 @@ attrs==24.3.0 # via # jsonschema # referencing -boto3==1.35.92 +boto3==1.35.96 # via iceprod (setup.py) -botocore==1.35.92 +botocore==1.35.96 # via # boto3 # s3transfer @@ -44,7 +44,7 @@ exceptiongroup==1.2.2 # via anyio h11==0.14.0 # via httpcore -htcondor==24.2.1 +htcondor==24.3.0 # via iceprod (setup.py) httpcore==1.0.7 # via httpx @@ -135,7 +135,7 @@ urllib3==2.3.0 # botocore # requests # wipac-rest-tools -wipac-dev-tools==1.13.0 +wipac-dev-tools==1.14.0 # via # iceprod (setup.py) # wipac-rest-tools diff --git a/tests/rest/tasks_test.py b/tests/rest/tasks_test.py index f1963639..d9254c94 100644 --- a/tests/rest/tasks_test.py +++ b/tests/rest/tasks_test.py @@ -297,13 +297,30 @@ async def test_rest_tasks_dataset_counts_status(server): 'requirements': {}, } ret = await client.request('POST', '/tasks', data) - task_id = ret['result'] + + data = { + 'dataset_id': 'foo', + 'job_id': 'foo1', + 'task_index': 1, + 'job_index': 0, + 'name': 'baz', + 'depends': [], + 'requirements': {'gpu': 1}, + 'status': 'processing' + } + ret = await client.request('POST', '/tasks', data) ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status') - assert ret == {states.TASK_STATUS_START: 1} + assert ret == {states.TASK_STATUS_START: 1, 'processing': 1} ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?status=complete') assert ret == {} + + ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?gpu=false') + assert ret == {states.TASK_STATUS_START: 1} + + ret = await client.request('GET', f'/datasets/{data["dataset_id"]}/task_counts/status?gpu=true') + assert ret == {'processing': 1} async def test_rest_tasks_dataset_counts_name_status(server): diff --git a/tests/scheduled_tasks/queue_tasks_test.py b/tests/scheduled_tasks/queue_tasks_test.py index f66de850..88fa99bd 100644 --- a/tests/scheduled_tasks/queue_tasks_test.py +++ b/tests/scheduled_tasks/queue_tasks_test.py @@ -12,6 +12,7 @@ async def test_200_run(): + config = queue_tasks.default_config.copy() rc = MagicMock() async def client(method, url, args=None): if url == '/datasets/foo': @@ -27,7 +28,7 @@ async def client(method, url, args=None): raise Exception() client.called = False rc.request = client - await queue_tasks.run(rc, debug=True) + await queue_tasks.run(rc, config, debug=True) assert client.called async def client(method, url, args=None): @@ -40,7 +41,7 @@ async def client(method, url, args=None): raise Exception() client.called = False rc.request = client - await queue_tasks.run(rc, debug=True) + await queue_tasks.run(rc, config, debug=True) assert not client.called async def client(method, url, args=None): @@ -55,7 +56,7 @@ async def client(method, url, args=None): raise Exception() client.called = False rc.request = client - await queue_tasks.run(rc, debug=True) + await queue_tasks.run(rc, config, debug=True) assert not client.called async def client(method, url, args=None): @@ -72,11 +73,12 @@ async def client(method, url, args=None): raise Exception() client.called = False rc.request = client - await queue_tasks.run(rc, debug=True) + await queue_tasks.run(rc, config, debug=True) assert not client.called async def test_300_run(): + config = queue_tasks.default_config.copy() rc = MagicMock() async def client(method, url, args=None): if url.startswith('/task_counts/status'): @@ -87,8 +89,8 @@ async def client(method, url, args=None): client.called = False rc.request = client with pytest.raises(Exception): - await queue_tasks.run(rc, debug=True) + await queue_tasks.run(rc, config, debug=True) assert client.called # internally catch the error - await queue_tasks.run(rc) + await queue_tasks.run(rc, config)