voice-changer/server/restapi/mods/trustedorigin.py

52 lines
1.9 KiB
Python
Raw Normal View History

2024-03-17 01:11:16 +03:00
import typing
from urllib.parse import urlparse
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
class TrustedOriginMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
port: typing.Optional[int] = None,
) -> None:
schemas = ['http', 'https']
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
if port is not None:
local_origins = [f'{origin}:{port}' for origin in local_origins]
self.allowed_origins: set[str] = set()
if allowed_origins is not None:
2024-03-17 01:11:16 +03:00
for origin in allowed_origins:
url = urlparse(origin)
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
valid_origin = f'{url.scheme}://{url.hostname}'
if url.port:
valid_origin += f':{url.port}'
self.allowed_origins.add(valid_origin)
self.allowed_origins.update(local_origins)
2024-03-17 01:11:16 +03:00
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in (
"http",
"websocket",
): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin", "")
# Origin header is not present for same origin
if not origin or origin in self.allowed_origins:
await self.app(scope, receive, send)
return
response = PlainTextResponse("Invalid origin header", status_code=400)
await response(scope, receive, send)