From cf2b6933348d5c60e53cf922f61333267fd6a032 Mon Sep 17 00:00:00 2001 From: Yury Date: Sun, 17 Mar 2024 00:11:16 +0200 Subject: [PATCH] Harden web server security --- server/MMVCServerSIO.py | 34 +++++++++---------- server/restapi/MMVC_Rest.py | 12 +++---- server/restapi/mods/trustedorigin.py | 49 ++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 24 deletions(-) create mode 100644 server/restapi/mods/trustedorigin.py diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py index c1b15157..1e8541fd 100755 --- a/server/MMVCServerSIO.py +++ b/server/MMVCServerSIO.py @@ -65,6 +65,9 @@ def setupArgParser(): parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe") parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx") + parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.") + parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.") + return parser @@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams) printMessage(f"Booting PHASE :{__name__}", level=2) +HOST = args.host PORT = args.p -def localServer(logLevel: str = "critical"): +def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None): try: uvicorn.run( f"{os.path.basename(__file__)[:-3]}:app_socketio", - host="0.0.0.0", + host=HOST, port=int(PORT), reload=False if hasattr(sys, "_MEIPASS") else True, + ssl_keyfile=key_path, + ssl_certfile=cert_path, log_level=logLevel, ) except Exception as e: @@ -134,7 +140,7 @@ if __name__ == "MMVCServerSIO": mp.freeze_support() voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams) - app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams) + app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins) app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager) @@ -220,34 +226,26 @@ if __name__ == "__main__": printMessage("In many cases, it will launch when you access any of the following URLs.", level=2) if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker) if args.https == 1: - printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1) + printMessage(f"https://localhost:{EX_PORT}/", level=1) for ip in EX_IP.strip().split(" "): printMessage(f"https://{ip}:{EX_PORT}/", level=1) else: - printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1) + printMessage(f"http://localhost:{EX_PORT}/", level=1) else: # 直接python起動 if args.https == 1: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect((args.test_connect, 80)) hostname = s.getsockname()[0] - printMessage(f"https://127.0.0.1:{PORT}/", level=1) + printMessage(f"https://localhost:{PORT}/", level=1) printMessage(f"https://{hostname}:{PORT}/", level=1) else: - printMessage(f"http://127.0.0.1:{PORT}/", level=1) + printMessage(f"http://localhost:{PORT}/", level=1) # サーバ起動 if args.https: # HTTPS サーバ起動 try: - uvicorn.run( - f"{os.path.basename(__file__)[:-3]}:app_socketio", - host="0.0.0.0", - port=int(PORT), - reload=False if hasattr(sys, "_MEIPASS") else True, - ssl_keyfile=key_path, - ssl_certfile=cert_path, - log_level=args.logLevel, - ) + localServer(args.logLevel, key_path, cert_path) except Exception as e: logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}") @@ -256,12 +254,12 @@ if __name__ == "__main__": p.start() try: if sys.platform.startswith("win"): - process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"]) + process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"]) return_code = process.wait() logger.info("client closed.") p.terminate() elif sys.platform.startswith("darwin"): - process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"]) + process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"]) return_code = process.wait() logger.info("client closed.") p.terminate() diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py index 13879ac8..472d5695 100644 --- a/server/restapi/MMVC_Rest.py +++ b/server/restapi/MMVC_Rest.py @@ -1,9 +1,9 @@ import os import sys +from restapi.mods.trustedorigin import TrustedOriginMiddleware from fastapi import FastAPI, Request, Response, HTTPException from fastapi.routing import APIRoute -from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.exceptions import RequestValidationError from typing import Callable @@ -43,17 +43,17 @@ class MMVC_Rest: cls, voiceChangerManager: VoiceChangerManager, voiceChangerParams: VoiceChangerParams, + port: int, + allowedOrigins: list[str], ): if cls._instance is None: logger.info("[Voice Changer] MMVC_Rest initializing...") app_fastapi = FastAPI() app_fastapi.router.route_class = ValidationErrorLoggingRoute app_fastapi.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + TrustedOriginMiddleware, + allowed_origins=allowedOrigins, + port=port ) app_fastapi.mount( diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py new file mode 100644 index 00000000..d7c51d8a --- /dev/null +++ b/server/restapi/mods/trustedorigin.py @@ -0,0 +1,49 @@ +import typing + +from urllib.parse import urlparse +from starlette.datastructures import Headers +from starlette.responses import PlainTextResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com." + + +class TrustedOriginMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_origins: typing.Optional[typing.Sequence[str]] = None, + port: typing.Optional[int] = None, + ) -> None: + schemas = ['http', 'https'] + local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']] + if port is not None: + local_origins = [f'{origin}:{port}' for origin in local_origins] + + if not allowed_origins: + allowed_origins = local_origins + else: + for origin in allowed_origins: + assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT + allowed_origins = local_origins + allowed_origins + + self.app = app + self.allowed_origins = list(allowed_origins) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + origin = headers.get("origin", "") + # Origin header is not present for same origin + if not origin or origin in self.allowed_origins: + await self.app(scope, receive, send) + return + + response = PlainTextResponse("Invalid origin header", status_code=400) + await response(scope, receive, send)