From 1c043a22f9e222c65a9bfa022adfce8a47cf292f Mon Sep 17 00:00:00 2001 From: nggit Date: Tue, 28 Jan 2025 07:40:42 +0700 Subject: [PATCH] refactor keepalive --- tremolo/lib/http_protocol.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tremolo/lib/http_protocol.py b/tremolo/lib/http_protocol.py index 3193e3b..53fab7f 100644 --- a/tremolo/lib/http_protocol.py +++ b/tremolo/lib/http_protocol.py @@ -320,7 +320,7 @@ def data_received(self, data): if 1 < header_size <= self.options['client_max_header_size']: # this will keep blocking on bodyless requests forever, unless - # Response.close is called (resumed in _handle_keepalive) + # Response.close is called (resumed in _send_data) self.transport.pause_reading() header = HTTPHeader(self._header_buf, @@ -379,9 +379,16 @@ async def _send_data(self): if data is None: # close the transport, unless keepalive is enabled if self.request is not None: + if self.request.http_continue: + self.request.http_continue = False + self.transport.resume_reading() + continue + if self.request.http_keepalive: self.request.http_keepalive = False + await self._handle_keepalive() + self.transport.resume_reading() continue self.request.clear_body() @@ -431,20 +438,15 @@ async def _handle_keepalive(self): self.close() return - if self.request.http_continue: - self.request.http_continue = False - elif self.request.upgraded: - if 'request' not in self._waiters: - self.close() - return - + if self.request.upgraded: self._waiters.setdefault('receive', self._waiters.pop('request')) else: - # reset. so the next data in data_received will be considered as - # a fresh http request (not a continuation data) if 'receive' in self._waiters: + # waits for all incoming data to enter the queue await self._waiters.pop('receive') + # reset. so the next data in data_received will be considered as + # a fresh http request (not a continuation data) self._header_buf = bytearray() self.request.clear_body() @@ -460,8 +462,6 @@ async def _handle_keepalive(self): # this data is supposed to be the next header self.data_received(self.queue[0].get_nowait()) - self.transport.resume_reading() - def connection_lost(self, _): while self.tasks: task = self.tasks.pop()