mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 21:45:00 +03:00
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
from typing import Optional, Sequence, Literal
|
|
|
|
from mods.origins import compute_local_origins, normalize_origins
|
|
from starlette.datastructures import Headers
|
|
from starlette.responses import PlainTextResponse
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
|
|
class TrustedOriginMiddleware:
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
allowed_origins: Optional[Sequence[str]] = None,
|
|
port: Optional[int] = None,
|
|
) -> None:
|
|
self.allowed_origins: set[str] = set()
|
|
|
|
local_origins = compute_local_origins(port)
|
|
self.allowed_origins.update(local_origins)
|
|
|
|
if allowed_origins is not None:
|
|
normalized_origins = normalize_origins(allowed_origins)
|
|
self.allowed_origins.update(normalized_origins)
|
|
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] not in (
|
|
"http",
|
|
"websocket",
|
|
): # pragma: no cover
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
headers = Headers(scope=scope)
|
|
origin = headers.get("origin", "")
|
|
# Origin header is not present for same origin
|
|
if not origin or origin in self.allowed_origins:
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
response = PlainTextResponse("Invalid origin header", status_code=400)
|
|
await response(scope, receive, send)
|