From cf2b6933348d5c60e53cf922f61333267fd6a032 Mon Sep 17 00:00:00 2001
From: Yury <rip95_95@mail.ru>
Date: Sun, 17 Mar 2024 00:11:16 +0200
Subject: [PATCH 1/3] Harden web server security

---
 server/MMVCServerSIO.py              | 34 +++++++++----------
 server/restapi/MMVC_Rest.py          | 12 +++----
 server/restapi/mods/trustedorigin.py | 49 ++++++++++++++++++++++++++++
 3 files changed, 71 insertions(+), 24 deletions(-)
 create mode 100644 server/restapi/mods/trustedorigin.py

diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py
index c1b15157..1e8541fd 100755
--- a/server/MMVCServerSIO.py
+++ b/server/MMVCServerSIO.py
@@ -65,6 +65,9 @@ def setupArgParser():
     parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe")
     parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx")
 
+    parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.")
+    parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.")
+
     return parser
 
 
@@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams)
 
 printMessage(f"Booting PHASE :{__name__}", level=2)
 
+HOST = args.host
 PORT = args.p
 
 
-def localServer(logLevel: str = "critical"):
+def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None):
     try:
         uvicorn.run(
             f"{os.path.basename(__file__)[:-3]}:app_socketio",
-            host="0.0.0.0",
+            host=HOST,
             port=int(PORT),
             reload=False if hasattr(sys, "_MEIPASS") else True,
+            ssl_keyfile=key_path,
+            ssl_certfile=cert_path,
             log_level=logLevel,
         )
     except Exception as e:
@@ -134,7 +140,7 @@ if __name__ == "MMVCServerSIO":
     mp.freeze_support()
 
     voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
-    app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams)
+    app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
     app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
 
 
@@ -220,34 +226,26 @@ if __name__ == "__main__":
     printMessage("In many cases, it will launch when you access any of the following URLs.", level=2)
     if "EX_PORT" in locals() and "EX_IP" in locals():  # シェルスクリプト経由起動(docker)
         if args.https == 1:
-            printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1)
+            printMessage(f"https://localhost:{EX_PORT}/", level=1)
             for ip in EX_IP.strip().split(" "):
                 printMessage(f"https://{ip}:{EX_PORT}/", level=1)
         else:
-            printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1)
+            printMessage(f"http://localhost:{EX_PORT}/", level=1)
     else:  # 直接python起動
         if args.https == 1:
             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
             s.connect((args.test_connect, 80))
             hostname = s.getsockname()[0]
-            printMessage(f"https://127.0.0.1:{PORT}/", level=1)
+            printMessage(f"https://localhost:{PORT}/", level=1)
             printMessage(f"https://{hostname}:{PORT}/", level=1)
         else:
-            printMessage(f"http://127.0.0.1:{PORT}/", level=1)
+            printMessage(f"http://localhost:{PORT}/", level=1)
 
     # サーバ起動
     if args.https:
         # HTTPS サーバ起動
         try:
-            uvicorn.run(
-                f"{os.path.basename(__file__)[:-3]}:app_socketio",
-                host="0.0.0.0",
-                port=int(PORT),
-                reload=False if hasattr(sys, "_MEIPASS") else True,
-                ssl_keyfile=key_path,
-                ssl_certfile=cert_path,
-                log_level=args.logLevel,
-            )
+            localServer(args.logLevel, key_path, cert_path)
         except Exception as e:
             logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}")
 
@@ -256,12 +254,12 @@ if __name__ == "__main__":
         p.start()
         try:
             if sys.platform.startswith("win"):
-                process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
+                process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
                 return_code = process.wait()
                 logger.info("client closed.")
                 p.terminate()
             elif sys.platform.startswith("darwin"):
-                process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
+                process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
                 return_code = process.wait()
                 logger.info("client closed.")
                 p.terminate()
diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py
index 13879ac8..472d5695 100644
--- a/server/restapi/MMVC_Rest.py
+++ b/server/restapi/MMVC_Rest.py
@@ -1,9 +1,9 @@
 import os
 import sys
 
+from restapi.mods.trustedorigin import TrustedOriginMiddleware
 from fastapi import FastAPI, Request, Response, HTTPException
 from fastapi.routing import APIRoute
-from fastapi.middleware.cors import CORSMiddleware
 from fastapi.staticfiles import StaticFiles
 from fastapi.exceptions import RequestValidationError
 from typing import Callable
@@ -43,17 +43,17 @@ class MMVC_Rest:
         cls,
         voiceChangerManager: VoiceChangerManager,
         voiceChangerParams: VoiceChangerParams,
+        port: int,
+        allowedOrigins: list[str],
     ):
         if cls._instance is None:
             logger.info("[Voice Changer] MMVC_Rest initializing...")
             app_fastapi = FastAPI()
             app_fastapi.router.route_class = ValidationErrorLoggingRoute
             app_fastapi.add_middleware(
-                CORSMiddleware,
-                allow_origins=["*"],
-                allow_credentials=True,
-                allow_methods=["*"],
-                allow_headers=["*"],
+                TrustedOriginMiddleware,
+                allowed_origins=allowedOrigins,
+                port=port
             )
 
             app_fastapi.mount(
diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py
new file mode 100644
index 00000000..d7c51d8a
--- /dev/null
+++ b/server/restapi/mods/trustedorigin.py
@@ -0,0 +1,49 @@
+import typing
+
+from urllib.parse import urlparse
+from starlette.datastructures import Headers
+from starlette.responses import PlainTextResponse
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
+
+
+class TrustedOriginMiddleware:
+    def __init__(
+        self,
+        app: ASGIApp,
+        allowed_origins: typing.Optional[typing.Sequence[str]] = None,
+        port: typing.Optional[int] = None,
+    ) -> None:
+        schemas = ['http', 'https']
+        local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
+        if port is not None:
+            local_origins = [f'{origin}:{port}' for origin in local_origins]
+
+        if not allowed_origins:
+            allowed_origins = local_origins
+        else:
+            for origin in allowed_origins:
+                assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT
+            allowed_origins = local_origins + allowed_origins
+
+        self.app = app
+        self.allowed_origins = list(allowed_origins)
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        if scope["type"] not in (
+            "http",
+            "websocket",
+        ):  # pragma: no cover
+            await self.app(scope, receive, send)
+            return
+
+        headers = Headers(scope=scope)
+        origin = headers.get("origin", "")
+        # Origin header is not present for same origin
+        if not origin or origin in self.allowed_origins:
+            await self.app(scope, receive, send)
+            return
+
+        response = PlainTextResponse("Invalid origin header", status_code=400)
+        await response(scope, receive, send)

From ce9b599501a636eaef08918ac13afc881335073f Mon Sep 17 00:00:00 2001
From: Yury <rip95_95@mail.ru>
Date: Sun, 17 Mar 2024 16:26:55 +0200
Subject: [PATCH 2/3] Improve allowed origins input and use set

---
 server/restapi/mods/trustedorigin.py | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py
index d7c51d8a..b6799547 100644
--- a/server/restapi/mods/trustedorigin.py
+++ b/server/restapi/mods/trustedorigin.py
@@ -20,15 +20,17 @@ class TrustedOriginMiddleware:
         if port is not None:
             local_origins = [f'{origin}:{port}' for origin in local_origins]
 
-        if not allowed_origins:
-            allowed_origins = local_origins
-        else:
+        self.allowed_origins: set[str] = set()
+        if allowed_origins is not None:
             for origin in allowed_origins:
-                assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT
-            allowed_origins = local_origins + allowed_origins
-
+                url = urlparse(origin)
+                assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
+                valid_origin = f'{url.scheme}://{url.hostname}'
+                if url.port:
+                    valid_origin += f':{url.port}'
+                self.allowed_origins.add(valid_origin)
+        self.allowed_origins.update(local_origins)
         self.app = app
-        self.allowed_origins = list(allowed_origins)
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] not in (

From 8dd8d7127d16142d84c45b144da9bc1c692b67f5 Mon Sep 17 00:00:00 2001
From: Yury <rip95_95@mail.ru>
Date: Mon, 18 Mar 2024 22:52:22 +0200
Subject: [PATCH 3/3] Refactor and add origin check to SIO

---
 server/MMVCServerSIO.py              |  4 ++--
 server/mods/origins.py               | 24 ++++++++++++++++++++++
 server/restapi/MMVC_Rest.py          |  6 +++---
 server/restapi/mods/trustedorigin.py | 30 ++++++++++------------------
 server/sio/MMVC_SocketIOApp.py       | 20 +++++++++++++++++--
 server/sio/MMVC_SocketIOServer.py    |  8 ++++++--
 6 files changed, 64 insertions(+), 28 deletions(-)
 create mode 100644 server/mods/origins.py

diff --git a/server/MMVCServerSIO.py b/server/MMVCServerSIO.py
index 1e8541fd..807a0525 100755
--- a/server/MMVCServerSIO.py
+++ b/server/MMVCServerSIO.py
@@ -140,8 +140,8 @@ if __name__ == "MMVCServerSIO":
     mp.freeze_support()
 
     voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
-    app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
-    app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
+    app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT)
+    app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT)
 
 
 if __name__ == "__mp_main__":
