diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py index c1b15157..807a0525 100755 --- a/server/MMVCServerSIO.py +++ b/server/MMVCServerSIO.py @@ -65,6 +65,9 @@ def setupArgParser(): parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe") parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx") + parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.") + parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.") + return parser @@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams) printMessage(f"Booting PHASE :{__name__}", level=2) +HOST = args.host PORT = args.p -def localServer(logLevel: str = "critical"): +def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None): try: uvicorn.run( f"{os.path.basename(__file__)[:-3]}:app_socketio", - host="0.0.0.0", + host=HOST, port=int(PORT), reload=False if hasattr(sys, "_MEIPASS") else True, + ssl_keyfile=key_path, + ssl_certfile=cert_path, log_level=logLevel, ) except Exception as e: @@ -134,8 +140,8 @@ if __name__ == "MMVCServerSIO": mp.freeze_support() voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams) - app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams) - app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager) + app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT) + app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT) if __name__ == "__mp_main__": @@ -220,34 +226,26 @@ if __name__ == "__main__": printMessage("In many cases, it will launch when you access any of the following URLs.", level=2) if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker) if args.https == 1: - printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1) + printMessage(f"https://localhost:{EX_PORT}/", level=1) for ip in EX_IP.strip().split(" "): printMessage(f"https://{ip}:{EX_PORT}/", level=1) else: - printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1) + printMessage(f"http://localhost:{EX_PORT}/", level=1) else: # 直接python起動 if args.https == 1: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect((args.test_connect, 80)) hostname = s.getsockname()[0] - printMessage(f"https://127.0.0.1:{PORT}/", level=1) + printMessage(f"https://localhost:{PORT}/", level=1) printMessage(f"https://{hostname}:{PORT}/", level=1) else: - printMessage(f"http://127.0.0.1:{PORT}/", level=1) + printMessage(f"http://localhost:{PORT}/", level=1) # サーバ起動 if args.https: # HTTPS サーバ起動 try: - uvicorn.run( - f"{os.path.basename(__file__)[:-3]}:app_socketio", - host="0.0.0.0", - port=int(PORT), - reload=False if hasattr(sys, "_MEIPASS") else True, - ssl_keyfile=key_path, - ssl_certfile=cert_path, - log_level=args.logLevel, - ) + localServer(args.logLevel, key_path, cert_path) except Exception as e: logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}") @@ -256,12 +254,12 @@ if __name__ == "__main__": p.start() try: if sys.platform.startswith("win"): - process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"]) + process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"]) return_code = process.wait() logger.info("client closed.") p.terminate() elif sys.platform.startswith("darwin"): - process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"]) + process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"]) return_code = process.wait() logger.info("client closed.") p.terminate() diff --git a/server/mods/origins.py b/server/mods/origins.py new file mode 100644 index 00000000..98f08f66 --- /dev/null +++ b/server/mods/origins.py @@ -0,0 +1,24 @@ +from typing import Optional, Sequence +from urllib.parse import urlparse + +ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." +SCHEMAS = ('http', 'https') +LOCAL_ORIGINS = ('127.0.0.1', 'localhost') + +def compute_local_origins(port: Optional[int] = None) -> list[str]: + local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS] + if port is not None: + local_origins = [f'{origin}:{port}' for origin in local_origins] + return local_origins + + +def normalize_origins(origins: Sequence[str]) -> set[str]: + allowed_origins = set() + for origin in origins: + url = urlparse(origin) + assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT + valid_origin = f'{url.scheme}://{url.hostname}' + if url.port: + valid_origin += f':{url.port}' + allowed_origins.add(valid_origin) + return allowed_origins diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py index 13879ac8..98a0c3ef 100644 --- a/server/restapi/MMVC_Rest.py +++ b/server/restapi/MMVC_Rest.py @@ -1,12 +1,12 @@ import os import sys +from restapi.mods.trustedorigin import TrustedOriginMiddleware from fastapi import FastAPI, Request, Response, HTTPException from fastapi.routing import APIRoute -from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.exceptions import RequestValidationError -from typing import Callable +from typing import Callable, Optional, Sequence, Literal from mods.log_control import VoiceChangaerLogger from voice_changer.VoiceChangerManager import VoiceChangerManager @@ -43,17 +43,17 @@ class MMVC_Rest: cls, voiceChangerManager: VoiceChangerManager, voiceChangerParams: VoiceChangerParams, + allowedOrigins: Optional[Sequence[str]] = None, + port: Optional[int] = None, ): if cls._instance is None: logger.info("[Voice Changer] MMVC_Rest initializing...") app_fastapi = FastAPI() app_fastapi.router.route_class = ValidationErrorLoggingRoute app_fastapi.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + TrustedOriginMiddleware, + allowed_origins=allowedOrigins, + port=port ) app_fastapi.mount( diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py new file mode 100644 index 00000000..582f4264 --- /dev/null +++ b/server/restapi/mods/trustedorigin.py @@ -0,0 +1,43 @@ +from typing import Optional, Sequence, Literal + +from mods.origins import compute_local_origins, normalize_origins +from starlette.datastructures import Headers +from starlette.responses import PlainTextResponse +from starlette.types import ASGIApp, Receive, Scope, Send + + +class TrustedOriginMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_origins: Optional[Sequence[str]] = None, + port: Optional[int] = None, + ) -> None: + self.allowed_origins: set[str] = set() + + local_origins = compute_local_origins(port) + self.allowed_origins.update(local_origins) + + if allowed_origins is not None: + normalized_origins = normalize_origins(allowed_origins) + self.allowed_origins.update(normalized_origins) + + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + origin = headers.get("origin", "") + # Origin header is not present for same origin + if not origin or origin in self.allowed_origins: + await self.app(scope, receive, send) + return + + response = PlainTextResponse("Invalid origin header", status_code=400) + await response(scope, receive, send) diff --git a/server/sio/MMVC_SocketIOApp.py b/server/sio/MMVC_SocketIOApp.py index 23205d48..f73c6918 100644 --- a/server/sio/MMVC_SocketIOApp.py +++ b/server/sio/MMVC_SocketIOApp.py @@ -1,6 +1,8 @@ import socketio from mods.log_control import VoiceChangaerLogger +from mods.origins import compute_local_origins, normalize_origins +from typing import Sequence, Optional from sio.MMVC_SocketIOServer import MMVC_SocketIOServer from voice_changer.VoiceChangerManager import VoiceChangerManager from const import getFrontendPath @@ -12,10 +14,24 @@ class MMVC_SocketIOApp: _instance: socketio.ASGIApp | None = None @classmethod - def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager): + def get_instance( + cls, + app_fastapi, + voiceChangerManager: VoiceChangerManager, + allowedOrigins: Optional[Sequence[str]] = None, + port: Optional[int] = None, + ): if cls._instance is None: logger.info("[Voice Changer] MMVC_SocketIOApp initializing...") - sio = MMVC_SocketIOServer.get_instance(voiceChangerManager) + + allowed_origins: set[str] = set() + local_origins = compute_local_origins(port) + allowed_origins.update(local_origins) + if allowedOrigins is not None: + normalized_origins = normalize_origins(allowedOrigins) + allowed_origins.update(normalized_origins) + sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins)) + app_socketio = socketio.ASGIApp( sio, other_asgi_app=app_fastapi, diff --git a/server/sio/MMVC_SocketIOServer.py b/server/sio/MMVC_SocketIOServer.py index 9f168515..68000681 100644 --- a/server/sio/MMVC_SocketIOServer.py +++ b/server/sio/MMVC_SocketIOServer.py @@ -8,9 +8,13 @@ class MMVC_SocketIOServer: _instance: socketio.AsyncServer | None = None @classmethod - def get_instance(cls, voiceChangerManager: VoiceChangerManager): + def get_instance( + cls, + voiceChangerManager: VoiceChangerManager, + allowedOrigins: list[str], + ): if cls._instance is None: - sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins) namespace = MMVC_Namespace.get_instance(voiceChangerManager) sio.register_namespace(namespace) cls._instance = sio