Merge pull request #1153 from deiteris/harden-security

Harden web server security
This commit is contained in:
w-okada 2024-04-02 16:04:02 +09:00 committed by GitHub
commit 621ad25a8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 115 additions and 30 deletions

View File

@ -65,6 +65,9 @@ def setupArgParser():
parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe")
parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx")
parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.")
parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.")
return parser
@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams)
printMessage(f"Booting PHASE :{__name__}", level=2)
HOST = args.host
PORT = args.p
def localServer(logLevel: str = "critical"):
def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None):
try:
uvicorn.run(
f"{os.path.basename(__file__)[:-3]}:app_socketio",
host="0.0.0.0",
host=HOST,
port=int(PORT),
reload=False if hasattr(sys, "_MEIPASS") else True,
ssl_keyfile=key_path,
ssl_certfile=cert_path,
log_level=logLevel,
)
except Exception as e:
@ -134,8 +140,8 @@ if __name__ == "MMVCServerSIO":
mp.freeze_support()
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, args.allowed_origins, PORT)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager, args.allowed_origins, PORT)
if __name__ == "__mp_main__":
@ -220,34 +226,26 @@ if __name__ == "__main__":
printMessage("In many cases, it will launch when you access any of the following URLs.", level=2)
if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker)
if args.https == 1:
printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1)
printMessage(f"https://localhost:{EX_PORT}/", level=1)
for ip in EX_IP.strip().split(" "):
printMessage(f"https://{ip}:{EX_PORT}/", level=1)
else:
printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1)
printMessage(f"http://localhost:{EX_PORT}/", level=1)
else: # 直接python起動
if args.https == 1:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect((args.test_connect, 80))
hostname = s.getsockname()[0]
printMessage(f"https://127.0.0.1:{PORT}/", level=1)
printMessage(f"https://localhost:{PORT}/", level=1)
printMessage(f"https://{hostname}:{PORT}/", level=1)
else:
printMessage(f"http://127.0.0.1:{PORT}/", level=1)
printMessage(f"http://localhost:{PORT}/", level=1)
# サーバ起動
if args.https:
# HTTPS サーバ起動
try:
uvicorn.run(
f"{os.path.basename(__file__)[:-3]}:app_socketio",
host="0.0.0.0",
port=int(PORT),
reload=False if hasattr(sys, "_MEIPASS") else True,
ssl_keyfile=key_path,
ssl_certfile=cert_path,
log_level=args.logLevel,
)
localServer(args.logLevel, key_path, cert_path)
except Exception as e:
logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}")
@ -256,12 +254,12 @@ if __name__ == "__main__":
p.start()
try:
if sys.platform.startswith("win"):
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
return_code = process.wait()
logger.info("client closed.")
p.terminate()
elif sys.platform.startswith("darwin"):
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
return_code = process.wait()
logger.info("client closed.")
p.terminate()

24
server/mods/origins.py Normal file
View File

@ -0,0 +1,24 @@
from typing import Optional, Sequence
from urllib.parse import urlparse
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
SCHEMAS = ('http', 'https')
LOCAL_ORIGINS = ('127.0.0.1', 'localhost')
def compute_local_origins(port: Optional[int] = None) -> list[str]:
local_origins = [f'{schema}://{origin}' for schema in SCHEMAS for origin in LOCAL_ORIGINS]
if port is not None:
local_origins = [f'{origin}:{port}' for origin in local_origins]
return local_origins
def normalize_origins(origins: Sequence[str]) -> set[str]:
allowed_origins = set()
for origin in origins:
url = urlparse(origin)
assert url.scheme, ENFORCE_URL_ORIGIN_FORMAT
valid_origin = f'{url.scheme}://{url.hostname}'
if url.port:
valid_origin += f':{url.port}'
allowed_origins.add(valid_origin)
return allowed_origins

View File

@ -1,12 +1,12 @@
import os
import sys
from restapi.mods.trustedorigin import TrustedOriginMiddleware
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.exceptions import RequestValidationError
from typing import Callable
from typing import Callable, Optional, Sequence, Literal
from mods.log_control import VoiceChangaerLogger
from voice_changer.VoiceChangerManager import VoiceChangerManager
@ -43,17 +43,17 @@ class MMVC_Rest:
cls,
voiceChangerManager: VoiceChangerManager,
voiceChangerParams: VoiceChangerParams,
allowedOrigins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
):
if cls._instance is None:
logger.info("[Voice Changer] MMVC_Rest initializing...")
app_fastapi = FastAPI()
app_fastapi.router.route_class = ValidationErrorLoggingRoute
app_fastapi.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
TrustedOriginMiddleware,
allowed_origins=allowedOrigins,
port=port
)
app_fastapi.mount(

View File

@ -0,0 +1,43 @@
from typing import Optional, Sequence, Literal
from mods.origins import compute_local_origins, normalize_origins
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Receive, Scope, Send
class TrustedOriginMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_origins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
) -> None:
self.allowed_origins: set[str] = set()
local_origins = compute_local_origins(port)
self.allowed_origins.update(local_origins)
if allowed_origins is not None:
normalized_origins = normalize_origins(allowed_origins)
self.allowed_origins.update(normalized_origins)
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in (
"http",
"websocket",
): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin", "")
# Origin header is not present for same origin
if not origin or origin in self.allowed_origins:
await self.app(scope, receive, send)
return
response = PlainTextResponse("Invalid origin header", status_code=400)
await response(scope, receive, send)

View File

@ -1,6 +1,8 @@
import socketio
from mods.log_control import VoiceChangaerLogger
from mods.origins import compute_local_origins, normalize_origins
from typing import Sequence, Optional
from sio.MMVC_SocketIOServer import MMVC_SocketIOServer
from voice_changer.VoiceChangerManager import VoiceChangerManager
from const import getFrontendPath
@ -12,10 +14,24 @@ class MMVC_SocketIOApp:
_instance: socketio.ASGIApp | None = None
@classmethod
def get_instance(cls, app_fastapi, voiceChangerManager: VoiceChangerManager):
def get_instance(
cls,
app_fastapi,
voiceChangerManager: VoiceChangerManager,
allowedOrigins: Optional[Sequence[str]] = None,
port: Optional[int] = None,
):
if cls._instance is None:
logger.info("[Voice Changer] MMVC_SocketIOApp initializing...")
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager)
allowed_origins: set[str] = set()
local_origins = compute_local_origins(port)
allowed_origins.update(local_origins)
if allowedOrigins is not None:
normalized_origins = normalize_origins(allowedOrigins)
allowed_origins.update(normalized_origins)
sio = MMVC_SocketIOServer.get_instance(voiceChangerManager, list(allowed_origins))
app_socketio = socketio.ASGIApp(
sio,
other_asgi_app=app_fastapi,

View File

@ -8,9 +8,13 @@ class MMVC_SocketIOServer:
_instance: socketio.AsyncServer | None = None
@classmethod
def get_instance(cls, voiceChangerManager: VoiceChangerManager):
def get_instance(
cls,
voiceChangerManager: VoiceChangerManager,
allowedOrigins: list[str],
):
if cls._instance is None:
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=allowedOrigins)
namespace = MMVC_Namespace.get_instance(voiceChangerManager)
sio.register_namespace(namespace)
cls._instance = sio