diff --git a/server/mods/origins.py b/server/mods/origins.py
new file mode 100644
index 00000000..98f08f66
--- /dev/null
+++ b/server/mods/origins.py
@@ -0,0 +1,24 @@
+from typing import Optional, Sequence
+from urllib.parse import urlparse
+
+ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
+SCHEMAS = ('http', 'https')
+LOCAL_ORIGINS = ('127.0.0.1', 'localhost')
+
+def compute_local_origins(port: Optional[int] = None) -> list[str]:
+    local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS]
+    if port is not None:
+        local_origins = [f'{origin}:{port}' for origin in local_origins]
+    return local_origins
+
+
+def normalize_origins(origins: Sequence[str]) -> set[str]:
+    allowed_origins = set()
+    for origin in origins:
+        url = urlparse(origin)
+        assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
+        valid_origin = f'{url.scheme}://{url.hostname}'
+        if url.port:
+            valid_origin += f':{url.port}'
+        allowed_origins.add(valid_origin)
+    return allowed_origins
diff --git a/server/restapi/MMVC_Rest.py b/server/restapi/MMVC_Rest.py
index 472d5695..98a0c3ef 100644
--- a/server/restapi/MMVC_Rest.py
+++ b/server/restapi/MMVC_Rest.py
@@ -6,7 +6,7 @@ from fastapi import FastAPI, Request, Response, HTTPException
 from fastapi.routing import APIRoute
 from fastapi.staticfiles import StaticFiles
 from fastapi.exceptions import RequestValidationError
-from typing import Callable
+from typing import Callable, Optional, Sequence, Literal
 from mods.log_control import VoiceChangaerLogger
 from voice_changer.VoiceChangerManager import VoiceChangerManager
 
@@ -43,8 +43,8 @@ class MMVC_Rest:
         cls,
         voiceChangerManager: VoiceChangerManager,
         voiceChangerParams: VoiceChangerParams,
-        port: int,
-        allowedOrigins: list[str],
+        allowedOrigins: Optional[Sequence[str]] = None,
+        port: Optional[int] = None,
     ):
         if cls._instance is None:
             logger.info("[Voice Changer] MMVC_Rest initializing...")
