diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py index 1e8541fd..807a0525 100755 --- a/server/MMVCServerSIO.py +++ b/server/MMVCServerSIO.py @@ -140,8 +140,8 @@ if __name__ == "MMVCServerSIO": mp.freeze_support() voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams) - app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins) - app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager) + app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT) + app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT) if __name__ == "__mp_main__": diff --git a/server/mods/origins.py b/server/mods/origins.py new file mode 100644 index 00000000..98f08f66 --- /dev/null +++ b/server/mods/origins.py @@ -0,0 +1,24 @@ +from typing import Optional, Sequence +from urllib.parse import urlparse + +ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." +SCHEMAS = ('http', 'https') +LOCAL_ORIGINS = ('127.0.0.1', 'localhost') + +def compute_local_origins(port: Optional[int] = None) -> list[str]: + local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS] + if port is not None: + local_origins = [f'{origin}:{port}' for origin in local_origins] + return local_origins + + +def normalize_origins(origins: Sequence[str]) -> set[str]: + allowed_origins = set() + for origin in origins: + url = urlparse(origin) + assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT + valid_origin = f'{url.scheme}://{url.hostname}' + if url.port: + valid_origin += f':{url.port}' + allowed_origins.add(valid_origin) + return allowed_origins diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py index 472d5695..98a0c3ef 100644 --- a/server/restapi/MMVC_Rest.py +++ b/server/restapi/MMVC_Rest.py @@ -6,7 +6,7 @@ from fastapi import FastAPI, Request, Response, HTTPException from fastapi.routing import APIRoute from fastapi.staticfiles import StaticFiles from fastapi.exceptions import RequestValidationError -from typing import Callable +from typing import Callable, Optional, Sequence, Literal from mods.log_control import VoiceChangaerLogger from voice_changer.VoiceChangerManager import VoiceChangerManager @@ -43,8 +43,8 @@ class MMVC_Rest: cls, voiceChangerManager: VoiceChangerManager, voiceChangerParams: VoiceChangerParams, - port: int, - allowedOrigins: list[str], + allowedOrigins: Optional[Sequence[str]] = None, + port: Optional[int] = None, ): if cls._instance is None: logger.info("[Voice Changer] MMVC_Rest initializing...") diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py index b6799547..582f4264 100644 --- a/server/restapi/mods/trustedorigin.py +++ b/server/restapi/mods/trustedorigin.py @@ -1,35 +1,27 @@ -import typing +from typing import Optional, Sequence, Literal -from urllib.parse import urlparse +from mods.origins import compute_local_origins, normalize_origins from starlette.datastructures import Headers from starlette.responses import PlainTextResponse from starlette.types import ASGIApp, Receive, Scope, Send -ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." - class TrustedOriginMiddleware: def __init__( self, app: ASGIApp, - allowed_origins: typing.Optional[typing.Sequence[str]] = None, - port: typing.Optional[int] = None, + allowed_origins: Optional[Sequence[str]] = None, + port: Optional[int] = None, ) -> None: - schemas = ['http', 'https'] - local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']] - if port is not None: - local_origins = [f'{origin}:{port}' for origin in local_origins] - self.allowed_origins: set[str] = set() - if allowed_origins is not None: - for origin in allowed_origins: - url = urlparse(origin) - assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT - valid_origin = f'{url.scheme}://{url.hostname}' - if url.port: - valid_origin += f':{url.port}' - self.allowed_origins.add(valid_origin) + + local_origins = compute_local_origins(port) self.allowed_origins.update(local_origins) + + if allowed_origins is not None: + normalized_origins = normalize_origins(allowed_origins) + self.allowed_origins.update(normalized_origins) + self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/server/sio/MMVC_SocketIOApp.py b/server/sio/MMVC_SocketIOApp.py index 23205d48..f73c6918 100644 --- a/server/sio/MMVC_SocketIOApp.py +++ b/server/sio/MMVC_SocketIOApp.py @@ -1,6 +1,8 @@ import socketio from mods.log_control import VoiceChangaerLogger +from mods.origins import compute_local_origins, normalize_origins +from typing import Sequence, Optional from sio.MMVC_SocketIOServer import MMVC_SocketIOServer from voice_changer.VoiceChangerManager import VoiceChangerManager from const import getFrontendPath @@ -12,10 +14,24 @@ class MMVC_SocketIOApp: _instance: socketio.ASGIApp | None = None @classmethod - def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager): + def get_instance( + cls, + app_fastapi, + voiceChangerManager: VoiceChangerManager, + allowedOrigins: Optional[Sequence[str]] = None, + port: Optional[int] = None, + ): if cls._instance is None: logger.info("[Voice Changer] MMVC_SocketIOApp initializing...") - sio = MMVC_SocketIOServer.get_instance(voiceChangerManager) + + allowed_origins: set[str] = set() + local_origins = compute_local_origins(port) + allowed_origins.update(local_origins) + if allowedOrigins is not None: + normalized_origins = normalize_origins(allowedOrigins) + allowed_origins.update(normalized_origins) + sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins)) + app_socketio = socketio.ASGIApp( sio, other_asgi_app=app_fastapi, diff --git a/server/sio/MMVC_SocketIOServer.py b/server/sio/MMVC_SocketIOServer.py index 9f168515..68000681 100644 --- a/server/sio/MMVC_SocketIOServer.py +++ b/server/sio/MMVC_SocketIOServer.py @@ -8,9 +8,13 @@ class MMVC_SocketIOServer: _instance: socketio.AsyncServer | None = None @classmethod - def get_instance(cls, voiceChangerManager: VoiceChangerManager): + def get_instance( + cls, + voiceChangerManager: VoiceChangerManager, + allowedOrigins: list[str], + ): if cls._instance is None: - sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins) namespace = MMVC_Namespace.get_instance(voiceChangerManager) sio.register_namespace(namespace) cls._instance = sio