Skip to content

Commit

Permalink
try to drop it...
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 14, 2024
1 parent 70562c7 commit 90f536e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 35 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ filterwarnings = [
"ignore:Uvicorn's native WSGI implementation is deprecated.*:DeprecationWarning",
"ignore: 'cgi' is deprecated and slated for removal in Python 3.13:DeprecationWarning",
"ignore: remove second argument of ws_handler:DeprecationWarning:websockets",
"ignore: websockets.legacy is deprecated.*:DeprecationWarning:websockets",
"ignore: websockets.server.WebSocketServerProtocol is deprecated.*:DeprecationWarning:websockets",
"ignore: websockets.exceptions.InvalidStatusCode.*:DeprecationWarning",
]

[tool.coverage.run]
Expand Down
61 changes: 29 additions & 32 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import httpx
import pytest
import websockets
import websockets.exceptions
from typing_extensions import TypedDict
from websockets import WebSocketClientProtocol
from websockets.asyncio.client import connect
from websockets.asyncio.client import ClientConnection, connect
from websockets.exceptions import ConnectionClosed, ConnectionClosedError, InvalidHandshake, InvalidStatus
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
from websockets.typing import Subprotocol

Expand Down Expand Up @@ -203,7 +202,7 @@ async def websocket_connect(self, message: WebSocketConnectEvent):
async def open_connection(url: str):
try:
await connect(url)
except websockets.exceptions.InvalidHandshake:
except InvalidHandshake:
return False
return True # pragma: no cover

Expand Down Expand Up @@ -412,9 +411,9 @@ async def open_connection(url: str):

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
with pytest.raises(InvalidStatus) as exc_info:
await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert exc_info.value.status_code == 500
assert exc_info.value.response.status_code == 500


async def test_send_before_handshake(
Expand All @@ -428,9 +427,9 @@ async def open_connection(url: str):

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
with pytest.raises(InvalidStatus) as exc_info:
await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert exc_info.value.status_code == 500
assert exc_info.value.response.status_code == 500


async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
Expand All @@ -441,9 +440,9 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
with pytest.raises(ConnectionClosed):
_ = await websocket.recv()
assert exc_info.value.code == 1006
assert websocket.protocol.close_code == 1006


async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
Expand All @@ -459,7 +458,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info:
with pytest.raises(ConnectionClosed) as exc_info:
_ = await websocket.recv()
assert exc_info.value.code == 1006

Expand Down Expand Up @@ -493,13 +492,13 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
await websocket.ping()
await websocket.send("abc")
with pytest.raises(websockets.exceptions.ConnectionClosed):
with pytest.raises(ConnectionClosed) as exc_info:
await websocket.recv()
assert websocket.close_code == (code or 1000)
assert websocket.close_reason == (reason or "")
assert exc_info.value.code == (code or 1000)
assert exc_info.value.reason == (reason or "")


async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
Expand Down Expand Up @@ -555,7 +554,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
port=unused_tcp_port,
)
async with run_server(config):
async with websockets.client.connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
async with connect(f"ws://127.0.0.1:{unused_tcp_port}") as websocket:
websocket.transport.close()
await asyncio.sleep(0.1)
got_disconnect_event_before_shutdown = got_disconnect_event
Expand Down Expand Up @@ -583,7 +582,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
url = f"ws://127.0.0.1:{unused_tcp_port}"
async with websockets.client.connect(url):
async with connect(url):
await asyncio.sleep(0.1)
disconnect.set()

Expand Down Expand Up @@ -648,11 +647,11 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
disconnect_message = message
break

websocket: WebSocketClientProtocol | None = None
websocket: ClientConnection | None = None

async def websocket_session(uri: str):
nonlocal websocket
async with websockets.client.connect(uri) as ws_connection:
async with connect(uri) as ws_connection:
websocket = ws_connection
await server_shutdown_event.wait()

Expand Down Expand Up @@ -682,9 +681,7 @@ async def websocket_connect(self, message: WebSocketConnectEvent):
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})

async def get_subprotocol(url: str):
async with websockets.client.connect(
url, subprotocols=[Subprotocol("proto1"), Subprotocol("proto2")]
) as websocket:
async with connect(url, subprotocols=[Subprotocol("proto1"), Subprotocol("proto2")]) as websocket:
return websocket.subprotocol

config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
Expand Down Expand Up @@ -743,9 +740,9 @@ async def websocket_receive(self, message: WebSocketReceiveEvent):
data = await ws.recv()
assert data == b"\x01" * client_size_sent
else:
with pytest.raises(websockets.exceptions.ConnectionClosedError) as exc_info:
with pytest.raises(ConnectionClosedError):
await ws.recv()
assert exc_info.value.code == expected_result
assert ws.protocol.close_code == expected_result


async def test_server_reject_connection(
Expand All @@ -770,10 +767,10 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
disconnected_message = await receive()

async def websocket_session(url: str):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
with pytest.raises(InvalidStatus) as exc_info:
async with connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 403
assert exc_info.value.response.status_code == 403

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
Expand Down Expand Up @@ -940,10 +937,10 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
await send(message)

async def websocket_session(url: str):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
with pytest.raises(InvalidStatus) as exc_info:
async with connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 404
assert exc_info.value.response.status_code == 404

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
Expand Down Expand Up @@ -971,10 +968,10 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
# no further message

async def websocket_session(url: str):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
with pytest.raises(InvalidStatus) as exc_info:
async with connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 404
assert exc_info.value.response.status_code == 404

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
Expand Down Expand Up @@ -1012,10 +1009,10 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
exception_message = str(exc)

async def websocket_session(url: str):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
with pytest.raises(InvalidStatus) as exc_info:
async with connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 404
assert exc_info.value.response.status_code == 404

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
Expand Down
5 changes: 2 additions & 3 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from websockets.exceptions import ConnectionClosed
from websockets.extensions.base import ServerExtensionFactory
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.server import WebSocketServerProtocol
from websockets.legacy.server import HTTPResponse, WebSocketServerProtocol
from websockets.typing import Subprotocol

from uvicorn._types import (
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Subprotocol | None = None

self.ws_server = Server()
self.ws_server: Server = Server()

extensions: list[ServerExtensionFactory] = []
if self.config.ws_per_message_deflate:
Expand Down

0 comments on commit 90f536e

Please sign in to comment.