From cf2b6933348d5c60e53cf922f61333267fd6a032 Mon Sep 17 00:00:00 2001 From: Yury <rip95_95@mail.ru> Date: Sun, 17 Mar 2024 00:11:16 +0200 Subject: [PATCH 1/3] Harden web server security --- server/MMVCServerSIO.py | 34 +++++++++---------- server/restapi/MMVC_Rest.py | 12 +++---- server/restapi/mods/trustedorigin.py | 49 ++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 24 deletions(-) create mode 100644 server/restapi/mods/trustedorigin.py diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py index c1b15157..1e8541fd 100755 --- a/server/MMVCServerSIO.py +++ b/server/MMVCServerSIO.py @@ -65,6 +65,9 @@ def setupArgParser(): parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe") parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx") + parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.") + parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.") + return parser @@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams) printMessage(f"Booting PHASE :{__name__}", level=2) +HOST = args.host PORT = args.p -def localServer(logLevel: str = "critical"): +def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None): try: uvicorn.run( f"{os.path.basename(__file__)[:-3]}:app_socketio", - host="0.0.0.0", + host=HOST, port=int(PORT), reload=False if hasattr(sys, "_MEIPASS") else True, + ssl_keyfile=key_path, + ssl_certfile=cert_path, log_level=logLevel, ) except Exception as e: @@ -134,7 +140,7 @@ if __name__ == "MMVCServerSIO": mp.freeze_support() voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams) - app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams) + app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins) app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager) @@ -220,34 +226,26 @@ if __name__ == "__main__": printMessage("In many cases, it will launch when you access any of the following URLs.", level=2) if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker) if args.https == 1: - printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1) + printMessage(f"https://localhost:{EX_PORT}/", level=1) for ip in EX_IP.strip().split(" "): printMessage(f"https://{ip}:{EX_PORT}/", level=1) else: - printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1) + printMessage(f"http://localhost:{EX_PORT}/", level=1) else: # 直接python起動 if args.https == 1: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect((args.test_connect, 80)) hostname = s.getsockname()[0] - printMessage(f"https://127.0.0.1:{PORT}/", level=1) + printMessage(f"https://localhost:{PORT}/", level=1) printMessage(f"https://{hostname}:{PORT}/", level=1) else: - printMessage(f"http://127.0.0.1:{PORT}/", level=1) + printMessage(f"http://localhost:{PORT}/", level=1) # サーバ起動 if args.https: # HTTPS サーバ起動 try: - uvicorn.run( - f"{os.path.basename(__file__)[:-3]}:app_socketio", - host="0.0.0.0", - port=int(PORT), - reload=False if hasattr(sys, "_MEIPASS") else True, - ssl_keyfile=key_path, - ssl_certfile=cert_path, - log_level=args.logLevel, - ) + localServer(args.logLevel, key_path, cert_path) except Exception as e: logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}") @@ -256,12 +254,12 @@ if __name__ == "__main__": p.start() try: if sys.platform.startswith("win"): - process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"]) + process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"]) return_code = process.wait() logger.info("client closed.") p.terminate() elif sys.platform.startswith("darwin"): - process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"]) + process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"]) return_code = process.wait() logger.info("client closed.") p.terminate() diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py index 13879ac8..472d5695 100644 --- a/server/restapi/MMVC_Rest.py +++ b/server/restapi/MMVC_Rest.py @@ -1,9 +1,9 @@ import os import sys +from restapi.mods.trustedorigin import TrustedOriginMiddleware from fastapi import FastAPI, Request, Response, HTTPException from fastapi.routing import APIRoute -from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.exceptions import RequestValidationError from typing import Callable @@ -43,17 +43,17 @@ class MMVC_Rest: cls, voiceChangerManager: VoiceChangerManager, voiceChangerParams: VoiceChangerParams, + port: int, + allowedOrigins: list[str], ): if cls._instance is None: logger.info("[Voice Changer] MMVC_Rest initializing...") app_fastapi = FastAPI() app_fastapi.router.route_class = ValidationErrorLoggingRoute app_fastapi.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + TrustedOriginMiddleware, + allowed_origins=allowedOrigins, + port=port ) app_fastapi.mount( diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py new file mode 100644 index 00000000..d7c51d8a --- /dev/null +++ b/server/restapi/mods/trustedorigin.py @@ -0,0 +1,49 @@ +import typing + +from urllib.parse import urlparse +from starlette.datastructures import Headers +from starlette.responses import PlainTextResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." + + +class TrustedOriginMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_origins: typing.Optional[typing.Sequence[str]] = None, + port: typing.Optional[int] = None, + ) -> None: + schemas = ['http', 'https'] + local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']] + if port is not None: + local_origins = [f'{origin}:{port}' for origin in local_origins] + + if not allowed_origins: + allowed_origins = local_origins + else: + for origin in allowed_origins: + assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT + allowed_origins = local_origins + allowed_origins + + self.app = app + self.allowed_origins = list(allowed_origins) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + origin = headers.get("origin", "") + # Origin header is not present for same origin + if not origin or origin in self.allowed_origins: + await self.app(scope, receive, send) + return + + response = PlainTextResponse("Invalid origin header", status_code=400) + await response(scope, receive, send) From ce9b599501a636eaef08918ac13afc881335073f Mon Sep 17 00:00:00 2001 From: Yury <rip95_95@mail.ru> Date: Sun, 17 Mar 2024 16:26:55 +0200 Subject: [PATCH 2/3] Improve allowed origins input and use set --- server/restapi/mods/trustedorigin.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py index d7c51d8a..b6799547 100644 --- a/server/restapi/mods/trustedorigin.py +++ b/server/restapi/mods/trustedorigin.py @@ -20,15 +20,17 @@ class TrustedOriginMiddleware: if port is not None: local_origins = [f'{origin}:{port}' for origin in local_origins] - if not allowed_origins: - allowed_origins = local_origins - else: + self.allowed_origins: set[str] = set() + if allowed_origins is not None: for origin in allowed_origins: - assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT - allowed_origins = local_origins + allowed_origins - + url = urlparse(origin) + assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT + valid_origin = f'{url.scheme}://{url.hostname}' + if url.port: + valid_origin += f':{url.port}' + self.allowed_origins.add(valid_origin) + self.allowed_origins.update(local_origins) self.app = app - self.allowed_origins = list(allowed_origins) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ( From 8dd8d7127d16142d84c45b144da9bc1c692b67f5 Mon Sep 17 00:00:00 2001 From: Yury <rip95_95@mail.ru> Date: Mon, 18 Mar 2024 22:52:22 +0200 Subject: [PATCH 3/3] Refactor and add origin check to SIO --- server/MMVCServerSIO.py | 4 ++-- server/mods/origins.py | 24 ++++++++++++++++++++++ server/restapi/MMVC_Rest.py | 6 +++--- server/restapi/mods/trustedorigin.py | 30 ++++++++++------------------ server/sio/MMVC_SocketIOApp.py | 20 +++++++++++++++++-- server/sio/MMVC_SocketIOServer.py | 8 ++++++-- 6 files changed, 64 insertions(+), 28 deletions(-) create mode 100644 server/mods/origins.py diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py index 1e8541fd..807a0525 100755 --- a/server/MMVCServerSIO.py +++ b/server/MMVCServerSIO.py @@ -140,8 +140,8 @@ if __name__ == "MMVCServerSIO": mp.freeze_support() voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams) - app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins) - app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager) + app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT) + app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT) if __name__ == "__mp_main__": diff --git a/server/mods/origins.py b/server/mods/origins.py new file mode 100644 index 00000000..98f08f66 --- /dev/null +++ b/server/mods/origins.py @@ -0,0 +1,24 @@ +from typing import Optional, Sequence +from urllib.parse import urlparse + +ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." +SCHEMAS = ('http', 'https') +LOCAL_ORIGINS = ('127.0.0.1', 'localhost') + +def compute_local_origins(port: Optional[int] = None) -> list[str]: + local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS] + if port is not None: + local_origins = [f'{origin}:{port}' for origin in local_origins] + return local_origins + + +def normalize_origins(origins: Sequence[str]) -> set[str]: + allowed_origins = set() + for origin in origins: + url = urlparse(origin) + assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT + valid_origin = f'{url.scheme}://{url.hostname}' + if url.port: + valid_origin += f':{url.port}' + allowed_origins.add(valid_origin) + return allowed_origins diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py index 472d5695..98a0c3ef 100644 --- a/server/restapi/MMVC_Rest.py +++ b/server/restapi/MMVC_Rest.py @@ -6,7 +6,7 @@ from fastapi import FastAPI, Request, Response, HTTPException from fastapi.routing import APIRoute from fastapi.staticfiles import StaticFiles from fastapi.exceptions import RequestValidationError -from typing import Callable +from typing import Callable, Optional, Sequence, Literal from mods.log_control import VoiceChangaerLogger from voice_changer.VoiceChangerManager import VoiceChangerManager @@ -43,8 +43,8 @@ class MMVC_Rest: cls, voiceChangerManager: VoiceChangerManager, voiceChangerParams: VoiceChangerParams, - port: int, - allowedOrigins: list[str], + allowedOrigins: Optional[Sequence[str]] = None, + port: Optional[int] = None, ): if cls._instance is None: logger.info("[Voice Changer] MMVC_Rest initializing...") diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py index b6799547..582f4264 100644 --- a/server/restapi/mods/trustedorigin.py +++ b/server/restapi/mods/trustedorigin.py @@ -1,35 +1,27 @@ -import typing +from typing import Optional, Sequence, Literal -from urllib.parse import urlparse +from mods.origins import compute_local_origins, normalize_origins from starlette.datastructures import Headers from starlette.responses import PlainTextResponse from starlette.types import ASGIApp, Receive, Scope, Send -ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." - class TrustedOriginMiddleware: def __init__( self, app: ASGIApp, - allowed_origins: typing.Optional[typing.Sequence[str]] = None, - port: typing.Optional[int] = None, + allowed_origins: Optional[Sequence[str]] = None, + port: Optional[int] = None, ) -> None: - schemas = ['http', 'https'] - local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']] - if port is not None: - local_origins = [f'{origin}:{port}' for origin in local_origins] - self.allowed_origins: set[str] = set() - if allowed_origins is not None: - for origin in allowed_origins: - url = urlparse(origin) - assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT - valid_origin = f'{url.scheme}://{url.hostname}' - if url.port: - valid_origin += f':{url.port}' - self.allowed_origins.add(valid_origin) + + local_origins = compute_local_origins(port) self.allowed_origins.update(local_origins) + + if allowed_origins is not None: + normalized_origins = normalize_origins(allowed_origins) + self.allowed_origins.update(normalized_origins) + self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/server/sio/MMVC_SocketIOApp.py b/server/sio/MMVC_SocketIOApp.py index 23205d48..f73c6918 100644 --- a/server/sio/MMVC_SocketIOApp.py +++ b/server/sio/MMVC_SocketIOApp.py @@ -1,6 +1,8 @@ import socketio from mods.log_control import VoiceChangaerLogger +from mods.origins import compute_local_origins, normalize_origins +from typing import Sequence, Optional from sio.MMVC_SocketIOServer import MMVC_SocketIOServer from voice_changer.VoiceChangerManager import VoiceChangerManager from const import getFrontendPath @@ -12,10 +14,24 @@ class MMVC_SocketIOApp: _instance: socketio.ASGIApp | None = None @classmethod - def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager): + def get_instance( + cls, + app_fastapi, + voiceChangerManager: VoiceChangerManager, + allowedOrigins: Optional[Sequence[str]] = None, + port: Optional[int] = None, + ): if cls._instance is None: logger.info("[Voice Changer] MMVC_SocketIOApp initializing...") - sio = MMVC_SocketIOServer.get_instance(voiceChangerManager) + + allowed_origins: set[str] = set() + local_origins = compute_local_origins(port) + allowed_origins.update(local_origins) + if allowedOrigins is not None: + normalized_origins = normalize_origins(allowedOrigins) + allowed_origins.update(normalized_origins) + sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins)) + app_socketio = socketio.ASGIApp( sio, other_asgi_app=app_fastapi, diff --git a/server/sio/MMVC_SocketIOServer.py b/server/sio/MMVC_SocketIOServer.py index 9f168515..68000681 100644 --- a/server/sio/MMVC_SocketIOServer.py +++ b/server/sio/MMVC_SocketIOServer.py @@ -8,9 +8,13 @@ class MMVC_SocketIOServer: _instance: socketio.AsyncServer | None = None @classmethod - def get_instance(cls, voiceChangerManager: VoiceChangerManager): + def get_instance( + cls, + voiceChangerManager: VoiceChangerManager, + allowedOrigins: list[str], + ): if cls._instance is None: - sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins) namespace = MMVC_Namespace.get_instance(voiceChangerManager) sio.register_namespace(namespace) cls._instance = sio