From 5e1f3c41cad0145df30a563d852339bc5ccfcf6f Mon Sep 17 00:00:00 2001 From: William Courtney Date: Thu, 7 Nov 2024 01:47:58 +0000 Subject: [PATCH] Limit the number of concurrent jobs when running run_batch. --- cirq-core/cirq/work/sampler.py | 4 ++-- cirq-google/cirq_google/engine/engine_job.py | 23 ++++++++----------- .../cirq_google/engine/processor_sampler.py | 14 ++--------- 3 files changed, 13 insertions(+), 28 deletions(-) diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index 2998492191e..ac46135d81c 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -291,7 +291,7 @@ async def run_batch_async( programs: Sequence['cirq.AbstractCircuit'], params_list: Optional[Sequence['cirq.Sweepable']] = None, repetitions: Union[int, Sequence[int]] = 1, - limiter: duet.Limiter = duet.Limiter(10), + max_concurrent_jobs: int = 10, ) -> Sequence[Sequence['cirq.Result']]: """Runs the supplied circuits asynchronously. @@ -299,7 +299,7 @@ async def run_batch_async( """ params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions) return await duet.pstarmap_async( - self.run_sweep_async, zip(programs, params_list, repetitions, [limiter] * len(programs)) + self.run_sweep_async, zip(programs, params_list, repetitions), max_concurrent_jobs ) def _normalize_batch_args( diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index 1eaeb174a50..5eca36e2840 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -262,14 +262,12 @@ def delete(self) -> None: """Deletes the job and result, if any.""" self.context.client.delete_job(self.project_id, self.program_id, self.job_id) - async def results_async( - self, limiter: duet.Limiter = duet.Limiter(None) - ) -> Sequence[EngineResult]: + async def results_async(self) -> Sequence[EngineResult]: """Returns the job results, blocking until the job is complete.""" import cirq_google.engine.engine as engine_base if self._results is None: - result_response = await self._await_result_async(limiter) + result_response = await self._await_result_async() result = result_response.result result_type = result.type_url[len(engine_base.TYPE_PREFIX) :] if ( @@ -288,9 +286,7 @@ async def results_async( raise ValueError(f'invalid result proto version: {result_type}') return self._results - async def _await_result_async( - self, limiter: duet.Limiter = duet.Limiter(None) - ) -> quantum.QuantumResult: + async def _await_result_async(self) -> quantum.QuantumResult: if self._job_result_future is not None: response = await self._job_result_future if isinstance(response, quantum.QuantumResult): @@ -303,13 +299,12 @@ async def _await_result_async( 'Internal error: The job response type is not recognized.' ) # pragma: no cover - async with limiter: - async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type] - while True: - job = await self._refresh_job_async() - if job.execution_status.state in TERMINAL_STATES: - break - await duet.sleep(1) + async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type] + while True: + job = await self._refresh_job_async() + if job.execution_status.state in TERMINAL_STATES: + break + await duet.sleep(1) _raise_on_failure(job) response = await self.context.client.get_job_results_async( self.project_id, self.program_id, self.job_id diff --git a/cirq-google/cirq_google/engine/processor_sampler.py b/cirq-google/cirq_google/engine/processor_sampler.py index a109763c822..daf61b4e17d 100644 --- a/cirq-google/cirq_google/engine/processor_sampler.py +++ b/cirq-google/cirq_google/engine/processor_sampler.py @@ -59,14 +59,12 @@ def __init__( self._run_name = run_name self._snapshot_id = snapshot_id self._device_config_name = device_config_name - self._result_limiter = duet.Limiter(None) async def run_sweep_async( self, program: 'cirq.AbstractCircuit', params: cirq.Sweepable, repetitions: int = 1, - limiter: duet.Limiter = duet.Limiter(None), ) -> Sequence['cg.EngineResult']: job = await self._processor.run_sweep_async( program=program, @@ -77,9 +75,6 @@ async def run_sweep_async( device_config_name=self._device_config_name, ) - if isinstance(job, EngineJob): - return await job.results_async(limiter) - return await job.results_async() run_sweep = duet.sync(run_sweep_async) @@ -89,12 +84,11 @@ async def run_batch_async( programs: Sequence[cirq.AbstractCircuit], params_list: Optional[Sequence[cirq.Sweepable]] = None, repetitions: Union[int, Sequence[int]] = 1, - limiter: duet.Limiter = duet.Limiter(10), + max_concurrent_jobs: int = 10, ) -> Sequence[Sequence['cg.EngineResult']]: - self._result_limiter = limiter return cast( Sequence[Sequence['cg.EngineResult']], - await super().run_batch_async(programs, params_list, repetitions, self._result_limiter), + await super().run_batch_async(programs, params_list, repetitions, max_concurrent_jobs), ) run_batch = duet.sync(run_batch_async) @@ -114,7 +108,3 @@ def snapshot_id(self) -> str: @property def device_config_name(self) -> str: return self._device_config_name - - @property - def result_limiter(self) -> duet.Limiter: - return self._result_limiter