Harden web server security

This commit is contained in:
Yury 2024-03-17 00:11:16 +02:00
parent 11672e9653
commit cf2b693334
3 changed files with 71 additions and 24 deletions

View File

@ -65,6 +65,9 @@ def setupArgParser():
parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe")
parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx")
parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.")
parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.")
return parser
@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams)
printMessage(f"Booting PHASE :{__name__}", level=2)
HOST = args.host
PORT = args.p
def localServer(logLevel: str = "critical"):
def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None):
try:
uvicorn.run(
f"{os.path.basename(__file__)[:-3]}:app_socketio",
host="0.0.0.0",
host=HOST,
port=int(PORT),
reload=False if hasattr(sys, "_MEIPASS") else True,
ssl_keyfile=key_path,
ssl_certfile=cert_path,
log_level=logLevel,
)
except Exception as e:
@ -134,7 +140,7 @@ if __name__ == "MMVCServerSIO":
mp.freeze_support()
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams)
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
@ -220,34 +226,26 @@ if __name__ == "__main__":
printMessage("In many cases, it will launch when you access any of the following URLs.", level=2)
if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker)
if args.https == 1:
printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1)
printMessage(f"https://localhost:{EX_PORT}/", level=1)
for ip in EX_IP.strip().split(" "):
printMessage(f"https://{ip}:{EX_PORT}/", level=1)
else:
printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1)
printMessage(f"http://localhost:{EX_PORT}/", level=1)
else: # 直接python起動
if args.https == 1:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect((args.test_connect, 80))
hostname = s.getsockname()[0]
printMessage(f"https://127.0.0.1:{PORT}/", level=1)
printMessage(f"https://localhost:{PORT}/", level=1)
printMessage(f"https://{hostname}:{PORT}/", level=1)
else:
printMessage(f"http://127.0.0.1:{PORT}/", level=1)
printMessage(f"http://localhost:{PORT}/", level=1)
# サーバ起動
if args.https:
# HTTPS サーバ起動
try:
uvicorn.run(
f"{os.path.basename(__file__)[:-3]}:app_socketio",
host="0.0.0.0",
port=int(PORT),
reload=False if hasattr(sys, "_MEIPASS") else True,
ssl_keyfile=key_path,
ssl_certfile=cert_path,
log_level=args.logLevel,
)
localServer(args.logLevel, key_path, cert_path)
except Exception as e:
logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}")
@ -256,12 +254,12 @@ if __name__ == "__main__":
p.start()
try:
if sys.platform.startswith("win"):
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
return_code = process.wait()
logger.info("client closed.")
p.terminate()
elif sys.platform.startswith("darwin"):
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
return_code = process.wait()
logger.info("client closed.")
p.terminate()

View File

@ -1,9 +1,9 @@
import os
import sys
from restapi.mods.trustedorigin import TrustedOriginMiddleware
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.exceptions import RequestValidationError
from typing import Callable
@ -43,17 +43,17 @@ class MMVC_Rest:
cls,
voiceChangerManager: VoiceChangerManager,
voiceChangerParams: VoiceChangerParams,
port: int,
allowedOrigins: list[str],
):
if cls._instance is None:
logger.info("[Voice Changer] MMVC_Rest initializing...")
app_fastapi = FastAPI()
app_fastapi.router.route_class = ValidationErrorLoggingRoute
app_fastapi.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
TrustedOriginMiddleware,
allowed_origins=allowedOrigins,
port=port
)
app_fastapi.mount(

View File

@ -0,0 +1,49 @@
import typing
from urllib.parse import urlparse
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
class TrustedOriginMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
port: typing.Optional[int] = None,
) -> None:
schemas = ['http', 'https']
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
if port is not None:
local_origins = [f'{origin}:{port}' for origin in local_origins]
if not allowed_origins:
allowed_origins = local_origins
else:
for origin in allowed_origins:
assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT
allowed_origins = local_origins + allowed_origins
self.app = app
self.allowed_origins = list(allowed_origins)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in (
"http",
"websocket",
): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin", "")
# Origin header is not present for same origin
if not origin or origin in self.allowed_origins:
await self.app(scope, receive, send)
return
response = PlainTextResponse("Invalid origin header", status_code=400)
await response(scope, receive, send)