Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/lifespan.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ async def homepage(request: Request[State]) -> PlainTextResponse:
app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)])
```

This also works with WebSockets:

```python
async def websocket_endpoint(websocket: WebSocket[State]) -> None:
await websocket.accept()
client = websocket.state["http_client"]
response = await client.get("https://www.example.com")
await websocket.send_text(response.text)
await websocket.close()


app = Starlette(lifespan=lifespan, routes=[WebSocketRoute("/ws", websocket_endpoint)])
```

!!! note
There were many attempts to make this work with attribute-style access instead of
dictionary-style access, but none were satisfactory, given they would have been
Expand Down
4 changes: 2 additions & 2 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncIterator, Iterable
from typing import Any, cast

from starlette.requests import HTTPConnection
from starlette.requests import HTTPConnection, StateT
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send

Expand All @@ -23,7 +23,7 @@ def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.reason = reason or ""


class WebSocket(HTTPConnection):
class WebSocket(HTTPConnection[StateT]):
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
super().__init__(scope)
assert scope["type"] == "websocket"
Expand Down
14 changes: 14 additions & 0 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ async def websocket_raise_custom(websocket: WebSocket) -> None:
raise CustomWSException()


async def websocket_state(websocket: WebSocket[CustomState]) -> None:
await websocket.accept()
await websocket.send_json({"count": websocket.state["count"]})
await websocket.close()


def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)

Expand Down Expand Up @@ -141,6 +147,7 @@ async def state_count(request: Request[CustomState]) -> JSONResponse:
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
WebSocketRoute("/ws-state", endpoint=websocket_state),
Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain),
],
Expand Down Expand Up @@ -247,6 +254,12 @@ def test_websocket_raise_websocket_exception(client: TestClient) -> None:
}


def test_websocket_state(client: TestClient) -> None:
with client.websocket_connect("/ws-state") as session:
response = session.receive_json()
assert response == {"count": 1}


def test_websocket_raise_http_exception(client: TestClient) -> None:
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/ws-raise-http"):
Expand Down Expand Up @@ -283,6 +296,7 @@ def test_routes() -> None:
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
WebSocketRoute("/ws-state", endpoint=websocket_state),
Mount(
"/users",
app=Router(
Expand Down