diff --git a/server/restapi/mods/trustedorigin.py b/server/restapi/mods/trustedorigin.py
index b6799547..582f4264 100644
--- a/server/restapi/mods/trustedorigin.py
+++ b/server/restapi/mods/trustedorigin.py
@@ -1,35 +1,27 @@
-import typing
+from typing import Optional, Sequence, Literal
 
-from urllib.parse import urlparse
+from mods.origins import compute_local_origins, normalize_origins
 from starlette.datastructures import Headers
 from starlette.responses import PlainTextResponse
 from starlette.types import ASGIApp, Receive, Scope, Send
 
-ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
-
 
 class TrustedOriginMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        allowed_origins: typing.Optional[typing.Sequence[str]] = None,
-        port: typing.Optional[int] = None,
+        allowed_origins: Optional[Sequence[str]] = None,
+        port: Optional[int] = None,
     ) -> None:
-        schemas = ['http', 'https']
-        local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
-        if port is not None:
-            local_origins = [f'{origin}:{port}' for origin in local_origins]
-
         self.allowed_origins: set[str] = set()
-        if allowed_origins is not None:
-            for origin in allowed_origins:
-                url = urlparse(origin)
-                assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
-                valid_origin = f'{url.scheme}://{url.hostname}'
-                if url.port:
-                    valid_origin += f':{url.port}'
-                self.allowed_origins.add(valid_origin)
+
+        local_origins = compute_local_origins(port)
         self.allowed_origins.update(local_origins)
+
+        if allowed_origins is not None:
+            normalized_origins = normalize_origins(allowed_origins)
+            self.allowed_origins.update(normalized_origins)
+
         self.app = app
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
diff --git a/server/sio/MMVC_SocketIOApp.py b/server/sio/MMVC_SocketIOApp.py
index 23205d48..f73c6918 100644
--- a/server/sio/MMVC_SocketIOApp.py
+++ b/server/sio/MMVC_SocketIOApp.py
@@ -1,6 +1,8 @@
 import socketio
 from mods.log_control import VoiceChangaerLogger
+from mods.origins import compute_local_origins, normalize_origins
 
+from typing import Sequence, Optional
 from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
 from voice_changer.VoiceChangerManager import VoiceChangerManager
 from const import getFrontendPath
@@ -12,10 +14,24 @@ class MMVC_SocketIOApp:
     _instance: socketio.ASGIApp | None = None
 
     @classmethod
-    def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager):
+    def get_instance(
+        cls,
+        app_fastapi,
+        voiceChangerManager: VoiceChangerManager,
+        allowedOrigins: Optional[Sequence[str]] = None,
+        port: Optional[int] = None,
+    ):
         if cls._instance is None:
             logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
-            sio = MMVC_SocketIOServer.get_instance(voiceChangerManager)
+
+            allowed_origins: set[str] = set()
+            local_origins = compute_local_origins(port)
+            allowed_origins.update(local_origins)
+            if allowedOrigins is not None:
+                normalized_origins = normalize_origins(allowedOrigins)
+                allowed_origins.update(normalized_origins)
+            sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins))
+
             app_socketio = socketio.ASGIApp(
                 sio,
                 other_asgi_app=app_fastapi,
diff --git a/server/sio/MMVC_SocketIOServer.py b/server/sio/MMVC_SocketIOServer.py
index 9f168515..68000681 100644
--- a/server/sio/MMVC_SocketIOServer.py
+++ b/server/sio/MMVC_SocketIOServer.py
@@ -8,9 +8,13 @@ class MMVC_SocketIOServer:
     _instance: socketio.AsyncServer | None = None
 
     @classmethod
-    def get_instance(cls, voiceChangerManager: VoiceChangerManager):
+    def get_instance(
+        cls,
+        voiceChangerManager: VoiceChangerManager,
+        allowedOrigins: list[str],
+    ):
         if cls._instance is None:
-            sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
+            sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins)
             namespace = MMVC_Namespace.get_instance(voiceChangerManager)
             sio.register_namespace(namespace)
             cls._instance = sio