Skip to content

Commit

Permalink
Add socket path to scope["server"]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Jan 12, 2025
1 parent ae8253f commit e561072
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
15 changes: 10 additions & 5 deletions tests/protocols/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@


class MockSocket:
def __init__(self, family, peername=None, sockname=None):
def __init__(
self,
family: socket.AddressFamily,
peername: tuple[str, int] | None = None,
sockname: tuple[str, int] | str | None = None,
):
self.peername = peername
self.sockname = sockname
self.family = family
Expand Down Expand Up @@ -41,8 +46,8 @@ def test_get_local_addr_with_socket():
assert get_local_addr(transport) == ("123.45.6.7", 123)

if hasattr(socket, "AF_UNIX"): # pragma: no cover
transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, sockname=("127.0.0.1", 8000))})
assert get_local_addr(transport) == ("127.0.0.1", 8000)
transport = MockTransport({"socket": MockSocket(family=socket.AF_UNIX, sockname="/tmp/test.sock")})
assert get_local_addr(transport) == ("/tmp/test.sock", None)


def test_get_remote_addr_with_socket():
Expand All @@ -62,7 +67,7 @@ def test_get_remote_addr_with_socket():

def test_get_local_addr():
transport = MockTransport({"sockname": "path/to/unix-domain-socket"})
assert get_local_addr(transport) is None
assert get_local_addr(transport) == ("path/to/unix-domain-socket", None)

transport = MockTransport({"sockname": ("123.45.6.7", 123)})
assert get_local_addr(transport) == ("123.45.6.7", 123)
Expand All @@ -81,5 +86,5 @@ def test_get_remote_addr():
[({"client": ("127.0.0.1", 36000)}, "127.0.0.1:36000"), ({"client": None}, "")],
ids=["ip:port client", "None client"],
)
def test_get_client_addr(scope, expected_client):
def test_get_client_addr(scope: Any, expected_client: str):
assert get_client_addr(scope) == expected_client
14 changes: 10 additions & 4 deletions uvicorn/protocols/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import socket
import urllib.parse

from uvicorn._types import WWWScope
Expand All @@ -10,7 +11,7 @@ class ClientDisconnected(OSError): ...


def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
socket_info: socket.socket | None = transport.get_extra_info("socket")
if socket_info is not None:
try:
info = socket_info.getpeername()
Expand All @@ -27,14 +28,19 @@ def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:


def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
socket_info: socket.socket | None = transport.get_extra_info("socket")
if socket_info is not None:
info = socket_info.getsockname()

return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
if isinstance(info, tuple):
return (str(info[0]), int(info[1]))
elif isinstance(info, str):
return (info, None)
return None
info = transport.get_extra_info("sockname")
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
return (str(info[0]), int(info[1]))
elif isinstance(info, str):
return (info, None)
return None


Expand Down

0 comments on commit e561072

Please sign in to comment.